V3 Nodes: Load,Save,Vae audio nodes; sort imports; ruff

This commit is contained in:
bigcat88 2025-07-15 13:11:50 +03:00
parent ac05d9a5fa
commit b17cc99c1e
No known key found for this signature in database
GPG Key ID: 1F0BF0EC3CF22721
6 changed files with 376 additions and 103 deletions

0
comfy_api/v3/__init__.py Normal file
View File

View File

@ -1,4 +1,4 @@
from typing import Optional, Callable
from typing import Callable, Optional
def first_real_override(cls: type, name: str, *, base: type) -> Optional[Callable]:

View File

@ -1,26 +1,29 @@
from __future__ import annotations
from typing import Any, Literal, TypeVar, Callable, TypedDict
from typing_extensions import NotRequired
from enum import Enum
from abc import ABC, abstractmethod
from dataclasses import dataclass, asdict
from collections import Counter
from comfy_execution.graph import ExecutionBlocker
from comfy_api.v3.resources import Resources, ResourcesLocal
import copy
from abc import ABC, abstractmethod
from collections import Counter
from dataclasses import asdict, dataclass
from enum import Enum
from typing import Any, Callable, Literal, TypedDict, TypeVar
# used for type hinting
import torch
from spandrel import ImageModelDescriptor
from comfy.model_patcher import ModelPatcher
from comfy.samplers import Sampler, CFGGuider
from comfy.sd import CLIP
from comfy.controlnet import ControlNet
from comfy.sd import VAE
from comfy.sd import StyleModel as StyleModel_
from typing_extensions import NotRequired
from comfy.clip_vision import ClipVisionModel
from comfy.clip_vision import Output as ClipVisionOutput_
from comfy_api.input import VideoInput
from comfy.controlnet import ControlNet
from comfy.hooks import HookGroup, HookKeyframeGroup
from comfy.model_patcher import ModelPatcher
from comfy.samplers import CFGGuider, Sampler
from comfy.sd import CLIP, VAE
from comfy.sd import StyleModel as StyleModel_
from comfy_api.input import VideoInput
from comfy_api.v3.resources import Resources, ResourcesLocal
from comfy_execution.graph import ExecutionBlocker
# from comfy_extras.nodes_images import SVG as SVG_ # NOTE: needs to be moved before can be imported due to circular reference
@ -1137,7 +1140,7 @@ class ComfyNodeV3:
@classmethod
def GET_NODE_INFO_V3(cls) -> dict[str, Any]:
schema = cls.GET_SCHEMA()
# schema = cls.GET_SCHEMA()
# TODO: finish
return None
@ -1183,84 +1186,84 @@ class ComfyNodeV3:
#--------------------------------------------
_DESCRIPTION = None
@classproperty
def DESCRIPTION(cls):
def DESCRIPTION(cls): # noqa
if cls._DESCRIPTION is None:
cls.GET_SCHEMA()
return cls._DESCRIPTION
_CATEGORY = None
@classproperty
def CATEGORY(cls):
def CATEGORY(cls): # noqa
if cls._CATEGORY is None:
cls.GET_SCHEMA()
return cls._CATEGORY
_EXPERIMENTAL = None
@classproperty
def EXPERIMENTAL(cls):
def EXPERIMENTAL(cls): # noqa
if cls._EXPERIMENTAL is None:
cls.GET_SCHEMA()
return cls._EXPERIMENTAL
_DEPRECATED = None
@classproperty
def DEPRECATED(cls):
def DEPRECATED(cls): # noqa
if cls._DEPRECATED is None:
cls.GET_SCHEMA()
return cls._DEPRECATED
_API_NODE = None
@classproperty
def API_NODE(cls):
def API_NODE(cls): # noqa
if cls._API_NODE is None:
cls.GET_SCHEMA()
return cls._API_NODE
_OUTPUT_NODE = None
@classproperty
def OUTPUT_NODE(cls):
def OUTPUT_NODE(cls): # noqa
if cls._OUTPUT_NODE is None:
cls.GET_SCHEMA()
return cls._OUTPUT_NODE
_INPUT_IS_LIST = None
@classproperty
def INPUT_IS_LIST(cls):
def INPUT_IS_LIST(cls): # noqa
if cls._INPUT_IS_LIST is None:
cls.GET_SCHEMA()
return cls._INPUT_IS_LIST
_OUTPUT_IS_LIST = None
@classproperty
def OUTPUT_IS_LIST(cls):
def OUTPUT_IS_LIST(cls): # noqa
if cls._OUTPUT_IS_LIST is None:
cls.GET_SCHEMA()
return cls._OUTPUT_IS_LIST
_RETURN_TYPES = None
@classproperty
def RETURN_TYPES(cls):
def RETURN_TYPES(cls): # noqa
if cls._RETURN_TYPES is None:
cls.GET_SCHEMA()
return cls._RETURN_TYPES
_RETURN_NAMES = None
@classproperty
def RETURN_NAMES(cls):
def RETURN_NAMES(cls): # noqa
if cls._RETURN_NAMES is None:
cls.GET_SCHEMA()
return cls._RETURN_NAMES
_OUTPUT_TOOLTIPS = None
@classproperty
def OUTPUT_TOOLTIPS(cls):
def OUTPUT_TOOLTIPS(cls): # noqa
if cls._OUTPUT_TOOLTIPS is None:
cls.GET_SCHEMA()
return cls._OUTPUT_TOOLTIPS
_NOT_IDEMPOTENT = None
@classproperty
def NOT_IDEMPOTENT(cls):
def NOT_IDEMPOTENT(cls): # noqa
if cls._NOT_IDEMPOTENT is None:
cls.GET_SCHEMA()
return cls._NOT_IDEMPOTENT
@ -1440,36 +1443,36 @@ class TestNode(ComfyNodeV3):
def execute(cls, **kwargs):
pass
if __name__ == "__main__":
print("hello there")
inputs: list[InputV3] = [
Int.Input("tessfes", widgetType=String.io_type),
Int.Input("my_int"),
Custom("XYZ").Input("xyz"),
Custom("MODEL_M").Input("model1"),
Image.Input("my_image"),
Float.Input("my_float"),
MultiType.Input("my_inputs", [String, Custom("MODEL_M"), Custom("XYZ")]),
]
Custom("XYZ").Input()
outputs: list[OutputV3] = [
Image.Output("image"),
Custom("XYZ").Output("xyz"),
]
for c in inputs:
if isinstance(c, MultiType):
print(f"{c}, {type(c)}, {type(c).io_type}, {c.id}, {[x.io_type for x in c.io_types]}")
print(c.get_io_type_V1())
else:
print(f"{c}, {type(c)}, {type(c).io_type}, {c.id}")
for c in outputs:
print(f"{c}, {type(c)}, {type(c).io_type}, {c.id}")
zz = TestNode()
print(zz.GET_NODE_INFO_V1())
# aa = NodeInfoV1()
# print(asdict(aa))
# print(as_pruned_dict(aa))
# if __name__ == "__main__":
# print("hello there")
# inputs: list[InputV3] = [
# Int.Input("tessfes", widgetType=String.io_type),
# Int.Input("my_int"),
# Custom("XYZ").Input("xyz"),
# Custom("MODEL_M").Input("model1"),
# Image.Input("my_image"),
# Float.Input("my_float"),
# MultiType.Input("my_inputs", [String, Custom("MODEL_M"), Custom("XYZ")]),
# ]
# Custom("XYZ").Input()
# outputs: list[OutputV3] = [
# Image.Output("image"),
# Custom("XYZ").Output("xyz"),
# ]
#
# for c in inputs:
# if isinstance(c, MultiType):
# print(f"{c}, {type(c)}, {type(c).io_type}, {c.id}, {[x.io_type for x in c.io_types]}")
# print(c.get_io_type_V1())
# else:
# print(f"{c}, {type(c)}, {type(c).io_type}, {c.id}")
#
# for c in outputs:
# print(f"{c}, {type(c)}, {type(c).io_type}, {c.id}")
#
# zz = TestNode()
# print(zz.GET_NODE_INFO_V1())
#
# # aa = NodeInfoV1()
# # print(asdict(aa))
# # print(as_pruned_dict(aa))

