mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-07-29 17:26:34 +00:00
sync changes from #8989
This commit is contained in:
parent
5a8c426112
commit
4c83303801
@ -3,6 +3,7 @@ from __future__ import annotations
|
|||||||
import hashlib
|
import hashlib
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
import av
|
||||||
import torch
|
import torch
|
||||||
import torchaudio
|
import torchaudio
|
||||||
|
|
||||||
@ -82,9 +83,47 @@ class LoadAudio(io.ComfyNode):
|
|||||||
input_dir = folder_paths.get_input_directory()
|
input_dir = folder_paths.get_input_directory()
|
||||||
return sorted(folder_paths.filter_files_content_types(os.listdir(input_dir), ["audio", "video"]))
|
return sorted(folder_paths.filter_files_content_types(os.listdir(input_dir), ["audio", "video"]))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load(cls, filepath: str) -> tuple[torch.Tensor, int]:
|
||||||
|
with av.open(filepath) as af:
|
||||||
|
if not af.streams.audio:
|
||||||
|
raise ValueError("No audio stream found in the file.")
|
||||||
|
|
||||||
|
stream = af.streams.audio[0]
|
||||||
|
sr = stream.codec_context.sample_rate
|
||||||
|
n_channels = stream.channels
|
||||||
|
|
||||||
|
frames = []
|
||||||
|
length = 0
|
||||||
|
for frame in af.decode(streams=stream.index):
|
||||||
|
buf = torch.from_numpy(frame.to_ndarray())
|
||||||
|
if buf.shape[0] != n_channels:
|
||||||
|
buf = buf.view(-1, n_channels).t()
|
||||||
|
|
||||||
|
frames.append(buf)
|
||||||
|
length += buf.shape[1]
|
||||||
|
|
||||||
|
if not frames:
|
||||||
|
raise ValueError("No audio frames decoded.")
|
||||||
|
|
||||||
|
wav = torch.cat(frames, dim=1)
|
||||||
|
wav = cls.f32_pcm(wav)
|
||||||
|
return wav, sr
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def f32_pcm(cls, wav: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Convert audio to float 32 bits PCM format."""
|
||||||
|
if wav.dtype.is_floating_point:
|
||||||
|
return wav
|
||||||
|
elif wav.dtype == torch.int16:
|
||||||
|
return wav.float() / (2 ** 15)
|
||||||
|
elif wav.dtype == torch.int32:
|
||||||
|
return wav.float() / (2 ** 31)
|
||||||
|
raise ValueError(f"Unsupported wav dtype: {wav.dtype}")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, audio) -> io.NodeOutput:
|
def execute(cls, audio) -> io.NodeOutput:
|
||||||
waveform, sample_rate = torchaudio.load(folder_paths.get_annotated_filepath(audio))
|
waveform, sample_rate = cls.load(folder_paths.get_annotated_filepath(audio))
|
||||||
return io.NodeOutput({"waveform": waveform.unsqueeze(0), "sample_rate": sample_rate})
|
return io.NodeOutput({"waveform": waveform.unsqueeze(0), "sample_rate": sample_rate})
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
Loading…
x
Reference in New Issue
Block a user