From 4c8330380193fbcaf2aeb6e3cafdffa760a60d39 Mon Sep 17 00:00:00 2001 From: bigcat88 Date: Fri, 25 Jul 2025 14:48:39 +0300 Subject: [PATCH] sync changes from #8989 --- comfy_extras/v3/nodes_audio.py | 41 +++++++++++++++++++++++++++++++++- 1 file changed, 40 insertions(+), 1 deletion(-) diff --git a/comfy_extras/v3/nodes_audio.py b/comfy_extras/v3/nodes_audio.py index ff7a1f89b..3fb2fcc7d 100644 --- a/comfy_extras/v3/nodes_audio.py +++ b/comfy_extras/v3/nodes_audio.py @@ -3,6 +3,7 @@ from __future__ import annotations import hashlib import os +import av import torch import torchaudio @@ -82,9 +83,47 @@ class LoadAudio(io.ComfyNode): input_dir = folder_paths.get_input_directory() return sorted(folder_paths.filter_files_content_types(os.listdir(input_dir), ["audio", "video"])) + @classmethod + def load(cls, filepath: str) -> tuple[torch.Tensor, int]: + with av.open(filepath) as af: + if not af.streams.audio: + raise ValueError("No audio stream found in the file.") + + stream = af.streams.audio[0] + sr = stream.codec_context.sample_rate + n_channels = stream.channels + + frames = [] + length = 0 + for frame in af.decode(streams=stream.index): + buf = torch.from_numpy(frame.to_ndarray()) + if buf.shape[0] != n_channels: + buf = buf.view(-1, n_channels).t() + + frames.append(buf) + length += buf.shape[1] + + if not frames: + raise ValueError("No audio frames decoded.") + + wav = torch.cat(frames, dim=1) + wav = cls.f32_pcm(wav) + return wav, sr + + @classmethod + def f32_pcm(cls, wav: torch.Tensor) -> torch.Tensor: + """Convert audio to float 32 bits PCM format.""" + if wav.dtype.is_floating_point: + return wav + elif wav.dtype == torch.int16: + return wav.float() / (2 ** 15) + elif wav.dtype == torch.int32: + return wav.float() / (2 ** 31) + raise ValueError(f"Unsupported wav dtype: {wav.dtype}") + @classmethod def execute(cls, audio) -> io.NodeOutput: - waveform, sample_rate = torchaudio.load(folder_paths.get_annotated_filepath(audio)) + waveform, sample_rate = cls.load(folder_paths.get_annotated_filepath(audio)) return io.NodeOutput({"waveform": waveform.unsqueeze(0), "sample_rate": sample_rate}) @classmethod