async API nodes (#9129)

* converted API nodes to async

* converted BFL API nodes to async

* fixed client bug; converted gemini, ideogram, minimax

* fixed client bug; converted openai nodes

* fixed client bug; converted moonvalley, pika nodes

* fixed client bug; converted kling, luma nodes

* converted pixverse, rodin nodes

* converted tripo, veo2

* converted recraft nodes

* add lost log_request_response call
This commit is contained in:
Alexander Piskun 2025-08-08 06:37:50 +03:00 committed by GitHub
parent 42974a448c
commit bf2a1b5b1e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 878 additions and 1076 deletions

View File

@ -1,4 +1,5 @@
from __future__ import annotations from __future__ import annotations
import aiohttp
import io import io
import logging import logging
import mimetypes import mimetypes
@ -21,7 +22,6 @@ from server import PromptServer
import numpy as np import numpy as np
from PIL import Image from PIL import Image
import requests
import torch import torch
import math import math
import base64 import base64
@ -30,7 +30,7 @@ from io import BytesIO
import av import av
def download_url_to_video_output(video_url: str, timeout: int = None) -> VideoFromFile: async def download_url_to_video_output(video_url: str, timeout: int = None) -> VideoFromFile:
"""Downloads a video from a URL and returns a `VIDEO` output. """Downloads a video from a URL and returns a `VIDEO` output.
Args: Args:
@ -39,7 +39,7 @@ def download_url_to_video_output(video_url: str, timeout: int = None) -> VideoFr
Returns: Returns:
A Comfy node `VIDEO` output. A Comfy node `VIDEO` output.
""" """
video_io = download_url_to_bytesio(video_url, timeout) video_io = await download_url_to_bytesio(video_url, timeout)
if video_io is None: if video_io is None:
error_msg = f"Failed to download video from {video_url}" error_msg = f"Failed to download video from {video_url}"
logging.error(error_msg) logging.error(error_msg)
@ -62,7 +62,7 @@ def downscale_image_tensor(image, total_pixels=1536 * 1024) -> torch.Tensor:
return s return s
def validate_and_cast_response( async def validate_and_cast_response(
response, timeout: int = None, node_id: Union[str, None] = None response, timeout: int = None, node_id: Union[str, None] = None
) -> torch.Tensor: ) -> torch.Tensor:
"""Validates and casts a response to a torch.Tensor. """Validates and casts a response to a torch.Tensor.
@ -86,35 +86,24 @@ def validate_and_cast_response(
image_tensors: list[torch.Tensor] = [] image_tensors: list[torch.Tensor] = []
# Process each image in the data array # Process each image in the data array
for image_data in data: async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=timeout)) as session:
image_url = image_data.url for img_data in data:
b64_data = image_data.b64_json img_bytes: bytes
if img_data.b64_json:
if not image_url and not b64_data: img_bytes = base64.b64decode(img_data.b64_json)
raise ValueError("No image was generated in the response") elif img_data.url:
if b64_data:
img_data = base64.b64decode(b64_data)
img = Image.open(io.BytesIO(img_data))
elif image_url:
if node_id: if node_id:
PromptServer.instance.send_progress_text( PromptServer.instance.send_progress_text(f"Result URL: {img_data.url}", node_id)
f"Result URL: {image_url}", node_id async with session.get(img_data.url) as resp:
) if resp.status != 200:
img_response = requests.get(image_url, timeout=timeout) raise ValueError("Failed to download generated image")
if img_response.status_code != 200: img_bytes = await resp.read()
raise ValueError("Failed to download the image") else:
img = Image.open(io.BytesIO(img_response.content)) raise ValueError("Invalid image payload neither URL nor base64 data present.")
img = img.convert("RGBA") pil_img = Image.open(BytesIO(img_bytes)).convert("RGBA")
arr = np.asarray(pil_img).astype(np.float32) / 255.0
# Convert to numpy array, normalize to float32 between 0 and 1 image_tensors.append(torch.from_numpy(arr))
img_array = np.array(img).astype(np.float32) / 255.0
img_tensor = torch.from_numpy(img_array)
# Add to list of tensors
image_tensors.append(img_tensor)
return torch.stack(image_tensors, dim=0) return torch.stack(image_tensors, dim=0)
@ -175,7 +164,7 @@ def mimetype_to_extension(mime_type: str) -> str:
return mime_type.split("/")[-1].lower() return mime_type.split("/")[-1].lower()
def download_url_to_bytesio(url: str, timeout: int = None) -> BytesIO: async def download_url_to_bytesio(url: str, timeout: int = None) -> BytesIO:
"""Downloads content from a URL using requests and returns it as BytesIO. """Downloads content from a URL using requests and returns it as BytesIO.
Args: Args:
@ -185,9 +174,11 @@ def download_url_to_bytesio(url: str, timeout: int = None) -> BytesIO:
Returns: Returns:
BytesIO object containing the downloaded content. BytesIO object containing the downloaded content.
""" """
response = requests.get(url, stream=True, timeout=timeout) timeout_cfg = aiohttp.ClientTimeout(total=timeout) if timeout else None
response.raise_for_status() # Raises HTTPError for bad responses (4XX or 5XX) async with aiohttp.ClientSession(timeout=timeout_cfg) as session:
return BytesIO(response.content) async with session.get(url) as resp:
resp.raise_for_status() # Raises HTTPError for bad responses (4XX or 5XX)
return BytesIO(await resp.read())
def bytesio_to_image_tensor(image_bytesio: BytesIO, mode: str = "RGBA") -> torch.Tensor: def bytesio_to_image_tensor(image_bytesio: BytesIO, mode: str = "RGBA") -> torch.Tensor:
@ -210,15 +201,15 @@ def bytesio_to_image_tensor(image_bytesio: BytesIO, mode: str = "RGBA") -> torch
return torch.from_numpy(image_array).unsqueeze(0) return torch.from_numpy(image_array).unsqueeze(0)
def download_url_to_image_tensor(url: str, timeout: int = None) -> torch.Tensor: async def download_url_to_image_tensor(url: str, timeout: int = None) -> torch.Tensor:
"""Downloads an image from a URL and returns a [B, H, W, C] tensor.""" """Downloads an image from a URL and returns a [B, H, W, C] tensor."""
image_bytesio = download_url_to_bytesio(url, timeout) image_bytesio = await download_url_to_bytesio(url, timeout)
return bytesio_to_image_tensor(image_bytesio) return bytesio_to_image_tensor(image_bytesio)
def process_image_response(response: requests.Response) -> torch.Tensor: def process_image_response(response_content: bytes | str) -> torch.Tensor:
"""Uses content from a Response object and converts it to a torch.Tensor""" """Uses content from a Response object and converts it to a torch.Tensor"""
return bytesio_to_image_tensor(BytesIO(response.content)) return bytesio_to_image_tensor(BytesIO(response_content))
def _tensor_to_pil(image: torch.Tensor, total_pixels: int = 2048 * 2048) -> Image.Image: def _tensor_to_pil(image: torch.Tensor, total_pixels: int = 2048 * 2048) -> Image.Image:
@ -336,10 +327,10 @@ def text_filepath_to_data_uri(filepath: str) -> str:
return f"data:{mime_type};base64,{base64_string}" return f"data:{mime_type};base64,{base64_string}"
def upload_file_to_comfyapi( async def upload_file_to_comfyapi(
file_bytes_io: BytesIO, file_bytes_io: BytesIO,
filename: str, filename: str,
upload_mime_type: str, upload_mime_type: Optional[str],
auth_kwargs: Optional[dict[str, str]] = None, auth_kwargs: Optional[dict[str, str]] = None,
) -> str: ) -> str:
""" """
@ -354,6 +345,9 @@ def upload_file_to_comfyapi(
Returns: Returns:
The download URL for the uploaded file. The download URL for the uploaded file.
""" """
if upload_mime_type is None:
request_object = UploadRequest(file_name=filename)
else:
request_object = UploadRequest(file_name=filename, content_type=upload_mime_type) request_object = UploadRequest(file_name=filename, content_type=upload_mime_type)
operation = SynchronousOperation( operation = SynchronousOperation(
endpoint=ApiEndpoint( endpoint=ApiEndpoint(
@ -366,12 +360,8 @@ def upload_file_to_comfyapi(
auth_kwargs=auth_kwargs, auth_kwargs=auth_kwargs,
) )
response: UploadResponse = operation.execute() response: UploadResponse = await operation.execute()
upload_response = ApiClient.upload_file( await ApiClient.upload_file(response.upload_url, file_bytes_io, content_type=upload_mime_type)
response.upload_url, file_bytes_io, content_type=upload_mime_type
)
upload_response.raise_for_status()
return response.download_url return response.download_url
@ -399,7 +389,7 @@ def video_to_base64_string(
return base64.b64encode(video_bytes_io.getvalue()).decode("utf-8") return base64.b64encode(video_bytes_io.getvalue()).decode("utf-8")
def upload_video_to_comfyapi( async def upload_video_to_comfyapi(
video: VideoInput, video: VideoInput,
auth_kwargs: Optional[dict[str, str]] = None, auth_kwargs: Optional[dict[str, str]] = None,
container: VideoContainer = VideoContainer.MP4, container: VideoContainer = VideoContainer.MP4,
@ -439,9 +429,7 @@ def upload_video_to_comfyapi(
video.save_to(video_bytes_io, format=container, codec=codec) video.save_to(video_bytes_io, format=container, codec=codec)
video_bytes_io.seek(0) video_bytes_io.seek(0)
return upload_file_to_comfyapi( return await upload_file_to_comfyapi(video_bytes_io, filename, upload_mime_type, auth_kwargs)
video_bytes_io, filename, upload_mime_type, auth_kwargs
)
def audio_tensor_to_contiguous_ndarray(waveform: torch.Tensor) -> np.ndarray: def audio_tensor_to_contiguous_ndarray(waveform: torch.Tensor) -> np.ndarray:
@ -501,7 +489,7 @@ def audio_ndarray_to_bytesio(
return audio_bytes_io return audio_bytes_io
def upload_audio_to_comfyapi( async def upload_audio_to_comfyapi(
audio: AudioInput, audio: AudioInput,
auth_kwargs: Optional[dict[str, str]] = None, auth_kwargs: Optional[dict[str, str]] = None,
container_format: str = "mp4", container_format: str = "mp4",
@ -527,7 +515,7 @@ def upload_audio_to_comfyapi(
audio_data_np, sample_rate, container_format, codec_name audio_data_np, sample_rate, container_format, codec_name
) )
return upload_file_to_comfyapi(audio_bytes_io, filename, mime_type, auth_kwargs) return await upload_file_to_comfyapi(audio_bytes_io, filename, mime_type, auth_kwargs)
def audio_to_base64_string( def audio_to_base64_string(
@ -544,7 +532,7 @@ def audio_to_base64_string(
return base64.b64encode(audio_bytes).decode("utf-8") return base64.b64encode(audio_bytes).decode("utf-8")
def upload_images_to_comfyapi( async def upload_images_to_comfyapi(
image: torch.Tensor, image: torch.Tensor,
max_images=8, max_images=8,
auth_kwargs: Optional[dict[str, str]] = None, auth_kwargs: Optional[dict[str, str]] = None,
@ -561,55 +549,15 @@ def upload_images_to_comfyapi(
mime_type: Optional MIME type for the image. mime_type: Optional MIME type for the image.
""" """
# if batch, try to upload each file if max_images is greater than 0 # if batch, try to upload each file if max_images is greater than 0
idx_image = 0
download_urls: list[str] = [] download_urls: list[str] = []
is_batch = len(image.shape) > 3 is_batch = len(image.shape) > 3
batch_length = 1 batch_len = image.shape[0] if is_batch else 1
if is_batch:
batch_length = image.shape[0]
while True:
curr_image = image
if len(image.shape) > 3:
curr_image = image[idx_image]
# get BytesIO version of image
img_binary = tensor_to_bytesio(curr_image, mime_type=mime_type)
# first, request upload/download urls from comfy API
if not mime_type:
request_object = UploadRequest(file_name=img_binary.name)
else:
request_object = UploadRequest(
file_name=img_binary.name, content_type=mime_type
)
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/customers/storage",
method=HttpMethod.POST,
request_model=UploadRequest,
response_model=UploadResponse,
),
request=request_object,
auth_kwargs=auth_kwargs,
)
response = operation.execute()
upload_response = ApiClient.upload_file( for idx in range(min(batch_len, max_images)):
response.upload_url, img_binary, content_type=mime_type tensor = image[idx] if is_batch else image
) img_io = tensor_to_bytesio(tensor, mime_type=mime_type)
# verify success url = await upload_file_to_comfyapi(img_io, img_io.name, mime_type, auth_kwargs)
try: download_urls.append(url)
upload_response.raise_for_status()
except requests.exceptions.HTTPError as e:
raise ValueError(f"Could not upload one or more images: {e}") from e
# add download_url to list
download_urls.append(response.download_url)
idx_image += 1
# stop uploading additional files if done
if is_batch and max_images > 0:
if idx_image >= max_images:
break
if idx_image >= batch_length:
break
return download_urls return download_urls

File diff suppressed because it is too large Load Diff

View File

@ -1,3 +1,4 @@
import asyncio
import io import io
from inspect import cleandoc from inspect import cleandoc
from typing import Union, Optional from typing import Union, Optional
@ -28,7 +29,7 @@ from comfy_api_nodes.apinode_utils import (
import numpy as np import numpy as np
from PIL import Image from PIL import Image
import requests import aiohttp
import torch import torch
import base64 import base64
import time import time
@ -44,18 +45,18 @@ def convert_mask_to_image(mask: torch.Tensor):
return mask return mask
def handle_bfl_synchronous_operation( async def handle_bfl_synchronous_operation(
operation: SynchronousOperation, operation: SynchronousOperation,
timeout_bfl_calls=360, timeout_bfl_calls=360,
node_id: Union[str, None] = None, node_id: Union[str, None] = None,
): ):
response_api: BFLFluxProGenerateResponse = operation.execute() response_api: BFLFluxProGenerateResponse = await operation.execute()
return _poll_until_generated( return await _poll_until_generated(
response_api.polling_url, timeout=timeout_bfl_calls, node_id=node_id response_api.polling_url, timeout=timeout_bfl_calls, node_id=node_id
) )
def _poll_until_generated( async def _poll_until_generated(
polling_url: str, timeout=360, node_id: Union[str, None] = None polling_url: str, timeout=360, node_id: Union[str, None] = None
): ):
# used bfl-comfy-nodes to verify code implementation: # used bfl-comfy-nodes to verify code implementation:
@ -66,7 +67,8 @@ def _poll_until_generated(
retry_404_seconds = 2 retry_404_seconds = 2
retry_202_seconds = 2 retry_202_seconds = 2
retry_pending_seconds = 1 retry_pending_seconds = 1
request = requests.Request(method=HttpMethod.GET, url=polling_url)
async with aiohttp.ClientSession() as session:
# NOTE: should True loop be replaced with checking if workflow has been interrupted? # NOTE: should True loop be replaced with checking if workflow has been interrupted?
while True: while True:
if node_id: if node_id:
@ -75,17 +77,17 @@ def _poll_until_generated(
f"Generating ({time_elapsed:.0f}s)", node_id f"Generating ({time_elapsed:.0f}s)", node_id
) )
response = requests.Session().send(request.prepare()) async with session.get(polling_url) as response:
if response.status_code == 200: if response.status == 200:
result = response.json() result = await response.json()
if result["status"] == BFLStatus.ready: if result["status"] == BFLStatus.ready:
img_url = result["result"]["sample"] img_url = result["result"]["sample"]
if node_id: if node_id:
PromptServer.instance.send_progress_text( PromptServer.instance.send_progress_text(
f"Result URL: {img_url}", node_id f"Result URL: {img_url}", node_id
) )
img_response = requests.get(img_url) async with session.get(img_url) as img_resp:
return process_image_response(img_response) return process_image_response(await img_resp.content.read())
elif result["status"] in [ elif result["status"] in [
BFLStatus.request_moderated, BFLStatus.request_moderated,
BFLStatus.content_moderated, BFLStatus.content_moderated,
@ -97,18 +99,18 @@ def _poll_until_generated(
elif result["status"] == BFLStatus.error: elif result["status"] == BFLStatus.error:
raise Exception(f"BFL API encountered an error: {result}.") raise Exception(f"BFL API encountered an error: {result}.")
elif result["status"] == BFLStatus.pending: elif result["status"] == BFLStatus.pending:
time.sleep(retry_pending_seconds) await asyncio.sleep(retry_pending_seconds)
continue continue
elif response.status_code == 404: elif response.status == 404:
if retries_404 < max_retries_404: if retries_404 < max_retries_404:
retries_404 += 1 retries_404 += 1
time.sleep(retry_404_seconds) await asyncio.sleep(retry_404_seconds)
continue continue
raise Exception( raise Exception(
f"BFL API could not find task after {max_retries_404} tries." f"BFL API could not find task after {max_retries_404} tries."
) )
elif response.status_code == 202: elif response.status == 202:
time.sleep(retry_202_seconds) await asyncio.sleep(retry_202_seconds)
elif time.time() - start_time > timeout: elif time.time() - start_time > timeout:
raise Exception( raise Exception(
f"BFL API experienced a timeout; could not return request under {timeout} seconds." f"BFL API experienced a timeout; could not return request under {timeout} seconds."
@ -222,7 +224,7 @@ class FluxProUltraImageNode(ComfyNodeABC):
API_NODE = True API_NODE = True
CATEGORY = "api node/image/BFL" CATEGORY = "api node/image/BFL"
def api_call( async def api_call(
self, self,
prompt: str, prompt: str,
aspect_ratio: str, aspect_ratio: str,
@ -266,7 +268,7 @@ class FluxProUltraImageNode(ComfyNodeABC):
), ),
auth_kwargs=kwargs, auth_kwargs=kwargs,
) )
output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id) output_image = await handle_bfl_synchronous_operation(operation, node_id=unique_id)
return (output_image,) return (output_image,)
@ -354,7 +356,7 @@ class FluxKontextProImageNode(ComfyNodeABC):
BFL_PATH = "/proxy/bfl/flux-kontext-pro/generate" BFL_PATH = "/proxy/bfl/flux-kontext-pro/generate"
def api_call( async def api_call(
self, self,
prompt: str, prompt: str,
aspect_ratio: str, aspect_ratio: str,
@ -397,7 +399,7 @@ class FluxKontextProImageNode(ComfyNodeABC):
), ),
auth_kwargs=kwargs, auth_kwargs=kwargs,
) )
output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id) output_image = await handle_bfl_synchronous_operation(operation, node_id=unique_id)
return (output_image,) return (output_image,)
@ -489,7 +491,7 @@ class FluxProImageNode(ComfyNodeABC):
API_NODE = True API_NODE = True
CATEGORY = "api node/image/BFL" CATEGORY = "api node/image/BFL"
def api_call( async def api_call(
self, self,
prompt: str, prompt: str,
prompt_upsampling, prompt_upsampling,
@ -524,7 +526,7 @@ class FluxProImageNode(ComfyNodeABC):
), ),
auth_kwargs=kwargs, auth_kwargs=kwargs,
) )
output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id) output_image = await handle_bfl_synchronous_operation(operation, node_id=unique_id)
return (output_image,) return (output_image,)
@ -632,7 +634,7 @@ class FluxProExpandNode(ComfyNodeABC):
API_NODE = True API_NODE = True
CATEGORY = "api node/image/BFL" CATEGORY = "api node/image/BFL"
def api_call( async def api_call(
self, self,
image: torch.Tensor, image: torch.Tensor,
prompt: str, prompt: str,
@ -670,7 +672,7 @@ class FluxProExpandNode(ComfyNodeABC):
), ),
auth_kwargs=kwargs, auth_kwargs=kwargs,
) )
output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id) output_image = await handle_bfl_synchronous_operation(operation, node_id=unique_id)
return (output_image,) return (output_image,)
@ -744,7 +746,7 @@ class FluxProFillNode(ComfyNodeABC):
API_NODE = True API_NODE = True
CATEGORY = "api node/image/BFL" CATEGORY = "api node/image/BFL"
def api_call( async def api_call(
self, self,
image: torch.Tensor, image: torch.Tensor,
mask: torch.Tensor, mask: torch.Tensor,
@ -780,7 +782,7 @@ class FluxProFillNode(ComfyNodeABC):
), ),
auth_kwargs=kwargs, auth_kwargs=kwargs,
) )
output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id) output_image = await handle_bfl_synchronous_operation(operation, node_id=unique_id)
return (output_image,) return (output_image,)
@ -879,7 +881,7 @@ class FluxProCannyNode(ComfyNodeABC):
API_NODE = True API_NODE = True
CATEGORY = "api node/image/BFL" CATEGORY = "api node/image/BFL"
def api_call( async def api_call(
self, self,
control_image: torch.Tensor, control_image: torch.Tensor,
prompt: str, prompt: str,
@ -929,7 +931,7 @@ class FluxProCannyNode(ComfyNodeABC):
), ),
auth_kwargs=kwargs, auth_kwargs=kwargs,
) )
output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id) output_image = await handle_bfl_synchronous_operation(operation, node_id=unique_id)
return (output_image,) return (output_image,)
@ -1008,7 +1010,7 @@ class FluxProDepthNode(ComfyNodeABC):
API_NODE = True API_NODE = True
CATEGORY = "api node/image/BFL" CATEGORY = "api node/image/BFL"
def api_call( async def api_call(
self, self,
control_image: torch.Tensor, control_image: torch.Tensor,
prompt: str, prompt: str,
@ -1045,7 +1047,7 @@ class FluxProDepthNode(ComfyNodeABC):
), ),
auth_kwargs=kwargs, auth_kwargs=kwargs,
) )
output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id) output_image = await handle_bfl_synchronous_operation(operation, node_id=unique_id)
return (output_image,) return (output_image,)

