diff --git a/comfy_extras/nodes_audio.py b/comfy_extras/nodes_audio.py index 8cd647846..38697240e 100644 --- a/comfy_extras/nodes_audio.py +++ b/comfy_extras/nodes_audio.py @@ -278,6 +278,62 @@ class PreviewAudio(SaveAudio): "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, } +def f32_pcm(wav: torch.Tensor) -> torch.Tensor: + """Convert audio to float 32 bits PCM format.""" + if wav.dtype.is_floating_point: + return wav + elif wav.dtype == torch.int16: + return wav.float() / (2 ** 15) + elif wav.dtype == torch.int32: + return wav.float() / (2 ** 31) + raise ValueError(f"Unsupported wav dtype: {wav.dtype}") + +def load(filepath: str, frame_offset: int = 0, num_frames: int = -1) -> tuple[torch.Tensor, int]: + with av.open(filepath) as af: + if not af.streams.audio: + raise ValueError("No audio stream found in the file.") + + stream = af.streams.audio[0] + sr = stream.codec_context.sample_rate + n_channels = stream.channels + + seek_time = frame_offset / sr if frame_offset > 0 else 0.0 + duration = num_frames / sr if num_frames > 0 else -1.0 + + sample_offset = int(sr * seek_time) + num_samples = int(sr * duration) if duration >= 0 else -1 + + # Small negative offset for MP3 artifacts, NOTE: this is LLM code so idk if it's actually necessary' + seek_sec = max(0, seek_time - 0.1) if filepath.lower().endswith('.mp3') else seek_time + af.seek(int(seek_sec / stream.time_base), stream=stream) + + frames = [] + length = 0 + for frame in af.decode(streams=stream.index): + current_offset = int(frame.rate * frame.pts * frame.time_base) + strip = max(0, sample_offset - current_offset) + + buf = torch.from_numpy(frame.to_ndarray()) + if buf.shape[0] != n_channels: + buf = buf.view(-1, n_channels).t() + + buf = buf[:, strip:] + frames.append(buf) + length += buf.shape[1] + + if num_samples > 0 and length >= num_samples: + break + + if not frames: + raise ValueError("No audio frames decoded.") + + wav = torch.cat(frames, dim=1) + if num_samples > 0: + wav = wav[:, :num_samples] + + wav = f32_pcm(wav) + return wav, sr + class LoadAudio: @classmethod def INPUT_TYPES(s): @@ -292,7 +348,7 @@ class LoadAudio: def load(self, audio): audio_path = folder_paths.get_annotated_filepath(audio) - waveform, sample_rate = torchaudio.load(audio_path) + waveform, sample_rate = load(audio_path) audio = {"waveform": waveform.unsqueeze(0), "sample_rate": sample_rate} return (audio, )