mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-07-27 00:06:37 +00:00
Replace torchaudio.load with pyav. (#8989)
This commit is contained in:
parent
9a470e073e
commit
54a45b9967
@ -278,6 +278,62 @@ class PreviewAudio(SaveAudio):
|
|||||||
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
|
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def f32_pcm(wav: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Convert audio to float 32 bits PCM format."""
|
||||||
|
if wav.dtype.is_floating_point:
|
||||||
|
return wav
|
||||||
|
elif wav.dtype == torch.int16:
|
||||||
|
return wav.float() / (2 ** 15)
|
||||||
|
elif wav.dtype == torch.int32:
|
||||||
|
return wav.float() / (2 ** 31)
|
||||||
|
raise ValueError(f"Unsupported wav dtype: {wav.dtype}")
|
||||||
|
|
||||||
|
def load(filepath: str, frame_offset: int = 0, num_frames: int = -1) -> tuple[torch.Tensor, int]:
|
||||||
|
with av.open(filepath) as af:
|
||||||
|
if not af.streams.audio:
|
||||||
|
raise ValueError("No audio stream found in the file.")
|
||||||
|
|
||||||
|
stream = af.streams.audio[0]
|
||||||
|
sr = stream.codec_context.sample_rate
|
||||||
|
n_channels = stream.channels
|
||||||
|
|
||||||
|
seek_time = frame_offset / sr if frame_offset > 0 else 0.0
|
||||||
|
duration = num_frames / sr if num_frames > 0 else -1.0
|
||||||
|
|
||||||
|
sample_offset = int(sr * seek_time)
|
||||||
|
num_samples = int(sr * duration) if duration >= 0 else -1
|
||||||
|
|
||||||
|
# Small negative offset for MP3 artifacts, NOTE: this is LLM code so idk if it's actually necessary'
|
||||||
|
seek_sec = max(0, seek_time - 0.1) if filepath.lower().endswith('.mp3') else seek_time
|
||||||
|
af.seek(int(seek_sec / stream.time_base), stream=stream)
|
||||||
|
|
||||||
|
frames = []
|
||||||
|
length = 0
|
||||||
|
for frame in af.decode(streams=stream.index):
|
||||||
|
current_offset = int(frame.rate * frame.pts * frame.time_base)
|
||||||
|
strip = max(0, sample_offset - current_offset)
|
||||||
|
|
||||||
|
buf = torch.from_numpy(frame.to_ndarray())
|
||||||
|
if buf.shape[0] != n_channels:
|
||||||
|
buf = buf.view(-1, n_channels).t()
|
||||||
|
|
||||||
|
buf = buf[:, strip:]
|
||||||
|
frames.append(buf)
|
||||||
|
length += buf.shape[1]
|
||||||
|
|
||||||
|
if num_samples > 0 and length >= num_samples:
|
||||||
|
break
|
||||||
|
|
||||||
|
if not frames:
|
||||||
|
raise ValueError("No audio frames decoded.")
|
||||||
|
|
||||||
|
wav = torch.cat(frames, dim=1)
|
||||||
|
if num_samples > 0:
|
||||||
|
wav = wav[:, :num_samples]
|
||||||
|
|
||||||
|
wav = f32_pcm(wav)
|
||||||
|
return wav, sr
|
||||||
|
|
||||||
class LoadAudio:
|
class LoadAudio:
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
@ -292,7 +348,7 @@ class LoadAudio:
|
|||||||
|
|
||||||
def load(self, audio):
|
def load(self, audio):
|
||||||
audio_path = folder_paths.get_annotated_filepath(audio)
|
audio_path = folder_paths.get_annotated_filepath(audio)
|
||||||
waveform, sample_rate = torchaudio.load(audio_path)
|
waveform, sample_rate = load(audio_path)
|
||||||
audio = {"waveform": waveform.unsqueeze(0), "sample_rate": sample_rate}
|
audio = {"waveform": waveform.unsqueeze(0), "sample_rate": sample_rate}
|
||||||
return (audio, )
|
return (audio, )
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user