View File

@ -1,19 +1,21 @@
from __future__ import annotations
from abc import ABC, abstractmethod
import json
import os
import random
from io import BytesIO
import av
import numpy as np
import torchaudio
from comfy_api.v3.io import Image, FolderType, _UIOutput, ComfyNodeV3
# used for image preview
from comfy.cli_args import args
import folder_paths
import random
from PIL import Image as PILImage
from PIL.PngImagePlugin import PngInfo
import os
import json
import numpy as np
import folder_paths
# used for image preview
from comfy.cli_args import args
from comfy_api.v3.io import ComfyNodeV3, FolderType, Image, _UIOutput
class SavedResult(dict):
@ -67,11 +69,13 @@ class PreviewImage(_UIOutput):
"animated": (self.animated,)
}
class PreviewMask(PreviewImage):
def __init__(self, mask: PreviewMask.Type, animated: bool=False, cls: ComfyNodeV3=None, **kwargs):
preview = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])).movedim(1, -1).expand(-1, -1, -1, 3)
super().__init__(preview, animated, cls, **kwargs)
# class UILatent(_UIOutput):
# def __init__(self, values: list[SavedResult | dict], **kwargs):
# output_dir = folder_paths.get_temp_directory()
@ -119,21 +123,15 @@ class PreviewMask(PreviewImage):
# "latents": self.values,
# }
class PreviewAudio(_UIOutput):
def __init__(self, values: list[SavedResult | dict], **kwargs):
self.values = values
class PreviewAudio(_UIOutput):
def __init__(self, audio, cls: ComfyNodeV3=None, **kwargs):
output_dir = folder_paths.get_temp_directory()
type = "temp"
prefix_append = "_temp_" + ''.join(random.choice("abcdefghijklmnopqrstupvxyz") for x in range(5))
filename_prefix = "ComfyUI"
quality = "128k"
format = "flac"
filename_prefix += prefix_append
filename_prefix = "ComfyUI_temp_" + ''.join(random.choice("abcdefghijklmnopqrstupvxyz") for x in range(5))
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(
filename_prefix, output_dir
filename_prefix, folder_paths.get_temp_directory()
)
# Prepare metadata dictionary
@ -223,7 +221,7 @@ class PreviewAudio(_UIOutput):
with open(output_path, 'wb') as f:
f.write(output_buffer.getbuffer())
results.append(SavedResult(file, subfolder, type))
results.append(SavedResult(file, subfolder, FolderType.temp))
counter += 1
self.values = results
@ -231,6 +229,7 @@ class PreviewAudio(_UIOutput):
def as_dict(self):
return {"audio": self.values}
class PreviewUI3D(_UIOutput):
def __init__(self, values: list[SavedResult | dict], **kwargs):
self.values = values
@ -238,6 +237,7 @@ class PreviewUI3D(_UIOutput):
def as_dict(self):
return {"3d": self.values}
class PreviewText(_UIOutput):
def __init__(self, value: str, **kwargs):
self.value = value

