mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-09-10 19:46:38 +00:00
add StabilityAudio API nodes (#9749)
This commit is contained in:
@@ -518,6 +518,71 @@ async def upload_audio_to_comfyapi(
|
||||
return await upload_file_to_comfyapi(audio_bytes_io, filename, mime_type, auth_kwargs)
|
||||
|
||||
|
||||
def f32_pcm(wav: torch.Tensor) -> torch.Tensor:
|
||||
"""Convert audio to float 32 bits PCM format. Copy-paste from nodes_audio.py file."""
|
||||
if wav.dtype.is_floating_point:
|
||||
return wav
|
||||
elif wav.dtype == torch.int16:
|
||||
return wav.float() / (2 ** 15)
|
||||
elif wav.dtype == torch.int32:
|
||||
return wav.float() / (2 ** 31)
|
||||
raise ValueError(f"Unsupported wav dtype: {wav.dtype}")
|
||||
|
||||
|
||||
def audio_bytes_to_audio_input(audio_bytes: bytes,) -> dict:
|
||||
"""
|
||||
Decode any common audio container from bytes using PyAV and return
|
||||
a Comfy AUDIO dict: {"waveform": [1, C, T] float32, "sample_rate": int}.
|
||||
"""
|
||||
with av.open(io.BytesIO(audio_bytes)) as af:
|
||||
if not af.streams.audio:
|
||||
raise ValueError("No audio stream found in response.")
|
||||
stream = af.streams.audio[0]
|
||||
|
||||
in_sr = int(stream.codec_context.sample_rate)
|
||||
out_sr = in_sr
|
||||
|
||||
frames: list[torch.Tensor] = []
|
||||
n_channels = stream.channels or 1
|
||||
|
||||
for frame in af.decode(streams=stream.index):
|
||||
arr = frame.to_ndarray() # shape can be [C, T] or [T, C] or [T]
|
||||
buf = torch.from_numpy(arr)
|
||||
if buf.ndim == 1:
|
||||
buf = buf.unsqueeze(0) # [T] -> [1, T]
|
||||
elif buf.shape[0] != n_channels and buf.shape[-1] == n_channels:
|
||||
buf = buf.transpose(0, 1).contiguous() # [T, C] -> [C, T]
|
||||
elif buf.shape[0] != n_channels:
|
||||
buf = buf.reshape(-1, n_channels).t().contiguous() # fallback to [C, T]
|
||||
frames.append(buf)
|
||||
|
||||
if not frames:
|
||||
raise ValueError("Decoded zero audio frames.")
|
||||
|
||||
wav = torch.cat(frames, dim=1) # [C, T]
|
||||
wav = f32_pcm(wav)
|
||||
return {"waveform": wav.unsqueeze(0).contiguous(), "sample_rate": out_sr}
|
||||
|
||||
|
||||
def audio_input_to_mp3(audio: AudioInput) -> io.BytesIO:
|
||||
waveform = audio["waveform"].cpu()
|
||||
|
||||
output_buffer = io.BytesIO()
|
||||
output_container = av.open(output_buffer, mode='w', format="mp3")
|
||||
|
||||
out_stream = output_container.add_stream("libmp3lame", rate=audio["sample_rate"])
|
||||
out_stream.bit_rate = 320000
|
||||
|
||||
frame = av.AudioFrame.from_ndarray(waveform.movedim(0, 1).reshape(1, -1).float().numpy(), format='flt', layout='mono' if waveform.shape[0] == 1 else 'stereo')
|
||||
frame.sample_rate = audio["sample_rate"]
|
||||
frame.pts = 0
|
||||
output_container.mux(out_stream.encode(frame))
|
||||
output_container.mux(out_stream.encode(None))
|
||||
output_container.close()
|
||||
output_buffer.seek(0)
|
||||
return output_buffer
|
||||
|
||||
|
||||
def audio_to_base64_string(
|
||||
audio: AudioInput, container_format: str = "mp4", codec_name: str = "aac"
|
||||
) -> str:
|
||||
|
@@ -125,3 +125,25 @@ class StabilityResultsGetResponse(BaseModel):
|
||||
|
||||
class StabilityAsyncResponse(BaseModel):
|
||||
id: Optional[str] = Field(None)
|
||||
|
||||
|
||||
class StabilityTextToAudioRequest(BaseModel):
|
||||
model: str = Field(...)
|
||||
prompt: str = Field(...)
|
||||
duration: int = Field(190, ge=1, le=190)
|
||||
seed: int = Field(0, ge=0, le=4294967294)
|
||||
steps: int = Field(8, ge=4, le=8)
|
||||
output_format: str = Field("wav")
|
||||
|
||||
|
||||
class StabilityAudioToAudioRequest(StabilityTextToAudioRequest):
|
||||
strength: float = Field(0.01, ge=0.01, le=1.0)
|
||||
|
||||
|
||||
class StabilityAudioInpaintRequest(StabilityTextToAudioRequest):
|
||||
mask_start: int = Field(30, ge=0, le=190)
|
||||
mask_end: int = Field(190, ge=0, le=190)
|
||||
|
||||
|
||||
class StabilityAudioResponse(BaseModel):
|
||||
audio: Optional[str] = Field(None)
|
||||
|
@@ -2,7 +2,7 @@ from inspect import cleandoc
|
||||
from typing import Optional
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.latest import ComfyExtension, io as comfy_io
|
||||
from comfy_api.latest import ComfyExtension, Input, io as comfy_io
|
||||
from comfy_api_nodes.apis.stability_api import (
|
||||
StabilityUpscaleConservativeRequest,
|
||||
StabilityUpscaleCreativeRequest,
|
||||
@@ -15,6 +15,10 @@ from comfy_api_nodes.apis.stability_api import (
|
||||
Stability_SD3_5_Model,
|
||||
Stability_SD3_5_GenerationMode,
|
||||
get_stability_style_presets,
|
||||
StabilityTextToAudioRequest,
|
||||
StabilityAudioToAudioRequest,
|
||||
StabilityAudioInpaintRequest,
|
||||
StabilityAudioResponse,
|
||||
)
|
||||
from comfy_api_nodes.apis.client import (
|
||||
ApiEndpoint,
|
||||
@@ -27,7 +31,10 @@ from comfy_api_nodes.apinode_utils import (
|
||||
bytesio_to_image_tensor,
|
||||
tensor_to_bytesio,
|
||||
validate_string,
|
||||
audio_bytes_to_audio_input,
|
||||
audio_input_to_mp3,
|
||||
)
|
||||
from comfy_api_nodes.util.validation_utils import validate_audio_duration
|
||||
|
||||
import torch
|
||||
import base64
|
||||
@@ -649,6 +656,306 @@ class StabilityUpscaleFastNode(comfy_io.ComfyNode):
|
||||
return comfy_io.NodeOutput(returned_image)
|
||||
|
||||
|
||||
class StabilityTextToAudio(comfy_io.ComfyNode):
|
||||
"""Generates high-quality music and sound effects from text descriptions."""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return comfy_io.Schema(
|
||||
node_id="StabilityTextToAudio",
|
||||
display_name="Stability AI Text To Audio",
|
||||
category="api node/audio/Stability AI",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
inputs=[
|
||||
comfy_io.Combo.Input(
|
||||
"model",
|
||||
options=["stable-audio-2.5"],
|
||||
),
|
||||
comfy_io.String.Input("prompt", multiline=True, default=""),
|
||||
comfy_io.Int.Input(
|
||||
"duration",
|
||||
default=190,
|
||||
min=1,
|
||||
max=190,
|
||||
step=1,
|
||||
tooltip="Controls the duration in seconds of the generated audio.",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=4294967294,
|
||||
step=1,
|
||||
display_mode=comfy_io.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
tooltip="The random seed used for generation.",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Int.Input(
|
||||
"steps",
|
||||
default=8,
|
||||
min=4,
|
||||
max=8,
|
||||
step=1,
|
||||
tooltip="Controls the number of sampling steps.",
|
||||
optional=True,
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
comfy_io.Audio.Output(),
|
||||
],
|
||||
hidden=[
|
||||
comfy_io.Hidden.auth_token_comfy_org,
|
||||
comfy_io.Hidden.api_key_comfy_org,
|
||||
comfy_io.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(cls, model: str, prompt: str, duration: int, seed: int, steps: int) -> comfy_io.NodeOutput:
|
||||
validate_string(prompt, max_length=10000)
|
||||
payload = StabilityTextToAudioRequest(prompt=prompt, model=model, duration=duration, seed=seed, steps=steps)
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/stability/v2beta/audio/stable-audio-2/text-to-audio",
|
||||
method=HttpMethod.POST,
|
||||
request_model=StabilityTextToAudioRequest,
|
||||
response_model=StabilityAudioResponse,
|
||||
),
|
||||
request=payload,
|
||||
content_type="multipart/form-data",
|
||||
auth_kwargs= {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
},
|
||||
)
|
||||
response_api = await operation.execute()
|
||||
if not response_api.audio:
|
||||
raise ValueError("No audio file was received in response.")
|
||||
return comfy_io.NodeOutput(audio_bytes_to_audio_input(base64.b64decode(response_api.audio)))
|
||||
|
||||
|
||||
class StabilityAudioToAudio(comfy_io.ComfyNode):
|
||||
"""Transforms existing audio samples into new high-quality compositions using text instructions."""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return comfy_io.Schema(
|
||||
node_id="StabilityAudioToAudio",
|
||||
display_name="Stability AI Audio To Audio",
|
||||
category="api node/audio/Stability AI",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
inputs=[
|
||||
comfy_io.Combo.Input(
|
||||
"model",
|
||||
options=["stable-audio-2.5"],
|
||||
),
|
||||
comfy_io.String.Input("prompt", multiline=True, default=""),
|
||||
comfy_io.Audio.Input("audio", tooltip="Audio must be between 6 and 190 seconds long."),
|
||||
comfy_io.Int.Input(
|
||||
"duration",
|
||||
default=190,
|
||||
min=1,
|
||||
max=190,
|
||||
step=1,
|
||||
tooltip="Controls the duration in seconds of the generated audio.",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=4294967294,
|
||||
step=1,
|
||||
display_mode=comfy_io.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
tooltip="The random seed used for generation.",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Int.Input(
|
||||
"steps",
|
||||
default=8,
|
||||
min=4,
|
||||
max=8,
|
||||
step=1,
|
||||
tooltip="Controls the number of sampling steps.",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Float.Input(
|
||||
"strength",
|
||||
default=1,
|
||||
min=0.01,
|
||||
max=1.0,
|
||||
step=0.01,
|
||||
display_mode=comfy_io.NumberDisplay.slider,
|
||||
tooltip="Parameter controls how much influence the audio parameter has on the generated audio.",
|
||||
optional=True,
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
comfy_io.Audio.Output(),
|
||||
],
|
||||
hidden=[
|
||||
comfy_io.Hidden.auth_token_comfy_org,
|
||||
comfy_io.Hidden.api_key_comfy_org,
|
||||
comfy_io.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls, model: str, prompt: str, audio: Input.Audio, duration: int, seed: int, steps: int, strength: float
|
||||
) -> comfy_io.NodeOutput:
|
||||
validate_string(prompt, max_length=10000)
|
||||
validate_audio_duration(audio, 6, 190)
|
||||
payload = StabilityAudioToAudioRequest(
|
||||
prompt=prompt, model=model, duration=duration, seed=seed, steps=steps, strength=strength
|
||||
)
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/stability/v2beta/audio/stable-audio-2/audio-to-audio",
|
||||
method=HttpMethod.POST,
|
||||
request_model=StabilityAudioToAudioRequest,
|
||||
response_model=StabilityAudioResponse,
|
||||
),
|
||||
request=payload,
|
||||
content_type="multipart/form-data",
|
||||
files={"audio": audio_input_to_mp3(audio)},
|
||||
auth_kwargs= {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
},
|
||||
)
|
||||
response_api = await operation.execute()
|
||||
if not response_api.audio:
|
||||
raise ValueError("No audio file was received in response.")
|
||||
return comfy_io.NodeOutput(audio_bytes_to_audio_input(base64.b64decode(response_api.audio)))
|
||||
|
||||
|
||||
class StabilityAudioInpaint(comfy_io.ComfyNode):
|
||||
"""Transforms part of existing audio sample using text instructions."""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return comfy_io.Schema(
|
||||
node_id="StabilityAudioInpaint",
|
||||
display_name="Stability AI Audio Inpaint",
|
||||
category="api node/audio/Stability AI",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
inputs=[
|
||||
comfy_io.Combo.Input(
|
||||
"model",
|
||||
options=["stable-audio-2.5"],
|
||||
),
|
||||
comfy_io.String.Input("prompt", multiline=True, default=""),
|
||||
comfy_io.Audio.Input("audio", tooltip="Audio must be between 6 and 190 seconds long."),
|
||||
comfy_io.Int.Input(
|
||||
"duration",
|
||||
default=190,
|
||||
min=1,
|
||||
max=190,
|
||||
step=1,
|
||||
tooltip="Controls the duration in seconds of the generated audio.",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=4294967294,
|
||||
step=1,
|
||||
display_mode=comfy_io.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
tooltip="The random seed used for generation.",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Int.Input(
|
||||
"steps",
|
||||
default=8,
|
||||
min=4,
|
||||
max=8,
|
||||
step=1,
|
||||
tooltip="Controls the number of sampling steps.",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Int.Input(
|
||||
"mask_start",
|
||||
default=30,
|
||||
min=0,
|
||||
max=190,
|
||||
step=1,
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Int.Input(
|
||||
"mask_end",
|
||||
default=190,
|
||||
min=0,
|
||||
max=190,
|
||||
step=1,
|
||||
optional=True,
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
comfy_io.Audio.Output(),
|
||||
],
|
||||
hidden=[
|
||||
comfy_io.Hidden.auth_token_comfy_org,
|
||||
comfy_io.Hidden.api_key_comfy_org,
|
||||
comfy_io.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
model: str,
|
||||
prompt: str,
|
||||
audio: Input.Audio,
|
||||
duration: int,
|
||||
seed: int,
|
||||
steps: int,
|
||||
mask_start: int,
|
||||
mask_end: int,
|
||||
) -> comfy_io.NodeOutput:
|
||||
validate_string(prompt, max_length=10000)
|
||||
if mask_end <= mask_start:
|
||||
raise ValueError(f"Value of mask_end({mask_end}) should be greater then mask_start({mask_start})")
|
||||
validate_audio_duration(audio, 6, 190)
|
||||
|
||||
payload = StabilityAudioInpaintRequest(
|
||||
prompt=prompt,
|
||||
model=model,
|
||||
duration=duration,
|
||||
seed=seed,
|
||||
steps=steps,
|
||||
mask_start=mask_start,
|
||||
mask_end=mask_end,
|
||||
)
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/stability/v2beta/audio/stable-audio-2/inpaint",
|
||||
method=HttpMethod.POST,
|
||||
request_model=StabilityAudioInpaintRequest,
|
||||
response_model=StabilityAudioResponse,
|
||||
),
|
||||
request=payload,
|
||||
content_type="multipart/form-data",
|
||||
files={"audio": audio_input_to_mp3(audio)},
|
||||
auth_kwargs={
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
},
|
||||
)
|
||||
response_api = await operation.execute()
|
||||
if not response_api.audio:
|
||||
raise ValueError("No audio file was received in response.")
|
||||
return comfy_io.NodeOutput(audio_bytes_to_audio_input(base64.b64decode(response_api.audio)))
|
||||
|
||||
|
||||
class StabilityExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]:
|
||||
@@ -658,6 +965,9 @@ class StabilityExtension(ComfyExtension):
|
||||
StabilityUpscaleConservativeNode,
|
||||
StabilityUpscaleCreativeNode,
|
||||
StabilityUpscaleFastNode,
|
||||
StabilityTextToAudio,
|
||||
StabilityAudioToAudio,
|
||||
StabilityAudioInpaint,
|
||||
]
|
||||
|
||||
|
||||
|
@@ -2,7 +2,7 @@ import logging
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from comfy_api.input.video_types import VideoInput
|
||||
from comfy_api.latest import Input
|
||||
|
||||
|
||||
def get_image_dimensions(image: torch.Tensor) -> tuple[int, int]:
|
||||
@@ -101,7 +101,7 @@ def validate_aspect_ratio_closeness(
|
||||
|
||||
|
||||
def validate_video_dimensions(
|
||||
video: VideoInput,
|
||||
video: Input.Video,
|
||||
min_width: Optional[int] = None,
|
||||
max_width: Optional[int] = None,
|
||||
min_height: Optional[int] = None,
|
||||
@@ -126,7 +126,7 @@ def validate_video_dimensions(
|
||||
|
||||
|
||||
def validate_video_duration(
|
||||
video: VideoInput,
|
||||
video: Input.Video,
|
||||
min_duration: Optional[float] = None,
|
||||
max_duration: Optional[float] = None,
|
||||
):
|
||||
@@ -151,3 +151,17 @@ def get_number_of_images(images):
|
||||
if isinstance(images, torch.Tensor):
|
||||
return images.shape[0] if images.ndim >= 4 else 1
|
||||
return len(images)
|
||||
|
||||
|
||||
def validate_audio_duration(
|
||||
audio: Input.Audio,
|
||||
min_duration: Optional[float] = None,
|
||||
max_duration: Optional[float] = None,
|
||||
) -> None:
|
||||
sr = int(audio["sample_rate"])
|
||||
dur = int(audio["waveform"].shape[-1]) / sr
|
||||
eps = 1.0 / sr
|
||||
if min_duration is not None and dur + eps < min_duration:
|
||||
raise ValueError(f"Audio duration must be at least {min_duration}s, got {dur + eps:.2f}s")
|
||||
if max_duration is not None and dur - eps > max_duration:
|
||||
raise ValueError(f"Audio duration must be at most {max_duration}s, got {dur - eps:.2f}s")
|
||||
|
Reference in New Issue
Block a user