View File

@ -303,7 +303,7 @@ class GeminiNode(ComfyNodeABC):
""" """
return GeminiPart(text=text) return GeminiPart(text=text)
def api_call( async def api_call(
self, self,
prompt: str, prompt: str,
model: GeminiModel, model: GeminiModel,
@ -332,7 +332,7 @@ class GeminiNode(ComfyNodeABC):
parts.extend(files) parts.extend(files)
# Create response # Create response
response = SynchronousOperation( response = await SynchronousOperation(
endpoint=get_gemini_endpoint(model), endpoint=get_gemini_endpoint(model),
request=GeminiGenerateContentRequest( request=GeminiGenerateContentRequest(
contents=[ contents=[

View File

@ -212,7 +212,7 @@ V3_RESOLUTIONS= [
"1536x640" "1536x640"
] ]
def download_and_process_images(image_urls): async def download_and_process_images(image_urls):
"""Helper function to download and process multiple images from URLs""" """Helper function to download and process multiple images from URLs"""
# Initialize list to store image tensors # Initialize list to store image tensors
@ -220,7 +220,7 @@ def download_and_process_images(image_urls):
for image_url in image_urls: for image_url in image_urls:
# Using functions from apinode_utils.py to handle downloading and processing # Using functions from apinode_utils.py to handle downloading and processing
image_bytesio = download_url_to_bytesio(image_url) # Download image content to BytesIO image_bytesio = await download_url_to_bytesio(image_url) # Download image content to BytesIO
img_tensor = bytesio_to_image_tensor(image_bytesio, mode="RGB") # Convert to torch.Tensor with RGB mode img_tensor = bytesio_to_image_tensor(image_bytesio, mode="RGB") # Convert to torch.Tensor with RGB mode
image_tensors.append(img_tensor) image_tensors.append(img_tensor)
@ -328,7 +328,7 @@ class IdeogramV1(ComfyNodeABC):
DESCRIPTION = cleandoc(__doc__ or "") DESCRIPTION = cleandoc(__doc__ or "")
API_NODE = True API_NODE = True
def api_call( async def api_call(
self, self,
prompt, prompt,
turbo=False, turbo=False,
@ -367,7 +367,7 @@ class IdeogramV1(ComfyNodeABC):
auth_kwargs=kwargs, auth_kwargs=kwargs,
) )
response = operation.execute() response = await operation.execute()
if not response.data or len(response.data) == 0: if not response.data or len(response.data) == 0:
raise Exception("No images were generated in the response") raise Exception("No images were generated in the response")
@ -378,7 +378,7 @@ class IdeogramV1(ComfyNodeABC):
raise Exception("No image URLs were generated in the response") raise Exception("No image URLs were generated in the response")
display_image_urls_on_node(image_urls, unique_id) display_image_urls_on_node(image_urls, unique_id)
return (download_and_process_images(image_urls),) return (await download_and_process_images(image_urls),)
class IdeogramV2(ComfyNodeABC): class IdeogramV2(ComfyNodeABC):
@ -487,7 +487,7 @@ class IdeogramV2(ComfyNodeABC):
DESCRIPTION = cleandoc(__doc__ or "") DESCRIPTION = cleandoc(__doc__ or "")
API_NODE = True API_NODE = True
def api_call( async def api_call(
self, self,
prompt, prompt,
turbo=False, turbo=False,
@ -543,7 +543,7 @@ class IdeogramV2(ComfyNodeABC):
auth_kwargs=kwargs, auth_kwargs=kwargs,
) )
response = operation.execute() response = await operation.execute()
if not response.data or len(response.data) == 0: if not response.data or len(response.data) == 0:
raise Exception("No images were generated in the response") raise Exception("No images were generated in the response")
@ -554,7 +554,7 @@ class IdeogramV2(ComfyNodeABC):
raise Exception("No image URLs were generated in the response") raise Exception("No image URLs were generated in the response")
display_image_urls_on_node(image_urls, unique_id) display_image_urls_on_node(image_urls, unique_id)
return (download_and_process_images(image_urls),) return (await download_and_process_images(image_urls),)
class IdeogramV3(ComfyNodeABC): class IdeogramV3(ComfyNodeABC):
""" """
@ -653,7 +653,7 @@ class IdeogramV3(ComfyNodeABC):
DESCRIPTION = cleandoc(__doc__ or "") DESCRIPTION = cleandoc(__doc__ or "")
API_NODE = True API_NODE = True
def api_call( async def api_call(
self, self,
prompt, prompt,
image=None, image=None,
@ -774,7 +774,7 @@ class IdeogramV3(ComfyNodeABC):
) )
# Execute the operation and process response # Execute the operation and process response
response = operation.execute() response = await operation.execute()
if not response.data or len(response.data) == 0: if not response.data or len(response.data) == 0:
raise Exception("No images were generated in the response") raise Exception("No images were generated in the response")
@ -785,7 +785,7 @@ class IdeogramV3(ComfyNodeABC):
raise Exception("No image URLs were generated in the response") raise Exception("No image URLs were generated in the response")
display_image_urls_on_node(image_urls, unique_id) display_image_urls_on_node(image_urls, unique_id)
return (download_and_process_images(image_urls),) return (await download_and_process_images(image_urls),)
NODE_CLASS_MAPPINGS = { NODE_CLASS_MAPPINGS = {

View File

@ -109,7 +109,7 @@ class KlingApiError(Exception):
pass pass
def poll_until_finished( async def poll_until_finished(
auth_kwargs: dict[str, str], auth_kwargs: dict[str, str],
api_endpoint: ApiEndpoint[Any, R], api_endpoint: ApiEndpoint[Any, R],
result_url_extractor: Optional[Callable[[R], str]] = None, result_url_extractor: Optional[Callable[[R], str]] = None,
@ -117,7 +117,7 @@ def poll_until_finished(
node_id: Optional[str] = None, node_id: Optional[str] = None,
) -> R: ) -> R:
"""Polls the Kling API endpoint until the task reaches a terminal state, then returns the response.""" """Polls the Kling API endpoint until the task reaches a terminal state, then returns the response."""
return PollingOperation( return await PollingOperation(
poll_endpoint=api_endpoint, poll_endpoint=api_endpoint,
completed_statuses=[ completed_statuses=[
KlingTaskStatus.succeed.value, KlingTaskStatus.succeed.value,
@ -278,18 +278,18 @@ def get_images_urls_from_response(response) -> Optional[str]:
return None return None
def video_result_to_node_output( async def video_result_to_node_output(
video: KlingVideoResult, video: KlingVideoResult,
) -> tuple[VideoFromFile, str, str]: ) -> tuple[VideoFromFile, str, str]:
"""Converts a KlingVideoResult to a tuple of (VideoFromFile, str, str) to be used as a ComfyUI node output.""" """Converts a KlingVideoResult to a tuple of (VideoFromFile, str, str) to be used as a ComfyUI node output."""
return ( return (
download_url_to_video_output(video.url), await download_url_to_video_output(str(video.url)),
str(video.id), str(video.id),
str(video.duration), str(video.duration),
) )
def image_result_to_node_output( async def image_result_to_node_output(
images: list[KlingImageResult], images: list[KlingImageResult],
) -> torch.Tensor: ) -> torch.Tensor:
""" """
@ -297,9 +297,9 @@ def image_result_to_node_output(
If multiple images are returned, they will be stacked along the batch dimension. If multiple images are returned, they will be stacked along the batch dimension.
""" """
if len(images) == 1: if len(images) == 1:
return download_url_to_image_tensor(images[0].url) return await download_url_to_image_tensor(str(images[0].url))
else: else:
return torch.cat([download_url_to_image_tensor(image.url) for image in images]) return torch.cat([await download_url_to_image_tensor(str(image.url)) for image in images])
class KlingNodeBase(ComfyNodeABC): class KlingNodeBase(ComfyNodeABC):
@ -467,10 +467,10 @@ class KlingTextToVideoNode(KlingNodeBase):
RETURN_NAMES = ("VIDEO", "video_id", "duration") RETURN_NAMES = ("VIDEO", "video_id", "duration")
DESCRIPTION = "Kling Text to Video Node" DESCRIPTION = "Kling Text to Video Node"
def get_response( async def get_response(
self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
) -> KlingText2VideoResponse: ) -> KlingText2VideoResponse:
return poll_until_finished( return await poll_until_finished(
auth_kwargs, auth_kwargs,
ApiEndpoint( ApiEndpoint(
path=f"{PATH_TEXT_TO_VIDEO}/{task_id}", path=f"{PATH_TEXT_TO_VIDEO}/{task_id}",
@ -483,7 +483,7 @@ class KlingTextToVideoNode(KlingNodeBase):
node_id=node_id, node_id=node_id,
) )
def api_call( async def api_call(
self, self,
prompt: str, prompt: str,
negative_prompt: str, negative_prompt: str,
@ -519,17 +519,17 @@ class KlingTextToVideoNode(KlingNodeBase):
auth_kwargs=kwargs, auth_kwargs=kwargs,
) )
task_creation_response = initial_operation.execute() task_creation_response = await initial_operation.execute()
validate_task_creation_response(task_creation_response) validate_task_creation_response(task_creation_response)
task_id = task_creation_response.data.task_id task_id = task_creation_response.data.task_id
final_response = self.get_response( final_response = await self.get_response(
task_id, auth_kwargs=kwargs, node_id=unique_id task_id, auth_kwargs=kwargs, node_id=unique_id
) )
validate_video_result_response(final_response) validate_video_result_response(final_response)
video = get_video_from_response(final_response) video = get_video_from_response(final_response)
return video_result_to_node_output(video) return await video_result_to_node_output(video)
class KlingCameraControlT2VNode(KlingTextToVideoNode): class KlingCameraControlT2VNode(KlingTextToVideoNode):
@ -581,7 +581,7 @@ class KlingCameraControlT2VNode(KlingTextToVideoNode):
DESCRIPTION = "Transform text into cinematic videos with professional camera movements that simulate real-world cinematography. Control virtual camera actions including zoom, rotation, pan, tilt, and first-person view, while maintaining focus on your original text." DESCRIPTION = "Transform text into cinematic videos with professional camera movements that simulate real-world cinematography. Control virtual camera actions including zoom, rotation, pan, tilt, and first-person view, while maintaining focus on your original text."
def api_call( async def api_call(
self, self,
prompt: str, prompt: str,
negative_prompt: str, negative_prompt: str,
@ -591,7 +591,7 @@ class KlingCameraControlT2VNode(KlingTextToVideoNode):
unique_id: Optional[str] = None, unique_id: Optional[str] = None,
**kwargs, **kwargs,
): ):
return super().api_call( return await super().api_call(
model_name=KlingVideoGenModelName.kling_v1, model_name=KlingVideoGenModelName.kling_v1,
cfg_scale=cfg_scale, cfg_scale=cfg_scale,
mode=KlingVideoGenMode.std, mode=KlingVideoGenMode.std,
@ -670,10 +670,10 @@ class KlingImage2VideoNode(KlingNodeBase):
RETURN_NAMES = ("VIDEO", "video_id", "duration") RETURN_NAMES = ("VIDEO", "video_id", "duration")
DESCRIPTION = "Kling Image to Video Node" DESCRIPTION = "Kling Image to Video Node"
def get_response( async def get_response(
self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
) -> KlingImage2VideoResponse: ) -> KlingImage2VideoResponse:
return poll_until_finished( return await poll_until_finished(
auth_kwargs, auth_kwargs,
ApiEndpoint( ApiEndpoint(
path=f"{PATH_IMAGE_TO_VIDEO}/{task_id}", path=f"{PATH_IMAGE_TO_VIDEO}/{task_id}",
@ -686,7 +686,7 @@ class KlingImage2VideoNode(KlingNodeBase):
node_id=node_id, node_id=node_id,
) )
def api_call( async def api_call(
self, self,
start_frame: torch.Tensor, start_frame: torch.Tensor,
prompt: str, prompt: str,
@ -733,17 +733,17 @@ class KlingImage2VideoNode(KlingNodeBase):
auth_kwargs=kwargs, auth_kwargs=kwargs,
) )
task_creation_response = initial_operation.execute() task_creation_response = await initial_operation.execute()
validate_task_creation_response(task_creation_response) validate_task_creation_response(task_creation_response)
task_id = task_creation_response.data.task_id task_id = task_creation_response.data.task_id
final_response = self.get_response( final_response = await self.get_response(
task_id, auth_kwargs=kwargs, node_id=unique_id task_id, auth_kwargs=kwargs, node_id=unique_id
) )
validate_video_result_response(final_response) validate_video_result_response(final_response)
video = get_video_from_response(final_response) video = get_video_from_response(final_response)
return video_result_to_node_output(video) return await video_result_to_node_output(video)
class KlingCameraControlI2VNode(KlingImage2VideoNode): class KlingCameraControlI2VNode(KlingImage2VideoNode):
@ -798,7 +798,7 @@ class KlingCameraControlI2VNode(KlingImage2VideoNode):
DESCRIPTION = "Transform still images into cinematic videos with professional camera movements that simulate real-world cinematography. Control virtual camera actions including zoom, rotation, pan, tilt, and first-person view, while maintaining focus on your original image." DESCRIPTION = "Transform still images into cinematic videos with professional camera movements that simulate real-world cinematography. Control virtual camera actions including zoom, rotation, pan, tilt, and first-person view, while maintaining focus on your original image."
def api_call( async def api_call(
self, self,
start_frame: torch.Tensor, start_frame: torch.Tensor,
prompt: str, prompt: str,
@ -809,7 +809,7 @@ class KlingCameraControlI2VNode(KlingImage2VideoNode):
unique_id: Optional[str] = None, unique_id: Optional[str] = None,
**kwargs, **kwargs,
): ):
return super().api_call( return await super().api_call(
model_name=KlingVideoGenModelName.kling_v1_5, model_name=KlingVideoGenModelName.kling_v1_5,
start_frame=start_frame, start_frame=start_frame,
cfg_scale=cfg_scale, cfg_scale=cfg_scale,
@ -897,7 +897,7 @@ class KlingStartEndFrameNode(KlingImage2VideoNode):
DESCRIPTION = "Generate a video sequence that transitions between your provided start and end images. The node creates all frames in between, producing a smooth transformation from the first frame to the last." DESCRIPTION = "Generate a video sequence that transitions between your provided start and end images. The node creates all frames in between, producing a smooth transformation from the first frame to the last."
def api_call( async def api_call(
self, self,
start_frame: torch.Tensor, start_frame: torch.Tensor,
end_frame: torch.Tensor, end_frame: torch.Tensor,
@ -912,7 +912,7 @@ class KlingStartEndFrameNode(KlingImage2VideoNode):
mode, duration, model_name = KlingStartEndFrameNode.get_mode_string_mapping()[ mode, duration, model_name = KlingStartEndFrameNode.get_mode_string_mapping()[
mode mode
] ]
return super().api_call( return await super().api_call(
prompt=prompt, prompt=prompt,
negative_prompt=negative_prompt, negative_prompt=negative_prompt,
model_name=model_name, model_name=model_name,
@ -964,10 +964,10 @@ class KlingVideoExtendNode(KlingNodeBase):
RETURN_NAMES = ("VIDEO", "video_id", "duration") RETURN_NAMES = ("VIDEO", "video_id", "duration")
DESCRIPTION = "Kling Video Extend Node. Extend videos made by other Kling nodes. The video_id is created by using other Kling Nodes." DESCRIPTION = "Kling Video Extend Node. Extend videos made by other Kling nodes. The video_id is created by using other Kling Nodes."
def get_response( async def get_response(
self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
) -> KlingVideoExtendResponse: ) -> KlingVideoExtendResponse:
return poll_until_finished( return await poll_until_finished(
auth_kwargs, auth_kwargs,
ApiEndpoint( ApiEndpoint(
path=f"{PATH_VIDEO_EXTEND}/{task_id}", path=f"{PATH_VIDEO_EXTEND}/{task_id}",
@ -980,7 +980,7 @@ class KlingVideoExtendNode(KlingNodeBase):
node_id=node_id, node_id=node_id,
) )
def api_call( async def api_call(
self, self,
prompt: str, prompt: str,
negative_prompt: str, negative_prompt: str,
@ -1006,17 +1006,17 @@ class KlingVideoExtendNode(KlingNodeBase):
auth_kwargs=kwargs, auth_kwargs=kwargs,
) )
task_creation_response = initial_operation.execute() task_creation_response = await initial_operation.execute()
validate_task_creation_response(task_creation_response) validate_task_creation_response(task_creation_response)
task_id = task_creation_response.data.task_id task_id = task_creation_response.data.task_id
final_response = self.get_response( final_response = await self.get_response(
task_id, auth_kwargs=kwargs, node_id=unique_id task_id, auth_kwargs=kwargs, node_id=unique_id
) )
validate_video_result_response(final_response) validate_video_result_response(final_response)
video = get_video_from_response(final_response) video = get_video_from_response(final_response)
return video_result_to_node_output(video) return await video_result_to_node_output(video)
class KlingVideoEffectsBase(KlingNodeBase): class KlingVideoEffectsBase(KlingNodeBase):
@ -1025,10 +1025,10 @@ class KlingVideoEffectsBase(KlingNodeBase):
RETURN_TYPES = ("VIDEO", "STRING", "STRING") RETURN_TYPES = ("VIDEO", "STRING", "STRING")
RETURN_NAMES = ("VIDEO", "video_id", "duration") RETURN_NAMES = ("VIDEO", "video_id", "duration")
def get_response( async def get_response(
self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
) -> KlingVideoEffectsResponse: ) -> KlingVideoEffectsResponse:
return poll_until_finished( return await poll_until_finished(
auth_kwargs, auth_kwargs,
ApiEndpoint( ApiEndpoint(
path=f"{PATH_VIDEO_EFFECTS}/{task_id}", path=f"{PATH_VIDEO_EFFECTS}/{task_id}",
@ -1041,7 +1041,7 @@ class KlingVideoEffectsBase(KlingNodeBase):
node_id=node_id, node_id=node_id,
) )
def api_call( async def api_call(
self, self,
dual_character: bool, dual_character: bool,
effect_scene: KlingDualCharacterEffectsScene | KlingSingleImageEffectsScene, effect_scene: KlingDualCharacterEffectsScene | KlingSingleImageEffectsScene,
@ -1084,17 +1084,17 @@ class KlingVideoEffectsBase(KlingNodeBase):
auth_kwargs=kwargs, auth_kwargs=kwargs,
) )
task_creation_response = initial_operation.execute() task_creation_response = await initial_operation.execute()
validate_task_creation_response(task_creation_response) validate_task_creation_response(task_creation_response)
task_id = task_creation_response.data.task_id task_id = task_creation_response.data.task_id
final_response = self.get_response( final_response = await self.get_response(
task_id, auth_kwargs=kwargs, node_id=unique_id task_id, auth_kwargs=kwargs, node_id=unique_id
) )
validate_video_result_response(final_response) validate_video_result_response(final_response)
video = get_video_from_response(final_response) video = get_video_from_response(final_response)
return video_result_to_node_output(video) return await video_result_to_node_output(video)
class KlingDualCharacterVideoEffectNode(KlingVideoEffectsBase): class KlingDualCharacterVideoEffectNode(KlingVideoEffectsBase):
@ -1142,7 +1142,7 @@ class KlingDualCharacterVideoEffectNode(KlingVideoEffectsBase):
RETURN_TYPES = ("VIDEO", "STRING") RETURN_TYPES = ("VIDEO", "STRING")
RETURN_NAMES = ("VIDEO", "duration") RETURN_NAMES = ("VIDEO", "duration")
def api_call( async def api_call(
self, self,
image_left: torch.Tensor, image_left: torch.Tensor,
image_right: torch.Tensor, image_right: torch.Tensor,
@ -1153,7 +1153,7 @@ class KlingDualCharacterVideoEffectNode(KlingVideoEffectsBase):
unique_id: Optional[str] = None, unique_id: Optional[str] = None,
**kwargs, **kwargs,
): ):
video, _, duration = super().api_call( video, _, duration = await super().api_call(
dual_character=True, dual_character=True,
effect_scene=effect_scene, effect_scene=effect_scene,
model_name=model_name, model_name=model_name,
@ -1208,7 +1208,7 @@ class KlingSingleImageVideoEffectNode(KlingVideoEffectsBase):
DESCRIPTION = "Achieve different special effects when generating a video based on the effect_scene." DESCRIPTION = "Achieve different special effects when generating a video based on the effect_scene."
def api_call( async def api_call(
self, self,
image: torch.Tensor, image: torch.Tensor,
effect_scene: KlingSingleImageEffectsScene, effect_scene: KlingSingleImageEffectsScene,
@ -1217,7 +1217,7 @@ class KlingSingleImageVideoEffectNode(KlingVideoEffectsBase):
unique_id: Optional[str] = None, unique_id: Optional[str] = None,
**kwargs, **kwargs,
): ):
return super().api_call( return await super().api_call(
dual_character=False, dual_character=False,
effect_scene=effect_scene, effect_scene=effect_scene,
model_name=model_name, model_name=model_name,
@ -1253,11 +1253,11 @@ class KlingLipSyncBase(KlingNodeBase):
f"Text is too long. Maximum length is {MAX_PROMPT_LENGTH_LIP_SYNC} characters." f"Text is too long. Maximum length is {MAX_PROMPT_LENGTH_LIP_SYNC} characters."
) )
def get_response( async def get_response(
self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
) -> KlingLipSyncResponse: ) -> KlingLipSyncResponse:
"""Polls the Kling API endpoint until the task reaches a terminal state.""" """Polls the Kling API endpoint until the task reaches a terminal state."""
return poll_until_finished( return await poll_until_finished(
auth_kwargs, auth_kwargs,
ApiEndpoint( ApiEndpoint(
path=f"{PATH_LIP_SYNC}/{task_id}", path=f"{PATH_LIP_SYNC}/{task_id}",
@ -1270,7 +1270,7 @@ class KlingLipSyncBase(KlingNodeBase):
node_id=node_id, node_id=node_id,
) )
def api_call( async def api_call(
self, self,
video: VideoInput, video: VideoInput,
audio: Optional[AudioInput] = None, audio: Optional[AudioInput] = None,
@ -1287,12 +1287,12 @@ class KlingLipSyncBase(KlingNodeBase):
self.validate_lip_sync_video(video) self.validate_lip_sync_video(video)
# Upload video to Comfy API and get download URL # Upload video to Comfy API and get download URL
video_url = upload_video_to_comfyapi(video, auth_kwargs=kwargs) video_url = await upload_video_to_comfyapi(video, auth_kwargs=kwargs)
logging.info("Uploaded video to Comfy API. URL: %s", video_url) logging.info("Uploaded video to Comfy API. URL: %s", video_url)
# Upload the audio file to Comfy API and get download URL # Upload the audio file to Comfy API and get download URL
if audio: if audio:
audio_url = upload_audio_to_comfyapi(audio, auth_kwargs=kwargs) audio_url = await upload_audio_to_comfyapi(audio, auth_kwargs=kwargs)
logging.info("Uploaded audio to Comfy API. URL: %s", audio_url) logging.info("Uploaded audio to Comfy API. URL: %s", audio_url)
else: else:
audio_url = None audio_url = None
@ -1319,17 +1319,17 @@ class KlingLipSyncBase(KlingNodeBase):
auth_kwargs=kwargs, auth_kwargs=kwargs,
) )
task_creation_response = initial_operation.execute() task_creation_response = await initial_operation.execute()
validate_task_creation_response(task_creation_response) validate_task_creation_response(task_creation_response)
task_id = task_creation_response.data.task_id task_id = task_creation_response.data.task_id
final_response = self.get_response( final_response = await self.get_response(
task_id, auth_kwargs=kwargs, node_id=unique_id task_id, auth_kwargs=kwargs, node_id=unique_id
) )
validate_video_result_response(final_response) validate_video_result_response(final_response)
video = get_video_from_response(final_response) video = get_video_from_response(final_response)
return video_result_to_node_output(video) return await video_result_to_node_output(video)
class KlingLipSyncAudioToVideoNode(KlingLipSyncBase): class KlingLipSyncAudioToVideoNode(KlingLipSyncBase):
@ -1357,7 +1357,7 @@ class KlingLipSyncAudioToVideoNode(KlingLipSyncBase):
DESCRIPTION = "Kling Lip Sync Audio to Video Node. Syncs mouth movements in a video file to the audio content of an audio file. When using, ensure that the audio contains clearly distinguishable vocals and that the video contains a distinct face. The audio file should not be larger than 5MB. The video file should not be larger than 100MB, should have height/width between 720px and 1920px, and should be between 2s and 10s in length." DESCRIPTION = "Kling Lip Sync Audio to Video Node. Syncs mouth movements in a video file to the audio content of an audio file. When using, ensure that the audio contains clearly distinguishable vocals and that the video contains a distinct face. The audio file should not be larger than 5MB. The video file should not be larger than 100MB, should have height/width between 720px and 1920px, and should be between 2s and 10s in length."
def api_call( async def api_call(
self, self,
video: VideoInput, video: VideoInput,
audio: AudioInput, audio: AudioInput,
@ -1365,7 +1365,7 @@ class KlingLipSyncAudioToVideoNode(KlingLipSyncBase):
unique_id: Optional[str] = None, unique_id: Optional[str] = None,
**kwargs, **kwargs,
): ):
return super().api_call( return await super().api_call(
video=video, video=video,
audio=audio, audio=audio,
voice_language=voice_language, voice_language=voice_language,
@ -1469,7 +1469,7 @@ class KlingLipSyncTextToVideoNode(KlingLipSyncBase):
DESCRIPTION = "Kling Lip Sync Text to Video Node. Syncs mouth movements in a video file to a text prompt. The video file should not be larger than 100MB, should have height/width between 720px and 1920px, and should be between 2s and 10s in length." DESCRIPTION = "Kling Lip Sync Text to Video Node. Syncs mouth movements in a video file to a text prompt. The video file should not be larger than 100MB, should have height/width between 720px and 1920px, and should be between 2s and 10s in length."
def api_call( async def api_call(
self, self,
video: VideoInput, video: VideoInput,
text: str, text: str,
@ -1479,7 +1479,7 @@ class KlingLipSyncTextToVideoNode(KlingLipSyncBase):
**kwargs, **kwargs,
): ):
voice_id, voice_language = KlingLipSyncTextToVideoNode.get_voice_config()[voice] voice_id, voice_language = KlingLipSyncTextToVideoNode.get_voice_config()[voice]
return super().api_call( return await super().api_call(
video=video, video=video,
text=text, text=text,
voice_language=voice_language, voice_language=voice_language,
@ -1533,10 +1533,10 @@ class KlingVirtualTryOnNode(KlingImageGenerationBase):
DESCRIPTION = "Kling Virtual Try On Node. Input a human image and a cloth image to try on the cloth on the human. You can merge multiple clothing item pictures into one image with a white background." DESCRIPTION = "Kling Virtual Try On Node. Input a human image and a cloth image to try on the cloth on the human. You can merge multiple clothing item pictures into one image with a white background."
def get_response( async def get_response(
self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
) -> KlingVirtualTryOnResponse: ) -> KlingVirtualTryOnResponse:
return poll_until_finished( return await poll_until_finished(
auth_kwargs, auth_kwargs,
ApiEndpoint( ApiEndpoint(
path=f"{PATH_VIRTUAL_TRY_ON}/{task_id}", path=f"{PATH_VIRTUAL_TRY_ON}/{task_id}",
@ -1549,7 +1549,7 @@ class KlingVirtualTryOnNode(KlingImageGenerationBase):
node_id=node_id, node_id=node_id,
) )
def api_call( async def api_call(
self, self,
human_image: torch.Tensor, human_image: torch.Tensor,
cloth_image: torch.Tensor, cloth_image: torch.Tensor,
@ -1572,17 +1572,17 @@ class KlingVirtualTryOnNode(KlingImageGenerationBase):
auth_kwargs=kwargs, auth_kwargs=kwargs,
) )
task_creation_response = initial_operation.execute() task_creation_response = await initial_operation.execute()
validate_task_creation_response(task_creation_response) validate_task_creation_response(task_creation_response)
task_id = task_creation_response.data.task_id task_id = task_creation_response.data.task_id
final_response = self.get_response( final_response = await self.get_response(
task_id, auth_kwargs=kwargs, node_id=unique_id task_id, auth_kwargs=kwargs, node_id=unique_id
) )
validate_image_result_response(final_response) validate_image_result_response(final_response)
images = get_images_from_response(final_response) images = get_images_from_response(final_response)
return (image_result_to_node_output(images),) return (await image_result_to_node_output(images),)
class KlingImageGenerationNode(KlingImageGenerationBase): class KlingImageGenerationNode(KlingImageGenerationBase):
@ -1655,13 +1655,13 @@ class KlingImageGenerationNode(KlingImageGenerationBase):
DESCRIPTION = "Kling Image Generation Node. Generate an image from a text prompt with an optional reference image." DESCRIPTION = "Kling Image Generation Node. Generate an image from a text prompt with an optional reference image."
def get_response( async def get_response(
self, self,
task_id: str, task_id: str,
auth_kwargs: Optional[dict[str, str]], auth_kwargs: Optional[dict[str, str]],
node_id: Optional[str] = None, node_id: Optional[str] = None,
) -> KlingImageGenerationsResponse: ) -> KlingImageGenerationsResponse:
return poll_until_finished( return await poll_until_finished(
auth_kwargs, auth_kwargs,
ApiEndpoint( ApiEndpoint(
path=f"{PATH_IMAGE_GENERATIONS}/{task_id}", path=f"{PATH_IMAGE_GENERATIONS}/{task_id}",
@ -1674,7 +1674,7 @@ class KlingImageGenerationNode(KlingImageGenerationBase):
node_id=node_id, node_id=node_id,
) )
def api_call( async def api_call(
self, self,
model_name: KlingImageGenModelName, model_name: KlingImageGenModelName,
prompt: str, prompt: str,
@ -1714,17 +1714,17 @@ class KlingImageGenerationNode(KlingImageGenerationBase):
auth_kwargs=kwargs, auth_kwargs=kwargs,
) )
task_creation_response = initial_operation.execute() task_creation_response = await initial_operation.execute()
validate_task_creation_response(task_creation_response) validate_task_creation_response(task_creation_response)
task_id = task_creation_response.data.task_id task_id = task_creation_response.data.task_id
final_response = self.get_response( final_response = await self.get_response(
task_id, auth_kwargs=kwargs, node_id=unique_id task_id, auth_kwargs=kwargs, node_id=unique_id
) )
validate_image_result_response(final_response) validate_image_result_response(final_response)
images = get_images_from_response(final_response) images = get_images_from_response(final_response)
return (image_result_to_node_output(images),) return (await image_result_to_node_output(images),)
NODE_CLASS_MAPPINGS = { NODE_CLASS_MAPPINGS = {

View File

@ -38,7 +38,7 @@ from comfy_api_nodes.apinode_utils import (
) )
from server import PromptServer from server import PromptServer
import requests import aiohttp
import torch import torch
from io import BytesIO from io import BytesIO
@ -217,7 +217,7 @@ class LumaImageGenerationNode(ComfyNodeABC):
}, },
} }
def api_call( async def api_call(
self, self,
prompt: str, prompt: str,
model: str, model: str,
@ -234,19 +234,19 @@ class LumaImageGenerationNode(ComfyNodeABC):
# handle image_luma_ref # handle image_luma_ref
api_image_ref = None api_image_ref = None
if image_luma_ref is not None: if image_luma_ref is not None:
api_image_ref = self._convert_luma_refs( api_image_ref = await self._convert_luma_refs(
image_luma_ref, max_refs=4, auth_kwargs=kwargs, image_luma_ref, max_refs=4, auth_kwargs=kwargs,
) )
# handle style_luma_ref # handle style_luma_ref
api_style_ref = None api_style_ref = None
if style_image is not None: if style_image is not None:
api_style_ref = self._convert_style_image( api_style_ref = await self._convert_style_image(
style_image, weight=style_image_weight, auth_kwargs=kwargs, style_image, weight=style_image_weight, auth_kwargs=kwargs,
) )
# handle character_ref images # handle character_ref images
character_ref = None character_ref = None
if character_image is not None: if character_image is not None:
download_urls = upload_images_to_comfyapi( download_urls = await upload_images_to_comfyapi(
character_image, max_images=4, auth_kwargs=kwargs, character_image, max_images=4, auth_kwargs=kwargs,
) )
character_ref = LumaCharacterRef( character_ref = LumaCharacterRef(
@ -270,7 +270,7 @@ class LumaImageGenerationNode(ComfyNodeABC):
), ),
auth_kwargs=kwargs, auth_kwargs=kwargs,
) )
response_api: LumaGeneration = operation.execute() response_api: LumaGeneration = await operation.execute()
operation = PollingOperation( operation = PollingOperation(
poll_endpoint=ApiEndpoint( poll_endpoint=ApiEndpoint(
@ -286,19 +286,20 @@ class LumaImageGenerationNode(ComfyNodeABC):
node_id=unique_id, node_id=unique_id,
auth_kwargs=kwargs, auth_kwargs=kwargs,
) )
response_poll = operation.execute() response_poll = await operation.execute()
img_response = requests.get(response_poll.assets.image) async with aiohttp.ClientSession() as session:
img = process_image_response(img_response) async with session.get(response_poll.assets.image) as img_response:
img = process_image_response(await img_response.content.read())
return (img,) return (img,)
def _convert_luma_refs( async def _convert_luma_refs(
self, luma_ref: LumaReferenceChain, max_refs: int, auth_kwargs: Optional[dict[str,str]] = None self, luma_ref: LumaReferenceChain, max_refs: int, auth_kwargs: Optional[dict[str,str]] = None
): ):
luma_urls = [] luma_urls = []
ref_count = 0 ref_count = 0
for ref in luma_ref.refs: for ref in luma_ref.refs:
download_urls = upload_images_to_comfyapi( download_urls = await upload_images_to_comfyapi(
ref.image, max_images=1, auth_kwargs=auth_kwargs ref.image, max_images=1, auth_kwargs=auth_kwargs
) )
luma_urls.append(download_urls[0]) luma_urls.append(download_urls[0])
@ -307,13 +308,13 @@ class LumaImageGenerationNode(ComfyNodeABC):
break break
return luma_ref.create_api_model(download_urls=luma_urls, max_refs=max_refs) return luma_ref.create_api_model(download_urls=luma_urls, max_refs=max_refs)
def _convert_style_image( async def _convert_style_image(
self, style_image: torch.Tensor, weight: float, auth_kwargs: Optional[dict[str,str]] = None self, style_image: torch.Tensor, weight: float, auth_kwargs: Optional[dict[str,str]] = None
): ):
chain = LumaReferenceChain( chain = LumaReferenceChain(
first_ref=LumaReference(image=style_image, weight=weight) first_ref=LumaReference(image=style_image, weight=weight)
) )
return self._convert_luma_refs(chain, max_refs=1, auth_kwargs=auth_kwargs) return await self._convert_luma_refs(chain, max_refs=1, auth_kwargs=auth_kwargs)
class LumaImageModifyNode(ComfyNodeABC): class LumaImageModifyNode(ComfyNodeABC):
@ -370,7 +371,7 @@ class LumaImageModifyNode(ComfyNodeABC):
}, },
} }
def api_call( async def api_call(
self, self,
prompt: str, prompt: str,
model: str, model: str,
@ -381,7 +382,7 @@ class LumaImageModifyNode(ComfyNodeABC):
**kwargs, **kwargs,
): ):
# first, upload image # first, upload image
download_urls = upload_images_to_comfyapi( download_urls = await upload_images_to_comfyapi(
image, max_images=1, auth_kwargs=kwargs, image, max_images=1, auth_kwargs=kwargs,
) )
image_url = download_urls[0] image_url = download_urls[0]
@ -402,7 +403,7 @@ class LumaImageModifyNode(ComfyNodeABC):
), ),
auth_kwargs=kwargs, auth_kwargs=kwargs,
) )
response_api: LumaGeneration = operation.execute() response_api: LumaGeneration = await operation.execute()
operation = PollingOperation( operation = PollingOperation(
poll_endpoint=ApiEndpoint( poll_endpoint=ApiEndpoint(
@ -418,10 +419,11 @@ class LumaImageModifyNode(ComfyNodeABC):
node_id=unique_id, node_id=unique_id,
auth_kwargs=kwargs, auth_kwargs=kwargs,
) )
response_poll = operation.execute() response_poll = await operation.execute()
img_response = requests.get(response_poll.assets.image) async with aiohttp.ClientSession() as session:
img = process_image_response(img_response) async with session.get(response_poll.assets.image) as img_response:
img = process_image_response(await img_response.content.read())
return (img,) return (img,)
@ -494,7 +496,7 @@ class LumaTextToVideoGenerationNode(ComfyNodeABC):
}, },
} }
def api_call( async def api_call(
self, self,
prompt: str, prompt: str,
model: str, model: str,
@ -529,7 +531,7 @@ class LumaTextToVideoGenerationNode(ComfyNodeABC):
), ),
auth_kwargs=kwargs, auth_kwargs=kwargs,
) )
response_api: LumaGeneration = operation.execute() response_api: LumaGeneration = await operation.execute()
if unique_id: if unique_id:
PromptServer.instance.send_progress_text(f"Luma video generation started: {response_api.id}", unique_id) PromptServer.instance.send_progress_text(f"Luma video generation started: {response_api.id}", unique_id)
@ -549,10 +551,11 @@ class LumaTextToVideoGenerationNode(ComfyNodeABC):
estimated_duration=LUMA_T2V_AVERAGE_DURATION, estimated_duration=LUMA_T2V_AVERAGE_DURATION,
auth_kwargs=kwargs, auth_kwargs=kwargs,
) )
response_poll = operation.execute() response_poll = await operation.execute()
vid_response = requests.get(response_poll.assets.video) async with aiohttp.ClientSession() as session:
return (VideoFromFile(BytesIO(vid_response.content)),) async with session.get(response_poll.assets.video) as vid_response:
return (VideoFromFile(BytesIO(await vid_response.content.read())),)
class LumaImageToVideoGenerationNode(ComfyNodeABC): class LumaImageToVideoGenerationNode(ComfyNodeABC):
@ -626,7 +629,7 @@ class LumaImageToVideoGenerationNode(ComfyNodeABC):
}, },
} }
def api_call( async def api_call(
self, self,
prompt: str, prompt: str,
model: str, model: str,
@ -644,7 +647,7 @@ class LumaImageToVideoGenerationNode(ComfyNodeABC):
raise Exception( raise Exception(
"At least one of first_image and last_image requires an input." "At least one of first_image and last_image requires an input."
) )
keyframes = self._convert_to_keyframes(first_image, last_image, auth_kwargs=kwargs) keyframes = await self._convert_to_keyframes(first_image, last_image, auth_kwargs=kwargs)
duration = duration if model != LumaVideoModel.ray_1_6 else None duration = duration if model != LumaVideoModel.ray_1_6 else None
resolution = resolution if model != LumaVideoModel.ray_1_6 else None resolution = resolution if model != LumaVideoModel.ray_1_6 else None
@ -667,7 +670,7 @@ class LumaImageToVideoGenerationNode(ComfyNodeABC):
), ),
auth_kwargs=kwargs, auth_kwargs=kwargs,
) )
response_api: LumaGeneration = operation.execute() response_api: LumaGeneration = await operation.execute()
if unique_id: if unique_id:
PromptServer.instance.send_progress_text(f"Luma video generation started: {response_api.id}", unique_id) PromptServer.instance.send_progress_text(f"Luma video generation started: {response_api.id}", unique_id)
@ -687,12 +690,13 @@ class LumaImageToVideoGenerationNode(ComfyNodeABC):
estimated_duration=LUMA_I2V_AVERAGE_DURATION, estimated_duration=LUMA_I2V_AVERAGE_DURATION,
auth_kwargs=kwargs, auth_kwargs=kwargs,
) )
response_poll = operation.execute() response_poll = await operation.execute()
vid_response = requests.get(response_poll.assets.video) async with aiohttp.ClientSession() as session:
return (VideoFromFile(BytesIO(vid_response.content)),) async with session.get(response_poll.assets.video) as vid_response:
return (VideoFromFile(BytesIO(await vid_response.content.read())),)
def _convert_to_keyframes( async def _convert_to_keyframes(
self, self,
first_image: torch.Tensor = None, first_image: torch.Tensor = None,
last_image: torch.Tensor = None, last_image: torch.Tensor = None,
@ -703,12 +707,12 @@ class LumaImageToVideoGenerationNode(ComfyNodeABC):
frame0 = None frame0 = None
frame1 = None frame1 = None
if first_image is not None: if first_image is not None:
download_urls = upload_images_to_comfyapi( download_urls = await upload_images_to_comfyapi(
first_image, max_images=1, auth_kwargs=auth_kwargs, first_image, max_images=1, auth_kwargs=auth_kwargs,
) )
frame0 = LumaImageReference(type="image", url=download_urls[0]) frame0 = LumaImageReference(type="image", url=download_urls[0])
if last_image is not None: if last_image is not None:
download_urls = upload_images_to_comfyapi( download_urls = await upload_images_to_comfyapi(
last_image, max_images=1, auth_kwargs=auth_kwargs, last_image, max_images=1, auth_kwargs=auth_kwargs,
) )
frame1 = LumaImageReference(type="image", url=download_urls[0]) frame1 = LumaImageReference(type="image", url=download_urls[0])

View File

@ -86,7 +86,7 @@ class MinimaxTextToVideoNode:
API_NODE = True API_NODE = True
OUTPUT_NODE = True OUTPUT_NODE = True
def generate_video( async def generate_video(
self, self,
prompt_text, prompt_text,
seed=0, seed=0,
@ -104,12 +104,12 @@ class MinimaxTextToVideoNode:
# upload image, if passed in # upload image, if passed in
image_url = None image_url = None
if image is not None: if image is not None:
image_url = upload_images_to_comfyapi(image, max_images=1, auth_kwargs=kwargs)[0] image_url = (await upload_images_to_comfyapi(image, max_images=1, auth_kwargs=kwargs))[0]
# TODO: figure out how to deal with subject properly, API returns invalid params when using S2V-01 model # TODO: figure out how to deal with subject properly, API returns invalid params when using S2V-01 model
subject_reference = None subject_reference = None
if subject is not None: if subject is not None:
subject_url = upload_images_to_comfyapi(subject, max_images=1, auth_kwargs=kwargs)[0] subject_url = (await upload_images_to_comfyapi(subject, max_images=1, auth_kwargs=kwargs))[0]
subject_reference = [SubjectReferenceItem(image=subject_url)] subject_reference = [SubjectReferenceItem(image=subject_url)]
@ -130,7 +130,7 @@ class MinimaxTextToVideoNode:
), ),
auth_kwargs=kwargs, auth_kwargs=kwargs,
) )
response = video_generate_operation.execute() response = await video_generate_operation.execute()
task_id = response.task_id task_id = response.task_id
if not task_id: if not task_id:
@ -151,7 +151,7 @@ class MinimaxTextToVideoNode:
node_id=unique_id, node_id=unique_id,
auth_kwargs=kwargs, auth_kwargs=kwargs,
) )
task_result = video_generate_operation.execute() task_result = await video_generate_operation.execute()
file_id = task_result.file_id file_id = task_result.file_id
if file_id is None: if file_id is None:
@ -167,7 +167,7 @@ class MinimaxTextToVideoNode:
request=EmptyRequest(), request=EmptyRequest(),
auth_kwargs=kwargs, auth_kwargs=kwargs,
) )
file_result = file_retrieve_operation.execute() file_result = await file_retrieve_operation.execute()
file_url = file_result.file.download_url file_url = file_result.file.download_url
if file_url is None: if file_url is None:
@ -182,7 +182,7 @@ class MinimaxTextToVideoNode:
message = f"Result URL: {file_url}" message = f"Result URL: {file_url}"
PromptServer.instance.send_progress_text(message, unique_id) PromptServer.instance.send_progress_text(message, unique_id)
video_io = download_url_to_bytesio(file_url) video_io = await download_url_to_bytesio(file_url)
if video_io is None: if video_io is None:
error_msg = f"Failed to download video from {file_url}" error_msg = f"Failed to download video from {file_url}"
logging.error(error_msg) logging.error(error_msg)

View File

@ -95,14 +95,14 @@ def get_video_url_from_response(response) -> Optional[str]:
return None return None
def poll_until_finished( async def poll_until_finished(
auth_kwargs: dict[str, str], auth_kwargs: dict[str, str],
api_endpoint: ApiEndpoint[Any, R], api_endpoint: ApiEndpoint[Any, R],
result_url_extractor: Optional[Callable[[R], str]] = None, result_url_extractor: Optional[Callable[[R], str]] = None,
node_id: Optional[str] = None, node_id: Optional[str] = None,
) -> R: ) -> R:
"""Polls the Moonvalley API endpoint until the task reaches a terminal state, then returns the response.""" """Polls the Moonvalley API endpoint until the task reaches a terminal state, then returns the response."""
return PollingOperation( return await PollingOperation(
poll_endpoint=api_endpoint, poll_endpoint=api_endpoint,
completed_statuses=[ completed_statuses=[
"completed", "completed",
@ -394,10 +394,10 @@ class BaseMoonvalleyVideoNode:
else: else:
return control_map["Motion Transfer"] return control_map["Motion Transfer"]
def get_response( async def get_response(
self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
) -> MoonvalleyPromptResponse: ) -> MoonvalleyPromptResponse:
return poll_until_finished( return await poll_until_finished(
auth_kwargs, auth_kwargs,
ApiEndpoint( ApiEndpoint(
path=f"{API_PROMPTS_ENDPOINT}/{task_id}", path=f"{API_PROMPTS_ENDPOINT}/{task_id}",
@ -507,7 +507,7 @@ class MoonvalleyImg2VideoNode(BaseMoonvalleyVideoNode):
RETURN_NAMES = ("video",) RETURN_NAMES = ("video",)
DESCRIPTION = "Moonvalley Marey Image to Video Node" DESCRIPTION = "Moonvalley Marey Image to Video Node"
def generate( async def generate(
self, prompt, negative_prompt, unique_id: Optional[str] = None, **kwargs self, prompt, negative_prompt, unique_id: Optional[str] = None, **kwargs
): ):
image = kwargs.get("image", None) image = kwargs.get("image", None)
@ -532,9 +532,9 @@ class MoonvalleyImg2VideoNode(BaseMoonvalleyVideoNode):
# Get MIME type from tensor - assuming PNG format for image tensors # Get MIME type from tensor - assuming PNG format for image tensors
mime_type = "image/png" mime_type = "image/png"
image_url = upload_images_to_comfyapi( image_url = (await upload_images_to_comfyapi(
image, max_images=1, auth_kwargs=kwargs, mime_type=mime_type image, max_images=1, auth_kwargs=kwargs, mime_type=mime_type
)[0] ))[0]
request = MoonvalleyTextToVideoRequest( request = MoonvalleyTextToVideoRequest(
image_url=image_url, prompt_text=prompt, inference_params=inference_params image_url=image_url, prompt_text=prompt, inference_params=inference_params
@ -549,14 +549,14 @@ class MoonvalleyImg2VideoNode(BaseMoonvalleyVideoNode):
request=request, request=request,
auth_kwargs=kwargs, auth_kwargs=kwargs,
) )
task_creation_response = initial_operation.execute() task_creation_response = await initial_operation.execute()
validate_task_creation_response(task_creation_response) validate_task_creation_response(task_creation_response)
task_id = task_creation_response.id task_id = task_creation_response.id
final_response = self.get_response( final_response = await self.get_response(
task_id, auth_kwargs=kwargs, node_id=unique_id task_id, auth_kwargs=kwargs, node_id=unique_id
) )
video = download_url_to_video_output(final_response.output_url) video = await download_url_to_video_output(final_response.output_url)
return (video,) return (video,)
@ -609,7 +609,7 @@ class MoonvalleyVideo2VideoNode(BaseMoonvalleyVideoNode):
RETURN_TYPES = ("VIDEO",) RETURN_TYPES = ("VIDEO",)
RETURN_NAMES = ("video",) RETURN_NAMES = ("video",)
def generate( async def generate(
self, prompt, negative_prompt, unique_id: Optional[str] = None, **kwargs self, prompt, negative_prompt, unique_id: Optional[str] = None, **kwargs
): ):
video = kwargs.get("video") video = kwargs.get("video")
@ -620,7 +620,7 @@ class MoonvalleyVideo2VideoNode(BaseMoonvalleyVideoNode):
video_url = "" video_url = ""
if video: if video:
validated_video = validate_video_to_video_input(video) validated_video = validate_video_to_video_input(video)
video_url = upload_video_to_comfyapi(validated_video, auth_kwargs=kwargs) video_url = await upload_video_to_comfyapi(validated_video, auth_kwargs=kwargs)
control_type = kwargs.get("control_type") control_type = kwargs.get("control_type")
motion_intensity = kwargs.get("motion_intensity") motion_intensity = kwargs.get("motion_intensity")
@ -658,15 +658,15 @@ class MoonvalleyVideo2VideoNode(BaseMoonvalleyVideoNode):
request=request, request=request,
auth_kwargs=kwargs, auth_kwargs=kwargs,
) )
task_creation_response = initial_operation.execute() task_creation_response = await initial_operation.execute()
validate_task_creation_response(task_creation_response) validate_task_creation_response(task_creation_response)
task_id = task_creation_response.id task_id = task_creation_response.id
final_response = self.get_response( final_response = await self.get_response(
task_id, auth_kwargs=kwargs, node_id=unique_id task_id, auth_kwargs=kwargs, node_id=unique_id
) )
video = download_url_to_video_output(final_response.output_url) video = await download_url_to_video_output(final_response.output_url)
return (video,) return (video,)
@ -688,7 +688,7 @@ class MoonvalleyTxt2VideoNode(BaseMoonvalleyVideoNode):
del input_types["optional"][param] del input_types["optional"][param]
return input_types return input_types
def generate( async def generate(
self, prompt, negative_prompt, unique_id: Optional[str] = None, **kwargs self, prompt, negative_prompt, unique_id: Optional[str] = None, **kwargs
): ):
validate_prompts(prompt, negative_prompt, MOONVALLEY_MAREY_MAX_PROMPT_LENGTH) validate_prompts(prompt, negative_prompt, MOONVALLEY_MAREY_MAX_PROMPT_LENGTH)
@ -717,15 +717,15 @@ class MoonvalleyTxt2VideoNode(BaseMoonvalleyVideoNode):
request=request, request=request,
auth_kwargs=kwargs, auth_kwargs=kwargs,
) )
task_creation_response = initial_operation.execute() task_creation_response = await initial_operation.execute()
validate_task_creation_response(task_creation_response) validate_task_creation_response(task_creation_response)
task_id = task_creation_response.id task_id = task_creation_response.id
final_response = self.get_response( final_response = await self.get_response(
task_id, auth_kwargs=kwargs, node_id=unique_id task_id, auth_kwargs=kwargs, node_id=unique_id
) )
video = download_url_to_video_output(final_response.output_url) video = await download_url_to_video_output(final_response.output_url)
return (video,) return (video,)

View File

@ -163,7 +163,7 @@ class OpenAIDalle2(ComfyNodeABC):
DESCRIPTION = cleandoc(__doc__ or "") DESCRIPTION = cleandoc(__doc__ or "")
API_NODE = True API_NODE = True
def api_call( async def api_call(
self, self,
prompt, prompt,
seed=0, seed=0,
@ -233,9 +233,9 @@ class OpenAIDalle2(ComfyNodeABC):
auth_kwargs=kwargs, auth_kwargs=kwargs,
) )
response = operation.execute() response = await operation.execute()
img_tensor = validate_and_cast_response(response, node_id=unique_id) img_tensor = await validate_and_cast_response(response, node_id=unique_id)
return (img_tensor,) return (img_tensor,)
@ -311,7 +311,7 @@ class OpenAIDalle3(ComfyNodeABC):
DESCRIPTION = cleandoc(__doc__ or "") DESCRIPTION = cleandoc(__doc__ or "")
API_NODE = True API_NODE = True
def api_call( async def api_call(
self, self,
prompt, prompt,
seed=0, seed=0,
@ -343,9 +343,9 @@ class OpenAIDalle3(ComfyNodeABC):
auth_kwargs=kwargs, auth_kwargs=kwargs,
) )
response = operation.execute() response = await operation.execute()
img_tensor = validate_and_cast_response(response, node_id=unique_id) img_tensor = await validate_and_cast_response(response, node_id=unique_id)
return (img_tensor,) return (img_tensor,)
@ -446,7 +446,7 @@ class OpenAIGPTImage1(ComfyNodeABC):
DESCRIPTION = cleandoc(__doc__ or "") DESCRIPTION = cleandoc(__doc__ or "")
API_NODE = True API_NODE = True
def api_call( async def api_call(
self, self,
prompt, prompt,
seed=0, seed=0,
@ -537,9 +537,9 @@ class OpenAIGPTImage1(ComfyNodeABC):
auth_kwargs=kwargs, auth_kwargs=kwargs,
) )
response = operation.execute() response = await operation.execute()
img_tensor = validate_and_cast_response(response, node_id=unique_id) img_tensor = await validate_and_cast_response(response, node_id=unique_id)
return (img_tensor,) return (img_tensor,)
@ -623,7 +623,7 @@ class OpenAIChatNode(OpenAITextNode):
DESCRIPTION = "Generate text responses from an OpenAI model." DESCRIPTION = "Generate text responses from an OpenAI model."
def get_result_response( async def get_result_response(
self, self,
response_id: str, response_id: str,
include: Optional[list[Includable]] = None, include: Optional[list[Includable]] = None,
@ -639,7 +639,7 @@ class OpenAIChatNode(OpenAITextNode):
creation above for more information. creation above for more information.
""" """
return PollingOperation( return await PollingOperation(
poll_endpoint=ApiEndpoint( poll_endpoint=ApiEndpoint(
path=f"{RESPONSES_ENDPOINT}/{response_id}", path=f"{RESPONSES_ENDPOINT}/{response_id}",
method=HttpMethod.GET, method=HttpMethod.GET,
@ -784,7 +784,7 @@ class OpenAIChatNode(OpenAITextNode):
self.history[session_id] = new_history self.history[session_id] = new_history
def api_call( async def api_call(
self, self,
prompt: str, prompt: str,
persist_context: bool, persist_context: bool,
@ -815,7 +815,7 @@ class OpenAIChatNode(OpenAITextNode):
previous_response_id = None previous_response_id = None
# Create response # Create response
create_response = SynchronousOperation( create_response = await SynchronousOperation(
endpoint=ApiEndpoint( endpoint=ApiEndpoint(
path=RESPONSES_ENDPOINT, path=RESPONSES_ENDPOINT,
method=HttpMethod.POST, method=HttpMethod.POST,
@ -848,7 +848,7 @@ class OpenAIChatNode(OpenAITextNode):
response_id = create_response.id response_id = create_response.id
# Get result output # Get result output
result_response = self.get_result_response(response_id, auth_kwargs=kwargs) result_response = await self.get_result_response(response_id, auth_kwargs=kwargs)
output_text = self.parse_output_text_from_response(result_response) output_text = self.parse_output_text_from_response(result_response)
# Update history # Update history

View File

@ -122,7 +122,7 @@ class PikaNodeBase(ComfyNodeABC):
FUNCTION = "api_call" FUNCTION = "api_call"
RETURN_TYPES = ("VIDEO",) RETURN_TYPES = ("VIDEO",)
def poll_for_task_status( async def poll_for_task_status(
self, self,
task_id: str, task_id: str,
auth_kwargs: Optional[dict[str, str]] = None, auth_kwargs: Optional[dict[str, str]] = None,
@ -152,9 +152,9 @@ class PikaNodeBase(ComfyNodeABC):
node_id=node_id, node_id=node_id,
estimated_duration=60 estimated_duration=60
) )
return polling_operation.execute() return await polling_operation.execute()
def execute_task( async def execute_task(
self, self,
initial_operation: SynchronousOperation[R, PikaGenerateResponse], initial_operation: SynchronousOperation[R, PikaGenerateResponse],
auth_kwargs: Optional[dict[str, str]] = None, auth_kwargs: Optional[dict[str, str]] = None,
@ -169,14 +169,14 @@ class PikaNodeBase(ComfyNodeABC):
Returns: Returns:
A tuple containing the video file as a VIDEO output. A tuple containing the video file as a VIDEO output.
""" """
initial_response = initial_operation.execute() initial_response = await initial_operation.execute()
if not is_valid_initial_response(initial_response): if not is_valid_initial_response(initial_response):
error_msg = f"Pika initial request failed. Code: {initial_response.code}, Message: {initial_response.message}, Data: {initial_response.data}" error_msg = f"Pika initial request failed. Code: {initial_response.code}, Message: {initial_response.message}, Data: {initial_response.data}"
logging.error(error_msg) logging.error(error_msg)
raise PikaApiError(error_msg) raise PikaApiError(error_msg)
task_id = initial_response.video_id task_id = initial_response.video_id
final_response = self.poll_for_task_status(task_id, auth_kwargs) final_response = await self.poll_for_task_status(task_id, auth_kwargs)
if not is_valid_video_response(final_response): if not is_valid_video_response(final_response):
error_msg = ( error_msg = (
f"Pika task {task_id} succeeded but no video data found in response." f"Pika task {task_id} succeeded but no video data found in response."
@ -187,7 +187,7 @@ class PikaNodeBase(ComfyNodeABC):
video_url = str(final_response.url) video_url = str(final_response.url)
logging.info("Pika task %s succeeded. Video URL: %s", task_id, video_url) logging.info("Pika task %s succeeded. Video URL: %s", task_id, video_url)
return (download_url_to_video_output(video_url),) return (await download_url_to_video_output(video_url),)
class PikaImageToVideoV2_2(PikaNodeBase): class PikaImageToVideoV2_2(PikaNodeBase):
@ -212,7 +212,7 @@ class PikaImageToVideoV2_2(PikaNodeBase):
DESCRIPTION = "Sends an image and prompt to the Pika API v2.2 to generate a video." DESCRIPTION = "Sends an image and prompt to the Pika API v2.2 to generate a video."
def api_call( async def api_call(
self, self,
image: torch.Tensor, image: torch.Tensor,
prompt_text: str, prompt_text: str,
@ -251,7 +251,7 @@ class PikaImageToVideoV2_2(PikaNodeBase):
auth_kwargs=kwargs, auth_kwargs=kwargs,
) )
return self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id) return await self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id)
class PikaTextToVideoNodeV2_2(PikaNodeBase): class PikaTextToVideoNodeV2_2(PikaNodeBase):
@ -281,7 +281,7 @@ class PikaTextToVideoNodeV2_2(PikaNodeBase):
DESCRIPTION = "Sends a text prompt to the Pika API v2.2 to generate a video." DESCRIPTION = "Sends a text prompt to the Pika API v2.2 to generate a video."
def api_call( async def api_call(
self, self,
prompt_text: str, prompt_text: str,
negative_prompt: str, negative_prompt: str,
@ -311,7 +311,7 @@ class PikaTextToVideoNodeV2_2(PikaNodeBase):
content_type="application/x-www-form-urlencoded", content_type="application/x-www-form-urlencoded",
) )
return self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id) return await self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id)
class PikaScenesV2_2(PikaNodeBase): class PikaScenesV2_2(PikaNodeBase):
@ -361,7 +361,7 @@ class PikaScenesV2_2(PikaNodeBase):
DESCRIPTION = "Combine your images to create a video with the objects in them. Upload multiple images as ingredients and generate a high-quality video that incorporates all of them." DESCRIPTION = "Combine your images to create a video with the objects in them. Upload multiple images as ingredients and generate a high-quality video that incorporates all of them."
def api_call( async def api_call(
self, self,
prompt_text: str, prompt_text: str,
negative_prompt: str, negative_prompt: str,
@ -420,7 +420,7 @@ class PikaScenesV2_2(PikaNodeBase):
auth_kwargs=kwargs, auth_kwargs=kwargs,
) )
return self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id) return await self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id)
class PikAdditionsNode(PikaNodeBase): class PikAdditionsNode(PikaNodeBase):
@ -462,7 +462,7 @@ class PikAdditionsNode(PikaNodeBase):
DESCRIPTION = "Add any object or image into your video. Upload a video and specify what you'd like to add to create a seamlessly integrated result." DESCRIPTION = "Add any object or image into your video. Upload a video and specify what you'd like to add to create a seamlessly integrated result."
def api_call( async def api_call(
self, self,
video: VideoInput, video: VideoInput,
image: torch.Tensor, image: torch.Tensor,
@ -481,10 +481,10 @@ class PikAdditionsNode(PikaNodeBase):
image_bytes_io = tensor_to_bytesio(image) image_bytes_io = tensor_to_bytesio(image)
image_bytes_io.seek(0) image_bytes_io.seek(0)
pika_files = [ pika_files = {
("video", ("video.mp4", video_bytes_io, "video/mp4")), "video": ("video.mp4", video_bytes_io, "video/mp4"),
("image", ("image.png", image_bytes_io, "image/png")), "image": ("image.png", image_bytes_io, "image/png"),
] }
# Prepare non-file data # Prepare non-file data
pika_request_data = PikaBodyGeneratePikadditionsGeneratePikadditionsPost( pika_request_data = PikaBodyGeneratePikadditionsGeneratePikadditionsPost(
@ -506,7 +506,7 @@ class PikAdditionsNode(PikaNodeBase):
auth_kwargs=kwargs, auth_kwargs=kwargs,
) )
return self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id) return await self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id)
class PikaSwapsNode(PikaNodeBase): class PikaSwapsNode(PikaNodeBase):
@ -558,7 +558,7 @@ class PikaSwapsNode(PikaNodeBase):
DESCRIPTION = "Swap out any object or region of your video with a new image or object. Define areas to replace either with a mask or coordinates." DESCRIPTION = "Swap out any object or region of your video with a new image or object. Define areas to replace either with a mask or coordinates."
RETURN_TYPES = ("VIDEO",) RETURN_TYPES = ("VIDEO",)
def api_call( async def api_call(
self, self,
video: VideoInput, video: VideoInput,
image: torch.Tensor, image: torch.Tensor,
@ -587,11 +587,11 @@ class PikaSwapsNode(PikaNodeBase):
image_bytes_io = tensor_to_bytesio(image) image_bytes_io = tensor_to_bytesio(image)
image_bytes_io.seek(0) image_bytes_io.seek(0)
pika_files = [ pika_files = {
("video", ("video.mp4", video_bytes_io, "video/mp4")), "video": ("video.mp4", video_bytes_io, "video/mp4"),
("image", ("image.png", image_bytes_io, "image/png")), "image": ("image.png", image_bytes_io, "image/png"),
("modifyRegionMask", ("mask.png", mask_bytes_io, "image/png")), "modifyRegionMask": ("mask.png", mask_bytes_io, "image/png"),
] }
# Prepare non-file data # Prepare non-file data
pika_request_data = PikaBodyGeneratePikaswapsGeneratePikaswapsPost( pika_request_data = PikaBodyGeneratePikaswapsGeneratePikaswapsPost(
@ -613,7 +613,7 @@ class PikaSwapsNode(PikaNodeBase):
auth_kwargs=kwargs, auth_kwargs=kwargs,
) )
return self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id) return await self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id)
class PikaffectsNode(PikaNodeBase): class PikaffectsNode(PikaNodeBase):
@ -664,7 +664,7 @@ class PikaffectsNode(PikaNodeBase):
DESCRIPTION = "Generate a video with a specific Pikaffect. Supported Pikaffects: Cake-ify, Crumble, Crush, Decapitate, Deflate, Dissolve, Explode, Eye-pop, Inflate, Levitate, Melt, Peel, Poke, Squish, Ta-da, Tear" DESCRIPTION = "Generate a video with a specific Pikaffect. Supported Pikaffects: Cake-ify, Crumble, Crush, Decapitate, Deflate, Dissolve, Explode, Eye-pop, Inflate, Levitate, Melt, Peel, Poke, Squish, Ta-da, Tear"
def api_call( async def api_call(
self, self,
image: torch.Tensor, image: torch.Tensor,
pikaffect: str, pikaffect: str,
@ -693,7 +693,7 @@ class PikaffectsNode(PikaNodeBase):
auth_kwargs=kwargs, auth_kwargs=kwargs,
) )
return self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id) return await self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id)
class PikaStartEndFrameNode2_2(PikaNodeBase): class PikaStartEndFrameNode2_2(PikaNodeBase):
@ -718,7 +718,7 @@ class PikaStartEndFrameNode2_2(PikaNodeBase):
DESCRIPTION = "Generate a video by combining your first and last frame. Upload two images to define the start and end points, and let the AI create a smooth transition between them." DESCRIPTION = "Generate a video by combining your first and last frame. Upload two images to define the start and end points, and let the AI create a smooth transition between them."
def api_call( async def api_call(
self, self,
image_start: torch.Tensor, image_start: torch.Tensor,
image_end: torch.Tensor, image_end: torch.Tensor,
@ -732,10 +732,7 @@ class PikaStartEndFrameNode2_2(PikaNodeBase):
) -> tuple[VideoFromFile]: ) -> tuple[VideoFromFile]:
pika_files = [ pika_files = [
( ("keyFrames", ("image_start.png", tensor_to_bytesio(image_start), "image/png")),
"keyFrames",
("image_start.png", tensor_to_bytesio(image_start), "image/png"),
),
("keyFrames", ("image_end.png", tensor_to_bytesio(image_end), "image/png")), ("keyFrames", ("image_end.png", tensor_to_bytesio(image_end), "image/png")),
] ]
@ -758,7 +755,7 @@ class PikaStartEndFrameNode2_2(PikaNodeBase):
auth_kwargs=kwargs, auth_kwargs=kwargs,
) )
return self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id) return await self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id)
NODE_CLASS_MAPPINGS = { NODE_CLASS_MAPPINGS = {

View File

@ -30,7 +30,7 @@ from comfy.comfy_types.node_typing import IO, ComfyNodeABC
from comfy_api.input_impl import VideoFromFile from comfy_api.input_impl import VideoFromFile
import torch import torch
import requests import aiohttp
from io import BytesIO from io import BytesIO
@ -47,7 +47,7 @@ def get_video_url_from_response(
return str(response.Resp.url) return str(response.Resp.url)
def upload_image_to_pixverse(image: torch.Tensor, auth_kwargs=None): async def upload_image_to_pixverse(image: torch.Tensor, auth_kwargs=None):
# first, upload image to Pixverse and get image id to use in actual generation call # first, upload image to Pixverse and get image id to use in actual generation call
files = {"image": tensor_to_bytesio(image)} files = {"image": tensor_to_bytesio(image)}
operation = SynchronousOperation( operation = SynchronousOperation(
@ -62,7 +62,7 @@ def upload_image_to_pixverse(image: torch.Tensor, auth_kwargs=None):
content_type="multipart/form-data", content_type="multipart/form-data",
auth_kwargs=auth_kwargs, auth_kwargs=auth_kwargs,
) )
response_upload: PixverseImageUploadResponse = operation.execute() response_upload: PixverseImageUploadResponse = await operation.execute()
if response_upload.Resp is None: if response_upload.Resp is None:
raise Exception( raise Exception(
@ -164,7 +164,7 @@ class PixverseTextToVideoNode(ComfyNodeABC):
}, },
} }
def api_call( async def api_call(
self, self,
prompt: str, prompt: str,
aspect_ratio: str, aspect_ratio: str,
@ -205,7 +205,7 @@ class PixverseTextToVideoNode(ComfyNodeABC):
), ),
auth_kwargs=kwargs, auth_kwargs=kwargs,
) )
response_api = operation.execute() response_api = await operation.execute()
if response_api.Resp is None: if response_api.Resp is None:
raise Exception(f"PixVerse request failed: '{response_api.ErrMsg}'") raise Exception(f"PixVerse request failed: '{response_api.ErrMsg}'")
@ -229,11 +229,11 @@ class PixverseTextToVideoNode(ComfyNodeABC):
result_url_extractor=get_video_url_from_response, result_url_extractor=get_video_url_from_response,
estimated_duration=AVERAGE_DURATION_T2V, estimated_duration=AVERAGE_DURATION_T2V,
) )
response_poll = operation.execute() response_poll = await operation.execute()
vid_response = requests.get(response_poll.Resp.url) async with aiohttp.ClientSession() as session:
async with session.get(response_poll.Resp.url) as vid_response:
return (VideoFromFile(BytesIO(vid_response.content)),) return (VideoFromFile(BytesIO(await vid_response.content.read())),)
class PixverseImageToVideoNode(ComfyNodeABC): class PixverseImageToVideoNode(ComfyNodeABC):
@ -302,7 +302,7 @@ class PixverseImageToVideoNode(ComfyNodeABC):
}, },
} }
def api_call( async def api_call(
self, self,
image: torch.Tensor, image: torch.Tensor,
prompt: str, prompt: str,
@ -316,7 +316,7 @@ class PixverseImageToVideoNode(ComfyNodeABC):
**kwargs, **kwargs,
): ):
validate_string(prompt, strip_whitespace=False) validate_string(prompt, strip_whitespace=False)
img_id = upload_image_to_pixverse(image, auth_kwargs=kwargs) img_id = await upload_image_to_pixverse(image, auth_kwargs=kwargs)
# 1080p is limited to 5 seconds duration # 1080p is limited to 5 seconds duration
# only normal motion_mode supported for 1080p or for non-5 second duration # only normal motion_mode supported for 1080p or for non-5 second duration
@ -345,7 +345,7 @@ class PixverseImageToVideoNode(ComfyNodeABC):
), ),
auth_kwargs=kwargs, auth_kwargs=kwargs,
) )
response_api = operation.execute() response_api = await operation.execute()
if response_api.Resp is None: if response_api.Resp is None:
raise Exception(f"PixVerse request failed: '{response_api.ErrMsg}'") raise Exception(f"PixVerse request failed: '{response_api.ErrMsg}'")
@ -369,10 +369,11 @@ class PixverseImageToVideoNode(ComfyNodeABC):
result_url_extractor=get_video_url_from_response, result_url_extractor=get_video_url_from_response,
estimated_duration=AVERAGE_DURATION_I2V, estimated_duration=AVERAGE_DURATION_I2V,
) )
response_poll = operation.execute() response_poll = await operation.execute()
vid_response = requests.get(response_poll.Resp.url) async with aiohttp.ClientSession() as session:
return (VideoFromFile(BytesIO(vid_response.content)),) async with session.get(response_poll.Resp.url) as vid_response:
return (VideoFromFile(BytesIO(await vid_response.content.read())),)
class PixverseTransitionVideoNode(ComfyNodeABC): class PixverseTransitionVideoNode(ComfyNodeABC):
@ -436,7 +437,7 @@ class PixverseTransitionVideoNode(ComfyNodeABC):
}, },
} }
def api_call( async def api_call(
self, self,
first_frame: torch.Tensor, first_frame: torch.Tensor,
last_frame: torch.Tensor, last_frame: torch.Tensor,
@ -450,8 +451,8 @@ class PixverseTransitionVideoNode(ComfyNodeABC):
**kwargs, **kwargs,
): ):
validate_string(prompt, strip_whitespace=False) validate_string(prompt, strip_whitespace=False)
first_frame_id = upload_image_to_pixverse(first_frame, auth_kwargs=kwargs) first_frame_id = await upload_image_to_pixverse(first_frame, auth_kwargs=kwargs)
last_frame_id = upload_image_to_pixverse(last_frame, auth_kwargs=kwargs) last_frame_id = await upload_image_to_pixverse(last_frame, auth_kwargs=kwargs)
# 1080p is limited to 5 seconds duration # 1080p is limited to 5 seconds duration
# only normal motion_mode supported for 1080p or for non-5 second duration # only normal motion_mode supported for 1080p or for non-5 second duration
@ -480,7 +481,7 @@ class PixverseTransitionVideoNode(ComfyNodeABC):
), ),
auth_kwargs=kwargs, auth_kwargs=kwargs,
) )
response_api = operation.execute() response_api = await operation.execute()
if response_api.Resp is None: if response_api.Resp is None:
raise Exception(f"PixVerse request failed: '{response_api.ErrMsg}'") raise Exception(f"PixVerse request failed: '{response_api.ErrMsg}'")
@ -504,10 +505,11 @@ class PixverseTransitionVideoNode(ComfyNodeABC):
result_url_extractor=get_video_url_from_response, result_url_extractor=get_video_url_from_response,
estimated_duration=AVERAGE_DURATION_T2V, estimated_duration=AVERAGE_DURATION_T2V,
) )
response_poll = operation.execute() response_poll = await operation.execute()
vid_response = requests.get(response_poll.Resp.url) async with aiohttp.ClientSession() as session:
return (VideoFromFile(BytesIO(vid_response.content)),) async with session.get(response_poll.Resp.url) as vid_response:
return (VideoFromFile(BytesIO(await vid_response.content.read())),)
NODE_CLASS_MAPPINGS = { NODE_CLASS_MAPPINGS = {

View File

@ -37,7 +37,7 @@ from io import BytesIO
from PIL import UnidentifiedImageError from PIL import UnidentifiedImageError
def handle_recraft_file_request( async def handle_recraft_file_request(
image: torch.Tensor, image: torch.Tensor,
path: str, path: str,
mask: torch.Tensor=None, mask: torch.Tensor=None,
@ -71,13 +71,13 @@ def handle_recraft_file_request(
auth_kwargs=auth_kwargs, auth_kwargs=auth_kwargs,
multipart_parser=recraft_multipart_parser, multipart_parser=recraft_multipart_parser,
) )
response: RecraftImageGenerationResponse = operation.execute() response: RecraftImageGenerationResponse = await operation.execute()
all_bytesio = [] all_bytesio = []
if response.image is not None: if response.image is not None:
all_bytesio.append(download_url_to_bytesio(response.image.url, timeout=timeout)) all_bytesio.append(await download_url_to_bytesio(response.image.url, timeout=timeout))
else: else:
for data in response.data: for data in response.data:
all_bytesio.append(download_url_to_bytesio(data.url, timeout=timeout)) all_bytesio.append(await download_url_to_bytesio(data.url, timeout=timeout))
return all_bytesio return all_bytesio
@ -395,7 +395,7 @@ class RecraftTextToImageNode:
}, },
} }
def api_call( async def api_call(
self, self,
prompt: str, prompt: str,
size: str, size: str,
@ -439,7 +439,7 @@ class RecraftTextToImageNode:
), ),
auth_kwargs=kwargs, auth_kwargs=kwargs,
) )
response: RecraftImageGenerationResponse = operation.execute() response: RecraftImageGenerationResponse = await operation.execute()
images = [] images = []
urls = [] urls = []
for data in response.data: for data in response.data:
@ -451,7 +451,7 @@ class RecraftTextToImageNode:
f"Result URL: {urls_string}", unique_id f"Result URL: {urls_string}", unique_id
) )
image = bytesio_to_image_tensor( image = bytesio_to_image_tensor(
download_url_to_bytesio(data.url, timeout=1024) await download_url_to_bytesio(data.url, timeout=1024)
) )
if len(image.shape) < 4: if len(image.shape) < 4:
image = image.unsqueeze(0) image = image.unsqueeze(0)
@ -538,7 +538,7 @@ class RecraftImageToImageNode:
}, },
} }
def api_call( async def api_call(
self, self,
image: torch.Tensor, image: torch.Tensor,
prompt: str, prompt: str,
@ -578,7 +578,7 @@ class RecraftImageToImageNode:
total = image.shape[0] total = image.shape[0]
pbar = ProgressBar(total) pbar = ProgressBar(total)
for i in range(total): for i in range(total):
sub_bytes = handle_recraft_file_request( sub_bytes = await handle_recraft_file_request(
image=image[i], image=image[i],
path="/proxy/recraft/images/imageToImage", path="/proxy/recraft/images/imageToImage",
request=request, request=request,
@ -654,7 +654,7 @@ class RecraftImageInpaintingNode:
}, },
} }
def api_call( async def api_call(
self, self,
image: torch.Tensor, image: torch.Tensor,
mask: torch.Tensor, mask: torch.Tensor,
@ -690,7 +690,7 @@ class RecraftImageInpaintingNode:
total = image.shape[0] total = image.shape[0]
pbar = ProgressBar(total) pbar = ProgressBar(total)
for i in range(total): for i in range(total):
sub_bytes = handle_recraft_file_request( sub_bytes = await handle_recraft_file_request(
image=image[i], image=image[i],
mask=mask[i:i+1], mask=mask[i:i+1],
path="/proxy/recraft/images/inpaint", path="/proxy/recraft/images/inpaint",
@ -779,7 +779,7 @@ class RecraftTextToVectorNode:
}, },
} }
def api_call( async def api_call(
self, self,
prompt: str, prompt: str,
substyle: str, substyle: str,
@ -821,7 +821,7 @@ class RecraftTextToVectorNode:
), ),
auth_kwargs=kwargs, auth_kwargs=kwargs,
) )
response: RecraftImageGenerationResponse = operation.execute() response: RecraftImageGenerationResponse = await operation.execute()
svg_data = [] svg_data = []
urls = [] urls = []
for data in response.data: for data in response.data:
@ -831,7 +831,7 @@ class RecraftTextToVectorNode:
PromptServer.instance.send_progress_text( PromptServer.instance.send_progress_text(
f"Result URL: {' '.join(urls)}", unique_id f"Result URL: {' '.join(urls)}", unique_id
) )
svg_data.append(download_url_to_bytesio(data.url, timeout=1024)) svg_data.append(await download_url_to_bytesio(data.url, timeout=1024))
return (SVG(svg_data),) return (SVG(svg_data),)
@ -861,7 +861,7 @@ class RecraftVectorizeImageNode:
}, },
} }
def api_call( async def api_call(
self, self,
image: torch.Tensor, image: torch.Tensor,
**kwargs, **kwargs,
@ -870,7 +870,7 @@ class RecraftVectorizeImageNode:
total = image.shape[0] total = image.shape[0]
pbar = ProgressBar(total) pbar = ProgressBar(total)
for i in range(total): for i in range(total):
sub_bytes = handle_recraft_file_request( sub_bytes = await handle_recraft_file_request(
image=image[i], image=image[i],
path="/proxy/recraft/images/vectorize", path="/proxy/recraft/images/vectorize",
auth_kwargs=kwargs, auth_kwargs=kwargs,
@ -942,7 +942,7 @@ class RecraftReplaceBackgroundNode:
}, },
} }
def api_call( async def api_call(
self, self,
image: torch.Tensor, image: torch.Tensor,
prompt: str, prompt: str,
@ -973,7 +973,7 @@ class RecraftReplaceBackgroundNode:
total = image.shape[0] total = image.shape[0]
pbar = ProgressBar(total) pbar = ProgressBar(total)
for i in range(total): for i in range(total):
sub_bytes = handle_recraft_file_request( sub_bytes = await handle_recraft_file_request(
image=image[i], image=image[i],
path="/proxy/recraft/images/replaceBackground", path="/proxy/recraft/images/replaceBackground",
request=request, request=request,
@ -1011,7 +1011,7 @@ class RecraftRemoveBackgroundNode:
}, },
} }
def api_call( async def api_call(
self, self,
image: torch.Tensor, image: torch.Tensor,
**kwargs, **kwargs,
@ -1020,7 +1020,7 @@ class RecraftRemoveBackgroundNode:
total = image.shape[0] total = image.shape[0]
pbar = ProgressBar(total) pbar = ProgressBar(total)
for i in range(total): for i in range(total):
sub_bytes = handle_recraft_file_request( sub_bytes = await handle_recraft_file_request(
image=image[i], image=image[i],
path="/proxy/recraft/images/removeBackground", path="/proxy/recraft/images/removeBackground",
auth_kwargs=kwargs, auth_kwargs=kwargs,
@ -1062,7 +1062,7 @@ class RecraftCrispUpscaleNode:
}, },
} }
def api_call( async def api_call(
self, self,
image: torch.Tensor, image: torch.Tensor,
**kwargs, **kwargs,
@ -1071,7 +1071,7 @@ class RecraftCrispUpscaleNode:
total = image.shape[0] total = image.shape[0]
pbar = ProgressBar(total) pbar = ProgressBar(total)
for i in range(total): for i in range(total):
sub_bytes = handle_recraft_file_request( sub_bytes = await handle_recraft_file_request(
image=image[i], image=image[i],
path=self.RECRAFT_PATH, path=self.RECRAFT_PATH,
auth_kwargs=kwargs, auth_kwargs=kwargs,

View File

@ -9,11 +9,10 @@ from __future__ import annotations
from inspect import cleandoc from inspect import cleandoc
from comfy.comfy_types.node_typing import IO from comfy.comfy_types.node_typing import IO
import folder_paths as comfy_paths import folder_paths as comfy_paths
import requests import aiohttp
import os import os
import datetime import datetime
import shutil import asyncio
import time
import io import io
import logging import logging
import math import math
@ -66,7 +65,6 @@ def create_task_error(response: Rodin3DGenerateResponse):
return hasattr(response, "error") return hasattr(response, "error")
class Rodin3DAPI: class Rodin3DAPI:
""" """
Generate 3D Assets using Rodin API Generate 3D Assets using Rodin API
@ -123,8 +121,8 @@ class Rodin3DAPI:
else: else:
return "Generating" return "Generating"
def CreateGenerateTask(self, images=None, seed=1, material="PBR", quality="medium", tier="Regular", mesh_mode="Quad", **kwargs): async def create_generate_task(self, images=None, seed=1, material="PBR", quality="medium", tier="Regular", mesh_mode="Quad", **kwargs):
if images == None: if images is None:
raise Exception("Rodin 3D generate requires at least 1 image.") raise Exception("Rodin 3D generate requires at least 1 image.")
if len(images) >= 5: if len(images) >= 5:
raise Exception("Rodin 3D generate requires up to 5 image.") raise Exception("Rodin 3D generate requires up to 5 image.")
@ -155,7 +153,7 @@ class Rodin3DAPI:
auth_kwargs=kwargs, auth_kwargs=kwargs,
) )
response = operation.execute() response = await operation.execute()
if create_task_error(response): if create_task_error(response):
error_message = f"Rodin3D Create 3D generate Task Failed. Message: {response.message}, error: {response.error}" error_message = f"Rodin3D Create 3D generate Task Failed. Message: {response.message}, error: {response.error}"
@ -168,7 +166,7 @@ class Rodin3DAPI:
logging.info(f"[ Rodin3D API - Submit Jobs ] UUID: {task_uuid}") logging.info(f"[ Rodin3D API - Submit Jobs ] UUID: {task_uuid}")
return task_uuid, subscription_key return task_uuid, subscription_key
def poll_for_task_status(self, subscription_key, **kwargs) -> Rodin3DCheckStatusResponse: async def poll_for_task_status(self, subscription_key, **kwargs) -> Rodin3DCheckStatusResponse:
path = "/proxy/rodin/api/v2/status" path = "/proxy/rodin/api/v2/status"
@ -191,11 +189,9 @@ class Rodin3DAPI:
logging.info("[ Rodin3D API - CheckStatus ] Generate Start!") logging.info("[ Rodin3D API - CheckStatus ] Generate Start!")
return poll_operation.execute() return await poll_operation.execute()
async def get_rodin_download_list(self, uuid, **kwargs) -> Rodin3DDownloadResponse:
def GetRodinDownloadList(self, uuid, **kwargs) -> Rodin3DDownloadResponse:
logging.info("[ Rodin3D API - Downloading ] Generate Successfully!") logging.info("[ Rodin3D API - Downloading ] Generate Successfully!")
path = "/proxy/rodin/api/v2/download" path = "/proxy/rodin/api/v2/download"
@ -212,53 +208,59 @@ class Rodin3DAPI:
auth_kwargs=kwargs auth_kwargs=kwargs
) )
return operation.execute() return await operation.execute()
def GetQualityAndMode(self, PolyCount): def get_quality_mode(self, poly_count):
if PolyCount == "200K-Triangle": if poly_count == "200K-Triangle":
mesh_mode = "Raw" mesh_mode = "Raw"
quality = "medium" quality = "medium"
else: else:
mesh_mode = "Quad" mesh_mode = "Quad"
if PolyCount == "4K-Quad": if poly_count == "4K-Quad":
quality = "extra-low" quality = "extra-low"
elif PolyCount == "8K-Quad": elif poly_count == "8K-Quad":
quality = "low" quality = "low"
elif PolyCount == "18K-Quad": elif poly_count == "18K-Quad":
quality = "medium" quality = "medium"
elif PolyCount == "50K-Quad": elif poly_count == "50K-Quad":
quality = "high" quality = "high"
else: else:
quality = "medium" quality = "medium"
return mesh_mode, quality return mesh_mode, quality
def DownLoadFiles(self, Url_List): async def download_files(self, url_list):
Save_path = os.path.join(comfy_paths.get_output_directory(), "Rodin3D", datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")) save_path = os.path.join(comfy_paths.get_output_directory(), "Rodin3D", datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"))
os.makedirs(Save_path, exist_ok=True) os.makedirs(save_path, exist_ok=True)
model_file_path = None model_file_path = None
for Item in Url_List.list: async with aiohttp.ClientSession() as session:
url = Item.url for i in url_list.list:
file_name = Item.name url = i.url
file_path = os.path.join(Save_path, file_name) file_name = i.name
file_path = os.path.join(save_path, file_name)
if file_path.endswith(".glb"): if file_path.endswith(".glb"):
model_file_path = file_path model_file_path = file_path
logging.info(f"[ Rodin3D API - download_files ] Downloading file: {file_path}") logging.info(f"[ Rodin3D API - download_files ] Downloading file: {file_path}")
max_retries = 5 max_retries = 5
for attempt in range(max_retries): for attempt in range(max_retries):
try: try:
with requests.get(url, stream=True) as r: async with session.get(url) as resp:
r.raise_for_status() resp.raise_for_status()
with open(file_path, "wb") as f: with open(file_path, "wb") as f:
shutil.copyfileobj(r.raw, f) async for chunk in resp.content.iter_chunked(32 * 1024):
f.write(chunk)
break break
except Exception as e: except Exception as e:
logging.info(f"[ Rodin3D API - download_files ] Error downloading {file_path}:{e}") logging.info(f"[ Rodin3D API - download_files ] Error downloading {file_path}:{e}")
if attempt < max_retries - 1: if attempt < max_retries - 1:
logging.info("Retrying...") logging.info("Retrying...")
time.sleep(2) await asyncio.sleep(2)
else: else:
logging.info(f"[ Rodin3D API - download_files ] Failed to download {file_path} after {max_retries} attempts.") logging.info(
"[ Rodin3D API - download_files ] Failed to download %s after %s attempts.",
file_path,
max_retries,
)
return model_file_path return model_file_path
@ -285,7 +287,7 @@ class Rodin3D_Regular(Rodin3DAPI):
}, },
} }
def api_call( async def api_call(
self, self,
Images, Images,
Seed, Seed,
@ -298,14 +300,17 @@ class Rodin3D_Regular(Rodin3DAPI):
m_images = [] m_images = []
for i in range(num_images): for i in range(num_images):
m_images.append(Images[i]) m_images.append(Images[i])
mesh_mode, quality = self.GetQualityAndMode(Polygon_count) mesh_mode, quality = self.get_quality_mode(Polygon_count)
task_uuid, subscription_key = self.CreateGenerateTask(images=m_images, seed=Seed, material=Material_Type, quality=quality, tier=tier, mesh_mode=mesh_mode, **kwargs) task_uuid, subscription_key = await self.create_generate_task(images=m_images, seed=Seed, material=Material_Type,
self.poll_for_task_status(subscription_key, **kwargs) quality=quality, tier=tier, mesh_mode=mesh_mode,
Download_List = self.GetRodinDownloadList(task_uuid, **kwargs) **kwargs)
model = self.DownLoadFiles(Download_List) await self.poll_for_task_status(subscription_key, **kwargs)
download_list = await self.get_rodin_download_list(task_uuid, **kwargs)
model = await self.download_files(download_list)
return (model,) return (model,)
class Rodin3D_Detail(Rodin3DAPI): class Rodin3D_Detail(Rodin3DAPI):
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
@ -328,7 +333,7 @@ class Rodin3D_Detail(Rodin3DAPI):
}, },
} }
def api_call( async def api_call(
self, self,
Images, Images,
Seed, Seed,
@ -341,14 +346,17 @@ class Rodin3D_Detail(Rodin3DAPI):
m_images = [] m_images = []
for i in range(num_images): for i in range(num_images):
m_images.append(Images[i]) m_images.append(Images[i])
mesh_mode, quality = self.GetQualityAndMode(Polygon_count) mesh_mode, quality = self.get_quality_mode(Polygon_count)
task_uuid, subscription_key = self.CreateGenerateTask(images=m_images, seed=Seed, material=Material_Type, quality=quality, tier=tier, mesh_mode=mesh_mode, **kwargs) task_uuid, subscription_key = await self.create_generate_task(images=m_images, seed=Seed, material=Material_Type,
self.poll_for_task_status(subscription_key, **kwargs) quality=quality, tier=tier, mesh_mode=mesh_mode,
Download_List = self.GetRodinDownloadList(task_uuid, **kwargs) **kwargs)
model = self.DownLoadFiles(Download_List) await self.poll_for_task_status(subscription_key, **kwargs)
download_list = await self.get_rodin_download_list(task_uuid, **kwargs)
model = await self.download_files(download_list)
return (model,) return (model,)
class Rodin3D_Smooth(Rodin3DAPI): class Rodin3D_Smooth(Rodin3DAPI):
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
@ -371,7 +379,7 @@ class Rodin3D_Smooth(Rodin3DAPI):
}, },
} }
def api_call( async def api_call(
self, self,
Images, Images,
Seed, Seed,
@ -384,14 +392,17 @@ class Rodin3D_Smooth(Rodin3DAPI):
m_images = [] m_images = []
for i in range(num_images): for i in range(num_images):
m_images.append(Images[i]) m_images.append(Images[i])
mesh_mode, quality = self.GetQualityAndMode(Polygon_count) mesh_mode, quality = self.get_quality_mode(Polygon_count)
task_uuid, subscription_key = self.CreateGenerateTask(images=m_images, seed=Seed, material=Material_Type, quality=quality, tier=tier, mesh_mode=mesh_mode, **kwargs) task_uuid, subscription_key = await self.create_generate_task(images=m_images, seed=Seed, material=Material_Type,
self.poll_for_task_status(subscription_key, **kwargs) quality=quality, tier=tier, mesh_mode=mesh_mode,
Download_List = self.GetRodinDownloadList(task_uuid, **kwargs) **kwargs)
model = self.DownLoadFiles(Download_List) await self.poll_for_task_status(subscription_key, **kwargs)
download_list = await self.get_rodin_download_list(task_uuid, **kwargs)
model = await self.download_files(download_list)
return (model,) return (model,)
class Rodin3D_Sketch(Rodin3DAPI): class Rodin3D_Sketch(Rodin3DAPI):
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
@ -423,7 +434,7 @@ class Rodin3D_Sketch(Rodin3DAPI):
}, },
} }
def api_call( async def api_call(
self, self,
Images, Images,
Seed, Seed,
@ -437,10 +448,12 @@ class Rodin3D_Sketch(Rodin3DAPI):
material_type = "PBR" material_type = "PBR"
quality = "medium" quality = "medium"
mesh_mode = "Quad" mesh_mode = "Quad"
task_uuid, subscription_key = self.CreateGenerateTask(images=m_images, seed=Seed, material=material_type, quality=quality, tier=tier, mesh_mode=mesh_mode, **kwargs) task_uuid, subscription_key = await self.create_generate_task(
self.poll_for_task_status(subscription_key, **kwargs) images=m_images, seed=Seed, material=material_type, quality=quality, tier=tier, mesh_mode=mesh_mode, **kwargs
Download_List = self.GetRodinDownloadList(task_uuid, **kwargs) )
model = self.DownLoadFiles(Download_List) await self.poll_for_task_status(subscription_key, **kwargs)
download_list = await self.get_rodin_download_list(task_uuid, **kwargs)
model = await self.download_files(download_list)
return (model,) return (model,)

View File

@ -99,14 +99,14 @@ def validate_input_image(image: torch.Tensor) -> bool:
return image.shape[2] < 8000 and image.shape[1] < 8000 return image.shape[2] < 8000 and image.shape[1] < 8000
def poll_until_finished( async def poll_until_finished(
auth_kwargs: dict[str, str], auth_kwargs: dict[str, str],
api_endpoint: ApiEndpoint[Any, TaskStatusResponse], api_endpoint: ApiEndpoint[Any, TaskStatusResponse],
estimated_duration: Optional[int] = None, estimated_duration: Optional[int] = None,
node_id: Optional[str] = None, node_id: Optional[str] = None,
) -> TaskStatusResponse: ) -> TaskStatusResponse:
"""Polls the Runway API endpoint until the task reaches a terminal state, then returns the response.""" """Polls the Runway API endpoint until the task reaches a terminal state, then returns the response."""
return PollingOperation( return await PollingOperation(
poll_endpoint=api_endpoint, poll_endpoint=api_endpoint,
completed_statuses=[ completed_statuses=[
TaskStatus.SUCCEEDED.value, TaskStatus.SUCCEEDED.value,
@ -115,7 +115,7 @@ def poll_until_finished(
TaskStatus.FAILED.value, TaskStatus.FAILED.value,
TaskStatus.CANCELLED.value, TaskStatus.CANCELLED.value,
], ],
status_extractor=lambda response: (response.status.value), status_extractor=lambda response: response.status.value,
auth_kwargs=auth_kwargs, auth_kwargs=auth_kwargs,
result_url_extractor=get_video_url_from_task_status, result_url_extractor=get_video_url_from_task_status,
estimated_duration=estimated_duration, estimated_duration=estimated_duration,
@ -167,11 +167,11 @@ class RunwayVideoGenNode(ComfyNodeABC):
) )
return True return True
def get_response( async def get_response(
self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
) -> RunwayImageToVideoResponse: ) -> RunwayImageToVideoResponse:
"""Poll the task status until it is finished then get the response.""" """Poll the task status until it is finished then get the response."""
return poll_until_finished( return await poll_until_finished(
auth_kwargs, auth_kwargs,
ApiEndpoint( ApiEndpoint(
path=f"{PATH_GET_TASK_STATUS}/{task_id}", path=f"{PATH_GET_TASK_STATUS}/{task_id}",
@ -183,7 +183,7 @@ class RunwayVideoGenNode(ComfyNodeABC):
node_id=node_id, node_id=node_id,
) )
def generate_video( async def generate_video(
self, self,
request: RunwayImageToVideoRequest, request: RunwayImageToVideoRequest,
auth_kwargs: dict[str, str], auth_kwargs: dict[str, str],
@ -200,15 +200,15 @@ class RunwayVideoGenNode(ComfyNodeABC):
auth_kwargs=auth_kwargs, auth_kwargs=auth_kwargs,
) )
initial_response = initial_operation.execute() initial_response = await initial_operation.execute()
self.validate_task_created(initial_response) self.validate_task_created(initial_response)
task_id = initial_response.id task_id = initial_response.id
final_response = self.get_response(task_id, auth_kwargs, node_id) final_response = await self.get_response(task_id, auth_kwargs, node_id)
self.validate_response(final_response) self.validate_response(final_response)
video_url = get_video_url_from_task_status(final_response) video_url = get_video_url_from_task_status(final_response)
return (download_url_to_video_output(video_url),) return (await download_url_to_video_output(video_url),)
class RunwayImageToVideoNodeGen3a(RunwayVideoGenNode): class RunwayImageToVideoNodeGen3a(RunwayVideoGenNode):
@ -250,7 +250,7 @@ class RunwayImageToVideoNodeGen3a(RunwayVideoGenNode):
}, },
} }
def api_call( async def api_call(
self, self,
prompt: str, prompt: str,
start_frame: torch.Tensor, start_frame: torch.Tensor,
@ -265,7 +265,7 @@ class RunwayImageToVideoNodeGen3a(RunwayVideoGenNode):
validate_input_image(start_frame) validate_input_image(start_frame)
# Upload image # Upload image
download_urls = upload_images_to_comfyapi( download_urls = await upload_images_to_comfyapi(
start_frame, start_frame,
max_images=1, max_images=1,
mime_type="image/png", mime_type="image/png",
@ -274,7 +274,7 @@ class RunwayImageToVideoNodeGen3a(RunwayVideoGenNode):
if len(download_urls) != 1: if len(download_urls) != 1:
raise RunwayApiError("Failed to upload one or more images to comfy api.") raise RunwayApiError("Failed to upload one or more images to comfy api.")
return self.generate_video( return await self.generate_video(
RunwayImageToVideoRequest( RunwayImageToVideoRequest(
promptText=prompt, promptText=prompt,
seed=seed, seed=seed,
@ -333,7 +333,7 @@ class RunwayImageToVideoNodeGen4(RunwayVideoGenNode):
}, },
} }
def api_call( async def api_call(
self, self,
prompt: str, prompt: str,
start_frame: torch.Tensor, start_frame: torch.Tensor,
@ -348,7 +348,7 @@ class RunwayImageToVideoNodeGen4(RunwayVideoGenNode):
validate_input_image(start_frame) validate_input_image(start_frame)
# Upload image # Upload image
download_urls = upload_images_to_comfyapi( download_urls = await upload_images_to_comfyapi(
start_frame, start_frame,
max_images=1, max_images=1,
mime_type="image/png", mime_type="image/png",
@ -357,7 +357,7 @@ class RunwayImageToVideoNodeGen4(RunwayVideoGenNode):
if len(download_urls) != 1: if len(download_urls) != 1:
raise RunwayApiError("Failed to upload one or more images to comfy api.") raise RunwayApiError("Failed to upload one or more images to comfy api.")
return self.generate_video( return await self.generate_video(
RunwayImageToVideoRequest( RunwayImageToVideoRequest(
promptText=prompt, promptText=prompt,
seed=seed, seed=seed,
@ -382,10 +382,10 @@ class RunwayFirstLastFrameNode(RunwayVideoGenNode):
DESCRIPTION = "Upload first and last keyframes, draft a prompt, and generate a video. More complex transitions, such as cases where the Last frame is completely different from the First frame, may benefit from the longer 10s duration. This would give the generation more time to smoothly transition between the two inputs. Before diving in, review these best practices to ensure that your input selections will set your generation up for success: https://help.runwayml.com/hc/en-us/articles/34170748696595-Creating-with-Keyframes-on-Gen-3." DESCRIPTION = "Upload first and last keyframes, draft a prompt, and generate a video. More complex transitions, such as cases where the Last frame is completely different from the First frame, may benefit from the longer 10s duration. This would give the generation more time to smoothly transition between the two inputs. Before diving in, review these best practices to ensure that your input selections will set your generation up for success: https://help.runwayml.com/hc/en-us/articles/34170748696595-Creating-with-Keyframes-on-Gen-3."
def get_response( async def get_response(
self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
) -> RunwayImageToVideoResponse: ) -> RunwayImageToVideoResponse:
return poll_until_finished( return await poll_until_finished(
auth_kwargs, auth_kwargs,
ApiEndpoint( ApiEndpoint(
path=f"{PATH_GET_TASK_STATUS}/{task_id}", path=f"{PATH_GET_TASK_STATUS}/{task_id}",
@ -437,7 +437,7 @@ class RunwayFirstLastFrameNode(RunwayVideoGenNode):
}, },
} }
def api_call( async def api_call(
self, self,
prompt: str, prompt: str,
start_frame: torch.Tensor, start_frame: torch.Tensor,
@ -455,7 +455,7 @@ class RunwayFirstLastFrameNode(RunwayVideoGenNode):
# Upload images # Upload images
stacked_input_images = image_tensor_pair_to_batch(start_frame, end_frame) stacked_input_images = image_tensor_pair_to_batch(start_frame, end_frame)
download_urls = upload_images_to_comfyapi( download_urls = await upload_images_to_comfyapi(
stacked_input_images, stacked_input_images,
max_images=2, max_images=2,
mime_type="image/png", mime_type="image/png",
@ -464,7 +464,7 @@ class RunwayFirstLastFrameNode(RunwayVideoGenNode):
if len(download_urls) != 2: if len(download_urls) != 2:
raise RunwayApiError("Failed to upload one or more images to comfy api.") raise RunwayApiError("Failed to upload one or more images to comfy api.")
return self.generate_video( return await self.generate_video(
RunwayImageToVideoRequest( RunwayImageToVideoRequest(
promptText=prompt, promptText=prompt,
seed=seed, seed=seed,
@ -543,11 +543,11 @@ class RunwayTextToImageNode(ComfyNodeABC):
) )
return True return True
def get_response( async def get_response(
self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
) -> TaskStatusResponse: ) -> TaskStatusResponse:
"""Poll the task status until it is finished then get the response.""" """Poll the task status until it is finished then get the response."""
return poll_until_finished( return await poll_until_finished(
auth_kwargs, auth_kwargs,
ApiEndpoint( ApiEndpoint(
path=f"{PATH_GET_TASK_STATUS}/{task_id}", path=f"{PATH_GET_TASK_STATUS}/{task_id}",
@ -559,7 +559,7 @@ class RunwayTextToImageNode(ComfyNodeABC):
node_id=node_id, node_id=node_id,
) )
def api_call( async def api_call(
self, self,
prompt: str, prompt: str,
ratio: str, ratio: str,
@ -574,7 +574,7 @@ class RunwayTextToImageNode(ComfyNodeABC):
reference_images = None reference_images = None
if reference_image is not None: if reference_image is not None:
validate_input_image(reference_image) validate_input_image(reference_image)
download_urls = upload_images_to_comfyapi( download_urls = await upload_images_to_comfyapi(
reference_image, reference_image,
max_images=1, max_images=1,
mime_type="image/png", mime_type="image/png",
@ -605,19 +605,19 @@ class RunwayTextToImageNode(ComfyNodeABC):
auth_kwargs=kwargs, auth_kwargs=kwargs,
) )
initial_response = initial_operation.execute() initial_response = await initial_operation.execute()
self.validate_task_created(initial_response) self.validate_task_created(initial_response)
task_id = initial_response.id task_id = initial_response.id
# Poll for completion # Poll for completion
final_response = self.get_response( final_response = await self.get_response(
task_id, auth_kwargs=kwargs, node_id=unique_id task_id, auth_kwargs=kwargs, node_id=unique_id
) )
self.validate_response(final_response) self.validate_response(final_response)
# Download and return image # Download and return image
image_url = get_image_url_from_task_status(final_response) image_url = get_image_url_from_task_status(final_response)
return (download_url_to_image_tensor(image_url),) return (await download_url_to_image_tensor(image_url),)
NODE_CLASS_MAPPINGS = { NODE_CLASS_MAPPINGS = {

View File

@ -124,7 +124,7 @@ class StabilityStableImageUltraNode:
}, },
} }
def api_call(self, prompt: str, aspect_ratio: str, style_preset: str, seed: int, async def api_call(self, prompt: str, aspect_ratio: str, style_preset: str, seed: int,
negative_prompt: str=None, image: torch.Tensor = None, image_denoise: float=None, negative_prompt: str=None, image: torch.Tensor = None, image_denoise: float=None,
**kwargs): **kwargs):
validate_string(prompt, strip_whitespace=False) validate_string(prompt, strip_whitespace=False)
@ -163,7 +163,7 @@ class StabilityStableImageUltraNode:
content_type="multipart/form-data", content_type="multipart/form-data",
auth_kwargs=kwargs, auth_kwargs=kwargs,
) )
response_api = operation.execute() response_api = await operation.execute()
if response_api.finish_reason != "SUCCESS": if response_api.finish_reason != "SUCCESS":
raise Exception(f"Stable Image Ultra generation failed: {response_api.finish_reason}.") raise Exception(f"Stable Image Ultra generation failed: {response_api.finish_reason}.")
@ -257,7 +257,7 @@ class StabilityStableImageSD_3_5Node:
}, },
} }
def api_call(self, model: str, prompt: str, aspect_ratio: str, style_preset: str, seed: int, cfg_scale: float, async def api_call(self, model: str, prompt: str, aspect_ratio: str, style_preset: str, seed: int, cfg_scale: float,
negative_prompt: str=None, image: torch.Tensor = None, image_denoise: float=None, negative_prompt: str=None, image: torch.Tensor = None, image_denoise: float=None,
**kwargs): **kwargs):
validate_string(prompt, strip_whitespace=False) validate_string(prompt, strip_whitespace=False)
@ -302,7 +302,7 @@ class StabilityStableImageSD_3_5Node:
content_type="multipart/form-data", content_type="multipart/form-data",
auth_kwargs=kwargs, auth_kwargs=kwargs,
) )
response_api = operation.execute() response_api = await operation.execute()
if response_api.finish_reason != "SUCCESS": if response_api.finish_reason != "SUCCESS":
raise Exception(f"Stable Diffusion 3.5 Image generation failed: {response_api.finish_reason}.") raise Exception(f"Stable Diffusion 3.5 Image generation failed: {response_api.finish_reason}.")
@ -374,7 +374,7 @@ class StabilityUpscaleConservativeNode:
}, },
} }
def api_call(self, image: torch.Tensor, prompt: str, creativity: float, seed: int, negative_prompt: str=None, async def api_call(self, image: torch.Tensor, prompt: str, creativity: float, seed: int, negative_prompt: str=None,
**kwargs): **kwargs):
validate_string(prompt, strip_whitespace=False) validate_string(prompt, strip_whitespace=False)
image_binary = tensor_to_bytesio(image, total_pixels=1024*1024).read() image_binary = tensor_to_bytesio(image, total_pixels=1024*1024).read()
@ -403,7 +403,7 @@ class StabilityUpscaleConservativeNode:
content_type="multipart/form-data", content_type="multipart/form-data",
auth_kwargs=kwargs, auth_kwargs=kwargs,
) )
response_api = operation.execute() response_api = await operation.execute()
if response_api.finish_reason != "SUCCESS": if response_api.finish_reason != "SUCCESS":
raise Exception(f"Stability Upscale Conservative generation failed: {response_api.finish_reason}.") raise Exception(f"Stability Upscale Conservative generation failed: {response_api.finish_reason}.")
@ -480,7 +480,7 @@ class StabilityUpscaleCreativeNode:
}, },
} }
def api_call(self, image: torch.Tensor, prompt: str, creativity: float, style_preset: str, seed: int, negative_prompt: str=None, async def api_call(self, image: torch.Tensor, prompt: str, creativity: float, style_preset: str, seed: int, negative_prompt: str=None,
**kwargs): **kwargs):
validate_string(prompt, strip_whitespace=False) validate_string(prompt, strip_whitespace=False)
image_binary = tensor_to_bytesio(image, total_pixels=1024*1024).read() image_binary = tensor_to_bytesio(image, total_pixels=1024*1024).read()
@ -512,7 +512,7 @@ class StabilityUpscaleCreativeNode:
content_type="multipart/form-data", content_type="multipart/form-data",
auth_kwargs=kwargs, auth_kwargs=kwargs,
) )
response_api = operation.execute() response_api = await operation.execute()
operation = PollingOperation( operation = PollingOperation(
poll_endpoint=ApiEndpoint( poll_endpoint=ApiEndpoint(
@ -527,7 +527,7 @@ class StabilityUpscaleCreativeNode:
status_extractor=lambda x: get_async_dummy_status(x), status_extractor=lambda x: get_async_dummy_status(x),
auth_kwargs=kwargs, auth_kwargs=kwargs,
) )
response_poll: StabilityResultsGetResponse = operation.execute() response_poll: StabilityResultsGetResponse = await operation.execute()
if response_poll.finish_reason != "SUCCESS": if response_poll.finish_reason != "SUCCESS":
raise Exception(f"Stability Upscale Creative generation failed: {response_poll.finish_reason}.") raise Exception(f"Stability Upscale Creative generation failed: {response_poll.finish_reason}.")
@ -563,8 +563,7 @@ class StabilityUpscaleFastNode:
}, },
} }
def api_call(self, image: torch.Tensor, async def api_call(self, image: torch.Tensor, **kwargs):
**kwargs):
image_binary = tensor_to_bytesio(image, total_pixels=4096*4096).read() image_binary = tensor_to_bytesio(image, total_pixels=4096*4096).read()
files = { files = {
@ -583,7 +582,7 @@ class StabilityUpscaleFastNode:
content_type="multipart/form-data", content_type="multipart/form-data",
auth_kwargs=kwargs, auth_kwargs=kwargs,
) )
response_api = operation.execute() response_api = await operation.execute()
if response_api.finish_reason != "SUCCESS": if response_api.finish_reason != "SUCCESS":
raise Exception(f"Stability Upscale Fast failed: {response_api.finish_reason}.") raise Exception(f"Stability Upscale Fast failed: {response_api.finish_reason}.")