View File

@ -1,38 +1,79 @@
from __future__ import annotations
import torchaudio
import folder_paths
import os
import io
import hashlib
import json
import os
from io import BytesIO
import av
import torch
import torchaudio
import comfy.model_management
import folder_paths
import node_helpers
from comfy.cli_args import args
from comfy_api.v3 import io, ui
class PreviewAudio_V3(io.ComfyNodeV3):
class ConditioningStableAudio_V3(io.ComfyNodeV3):
@classmethod
def DEFINE_SCHEMA(cls):
return io.SchemaV3(
node_id="PreviewAudio_V3",
display_name="Preview Audio _V3",
category="audio",
node_id="ConditioningStableAudio_V3",
category="conditioning",
inputs=[
io.Audio.Input("audio"),
io.Conditioning.Input(id="positive"),
io.Conditioning.Input(id="negative"),
io.Float.Input(id="seconds_start", default=0.0, min=0.0, max=1000.0, step=0.1),
io.Float.Input(id="seconds_total", default=47.0, min=0.0, max=1000.0, step=0.1),
],
outputs=[
io.Conditioning.Output(id="positive_out", display_name="positive"),
io.Conditioning.Output(id="negative_out", display_name="negative"),
],
hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo],
is_output_node=True,
)
@classmethod
def execute(cls, audio):
return io.NodeOutput(ui=ui.PreviewAudio(audio, cls=cls))
def execute(cls, positive, negative, seconds_start, seconds_total) -> io.NodeOutput:
return io.NodeOutput(
node_helpers.conditioning_set_values(
positive, {"seconds_start": seconds_start, "seconds_total": seconds_total}
),
node_helpers.conditioning_set_values(
negative, {"seconds_start": seconds_start, "seconds_total": seconds_total}
),
)
class EmptyLatentAudio_V3(io.ComfyNodeV3):
@classmethod
def DEFINE_SCHEMA(cls):
return io.SchemaV3(
node_id="EmptyLatentAudio_V3",
category="latent/audio",
inputs=[
io.Float.Input(id="seconds", default=47.6, min=1.0, max=1000.0, step=0.1),
io.Int.Input(
id="batch_size", default=1, min=1, max=4096, tooltip="The number of latent images in the batch."
),
],
outputs=[io.Latent.Output()],
)
@classmethod
def execute(cls, seconds, batch_size) -> io.NodeOutput:
length = round((seconds * 44100 / 2048) / 2) * 2
latent = torch.zeros([batch_size, 64, length], device=comfy.model_management.intermediate_device())
return io.NodeOutput({"samples":latent, "type": "audio"})
class LoadAudio_V3(io.ComfyNodeV3):
@classmethod
def DEFINE_SCHEMA(cls):
return io.SchemaV3(
node_id="LoadAudio_V3",
display_name="Load Audio _V3",
node_id="LoadAudio_V3", # frontend expects "LoadAudio" to work
display_name="Load Audio _V3", # frontend ignores "display_name" for this node
category="audio",
inputs=[
io.Combo.Input("audio", upload=io.UploadType.audio, options=cls.get_files_options()),
@ -65,14 +106,242 @@ class LoadAudio_V3(io.ComfyNodeV3):
return True
class PreviewAudio_V3(io.ComfyNodeV3):
@classmethod
def DEFINE_SCHEMA(cls):
return io.SchemaV3(
node_id="PreviewAudio_V3", # frontend expects "PreviewAudio" to work
display_name="Preview Audio _V3", # frontend ignores "display_name" for this node
category="audio",
inputs=[
io.Audio.Input("audio"),
],
hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo],
is_output_node=True,
)
@classmethod
def execute(cls, audio) -> io.NodeOutput:
return io.NodeOutput(ui=ui.PreviewAudio(audio, cls=cls))
class SaveAudioMP3_V3(io.ComfyNodeV3):
@classmethod
def DEFINE_SCHEMA(cls):
return io.SchemaV3(
node_id="SaveAudioMP3_V3", # frontend expects "SaveAudioMP3" to work
display_name="Save Audio(MP3) _V3", # frontend ignores "display_name" for this node
category="audio",
inputs=[
io.Audio.Input("audio"),
io.String.Input("filename_prefix", default="audio/ComfyUI"),
io.Combo.Input("quality", options=["V0", "128k", "320k"], default="V0"),
],
hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo],
is_output_node=True,
)
@classmethod
def execute(self, audio, filename_prefix="ComfyUI", format="mp3", quality="V0") -> io.NodeOutput:
return _save_audio(self, audio, filename_prefix, format, quality)
class SaveAudioOpus_V3(io.ComfyNodeV3):
@classmethod
def DEFINE_SCHEMA(cls):
return io.SchemaV3(
node_id="SaveAudioOpus_V3", # frontend expects "SaveAudioOpus" to work
display_name="Save Audio(Opus) _V3", # frontend ignores "display_name" for this node
category="audio",
inputs=[
io.Audio.Input("audio"),
io.String.Input("filename_prefix", default="audio/ComfyUI"),
io.Combo.Input("quality", options=["64k", "96k", "128k", "192k", "320k"], default="128k"),
],
hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo],
is_output_node=True,
)
@classmethod
def execute(self, audio, filename_prefix="ComfyUI", format="opus", quality="128k") -> io.NodeOutput:
return _save_audio(self, audio, filename_prefix, format, quality)
class SaveAudio_V3(io.ComfyNodeV3):
@classmethod
def DEFINE_SCHEMA(cls):
return io.SchemaV3(
node_id="SaveAudio_V3", # frontend expects "SaveAudio" to work
display_name="Save Audio _V3", # frontend ignores "display_name" for this node
category="audio",
inputs=[
io.Audio.Input("audio"),
io.String.Input("filename_prefix", default="audio/ComfyUI"),
],
hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo],
is_output_node=True,
)
@classmethod
def execute(cls, audio, filename_prefix="ComfyUI", format="flac") -> io.NodeOutput:
return _save_audio(cls, audio, filename_prefix, format)
class VAEDecodeAudio_V3(io.ComfyNodeV3):
@classmethod
def DEFINE_SCHEMA(cls):
return io.SchemaV3(
node_id="VAEDecodeAudio_V3",
category="latent/audio",
inputs=[
io.Latent.Input(id="samples"),
io.Vae.Input(id="vae"),
],
outputs=[io.Audio.Output()],
)
@classmethod
def execute(cls, vae, samples) -> io.NodeOutput:
audio = vae.decode(samples["samples"]).movedim(-1, 1)
std = torch.std(audio, dim=[1,2], keepdim=True) * 5.0
std[std < 1.0] = 1.0
audio /= std
return io.NodeOutput({"waveform": audio, "sample_rate": 44100})
class VAEEncodeAudio_V3(io.ComfyNodeV3):
@classmethod
def DEFINE_SCHEMA(cls):
return io.SchemaV3(
node_id="VAEEncodeAudio_V3",
category="latent/audio",
inputs=[
io.Audio.Input(id="audio"),
io.Vae.Input(id="vae"),
],
outputs=[io.Latent.Output()],
)
@classmethod
def execute(cls, vae, audio) -> io.NodeOutput:
sample_rate = audio["sample_rate"]
if 44100 != sample_rate:
waveform = torchaudio.functional.resample(audio["waveform"], sample_rate, 44100)
else:
waveform = audio["waveform"]
return io.NodeOutput({"samples": vae.encode(waveform.movedim(1, -1))})
def _save_audio(cls, audio, filename_prefix="ComfyUI", format="flac", quality="128k") -> io.NodeOutput:
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(
filename_prefix, folder_paths.get_output_directory()
)
# Prepare metadata dictionary
metadata = {}
if not args.disable_metadata:
if cls.hidden.prompt is not None:
metadata["prompt"] = json.dumps(cls.hidden.prompt)
if cls.hidden.extra_pnginfo is not None:
for x in cls.hidden.extra_pnginfo:
metadata[x] = json.dumps(cls.hidden.extra_pnginfo[x])
# Opus supported sample rates
OPUS_RATES = [8000, 12000, 16000, 24000, 48000]
results = []
for (batch_number, waveform) in enumerate(audio["waveform"].cpu()):
filename_with_batch_num = filename.replace("%batch_num%", str(batch_number))
file = f"{filename_with_batch_num}_{counter:05}_.{format}"
output_path = os.path.join(full_output_folder, file)
# Use original sample rate initially
sample_rate = audio["sample_rate"]
# Handle Opus sample rate requirements
if format == "opus":
if sample_rate > 48000:
sample_rate = 48000
elif sample_rate not in OPUS_RATES:
# Find the next highest supported rate
for rate in sorted(OPUS_RATES):
if rate > sample_rate:
sample_rate = rate
break
if sample_rate not in OPUS_RATES: # Fallback if still not supported
sample_rate = 48000
# Resample if necessary
if sample_rate != audio["sample_rate"]:
waveform = torchaudio.functional.resample(waveform, audio["sample_rate"], sample_rate)
# Create output with specified format
output_buffer = BytesIO()
output_container = av.open(output_buffer, mode='w', format=format)
# Set metadata on the container
for key, value in metadata.items():
output_container.metadata[key] = value
# Set up the output stream with appropriate properties
if format == "opus":
out_stream = output_container.add_stream("libopus", rate=sample_rate)
if quality == "64k":
out_stream.bit_rate = 64000
elif quality == "96k":
out_stream.bit_rate = 96000
elif quality == "128k":
out_stream.bit_rate = 128000
elif quality == "192k":
out_stream.bit_rate = 192000
elif quality == "320k":
out_stream.bit_rate = 320000
elif format == "mp3":
out_stream = output_container.add_stream("libmp3lame", rate=sample_rate)
if quality == "V0":
#TODO i would really love to support V3 and V5 but there doesn't seem to be a way to set the qscale level, the property below is a bool
out_stream.codec_context.qscale = 1
elif quality == "128k":
out_stream.bit_rate = 128000
elif quality == "320k":
out_stream.bit_rate = 320000
else: # format == "flac":
out_stream = output_container.add_stream("flac", rate=sample_rate)
frame = av.AudioFrame.from_ndarray(
waveform.movedim(0, 1).reshape(1, -1).float().numpy(),
format='flt',
layout='mono' if waveform.shape[0] == 1 else 'stereo',
)
frame.sample_rate = sample_rate
frame.pts = 0
output_container.mux(out_stream.encode(frame))
# Flush encoder
output_container.mux(out_stream.encode(None))
# Close containers
output_container.close()
# Write the output to file
output_buffer.seek(0)
with open(output_path, 'wb') as f:
f.write(output_buffer.getbuffer())
results.append(ui.SavedResult(file, subfolder, io.FolderType.output))
counter += 1
return io.NodeOutput(ui={"audio": results})
NODES_LIST: list[type[io.ComfyNodeV3]] = [
# EmptyLatentAudio_V3,
# VAEEncodeAudio_V3,
# VAEDecodeAudio_V3,
# SaveAudio_V3,
# SaveAudioMP3_V3,
# SaveAudioOpus_V3,
ConditioningStableAudio_V3,
EmptyLatentAudio_V3,
LoadAudio_V3,
PreviewAudio_V3,
# ConditioningStableAudio_V3,
SaveAudioMP3_V3,
SaveAudioOpus_V3,
SaveAudio_V3,
VAEDecodeAudio_V3,
VAEEncodeAudio_V3,
]

View File

@ -2299,6 +2299,7 @@ def init_builtin_extra_nodes():
"nodes_tcfg.py",
"nodes_v3_test.py",
"nodes_v1_test.py",
"v3/nodes_audio.py",
"v3/nodes_controlnet.py",
"v3/nodes_images.py",
"v3/nodes_mask.py",