mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-07-27 08:16:44 +00:00
sync changes from #8989
This commit is contained in:
parent
5a8c426112
commit
4c83303801
@ -3,6 +3,7 @@ from __future__ import annotations
|
||||
import hashlib
|
||||
import os
|
||||
|
||||
import av
|
||||
import torch
|
||||
import torchaudio
|
||||
|
||||
@ -82,9 +83,47 @@ class LoadAudio(io.ComfyNode):
|
||||
input_dir = folder_paths.get_input_directory()
|
||||
return sorted(folder_paths.filter_files_content_types(os.listdir(input_dir), ["audio", "video"]))
|
||||
|
||||
@classmethod
|
||||
def load(cls, filepath: str) -> tuple[torch.Tensor, int]:
|
||||
with av.open(filepath) as af:
|
||||
if not af.streams.audio:
|
||||
raise ValueError("No audio stream found in the file.")
|
||||
|
||||
stream = af.streams.audio[0]
|
||||
sr = stream.codec_context.sample_rate
|
||||
n_channels = stream.channels
|
||||
|
||||
frames = []
|
||||
length = 0
|
||||
for frame in af.decode(streams=stream.index):
|
||||
buf = torch.from_numpy(frame.to_ndarray())
|
||||
if buf.shape[0] != n_channels:
|
||||
buf = buf.view(-1, n_channels).t()
|
||||
|
||||
frames.append(buf)
|
||||
length += buf.shape[1]
|
||||
|
||||
if not frames:
|
||||
raise ValueError("No audio frames decoded.")
|
||||
|
||||
wav = torch.cat(frames, dim=1)
|
||||
wav = cls.f32_pcm(wav)
|
||||
return wav, sr
|
||||
|
||||
@classmethod
|
||||
def f32_pcm(cls, wav: torch.Tensor) -> torch.Tensor:
|
||||
"""Convert audio to float 32 bits PCM format."""
|
||||
if wav.dtype.is_floating_point:
|
||||
return wav
|
||||
elif wav.dtype == torch.int16:
|
||||
return wav.float() / (2 ** 15)
|
||||
elif wav.dtype == torch.int32:
|
||||
return wav.float() / (2 ** 31)
|
||||
raise ValueError(f"Unsupported wav dtype: {wav.dtype}")
|
||||
|
||||
@classmethod
|
||||
def execute(cls, audio) -> io.NodeOutput:
|
||||
waveform, sample_rate = torchaudio.load(folder_paths.get_annotated_filepath(audio))
|
||||
waveform, sample_rate = cls.load(folder_paths.get_annotated_filepath(audio))
|
||||
return io.NodeOutput({"waveform": waveform.unsqueeze(0), "sample_rate": sample_rate})
|
||||
|
||||
@classmethod
|
||||
|
Loading…
x
Reference in New Issue
Block a user