View File

@ -37,8 +37,8 @@ from comfy_api_nodes.apinode_utils import (
) )
def upload_image_to_tripo(image, **kwargs): async def upload_image_to_tripo(image, **kwargs):
urls = upload_images_to_comfyapi(image, max_images=1, auth_kwargs=kwargs) urls = await upload_images_to_comfyapi(image, max_images=1, auth_kwargs=kwargs)
return TripoFileReference(TripoUrlReference(url=urls[0], type="jpeg")) return TripoFileReference(TripoUrlReference(url=urls[0], type="jpeg"))
def get_model_url_from_response(response: TripoTaskResponse) -> str: def get_model_url_from_response(response: TripoTaskResponse) -> str:
@ -49,7 +49,7 @@ def get_model_url_from_response(response: TripoTaskResponse) -> str:
raise RuntimeError(f"Failed to get model url from response: {response}") raise RuntimeError(f"Failed to get model url from response: {response}")
def poll_until_finished( async def poll_until_finished(
kwargs: dict[str, str], kwargs: dict[str, str],
response: TripoTaskResponse, response: TripoTaskResponse,
) -> tuple[str, str]: ) -> tuple[str, str]:
@ -57,7 +57,7 @@ def poll_until_finished(
if response.code != 0: if response.code != 0:
raise RuntimeError(f"Failed to generate mesh: {response.error}") raise RuntimeError(f"Failed to generate mesh: {response.error}")
task_id = response.data.task_id task_id = response.data.task_id
response_poll = PollingOperation( response_poll = await PollingOperation(
poll_endpoint=ApiEndpoint( poll_endpoint=ApiEndpoint(
path=f"/proxy/tripo/v2/openapi/task/{task_id}", path=f"/proxy/tripo/v2/openapi/task/{task_id}",
method=HttpMethod.GET, method=HttpMethod.GET,
@ -80,7 +80,7 @@ def poll_until_finished(
).execute() ).execute()
if response_poll.data.status == TripoTaskStatus.SUCCESS: if response_poll.data.status == TripoTaskStatus.SUCCESS:
url = get_model_url_from_response(response_poll) url = get_model_url_from_response(response_poll)
bytesio = download_url_to_bytesio(url) bytesio = await download_url_to_bytesio(url)
# Save the downloaded model file # Save the downloaded model file
model_file = f"tripo_model_{task_id}.glb" model_file = f"tripo_model_{task_id}.glb"
with open(os.path.join(get_output_directory(), model_file), "wb") as f: with open(os.path.join(get_output_directory(), model_file), "wb") as f:
@ -88,6 +88,7 @@ def poll_until_finished(
return model_file, task_id return model_file, task_id
raise RuntimeError(f"Failed to generate mesh: {response_poll}") raise RuntimeError(f"Failed to generate mesh: {response_poll}")
class TripoTextToModelNode: class TripoTextToModelNode:
""" """
Generates 3D models synchronously based on a text prompt using Tripo's API. Generates 3D models synchronously based on a text prompt using Tripo's API.
@ -126,11 +127,11 @@ class TripoTextToModelNode:
API_NODE = True API_NODE = True
OUTPUT_NODE = True OUTPUT_NODE = True
def generate_mesh(self, prompt, negative_prompt=None, model_version=None, style=None, texture=None, pbr=None, image_seed=None, model_seed=None, texture_seed=None, texture_quality=None, face_limit=None, quad=None, **kwargs): async def generate_mesh(self, prompt, negative_prompt=None, model_version=None, style=None, texture=None, pbr=None, image_seed=None, model_seed=None, texture_seed=None, texture_quality=None, face_limit=None, quad=None, **kwargs):
style_enum = None if style == "None" else style style_enum = None if style == "None" else style
if not prompt: if not prompt:
raise RuntimeError("Prompt is required") raise RuntimeError("Prompt is required")
response = SynchronousOperation( response = await SynchronousOperation(
endpoint=ApiEndpoint( endpoint=ApiEndpoint(
path="/proxy/tripo/v2/openapi/task", path="/proxy/tripo/v2/openapi/task",
method=HttpMethod.POST, method=HttpMethod.POST,
@ -155,7 +156,8 @@ class TripoTextToModelNode:
), ),
auth_kwargs=kwargs, auth_kwargs=kwargs,
).execute() ).execute()
return poll_until_finished(kwargs, response) return await poll_until_finished(kwargs, response)
class TripoImageToModelNode: class TripoImageToModelNode:
""" """
@ -195,12 +197,12 @@ class TripoImageToModelNode:
API_NODE = True API_NODE = True
OUTPUT_NODE = True OUTPUT_NODE = True
def generate_mesh(self, image, model_version=None, style=None, texture=None, pbr=None, model_seed=None, orientation=None, texture_alignment=None, texture_seed=None, texture_quality=None, face_limit=None, quad=None, **kwargs): async def generate_mesh(self, image, model_version=None, style=None, texture=None, pbr=None, model_seed=None, orientation=None, texture_alignment=None, texture_seed=None, texture_quality=None, face_limit=None, quad=None, **kwargs):
style_enum = None if style == "None" else style style_enum = None if style == "None" else style
if image is None: if image is None:
raise RuntimeError("Image is required") raise RuntimeError("Image is required")
tripo_file = upload_image_to_tripo(image, **kwargs) tripo_file = await upload_image_to_tripo(image, **kwargs)
response = SynchronousOperation( response = await SynchronousOperation(
endpoint=ApiEndpoint( endpoint=ApiEndpoint(
path="/proxy/tripo/v2/openapi/task", path="/proxy/tripo/v2/openapi/task",
method=HttpMethod.POST, method=HttpMethod.POST,
@ -225,7 +227,8 @@ class TripoImageToModelNode:
), ),
auth_kwargs=kwargs, auth_kwargs=kwargs,
).execute() ).execute()
return poll_until_finished(kwargs, response) return await poll_until_finished(kwargs, response)
class TripoMultiviewToModelNode: class TripoMultiviewToModelNode:
""" """
@ -267,7 +270,7 @@ class TripoMultiviewToModelNode:
API_NODE = True API_NODE = True
OUTPUT_NODE = True OUTPUT_NODE = True
def generate_mesh(self, image, image_left=None, image_back=None, image_right=None, model_version=None, orientation=None, texture=None, pbr=None, model_seed=None, texture_seed=None, texture_quality=None, texture_alignment=None, face_limit=None, quad=None, **kwargs): async def generate_mesh(self, image, image_left=None, image_back=None, image_right=None, model_version=None, orientation=None, texture=None, pbr=None, model_seed=None, texture_seed=None, texture_quality=None, texture_alignment=None, face_limit=None, quad=None, **kwargs):
if image is None: if image is None:
raise RuntimeError("front image for multiview is required") raise RuntimeError("front image for multiview is required")
images = [] images = []
@ -282,11 +285,11 @@ class TripoMultiviewToModelNode:
for image_name in ["image", "image_left", "image_back", "image_right"]: for image_name in ["image", "image_left", "image_back", "image_right"]:
image_ = image_dict[image_name] image_ = image_dict[image_name]
if image_ is not None: if image_ is not None:
tripo_file = upload_image_to_tripo(image_, **kwargs) tripo_file = await upload_image_to_tripo(image_, **kwargs)
images.append(tripo_file) images.append(tripo_file)
else: else:
images.append(TripoFileEmptyReference()) images.append(TripoFileEmptyReference())
response = SynchronousOperation( response = await SynchronousOperation(
endpoint=ApiEndpoint( endpoint=ApiEndpoint(
path="/proxy/tripo/v2/openapi/task", path="/proxy/tripo/v2/openapi/task",
method=HttpMethod.POST, method=HttpMethod.POST,
@ -309,7 +312,8 @@ class TripoMultiviewToModelNode:
), ),
auth_kwargs=kwargs, auth_kwargs=kwargs,
).execute() ).execute()
return poll_until_finished(kwargs, response) return await poll_until_finished(kwargs, response)
class TripoTextureNode: class TripoTextureNode:
@classmethod @classmethod
@ -340,8 +344,8 @@ class TripoTextureNode:
OUTPUT_NODE = True OUTPUT_NODE = True
AVERAGE_DURATION = 80 AVERAGE_DURATION = 80
def generate_mesh(self, model_task_id, texture=None, pbr=None, texture_seed=None, texture_quality=None, texture_alignment=None, **kwargs): async def generate_mesh(self, model_task_id, texture=None, pbr=None, texture_seed=None, texture_quality=None, texture_alignment=None, **kwargs):
response = SynchronousOperation( response = await SynchronousOperation(
endpoint=ApiEndpoint( endpoint=ApiEndpoint(
path="/proxy/tripo/v2/openapi/task", path="/proxy/tripo/v2/openapi/task",
method=HttpMethod.POST, method=HttpMethod.POST,
@ -358,7 +362,7 @@ class TripoTextureNode:
), ),
auth_kwargs=kwargs, auth_kwargs=kwargs,
).execute() ).execute()
return poll_until_finished(kwargs, response) return await poll_until_finished(kwargs, response)
class TripoRefineNode: class TripoRefineNode:
@ -387,8 +391,8 @@ class TripoRefineNode:
OUTPUT_NODE = True OUTPUT_NODE = True
AVERAGE_DURATION = 240 AVERAGE_DURATION = 240
def generate_mesh(self, model_task_id, **kwargs): async def generate_mesh(self, model_task_id, **kwargs):
response = SynchronousOperation( response = await SynchronousOperation(
endpoint=ApiEndpoint( endpoint=ApiEndpoint(
path="/proxy/tripo/v2/openapi/task", path="/proxy/tripo/v2/openapi/task",
method=HttpMethod.POST, method=HttpMethod.POST,
@ -400,7 +404,7 @@ class TripoRefineNode:
), ),
auth_kwargs=kwargs, auth_kwargs=kwargs,
).execute() ).execute()
return poll_until_finished(kwargs, response) return await poll_until_finished(kwargs, response)
class TripoRigNode: class TripoRigNode:
@ -425,8 +429,8 @@ class TripoRigNode:
OUTPUT_NODE = True OUTPUT_NODE = True
AVERAGE_DURATION = 180 AVERAGE_DURATION = 180
def generate_mesh(self, original_model_task_id, **kwargs): async def generate_mesh(self, original_model_task_id, **kwargs):
response = SynchronousOperation( response = await SynchronousOperation(
endpoint=ApiEndpoint( endpoint=ApiEndpoint(
path="/proxy/tripo/v2/openapi/task", path="/proxy/tripo/v2/openapi/task",
method=HttpMethod.POST, method=HttpMethod.POST,
@ -440,7 +444,8 @@ class TripoRigNode:
), ),
auth_kwargs=kwargs, auth_kwargs=kwargs,
).execute() ).execute()
return poll_until_finished(kwargs, response) return await poll_until_finished(kwargs, response)
class TripoRetargetNode: class TripoRetargetNode:
@classmethod @classmethod
@ -475,8 +480,8 @@ class TripoRetargetNode:
OUTPUT_NODE = True OUTPUT_NODE = True
AVERAGE_DURATION = 30 AVERAGE_DURATION = 30
def generate_mesh(self, animation, original_model_task_id, **kwargs): async def generate_mesh(self, animation, original_model_task_id, **kwargs):
response = SynchronousOperation( response = await SynchronousOperation(
endpoint=ApiEndpoint( endpoint=ApiEndpoint(
path="/proxy/tripo/v2/openapi/task", path="/proxy/tripo/v2/openapi/task",
method=HttpMethod.POST, method=HttpMethod.POST,
@ -491,7 +496,8 @@ class TripoRetargetNode:
), ),
auth_kwargs=kwargs, auth_kwargs=kwargs,
).execute() ).execute()
return poll_until_finished(kwargs, response) return await poll_until_finished(kwargs, response)
class TripoConversionNode: class TripoConversionNode:
@classmethod @classmethod
@ -529,10 +535,10 @@ class TripoConversionNode:
OUTPUT_NODE = True OUTPUT_NODE = True
AVERAGE_DURATION = 30 AVERAGE_DURATION = 30
def generate_mesh(self, original_model_task_id, format, quad, face_limit, texture_size, texture_format, **kwargs): async def generate_mesh(self, original_model_task_id, format, quad, face_limit, texture_size, texture_format, **kwargs):
if not original_model_task_id: if not original_model_task_id:
raise RuntimeError("original_model_task_id is required") raise RuntimeError("original_model_task_id is required")
response = SynchronousOperation( response = await SynchronousOperation(
endpoint=ApiEndpoint( endpoint=ApiEndpoint(
path="/proxy/tripo/v2/openapi/task", path="/proxy/tripo/v2/openapi/task",
method=HttpMethod.POST, method=HttpMethod.POST,
@ -549,7 +555,8 @@ class TripoConversionNode:
), ),
auth_kwargs=kwargs, auth_kwargs=kwargs,
).execute() ).execute()
return poll_until_finished(kwargs, response) return await poll_until_finished(kwargs, response)
NODE_CLASS_MAPPINGS = { NODE_CLASS_MAPPINGS = {
"TripoTextToModelNode": TripoTextToModelNode, "TripoTextToModelNode": TripoTextToModelNode,

View File

@ -1,7 +1,7 @@
import io import io
import logging import logging
import base64 import base64
import requests import aiohttp
import torch import torch
from typing import Optional from typing import Optional
@ -152,7 +152,7 @@ class VeoVideoGenerationNode(ComfyNodeABC):
DESCRIPTION = "Generates videos from text prompts using Google's Veo 2 API" DESCRIPTION = "Generates videos from text prompts using Google's Veo 2 API"
API_NODE = True API_NODE = True
def generate_video( async def generate_video(
self, self,
prompt, prompt,
aspect_ratio="16:9", aspect_ratio="16:9",
@ -217,7 +217,7 @@ class VeoVideoGenerationNode(ComfyNodeABC):
auth_kwargs=kwargs, auth_kwargs=kwargs,
) )
initial_response = initial_operation.execute() initial_response = await initial_operation.execute()
operation_name = initial_response.name operation_name = initial_response.name
logging.info(f"Veo generation started with operation name: {operation_name}") logging.info(f"Veo generation started with operation name: {operation_name}")
@ -256,7 +256,7 @@ class VeoVideoGenerationNode(ComfyNodeABC):
) )
# Execute the polling operation # Execute the polling operation
poll_response = poll_operation.execute() poll_response = await poll_operation.execute()
# Now check for errors in the final response # Now check for errors in the final response
# Check for error in poll response # Check for error in poll response
@ -281,7 +281,6 @@ class VeoVideoGenerationNode(ComfyNodeABC):
raise Exception(error_message) raise Exception(error_message)
# Extract video data # Extract video data
video_data = None
if poll_response.response and hasattr(poll_response.response, 'videos') and poll_response.response.videos and len(poll_response.response.videos) > 0: if poll_response.response and hasattr(poll_response.response, 'videos') and poll_response.response.videos and len(poll_response.response.videos) > 0:
video = poll_response.response.videos[0] video = poll_response.response.videos[0]
@ -291,9 +290,9 @@ class VeoVideoGenerationNode(ComfyNodeABC):
video_data = base64.b64decode(video.bytesBase64Encoded) video_data = base64.b64decode(video.bytesBase64Encoded)
elif hasattr(video, 'gcsUri') and video.gcsUri: elif hasattr(video, 'gcsUri') and video.gcsUri:
# Download from URL # Download from URL
video_url = video.gcsUri async with aiohttp.ClientSession() as session:
video_response = requests.get(video_url) async with session.get(video.gcsUri) as video_response:
video_data = video_response.content video_data = await video_response.content.read()
else: else:
raise Exception("Video returned but no data or URL was provided") raise Exception("Video returned but no data or URL was provided")
else: else: