mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-09-11 12:06:23 +00:00
add StabilityAudio API nodes (#9749)
This commit is contained in:
@@ -518,6 +518,71 @@ async def upload_audio_to_comfyapi(
|
|||||||
return await upload_file_to_comfyapi(audio_bytes_io, filename, mime_type, auth_kwargs)
|
return await upload_file_to_comfyapi(audio_bytes_io, filename, mime_type, auth_kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def f32_pcm(wav: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Convert audio to float 32 bits PCM format. Copy-paste from nodes_audio.py file."""
|
||||||
|
if wav.dtype.is_floating_point:
|
||||||
|
return wav
|
||||||
|
elif wav.dtype == torch.int16:
|
||||||
|
return wav.float() / (2 ** 15)
|
||||||
|
elif wav.dtype == torch.int32:
|
||||||
|
return wav.float() / (2 ** 31)
|
||||||
|
raise ValueError(f"Unsupported wav dtype: {wav.dtype}")
|
||||||
|
|
||||||
|
|
||||||
|
def audio_bytes_to_audio_input(audio_bytes: bytes,) -> dict:
|
||||||
|
"""
|
||||||
|
Decode any common audio container from bytes using PyAV and return
|
||||||
|
a Comfy AUDIO dict: {"waveform": [1, C, T] float32, "sample_rate": int}.
|
||||||
|
"""
|
||||||
|
with av.open(io.BytesIO(audio_bytes)) as af:
|
||||||
|
if not af.streams.audio:
|
||||||
|
raise ValueError("No audio stream found in response.")
|
||||||
|
stream = af.streams.audio[0]
|
||||||
|
|
||||||
|
in_sr = int(stream.codec_context.sample_rate)
|
||||||
|
out_sr = in_sr
|
||||||
|
|
||||||
|
frames: list[torch.Tensor] = []
|
||||||
|
n_channels = stream.channels or 1
|
||||||
|
|
||||||
|
for frame in af.decode(streams=stream.index):
|
||||||
|
arr = frame.to_ndarray() # shape can be [C, T] or [T, C] or [T]
|
||||||
|
buf = torch.from_numpy(arr)
|
||||||
|
if buf.ndim == 1:
|
||||||
|
buf = buf.unsqueeze(0) # [T] -> [1, T]
|
||||||
|
elif buf.shape[0] != n_channels and buf.shape[-1] == n_channels:
|
||||||
|
buf = buf.transpose(0, 1).contiguous() # [T, C] -> [C, T]
|
||||||
|
elif buf.shape[0] != n_channels:
|
||||||
|
buf = buf.reshape(-1, n_channels).t().contiguous() # fallback to [C, T]
|
||||||
|
frames.append(buf)
|
||||||
|
|
||||||
|
if not frames:
|
||||||
|
raise ValueError("Decoded zero audio frames.")
|
||||||
|
|
||||||
|
wav = torch.cat(frames, dim=1) # [C, T]
|
||||||
|
wav = f32_pcm(wav)
|
||||||
|
return {"waveform": wav.unsqueeze(0).contiguous(), "sample_rate": out_sr}
|
||||||
|
|
||||||
|
|
||||||
|
def audio_input_to_mp3(audio: AudioInput) -> io.BytesIO:
|
||||||
|
waveform = audio["waveform"].cpu()
|
||||||
|
|
||||||
|
output_buffer = io.BytesIO()
|
||||||
|
output_container = av.open(output_buffer, mode='w', format="mp3")
|
||||||
|
|
||||||
|
out_stream = output_container.add_stream("libmp3lame", rate=audio["sample_rate"])
|
||||||
|
out_stream.bit_rate = 320000
|
||||||
|
|
||||||
|
frame = av.AudioFrame.from_ndarray(waveform.movedim(0, 1).reshape(1, -1).float().numpy(), format='flt', layout='mono' if waveform.shape[0] == 1 else 'stereo')
|
||||||
|
frame.sample_rate = audio["sample_rate"]
|
||||||
|
frame.pts = 0
|
||||||
|
output_container.mux(out_stream.encode(frame))
|
||||||
|
output_container.mux(out_stream.encode(None))
|
||||||
|
output_container.close()
|
||||||
|
output_buffer.seek(0)
|
||||||
|
return output_buffer
|
||||||
|
|
||||||
|
|
||||||
def audio_to_base64_string(
|
def audio_to_base64_string(
|
||||||
audio: AudioInput, container_format: str = "mp4", codec_name: str = "aac"
|
audio: AudioInput, container_format: str = "mp4", codec_name: str = "aac"
|
||||||
) -> str:
|
) -> str:
|
||||||
|
@@ -125,3 +125,25 @@ class StabilityResultsGetResponse(BaseModel):
|
|||||||
|
|
||||||
class StabilityAsyncResponse(BaseModel):
|
class StabilityAsyncResponse(BaseModel):
|
||||||
id: Optional[str] = Field(None)
|
id: Optional[str] = Field(None)
|
||||||
|
|
||||||
|
|
||||||
|
class StabilityTextToAudioRequest(BaseModel):
|
||||||
|
model: str = Field(...)
|
||||||
|
prompt: str = Field(...)
|
||||||
|
duration: int = Field(190, ge=1, le=190)
|
||||||
|
seed: int = Field(0, ge=0, le=4294967294)
|
||||||
|
steps: int = Field(8, ge=4, le=8)
|
||||||
|
output_format: str = Field("wav")
|
||||||
|
|
||||||
|
|
||||||
|
class StabilityAudioToAudioRequest(StabilityTextToAudioRequest):
|
||||||
|
strength: float = Field(0.01, ge=0.01, le=1.0)
|
||||||
|
|
||||||
|
|
||||||
|
class StabilityAudioInpaintRequest(StabilityTextToAudioRequest):
|
||||||
|
mask_start: int = Field(30, ge=0, le=190)
|
||||||
|
mask_end: int = Field(190, ge=0, le=190)
|
||||||
|
|
||||||
|
|
||||||
|
class StabilityAudioResponse(BaseModel):
|
||||||
|
audio: Optional[str] = Field(None)
|
||||||
|
@@ -2,7 +2,7 @@ from inspect import cleandoc
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
from comfy_api.latest import ComfyExtension, io as comfy_io
|
from comfy_api.latest import ComfyExtension, Input, io as comfy_io
|
||||||
from comfy_api_nodes.apis.stability_api import (
|
from comfy_api_nodes.apis.stability_api import (
|
||||||
StabilityUpscaleConservativeRequest,
|
StabilityUpscaleConservativeRequest,
|
||||||
StabilityUpscaleCreativeRequest,
|
StabilityUpscaleCreativeRequest,
|
||||||
@@ -15,6 +15,10 @@ from comfy_api_nodes.apis.stability_api import (
|
|||||||
Stability_SD3_5_Model,
|
Stability_SD3_5_Model,
|
||||||
Stability_SD3_5_GenerationMode,
|
Stability_SD3_5_GenerationMode,
|
||||||
get_stability_style_presets,
|
get_stability_style_presets,
|
||||||
|
StabilityTextToAudioRequest,
|
||||||
|
StabilityAudioToAudioRequest,
|
||||||
|
StabilityAudioInpaintRequest,
|
||||||
|
StabilityAudioResponse,
|
||||||
)
|
)
|
||||||
from comfy_api_nodes.apis.client import (
|
from comfy_api_nodes.apis.client import (
|
||||||
ApiEndpoint,
|
ApiEndpoint,
|
||||||
@@ -27,7 +31,10 @@ from comfy_api_nodes.apinode_utils import (
|
|||||||
bytesio_to_image_tensor,
|
bytesio_to_image_tensor,
|
||||||
tensor_to_bytesio,
|
tensor_to_bytesio,
|
||||||
validate_string,
|
validate_string,
|
||||||
|
audio_bytes_to_audio_input,
|
||||||
|
audio_input_to_mp3,
|
||||||
)
|
)
|
||||||
|
from comfy_api_nodes.util.validation_utils import validate_audio_duration
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import base64
|
import base64
|
||||||
@@ -649,6 +656,306 @@ class StabilityUpscaleFastNode(comfy_io.ComfyNode):
|
|||||||
return comfy_io.NodeOutput(returned_image)
|
return comfy_io.NodeOutput(returned_image)
|
||||||
|
|
||||||
|
|
||||||
|
class StabilityTextToAudio(comfy_io.ComfyNode):
|
||||||
|
"""Generates high-quality music and sound effects from text descriptions."""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return comfy_io.Schema(
|
||||||
|
node_id="StabilityTextToAudio",
|
||||||
|
display_name="Stability AI Text To Audio",
|
||||||
|
category="api node/audio/Stability AI",
|
||||||
|
description=cleandoc(cls.__doc__ or ""),
|
||||||
|
inputs=[
|
||||||
|
comfy_io.Combo.Input(
|
||||||
|
"model",
|
||||||
|
options=["stable-audio-2.5"],
|
||||||
|
),
|
||||||
|
comfy_io.String.Input("prompt", multiline=True, default=""),
|
||||||
|
comfy_io.Int.Input(
|
||||||
|
"duration",
|
||||||
|
default=190,
|
||||||
|
min=1,
|
||||||
|
max=190,
|
||||||
|
step=1,
|
||||||
|
tooltip="Controls the duration in seconds of the generated audio.",
|
||||||
|
optional=True,
|
||||||
|
),
|
||||||
|
comfy_io.Int.Input(
|
||||||
|
"seed",
|
||||||
|
default=0,
|
||||||
|
min=0,
|
||||||
|
max=4294967294,
|
||||||
|
step=1,
|
||||||
|
display_mode=comfy_io.NumberDisplay.number,
|
||||||
|
control_after_generate=True,
|
||||||
|
tooltip="The random seed used for generation.",
|
||||||
|
optional=True,
|
||||||
|
),
|
||||||
|
comfy_io.Int.Input(
|
||||||
|
"steps",
|
||||||
|
default=8,
|
||||||
|
min=4,
|
||||||
|
max=8,
|
||||||
|
step=1,
|
||||||
|
tooltip="Controls the number of sampling steps.",
|
||||||
|
optional=True,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
comfy_io.Audio.Output(),
|
||||||
|
],
|
||||||
|
hidden=[
|
||||||
|
comfy_io.Hidden.auth_token_comfy_org,
|
||||||
|
comfy_io.Hidden.api_key_comfy_org,
|
||||||
|
comfy_io.Hidden.unique_id,
|
||||||
|
],
|
||||||
|
is_api_node=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def execute(cls, model: str, prompt: str, duration: int, seed: int, steps: int) -> comfy_io.NodeOutput:
|
||||||
|
validate_string(prompt, max_length=10000)
|
||||||
|
payload = StabilityTextToAudioRequest(prompt=prompt, model=model, duration=duration, seed=seed, steps=steps)
|
||||||
|
operation = SynchronousOperation(
|
||||||
|
endpoint=ApiEndpoint(
|
||||||
|
path="/proxy/stability/v2beta/audio/stable-audio-2/text-to-audio",
|
||||||
|
method=HttpMethod.POST,
|
||||||
|
request_model=StabilityTextToAudioRequest,
|
||||||
|
response_model=StabilityAudioResponse,
|
||||||
|
),
|
||||||
|
request=payload,
|
||||||
|
content_type="multipart/form-data",
|
||||||
|
auth_kwargs= {
|
||||||
|
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||||
|
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
response_api = await operation.execute()
|
||||||
|
if not response_api.audio:
|
||||||
|
raise ValueError("No audio file was received in response.")
|
||||||
|
return comfy_io.NodeOutput(audio_bytes_to_audio_input(base64.b64decode(response_api.audio)))
|
||||||
|
|
||||||
|
|
||||||
|
class StabilityAudioToAudio(comfy_io.ComfyNode):
|
||||||
|
"""Transforms existing audio samples into new high-quality compositions using text instructions."""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return comfy_io.Schema(
|
||||||
|
node_id="StabilityAudioToAudio",
|
||||||
|
display_name="Stability AI Audio To Audio",
|
||||||
|
category="api node/audio/Stability AI",
|
||||||
|
description=cleandoc(cls.__doc__ or ""),
|
||||||
|
inputs=[
|
||||||
|
comfy_io.Combo.Input(
|
||||||
|
"model",
|
||||||
|
options=["stable-audio-2.5"],
|
||||||
|
),
|
||||||
|
comfy_io.String.Input("prompt", multiline=True, default=""),
|
||||||
|
comfy_io.Audio.Input("audio", tooltip="Audio must be between 6 and 190 seconds long."),
|
||||||
|
comfy_io.Int.Input(
|
||||||
|
"duration",
|
||||||
|
default=190,
|
||||||
|
min=1,
|
||||||
|
max=190,
|
||||||
|
step=1,
|
||||||
|
tooltip="Controls the duration in seconds of the generated audio.",
|
||||||
|
optional=True,
|
||||||
|
),
|
||||||
|
comfy_io.Int.Input(
|
||||||
|
"seed",
|
||||||
|
default=0,
|
||||||
|
min=0,
|
||||||
|
max=4294967294,
|
||||||
|
step=1,
|
||||||
|
display_mode=comfy_io.NumberDisplay.number,
|
||||||
|
control_after_generate=True,
|
||||||
|
tooltip="The random seed used for generation.",
|
||||||
|
optional=True,
|
||||||
|
),
|
||||||
|
comfy_io.Int.Input(
|
||||||
|
"steps",
|
||||||
|
default=8,
|
||||||
|
min=4,
|
||||||
|
max=8,
|
||||||
|
step=1,
|
||||||
|
tooltip="Controls the number of sampling steps.",
|
||||||
|
optional=True,
|
||||||
|
),
|
||||||
|
comfy_io.Float.Input(
|
||||||
|
"strength",
|
||||||
|
default=1,
|
||||||
|
min=0.01,
|
||||||
|
max=1.0,
|
||||||
|
step=0.01,
|
||||||
|
display_mode=comfy_io.NumberDisplay.slider,
|
||||||
|
tooltip="Parameter controls how much influence the audio parameter has on the generated audio.",
|
||||||
|
optional=True,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
comfy_io.Audio.Output(),
|
||||||
|
],
|
||||||
|
hidden=[
|
||||||
|
comfy_io.Hidden.auth_token_comfy_org,
|
||||||
|
comfy_io.Hidden.api_key_comfy_org,
|
||||||
|
comfy_io.Hidden.unique_id,
|
||||||
|
],
|
||||||
|
is_api_node=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def execute(
|
||||||
|
cls, model: str, prompt: str, audio: Input.Audio, duration: int, seed: int, steps: int, strength: float
|
||||||
|
) -> comfy_io.NodeOutput:
|
||||||
|
validate_string(prompt, max_length=10000)
|
||||||
|
validate_audio_duration(audio, 6, 190)
|
||||||
|
payload = StabilityAudioToAudioRequest(
|
||||||
|
prompt=prompt, model=model, duration=duration, seed=seed, steps=steps, strength=strength
|
||||||
|
)
|
||||||
|
operation = SynchronousOperation(
|
||||||
|
endpoint=ApiEndpoint(
|
||||||
|
path="/proxy/stability/v2beta/audio/stable-audio-2/audio-to-audio",
|
||||||
|
method=HttpMethod.POST,
|
||||||
|
request_model=StabilityAudioToAudioRequest,
|
||||||
|
response_model=StabilityAudioResponse,
|
||||||
|
),
|
||||||
|
request=payload,
|
||||||
|
content_type="multipart/form-data",
|
||||||
|
files={"audio": audio_input_to_mp3(audio)},
|
||||||
|
auth_kwargs= {
|
||||||
|
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||||
|
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
response_api = await operation.execute()
|
||||||
|
if not response_api.audio:
|
||||||
|
raise ValueError("No audio file was received in response.")
|
||||||
|
return comfy_io.NodeOutput(audio_bytes_to_audio_input(base64.b64decode(response_api.audio)))
|
||||||
|
|
||||||
|
|
||||||
|
class StabilityAudioInpaint(comfy_io.ComfyNode):
|
||||||
|
"""Transforms part of existing audio sample using text instructions."""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return comfy_io.Schema(
|
||||||
|
node_id="StabilityAudioInpaint",
|
||||||
|
display_name="Stability AI Audio Inpaint",
|
||||||
|
category="api node/audio/Stability AI",
|
||||||
|
description=cleandoc(cls.__doc__ or ""),
|
||||||
|
inputs=[
|
||||||
|
comfy_io.Combo.Input(
|
||||||
|
"model",
|
||||||
|
options=["stable-audio-2.5"],
|
||||||
|
),
|
||||||
|
comfy_io.String.Input("prompt", multiline=True, default=""),
|
||||||
|
comfy_io.Audio.Input("audio", tooltip="Audio must be between 6 and 190 seconds long."),
|
||||||
|
comfy_io.Int.Input(
|
||||||
|
"duration",
|
||||||
|
default=190,
|
||||||
|
min=1,
|
||||||
|
max=190,
|
||||||
|
step=1,
|
||||||
|
tooltip="Controls the duration in seconds of the generated audio.",
|
||||||
|
optional=True,
|
||||||
|
),
|
||||||
|
comfy_io.Int.Input(
|
||||||
|
"seed",
|
||||||
|
default=0,
|
||||||
|
min=0,
|
||||||
|
max=4294967294,
|
||||||
|
step=1,
|
||||||
|
display_mode=comfy_io.NumberDisplay.number,
|
||||||
|
control_after_generate=True,
|
||||||
|
tooltip="The random seed used for generation.",
|
||||||
|
optional=True,
|
||||||
|
),
|
||||||
|
comfy_io.Int.Input(
|
||||||
|
"steps",
|
||||||
|
default=8,
|
||||||
|
min=4,
|
||||||
|
max=8,
|
||||||
|
step=1,
|
||||||
|
tooltip="Controls the number of sampling steps.",
|
||||||
|
optional=True,
|
||||||
|
),
|
||||||
|
comfy_io.Int.Input(
|
||||||
|
"mask_start",
|
||||||
|
default=30,
|
||||||
|
min=0,
|
||||||
|
max=190,
|
||||||
|
step=1,
|
||||||
|
optional=True,
|
||||||
|
),
|
||||||
|
comfy_io.Int.Input(
|
||||||
|
"mask_end",
|
||||||
|
default=190,
|
||||||
|
min=0,
|
||||||
|
max=190,
|
||||||
|
step=1,
|
||||||
|
optional=True,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
comfy_io.Audio.Output(),
|
||||||
|
],
|
||||||
|
hidden=[
|
||||||
|
comfy_io.Hidden.auth_token_comfy_org,
|
||||||
|
comfy_io.Hidden.api_key_comfy_org,
|
||||||
|
comfy_io.Hidden.unique_id,
|
||||||
|
],
|
||||||
|
is_api_node=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def execute(
|
||||||
|
cls,
|
||||||
|
model: str,
|
||||||
|
prompt: str,
|
||||||
|
audio: Input.Audio,
|
||||||
|
duration: int,
|
||||||
|
seed: int,
|
||||||
|
steps: int,
|
||||||
|
mask_start: int,
|
||||||
|
mask_end: int,
|
||||||
|
) -> comfy_io.NodeOutput:
|
||||||
|
validate_string(prompt, max_length=10000)
|
||||||
|
if mask_end <= mask_start:
|
||||||
|
raise ValueError(f"Value of mask_end({mask_end}) should be greater then mask_start({mask_start})")
|
||||||
|
validate_audio_duration(audio, 6, 190)
|
||||||
|
|
||||||
|
payload = StabilityAudioInpaintRequest(
|
||||||
|
prompt=prompt,
|
||||||
|
model=model,
|
||||||
|
duration=duration,
|
||||||
|
seed=seed,
|
||||||
|
steps=steps,
|
||||||
|
mask_start=mask_start,
|
||||||
|
mask_end=mask_end,
|
||||||
|
)
|
||||||
|
operation = SynchronousOperation(
|
||||||
|
endpoint=ApiEndpoint(
|
||||||
|
path="/proxy/stability/v2beta/audio/stable-audio-2/inpaint",
|
||||||
|
method=HttpMethod.POST,
|
||||||
|
request_model=StabilityAudioInpaintRequest,
|
||||||
|
response_model=StabilityAudioResponse,
|
||||||
|
),
|
||||||
|
request=payload,
|
||||||
|
content_type="multipart/form-data",
|
||||||
|
files={"audio": audio_input_to_mp3(audio)},
|
||||||
|
auth_kwargs={
|
||||||
|
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||||
|
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
response_api = await operation.execute()
|
||||||
|
if not response_api.audio:
|
||||||
|
raise ValueError("No audio file was received in response.")
|
||||||
|
return comfy_io.NodeOutput(audio_bytes_to_audio_input(base64.b64decode(response_api.audio)))
|
||||||
|
|
||||||
|
|
||||||
class StabilityExtension(ComfyExtension):
|
class StabilityExtension(ComfyExtension):
|
||||||
@override
|
@override
|
||||||
async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]:
|
async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]:
|
||||||
@@ -658,6 +965,9 @@ class StabilityExtension(ComfyExtension):
|
|||||||
StabilityUpscaleConservativeNode,
|
StabilityUpscaleConservativeNode,
|
||||||
StabilityUpscaleCreativeNode,
|
StabilityUpscaleCreativeNode,
|
||||||
StabilityUpscaleFastNode,
|
StabilityUpscaleFastNode,
|
||||||
|
StabilityTextToAudio,
|
||||||
|
StabilityAudioToAudio,
|
||||||
|
StabilityAudioInpaint,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@@ -2,7 +2,7 @@ import logging
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from comfy_api.input.video_types import VideoInput
|
from comfy_api.latest import Input
|
||||||
|
|
||||||
|
|
||||||
def get_image_dimensions(image: torch.Tensor) -> tuple[int, int]:
|
def get_image_dimensions(image: torch.Tensor) -> tuple[int, int]:
|
||||||
@@ -101,7 +101,7 @@ def validate_aspect_ratio_closeness(
|
|||||||
|
|
||||||
|
|
||||||
def validate_video_dimensions(
|
def validate_video_dimensions(
|
||||||
video: VideoInput,
|
video: Input.Video,
|
||||||
min_width: Optional[int] = None,
|
min_width: Optional[int] = None,
|
||||||
max_width: Optional[int] = None,
|
max_width: Optional[int] = None,
|
||||||
min_height: Optional[int] = None,
|
min_height: Optional[int] = None,
|
||||||
@@ -126,7 +126,7 @@ def validate_video_dimensions(
|
|||||||
|
|
||||||
|
|
||||||
def validate_video_duration(
|
def validate_video_duration(
|
||||||
video: VideoInput,
|
video: Input.Video,
|
||||||
min_duration: Optional[float] = None,
|
min_duration: Optional[float] = None,
|
||||||
max_duration: Optional[float] = None,
|
max_duration: Optional[float] = None,
|
||||||
):
|
):
|
||||||
@@ -151,3 +151,17 @@ def get_number_of_images(images):
|
|||||||
if isinstance(images, torch.Tensor):
|
if isinstance(images, torch.Tensor):
|
||||||
return images.shape[0] if images.ndim >= 4 else 1
|
return images.shape[0] if images.ndim >= 4 else 1
|
||||||
return len(images)
|
return len(images)
|
||||||
|
|
||||||
|
|
||||||
|
def validate_audio_duration(
|
||||||
|
audio: Input.Audio,
|
||||||
|
min_duration: Optional[float] = None,
|
||||||
|
max_duration: Optional[float] = None,
|
||||||
|
) -> None:
|
||||||
|
sr = int(audio["sample_rate"])
|
||||||
|
dur = int(audio["waveform"].shape[-1]) / sr
|
||||||
|
eps = 1.0 / sr
|
||||||
|
if min_duration is not None and dur + eps < min_duration:
|
||||||
|
raise ValueError(f"Audio duration must be at least {min_duration}s, got {dur + eps:.2f}s")
|
||||||
|
if max_duration is not None and dur - eps > max_duration:
|
||||||
|
raise ValueError(f"Audio duration must be at most {max_duration}s, got {dur - eps:.2f}s")
|
||||||
|
Reference in New Issue
Block a user