Merge branch 'master' into worksplit-multigpu

This commit is contained in:
Jedrzej Kosinski 2025-08-07 23:24:02 -07:00
commit 6ea69369ce
32 changed files with 3601 additions and 1182 deletions

1
.gitattributes vendored
View File

@ -1,2 +1,3 @@
/web/assets/** linguist-generated
/web/** linguist-vendored
comfy_api_nodes/apis/__init__.py linguist-generated

View File

@ -66,6 +66,7 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith
- [Lumina Image 2.0](https://comfyanonymous.github.io/ComfyUI_examples/lumina2/)
- [HiDream](https://comfyanonymous.github.io/ComfyUI_examples/hidream/)
- [Cosmos Predict2](https://comfyanonymous.github.io/ComfyUI_examples/cosmos_predict2/)
- [Qwen Image](https://comfyanonymous.github.io/ComfyUI_examples/qwen_image/)
- Image Editing Models
- [Omnigen 2](https://comfyanonymous.github.io/ComfyUI_examples/omnigen/)
- [Flux Kontext](https://comfyanonymous.github.io/ComfyUI_examples/flux/#flux-kontext-image-editing-model)

View File

@ -8,7 +8,7 @@ from einops import repeat
from comfy.ldm.lightricks.model import TimestepEmbedding, Timesteps
from comfy.ldm.modules.attention import optimized_attention_masked
from comfy.ldm.flux.layers import EmbedND
import comfy.ldm.common_dit
class GELU(nn.Module):
def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool = True, dtype=None, device=None, operations=None):
@ -364,8 +364,9 @@ class QwenImageTransformer2DModel(nn.Module):
image_rotary_emb = self.pos_embeds(x, context)
orig_shape = x.shape
hidden_states = x.view(orig_shape[0], orig_shape[1], orig_shape[-2] // 2, 2, orig_shape[-1] // 2, 2)
hidden_states = comfy.ldm.common_dit.pad_to_patch_size(x, (1, self.patch_size, self.patch_size))
orig_shape = hidden_states.shape
hidden_states = hidden_states.view(orig_shape[0], orig_shape[1], orig_shape[-2] // 2, 2, orig_shape[-1] // 2, 2)
hidden_states = hidden_states.permute(0, 2, 4, 1, 3, 5)
hidden_states = hidden_states.reshape(orig_shape[0], (orig_shape[-2] // 2) * (orig_shape[-1] // 2), orig_shape[1] * 4)
@ -396,4 +397,4 @@ class QwenImageTransformer2DModel(nn.Module):
hidden_states = hidden_states.view(orig_shape[0], orig_shape[-2] // 2, orig_shape[-1] // 2, orig_shape[1], 2, 2)
hidden_states = hidden_states.permute(0, 3, 1, 4, 2, 5)
return hidden_states.reshape(orig_shape)
return hidden_states.reshape(orig_shape)[:, :, :, :x.shape[-2], :x.shape[-1]]

View File

@ -293,6 +293,15 @@ def model_lora_keys_unet(model, key_map={}):
key_lora = k[len("diffusion_model."):-len(".weight")]
key_map["{}".format(key_lora)] = k
if isinstance(model, comfy.model_base.QwenImage):
for k in sdk:
if k.startswith("diffusion_model.") and k.endswith(".weight"): #QwenImage lora format
key_lora = k[len("diffusion_model."):-len(".weight")]
# Direct mapping for transformer_blocks format (QwenImage LoRA format)
key_map["{}".format(key_lora)] = k
# Support transformer prefix format
key_map["transformer.{}".format(key_lora)] = k
return key_map

View File

@ -1237,7 +1237,7 @@ class QwenImage(supported_models_base.BASE):
sampling_settings = {
"multiplier": 1.0,
"shift": 2.6,
"shift": 1.15,
}
memory_usage_factor = 1.8 #TODO

View File

@ -96,6 +96,7 @@ class LoRAAdapter(WeightAdapterBase):
diffusers3_lora = "{}.lora.up.weight".format(x)
mochi_lora = "{}.lora_B".format(x)
transformers_lora = "{}.lora_linear_layer.up.weight".format(x)
qwen_default_lora = "{}.lora_B.default.weight".format(x)
A_name = None
if regular_lora in lora.keys():
@ -122,6 +123,10 @@ class LoRAAdapter(WeightAdapterBase):
A_name = transformers_lora
B_name = "{}.lora_linear_layer.down.weight".format(x)
mid_name = None
elif qwen_default_lora in lora.keys():
A_name = qwen_default_lora
B_name = "{}.lora_A.default.weight".format(x)
mid_name = None
if A_name is not None:
mid = None

View File

@ -9,7 +9,11 @@ from typing import Type
import av
import numpy as np
import torch
import torchaudio
try:
import torchaudio
TORCH_AUDIO_AVAILABLE = True
except ImportError:
TORCH_AUDIO_AVAILABLE = False
from PIL import Image as PILImage
from PIL.PngImagePlugin import PngInfo
@ -302,6 +306,8 @@ class AudioSaveHelper:
# Resample if necessary
if sample_rate != audio["sample_rate"]:
if not TORCH_AUDIO_AVAILABLE:
raise Exception("torchaudio is not available; cannot resample audio.")
waveform = torchaudio.functional.resample(waveform, audio["sample_rate"], sample_rate)
# Create output with specified format

View File

@ -1,4 +1,5 @@
from __future__ import annotations
import aiohttp
import io
import logging
import mimetypes
@ -21,7 +22,6 @@ from server import PromptServer
import numpy as np
from PIL import Image
import requests
import torch
import math
import base64
@ -30,7 +30,7 @@ from io import BytesIO
import av
def download_url_to_video_output(video_url: str, timeout: int = None) -> VideoFromFile:
async def download_url_to_video_output(video_url: str, timeout: int = None) -> VideoFromFile:
"""Downloads a video from a URL and returns a `VIDEO` output.
Args:
@ -39,7 +39,7 @@ def download_url_to_video_output(video_url: str, timeout: int = None) -> VideoFr
Returns:
A Comfy node `VIDEO` output.
"""
video_io = download_url_to_bytesio(video_url, timeout)
video_io = await download_url_to_bytesio(video_url, timeout)
if video_io is None:
error_msg = f"Failed to download video from {video_url}"
logging.error(error_msg)
@ -62,7 +62,7 @@ def downscale_image_tensor(image, total_pixels=1536 * 1024) -> torch.Tensor:
return s
def validate_and_cast_response(
async def validate_and_cast_response(
response, timeout: int = None, node_id: Union[str, None] = None
) -> torch.Tensor:
"""Validates and casts a response to a torch.Tensor.
@ -86,35 +86,24 @@ def validate_and_cast_response(
image_tensors: list[torch.Tensor] = []
# Process each image in the data array
for image_data in data:
image_url = image_data.url
b64_data = image_data.b64_json
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=timeout)) as session:
for img_data in data:
img_bytes: bytes
if img_data.b64_json:
img_bytes = base64.b64decode(img_data.b64_json)
elif img_data.url:
if node_id:
PromptServer.instance.send_progress_text(f"Result URL: {img_data.url}", node_id)
async with session.get(img_data.url) as resp:
if resp.status != 200:
raise ValueError("Failed to download generated image")
img_bytes = await resp.read()
else:
raise ValueError("Invalid image payload neither URL nor base64 data present.")
if not image_url and not b64_data:
raise ValueError("No image was generated in the response")
if b64_data:
img_data = base64.b64decode(b64_data)
img = Image.open(io.BytesIO(img_data))
elif image_url:
if node_id:
PromptServer.instance.send_progress_text(
f"Result URL: {image_url}", node_id
)
img_response = requests.get(image_url, timeout=timeout)
if img_response.status_code != 200:
raise ValueError("Failed to download the image")
img = Image.open(io.BytesIO(img_response.content))
img = img.convert("RGBA")
# Convert to numpy array, normalize to float32 between 0 and 1
img_array = np.array(img).astype(np.float32) / 255.0
img_tensor = torch.from_numpy(img_array)
# Add to list of tensors
image_tensors.append(img_tensor)
pil_img = Image.open(BytesIO(img_bytes)).convert("RGBA")
arr = np.asarray(pil_img).astype(np.float32) / 255.0
image_tensors.append(torch.from_numpy(arr))
return torch.stack(image_tensors, dim=0)
@ -175,7 +164,7 @@ def mimetype_to_extension(mime_type: str) -> str:
return mime_type.split("/")[-1].lower()
def download_url_to_bytesio(url: str, timeout: int = None) -> BytesIO:
async def download_url_to_bytesio(url: str, timeout: int = None) -> BytesIO:
"""Downloads content from a URL using requests and returns it as BytesIO.
Args:
@ -185,9 +174,11 @@ def download_url_to_bytesio(url: str, timeout: int = None) -> BytesIO:
Returns:
BytesIO object containing the downloaded content.
"""
response = requests.get(url, stream=True, timeout=timeout)
response.raise_for_status() # Raises HTTPError for bad responses (4XX or 5XX)
return BytesIO(response.content)
timeout_cfg = aiohttp.ClientTimeout(total=timeout) if timeout else None
async with aiohttp.ClientSession(timeout=timeout_cfg) as session:
async with session.get(url) as resp:
resp.raise_for_status() # Raises HTTPError for bad responses (4XX or 5XX)
return BytesIO(await resp.read())
def bytesio_to_image_tensor(image_bytesio: BytesIO, mode: str = "RGBA") -> torch.Tensor:
@ -210,15 +201,15 @@ def bytesio_to_image_tensor(image_bytesio: BytesIO, mode: str = "RGBA") -> torch
return torch.from_numpy(image_array).unsqueeze(0)
def download_url_to_image_tensor(url: str, timeout: int = None) -> torch.Tensor:
async def download_url_to_image_tensor(url: str, timeout: int = None) -> torch.Tensor:
"""Downloads an image from a URL and returns a [B, H, W, C] tensor."""
image_bytesio = download_url_to_bytesio(url, timeout)
image_bytesio = await download_url_to_bytesio(url, timeout)
return bytesio_to_image_tensor(image_bytesio)
def process_image_response(response: requests.Response) -> torch.Tensor:
def process_image_response(response_content: bytes | str) -> torch.Tensor:
"""Uses content from a Response object and converts it to a torch.Tensor"""
return bytesio_to_image_tensor(BytesIO(response.content))
return bytesio_to_image_tensor(BytesIO(response_content))
def _tensor_to_pil(image: torch.Tensor, total_pixels: int = 2048 * 2048) -> Image.Image:
@ -336,10 +327,10 @@ def text_filepath_to_data_uri(filepath: str) -> str:
return f"data:{mime_type};base64,{base64_string}"
def upload_file_to_comfyapi(
async def upload_file_to_comfyapi(
file_bytes_io: BytesIO,
filename: str,
upload_mime_type: str,
upload_mime_type: Optional[str],
auth_kwargs: Optional[dict[str, str]] = None,
) -> str:
"""
@ -354,7 +345,10 @@ def upload_file_to_comfyapi(
Returns:
The download URL for the uploaded file.
"""
request_object = UploadRequest(file_name=filename, content_type=upload_mime_type)
if upload_mime_type is None:
request_object = UploadRequest(file_name=filename)
else:
request_object = UploadRequest(file_name=filename, content_type=upload_mime_type)
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/customers/storage",
@ -366,12 +360,8 @@ def upload_file_to_comfyapi(
auth_kwargs=auth_kwargs,
)
response: UploadResponse = operation.execute()
upload_response = ApiClient.upload_file(
response.upload_url, file_bytes_io, content_type=upload_mime_type
)
upload_response.raise_for_status()
response: UploadResponse = await operation.execute()
await ApiClient.upload_file(response.upload_url, file_bytes_io, content_type=upload_mime_type)
return response.download_url
@ -399,7 +389,7 @@ def video_to_base64_string(
return base64.b64encode(video_bytes_io.getvalue()).decode("utf-8")
def upload_video_to_comfyapi(
async def upload_video_to_comfyapi(
video: VideoInput,
auth_kwargs: Optional[dict[str, str]] = None,
container: VideoContainer = VideoContainer.MP4,
@ -439,9 +429,7 @@ def upload_video_to_comfyapi(
video.save_to(video_bytes_io, format=container, codec=codec)
video_bytes_io.seek(0)
return upload_file_to_comfyapi(
video_bytes_io, filename, upload_mime_type, auth_kwargs
)
return await upload_file_to_comfyapi(video_bytes_io, filename, upload_mime_type, auth_kwargs)
def audio_tensor_to_contiguous_ndarray(waveform: torch.Tensor) -> np.ndarray:
@ -501,7 +489,7 @@ def audio_ndarray_to_bytesio(
return audio_bytes_io
def upload_audio_to_comfyapi(
async def upload_audio_to_comfyapi(
audio: AudioInput,
auth_kwargs: Optional[dict[str, str]] = None,
container_format: str = "mp4",
@ -527,7 +515,7 @@ def upload_audio_to_comfyapi(
audio_data_np, sample_rate, container_format, codec_name
)
return upload_file_to_comfyapi(audio_bytes_io, filename, mime_type, auth_kwargs)
return await upload_file_to_comfyapi(audio_bytes_io, filename, mime_type, auth_kwargs)
def audio_to_base64_string(
@ -544,7 +532,7 @@ def audio_to_base64_string(
return base64.b64encode(audio_bytes).decode("utf-8")
def upload_images_to_comfyapi(
async def upload_images_to_comfyapi(
image: torch.Tensor,
max_images=8,
auth_kwargs: Optional[dict[str, str]] = None,
@ -561,55 +549,15 @@ def upload_images_to_comfyapi(
mime_type: Optional MIME type for the image.
"""
# if batch, try to upload each file if max_images is greater than 0
idx_image = 0
download_urls: list[str] = []
is_batch = len(image.shape) > 3
batch_length = 1
if is_batch:
batch_length = image.shape[0]
while True:
curr_image = image
if len(image.shape) > 3:
curr_image = image[idx_image]
# get BytesIO version of image
img_binary = tensor_to_bytesio(curr_image, mime_type=mime_type)
# first, request upload/download urls from comfy API
if not mime_type:
request_object = UploadRequest(file_name=img_binary.name)
else:
request_object = UploadRequest(
file_name=img_binary.name, content_type=mime_type
)
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/customers/storage",
method=HttpMethod.POST,
request_model=UploadRequest,
response_model=UploadResponse,
),
request=request_object,
auth_kwargs=auth_kwargs,
)
response = operation.execute()
batch_len = image.shape[0] if is_batch else 1
upload_response = ApiClient.upload_file(
response.upload_url, img_binary, content_type=mime_type
)
# verify success
try:
upload_response.raise_for_status()
except requests.exceptions.HTTPError as e:
raise ValueError(f"Could not upload one or more images: {e}") from e
# add download_url to list
download_urls.append(response.download_url)
idx_image += 1
# stop uploading additional files if done
if is_batch and max_images > 0:
if idx_image >= max_images:
break
if idx_image >= batch_length:
break
for idx in range(min(batch_len, max_images)):
tensor = image[idx] if is_batch else image
img_io = tensor_to_bytesio(tensor, mime_type=mime_type)
url = await upload_file_to_comfyapi(img_io, img_io.name, mime_type, auth_kwargs)
download_urls.append(url)
return download_urls

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -127,7 +127,7 @@ class TripoTextToModelRequest(BaseModel):
type: TripoTaskType = Field(TripoTaskType.TEXT_TO_MODEL, description='Type of task')
prompt: str = Field(..., description='The text prompt describing the model to generate', max_length=1024)
negative_prompt: Optional[str] = Field(None, description='The negative text prompt', max_length=1024)
model_version: Optional[TripoModelVersion] = TripoModelVersion.V2_5
model_version: Optional[TripoModelVersion] = TripoModelVersion.v2_5_20250123
face_limit: Optional[int] = Field(None, description='The number of faces to limit the generation to')
texture: Optional[bool] = Field(True, description='Whether to apply texture to the generated model')
pbr: Optional[bool] = Field(True, description='Whether to apply PBR to the generated model')

View File

@ -1,3 +1,4 @@
import asyncio
import io
from inspect import cleandoc
from typing import Union, Optional
@ -28,7 +29,7 @@ from comfy_api_nodes.apinode_utils import (
import numpy as np
from PIL import Image
import requests
import aiohttp
import torch
import base64
import time
@ -44,18 +45,18 @@ def convert_mask_to_image(mask: torch.Tensor):
return mask
def handle_bfl_synchronous_operation(
async def handle_bfl_synchronous_operation(
operation: SynchronousOperation,
timeout_bfl_calls=360,
node_id: Union[str, None] = None,
):
response_api: BFLFluxProGenerateResponse = operation.execute()
return _poll_until_generated(
response_api: BFLFluxProGenerateResponse = await operation.execute()
return await _poll_until_generated(
response_api.polling_url, timeout=timeout_bfl_calls, node_id=node_id
)
def _poll_until_generated(
async def _poll_until_generated(
polling_url: str, timeout=360, node_id: Union[str, None] = None
):
# used bfl-comfy-nodes to verify code implementation:
@ -66,55 +67,56 @@ def _poll_until_generated(
retry_404_seconds = 2
retry_202_seconds = 2
retry_pending_seconds = 1
request = requests.Request(method=HttpMethod.GET, url=polling_url)
# NOTE: should True loop be replaced with checking if workflow has been interrupted?
while True:
if node_id:
time_elapsed = time.time() - start_time
PromptServer.instance.send_progress_text(
f"Generating ({time_elapsed:.0f}s)", node_id
)
response = requests.Session().send(request.prepare())
if response.status_code == 200:
result = response.json()
if result["status"] == BFLStatus.ready:
img_url = result["result"]["sample"]
if node_id:
PromptServer.instance.send_progress_text(
f"Result URL: {img_url}", node_id
)
img_response = requests.get(img_url)
return process_image_response(img_response)
elif result["status"] in [
BFLStatus.request_moderated,
BFLStatus.content_moderated,
]:
status = result["status"]
raise Exception(
f"BFL API did not return an image due to: {status}."
async with aiohttp.ClientSession() as session:
# NOTE: should True loop be replaced with checking if workflow has been interrupted?
while True:
if node_id:
time_elapsed = time.time() - start_time
PromptServer.instance.send_progress_text(
f"Generating ({time_elapsed:.0f}s)", node_id
)
elif result["status"] == BFLStatus.error:
raise Exception(f"BFL API encountered an error: {result}.")
elif result["status"] == BFLStatus.pending:
time.sleep(retry_pending_seconds)
continue
elif response.status_code == 404:
if retries_404 < max_retries_404:
retries_404 += 1
time.sleep(retry_404_seconds)
continue
raise Exception(
f"BFL API could not find task after {max_retries_404} tries."
)
elif response.status_code == 202:
time.sleep(retry_202_seconds)
elif time.time() - start_time > timeout:
raise Exception(
f"BFL API experienced a timeout; could not return request under {timeout} seconds."
)
else:
raise Exception(f"BFL API encountered an error: {response.json()}")
async with session.get(polling_url) as response:
if response.status == 200:
result = await response.json()
if result["status"] == BFLStatus.ready:
img_url = result["result"]["sample"]
if node_id:
PromptServer.instance.send_progress_text(
f"Result URL: {img_url}", node_id
)
async with session.get(img_url) as img_resp:
return process_image_response(await img_resp.content.read())
elif result["status"] in [
BFLStatus.request_moderated,
BFLStatus.content_moderated,
]:
status = result["status"]
raise Exception(
f"BFL API did not return an image due to: {status}."
)
elif result["status"] == BFLStatus.error:
raise Exception(f"BFL API encountered an error: {result}.")
elif result["status"] == BFLStatus.pending:
await asyncio.sleep(retry_pending_seconds)
continue
elif response.status == 404:
if retries_404 < max_retries_404:
retries_404 += 1
await asyncio.sleep(retry_404_seconds)
continue
raise Exception(
f"BFL API could not find task after {max_retries_404} tries."
)
elif response.status == 202:
await asyncio.sleep(retry_202_seconds)
elif time.time() - start_time > timeout:
raise Exception(
f"BFL API experienced a timeout; could not return request under {timeout} seconds."
)
else:
raise Exception(f"BFL API encountered an error: {response.json()}")
def convert_image_to_base64(image: torch.Tensor):
scaled_image = downscale_image_tensor(image, total_pixels=2048 * 2048)
@ -222,7 +224,7 @@ class FluxProUltraImageNode(ComfyNodeABC):
API_NODE = True
CATEGORY = "api node/image/BFL"
def api_call(
async def api_call(
self,
prompt: str,
aspect_ratio: str,
@ -266,7 +268,7 @@ class FluxProUltraImageNode(ComfyNodeABC):
),
auth_kwargs=kwargs,
)
output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id)
output_image = await handle_bfl_synchronous_operation(operation, node_id=unique_id)
return (output_image,)
@ -354,7 +356,7 @@ class FluxKontextProImageNode(ComfyNodeABC):
BFL_PATH = "/proxy/bfl/flux-kontext-pro/generate"
def api_call(
async def api_call(
self,
prompt: str,
aspect_ratio: str,
@ -397,7 +399,7 @@ class FluxKontextProImageNode(ComfyNodeABC):
),
auth_kwargs=kwargs,
)
output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id)
output_image = await handle_bfl_synchronous_operation(operation, node_id=unique_id)
return (output_image,)
@ -489,7 +491,7 @@ class FluxProImageNode(ComfyNodeABC):
API_NODE = True
CATEGORY = "api node/image/BFL"
def api_call(
async def api_call(
self,
prompt: str,
prompt_upsampling,
@ -524,7 +526,7 @@ class FluxProImageNode(ComfyNodeABC):
),
auth_kwargs=kwargs,
)
output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id)
output_image = await handle_bfl_synchronous_operation(operation, node_id=unique_id)
return (output_image,)
@ -632,7 +634,7 @@ class FluxProExpandNode(ComfyNodeABC):
API_NODE = True
CATEGORY = "api node/image/BFL"
def api_call(
async def api_call(
self,
image: torch.Tensor,
prompt: str,
@ -670,7 +672,7 @@ class FluxProExpandNode(ComfyNodeABC):
),
auth_kwargs=kwargs,
)
output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id)
output_image = await handle_bfl_synchronous_operation(operation, node_id=unique_id)
return (output_image,)
@ -744,7 +746,7 @@ class FluxProFillNode(ComfyNodeABC):
API_NODE = True
CATEGORY = "api node/image/BFL"
def api_call(
async def api_call(
self,
image: torch.Tensor,
mask: torch.Tensor,
@ -780,7 +782,7 @@ class FluxProFillNode(ComfyNodeABC):
),
auth_kwargs=kwargs,
)
output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id)
output_image = await handle_bfl_synchronous_operation(operation, node_id=unique_id)
return (output_image,)
@ -879,7 +881,7 @@ class FluxProCannyNode(ComfyNodeABC):
API_NODE = True
CATEGORY = "api node/image/BFL"
def api_call(
async def api_call(
self,
control_image: torch.Tensor,
prompt: str,
@ -929,7 +931,7 @@ class FluxProCannyNode(ComfyNodeABC):
),
auth_kwargs=kwargs,
)
output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id)
output_image = await handle_bfl_synchronous_operation(operation, node_id=unique_id)
return (output_image,)
@ -1008,7 +1010,7 @@ class FluxProDepthNode(ComfyNodeABC):
API_NODE = True
CATEGORY = "api node/image/BFL"
def api_call(
async def api_call(
self,
control_image: torch.Tensor,
prompt: str,
@ -1045,7 +1047,7 @@ class FluxProDepthNode(ComfyNodeABC):
),
auth_kwargs=kwargs,
)
output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id)
output_image = await handle_bfl_synchronous_operation(operation, node_id=unique_id)
return (output_image,)

View File

@ -303,7 +303,7 @@ class GeminiNode(ComfyNodeABC):
"""
return GeminiPart(text=text)
def api_call(
async def api_call(
self,
prompt: str,
model: GeminiModel,
@ -332,7 +332,7 @@ class GeminiNode(ComfyNodeABC):
parts.extend(files)
# Create response
response = SynchronousOperation(
response = await SynchronousOperation(
endpoint=get_gemini_endpoint(model),
request=GeminiGenerateContentRequest(
contents=[

View File

@ -212,7 +212,7 @@ V3_RESOLUTIONS= [
"1536x640"
]
def download_and_process_images(image_urls):
async def download_and_process_images(image_urls):
"""Helper function to download and process multiple images from URLs"""
# Initialize list to store image tensors
@ -220,7 +220,7 @@ def download_and_process_images(image_urls):
for image_url in image_urls:
# Using functions from apinode_utils.py to handle downloading and processing
image_bytesio = download_url_to_bytesio(image_url) # Download image content to BytesIO
image_bytesio = await download_url_to_bytesio(image_url) # Download image content to BytesIO
img_tensor = bytesio_to_image_tensor(image_bytesio, mode="RGB") # Convert to torch.Tensor with RGB mode
image_tensors.append(img_tensor)
@ -328,7 +328,7 @@ class IdeogramV1(ComfyNodeABC):
DESCRIPTION = cleandoc(__doc__ or "")
API_NODE = True
def api_call(
async def api_call(
self,
prompt,
turbo=False,
@ -367,7 +367,7 @@ class IdeogramV1(ComfyNodeABC):
auth_kwargs=kwargs,
)
response = operation.execute()
response = await operation.execute()
if not response.data or len(response.data) == 0:
raise Exception("No images were generated in the response")
@ -378,7 +378,7 @@ class IdeogramV1(ComfyNodeABC):
raise Exception("No image URLs were generated in the response")
display_image_urls_on_node(image_urls, unique_id)
return (download_and_process_images(image_urls),)
return (await download_and_process_images(image_urls),)
class IdeogramV2(ComfyNodeABC):
@ -487,7 +487,7 @@ class IdeogramV2(ComfyNodeABC):
DESCRIPTION = cleandoc(__doc__ or "")
API_NODE = True
def api_call(
async def api_call(
self,
prompt,
turbo=False,
@ -543,7 +543,7 @@ class IdeogramV2(ComfyNodeABC):
auth_kwargs=kwargs,
)
response = operation.execute()
response = await operation.execute()
if not response.data or len(response.data) == 0:
raise Exception("No images were generated in the response")
@ -554,7 +554,7 @@ class IdeogramV2(ComfyNodeABC):
raise Exception("No image URLs were generated in the response")
display_image_urls_on_node(image_urls, unique_id)
return (download_and_process_images(image_urls),)
return (await download_and_process_images(image_urls),)
class IdeogramV3(ComfyNodeABC):
"""
@ -653,7 +653,7 @@ class IdeogramV3(ComfyNodeABC):
DESCRIPTION = cleandoc(__doc__ or "")
API_NODE = True
def api_call(
async def api_call(
self,
prompt,
image=None,
@ -774,7 +774,7 @@ class IdeogramV3(ComfyNodeABC):
)
# Execute the operation and process response
response = operation.execute()
response = await operation.execute()
if not response.data or len(response.data) == 0:
raise Exception("No images were generated in the response")
@ -785,7 +785,7 @@ class IdeogramV3(ComfyNodeABC):
raise Exception("No image URLs were generated in the response")
display_image_urls_on_node(image_urls, unique_id)
return (download_and_process_images(image_urls),)
return (await download_and_process_images(image_urls),)
NODE_CLASS_MAPPINGS = {

View File

@ -109,7 +109,7 @@ class KlingApiError(Exception):
pass
def poll_until_finished(
async def poll_until_finished(
auth_kwargs: dict[str, str],
api_endpoint: ApiEndpoint[Any, R],
result_url_extractor: Optional[Callable[[R], str]] = None,
@ -117,7 +117,7 @@ def poll_until_finished(
node_id: Optional[str] = None,
) -> R:
"""Polls the Kling API endpoint until the task reaches a terminal state, then returns the response."""
return PollingOperation(
return await PollingOperation(
poll_endpoint=api_endpoint,
completed_statuses=[
KlingTaskStatus.succeed.value,
@ -278,18 +278,18 @@ def get_images_urls_from_response(response) -> Optional[str]:
return None
def video_result_to_node_output(
async def video_result_to_node_output(
video: KlingVideoResult,
) -> tuple[VideoFromFile, str, str]:
"""Converts a KlingVideoResult to a tuple of (VideoFromFile, str, str) to be used as a ComfyUI node output."""
return (
download_url_to_video_output(video.url),
await download_url_to_video_output(str(video.url)),
str(video.id),
str(video.duration),
)
def image_result_to_node_output(
async def image_result_to_node_output(
images: list[KlingImageResult],
) -> torch.Tensor:
"""
@ -297,9 +297,9 @@ def image_result_to_node_output(
If multiple images are returned, they will be stacked along the batch dimension.
"""
if len(images) == 1:
return download_url_to_image_tensor(images[0].url)
return await download_url_to_image_tensor(str(images[0].url))
else:
return torch.cat([download_url_to_image_tensor(image.url) for image in images])
return torch.cat([await download_url_to_image_tensor(str(image.url)) for image in images])
class KlingNodeBase(ComfyNodeABC):
@ -467,10 +467,10 @@ class KlingTextToVideoNode(KlingNodeBase):
RETURN_NAMES = ("VIDEO", "video_id", "duration")
DESCRIPTION = "Kling Text to Video Node"
def get_response(
async def get_response(
self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
) -> KlingText2VideoResponse:
return poll_until_finished(
return await poll_until_finished(
auth_kwargs,
ApiEndpoint(
path=f"{PATH_TEXT_TO_VIDEO}/{task_id}",
@ -483,7 +483,7 @@ class KlingTextToVideoNode(KlingNodeBase):
node_id=node_id,
)
def api_call(
async def api_call(
self,
prompt: str,
negative_prompt: str,
@ -519,17 +519,17 @@ class KlingTextToVideoNode(KlingNodeBase):
auth_kwargs=kwargs,
)
task_creation_response = initial_operation.execute()
task_creation_response = await initial_operation.execute()
validate_task_creation_response(task_creation_response)
task_id = task_creation_response.data.task_id
final_response = self.get_response(
final_response = await self.get_response(
task_id, auth_kwargs=kwargs, node_id=unique_id
)
validate_video_result_response(final_response)
video = get_video_from_response(final_response)
return video_result_to_node_output(video)
return await video_result_to_node_output(video)
class KlingCameraControlT2VNode(KlingTextToVideoNode):
@ -581,7 +581,7 @@ class KlingCameraControlT2VNode(KlingTextToVideoNode):
DESCRIPTION = "Transform text into cinematic videos with professional camera movements that simulate real-world cinematography. Control virtual camera actions including zoom, rotation, pan, tilt, and first-person view, while maintaining focus on your original text."
def api_call(
async def api_call(
self,
prompt: str,
negative_prompt: str,
@ -591,7 +591,7 @@ class KlingCameraControlT2VNode(KlingTextToVideoNode):
unique_id: Optional[str] = None,
**kwargs,
):
return super().api_call(
return await super().api_call(
model_name=KlingVideoGenModelName.kling_v1,
cfg_scale=cfg_scale,
mode=KlingVideoGenMode.std,
@ -670,10 +670,10 @@ class KlingImage2VideoNode(KlingNodeBase):
RETURN_NAMES = ("VIDEO", "video_id", "duration")
DESCRIPTION = "Kling Image to Video Node"
def get_response(
async def get_response(
self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
) -> KlingImage2VideoResponse:
return poll_until_finished(
return await poll_until_finished(
auth_kwargs,
ApiEndpoint(
path=f"{PATH_IMAGE_TO_VIDEO}/{task_id}",
@ -686,7 +686,7 @@ class KlingImage2VideoNode(KlingNodeBase):
node_id=node_id,
)
def api_call(
async def api_call(
self,
start_frame: torch.Tensor,
prompt: str,
@ -733,17 +733,17 @@ class KlingImage2VideoNode(KlingNodeBase):
auth_kwargs=kwargs,
)
task_creation_response = initial_operation.execute()
task_creation_response = await initial_operation.execute()
validate_task_creation_response(task_creation_response)
task_id = task_creation_response.data.task_id
final_response = self.get_response(
final_response = await self.get_response(
task_id, auth_kwargs=kwargs, node_id=unique_id
)
validate_video_result_response(final_response)
video = get_video_from_response(final_response)
return video_result_to_node_output(video)
return await video_result_to_node_output(video)
class KlingCameraControlI2VNode(KlingImage2VideoNode):
@ -798,7 +798,7 @@ class KlingCameraControlI2VNode(KlingImage2VideoNode):
DESCRIPTION = "Transform still images into cinematic videos with professional camera movements that simulate real-world cinematography. Control virtual camera actions including zoom, rotation, pan, tilt, and first-person view, while maintaining focus on your original image."
def api_call(
async def api_call(
self,
start_frame: torch.Tensor,
prompt: str,
@ -809,7 +809,7 @@ class KlingCameraControlI2VNode(KlingImage2VideoNode):
unique_id: Optional[str] = None,
**kwargs,
):
return super().api_call(
return await super().api_call(
model_name=KlingVideoGenModelName.kling_v1_5,
start_frame=start_frame,
cfg_scale=cfg_scale,
@ -897,7 +897,7 @@ class KlingStartEndFrameNode(KlingImage2VideoNode):
DESCRIPTION = "Generate a video sequence that transitions between your provided start and end images. The node creates all frames in between, producing a smooth transformation from the first frame to the last."
def api_call(
async def api_call(
self,
start_frame: torch.Tensor,
end_frame: torch.Tensor,
@ -912,7 +912,7 @@ class KlingStartEndFrameNode(KlingImage2VideoNode):
mode, duration, model_name = KlingStartEndFrameNode.get_mode_string_mapping()[
mode
]
return super().api_call(
return await super().api_call(
prompt=prompt,
negative_prompt=negative_prompt,
model_name=model_name,
@ -964,10 +964,10 @@ class KlingVideoExtendNode(KlingNodeBase):
RETURN_NAMES = ("VIDEO", "video_id", "duration")
DESCRIPTION = "Kling Video Extend Node. Extend videos made by other Kling nodes. The video_id is created by using other Kling Nodes."
def get_response(
async def get_response(
self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
) -> KlingVideoExtendResponse:
return poll_until_finished(
return await poll_until_finished(
auth_kwargs,
ApiEndpoint(
path=f"{PATH_VIDEO_EXTEND}/{task_id}",
@ -980,7 +980,7 @@ class KlingVideoExtendNode(KlingNodeBase):
node_id=node_id,
)
def api_call(
async def api_call(
self,
prompt: str,
negative_prompt: str,
@ -1006,17 +1006,17 @@ class KlingVideoExtendNode(KlingNodeBase):
auth_kwargs=kwargs,
)
task_creation_response = initial_operation.execute()
task_creation_response = await initial_operation.execute()
validate_task_creation_response(task_creation_response)
task_id = task_creation_response.data.task_id
final_response = self.get_response(
final_response = await self.get_response(
task_id, auth_kwargs=kwargs, node_id=unique_id
)
validate_video_result_response(final_response)
video = get_video_from_response(final_response)
return video_result_to_node_output(video)
return await video_result_to_node_output(video)
class KlingVideoEffectsBase(KlingNodeBase):
@ -1025,10 +1025,10 @@ class KlingVideoEffectsBase(KlingNodeBase):
RETURN_TYPES = ("VIDEO", "STRING", "STRING")
RETURN_NAMES = ("VIDEO", "video_id", "duration")
def get_response(
async def get_response(
self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
) -> KlingVideoEffectsResponse:
return poll_until_finished(
return await poll_until_finished(
auth_kwargs,
ApiEndpoint(
path=f"{PATH_VIDEO_EFFECTS}/{task_id}",
@ -1041,7 +1041,7 @@ class KlingVideoEffectsBase(KlingNodeBase):
node_id=node_id,
)
def api_call(
async def api_call(
self,
dual_character: bool,
effect_scene: KlingDualCharacterEffectsScene | KlingSingleImageEffectsScene,
@ -1084,17 +1084,17 @@ class KlingVideoEffectsBase(KlingNodeBase):
auth_kwargs=kwargs,
)
task_creation_response = initial_operation.execute()
task_creation_response = await initial_operation.execute()
validate_task_creation_response(task_creation_response)
task_id = task_creation_response.data.task_id
final_response = self.get_response(
final_response = await self.get_response(
task_id, auth_kwargs=kwargs, node_id=unique_id
)
validate_video_result_response(final_response)
video = get_video_from_response(final_response)
return video_result_to_node_output(video)
return await video_result_to_node_output(video)
class KlingDualCharacterVideoEffectNode(KlingVideoEffectsBase):
@ -1142,7 +1142,7 @@ class KlingDualCharacterVideoEffectNode(KlingVideoEffectsBase):
RETURN_TYPES = ("VIDEO", "STRING")
RETURN_NAMES = ("VIDEO", "duration")
def api_call(
async def api_call(
self,
image_left: torch.Tensor,
image_right: torch.Tensor,
@ -1153,7 +1153,7 @@ class KlingDualCharacterVideoEffectNode(KlingVideoEffectsBase):
unique_id: Optional[str] = None,
**kwargs,
):
video, _, duration = super().api_call(
video, _, duration = await super().api_call(
dual_character=True,
effect_scene=effect_scene,
model_name=model_name,
@ -1208,7 +1208,7 @@ class KlingSingleImageVideoEffectNode(KlingVideoEffectsBase):
DESCRIPTION = "Achieve different special effects when generating a video based on the effect_scene."
def api_call(
async def api_call(
self,
image: torch.Tensor,
effect_scene: KlingSingleImageEffectsScene,
@ -1217,7 +1217,7 @@ class KlingSingleImageVideoEffectNode(KlingVideoEffectsBase):
unique_id: Optional[str] = None,
**kwargs,
):
return super().api_call(
return await super().api_call(
dual_character=False,
effect_scene=effect_scene,
model_name=model_name,
@ -1253,11 +1253,11 @@ class KlingLipSyncBase(KlingNodeBase):
f"Text is too long. Maximum length is {MAX_PROMPT_LENGTH_LIP_SYNC} characters."
)
def get_response(
async def get_response(
self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
) -> KlingLipSyncResponse:
"""Polls the Kling API endpoint until the task reaches a terminal state."""
return poll_until_finished(
return await poll_until_finished(
auth_kwargs,
ApiEndpoint(
path=f"{PATH_LIP_SYNC}/{task_id}",
@ -1270,7 +1270,7 @@ class KlingLipSyncBase(KlingNodeBase):
node_id=node_id,
)
def api_call(
async def api_call(
self,
video: VideoInput,
audio: Optional[AudioInput] = None,
@ -1287,12 +1287,12 @@ class KlingLipSyncBase(KlingNodeBase):
self.validate_lip_sync_video(video)
# Upload video to Comfy API and get download URL
video_url = upload_video_to_comfyapi(video, auth_kwargs=kwargs)
video_url = await upload_video_to_comfyapi(video, auth_kwargs=kwargs)
logging.info("Uploaded video to Comfy API. URL: %s", video_url)
# Upload the audio file to Comfy API and get download URL
if audio:
audio_url = upload_audio_to_comfyapi(audio, auth_kwargs=kwargs)
audio_url = await upload_audio_to_comfyapi(audio, auth_kwargs=kwargs)
logging.info("Uploaded audio to Comfy API. URL: %s", audio_url)
else:
audio_url = None
@ -1319,17 +1319,17 @@ class KlingLipSyncBase(KlingNodeBase):
auth_kwargs=kwargs,
)
task_creation_response = initial_operation.execute()
task_creation_response = await initial_operation.execute()
validate_task_creation_response(task_creation_response)
task_id = task_creation_response.data.task_id
final_response = self.get_response(
final_response = await self.get_response(
task_id, auth_kwargs=kwargs, node_id=unique_id
)
validate_video_result_response(final_response)
video = get_video_from_response(final_response)
return video_result_to_node_output(video)
return await video_result_to_node_output(video)
class KlingLipSyncAudioToVideoNode(KlingLipSyncBase):
@ -1357,7 +1357,7 @@ class KlingLipSyncAudioToVideoNode(KlingLipSyncBase):
DESCRIPTION = "Kling Lip Sync Audio to Video Node. Syncs mouth movements in a video file to the audio content of an audio file. When using, ensure that the audio contains clearly distinguishable vocals and that the video contains a distinct face. The audio file should not be larger than 5MB. The video file should not be larger than 100MB, should have height/width between 720px and 1920px, and should be between 2s and 10s in length."
def api_call(
async def api_call(
self,
video: VideoInput,
audio: AudioInput,
@ -1365,7 +1365,7 @@ class KlingLipSyncAudioToVideoNode(KlingLipSyncBase):
unique_id: Optional[str] = None,
**kwargs,
):
return super().api_call(
return await super().api_call(
video=video,
audio=audio,
voice_language=voice_language,
@ -1469,7 +1469,7 @@ class KlingLipSyncTextToVideoNode(KlingLipSyncBase):
DESCRIPTION = "Kling Lip Sync Text to Video Node. Syncs mouth movements in a video file to a text prompt. The video file should not be larger than 100MB, should have height/width between 720px and 1920px, and should be between 2s and 10s in length."
def api_call(
async def api_call(
self,
video: VideoInput,
text: str,
@ -1479,7 +1479,7 @@ class KlingLipSyncTextToVideoNode(KlingLipSyncBase):
**kwargs,
):
voice_id, voice_language = KlingLipSyncTextToVideoNode.get_voice_config()[voice]
return super().api_call(
return await super().api_call(
video=video,
text=text,
voice_language=voice_language,
@ -1533,10 +1533,10 @@ class KlingVirtualTryOnNode(KlingImageGenerationBase):
DESCRIPTION = "Kling Virtual Try On Node. Input a human image and a cloth image to try on the cloth on the human. You can merge multiple clothing item pictures into one image with a white background."
def get_response(
async def get_response(
self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
) -> KlingVirtualTryOnResponse:
return poll_until_finished(
return await poll_until_finished(
auth_kwargs,
ApiEndpoint(
path=f"{PATH_VIRTUAL_TRY_ON}/{task_id}",
@ -1549,7 +1549,7 @@ class KlingVirtualTryOnNode(KlingImageGenerationBase):
node_id=node_id,
)
def api_call(
async def api_call(
self,
human_image: torch.Tensor,
cloth_image: torch.Tensor,
@ -1572,17 +1572,17 @@ class KlingVirtualTryOnNode(KlingImageGenerationBase):
auth_kwargs=kwargs,
)
task_creation_response = initial_operation.execute()
task_creation_response = await initial_operation.execute()
validate_task_creation_response(task_creation_response)
task_id = task_creation_response.data.task_id
final_response = self.get_response(
final_response = await self.get_response(
task_id, auth_kwargs=kwargs, node_id=unique_id
)
validate_image_result_response(final_response)
images = get_images_from_response(final_response)
return (image_result_to_node_output(images),)
return (await image_result_to_node_output(images),)
class KlingImageGenerationNode(KlingImageGenerationBase):
@ -1655,13 +1655,13 @@ class KlingImageGenerationNode(KlingImageGenerationBase):
DESCRIPTION = "Kling Image Generation Node. Generate an image from a text prompt with an optional reference image."
def get_response(
async def get_response(
self,
task_id: str,
auth_kwargs: Optional[dict[str, str]],
node_id: Optional[str] = None,
) -> KlingImageGenerationsResponse:
return poll_until_finished(
return await poll_until_finished(
auth_kwargs,
ApiEndpoint(
path=f"{PATH_IMAGE_GENERATIONS}/{task_id}",
@ -1674,7 +1674,7 @@ class KlingImageGenerationNode(KlingImageGenerationBase):
node_id=node_id,
)
def api_call(
async def api_call(
self,
model_name: KlingImageGenModelName,
prompt: str,
@ -1714,17 +1714,17 @@ class KlingImageGenerationNode(KlingImageGenerationBase):
auth_kwargs=kwargs,
)
task_creation_response = initial_operation.execute()
task_creation_response = await initial_operation.execute()
validate_task_creation_response(task_creation_response)
task_id = task_creation_response.data.task_id
final_response = self.get_response(
final_response = await self.get_response(
task_id, auth_kwargs=kwargs, node_id=unique_id
)
validate_image_result_response(final_response)
images = get_images_from_response(final_response)
return (image_result_to_node_output(images),)
return (await image_result_to_node_output(images),)
NODE_CLASS_MAPPINGS = {

View File

@ -38,7 +38,7 @@ from comfy_api_nodes.apinode_utils import (
)
from server import PromptServer
import requests
import aiohttp
import torch
from io import BytesIO
@ -217,7 +217,7 @@ class LumaImageGenerationNode(ComfyNodeABC):
},
}
def api_call(
async def api_call(
self,
prompt: str,
model: str,
@ -234,19 +234,19 @@ class LumaImageGenerationNode(ComfyNodeABC):
# handle image_luma_ref
api_image_ref = None
if image_luma_ref is not None:
api_image_ref = self._convert_luma_refs(
api_image_ref = await self._convert_luma_refs(
image_luma_ref, max_refs=4, auth_kwargs=kwargs,
)
# handle style_luma_ref
api_style_ref = None
if style_image is not None:
api_style_ref = self._convert_style_image(
api_style_ref = await self._convert_style_image(
style_image, weight=style_image_weight, auth_kwargs=kwargs,
)
# handle character_ref images
character_ref = None
if character_image is not None:
download_urls = upload_images_to_comfyapi(
download_urls = await upload_images_to_comfyapi(
character_image, max_images=4, auth_kwargs=kwargs,
)
character_ref = LumaCharacterRef(
@ -270,7 +270,7 @@ class LumaImageGenerationNode(ComfyNodeABC):
),
auth_kwargs=kwargs,
)
response_api: LumaGeneration = operation.execute()
response_api: LumaGeneration = await operation.execute()
operation = PollingOperation(
poll_endpoint=ApiEndpoint(
@ -286,19 +286,20 @@ class LumaImageGenerationNode(ComfyNodeABC):
node_id=unique_id,
auth_kwargs=kwargs,
)
response_poll = operation.execute()
response_poll = await operation.execute()
img_response = requests.get(response_poll.assets.image)
img = process_image_response(img_response)
async with aiohttp.ClientSession() as session:
async with session.get(response_poll.assets.image) as img_response:
img = process_image_response(await img_response.content.read())
return (img,)
def _convert_luma_refs(
async def _convert_luma_refs(
self, luma_ref: LumaReferenceChain, max_refs: int, auth_kwargs: Optional[dict[str,str]] = None
):
luma_urls = []
ref_count = 0
for ref in luma_ref.refs:
download_urls = upload_images_to_comfyapi(
download_urls = await upload_images_to_comfyapi(
ref.image, max_images=1, auth_kwargs=auth_kwargs
)
luma_urls.append(download_urls[0])
@ -307,13 +308,13 @@ class LumaImageGenerationNode(ComfyNodeABC):
break
return luma_ref.create_api_model(download_urls=luma_urls, max_refs=max_refs)
def _convert_style_image(
async def _convert_style_image(
self, style_image: torch.Tensor, weight: float, auth_kwargs: Optional[dict[str,str]] = None
):
chain = LumaReferenceChain(
first_ref=LumaReference(image=style_image, weight=weight)
)
return self._convert_luma_refs(chain, max_refs=1, auth_kwargs=auth_kwargs)
return await self._convert_luma_refs(chain, max_refs=1, auth_kwargs=auth_kwargs)
class LumaImageModifyNode(ComfyNodeABC):
@ -370,7 +371,7 @@ class LumaImageModifyNode(ComfyNodeABC):
},
}
def api_call(
async def api_call(
self,
prompt: str,
model: str,
@ -381,7 +382,7 @@ class LumaImageModifyNode(ComfyNodeABC):
**kwargs,
):
# first, upload image
download_urls = upload_images_to_comfyapi(
download_urls = await upload_images_to_comfyapi(
image, max_images=1, auth_kwargs=kwargs,
)
image_url = download_urls[0]
@ -402,7 +403,7 @@ class LumaImageModifyNode(ComfyNodeABC):
),
auth_kwargs=kwargs,
)
response_api: LumaGeneration = operation.execute()
response_api: LumaGeneration = await operation.execute()
operation = PollingOperation(
poll_endpoint=ApiEndpoint(
@ -418,10 +419,11 @@ class LumaImageModifyNode(ComfyNodeABC):
node_id=unique_id,
auth_kwargs=kwargs,
)
response_poll = operation.execute()
response_poll = await operation.execute()
img_response = requests.get(response_poll.assets.image)
img = process_image_response(img_response)
async with aiohttp.ClientSession() as session:
async with session.get(response_poll.assets.image) as img_response:
img = process_image_response(await img_response.content.read())
return (img,)
@ -494,7 +496,7 @@ class LumaTextToVideoGenerationNode(ComfyNodeABC):
},
}
def api_call(
async def api_call(
self,
prompt: str,
model: str,
@ -529,7 +531,7 @@ class LumaTextToVideoGenerationNode(ComfyNodeABC):
),
auth_kwargs=kwargs,
)
response_api: LumaGeneration = operation.execute()
response_api: LumaGeneration = await operation.execute()
if unique_id:
PromptServer.instance.send_progress_text(f"Luma video generation started: {response_api.id}", unique_id)
@ -549,10 +551,11 @@ class LumaTextToVideoGenerationNode(ComfyNodeABC):
estimated_duration=LUMA_T2V_AVERAGE_DURATION,
auth_kwargs=kwargs,
)
response_poll = operation.execute()
response_poll = await operation.execute()
vid_response = requests.get(response_poll.assets.video)
return (VideoFromFile(BytesIO(vid_response.content)),)
async with aiohttp.ClientSession() as session:
async with session.get(response_poll.assets.video) as vid_response:
return (VideoFromFile(BytesIO(await vid_response.content.read())),)
class LumaImageToVideoGenerationNode(ComfyNodeABC):
@ -626,7 +629,7 @@ class LumaImageToVideoGenerationNode(ComfyNodeABC):
},
}
def api_call(
async def api_call(
self,
prompt: str,
model: str,
@ -644,7 +647,7 @@ class LumaImageToVideoGenerationNode(ComfyNodeABC):
raise Exception(
"At least one of first_image and last_image requires an input."
)
keyframes = self._convert_to_keyframes(first_image, last_image, auth_kwargs=kwargs)
keyframes = await self._convert_to_keyframes(first_image, last_image, auth_kwargs=kwargs)
duration = duration if model != LumaVideoModel.ray_1_6 else None
resolution = resolution if model != LumaVideoModel.ray_1_6 else None
@ -667,7 +670,7 @@ class LumaImageToVideoGenerationNode(ComfyNodeABC):
),
auth_kwargs=kwargs,
)
response_api: LumaGeneration = operation.execute()
response_api: LumaGeneration = await operation.execute()
if unique_id:
PromptServer.instance.send_progress_text(f"Luma video generation started: {response_api.id}", unique_id)
@ -687,12 +690,13 @@ class LumaImageToVideoGenerationNode(ComfyNodeABC):
estimated_duration=LUMA_I2V_AVERAGE_DURATION,
auth_kwargs=kwargs,
)
response_poll = operation.execute()
response_poll = await operation.execute()
vid_response = requests.get(response_poll.assets.video)
return (VideoFromFile(BytesIO(vid_response.content)),)
async with aiohttp.ClientSession() as session:
async with session.get(response_poll.assets.video) as vid_response:
return (VideoFromFile(BytesIO(await vid_response.content.read())),)
def _convert_to_keyframes(
async def _convert_to_keyframes(
self,
first_image: torch.Tensor = None,
last_image: torch.Tensor = None,
@ -703,12 +707,12 @@ class LumaImageToVideoGenerationNode(ComfyNodeABC):
frame0 = None
frame1 = None
if first_image is not None:
download_urls = upload_images_to_comfyapi(
download_urls = await upload_images_to_comfyapi(
first_image, max_images=1, auth_kwargs=auth_kwargs,
)
frame0 = LumaImageReference(type="image", url=download_urls[0])
if last_image is not None:
download_urls = upload_images_to_comfyapi(
download_urls = await upload_images_to_comfyapi(
last_image, max_images=1, auth_kwargs=auth_kwargs,
)
frame1 = LumaImageReference(type="image", url=download_urls[0])

View File

@ -86,7 +86,7 @@ class MinimaxTextToVideoNode:
API_NODE = True
OUTPUT_NODE = True
def generate_video(
async def generate_video(
self,
prompt_text,
seed=0,
@ -104,12 +104,12 @@ class MinimaxTextToVideoNode:
# upload image, if passed in
image_url = None
if image is not None:
image_url = upload_images_to_comfyapi(image, max_images=1, auth_kwargs=kwargs)[0]
image_url = (await upload_images_to_comfyapi(image, max_images=1, auth_kwargs=kwargs))[0]
# TODO: figure out how to deal with subject properly, API returns invalid params when using S2V-01 model
subject_reference = None
if subject is not None:
subject_url = upload_images_to_comfyapi(subject, max_images=1, auth_kwargs=kwargs)[0]
subject_url = (await upload_images_to_comfyapi(subject, max_images=1, auth_kwargs=kwargs))[0]
subject_reference = [SubjectReferenceItem(image=subject_url)]
@ -130,7 +130,7 @@ class MinimaxTextToVideoNode:
),
auth_kwargs=kwargs,
)
response = video_generate_operation.execute()
response = await video_generate_operation.execute()
task_id = response.task_id
if not task_id:
@ -151,7 +151,7 @@ class MinimaxTextToVideoNode:
node_id=unique_id,
auth_kwargs=kwargs,
)
task_result = video_generate_operation.execute()
task_result = await video_generate_operation.execute()
file_id = task_result.file_id
if file_id is None:
@ -167,7 +167,7 @@ class MinimaxTextToVideoNode:
request=EmptyRequest(),
auth_kwargs=kwargs,
)
file_result = file_retrieve_operation.execute()
file_result = await file_retrieve_operation.execute()
file_url = file_result.file.download_url
if file_url is None:
@ -182,7 +182,7 @@ class MinimaxTextToVideoNode:
message = f"Result URL: {file_url}"
PromptServer.instance.send_progress_text(message, unique_id)
video_io = download_url_to_bytesio(file_url)
video_io = await download_url_to_bytesio(file_url)
if video_io is None:
error_msg = f"Failed to download video from {file_url}"
logging.error(error_msg)

View File

@ -95,14 +95,14 @@ def get_video_url_from_response(response) -> Optional[str]:
return None
def poll_until_finished(
async def poll_until_finished(
auth_kwargs: dict[str, str],
api_endpoint: ApiEndpoint[Any, R],
result_url_extractor: Optional[Callable[[R], str]] = None,
node_id: Optional[str] = None,
) -> R:
"""Polls the Moonvalley API endpoint until the task reaches a terminal state, then returns the response."""
return PollingOperation(
return await PollingOperation(
poll_endpoint=api_endpoint,
completed_statuses=[
"completed",
@ -394,10 +394,10 @@ class BaseMoonvalleyVideoNode:
else:
return control_map["Motion Transfer"]
def get_response(
async def get_response(
self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
) -> MoonvalleyPromptResponse:
return poll_until_finished(
return await poll_until_finished(
auth_kwargs,
ApiEndpoint(
path=f"{API_PROMPTS_ENDPOINT}/{task_id}",
@ -507,7 +507,7 @@ class MoonvalleyImg2VideoNode(BaseMoonvalleyVideoNode):
RETURN_NAMES = ("video",)
DESCRIPTION = "Moonvalley Marey Image to Video Node"
def generate(
async def generate(
self, prompt, negative_prompt, unique_id: Optional[str] = None, **kwargs
):
image = kwargs.get("image", None)
@ -532,9 +532,9 @@ class MoonvalleyImg2VideoNode(BaseMoonvalleyVideoNode):
# Get MIME type from tensor - assuming PNG format for image tensors
mime_type = "image/png"
image_url = upload_images_to_comfyapi(
image_url = (await upload_images_to_comfyapi(
image, max_images=1, auth_kwargs=kwargs, mime_type=mime_type
)[0]
))[0]
request = MoonvalleyTextToVideoRequest(
image_url=image_url, prompt_text=prompt, inference_params=inference_params
@ -549,14 +549,14 @@ class MoonvalleyImg2VideoNode(BaseMoonvalleyVideoNode):
request=request,
auth_kwargs=kwargs,
)
task_creation_response = initial_operation.execute()
task_creation_response = await initial_operation.execute()
validate_task_creation_response(task_creation_response)
task_id = task_creation_response.id
final_response = self.get_response(
final_response = await self.get_response(
task_id, auth_kwargs=kwargs, node_id=unique_id
)
video = download_url_to_video_output(final_response.output_url)
video = await download_url_to_video_output(final_response.output_url)
return (video,)
@ -609,7 +609,7 @@ class MoonvalleyVideo2VideoNode(BaseMoonvalleyVideoNode):
RETURN_TYPES = ("VIDEO",)
RETURN_NAMES = ("video",)
def generate(
async def generate(
self, prompt, negative_prompt, unique_id: Optional[str] = None, **kwargs
):
video = kwargs.get("video")
@ -620,7 +620,7 @@ class MoonvalleyVideo2VideoNode(BaseMoonvalleyVideoNode):
video_url = ""
if video:
validated_video = validate_video_to_video_input(video)
video_url = upload_video_to_comfyapi(validated_video, auth_kwargs=kwargs)
video_url = await upload_video_to_comfyapi(validated_video, auth_kwargs=kwargs)
control_type = kwargs.get("control_type")
motion_intensity = kwargs.get("motion_intensity")
@ -658,15 +658,15 @@ class MoonvalleyVideo2VideoNode(BaseMoonvalleyVideoNode):
request=request,
auth_kwargs=kwargs,
)
task_creation_response = initial_operation.execute()
task_creation_response = await initial_operation.execute()
validate_task_creation_response(task_creation_response)
task_id = task_creation_response.id
final_response = self.get_response(
final_response = await self.get_response(
task_id, auth_kwargs=kwargs, node_id=unique_id
)
video = download_url_to_video_output(final_response.output_url)
video = await download_url_to_video_output(final_response.output_url)
return (video,)
@ -688,7 +688,7 @@ class MoonvalleyTxt2VideoNode(BaseMoonvalleyVideoNode):
del input_types["optional"][param]
return input_types
def generate(
async def generate(
self, prompt, negative_prompt, unique_id: Optional[str] = None, **kwargs
):
validate_prompts(prompt, negative_prompt, MOONVALLEY_MAREY_MAX_PROMPT_LENGTH)
@ -717,15 +717,15 @@ class MoonvalleyTxt2VideoNode(BaseMoonvalleyVideoNode):
request=request,
auth_kwargs=kwargs,
)
task_creation_response = initial_operation.execute()
task_creation_response = await initial_operation.execute()
validate_task_creation_response(task_creation_response)
task_id = task_creation_response.id
final_response = self.get_response(
final_response = await self.get_response(
task_id, auth_kwargs=kwargs, node_id=unique_id
)
video = download_url_to_video_output(final_response.output_url)
video = await download_url_to_video_output(final_response.output_url)
return (video,)

View File

@ -163,7 +163,7 @@ class OpenAIDalle2(ComfyNodeABC):
DESCRIPTION = cleandoc(__doc__ or "")
API_NODE = True
def api_call(
async def api_call(
self,
prompt,
seed=0,
@ -233,9 +233,9 @@ class OpenAIDalle2(ComfyNodeABC):
auth_kwargs=kwargs,
)
response = operation.execute()
response = await operation.execute()
img_tensor = validate_and_cast_response(response, node_id=unique_id)
img_tensor = await validate_and_cast_response(response, node_id=unique_id)
return (img_tensor,)
@ -311,7 +311,7 @@ class OpenAIDalle3(ComfyNodeABC):
DESCRIPTION = cleandoc(__doc__ or "")
API_NODE = True
def api_call(
async def api_call(
self,
prompt,
seed=0,
@ -343,9 +343,9 @@ class OpenAIDalle3(ComfyNodeABC):
auth_kwargs=kwargs,
)
response = operation.execute()
response = await operation.execute()
img_tensor = validate_and_cast_response(response, node_id=unique_id)
img_tensor = await validate_and_cast_response(response, node_id=unique_id)
return (img_tensor,)
@ -446,7 +446,7 @@ class OpenAIGPTImage1(ComfyNodeABC):
DESCRIPTION = cleandoc(__doc__ or "")
API_NODE = True
def api_call(
async def api_call(
self,
prompt,
seed=0,
@ -537,9 +537,9 @@ class OpenAIGPTImage1(ComfyNodeABC):
auth_kwargs=kwargs,
)
response = operation.execute()
response = await operation.execute()
img_tensor = validate_and_cast_response(response, node_id=unique_id)
img_tensor = await validate_and_cast_response(response, node_id=unique_id)
return (img_tensor,)
@ -623,7 +623,7 @@ class OpenAIChatNode(OpenAITextNode):
DESCRIPTION = "Generate text responses from an OpenAI model."
def get_result_response(
async def get_result_response(
self,
response_id: str,
include: Optional[list[Includable]] = None,
@ -639,7 +639,7 @@ class OpenAIChatNode(OpenAITextNode):
creation above for more information.
"""
return PollingOperation(
return await PollingOperation(
poll_endpoint=ApiEndpoint(
path=f"{RESPONSES_ENDPOINT}/{response_id}",
method=HttpMethod.GET,
@ -784,7 +784,7 @@ class OpenAIChatNode(OpenAITextNode):
self.history[session_id] = new_history
def api_call(
async def api_call(
self,
prompt: str,
persist_context: bool,
@ -815,7 +815,7 @@ class OpenAIChatNode(OpenAITextNode):
previous_response_id = None
# Create response
create_response = SynchronousOperation(
create_response = await SynchronousOperation(
endpoint=ApiEndpoint(
path=RESPONSES_ENDPOINT,
method=HttpMethod.POST,
@ -848,7 +848,7 @@ class OpenAIChatNode(OpenAITextNode):
response_id = create_response.id
# Get result output
result_response = self.get_result_response(response_id, auth_kwargs=kwargs)
result_response = await self.get_result_response(response_id, auth_kwargs=kwargs)
output_text = self.parse_output_text_from_response(result_response)
# Update history

View File

@ -122,7 +122,7 @@ class PikaNodeBase(ComfyNodeABC):
FUNCTION = "api_call"
RETURN_TYPES = ("VIDEO",)
def poll_for_task_status(
async def poll_for_task_status(
self,
task_id: str,
auth_kwargs: Optional[dict[str, str]] = None,
@ -152,9 +152,9 @@ class PikaNodeBase(ComfyNodeABC):
node_id=node_id,
estimated_duration=60
)
return polling_operation.execute()
return await polling_operation.execute()
def execute_task(
async def execute_task(
self,
initial_operation: SynchronousOperation[R, PikaGenerateResponse],
auth_kwargs: Optional[dict[str, str]] = None,
@ -169,14 +169,14 @@ class PikaNodeBase(ComfyNodeABC):
Returns:
A tuple containing the video file as a VIDEO output.
"""
initial_response = initial_operation.execute()
initial_response = await initial_operation.execute()
if not is_valid_initial_response(initial_response):
error_msg = f"Pika initial request failed. Code: {initial_response.code}, Message: {initial_response.message}, Data: {initial_response.data}"
logging.error(error_msg)
raise PikaApiError(error_msg)
task_id = initial_response.video_id
final_response = self.poll_for_task_status(task_id, auth_kwargs)
final_response = await self.poll_for_task_status(task_id, auth_kwargs)
if not is_valid_video_response(final_response):
error_msg = (
f"Pika task {task_id} succeeded but no video data found in response."
@ -187,7 +187,7 @@ class PikaNodeBase(ComfyNodeABC):
video_url = str(final_response.url)
logging.info("Pika task %s succeeded. Video URL: %s", task_id, video_url)
return (download_url_to_video_output(video_url),)
return (await download_url_to_video_output(video_url),)
class PikaImageToVideoV2_2(PikaNodeBase):
@ -212,7 +212,7 @@ class PikaImageToVideoV2_2(PikaNodeBase):
DESCRIPTION = "Sends an image and prompt to the Pika API v2.2 to generate a video."
def api_call(
async def api_call(
self,
image: torch.Tensor,
prompt_text: str,
@ -251,7 +251,7 @@ class PikaImageToVideoV2_2(PikaNodeBase):
auth_kwargs=kwargs,
)
return self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id)
return await self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id)
class PikaTextToVideoNodeV2_2(PikaNodeBase):
@ -281,7 +281,7 @@ class PikaTextToVideoNodeV2_2(PikaNodeBase):
DESCRIPTION = "Sends a text prompt to the Pika API v2.2 to generate a video."
def api_call(
async def api_call(
self,
prompt_text: str,
negative_prompt: str,
@ -311,7 +311,7 @@ class PikaTextToVideoNodeV2_2(PikaNodeBase):
content_type="application/x-www-form-urlencoded",
)
return self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id)
return await self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id)
class PikaScenesV2_2(PikaNodeBase):
@ -361,7 +361,7 @@ class PikaScenesV2_2(PikaNodeBase):
DESCRIPTION = "Combine your images to create a video with the objects in them. Upload multiple images as ingredients and generate a high-quality video that incorporates all of them."
def api_call(
async def api_call(
self,
prompt_text: str,
negative_prompt: str,
@ -420,7 +420,7 @@ class PikaScenesV2_2(PikaNodeBase):
auth_kwargs=kwargs,
)
return self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id)
return await self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id)
class PikAdditionsNode(PikaNodeBase):
@ -462,7 +462,7 @@ class PikAdditionsNode(PikaNodeBase):
DESCRIPTION = "Add any object or image into your video. Upload a video and specify what you'd like to add to create a seamlessly integrated result."
def api_call(
async def api_call(
self,
video: VideoInput,
image: torch.Tensor,
@ -481,10 +481,10 @@ class PikAdditionsNode(PikaNodeBase):
image_bytes_io = tensor_to_bytesio(image)
image_bytes_io.seek(0)
pika_files = [
("video", ("video.mp4", video_bytes_io, "video/mp4")),
("image", ("image.png", image_bytes_io, "image/png")),
]
pika_files = {
"video": ("video.mp4", video_bytes_io, "video/mp4"),
"image": ("image.png", image_bytes_io, "image/png"),
}
# Prepare non-file data
pika_request_data = PikaBodyGeneratePikadditionsGeneratePikadditionsPost(
@ -506,7 +506,7 @@ class PikAdditionsNode(PikaNodeBase):
auth_kwargs=kwargs,
)
return self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id)
return await self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id)
class PikaSwapsNode(PikaNodeBase):
@ -558,7 +558,7 @@ class PikaSwapsNode(PikaNodeBase):
DESCRIPTION = "Swap out any object or region of your video with a new image or object. Define areas to replace either with a mask or coordinates."
RETURN_TYPES = ("VIDEO",)
def api_call(
async def api_call(
self,
video: VideoInput,
image: torch.Tensor,
@ -587,11 +587,11 @@ class PikaSwapsNode(PikaNodeBase):
image_bytes_io = tensor_to_bytesio(image)
image_bytes_io.seek(0)
pika_files = [
("video", ("video.mp4", video_bytes_io, "video/mp4")),
("image", ("image.png", image_bytes_io, "image/png")),
("modifyRegionMask", ("mask.png", mask_bytes_io, "image/png")),
]
pika_files = {
"video": ("video.mp4", video_bytes_io, "video/mp4"),
"image": ("image.png", image_bytes_io, "image/png"),
"modifyRegionMask": ("mask.png", mask_bytes_io, "image/png"),
}
# Prepare non-file data
pika_request_data = PikaBodyGeneratePikaswapsGeneratePikaswapsPost(
@ -613,7 +613,7 @@ class PikaSwapsNode(PikaNodeBase):
auth_kwargs=kwargs,
)
return self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id)
return await self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id)
class PikaffectsNode(PikaNodeBase):
@ -664,7 +664,7 @@ class PikaffectsNode(PikaNodeBase):
DESCRIPTION = "Generate a video with a specific Pikaffect. Supported Pikaffects: Cake-ify, Crumble, Crush, Decapitate, Deflate, Dissolve, Explode, Eye-pop, Inflate, Levitate, Melt, Peel, Poke, Squish, Ta-da, Tear"
def api_call(
async def api_call(
self,
image: torch.Tensor,
pikaffect: str,
@ -693,7 +693,7 @@ class PikaffectsNode(PikaNodeBase):
auth_kwargs=kwargs,
)
return self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id)
return await self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id)
class PikaStartEndFrameNode2_2(PikaNodeBase):
@ -718,7 +718,7 @@ class PikaStartEndFrameNode2_2(PikaNodeBase):
DESCRIPTION = "Generate a video by combining your first and last frame. Upload two images to define the start and end points, and let the AI create a smooth transition between them."
def api_call(
async def api_call(
self,
image_start: torch.Tensor,
image_end: torch.Tensor,
@ -732,10 +732,7 @@ class PikaStartEndFrameNode2_2(PikaNodeBase):
) -> tuple[VideoFromFile]:
pika_files = [
(
"keyFrames",
("image_start.png", tensor_to_bytesio(image_start), "image/png"),
),
("keyFrames", ("image_start.png", tensor_to_bytesio(image_start), "image/png")),
("keyFrames", ("image_end.png", tensor_to_bytesio(image_end), "image/png")),
]
@ -758,7 +755,7 @@ class PikaStartEndFrameNode2_2(PikaNodeBase):
auth_kwargs=kwargs,
)
return self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id)
return await self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id)
NODE_CLASS_MAPPINGS = {

View File

@ -30,7 +30,7 @@ from comfy.comfy_types.node_typing import IO, ComfyNodeABC
from comfy_api.input_impl import VideoFromFile
import torch
import requests
import aiohttp
from io import BytesIO
@ -47,7 +47,7 @@ def get_video_url_from_response(
return str(response.Resp.url)
def upload_image_to_pixverse(image: torch.Tensor, auth_kwargs=None):
async def upload_image_to_pixverse(image: torch.Tensor, auth_kwargs=None):
# first, upload image to Pixverse and get image id to use in actual generation call
files = {"image": tensor_to_bytesio(image)}
operation = SynchronousOperation(
@ -62,7 +62,7 @@ def upload_image_to_pixverse(image: torch.Tensor, auth_kwargs=None):
content_type="multipart/form-data",
auth_kwargs=auth_kwargs,
)
response_upload: PixverseImageUploadResponse = operation.execute()
response_upload: PixverseImageUploadResponse = await operation.execute()
if response_upload.Resp is None:
raise Exception(
@ -164,7 +164,7 @@ class PixverseTextToVideoNode(ComfyNodeABC):
},
}
def api_call(
async def api_call(
self,
prompt: str,
aspect_ratio: str,
@ -205,7 +205,7 @@ class PixverseTextToVideoNode(ComfyNodeABC):
),
auth_kwargs=kwargs,
)
response_api = operation.execute()
response_api = await operation.execute()
if response_api.Resp is None:
raise Exception(f"PixVerse request failed: '{response_api.ErrMsg}'")
@ -229,11 +229,11 @@ class PixverseTextToVideoNode(ComfyNodeABC):
result_url_extractor=get_video_url_from_response,
estimated_duration=AVERAGE_DURATION_T2V,
)
response_poll = operation.execute()
response_poll = await operation.execute()
vid_response = requests.get(response_poll.Resp.url)
return (VideoFromFile(BytesIO(vid_response.content)),)
async with aiohttp.ClientSession() as session:
async with session.get(response_poll.Resp.url) as vid_response:
return (VideoFromFile(BytesIO(await vid_response.content.read())),)
class PixverseImageToVideoNode(ComfyNodeABC):
@ -302,7 +302,7 @@ class PixverseImageToVideoNode(ComfyNodeABC):
},
}
def api_call(
async def api_call(
self,
image: torch.Tensor,
prompt: str,
@ -316,7 +316,7 @@ class PixverseImageToVideoNode(ComfyNodeABC):
**kwargs,
):
validate_string(prompt, strip_whitespace=False)
img_id = upload_image_to_pixverse(image, auth_kwargs=kwargs)
img_id = await upload_image_to_pixverse(image, auth_kwargs=kwargs)
# 1080p is limited to 5 seconds duration
# only normal motion_mode supported for 1080p or for non-5 second duration
@ -345,7 +345,7 @@ class PixverseImageToVideoNode(ComfyNodeABC):
),
auth_kwargs=kwargs,
)
response_api = operation.execute()
response_api = await operation.execute()
if response_api.Resp is None:
raise Exception(f"PixVerse request failed: '{response_api.ErrMsg}'")
@ -369,10 +369,11 @@ class PixverseImageToVideoNode(ComfyNodeABC):
result_url_extractor=get_video_url_from_response,
estimated_duration=AVERAGE_DURATION_I2V,
)
response_poll = operation.execute()
response_poll = await operation.execute()
vid_response = requests.get(response_poll.Resp.url)
return (VideoFromFile(BytesIO(vid_response.content)),)
async with aiohttp.ClientSession() as session:
async with session.get(response_poll.Resp.url) as vid_response:
return (VideoFromFile(BytesIO(await vid_response.content.read())),)
class PixverseTransitionVideoNode(ComfyNodeABC):
@ -436,7 +437,7 @@ class PixverseTransitionVideoNode(ComfyNodeABC):
},
}
def api_call(
async def api_call(
self,
first_frame: torch.Tensor,
last_frame: torch.Tensor,
@ -450,8 +451,8 @@ class PixverseTransitionVideoNode(ComfyNodeABC):
**kwargs,
):
validate_string(prompt, strip_whitespace=False)
first_frame_id = upload_image_to_pixverse(first_frame, auth_kwargs=kwargs)
last_frame_id = upload_image_to_pixverse(last_frame, auth_kwargs=kwargs)
first_frame_id = await upload_image_to_pixverse(first_frame, auth_kwargs=kwargs)
last_frame_id = await upload_image_to_pixverse(last_frame, auth_kwargs=kwargs)
# 1080p is limited to 5 seconds duration
# only normal motion_mode supported for 1080p or for non-5 second duration
@ -480,7 +481,7 @@ class PixverseTransitionVideoNode(ComfyNodeABC):
),
auth_kwargs=kwargs,
)
response_api = operation.execute()
response_api = await operation.execute()
if response_api.Resp is None:
raise Exception(f"PixVerse request failed: '{response_api.ErrMsg}'")
@ -504,10 +505,11 @@ class PixverseTransitionVideoNode(ComfyNodeABC):
result_url_extractor=get_video_url_from_response,
estimated_duration=AVERAGE_DURATION_T2V,
)
response_poll = operation.execute()
response_poll = await operation.execute()
vid_response = requests.get(response_poll.Resp.url)
return (VideoFromFile(BytesIO(vid_response.content)),)
async with aiohttp.ClientSession() as session:
async with session.get(response_poll.Resp.url) as vid_response:
return (VideoFromFile(BytesIO(await vid_response.content.read())),)
NODE_CLASS_MAPPINGS = {

View File

@ -37,7 +37,7 @@ from io import BytesIO
from PIL import UnidentifiedImageError
def handle_recraft_file_request(
async def handle_recraft_file_request(
image: torch.Tensor,
path: str,
mask: torch.Tensor=None,
@ -71,13 +71,13 @@ def handle_recraft_file_request(
auth_kwargs=auth_kwargs,
multipart_parser=recraft_multipart_parser,
)
response: RecraftImageGenerationResponse = operation.execute()
response: RecraftImageGenerationResponse = await operation.execute()
all_bytesio = []
if response.image is not None:
all_bytesio.append(download_url_to_bytesio(response.image.url, timeout=timeout))
all_bytesio.append(await download_url_to_bytesio(response.image.url, timeout=timeout))
else:
for data in response.data:
all_bytesio.append(download_url_to_bytesio(data.url, timeout=timeout))
all_bytesio.append(await download_url_to_bytesio(data.url, timeout=timeout))
return all_bytesio
@ -395,7 +395,7 @@ class RecraftTextToImageNode:
},
}
def api_call(
async def api_call(
self,
prompt: str,
size: str,
@ -439,7 +439,7 @@ class RecraftTextToImageNode:
),
auth_kwargs=kwargs,
)
response: RecraftImageGenerationResponse = operation.execute()
response: RecraftImageGenerationResponse = await operation.execute()
images = []
urls = []
for data in response.data:
@ -451,7 +451,7 @@ class RecraftTextToImageNode:
f"Result URL: {urls_string}", unique_id
)
image = bytesio_to_image_tensor(
download_url_to_bytesio(data.url, timeout=1024)
await download_url_to_bytesio(data.url, timeout=1024)
)
if len(image.shape) < 4:
image = image.unsqueeze(0)
@ -538,7 +538,7 @@ class RecraftImageToImageNode:
},
}
def api_call(
async def api_call(
self,
image: torch.Tensor,
prompt: str,
@ -578,7 +578,7 @@ class RecraftImageToImageNode:
total = image.shape[0]
pbar = ProgressBar(total)
for i in range(total):
sub_bytes = handle_recraft_file_request(
sub_bytes = await handle_recraft_file_request(
image=image[i],
path="/proxy/recraft/images/imageToImage",
request=request,
@ -654,7 +654,7 @@ class RecraftImageInpaintingNode:
},
}
def api_call(
async def api_call(
self,
image: torch.Tensor,
mask: torch.Tensor,
@ -690,7 +690,7 @@ class RecraftImageInpaintingNode:
total = image.shape[0]
pbar = ProgressBar(total)
for i in range(total):
sub_bytes = handle_recraft_file_request(
sub_bytes = await handle_recraft_file_request(
image=image[i],
mask=mask[i:i+1],
path="/proxy/recraft/images/inpaint",
@ -779,7 +779,7 @@ class RecraftTextToVectorNode:
},
}
def api_call(
async def api_call(
self,
prompt: str,
substyle: str,
@ -821,7 +821,7 @@ class RecraftTextToVectorNode:
),
auth_kwargs=kwargs,
)
response: RecraftImageGenerationResponse = operation.execute()
response: RecraftImageGenerationResponse = await operation.execute()
svg_data = []
urls = []
for data in response.data:
@ -831,7 +831,7 @@ class RecraftTextToVectorNode:
PromptServer.instance.send_progress_text(
f"Result URL: {' '.join(urls)}", unique_id
)
svg_data.append(download_url_to_bytesio(data.url, timeout=1024))
svg_data.append(await download_url_to_bytesio(data.url, timeout=1024))
return (SVG(svg_data),)
@ -861,7 +861,7 @@ class RecraftVectorizeImageNode:
},
}
def api_call(
async def api_call(
self,
image: torch.Tensor,
**kwargs,
@ -870,7 +870,7 @@ class RecraftVectorizeImageNode:
total = image.shape[0]
pbar = ProgressBar(total)
for i in range(total):
sub_bytes = handle_recraft_file_request(
sub_bytes = await handle_recraft_file_request(
image=image[i],
path="/proxy/recraft/images/vectorize",
auth_kwargs=kwargs,
@ -942,7 +942,7 @@ class RecraftReplaceBackgroundNode:
},
}
def api_call(
async def api_call(
self,
image: torch.Tensor,
prompt: str,
@ -973,7 +973,7 @@ class RecraftReplaceBackgroundNode:
total = image.shape[0]
pbar = ProgressBar(total)
for i in range(total):
sub_bytes = handle_recraft_file_request(
sub_bytes = await handle_recraft_file_request(
image=image[i],
path="/proxy/recraft/images/replaceBackground",
request=request,
@ -1011,7 +1011,7 @@ class RecraftRemoveBackgroundNode:
},
}
def api_call(
async def api_call(
self,
image: torch.Tensor,
**kwargs,
@ -1020,7 +1020,7 @@ class RecraftRemoveBackgroundNode:
total = image.shape[0]
pbar = ProgressBar(total)
for i in range(total):
sub_bytes = handle_recraft_file_request(
sub_bytes = await handle_recraft_file_request(
image=image[i],
path="/proxy/recraft/images/removeBackground",
auth_kwargs=kwargs,
@ -1062,7 +1062,7 @@ class RecraftCrispUpscaleNode:
},
}
def api_call(
async def api_call(
self,
image: torch.Tensor,
**kwargs,
@ -1071,7 +1071,7 @@ class RecraftCrispUpscaleNode:
total = image.shape[0]
pbar = ProgressBar(total)
for i in range(total):
sub_bytes = handle_recraft_file_request(
sub_bytes = await handle_recraft_file_request(
image=image[i],
path=self.RECRAFT_PATH,
auth_kwargs=kwargs,

View File

@ -9,11 +9,10 @@ from __future__ import annotations
from inspect import cleandoc
from comfy.comfy_types.node_typing import IO
import folder_paths as comfy_paths
import requests
import aiohttp
import os
import datetime
import shutil
import time
import asyncio
import io
import logging
import math
@ -66,7 +65,6 @@ def create_task_error(response: Rodin3DGenerateResponse):
return hasattr(response, "error")
class Rodin3DAPI:
"""
Generate 3D Assets using Rodin API
@ -123,8 +121,8 @@ class Rodin3DAPI:
else:
return "Generating"
def CreateGenerateTask(self, images=None, seed=1, material="PBR", quality="medium", tier="Regular", mesh_mode="Quad", **kwargs):
if images == None:
async def create_generate_task(self, images=None, seed=1, material="PBR", quality="medium", tier="Regular", mesh_mode="Quad", **kwargs):
if images is None:
raise Exception("Rodin 3D generate requires at least 1 image.")
if len(images) >= 5:
raise Exception("Rodin 3D generate requires up to 5 image.")
@ -155,7 +153,7 @@ class Rodin3DAPI:
auth_kwargs=kwargs,
)
response = operation.execute()
response = await operation.execute()
if create_task_error(response):
error_message = f"Rodin3D Create 3D generate Task Failed. Message: {response.message}, error: {response.error}"
@ -168,7 +166,7 @@ class Rodin3DAPI:
logging.info(f"[ Rodin3D API - Submit Jobs ] UUID: {task_uuid}")
return task_uuid, subscription_key
def poll_for_task_status(self, subscription_key, **kwargs) -> Rodin3DCheckStatusResponse:
async def poll_for_task_status(self, subscription_key, **kwargs) -> Rodin3DCheckStatusResponse:
path = "/proxy/rodin/api/v2/status"
@ -191,11 +189,9 @@ class Rodin3DAPI:
logging.info("[ Rodin3D API - CheckStatus ] Generate Start!")
return poll_operation.execute()
return await poll_operation.execute()
def GetRodinDownloadList(self, uuid, **kwargs) -> Rodin3DDownloadResponse:
async def get_rodin_download_list(self, uuid, **kwargs) -> Rodin3DDownloadResponse:
logging.info("[ Rodin3D API - Downloading ] Generate Successfully!")
path = "/proxy/rodin/api/v2/download"
@ -212,53 +208,59 @@ class Rodin3DAPI:
auth_kwargs=kwargs
)
return operation.execute()
return await operation.execute()
def GetQualityAndMode(self, PolyCount):
if PolyCount == "200K-Triangle":
def get_quality_mode(self, poly_count):
if poly_count == "200K-Triangle":
mesh_mode = "Raw"
quality = "medium"
else:
mesh_mode = "Quad"
if PolyCount == "4K-Quad":
if poly_count == "4K-Quad":
quality = "extra-low"
elif PolyCount == "8K-Quad":
elif poly_count == "8K-Quad":
quality = "low"
elif PolyCount == "18K-Quad":
elif poly_count == "18K-Quad":
quality = "medium"
elif PolyCount == "50K-Quad":
elif poly_count == "50K-Quad":
quality = "high"
else:
quality = "medium"
return mesh_mode, quality
def DownLoadFiles(self, Url_List):
Save_path = os.path.join(comfy_paths.get_output_directory(), "Rodin3D", datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"))
os.makedirs(Save_path, exist_ok=True)
async def download_files(self, url_list):
save_path = os.path.join(comfy_paths.get_output_directory(), "Rodin3D", datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"))
os.makedirs(save_path, exist_ok=True)
model_file_path = None
for Item in Url_List.list:
url = Item.url
file_name = Item.name
file_path = os.path.join(Save_path, file_name)
if file_path.endswith(".glb"):
model_file_path = file_path
logging.info(f"[ Rodin3D API - download_files ] Downloading file: {file_path}")
max_retries = 5
for attempt in range(max_retries):
try:
with requests.get(url, stream=True) as r:
r.raise_for_status()
with open(file_path, "wb") as f:
shutil.copyfileobj(r.raw, f)
break
except Exception as e:
logging.info(f"[ Rodin3D API - download_files ] Error downloading {file_path}:{e}")
if attempt < max_retries - 1:
logging.info("Retrying...")
time.sleep(2)
else:
logging.info(f"[ Rodin3D API - download_files ] Failed to download {file_path} after {max_retries} attempts.")
async with aiohttp.ClientSession() as session:
for i in url_list.list:
url = i.url
file_name = i.name
file_path = os.path.join(save_path, file_name)
if file_path.endswith(".glb"):
model_file_path = file_path
logging.info(f"[ Rodin3D API - download_files ] Downloading file: {file_path}")
max_retries = 5
for attempt in range(max_retries):
try:
async with session.get(url) as resp:
resp.raise_for_status()
with open(file_path, "wb") as f:
async for chunk in resp.content.iter_chunked(32 * 1024):
f.write(chunk)
break
except Exception as e:
logging.info(f"[ Rodin3D API - download_files ] Error downloading {file_path}:{e}")
if attempt < max_retries - 1:
logging.info("Retrying...")
await asyncio.sleep(2)
else:
logging.info(
"[ Rodin3D API - download_files ] Failed to download %s after %s attempts.",
file_path,
max_retries,
)
return model_file_path
@ -285,7 +287,7 @@ class Rodin3D_Regular(Rodin3DAPI):
},
}
def api_call(
async def api_call(
self,
Images,
Seed,
@ -298,14 +300,17 @@ class Rodin3D_Regular(Rodin3DAPI):
m_images = []
for i in range(num_images):
m_images.append(Images[i])
mesh_mode, quality = self.GetQualityAndMode(Polygon_count)
task_uuid, subscription_key = self.CreateGenerateTask(images=m_images, seed=Seed, material=Material_Type, quality=quality, tier=tier, mesh_mode=mesh_mode, **kwargs)
self.poll_for_task_status(subscription_key, **kwargs)
Download_List = self.GetRodinDownloadList(task_uuid, **kwargs)
model = self.DownLoadFiles(Download_List)
mesh_mode, quality = self.get_quality_mode(Polygon_count)
task_uuid, subscription_key = await self.create_generate_task(images=m_images, seed=Seed, material=Material_Type,
quality=quality, tier=tier, mesh_mode=mesh_mode,
**kwargs)
await self.poll_for_task_status(subscription_key, **kwargs)
download_list = await self.get_rodin_download_list(task_uuid, **kwargs)
model = await self.download_files(download_list)
return (model,)
class Rodin3D_Detail(Rodin3DAPI):
@classmethod
def INPUT_TYPES(s):
@ -328,7 +333,7 @@ class Rodin3D_Detail(Rodin3DAPI):
},
}
def api_call(
async def api_call(
self,
Images,
Seed,
@ -341,14 +346,17 @@ class Rodin3D_Detail(Rodin3DAPI):
m_images = []
for i in range(num_images):
m_images.append(Images[i])
mesh_mode, quality = self.GetQualityAndMode(Polygon_count)
task_uuid, subscription_key = self.CreateGenerateTask(images=m_images, seed=Seed, material=Material_Type, quality=quality, tier=tier, mesh_mode=mesh_mode, **kwargs)
self.poll_for_task_status(subscription_key, **kwargs)
Download_List = self.GetRodinDownloadList(task_uuid, **kwargs)
model = self.DownLoadFiles(Download_List)
mesh_mode, quality = self.get_quality_mode(Polygon_count)
task_uuid, subscription_key = await self.create_generate_task(images=m_images, seed=Seed, material=Material_Type,
quality=quality, tier=tier, mesh_mode=mesh_mode,
**kwargs)
await self.poll_for_task_status(subscription_key, **kwargs)
download_list = await self.get_rodin_download_list(task_uuid, **kwargs)
model = await self.download_files(download_list)
return (model,)
class Rodin3D_Smooth(Rodin3DAPI):
@classmethod
def INPUT_TYPES(s):
@ -371,7 +379,7 @@ class Rodin3D_Smooth(Rodin3DAPI):
},
}
def api_call(
async def api_call(
self,
Images,
Seed,
@ -384,14 +392,17 @@ class Rodin3D_Smooth(Rodin3DAPI):
m_images = []
for i in range(num_images):
m_images.append(Images[i])
mesh_mode, quality = self.GetQualityAndMode(Polygon_count)
task_uuid, subscription_key = self.CreateGenerateTask(images=m_images, seed=Seed, material=Material_Type, quality=quality, tier=tier, mesh_mode=mesh_mode, **kwargs)
self.poll_for_task_status(subscription_key, **kwargs)
Download_List = self.GetRodinDownloadList(task_uuid, **kwargs)
model = self.DownLoadFiles(Download_List)
mesh_mode, quality = self.get_quality_mode(Polygon_count)
task_uuid, subscription_key = await self.create_generate_task(images=m_images, seed=Seed, material=Material_Type,
quality=quality, tier=tier, mesh_mode=mesh_mode,
**kwargs)
await self.poll_for_task_status(subscription_key, **kwargs)
download_list = await self.get_rodin_download_list(task_uuid, **kwargs)
model = await self.download_files(download_list)
return (model,)
class Rodin3D_Sketch(Rodin3DAPI):
@classmethod
def INPUT_TYPES(s):
@ -423,7 +434,7 @@ class Rodin3D_Sketch(Rodin3DAPI):
},
}
def api_call(
async def api_call(
self,
Images,
Seed,
@ -437,10 +448,12 @@ class Rodin3D_Sketch(Rodin3DAPI):
material_type = "PBR"
quality = "medium"
mesh_mode = "Quad"
task_uuid, subscription_key = self.CreateGenerateTask(images=m_images, seed=Seed, material=material_type, quality=quality, tier=tier, mesh_mode=mesh_mode, **kwargs)
self.poll_for_task_status(subscription_key, **kwargs)
Download_List = self.GetRodinDownloadList(task_uuid, **kwargs)
model = self.DownLoadFiles(Download_List)
task_uuid, subscription_key = await self.create_generate_task(
images=m_images, seed=Seed, material=material_type, quality=quality, tier=tier, mesh_mode=mesh_mode, **kwargs
)
await self.poll_for_task_status(subscription_key, **kwargs)
download_list = await self.get_rodin_download_list(task_uuid, **kwargs)
model = await self.download_files(download_list)
return (model,)

View File

@ -99,14 +99,14 @@ def validate_input_image(image: torch.Tensor) -> bool:
return image.shape[2] < 8000 and image.shape[1] < 8000
def poll_until_finished(
async def poll_until_finished(
auth_kwargs: dict[str, str],
api_endpoint: ApiEndpoint[Any, TaskStatusResponse],
estimated_duration: Optional[int] = None,
node_id: Optional[str] = None,
) -> TaskStatusResponse:
"""Polls the Runway API endpoint until the task reaches a terminal state, then returns the response."""
return PollingOperation(
return await PollingOperation(
poll_endpoint=api_endpoint,
completed_statuses=[
TaskStatus.SUCCEEDED.value,
@ -115,7 +115,7 @@ def poll_until_finished(
TaskStatus.FAILED.value,
TaskStatus.CANCELLED.value,
],
status_extractor=lambda response: (response.status.value),
status_extractor=lambda response: response.status.value,
auth_kwargs=auth_kwargs,
result_url_extractor=get_video_url_from_task_status,
estimated_duration=estimated_duration,
@ -167,11 +167,11 @@ class RunwayVideoGenNode(ComfyNodeABC):
)
return True
def get_response(
async def get_response(
self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
) -> RunwayImageToVideoResponse:
"""Poll the task status until it is finished then get the response."""
return poll_until_finished(
return await poll_until_finished(
auth_kwargs,
ApiEndpoint(
path=f"{PATH_GET_TASK_STATUS}/{task_id}",
@ -183,7 +183,7 @@ class RunwayVideoGenNode(ComfyNodeABC):
node_id=node_id,
)
def generate_video(
async def generate_video(
self,
request: RunwayImageToVideoRequest,
auth_kwargs: dict[str, str],
@ -200,15 +200,15 @@ class RunwayVideoGenNode(ComfyNodeABC):
auth_kwargs=auth_kwargs,
)
initial_response = initial_operation.execute()
initial_response = await initial_operation.execute()
self.validate_task_created(initial_response)
task_id = initial_response.id
final_response = self.get_response(task_id, auth_kwargs, node_id)
final_response = await self.get_response(task_id, auth_kwargs, node_id)
self.validate_response(final_response)
video_url = get_video_url_from_task_status(final_response)
return (download_url_to_video_output(video_url),)
return (await download_url_to_video_output(video_url),)
class RunwayImageToVideoNodeGen3a(RunwayVideoGenNode):
@ -250,7 +250,7 @@ class RunwayImageToVideoNodeGen3a(RunwayVideoGenNode):
},
}
def api_call(
async def api_call(
self,
prompt: str,
start_frame: torch.Tensor,
@ -265,7 +265,7 @@ class RunwayImageToVideoNodeGen3a(RunwayVideoGenNode):
validate_input_image(start_frame)
# Upload image
download_urls = upload_images_to_comfyapi(
download_urls = await upload_images_to_comfyapi(
start_frame,
max_images=1,
mime_type="image/png",
@ -274,7 +274,7 @@ class RunwayImageToVideoNodeGen3a(RunwayVideoGenNode):
if len(download_urls) != 1:
raise RunwayApiError("Failed to upload one or more images to comfy api.")
return self.generate_video(
return await self.generate_video(
RunwayImageToVideoRequest(
promptText=prompt,
seed=seed,
@ -333,7 +333,7 @@ class RunwayImageToVideoNodeGen4(RunwayVideoGenNode):
},
}
def api_call(
async def api_call(
self,
prompt: str,
start_frame: torch.Tensor,
@ -348,7 +348,7 @@ class RunwayImageToVideoNodeGen4(RunwayVideoGenNode):
validate_input_image(start_frame)
# Upload image
download_urls = upload_images_to_comfyapi(
download_urls = await upload_images_to_comfyapi(
start_frame,
max_images=1,
mime_type="image/png",
@ -357,7 +357,7 @@ class RunwayImageToVideoNodeGen4(RunwayVideoGenNode):
if len(download_urls) != 1:
raise RunwayApiError("Failed to upload one or more images to comfy api.")
return self.generate_video(
return await self.generate_video(
RunwayImageToVideoRequest(
promptText=prompt,
seed=seed,
@ -382,10 +382,10 @@ class RunwayFirstLastFrameNode(RunwayVideoGenNode):
DESCRIPTION = "Upload first and last keyframes, draft a prompt, and generate a video. More complex transitions, such as cases where the Last frame is completely different from the First frame, may benefit from the longer 10s duration. This would give the generation more time to smoothly transition between the two inputs. Before diving in, review these best practices to ensure that your input selections will set your generation up for success: https://help.runwayml.com/hc/en-us/articles/34170748696595-Creating-with-Keyframes-on-Gen-3."
def get_response(
async def get_response(
self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
) -> RunwayImageToVideoResponse:
return poll_until_finished(
return await poll_until_finished(
auth_kwargs,
ApiEndpoint(
path=f"{PATH_GET_TASK_STATUS}/{task_id}",
@ -437,7 +437,7 @@ class RunwayFirstLastFrameNode(RunwayVideoGenNode):
},
}
def api_call(
async def api_call(
self,
prompt: str,
start_frame: torch.Tensor,
@ -455,7 +455,7 @@ class RunwayFirstLastFrameNode(RunwayVideoGenNode):
# Upload images
stacked_input_images = image_tensor_pair_to_batch(start_frame, end_frame)
download_urls = upload_images_to_comfyapi(
download_urls = await upload_images_to_comfyapi(
stacked_input_images,
max_images=2,
mime_type="image/png",
@ -464,7 +464,7 @@ class RunwayFirstLastFrameNode(RunwayVideoGenNode):
if len(download_urls) != 2:
raise RunwayApiError("Failed to upload one or more images to comfy api.")
return self.generate_video(
return await self.generate_video(
RunwayImageToVideoRequest(
promptText=prompt,
seed=seed,
@ -543,11 +543,11 @@ class RunwayTextToImageNode(ComfyNodeABC):
)
return True
def get_response(
async def get_response(
self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
) -> TaskStatusResponse:
"""Poll the task status until it is finished then get the response."""
return poll_until_finished(
return await poll_until_finished(
auth_kwargs,
ApiEndpoint(
path=f"{PATH_GET_TASK_STATUS}/{task_id}",
@ -559,7 +559,7 @@ class RunwayTextToImageNode(ComfyNodeABC):
node_id=node_id,
)
def api_call(
async def api_call(
self,
prompt: str,
ratio: str,
@ -574,7 +574,7 @@ class RunwayTextToImageNode(ComfyNodeABC):
reference_images = None
if reference_image is not None:
validate_input_image(reference_image)
download_urls = upload_images_to_comfyapi(
download_urls = await upload_images_to_comfyapi(
reference_image,
max_images=1,
mime_type="image/png",
@ -605,19 +605,19 @@ class RunwayTextToImageNode(ComfyNodeABC):
auth_kwargs=kwargs,
)
initial_response = initial_operation.execute()
initial_response = await initial_operation.execute()
self.validate_task_created(initial_response)
task_id = initial_response.id
# Poll for completion
final_response = self.get_response(
final_response = await self.get_response(
task_id, auth_kwargs=kwargs, node_id=unique_id
)
self.validate_response(final_response)
# Download and return image
image_url = get_image_url_from_task_status(final_response)
return (download_url_to_image_tensor(image_url),)
return (await download_url_to_image_tensor(image_url),)
NODE_CLASS_MAPPINGS = {

View File

@ -124,7 +124,7 @@ class StabilityStableImageUltraNode:
},
}
def api_call(self, prompt: str, aspect_ratio: str, style_preset: str, seed: int,
async def api_call(self, prompt: str, aspect_ratio: str, style_preset: str, seed: int,
negative_prompt: str=None, image: torch.Tensor = None, image_denoise: float=None,
**kwargs):
validate_string(prompt, strip_whitespace=False)
@ -163,7 +163,7 @@ class StabilityStableImageUltraNode:
content_type="multipart/form-data",
auth_kwargs=kwargs,
)
response_api = operation.execute()
response_api = await operation.execute()
if response_api.finish_reason != "SUCCESS":
raise Exception(f"Stable Image Ultra generation failed: {response_api.finish_reason}.")
@ -257,7 +257,7 @@ class StabilityStableImageSD_3_5Node:
},
}
def api_call(self, model: str, prompt: str, aspect_ratio: str, style_preset: str, seed: int, cfg_scale: float,
async def api_call(self, model: str, prompt: str, aspect_ratio: str, style_preset: str, seed: int, cfg_scale: float,
negative_prompt: str=None, image: torch.Tensor = None, image_denoise: float=None,
**kwargs):
validate_string(prompt, strip_whitespace=False)
@ -302,7 +302,7 @@ class StabilityStableImageSD_3_5Node:
content_type="multipart/form-data",
auth_kwargs=kwargs,
)
response_api = operation.execute()
response_api = await operation.execute()
if response_api.finish_reason != "SUCCESS":
raise Exception(f"Stable Diffusion 3.5 Image generation failed: {response_api.finish_reason}.")
@ -374,7 +374,7 @@ class StabilityUpscaleConservativeNode:
},
}
def api_call(self, image: torch.Tensor, prompt: str, creativity: float, seed: int, negative_prompt: str=None,
async def api_call(self, image: torch.Tensor, prompt: str, creativity: float, seed: int, negative_prompt: str=None,
**kwargs):
validate_string(prompt, strip_whitespace=False)
image_binary = tensor_to_bytesio(image, total_pixels=1024*1024).read()
@ -403,7 +403,7 @@ class StabilityUpscaleConservativeNode:
content_type="multipart/form-data",
auth_kwargs=kwargs,
)
response_api = operation.execute()
response_api = await operation.execute()
if response_api.finish_reason != "SUCCESS":
raise Exception(f"Stability Upscale Conservative generation failed: {response_api.finish_reason}.")
@ -480,7 +480,7 @@ class StabilityUpscaleCreativeNode:
},
}
def api_call(self, image: torch.Tensor, prompt: str, creativity: float, style_preset: str, seed: int, negative_prompt: str=None,
async def api_call(self, image: torch.Tensor, prompt: str, creativity: float, style_preset: str, seed: int, negative_prompt: str=None,
**kwargs):
validate_string(prompt, strip_whitespace=False)
image_binary = tensor_to_bytesio(image, total_pixels=1024*1024).read()
@ -512,7 +512,7 @@ class StabilityUpscaleCreativeNode:
content_type="multipart/form-data",
auth_kwargs=kwargs,
)
response_api = operation.execute()
response_api = await operation.execute()
operation = PollingOperation(
poll_endpoint=ApiEndpoint(
@ -527,7 +527,7 @@ class StabilityUpscaleCreativeNode:
status_extractor=lambda x: get_async_dummy_status(x),
auth_kwargs=kwargs,
)
response_poll: StabilityResultsGetResponse = operation.execute()
response_poll: StabilityResultsGetResponse = await operation.execute()
if response_poll.finish_reason != "SUCCESS":
raise Exception(f"Stability Upscale Creative generation failed: {response_poll.finish_reason}.")
@ -563,8 +563,7 @@ class StabilityUpscaleFastNode:
},
}
def api_call(self, image: torch.Tensor,
**kwargs):
async def api_call(self, image: torch.Tensor, **kwargs):
image_binary = tensor_to_bytesio(image, total_pixels=4096*4096).read()
files = {
@ -583,7 +582,7 @@ class StabilityUpscaleFastNode:
content_type="multipart/form-data",
auth_kwargs=kwargs,
)
response_api = operation.execute()
response_api = await operation.execute()
if response_api.finish_reason != "SUCCESS":
raise Exception(f"Stability Upscale Fast failed: {response_api.finish_reason}.")

View File

@ -37,8 +37,8 @@ from comfy_api_nodes.apinode_utils import (
)
def upload_image_to_tripo(image, **kwargs):
urls = upload_images_to_comfyapi(image, max_images=1, auth_kwargs=kwargs)
async def upload_image_to_tripo(image, **kwargs):
urls = await upload_images_to_comfyapi(image, max_images=1, auth_kwargs=kwargs)
return TripoFileReference(TripoUrlReference(url=urls[0], type="jpeg"))
def get_model_url_from_response(response: TripoTaskResponse) -> str:
@ -49,7 +49,7 @@ def get_model_url_from_response(response: TripoTaskResponse) -> str:
raise RuntimeError(f"Failed to get model url from response: {response}")
def poll_until_finished(
async def poll_until_finished(
kwargs: dict[str, str],
response: TripoTaskResponse,
) -> tuple[str, str]:
@ -57,7 +57,7 @@ def poll_until_finished(
if response.code != 0:
raise RuntimeError(f"Failed to generate mesh: {response.error}")
task_id = response.data.task_id
response_poll = PollingOperation(
response_poll = await PollingOperation(
poll_endpoint=ApiEndpoint(
path=f"/proxy/tripo/v2/openapi/task/{task_id}",
method=HttpMethod.GET,
@ -80,7 +80,7 @@ def poll_until_finished(
).execute()
if response_poll.data.status == TripoTaskStatus.SUCCESS:
url = get_model_url_from_response(response_poll)
bytesio = download_url_to_bytesio(url)
bytesio = await download_url_to_bytesio(url)
# Save the downloaded model file
model_file = f"tripo_model_{task_id}.glb"
with open(os.path.join(get_output_directory(), model_file), "wb") as f:
@ -88,6 +88,7 @@ def poll_until_finished(
return model_file, task_id
raise RuntimeError(f"Failed to generate mesh: {response_poll}")
class TripoTextToModelNode:
"""
Generates 3D models synchronously based on a text prompt using Tripo's API.
@ -126,11 +127,11 @@ class TripoTextToModelNode:
API_NODE = True
OUTPUT_NODE = True
def generate_mesh(self, prompt, negative_prompt=None, model_version=None, style=None, texture=None, pbr=None, image_seed=None, model_seed=None, texture_seed=None, texture_quality=None, face_limit=None, quad=None, **kwargs):
async def generate_mesh(self, prompt, negative_prompt=None, model_version=None, style=None, texture=None, pbr=None, image_seed=None, model_seed=None, texture_seed=None, texture_quality=None, face_limit=None, quad=None, **kwargs):
style_enum = None if style == "None" else style
if not prompt:
raise RuntimeError("Prompt is required")
response = SynchronousOperation(
response = await SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/tripo/v2/openapi/task",
method=HttpMethod.POST,
@ -155,7 +156,8 @@ class TripoTextToModelNode:
),
auth_kwargs=kwargs,
).execute()
return poll_until_finished(kwargs, response)
return await poll_until_finished(kwargs, response)
class TripoImageToModelNode:
"""
@ -195,12 +197,12 @@ class TripoImageToModelNode:
API_NODE = True
OUTPUT_NODE = True
def generate_mesh(self, image, model_version=None, style=None, texture=None, pbr=None, model_seed=None, orientation=None, texture_alignment=None, texture_seed=None, texture_quality=None, face_limit=None, quad=None, **kwargs):
async def generate_mesh(self, image, model_version=None, style=None, texture=None, pbr=None, model_seed=None, orientation=None, texture_alignment=None, texture_seed=None, texture_quality=None, face_limit=None, quad=None, **kwargs):
style_enum = None if style == "None" else style
if image is None:
raise RuntimeError("Image is required")
tripo_file = upload_image_to_tripo(image, **kwargs)
response = SynchronousOperation(
tripo_file = await upload_image_to_tripo(image, **kwargs)
response = await SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/tripo/v2/openapi/task",
method=HttpMethod.POST,
@ -225,7 +227,8 @@ class TripoImageToModelNode:
),
auth_kwargs=kwargs,
).execute()
return poll_until_finished(kwargs, response)
return await poll_until_finished(kwargs, response)
class TripoMultiviewToModelNode:
"""
@ -267,7 +270,7 @@ class TripoMultiviewToModelNode:
API_NODE = True
OUTPUT_NODE = True
def generate_mesh(self, image, image_left=None, image_back=None, image_right=None, model_version=None, orientation=None, texture=None, pbr=None, model_seed=None, texture_seed=None, texture_quality=None, texture_alignment=None, face_limit=None, quad=None, **kwargs):
async def generate_mesh(self, image, image_left=None, image_back=None, image_right=None, model_version=None, orientation=None, texture=None, pbr=None, model_seed=None, texture_seed=None, texture_quality=None, texture_alignment=None, face_limit=None, quad=None, **kwargs):
if image is None:
raise RuntimeError("front image for multiview is required")
images = []
@ -282,11 +285,11 @@ class TripoMultiviewToModelNode:
for image_name in ["image", "image_left", "image_back", "image_right"]:
image_ = image_dict[image_name]
if image_ is not None:
tripo_file = upload_image_to_tripo(image_, **kwargs)
tripo_file = await upload_image_to_tripo(image_, **kwargs)
images.append(tripo_file)
else:
images.append(TripoFileEmptyReference())
response = SynchronousOperation(
response = await SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/tripo/v2/openapi/task",
method=HttpMethod.POST,
@ -309,7 +312,8 @@ class TripoMultiviewToModelNode:
),
auth_kwargs=kwargs,
).execute()
return poll_until_finished(kwargs, response)
return await poll_until_finished(kwargs, response)
class TripoTextureNode:
@classmethod
@ -340,8 +344,8 @@ class TripoTextureNode:
OUTPUT_NODE = True
AVERAGE_DURATION = 80
def generate_mesh(self, model_task_id, texture=None, pbr=None, texture_seed=None, texture_quality=None, texture_alignment=None, **kwargs):
response = SynchronousOperation(
async def generate_mesh(self, model_task_id, texture=None, pbr=None, texture_seed=None, texture_quality=None, texture_alignment=None, **kwargs):
response = await SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/tripo/v2/openapi/task",
method=HttpMethod.POST,
@ -358,7 +362,7 @@ class TripoTextureNode:
),
auth_kwargs=kwargs,
).execute()
return poll_until_finished(kwargs, response)
return await poll_until_finished(kwargs, response)
class TripoRefineNode:
@ -387,8 +391,8 @@ class TripoRefineNode:
OUTPUT_NODE = True
AVERAGE_DURATION = 240
def generate_mesh(self, model_task_id, **kwargs):
response = SynchronousOperation(
async def generate_mesh(self, model_task_id, **kwargs):
response = await SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/tripo/v2/openapi/task",
method=HttpMethod.POST,
@ -400,7 +404,7 @@ class TripoRefineNode:
),
auth_kwargs=kwargs,
).execute()
return poll_until_finished(kwargs, response)
return await poll_until_finished(kwargs, response)
class TripoRigNode:
@ -425,8 +429,8 @@ class TripoRigNode:
OUTPUT_NODE = True
AVERAGE_DURATION = 180
def generate_mesh(self, original_model_task_id, **kwargs):
response = SynchronousOperation(
async def generate_mesh(self, original_model_task_id, **kwargs):
response = await SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/tripo/v2/openapi/task",
method=HttpMethod.POST,
@ -440,7 +444,8 @@ class TripoRigNode:
),
auth_kwargs=kwargs,
).execute()
return poll_until_finished(kwargs, response)
return await poll_until_finished(kwargs, response)
class TripoRetargetNode:
@classmethod
@ -475,8 +480,8 @@ class TripoRetargetNode:
OUTPUT_NODE = True
AVERAGE_DURATION = 30
def generate_mesh(self, animation, original_model_task_id, **kwargs):
response = SynchronousOperation(
async def generate_mesh(self, animation, original_model_task_id, **kwargs):
response = await SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/tripo/v2/openapi/task",
method=HttpMethod.POST,
@ -491,7 +496,8 @@ class TripoRetargetNode:
),
auth_kwargs=kwargs,
).execute()
return poll_until_finished(kwargs, response)
return await poll_until_finished(kwargs, response)
class TripoConversionNode:
@classmethod
@ -529,10 +535,10 @@ class TripoConversionNode:
OUTPUT_NODE = True
AVERAGE_DURATION = 30
def generate_mesh(self, original_model_task_id, format, quad, face_limit, texture_size, texture_format, **kwargs):
async def generate_mesh(self, original_model_task_id, format, quad, face_limit, texture_size, texture_format, **kwargs):
if not original_model_task_id:
raise RuntimeError("original_model_task_id is required")
response = SynchronousOperation(
response = await SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/tripo/v2/openapi/task",
method=HttpMethod.POST,
@ -549,7 +555,8 @@ class TripoConversionNode:
),
auth_kwargs=kwargs,
).execute()
return poll_until_finished(kwargs, response)
return await poll_until_finished(kwargs, response)
NODE_CLASS_MAPPINGS = {
"TripoTextToModelNode": TripoTextToModelNode,

View File

@ -1,17 +1,17 @@
import io
import logging
import base64
import requests
import aiohttp
import torch
from typing import Optional
from comfy.comfy_types.node_typing import IO, ComfyNodeABC
from comfy_api.input_impl.video_types import VideoFromFile
from comfy_api_nodes.apis import (
Veo2GenVidRequest,
Veo2GenVidResponse,
Veo2GenVidPollRequest,
Veo2GenVidPollResponse
VeoGenVidRequest,
VeoGenVidResponse,
VeoGenVidPollRequest,
VeoGenVidPollResponse
)
from comfy_api_nodes.apis.client import (
ApiEndpoint,
@ -35,7 +35,7 @@ def convert_image_to_base64(image: torch.Tensor):
return tensor_to_base64_string(scaled_image)
def get_video_url_from_response(poll_response: Veo2GenVidPollResponse) -> Optional[str]:
def get_video_url_from_response(poll_response: VeoGenVidPollResponse) -> Optional[str]:
if (
poll_response.response
and hasattr(poll_response.response, "videos")
@ -130,6 +130,14 @@ class VeoVideoGenerationNode(ComfyNodeABC):
"default": None,
"tooltip": "Optional reference image to guide video generation",
}),
"model": (
IO.COMBO,
{
"options": ["veo-2.0-generate-001"],
"default": "veo-2.0-generate-001",
"tooltip": "Veo 2 model to use for video generation",
},
),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
@ -141,10 +149,10 @@ class VeoVideoGenerationNode(ComfyNodeABC):
RETURN_TYPES = (IO.VIDEO,)
FUNCTION = "generate_video"
CATEGORY = "api node/video/Veo"
DESCRIPTION = "Generates videos from text prompts using Google's Veo API"
DESCRIPTION = "Generates videos from text prompts using Google's Veo 2 API"
API_NODE = True
def generate_video(
async def generate_video(
self,
prompt,
aspect_ratio="16:9",
@ -154,6 +162,8 @@ class VeoVideoGenerationNode(ComfyNodeABC):
person_generation="ALLOW",
seed=0,
image=None,
model="veo-2.0-generate-001",
generate_audio=False,
unique_id: Optional[str] = None,
**kwargs,
):
@ -188,23 +198,26 @@ class VeoVideoGenerationNode(ComfyNodeABC):
parameters["negativePrompt"] = negative_prompt
if seed > 0:
parameters["seed"] = seed
# Only add generateAudio for Veo 3 models
if "veo-3.0" in model:
parameters["generateAudio"] = generate_audio
# Initial request to start video generation
initial_operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/veo/generate",
path=f"/proxy/veo/{model}/generate",
method=HttpMethod.POST,
request_model=Veo2GenVidRequest,
response_model=Veo2GenVidResponse
request_model=VeoGenVidRequest,
response_model=VeoGenVidResponse
),
request=Veo2GenVidRequest(
request=VeoGenVidRequest(
instances=instances,
parameters=parameters
),
auth_kwargs=kwargs,
)
initial_response = initial_operation.execute()
initial_response = await initial_operation.execute()
operation_name = initial_response.name
logging.info(f"Veo generation started with operation name: {operation_name}")
@ -223,16 +236,16 @@ class VeoVideoGenerationNode(ComfyNodeABC):
# Define the polling operation
poll_operation = PollingOperation(
poll_endpoint=ApiEndpoint(
path="/proxy/veo/poll",
path=f"/proxy/veo/{model}/poll",
method=HttpMethod.POST,
request_model=Veo2GenVidPollRequest,
response_model=Veo2GenVidPollResponse
request_model=VeoGenVidPollRequest,
response_model=VeoGenVidPollResponse
),
completed_statuses=["completed"],
failed_statuses=[], # No failed statuses, we'll handle errors after polling
status_extractor=status_extractor,
progress_extractor=progress_extractor,
request=Veo2GenVidPollRequest(
request=VeoGenVidPollRequest(
operationName=operation_name
),
auth_kwargs=kwargs,
@ -243,7 +256,7 @@ class VeoVideoGenerationNode(ComfyNodeABC):
)
# Execute the polling operation
poll_response = poll_operation.execute()
poll_response = await poll_operation.execute()
# Now check for errors in the final response
# Check for error in poll response
@ -268,7 +281,6 @@ class VeoVideoGenerationNode(ComfyNodeABC):
raise Exception(error_message)
# Extract video data
video_data = None
if poll_response.response and hasattr(poll_response.response, 'videos') and poll_response.response.videos and len(poll_response.response.videos) > 0:
video = poll_response.response.videos[0]
@ -278,9 +290,9 @@ class VeoVideoGenerationNode(ComfyNodeABC):
video_data = base64.b64decode(video.bytesBase64Encoded)
elif hasattr(video, 'gcsUri') and video.gcsUri:
# Download from URL
video_url = video.gcsUri
video_response = requests.get(video_url)
video_data = video_response.content
async with aiohttp.ClientSession() as session:
async with session.get(video.gcsUri) as video_response:
video_data = await video_response.content.read()
else:
raise Exception("Video returned but no data or URL was provided")
else:
@ -298,11 +310,64 @@ class VeoVideoGenerationNode(ComfyNodeABC):
return (VideoFromFile(video_io),)
# Register the node
class Veo3VideoGenerationNode(VeoVideoGenerationNode):
"""
Generates videos from text prompts using Google's Veo 3 API.
Supported models:
- veo-3.0-generate-001
- veo-3.0-fast-generate-001
This node extends the base Veo node with Veo 3 specific features including
audio generation and fixed 8-second duration.
"""
@classmethod
def INPUT_TYPES(s):
parent_input = super().INPUT_TYPES()
# Update model options for Veo 3
parent_input["optional"]["model"] = (
IO.COMBO,
{
"options": ["veo-3.0-generate-001", "veo-3.0-fast-generate-001"],
"default": "veo-3.0-generate-001",
"tooltip": "Veo 3 model to use for video generation",
},
)
# Add generateAudio parameter
parent_input["optional"]["generate_audio"] = (
IO.BOOLEAN,
{
"default": False,
"tooltip": "Generate audio for the video. Supported by all Veo 3 models.",
}
)
# Update duration constraints for Veo 3 (only 8 seconds supported)
parent_input["optional"]["duration_seconds"] = (
IO.INT,
{
"default": 8,
"min": 8,
"max": 8,
"step": 1,
"display": "number",
"tooltip": "Duration of the output video in seconds (Veo 3 only supports 8 seconds)",
},
)
return parent_input
# Register the nodes
NODE_CLASS_MAPPINGS = {
"VeoVideoGenerationNode": VeoVideoGenerationNode,
"Veo3VideoGenerationNode": Veo3VideoGenerationNode,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"VeoVideoGenerationNode": "Google Veo2 Video Generation",
"VeoVideoGenerationNode": "Google Veo 2 Video Generation",
"Veo3VideoGenerationNode": "Google Veo 3 Video Generation",
}

View File

@ -314,6 +314,29 @@ class ModelMergeCosmosPredict2_14B(comfy_extras.nodes_model_merging.ModelMergeBl
return {"required": arg_dict}
class ModelMergeQwenImage(comfy_extras.nodes_model_merging.ModelMergeBlocks):
CATEGORY = "advanced/model_merging/model_specific"
@classmethod
def INPUT_TYPES(s):
arg_dict = { "model1": ("MODEL",),
"model2": ("MODEL",)}
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
arg_dict["pos_embeds."] = argument
arg_dict["img_in."] = argument
arg_dict["txt_norm."] = argument
arg_dict["txt_in."] = argument
arg_dict["time_text_embed."] = argument
for i in range(60):
arg_dict["transformer_blocks.{}.".format(i)] = argument
arg_dict["proj_out."] = argument
return {"required": arg_dict}
NODE_CLASS_MAPPINGS = {
"ModelMergeSD1": ModelMergeSD1,
"ModelMergeSD2": ModelMergeSD1, #SD1 and SD2 have the same blocks
@ -329,4 +352,5 @@ NODE_CLASS_MAPPINGS = {
"ModelMergeWAN2_1": ModelMergeWAN2_1,
"ModelMergeCosmosPredict2_2B": ModelMergeCosmosPredict2_2B,
"ModelMergeCosmosPredict2_14B": ModelMergeCosmosPredict2_14B,
"ModelMergeQwenImage": ModelMergeQwenImage,
}

View File

@ -1,3 +1,3 @@
# This file is automatically generated by the build process when version is
# updated in pyproject.toml.
__version__ = "0.3.48"
__version__ = "0.3.49"

View File

@ -1229,12 +1229,12 @@ class RepeatLatentBatch:
s = samples.copy()
s_in = samples["samples"]
s["samples"] = s_in.repeat((amount, 1,1,1))
s["samples"] = s_in.repeat((amount,) + ((1,) * (s_in.ndim - 1)))
if "noise_mask" in samples and samples["noise_mask"].shape[0] > 1:
masks = samples["noise_mask"]
if masks.shape[0] < s_in.shape[0]:
masks = masks.repeat(math.ceil(s_in.shape[0] / masks.shape[0]), 1, 1, 1)[:s_in.shape[0]]
s["noise_mask"] = samples["noise_mask"].repeat((amount, 1,1,1))
masks = masks.repeat((math.ceil(s_in.shape[0] / masks.shape[0]),) + ((1,) * (masks.ndim - 1)))[:s_in.shape[0]]
s["noise_mask"] = samples["noise_mask"].repeat((amount,) + ((1,) * (samples["noise_mask"].ndim - 1)))
if "batch_index" in s:
offset = max(s["batch_index"]) - min(s["batch_index"]) + 1
s["batch_index"] = s["batch_index"] + [x + (i * offset) for i in range(1, amount) for x in s["batch_index"]]

View File

@ -1,6 +1,6 @@
[project]
name = "ComfyUI"
version = "0.3.48"
version = "0.3.49"
readme = "README.md"
license = { file = "LICENSE" }
requires-python = ">=3.9"

View File

@ -1,5 +1,5 @@
comfyui-frontend-package==1.23.4
comfyui-workflow-templates==0.1.47
comfyui-frontend-package==1.24.4
comfyui-workflow-templates==0.1.52
comfyui-embedded-docs==0.2.4
torch
torchsde