mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-08-13 08:36:40 +00:00
Merge branch 'master' into worksplit-multigpu
This commit is contained in:
commit
6ea69369ce
1
.gitattributes
vendored
1
.gitattributes
vendored
@ -1,2 +1,3 @@
|
||||
/web/assets/** linguist-generated
|
||||
/web/** linguist-vendored
|
||||
comfy_api_nodes/apis/__init__.py linguist-generated
|
||||
|
@ -66,6 +66,7 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith
|
||||
- [Lumina Image 2.0](https://comfyanonymous.github.io/ComfyUI_examples/lumina2/)
|
||||
- [HiDream](https://comfyanonymous.github.io/ComfyUI_examples/hidream/)
|
||||
- [Cosmos Predict2](https://comfyanonymous.github.io/ComfyUI_examples/cosmos_predict2/)
|
||||
- [Qwen Image](https://comfyanonymous.github.io/ComfyUI_examples/qwen_image/)
|
||||
- Image Editing Models
|
||||
- [Omnigen 2](https://comfyanonymous.github.io/ComfyUI_examples/omnigen/)
|
||||
- [Flux Kontext](https://comfyanonymous.github.io/ComfyUI_examples/flux/#flux-kontext-image-editing-model)
|
||||
|
@ -8,7 +8,7 @@ from einops import repeat
|
||||
from comfy.ldm.lightricks.model import TimestepEmbedding, Timesteps
|
||||
from comfy.ldm.modules.attention import optimized_attention_masked
|
||||
from comfy.ldm.flux.layers import EmbedND
|
||||
|
||||
import comfy.ldm.common_dit
|
||||
|
||||
class GELU(nn.Module):
|
||||
def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool = True, dtype=None, device=None, operations=None):
|
||||
@ -364,8 +364,9 @@ class QwenImageTransformer2DModel(nn.Module):
|
||||
|
||||
image_rotary_emb = self.pos_embeds(x, context)
|
||||
|
||||
orig_shape = x.shape
|
||||
hidden_states = x.view(orig_shape[0], orig_shape[1], orig_shape[-2] // 2, 2, orig_shape[-1] // 2, 2)
|
||||
hidden_states = comfy.ldm.common_dit.pad_to_patch_size(x, (1, self.patch_size, self.patch_size))
|
||||
orig_shape = hidden_states.shape
|
||||
hidden_states = hidden_states.view(orig_shape[0], orig_shape[1], orig_shape[-2] // 2, 2, orig_shape[-1] // 2, 2)
|
||||
hidden_states = hidden_states.permute(0, 2, 4, 1, 3, 5)
|
||||
hidden_states = hidden_states.reshape(orig_shape[0], (orig_shape[-2] // 2) * (orig_shape[-1] // 2), orig_shape[1] * 4)
|
||||
|
||||
@ -396,4 +397,4 @@ class QwenImageTransformer2DModel(nn.Module):
|
||||
|
||||
hidden_states = hidden_states.view(orig_shape[0], orig_shape[-2] // 2, orig_shape[-1] // 2, orig_shape[1], 2, 2)
|
||||
hidden_states = hidden_states.permute(0, 3, 1, 4, 2, 5)
|
||||
return hidden_states.reshape(orig_shape)
|
||||
return hidden_states.reshape(orig_shape)[:, :, :, :x.shape[-2], :x.shape[-1]]
|
||||
|
@ -293,6 +293,15 @@ def model_lora_keys_unet(model, key_map={}):
|
||||
key_lora = k[len("diffusion_model."):-len(".weight")]
|
||||
key_map["{}".format(key_lora)] = k
|
||||
|
||||
if isinstance(model, comfy.model_base.QwenImage):
|
||||
for k in sdk:
|
||||
if k.startswith("diffusion_model.") and k.endswith(".weight"): #QwenImage lora format
|
||||
key_lora = k[len("diffusion_model."):-len(".weight")]
|
||||
# Direct mapping for transformer_blocks format (QwenImage LoRA format)
|
||||
key_map["{}".format(key_lora)] = k
|
||||
# Support transformer prefix format
|
||||
key_map["transformer.{}".format(key_lora)] = k
|
||||
|
||||
return key_map
|
||||
|
||||
|
||||
|
@ -1237,7 +1237,7 @@ class QwenImage(supported_models_base.BASE):
|
||||
|
||||
sampling_settings = {
|
||||
"multiplier": 1.0,
|
||||
"shift": 2.6,
|
||||
"shift": 1.15,
|
||||
}
|
||||
|
||||
memory_usage_factor = 1.8 #TODO
|
||||
|
@ -96,6 +96,7 @@ class LoRAAdapter(WeightAdapterBase):
|
||||
diffusers3_lora = "{}.lora.up.weight".format(x)
|
||||
mochi_lora = "{}.lora_B".format(x)
|
||||
transformers_lora = "{}.lora_linear_layer.up.weight".format(x)
|
||||
qwen_default_lora = "{}.lora_B.default.weight".format(x)
|
||||
A_name = None
|
||||
|
||||
if regular_lora in lora.keys():
|
||||
@ -122,6 +123,10 @@ class LoRAAdapter(WeightAdapterBase):
|
||||
A_name = transformers_lora
|
||||
B_name = "{}.lora_linear_layer.down.weight".format(x)
|
||||
mid_name = None
|
||||
elif qwen_default_lora in lora.keys():
|
||||
A_name = qwen_default_lora
|
||||
B_name = "{}.lora_A.default.weight".format(x)
|
||||
mid_name = None
|
||||
|
||||
if A_name is not None:
|
||||
mid = None
|
||||
|
@ -9,7 +9,11 @@ from typing import Type
|
||||
import av
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchaudio
|
||||
try:
|
||||
import torchaudio
|
||||
TORCH_AUDIO_AVAILABLE = True
|
||||
except ImportError:
|
||||
TORCH_AUDIO_AVAILABLE = False
|
||||
from PIL import Image as PILImage
|
||||
from PIL.PngImagePlugin import PngInfo
|
||||
|
||||
@ -302,6 +306,8 @@ class AudioSaveHelper:
|
||||
|
||||
# Resample if necessary
|
||||
if sample_rate != audio["sample_rate"]:
|
||||
if not TORCH_AUDIO_AVAILABLE:
|
||||
raise Exception("torchaudio is not available; cannot resample audio.")
|
||||
waveform = torchaudio.functional.resample(waveform, audio["sample_rate"], sample_rate)
|
||||
|
||||
# Create output with specified format
|
||||
|
@ -1,4 +1,5 @@
|
||||
from __future__ import annotations
|
||||
import aiohttp
|
||||
import io
|
||||
import logging
|
||||
import mimetypes
|
||||
@ -21,7 +22,6 @@ from server import PromptServer
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import requests
|
||||
import torch
|
||||
import math
|
||||
import base64
|
||||
@ -30,7 +30,7 @@ from io import BytesIO
|
||||
import av
|
||||
|
||||
|
||||
def download_url_to_video_output(video_url: str, timeout: int = None) -> VideoFromFile:
|
||||
async def download_url_to_video_output(video_url: str, timeout: int = None) -> VideoFromFile:
|
||||
"""Downloads a video from a URL and returns a `VIDEO` output.
|
||||
|
||||
Args:
|
||||
@ -39,7 +39,7 @@ def download_url_to_video_output(video_url: str, timeout: int = None) -> VideoFr
|
||||
Returns:
|
||||
A Comfy node `VIDEO` output.
|
||||
"""
|
||||
video_io = download_url_to_bytesio(video_url, timeout)
|
||||
video_io = await download_url_to_bytesio(video_url, timeout)
|
||||
if video_io is None:
|
||||
error_msg = f"Failed to download video from {video_url}"
|
||||
logging.error(error_msg)
|
||||
@ -62,7 +62,7 @@ def downscale_image_tensor(image, total_pixels=1536 * 1024) -> torch.Tensor:
|
||||
return s
|
||||
|
||||
|
||||
def validate_and_cast_response(
|
||||
async def validate_and_cast_response(
|
||||
response, timeout: int = None, node_id: Union[str, None] = None
|
||||
) -> torch.Tensor:
|
||||
"""Validates and casts a response to a torch.Tensor.
|
||||
@ -86,35 +86,24 @@ def validate_and_cast_response(
|
||||
image_tensors: list[torch.Tensor] = []
|
||||
|
||||
# Process each image in the data array
|
||||
for image_data in data:
|
||||
image_url = image_data.url
|
||||
b64_data = image_data.b64_json
|
||||
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=timeout)) as session:
|
||||
for img_data in data:
|
||||
img_bytes: bytes
|
||||
if img_data.b64_json:
|
||||
img_bytes = base64.b64decode(img_data.b64_json)
|
||||
elif img_data.url:
|
||||
if node_id:
|
||||
PromptServer.instance.send_progress_text(f"Result URL: {img_data.url}", node_id)
|
||||
async with session.get(img_data.url) as resp:
|
||||
if resp.status != 200:
|
||||
raise ValueError("Failed to download generated image")
|
||||
img_bytes = await resp.read()
|
||||
else:
|
||||
raise ValueError("Invalid image payload – neither URL nor base64 data present.")
|
||||
|
||||
if not image_url and not b64_data:
|
||||
raise ValueError("No image was generated in the response")
|
||||
|
||||
if b64_data:
|
||||
img_data = base64.b64decode(b64_data)
|
||||
img = Image.open(io.BytesIO(img_data))
|
||||
|
||||
elif image_url:
|
||||
if node_id:
|
||||
PromptServer.instance.send_progress_text(
|
||||
f"Result URL: {image_url}", node_id
|
||||
)
|
||||
img_response = requests.get(image_url, timeout=timeout)
|
||||
if img_response.status_code != 200:
|
||||
raise ValueError("Failed to download the image")
|
||||
img = Image.open(io.BytesIO(img_response.content))
|
||||
|
||||
img = img.convert("RGBA")
|
||||
|
||||
# Convert to numpy array, normalize to float32 between 0 and 1
|
||||
img_array = np.array(img).astype(np.float32) / 255.0
|
||||
img_tensor = torch.from_numpy(img_array)
|
||||
|
||||
# Add to list of tensors
|
||||
image_tensors.append(img_tensor)
|
||||
pil_img = Image.open(BytesIO(img_bytes)).convert("RGBA")
|
||||
arr = np.asarray(pil_img).astype(np.float32) / 255.0
|
||||
image_tensors.append(torch.from_numpy(arr))
|
||||
|
||||
return torch.stack(image_tensors, dim=0)
|
||||
|
||||
@ -175,7 +164,7 @@ def mimetype_to_extension(mime_type: str) -> str:
|
||||
return mime_type.split("/")[-1].lower()
|
||||
|
||||
|
||||
def download_url_to_bytesio(url: str, timeout: int = None) -> BytesIO:
|
||||
async def download_url_to_bytesio(url: str, timeout: int = None) -> BytesIO:
|
||||
"""Downloads content from a URL using requests and returns it as BytesIO.
|
||||
|
||||
Args:
|
||||
@ -185,9 +174,11 @@ def download_url_to_bytesio(url: str, timeout: int = None) -> BytesIO:
|
||||
Returns:
|
||||
BytesIO object containing the downloaded content.
|
||||
"""
|
||||
response = requests.get(url, stream=True, timeout=timeout)
|
||||
response.raise_for_status() # Raises HTTPError for bad responses (4XX or 5XX)
|
||||
return BytesIO(response.content)
|
||||
timeout_cfg = aiohttp.ClientTimeout(total=timeout) if timeout else None
|
||||
async with aiohttp.ClientSession(timeout=timeout_cfg) as session:
|
||||
async with session.get(url) as resp:
|
||||
resp.raise_for_status() # Raises HTTPError for bad responses (4XX or 5XX)
|
||||
return BytesIO(await resp.read())
|
||||
|
||||
|
||||
def bytesio_to_image_tensor(image_bytesio: BytesIO, mode: str = "RGBA") -> torch.Tensor:
|
||||
@ -210,15 +201,15 @@ def bytesio_to_image_tensor(image_bytesio: BytesIO, mode: str = "RGBA") -> torch
|
||||
return torch.from_numpy(image_array).unsqueeze(0)
|
||||
|
||||
|
||||
def download_url_to_image_tensor(url: str, timeout: int = None) -> torch.Tensor:
|
||||
async def download_url_to_image_tensor(url: str, timeout: int = None) -> torch.Tensor:
|
||||
"""Downloads an image from a URL and returns a [B, H, W, C] tensor."""
|
||||
image_bytesio = download_url_to_bytesio(url, timeout)
|
||||
image_bytesio = await download_url_to_bytesio(url, timeout)
|
||||
return bytesio_to_image_tensor(image_bytesio)
|
||||
|
||||
|
||||
def process_image_response(response: requests.Response) -> torch.Tensor:
|
||||
def process_image_response(response_content: bytes | str) -> torch.Tensor:
|
||||
"""Uses content from a Response object and converts it to a torch.Tensor"""
|
||||
return bytesio_to_image_tensor(BytesIO(response.content))
|
||||
return bytesio_to_image_tensor(BytesIO(response_content))
|
||||
|
||||
|
||||
def _tensor_to_pil(image: torch.Tensor, total_pixels: int = 2048 * 2048) -> Image.Image:
|
||||
@ -336,10 +327,10 @@ def text_filepath_to_data_uri(filepath: str) -> str:
|
||||
return f"data:{mime_type};base64,{base64_string}"
|
||||
|
||||
|
||||
def upload_file_to_comfyapi(
|
||||
async def upload_file_to_comfyapi(
|
||||
file_bytes_io: BytesIO,
|
||||
filename: str,
|
||||
upload_mime_type: str,
|
||||
upload_mime_type: Optional[str],
|
||||
auth_kwargs: Optional[dict[str, str]] = None,
|
||||
) -> str:
|
||||
"""
|
||||
@ -354,7 +345,10 @@ def upload_file_to_comfyapi(
|
||||
Returns:
|
||||
The download URL for the uploaded file.
|
||||
"""
|
||||
request_object = UploadRequest(file_name=filename, content_type=upload_mime_type)
|
||||
if upload_mime_type is None:
|
||||
request_object = UploadRequest(file_name=filename)
|
||||
else:
|
||||
request_object = UploadRequest(file_name=filename, content_type=upload_mime_type)
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/customers/storage",
|
||||
@ -366,12 +360,8 @@ def upload_file_to_comfyapi(
|
||||
auth_kwargs=auth_kwargs,
|
||||
)
|
||||
|
||||
response: UploadResponse = operation.execute()
|
||||
upload_response = ApiClient.upload_file(
|
||||
response.upload_url, file_bytes_io, content_type=upload_mime_type
|
||||
)
|
||||
upload_response.raise_for_status()
|
||||
|
||||
response: UploadResponse = await operation.execute()
|
||||
await ApiClient.upload_file(response.upload_url, file_bytes_io, content_type=upload_mime_type)
|
||||
return response.download_url
|
||||
|
||||
|
||||
@ -399,7 +389,7 @@ def video_to_base64_string(
|
||||
return base64.b64encode(video_bytes_io.getvalue()).decode("utf-8")
|
||||
|
||||
|
||||
def upload_video_to_comfyapi(
|
||||
async def upload_video_to_comfyapi(
|
||||
video: VideoInput,
|
||||
auth_kwargs: Optional[dict[str, str]] = None,
|
||||
container: VideoContainer = VideoContainer.MP4,
|
||||
@ -439,9 +429,7 @@ def upload_video_to_comfyapi(
|
||||
video.save_to(video_bytes_io, format=container, codec=codec)
|
||||
video_bytes_io.seek(0)
|
||||
|
||||
return upload_file_to_comfyapi(
|
||||
video_bytes_io, filename, upload_mime_type, auth_kwargs
|
||||
)
|
||||
return await upload_file_to_comfyapi(video_bytes_io, filename, upload_mime_type, auth_kwargs)
|
||||
|
||||
|
||||
def audio_tensor_to_contiguous_ndarray(waveform: torch.Tensor) -> np.ndarray:
|
||||
@ -501,7 +489,7 @@ def audio_ndarray_to_bytesio(
|
||||
return audio_bytes_io
|
||||
|
||||
|
||||
def upload_audio_to_comfyapi(
|
||||
async def upload_audio_to_comfyapi(
|
||||
audio: AudioInput,
|
||||
auth_kwargs: Optional[dict[str, str]] = None,
|
||||
container_format: str = "mp4",
|
||||
@ -527,7 +515,7 @@ def upload_audio_to_comfyapi(
|
||||
audio_data_np, sample_rate, container_format, codec_name
|
||||
)
|
||||
|
||||
return upload_file_to_comfyapi(audio_bytes_io, filename, mime_type, auth_kwargs)
|
||||
return await upload_file_to_comfyapi(audio_bytes_io, filename, mime_type, auth_kwargs)
|
||||
|
||||
|
||||
def audio_to_base64_string(
|
||||
@ -544,7 +532,7 @@ def audio_to_base64_string(
|
||||
return base64.b64encode(audio_bytes).decode("utf-8")
|
||||
|
||||
|
||||
def upload_images_to_comfyapi(
|
||||
async def upload_images_to_comfyapi(
|
||||
image: torch.Tensor,
|
||||
max_images=8,
|
||||
auth_kwargs: Optional[dict[str, str]] = None,
|
||||
@ -561,55 +549,15 @@ def upload_images_to_comfyapi(
|
||||
mime_type: Optional MIME type for the image.
|
||||
"""
|
||||
# if batch, try to upload each file if max_images is greater than 0
|
||||
idx_image = 0
|
||||
download_urls: list[str] = []
|
||||
is_batch = len(image.shape) > 3
|
||||
batch_length = 1
|
||||
if is_batch:
|
||||
batch_length = image.shape[0]
|
||||
while True:
|
||||
curr_image = image
|
||||
if len(image.shape) > 3:
|
||||
curr_image = image[idx_image]
|
||||
# get BytesIO version of image
|
||||
img_binary = tensor_to_bytesio(curr_image, mime_type=mime_type)
|
||||
# first, request upload/download urls from comfy API
|
||||
if not mime_type:
|
||||
request_object = UploadRequest(file_name=img_binary.name)
|
||||
else:
|
||||
request_object = UploadRequest(
|
||||
file_name=img_binary.name, content_type=mime_type
|
||||
)
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/customers/storage",
|
||||
method=HttpMethod.POST,
|
||||
request_model=UploadRequest,
|
||||
response_model=UploadResponse,
|
||||
),
|
||||
request=request_object,
|
||||
auth_kwargs=auth_kwargs,
|
||||
)
|
||||
response = operation.execute()
|
||||
batch_len = image.shape[0] if is_batch else 1
|
||||
|
||||
upload_response = ApiClient.upload_file(
|
||||
response.upload_url, img_binary, content_type=mime_type
|
||||
)
|
||||
# verify success
|
||||
try:
|
||||
upload_response.raise_for_status()
|
||||
except requests.exceptions.HTTPError as e:
|
||||
raise ValueError(f"Could not upload one or more images: {e}") from e
|
||||
# add download_url to list
|
||||
download_urls.append(response.download_url)
|
||||
|
||||
idx_image += 1
|
||||
# stop uploading additional files if done
|
||||
if is_batch and max_images > 0:
|
||||
if idx_image >= max_images:
|
||||
break
|
||||
if idx_image >= batch_length:
|
||||
break
|
||||
for idx in range(min(batch_len, max_images)):
|
||||
tensor = image[idx] if is_batch else image
|
||||
img_io = tensor_to_bytesio(tensor, mime_type=mime_type)
|
||||
url = await upload_file_to_comfyapi(img_io, img_io.name, mime_type, auth_kwargs)
|
||||
download_urls.append(url)
|
||||
return download_urls
|
||||
|
||||
|
||||
|
2656
comfy_api_nodes/apis/__init__.py
generated
2656
comfy_api_nodes/apis/__init__.py
generated
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -127,7 +127,7 @@ class TripoTextToModelRequest(BaseModel):
|
||||
type: TripoTaskType = Field(TripoTaskType.TEXT_TO_MODEL, description='Type of task')
|
||||
prompt: str = Field(..., description='The text prompt describing the model to generate', max_length=1024)
|
||||
negative_prompt: Optional[str] = Field(None, description='The negative text prompt', max_length=1024)
|
||||
model_version: Optional[TripoModelVersion] = TripoModelVersion.V2_5
|
||||
model_version: Optional[TripoModelVersion] = TripoModelVersion.v2_5_20250123
|
||||
face_limit: Optional[int] = Field(None, description='The number of faces to limit the generation to')
|
||||
texture: Optional[bool] = Field(True, description='Whether to apply texture to the generated model')
|
||||
pbr: Optional[bool] = Field(True, description='Whether to apply PBR to the generated model')
|
||||
|
@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
import io
|
||||
from inspect import cleandoc
|
||||
from typing import Union, Optional
|
||||
@ -28,7 +29,7 @@ from comfy_api_nodes.apinode_utils import (
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import requests
|
||||
import aiohttp
|
||||
import torch
|
||||
import base64
|
||||
import time
|
||||
@ -44,18 +45,18 @@ def convert_mask_to_image(mask: torch.Tensor):
|
||||
return mask
|
||||
|
||||
|
||||
def handle_bfl_synchronous_operation(
|
||||
async def handle_bfl_synchronous_operation(
|
||||
operation: SynchronousOperation,
|
||||
timeout_bfl_calls=360,
|
||||
node_id: Union[str, None] = None,
|
||||
):
|
||||
response_api: BFLFluxProGenerateResponse = operation.execute()
|
||||
return _poll_until_generated(
|
||||
response_api: BFLFluxProGenerateResponse = await operation.execute()
|
||||
return await _poll_until_generated(
|
||||
response_api.polling_url, timeout=timeout_bfl_calls, node_id=node_id
|
||||
)
|
||||
|
||||
|
||||
def _poll_until_generated(
|
||||
async def _poll_until_generated(
|
||||
polling_url: str, timeout=360, node_id: Union[str, None] = None
|
||||
):
|
||||
# used bfl-comfy-nodes to verify code implementation:
|
||||
@ -66,55 +67,56 @@ def _poll_until_generated(
|
||||
retry_404_seconds = 2
|
||||
retry_202_seconds = 2
|
||||
retry_pending_seconds = 1
|
||||
request = requests.Request(method=HttpMethod.GET, url=polling_url)
|
||||
# NOTE: should True loop be replaced with checking if workflow has been interrupted?
|
||||
while True:
|
||||
if node_id:
|
||||
time_elapsed = time.time() - start_time
|
||||
PromptServer.instance.send_progress_text(
|
||||
f"Generating ({time_elapsed:.0f}s)", node_id
|
||||
)
|
||||
|
||||
response = requests.Session().send(request.prepare())
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
if result["status"] == BFLStatus.ready:
|
||||
img_url = result["result"]["sample"]
|
||||
if node_id:
|
||||
PromptServer.instance.send_progress_text(
|
||||
f"Result URL: {img_url}", node_id
|
||||
)
|
||||
img_response = requests.get(img_url)
|
||||
return process_image_response(img_response)
|
||||
elif result["status"] in [
|
||||
BFLStatus.request_moderated,
|
||||
BFLStatus.content_moderated,
|
||||
]:
|
||||
status = result["status"]
|
||||
raise Exception(
|
||||
f"BFL API did not return an image due to: {status}."
|
||||
async with aiohttp.ClientSession() as session:
|
||||
# NOTE: should True loop be replaced with checking if workflow has been interrupted?
|
||||
while True:
|
||||
if node_id:
|
||||
time_elapsed = time.time() - start_time
|
||||
PromptServer.instance.send_progress_text(
|
||||
f"Generating ({time_elapsed:.0f}s)", node_id
|
||||
)
|
||||
elif result["status"] == BFLStatus.error:
|
||||
raise Exception(f"BFL API encountered an error: {result}.")
|
||||
elif result["status"] == BFLStatus.pending:
|
||||
time.sleep(retry_pending_seconds)
|
||||
continue
|
||||
elif response.status_code == 404:
|
||||
if retries_404 < max_retries_404:
|
||||
retries_404 += 1
|
||||
time.sleep(retry_404_seconds)
|
||||
continue
|
||||
raise Exception(
|
||||
f"BFL API could not find task after {max_retries_404} tries."
|
||||
)
|
||||
elif response.status_code == 202:
|
||||
time.sleep(retry_202_seconds)
|
||||
elif time.time() - start_time > timeout:
|
||||
raise Exception(
|
||||
f"BFL API experienced a timeout; could not return request under {timeout} seconds."
|
||||
)
|
||||
else:
|
||||
raise Exception(f"BFL API encountered an error: {response.json()}")
|
||||
|
||||
async with session.get(polling_url) as response:
|
||||
if response.status == 200:
|
||||
result = await response.json()
|
||||
if result["status"] == BFLStatus.ready:
|
||||
img_url = result["result"]["sample"]
|
||||
if node_id:
|
||||
PromptServer.instance.send_progress_text(
|
||||
f"Result URL: {img_url}", node_id
|
||||
)
|
||||
async with session.get(img_url) as img_resp:
|
||||
return process_image_response(await img_resp.content.read())
|
||||
elif result["status"] in [
|
||||
BFLStatus.request_moderated,
|
||||
BFLStatus.content_moderated,
|
||||
]:
|
||||
status = result["status"]
|
||||
raise Exception(
|
||||
f"BFL API did not return an image due to: {status}."
|
||||
)
|
||||
elif result["status"] == BFLStatus.error:
|
||||
raise Exception(f"BFL API encountered an error: {result}.")
|
||||
elif result["status"] == BFLStatus.pending:
|
||||
await asyncio.sleep(retry_pending_seconds)
|
||||
continue
|
||||
elif response.status == 404:
|
||||
if retries_404 < max_retries_404:
|
||||
retries_404 += 1
|
||||
await asyncio.sleep(retry_404_seconds)
|
||||
continue
|
||||
raise Exception(
|
||||
f"BFL API could not find task after {max_retries_404} tries."
|
||||
)
|
||||
elif response.status == 202:
|
||||
await asyncio.sleep(retry_202_seconds)
|
||||
elif time.time() - start_time > timeout:
|
||||
raise Exception(
|
||||
f"BFL API experienced a timeout; could not return request under {timeout} seconds."
|
||||
)
|
||||
else:
|
||||
raise Exception(f"BFL API encountered an error: {response.json()}")
|
||||
|
||||
def convert_image_to_base64(image: torch.Tensor):
|
||||
scaled_image = downscale_image_tensor(image, total_pixels=2048 * 2048)
|
||||
@ -222,7 +224,7 @@ class FluxProUltraImageNode(ComfyNodeABC):
|
||||
API_NODE = True
|
||||
CATEGORY = "api node/image/BFL"
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
prompt: str,
|
||||
aspect_ratio: str,
|
||||
@ -266,7 +268,7 @@ class FluxProUltraImageNode(ComfyNodeABC):
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id)
|
||||
output_image = await handle_bfl_synchronous_operation(operation, node_id=unique_id)
|
||||
return (output_image,)
|
||||
|
||||
|
||||
@ -354,7 +356,7 @@ class FluxKontextProImageNode(ComfyNodeABC):
|
||||
|
||||
BFL_PATH = "/proxy/bfl/flux-kontext-pro/generate"
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
prompt: str,
|
||||
aspect_ratio: str,
|
||||
@ -397,7 +399,7 @@ class FluxKontextProImageNode(ComfyNodeABC):
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id)
|
||||
output_image = await handle_bfl_synchronous_operation(operation, node_id=unique_id)
|
||||
return (output_image,)
|
||||
|
||||
|
||||
@ -489,7 +491,7 @@ class FluxProImageNode(ComfyNodeABC):
|
||||
API_NODE = True
|
||||
CATEGORY = "api node/image/BFL"
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
prompt: str,
|
||||
prompt_upsampling,
|
||||
@ -524,7 +526,7 @@ class FluxProImageNode(ComfyNodeABC):
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id)
|
||||
output_image = await handle_bfl_synchronous_operation(operation, node_id=unique_id)
|
||||
return (output_image,)
|
||||
|
||||
|
||||
@ -632,7 +634,7 @@ class FluxProExpandNode(ComfyNodeABC):
|
||||
API_NODE = True
|
||||
CATEGORY = "api node/image/BFL"
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
image: torch.Tensor,
|
||||
prompt: str,
|
||||
@ -670,7 +672,7 @@ class FluxProExpandNode(ComfyNodeABC):
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id)
|
||||
output_image = await handle_bfl_synchronous_operation(operation, node_id=unique_id)
|
||||
return (output_image,)
|
||||
|
||||
|
||||
@ -744,7 +746,7 @@ class FluxProFillNode(ComfyNodeABC):
|
||||
API_NODE = True
|
||||
CATEGORY = "api node/image/BFL"
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
image: torch.Tensor,
|
||||
mask: torch.Tensor,
|
||||
@ -780,7 +782,7 @@ class FluxProFillNode(ComfyNodeABC):
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id)
|
||||
output_image = await handle_bfl_synchronous_operation(operation, node_id=unique_id)
|
||||
return (output_image,)
|
||||
|
||||
|
||||
@ -879,7 +881,7 @@ class FluxProCannyNode(ComfyNodeABC):
|
||||
API_NODE = True
|
||||
CATEGORY = "api node/image/BFL"
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
control_image: torch.Tensor,
|
||||
prompt: str,
|
||||
@ -929,7 +931,7 @@ class FluxProCannyNode(ComfyNodeABC):
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id)
|
||||
output_image = await handle_bfl_synchronous_operation(operation, node_id=unique_id)
|
||||
return (output_image,)
|
||||
|
||||
|
||||
@ -1008,7 +1010,7 @@ class FluxProDepthNode(ComfyNodeABC):
|
||||
API_NODE = True
|
||||
CATEGORY = "api node/image/BFL"
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
control_image: torch.Tensor,
|
||||
prompt: str,
|
||||
@ -1045,7 +1047,7 @@ class FluxProDepthNode(ComfyNodeABC):
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id)
|
||||
output_image = await handle_bfl_synchronous_operation(operation, node_id=unique_id)
|
||||
return (output_image,)
|
||||
|
||||
|
||||
|
@ -303,7 +303,7 @@ class GeminiNode(ComfyNodeABC):
|
||||
"""
|
||||
return GeminiPart(text=text)
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
prompt: str,
|
||||
model: GeminiModel,
|
||||
@ -332,7 +332,7 @@ class GeminiNode(ComfyNodeABC):
|
||||
parts.extend(files)
|
||||
|
||||
# Create response
|
||||
response = SynchronousOperation(
|
||||
response = await SynchronousOperation(
|
||||
endpoint=get_gemini_endpoint(model),
|
||||
request=GeminiGenerateContentRequest(
|
||||
contents=[
|
||||
|
@ -212,7 +212,7 @@ V3_RESOLUTIONS= [
|
||||
"1536x640"
|
||||
]
|
||||
|
||||
def download_and_process_images(image_urls):
|
||||
async def download_and_process_images(image_urls):
|
||||
"""Helper function to download and process multiple images from URLs"""
|
||||
|
||||
# Initialize list to store image tensors
|
||||
@ -220,7 +220,7 @@ def download_and_process_images(image_urls):
|
||||
|
||||
for image_url in image_urls:
|
||||
# Using functions from apinode_utils.py to handle downloading and processing
|
||||
image_bytesio = download_url_to_bytesio(image_url) # Download image content to BytesIO
|
||||
image_bytesio = await download_url_to_bytesio(image_url) # Download image content to BytesIO
|
||||
img_tensor = bytesio_to_image_tensor(image_bytesio, mode="RGB") # Convert to torch.Tensor with RGB mode
|
||||
image_tensors.append(img_tensor)
|
||||
|
||||
@ -328,7 +328,7 @@ class IdeogramV1(ComfyNodeABC):
|
||||
DESCRIPTION = cleandoc(__doc__ or "")
|
||||
API_NODE = True
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
prompt,
|
||||
turbo=False,
|
||||
@ -367,7 +367,7 @@ class IdeogramV1(ComfyNodeABC):
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
|
||||
response = operation.execute()
|
||||
response = await operation.execute()
|
||||
|
||||
if not response.data or len(response.data) == 0:
|
||||
raise Exception("No images were generated in the response")
|
||||
@ -378,7 +378,7 @@ class IdeogramV1(ComfyNodeABC):
|
||||
raise Exception("No image URLs were generated in the response")
|
||||
|
||||
display_image_urls_on_node(image_urls, unique_id)
|
||||
return (download_and_process_images(image_urls),)
|
||||
return (await download_and_process_images(image_urls),)
|
||||
|
||||
|
||||
class IdeogramV2(ComfyNodeABC):
|
||||
@ -487,7 +487,7 @@ class IdeogramV2(ComfyNodeABC):
|
||||
DESCRIPTION = cleandoc(__doc__ or "")
|
||||
API_NODE = True
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
prompt,
|
||||
turbo=False,
|
||||
@ -543,7 +543,7 @@ class IdeogramV2(ComfyNodeABC):
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
|
||||
response = operation.execute()
|
||||
response = await operation.execute()
|
||||
|
||||
if not response.data or len(response.data) == 0:
|
||||
raise Exception("No images were generated in the response")
|
||||
@ -554,7 +554,7 @@ class IdeogramV2(ComfyNodeABC):
|
||||
raise Exception("No image URLs were generated in the response")
|
||||
|
||||
display_image_urls_on_node(image_urls, unique_id)
|
||||
return (download_and_process_images(image_urls),)
|
||||
return (await download_and_process_images(image_urls),)
|
||||
|
||||
class IdeogramV3(ComfyNodeABC):
|
||||
"""
|
||||
@ -653,7 +653,7 @@ class IdeogramV3(ComfyNodeABC):
|
||||
DESCRIPTION = cleandoc(__doc__ or "")
|
||||
API_NODE = True
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
prompt,
|
||||
image=None,
|
||||
@ -774,7 +774,7 @@ class IdeogramV3(ComfyNodeABC):
|
||||
)
|
||||
|
||||
# Execute the operation and process response
|
||||
response = operation.execute()
|
||||
response = await operation.execute()
|
||||
|
||||
if not response.data or len(response.data) == 0:
|
||||
raise Exception("No images were generated in the response")
|
||||
@ -785,7 +785,7 @@ class IdeogramV3(ComfyNodeABC):
|
||||
raise Exception("No image URLs were generated in the response")
|
||||
|
||||
display_image_urls_on_node(image_urls, unique_id)
|
||||
return (download_and_process_images(image_urls),)
|
||||
return (await download_and_process_images(image_urls),)
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
|
@ -109,7 +109,7 @@ class KlingApiError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def poll_until_finished(
|
||||
async def poll_until_finished(
|
||||
auth_kwargs: dict[str, str],
|
||||
api_endpoint: ApiEndpoint[Any, R],
|
||||
result_url_extractor: Optional[Callable[[R], str]] = None,
|
||||
@ -117,7 +117,7 @@ def poll_until_finished(
|
||||
node_id: Optional[str] = None,
|
||||
) -> R:
|
||||
"""Polls the Kling API endpoint until the task reaches a terminal state, then returns the response."""
|
||||
return PollingOperation(
|
||||
return await PollingOperation(
|
||||
poll_endpoint=api_endpoint,
|
||||
completed_statuses=[
|
||||
KlingTaskStatus.succeed.value,
|
||||
@ -278,18 +278,18 @@ def get_images_urls_from_response(response) -> Optional[str]:
|
||||
return None
|
||||
|
||||
|
||||
def video_result_to_node_output(
|
||||
async def video_result_to_node_output(
|
||||
video: KlingVideoResult,
|
||||
) -> tuple[VideoFromFile, str, str]:
|
||||
"""Converts a KlingVideoResult to a tuple of (VideoFromFile, str, str) to be used as a ComfyUI node output."""
|
||||
return (
|
||||
download_url_to_video_output(video.url),
|
||||
await download_url_to_video_output(str(video.url)),
|
||||
str(video.id),
|
||||
str(video.duration),
|
||||
)
|
||||
|
||||
|
||||
def image_result_to_node_output(
|
||||
async def image_result_to_node_output(
|
||||
images: list[KlingImageResult],
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
@ -297,9 +297,9 @@ def image_result_to_node_output(
|
||||
If multiple images are returned, they will be stacked along the batch dimension.
|
||||
"""
|
||||
if len(images) == 1:
|
||||
return download_url_to_image_tensor(images[0].url)
|
||||
return await download_url_to_image_tensor(str(images[0].url))
|
||||
else:
|
||||
return torch.cat([download_url_to_image_tensor(image.url) for image in images])
|
||||
return torch.cat([await download_url_to_image_tensor(str(image.url)) for image in images])
|
||||
|
||||
|
||||
class KlingNodeBase(ComfyNodeABC):
|
||||
@ -467,10 +467,10 @@ class KlingTextToVideoNode(KlingNodeBase):
|
||||
RETURN_NAMES = ("VIDEO", "video_id", "duration")
|
||||
DESCRIPTION = "Kling Text to Video Node"
|
||||
|
||||
def get_response(
|
||||
async def get_response(
|
||||
self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
|
||||
) -> KlingText2VideoResponse:
|
||||
return poll_until_finished(
|
||||
return await poll_until_finished(
|
||||
auth_kwargs,
|
||||
ApiEndpoint(
|
||||
path=f"{PATH_TEXT_TO_VIDEO}/{task_id}",
|
||||
@ -483,7 +483,7 @@ class KlingTextToVideoNode(KlingNodeBase):
|
||||
node_id=node_id,
|
||||
)
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
prompt: str,
|
||||
negative_prompt: str,
|
||||
@ -519,17 +519,17 @@ class KlingTextToVideoNode(KlingNodeBase):
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
|
||||
task_creation_response = initial_operation.execute()
|
||||
task_creation_response = await initial_operation.execute()
|
||||
validate_task_creation_response(task_creation_response)
|
||||
|
||||
task_id = task_creation_response.data.task_id
|
||||
final_response = self.get_response(
|
||||
final_response = await self.get_response(
|
||||
task_id, auth_kwargs=kwargs, node_id=unique_id
|
||||
)
|
||||
validate_video_result_response(final_response)
|
||||
|
||||
video = get_video_from_response(final_response)
|
||||
return video_result_to_node_output(video)
|
||||
return await video_result_to_node_output(video)
|
||||
|
||||
|
||||
class KlingCameraControlT2VNode(KlingTextToVideoNode):
|
||||
@ -581,7 +581,7 @@ class KlingCameraControlT2VNode(KlingTextToVideoNode):
|
||||
|
||||
DESCRIPTION = "Transform text into cinematic videos with professional camera movements that simulate real-world cinematography. Control virtual camera actions including zoom, rotation, pan, tilt, and first-person view, while maintaining focus on your original text."
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
prompt: str,
|
||||
negative_prompt: str,
|
||||
@ -591,7 +591,7 @@ class KlingCameraControlT2VNode(KlingTextToVideoNode):
|
||||
unique_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
return super().api_call(
|
||||
return await super().api_call(
|
||||
model_name=KlingVideoGenModelName.kling_v1,
|
||||
cfg_scale=cfg_scale,
|
||||
mode=KlingVideoGenMode.std,
|
||||
@ -670,10 +670,10 @@ class KlingImage2VideoNode(KlingNodeBase):
|
||||
RETURN_NAMES = ("VIDEO", "video_id", "duration")
|
||||
DESCRIPTION = "Kling Image to Video Node"
|
||||
|
||||
def get_response(
|
||||
async def get_response(
|
||||
self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
|
||||
) -> KlingImage2VideoResponse:
|
||||
return poll_until_finished(
|
||||
return await poll_until_finished(
|
||||
auth_kwargs,
|
||||
ApiEndpoint(
|
||||
path=f"{PATH_IMAGE_TO_VIDEO}/{task_id}",
|
||||
@ -686,7 +686,7 @@ class KlingImage2VideoNode(KlingNodeBase):
|
||||
node_id=node_id,
|
||||
)
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
start_frame: torch.Tensor,
|
||||
prompt: str,
|
||||
@ -733,17 +733,17 @@ class KlingImage2VideoNode(KlingNodeBase):
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
|
||||
task_creation_response = initial_operation.execute()
|
||||
task_creation_response = await initial_operation.execute()
|
||||
validate_task_creation_response(task_creation_response)
|
||||
task_id = task_creation_response.data.task_id
|
||||
|
||||
final_response = self.get_response(
|
||||
final_response = await self.get_response(
|
||||
task_id, auth_kwargs=kwargs, node_id=unique_id
|
||||
)
|
||||
validate_video_result_response(final_response)
|
||||
|
||||
video = get_video_from_response(final_response)
|
||||
return video_result_to_node_output(video)
|
||||
return await video_result_to_node_output(video)
|
||||
|
||||
|
||||
class KlingCameraControlI2VNode(KlingImage2VideoNode):
|
||||
@ -798,7 +798,7 @@ class KlingCameraControlI2VNode(KlingImage2VideoNode):
|
||||
|
||||
DESCRIPTION = "Transform still images into cinematic videos with professional camera movements that simulate real-world cinematography. Control virtual camera actions including zoom, rotation, pan, tilt, and first-person view, while maintaining focus on your original image."
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
start_frame: torch.Tensor,
|
||||
prompt: str,
|
||||
@ -809,7 +809,7 @@ class KlingCameraControlI2VNode(KlingImage2VideoNode):
|
||||
unique_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
return super().api_call(
|
||||
return await super().api_call(
|
||||
model_name=KlingVideoGenModelName.kling_v1_5,
|
||||
start_frame=start_frame,
|
||||
cfg_scale=cfg_scale,
|
||||
@ -897,7 +897,7 @@ class KlingStartEndFrameNode(KlingImage2VideoNode):
|
||||
|
||||
DESCRIPTION = "Generate a video sequence that transitions between your provided start and end images. The node creates all frames in between, producing a smooth transformation from the first frame to the last."
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
start_frame: torch.Tensor,
|
||||
end_frame: torch.Tensor,
|
||||
@ -912,7 +912,7 @@ class KlingStartEndFrameNode(KlingImage2VideoNode):
|
||||
mode, duration, model_name = KlingStartEndFrameNode.get_mode_string_mapping()[
|
||||
mode
|
||||
]
|
||||
return super().api_call(
|
||||
return await super().api_call(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
model_name=model_name,
|
||||
@ -964,10 +964,10 @@ class KlingVideoExtendNode(KlingNodeBase):
|
||||
RETURN_NAMES = ("VIDEO", "video_id", "duration")
|
||||
DESCRIPTION = "Kling Video Extend Node. Extend videos made by other Kling nodes. The video_id is created by using other Kling Nodes."
|
||||
|
||||
def get_response(
|
||||
async def get_response(
|
||||
self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
|
||||
) -> KlingVideoExtendResponse:
|
||||
return poll_until_finished(
|
||||
return await poll_until_finished(
|
||||
auth_kwargs,
|
||||
ApiEndpoint(
|
||||
path=f"{PATH_VIDEO_EXTEND}/{task_id}",
|
||||
@ -980,7 +980,7 @@ class KlingVideoExtendNode(KlingNodeBase):
|
||||
node_id=node_id,
|
||||
)
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
prompt: str,
|
||||
negative_prompt: str,
|
||||
@ -1006,17 +1006,17 @@ class KlingVideoExtendNode(KlingNodeBase):
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
|
||||
task_creation_response = initial_operation.execute()
|
||||
task_creation_response = await initial_operation.execute()
|
||||
validate_task_creation_response(task_creation_response)
|
||||
task_id = task_creation_response.data.task_id
|
||||
|
||||
final_response = self.get_response(
|
||||
final_response = await self.get_response(
|
||||
task_id, auth_kwargs=kwargs, node_id=unique_id
|
||||
)
|
||||
validate_video_result_response(final_response)
|
||||
|
||||
video = get_video_from_response(final_response)
|
||||
return video_result_to_node_output(video)
|
||||
return await video_result_to_node_output(video)
|
||||
|
||||
|
||||
class KlingVideoEffectsBase(KlingNodeBase):
|
||||
@ -1025,10 +1025,10 @@ class KlingVideoEffectsBase(KlingNodeBase):
|
||||
RETURN_TYPES = ("VIDEO", "STRING", "STRING")
|
||||
RETURN_NAMES = ("VIDEO", "video_id", "duration")
|
||||
|
||||
def get_response(
|
||||
async def get_response(
|
||||
self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
|
||||
) -> KlingVideoEffectsResponse:
|
||||
return poll_until_finished(
|
||||
return await poll_until_finished(
|
||||
auth_kwargs,
|
||||
ApiEndpoint(
|
||||
path=f"{PATH_VIDEO_EFFECTS}/{task_id}",
|
||||
@ -1041,7 +1041,7 @@ class KlingVideoEffectsBase(KlingNodeBase):
|
||||
node_id=node_id,
|
||||
)
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
dual_character: bool,
|
||||
effect_scene: KlingDualCharacterEffectsScene | KlingSingleImageEffectsScene,
|
||||
@ -1084,17 +1084,17 @@ class KlingVideoEffectsBase(KlingNodeBase):
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
|
||||
task_creation_response = initial_operation.execute()
|
||||
task_creation_response = await initial_operation.execute()
|
||||
validate_task_creation_response(task_creation_response)
|
||||
task_id = task_creation_response.data.task_id
|
||||
|
||||
final_response = self.get_response(
|
||||
final_response = await self.get_response(
|
||||
task_id, auth_kwargs=kwargs, node_id=unique_id
|
||||
)
|
||||
validate_video_result_response(final_response)
|
||||
|
||||
video = get_video_from_response(final_response)
|
||||
return video_result_to_node_output(video)
|
||||
return await video_result_to_node_output(video)
|
||||
|
||||
|
||||
class KlingDualCharacterVideoEffectNode(KlingVideoEffectsBase):
|
||||
@ -1142,7 +1142,7 @@ class KlingDualCharacterVideoEffectNode(KlingVideoEffectsBase):
|
||||
RETURN_TYPES = ("VIDEO", "STRING")
|
||||
RETURN_NAMES = ("VIDEO", "duration")
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
image_left: torch.Tensor,
|
||||
image_right: torch.Tensor,
|
||||
@ -1153,7 +1153,7 @@ class KlingDualCharacterVideoEffectNode(KlingVideoEffectsBase):
|
||||
unique_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
video, _, duration = super().api_call(
|
||||
video, _, duration = await super().api_call(
|
||||
dual_character=True,
|
||||
effect_scene=effect_scene,
|
||||
model_name=model_name,
|
||||
@ -1208,7 +1208,7 @@ class KlingSingleImageVideoEffectNode(KlingVideoEffectsBase):
|
||||
|
||||
DESCRIPTION = "Achieve different special effects when generating a video based on the effect_scene."
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
image: torch.Tensor,
|
||||
effect_scene: KlingSingleImageEffectsScene,
|
||||
@ -1217,7 +1217,7 @@ class KlingSingleImageVideoEffectNode(KlingVideoEffectsBase):
|
||||
unique_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
return super().api_call(
|
||||
return await super().api_call(
|
||||
dual_character=False,
|
||||
effect_scene=effect_scene,
|
||||
model_name=model_name,
|
||||
@ -1253,11 +1253,11 @@ class KlingLipSyncBase(KlingNodeBase):
|
||||
f"Text is too long. Maximum length is {MAX_PROMPT_LENGTH_LIP_SYNC} characters."
|
||||
)
|
||||
|
||||
def get_response(
|
||||
async def get_response(
|
||||
self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
|
||||
) -> KlingLipSyncResponse:
|
||||
"""Polls the Kling API endpoint until the task reaches a terminal state."""
|
||||
return poll_until_finished(
|
||||
return await poll_until_finished(
|
||||
auth_kwargs,
|
||||
ApiEndpoint(
|
||||
path=f"{PATH_LIP_SYNC}/{task_id}",
|
||||
@ -1270,7 +1270,7 @@ class KlingLipSyncBase(KlingNodeBase):
|
||||
node_id=node_id,
|
||||
)
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
video: VideoInput,
|
||||
audio: Optional[AudioInput] = None,
|
||||
@ -1287,12 +1287,12 @@ class KlingLipSyncBase(KlingNodeBase):
|
||||
self.validate_lip_sync_video(video)
|
||||
|
||||
# Upload video to Comfy API and get download URL
|
||||
video_url = upload_video_to_comfyapi(video, auth_kwargs=kwargs)
|
||||
video_url = await upload_video_to_comfyapi(video, auth_kwargs=kwargs)
|
||||
logging.info("Uploaded video to Comfy API. URL: %s", video_url)
|
||||
|
||||
# Upload the audio file to Comfy API and get download URL
|
||||
if audio:
|
||||
audio_url = upload_audio_to_comfyapi(audio, auth_kwargs=kwargs)
|
||||
audio_url = await upload_audio_to_comfyapi(audio, auth_kwargs=kwargs)
|
||||
logging.info("Uploaded audio to Comfy API. URL: %s", audio_url)
|
||||
else:
|
||||
audio_url = None
|
||||
@ -1319,17 +1319,17 @@ class KlingLipSyncBase(KlingNodeBase):
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
|
||||
task_creation_response = initial_operation.execute()
|
||||
task_creation_response = await initial_operation.execute()
|
||||
validate_task_creation_response(task_creation_response)
|
||||
task_id = task_creation_response.data.task_id
|
||||
|
||||
final_response = self.get_response(
|
||||
final_response = await self.get_response(
|
||||
task_id, auth_kwargs=kwargs, node_id=unique_id
|
||||
)
|
||||
validate_video_result_response(final_response)
|
||||
|
||||
video = get_video_from_response(final_response)
|
||||
return video_result_to_node_output(video)
|
||||
return await video_result_to_node_output(video)
|
||||
|
||||
|
||||
class KlingLipSyncAudioToVideoNode(KlingLipSyncBase):
|
||||
@ -1357,7 +1357,7 @@ class KlingLipSyncAudioToVideoNode(KlingLipSyncBase):
|
||||
|
||||
DESCRIPTION = "Kling Lip Sync Audio to Video Node. Syncs mouth movements in a video file to the audio content of an audio file. When using, ensure that the audio contains clearly distinguishable vocals and that the video contains a distinct face. The audio file should not be larger than 5MB. The video file should not be larger than 100MB, should have height/width between 720px and 1920px, and should be between 2s and 10s in length."
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
video: VideoInput,
|
||||
audio: AudioInput,
|
||||
@ -1365,7 +1365,7 @@ class KlingLipSyncAudioToVideoNode(KlingLipSyncBase):
|
||||
unique_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
return super().api_call(
|
||||
return await super().api_call(
|
||||
video=video,
|
||||
audio=audio,
|
||||
voice_language=voice_language,
|
||||
@ -1469,7 +1469,7 @@ class KlingLipSyncTextToVideoNode(KlingLipSyncBase):
|
||||
|
||||
DESCRIPTION = "Kling Lip Sync Text to Video Node. Syncs mouth movements in a video file to a text prompt. The video file should not be larger than 100MB, should have height/width between 720px and 1920px, and should be between 2s and 10s in length."
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
video: VideoInput,
|
||||
text: str,
|
||||
@ -1479,7 +1479,7 @@ class KlingLipSyncTextToVideoNode(KlingLipSyncBase):
|
||||
**kwargs,
|
||||
):
|
||||
voice_id, voice_language = KlingLipSyncTextToVideoNode.get_voice_config()[voice]
|
||||
return super().api_call(
|
||||
return await super().api_call(
|
||||
video=video,
|
||||
text=text,
|
||||
voice_language=voice_language,
|
||||
@ -1533,10 +1533,10 @@ class KlingVirtualTryOnNode(KlingImageGenerationBase):
|
||||
|
||||
DESCRIPTION = "Kling Virtual Try On Node. Input a human image and a cloth image to try on the cloth on the human. You can merge multiple clothing item pictures into one image with a white background."
|
||||
|
||||
def get_response(
|
||||
async def get_response(
|
||||
self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
|
||||
) -> KlingVirtualTryOnResponse:
|
||||
return poll_until_finished(
|
||||
return await poll_until_finished(
|
||||
auth_kwargs,
|
||||
ApiEndpoint(
|
||||
path=f"{PATH_VIRTUAL_TRY_ON}/{task_id}",
|
||||
@ -1549,7 +1549,7 @@ class KlingVirtualTryOnNode(KlingImageGenerationBase):
|
||||
node_id=node_id,
|
||||
)
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
human_image: torch.Tensor,
|
||||
cloth_image: torch.Tensor,
|
||||
@ -1572,17 +1572,17 @@ class KlingVirtualTryOnNode(KlingImageGenerationBase):
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
|
||||
task_creation_response = initial_operation.execute()
|
||||
task_creation_response = await initial_operation.execute()
|
||||
validate_task_creation_response(task_creation_response)
|
||||
task_id = task_creation_response.data.task_id
|
||||
|
||||
final_response = self.get_response(
|
||||
final_response = await self.get_response(
|
||||
task_id, auth_kwargs=kwargs, node_id=unique_id
|
||||
)
|
||||
validate_image_result_response(final_response)
|
||||
|
||||
images = get_images_from_response(final_response)
|
||||
return (image_result_to_node_output(images),)
|
||||
return (await image_result_to_node_output(images),)
|
||||
|
||||
|
||||
class KlingImageGenerationNode(KlingImageGenerationBase):
|
||||
@ -1655,13 +1655,13 @@ class KlingImageGenerationNode(KlingImageGenerationBase):
|
||||
|
||||
DESCRIPTION = "Kling Image Generation Node. Generate an image from a text prompt with an optional reference image."
|
||||
|
||||
def get_response(
|
||||
async def get_response(
|
||||
self,
|
||||
task_id: str,
|
||||
auth_kwargs: Optional[dict[str, str]],
|
||||
node_id: Optional[str] = None,
|
||||
) -> KlingImageGenerationsResponse:
|
||||
return poll_until_finished(
|
||||
return await poll_until_finished(
|
||||
auth_kwargs,
|
||||
ApiEndpoint(
|
||||
path=f"{PATH_IMAGE_GENERATIONS}/{task_id}",
|
||||
@ -1674,7 +1674,7 @@ class KlingImageGenerationNode(KlingImageGenerationBase):
|
||||
node_id=node_id,
|
||||
)
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
model_name: KlingImageGenModelName,
|
||||
prompt: str,
|
||||
@ -1714,17 +1714,17 @@ class KlingImageGenerationNode(KlingImageGenerationBase):
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
|
||||
task_creation_response = initial_operation.execute()
|
||||
task_creation_response = await initial_operation.execute()
|
||||
validate_task_creation_response(task_creation_response)
|
||||
task_id = task_creation_response.data.task_id
|
||||
|
||||
final_response = self.get_response(
|
||||
final_response = await self.get_response(
|
||||
task_id, auth_kwargs=kwargs, node_id=unique_id
|
||||
)
|
||||
validate_image_result_response(final_response)
|
||||
|
||||
images = get_images_from_response(final_response)
|
||||
return (image_result_to_node_output(images),)
|
||||
return (await image_result_to_node_output(images),)
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
|
@ -38,7 +38,7 @@ from comfy_api_nodes.apinode_utils import (
|
||||
)
|
||||
from server import PromptServer
|
||||
|
||||
import requests
|
||||
import aiohttp
|
||||
import torch
|
||||
from io import BytesIO
|
||||
|
||||
@ -217,7 +217,7 @@ class LumaImageGenerationNode(ComfyNodeABC):
|
||||
},
|
||||
}
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
prompt: str,
|
||||
model: str,
|
||||
@ -234,19 +234,19 @@ class LumaImageGenerationNode(ComfyNodeABC):
|
||||
# handle image_luma_ref
|
||||
api_image_ref = None
|
||||
if image_luma_ref is not None:
|
||||
api_image_ref = self._convert_luma_refs(
|
||||
api_image_ref = await self._convert_luma_refs(
|
||||
image_luma_ref, max_refs=4, auth_kwargs=kwargs,
|
||||
)
|
||||
# handle style_luma_ref
|
||||
api_style_ref = None
|
||||
if style_image is not None:
|
||||
api_style_ref = self._convert_style_image(
|
||||
api_style_ref = await self._convert_style_image(
|
||||
style_image, weight=style_image_weight, auth_kwargs=kwargs,
|
||||
)
|
||||
# handle character_ref images
|
||||
character_ref = None
|
||||
if character_image is not None:
|
||||
download_urls = upload_images_to_comfyapi(
|
||||
download_urls = await upload_images_to_comfyapi(
|
||||
character_image, max_images=4, auth_kwargs=kwargs,
|
||||
)
|
||||
character_ref = LumaCharacterRef(
|
||||
@ -270,7 +270,7 @@ class LumaImageGenerationNode(ComfyNodeABC):
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
response_api: LumaGeneration = operation.execute()
|
||||
response_api: LumaGeneration = await operation.execute()
|
||||
|
||||
operation = PollingOperation(
|
||||
poll_endpoint=ApiEndpoint(
|
||||
@ -286,19 +286,20 @@ class LumaImageGenerationNode(ComfyNodeABC):
|
||||
node_id=unique_id,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
response_poll = operation.execute()
|
||||
response_poll = await operation.execute()
|
||||
|
||||
img_response = requests.get(response_poll.assets.image)
|
||||
img = process_image_response(img_response)
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(response_poll.assets.image) as img_response:
|
||||
img = process_image_response(await img_response.content.read())
|
||||
return (img,)
|
||||
|
||||
def _convert_luma_refs(
|
||||
async def _convert_luma_refs(
|
||||
self, luma_ref: LumaReferenceChain, max_refs: int, auth_kwargs: Optional[dict[str,str]] = None
|
||||
):
|
||||
luma_urls = []
|
||||
ref_count = 0
|
||||
for ref in luma_ref.refs:
|
||||
download_urls = upload_images_to_comfyapi(
|
||||
download_urls = await upload_images_to_comfyapi(
|
||||
ref.image, max_images=1, auth_kwargs=auth_kwargs
|
||||
)
|
||||
luma_urls.append(download_urls[0])
|
||||
@ -307,13 +308,13 @@ class LumaImageGenerationNode(ComfyNodeABC):
|
||||
break
|
||||
return luma_ref.create_api_model(download_urls=luma_urls, max_refs=max_refs)
|
||||
|
||||
def _convert_style_image(
|
||||
async def _convert_style_image(
|
||||
self, style_image: torch.Tensor, weight: float, auth_kwargs: Optional[dict[str,str]] = None
|
||||
):
|
||||
chain = LumaReferenceChain(
|
||||
first_ref=LumaReference(image=style_image, weight=weight)
|
||||
)
|
||||
return self._convert_luma_refs(chain, max_refs=1, auth_kwargs=auth_kwargs)
|
||||
return await self._convert_luma_refs(chain, max_refs=1, auth_kwargs=auth_kwargs)
|
||||
|
||||
|
||||
class LumaImageModifyNode(ComfyNodeABC):
|
||||
@ -370,7 +371,7 @@ class LumaImageModifyNode(ComfyNodeABC):
|
||||
},
|
||||
}
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
prompt: str,
|
||||
model: str,
|
||||
@ -381,7 +382,7 @@ class LumaImageModifyNode(ComfyNodeABC):
|
||||
**kwargs,
|
||||
):
|
||||
# first, upload image
|
||||
download_urls = upload_images_to_comfyapi(
|
||||
download_urls = await upload_images_to_comfyapi(
|
||||
image, max_images=1, auth_kwargs=kwargs,
|
||||
)
|
||||
image_url = download_urls[0]
|
||||
@ -402,7 +403,7 @@ class LumaImageModifyNode(ComfyNodeABC):
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
response_api: LumaGeneration = operation.execute()
|
||||
response_api: LumaGeneration = await operation.execute()
|
||||
|
||||
operation = PollingOperation(
|
||||
poll_endpoint=ApiEndpoint(
|
||||
@ -418,10 +419,11 @@ class LumaImageModifyNode(ComfyNodeABC):
|
||||
node_id=unique_id,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
response_poll = operation.execute()
|
||||
response_poll = await operation.execute()
|
||||
|
||||
img_response = requests.get(response_poll.assets.image)
|
||||
img = process_image_response(img_response)
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(response_poll.assets.image) as img_response:
|
||||
img = process_image_response(await img_response.content.read())
|
||||
return (img,)
|
||||
|
||||
|
||||
@ -494,7 +496,7 @@ class LumaTextToVideoGenerationNode(ComfyNodeABC):
|
||||
},
|
||||
}
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
prompt: str,
|
||||
model: str,
|
||||
@ -529,7 +531,7 @@ class LumaTextToVideoGenerationNode(ComfyNodeABC):
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
response_api: LumaGeneration = operation.execute()
|
||||
response_api: LumaGeneration = await operation.execute()
|
||||
|
||||
if unique_id:
|
||||
PromptServer.instance.send_progress_text(f"Luma video generation started: {response_api.id}", unique_id)
|
||||
@ -549,10 +551,11 @@ class LumaTextToVideoGenerationNode(ComfyNodeABC):
|
||||
estimated_duration=LUMA_T2V_AVERAGE_DURATION,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
response_poll = operation.execute()
|
||||
response_poll = await operation.execute()
|
||||
|
||||
vid_response = requests.get(response_poll.assets.video)
|
||||
return (VideoFromFile(BytesIO(vid_response.content)),)
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(response_poll.assets.video) as vid_response:
|
||||
return (VideoFromFile(BytesIO(await vid_response.content.read())),)
|
||||
|
||||
|
||||
class LumaImageToVideoGenerationNode(ComfyNodeABC):
|
||||
@ -626,7 +629,7 @@ class LumaImageToVideoGenerationNode(ComfyNodeABC):
|
||||
},
|
||||
}
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
prompt: str,
|
||||
model: str,
|
||||
@ -644,7 +647,7 @@ class LumaImageToVideoGenerationNode(ComfyNodeABC):
|
||||
raise Exception(
|
||||
"At least one of first_image and last_image requires an input."
|
||||
)
|
||||
keyframes = self._convert_to_keyframes(first_image, last_image, auth_kwargs=kwargs)
|
||||
keyframes = await self._convert_to_keyframes(first_image, last_image, auth_kwargs=kwargs)
|
||||
duration = duration if model != LumaVideoModel.ray_1_6 else None
|
||||
resolution = resolution if model != LumaVideoModel.ray_1_6 else None
|
||||
|
||||
@ -667,7 +670,7 @@ class LumaImageToVideoGenerationNode(ComfyNodeABC):
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
response_api: LumaGeneration = operation.execute()
|
||||
response_api: LumaGeneration = await operation.execute()
|
||||
|
||||
if unique_id:
|
||||
PromptServer.instance.send_progress_text(f"Luma video generation started: {response_api.id}", unique_id)
|
||||
@ -687,12 +690,13 @@ class LumaImageToVideoGenerationNode(ComfyNodeABC):
|
||||
estimated_duration=LUMA_I2V_AVERAGE_DURATION,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
response_poll = operation.execute()
|
||||
response_poll = await operation.execute()
|
||||
|
||||
vid_response = requests.get(response_poll.assets.video)
|
||||
return (VideoFromFile(BytesIO(vid_response.content)),)
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(response_poll.assets.video) as vid_response:
|
||||
return (VideoFromFile(BytesIO(await vid_response.content.read())),)
|
||||
|
||||
def _convert_to_keyframes(
|
||||
async def _convert_to_keyframes(
|
||||
self,
|
||||
first_image: torch.Tensor = None,
|
||||
last_image: torch.Tensor = None,
|
||||
@ -703,12 +707,12 @@ class LumaImageToVideoGenerationNode(ComfyNodeABC):
|
||||
frame0 = None
|
||||
frame1 = None
|
||||
if first_image is not None:
|
||||
download_urls = upload_images_to_comfyapi(
|
||||
download_urls = await upload_images_to_comfyapi(
|
||||
first_image, max_images=1, auth_kwargs=auth_kwargs,
|
||||
)
|
||||
frame0 = LumaImageReference(type="image", url=download_urls[0])
|
||||
if last_image is not None:
|
||||
download_urls = upload_images_to_comfyapi(
|
||||
download_urls = await upload_images_to_comfyapi(
|
||||
last_image, max_images=1, auth_kwargs=auth_kwargs,
|
||||
)
|
||||
frame1 = LumaImageReference(type="image", url=download_urls[0])
|
||||
|
@ -86,7 +86,7 @@ class MinimaxTextToVideoNode:
|
||||
API_NODE = True
|
||||
OUTPUT_NODE = True
|
||||
|
||||
def generate_video(
|
||||
async def generate_video(
|
||||
self,
|
||||
prompt_text,
|
||||
seed=0,
|
||||
@ -104,12 +104,12 @@ class MinimaxTextToVideoNode:
|
||||
# upload image, if passed in
|
||||
image_url = None
|
||||
if image is not None:
|
||||
image_url = upload_images_to_comfyapi(image, max_images=1, auth_kwargs=kwargs)[0]
|
||||
image_url = (await upload_images_to_comfyapi(image, max_images=1, auth_kwargs=kwargs))[0]
|
||||
|
||||
# TODO: figure out how to deal with subject properly, API returns invalid params when using S2V-01 model
|
||||
subject_reference = None
|
||||
if subject is not None:
|
||||
subject_url = upload_images_to_comfyapi(subject, max_images=1, auth_kwargs=kwargs)[0]
|
||||
subject_url = (await upload_images_to_comfyapi(subject, max_images=1, auth_kwargs=kwargs))[0]
|
||||
subject_reference = [SubjectReferenceItem(image=subject_url)]
|
||||
|
||||
|
||||
@ -130,7 +130,7 @@ class MinimaxTextToVideoNode:
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
response = video_generate_operation.execute()
|
||||
response = await video_generate_operation.execute()
|
||||
|
||||
task_id = response.task_id
|
||||
if not task_id:
|
||||
@ -151,7 +151,7 @@ class MinimaxTextToVideoNode:
|
||||
node_id=unique_id,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
task_result = video_generate_operation.execute()
|
||||
task_result = await video_generate_operation.execute()
|
||||
|
||||
file_id = task_result.file_id
|
||||
if file_id is None:
|
||||
@ -167,7 +167,7 @@ class MinimaxTextToVideoNode:
|
||||
request=EmptyRequest(),
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
file_result = file_retrieve_operation.execute()
|
||||
file_result = await file_retrieve_operation.execute()
|
||||
|
||||
file_url = file_result.file.download_url
|
||||
if file_url is None:
|
||||
@ -182,7 +182,7 @@ class MinimaxTextToVideoNode:
|
||||
message = f"Result URL: {file_url}"
|
||||
PromptServer.instance.send_progress_text(message, unique_id)
|
||||
|
||||
video_io = download_url_to_bytesio(file_url)
|
||||
video_io = await download_url_to_bytesio(file_url)
|
||||
if video_io is None:
|
||||
error_msg = f"Failed to download video from {file_url}"
|
||||
logging.error(error_msg)
|
||||
|
@ -95,14 +95,14 @@ def get_video_url_from_response(response) -> Optional[str]:
|
||||
return None
|
||||
|
||||
|
||||
def poll_until_finished(
|
||||
async def poll_until_finished(
|
||||
auth_kwargs: dict[str, str],
|
||||
api_endpoint: ApiEndpoint[Any, R],
|
||||
result_url_extractor: Optional[Callable[[R], str]] = None,
|
||||
node_id: Optional[str] = None,
|
||||
) -> R:
|
||||
"""Polls the Moonvalley API endpoint until the task reaches a terminal state, then returns the response."""
|
||||
return PollingOperation(
|
||||
return await PollingOperation(
|
||||
poll_endpoint=api_endpoint,
|
||||
completed_statuses=[
|
||||
"completed",
|
||||
@ -394,10 +394,10 @@ class BaseMoonvalleyVideoNode:
|
||||
else:
|
||||
return control_map["Motion Transfer"]
|
||||
|
||||
def get_response(
|
||||
async def get_response(
|
||||
self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
|
||||
) -> MoonvalleyPromptResponse:
|
||||
return poll_until_finished(
|
||||
return await poll_until_finished(
|
||||
auth_kwargs,
|
||||
ApiEndpoint(
|
||||
path=f"{API_PROMPTS_ENDPOINT}/{task_id}",
|
||||
@ -507,7 +507,7 @@ class MoonvalleyImg2VideoNode(BaseMoonvalleyVideoNode):
|
||||
RETURN_NAMES = ("video",)
|
||||
DESCRIPTION = "Moonvalley Marey Image to Video Node"
|
||||
|
||||
def generate(
|
||||
async def generate(
|
||||
self, prompt, negative_prompt, unique_id: Optional[str] = None, **kwargs
|
||||
):
|
||||
image = kwargs.get("image", None)
|
||||
@ -532,9 +532,9 @@ class MoonvalleyImg2VideoNode(BaseMoonvalleyVideoNode):
|
||||
# Get MIME type from tensor - assuming PNG format for image tensors
|
||||
mime_type = "image/png"
|
||||
|
||||
image_url = upload_images_to_comfyapi(
|
||||
image_url = (await upload_images_to_comfyapi(
|
||||
image, max_images=1, auth_kwargs=kwargs, mime_type=mime_type
|
||||
)[0]
|
||||
))[0]
|
||||
|
||||
request = MoonvalleyTextToVideoRequest(
|
||||
image_url=image_url, prompt_text=prompt, inference_params=inference_params
|
||||
@ -549,14 +549,14 @@ class MoonvalleyImg2VideoNode(BaseMoonvalleyVideoNode):
|
||||
request=request,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
task_creation_response = initial_operation.execute()
|
||||
task_creation_response = await initial_operation.execute()
|
||||
validate_task_creation_response(task_creation_response)
|
||||
task_id = task_creation_response.id
|
||||
|
||||
final_response = self.get_response(
|
||||
final_response = await self.get_response(
|
||||
task_id, auth_kwargs=kwargs, node_id=unique_id
|
||||
)
|
||||
video = download_url_to_video_output(final_response.output_url)
|
||||
video = await download_url_to_video_output(final_response.output_url)
|
||||
return (video,)
|
||||
|
||||
|
||||
@ -609,7 +609,7 @@ class MoonvalleyVideo2VideoNode(BaseMoonvalleyVideoNode):
|
||||
RETURN_TYPES = ("VIDEO",)
|
||||
RETURN_NAMES = ("video",)
|
||||
|
||||
def generate(
|
||||
async def generate(
|
||||
self, prompt, negative_prompt, unique_id: Optional[str] = None, **kwargs
|
||||
):
|
||||
video = kwargs.get("video")
|
||||
@ -620,7 +620,7 @@ class MoonvalleyVideo2VideoNode(BaseMoonvalleyVideoNode):
|
||||
video_url = ""
|
||||
if video:
|
||||
validated_video = validate_video_to_video_input(video)
|
||||
video_url = upload_video_to_comfyapi(validated_video, auth_kwargs=kwargs)
|
||||
video_url = await upload_video_to_comfyapi(validated_video, auth_kwargs=kwargs)
|
||||
|
||||
control_type = kwargs.get("control_type")
|
||||
motion_intensity = kwargs.get("motion_intensity")
|
||||
@ -658,15 +658,15 @@ class MoonvalleyVideo2VideoNode(BaseMoonvalleyVideoNode):
|
||||
request=request,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
task_creation_response = initial_operation.execute()
|
||||
task_creation_response = await initial_operation.execute()
|
||||
validate_task_creation_response(task_creation_response)
|
||||
task_id = task_creation_response.id
|
||||
|
||||
final_response = self.get_response(
|
||||
final_response = await self.get_response(
|
||||
task_id, auth_kwargs=kwargs, node_id=unique_id
|
||||
)
|
||||
|
||||
video = download_url_to_video_output(final_response.output_url)
|
||||
video = await download_url_to_video_output(final_response.output_url)
|
||||
|
||||
return (video,)
|
||||
|
||||
@ -688,7 +688,7 @@ class MoonvalleyTxt2VideoNode(BaseMoonvalleyVideoNode):
|
||||
del input_types["optional"][param]
|
||||
return input_types
|
||||
|
||||
def generate(
|
||||
async def generate(
|
||||
self, prompt, negative_prompt, unique_id: Optional[str] = None, **kwargs
|
||||
):
|
||||
validate_prompts(prompt, negative_prompt, MOONVALLEY_MAREY_MAX_PROMPT_LENGTH)
|
||||
@ -717,15 +717,15 @@ class MoonvalleyTxt2VideoNode(BaseMoonvalleyVideoNode):
|
||||
request=request,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
task_creation_response = initial_operation.execute()
|
||||
task_creation_response = await initial_operation.execute()
|
||||
validate_task_creation_response(task_creation_response)
|
||||
task_id = task_creation_response.id
|
||||
|
||||
final_response = self.get_response(
|
||||
final_response = await self.get_response(
|
||||
task_id, auth_kwargs=kwargs, node_id=unique_id
|
||||
)
|
||||
|
||||
video = download_url_to_video_output(final_response.output_url)
|
||||
video = await download_url_to_video_output(final_response.output_url)
|
||||
return (video,)
|
||||
|
||||
|
||||
|
@ -163,7 +163,7 @@ class OpenAIDalle2(ComfyNodeABC):
|
||||
DESCRIPTION = cleandoc(__doc__ or "")
|
||||
API_NODE = True
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
prompt,
|
||||
seed=0,
|
||||
@ -233,9 +233,9 @@ class OpenAIDalle2(ComfyNodeABC):
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
|
||||
response = operation.execute()
|
||||
response = await operation.execute()
|
||||
|
||||
img_tensor = validate_and_cast_response(response, node_id=unique_id)
|
||||
img_tensor = await validate_and_cast_response(response, node_id=unique_id)
|
||||
return (img_tensor,)
|
||||
|
||||
|
||||
@ -311,7 +311,7 @@ class OpenAIDalle3(ComfyNodeABC):
|
||||
DESCRIPTION = cleandoc(__doc__ or "")
|
||||
API_NODE = True
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
prompt,
|
||||
seed=0,
|
||||
@ -343,9 +343,9 @@ class OpenAIDalle3(ComfyNodeABC):
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
|
||||
response = operation.execute()
|
||||
response = await operation.execute()
|
||||
|
||||
img_tensor = validate_and_cast_response(response, node_id=unique_id)
|
||||
img_tensor = await validate_and_cast_response(response, node_id=unique_id)
|
||||
return (img_tensor,)
|
||||
|
||||
|
||||
@ -446,7 +446,7 @@ class OpenAIGPTImage1(ComfyNodeABC):
|
||||
DESCRIPTION = cleandoc(__doc__ or "")
|
||||
API_NODE = True
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
prompt,
|
||||
seed=0,
|
||||
@ -537,9 +537,9 @@ class OpenAIGPTImage1(ComfyNodeABC):
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
|
||||
response = operation.execute()
|
||||
response = await operation.execute()
|
||||
|
||||
img_tensor = validate_and_cast_response(response, node_id=unique_id)
|
||||
img_tensor = await validate_and_cast_response(response, node_id=unique_id)
|
||||
return (img_tensor,)
|
||||
|
||||
|
||||
@ -623,7 +623,7 @@ class OpenAIChatNode(OpenAITextNode):
|
||||
|
||||
DESCRIPTION = "Generate text responses from an OpenAI model."
|
||||
|
||||
def get_result_response(
|
||||
async def get_result_response(
|
||||
self,
|
||||
response_id: str,
|
||||
include: Optional[list[Includable]] = None,
|
||||
@ -639,7 +639,7 @@ class OpenAIChatNode(OpenAITextNode):
|
||||
creation above for more information.
|
||||
|
||||
"""
|
||||
return PollingOperation(
|
||||
return await PollingOperation(
|
||||
poll_endpoint=ApiEndpoint(
|
||||
path=f"{RESPONSES_ENDPOINT}/{response_id}",
|
||||
method=HttpMethod.GET,
|
||||
@ -784,7 +784,7 @@ class OpenAIChatNode(OpenAITextNode):
|
||||
|
||||
self.history[session_id] = new_history
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
prompt: str,
|
||||
persist_context: bool,
|
||||
@ -815,7 +815,7 @@ class OpenAIChatNode(OpenAITextNode):
|
||||
previous_response_id = None
|
||||
|
||||
# Create response
|
||||
create_response = SynchronousOperation(
|
||||
create_response = await SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path=RESPONSES_ENDPOINT,
|
||||
method=HttpMethod.POST,
|
||||
@ -848,7 +848,7 @@ class OpenAIChatNode(OpenAITextNode):
|
||||
response_id = create_response.id
|
||||
|
||||
# Get result output
|
||||
result_response = self.get_result_response(response_id, auth_kwargs=kwargs)
|
||||
result_response = await self.get_result_response(response_id, auth_kwargs=kwargs)
|
||||
output_text = self.parse_output_text_from_response(result_response)
|
||||
|
||||
# Update history
|
||||
|
@ -122,7 +122,7 @@ class PikaNodeBase(ComfyNodeABC):
|
||||
FUNCTION = "api_call"
|
||||
RETURN_TYPES = ("VIDEO",)
|
||||
|
||||
def poll_for_task_status(
|
||||
async def poll_for_task_status(
|
||||
self,
|
||||
task_id: str,
|
||||
auth_kwargs: Optional[dict[str, str]] = None,
|
||||
@ -152,9 +152,9 @@ class PikaNodeBase(ComfyNodeABC):
|
||||
node_id=node_id,
|
||||
estimated_duration=60
|
||||
)
|
||||
return polling_operation.execute()
|
||||
return await polling_operation.execute()
|
||||
|
||||
def execute_task(
|
||||
async def execute_task(
|
||||
self,
|
||||
initial_operation: SynchronousOperation[R, PikaGenerateResponse],
|
||||
auth_kwargs: Optional[dict[str, str]] = None,
|
||||
@ -169,14 +169,14 @@ class PikaNodeBase(ComfyNodeABC):
|
||||
Returns:
|
||||
A tuple containing the video file as a VIDEO output.
|
||||
"""
|
||||
initial_response = initial_operation.execute()
|
||||
initial_response = await initial_operation.execute()
|
||||
if not is_valid_initial_response(initial_response):
|
||||
error_msg = f"Pika initial request failed. Code: {initial_response.code}, Message: {initial_response.message}, Data: {initial_response.data}"
|
||||
logging.error(error_msg)
|
||||
raise PikaApiError(error_msg)
|
||||
|
||||
task_id = initial_response.video_id
|
||||
final_response = self.poll_for_task_status(task_id, auth_kwargs)
|
||||
final_response = await self.poll_for_task_status(task_id, auth_kwargs)
|
||||
if not is_valid_video_response(final_response):
|
||||
error_msg = (
|
||||
f"Pika task {task_id} succeeded but no video data found in response."
|
||||
@ -187,7 +187,7 @@ class PikaNodeBase(ComfyNodeABC):
|
||||
video_url = str(final_response.url)
|
||||
logging.info("Pika task %s succeeded. Video URL: %s", task_id, video_url)
|
||||
|
||||
return (download_url_to_video_output(video_url),)
|
||||
return (await download_url_to_video_output(video_url),)
|
||||
|
||||
|
||||
class PikaImageToVideoV2_2(PikaNodeBase):
|
||||
@ -212,7 +212,7 @@ class PikaImageToVideoV2_2(PikaNodeBase):
|
||||
|
||||
DESCRIPTION = "Sends an image and prompt to the Pika API v2.2 to generate a video."
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
image: torch.Tensor,
|
||||
prompt_text: str,
|
||||
@ -251,7 +251,7 @@ class PikaImageToVideoV2_2(PikaNodeBase):
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
|
||||
return self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id)
|
||||
return await self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id)
|
||||
|
||||
|
||||
class PikaTextToVideoNodeV2_2(PikaNodeBase):
|
||||
@ -281,7 +281,7 @@ class PikaTextToVideoNodeV2_2(PikaNodeBase):
|
||||
|
||||
DESCRIPTION = "Sends a text prompt to the Pika API v2.2 to generate a video."
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
prompt_text: str,
|
||||
negative_prompt: str,
|
||||
@ -311,7 +311,7 @@ class PikaTextToVideoNodeV2_2(PikaNodeBase):
|
||||
content_type="application/x-www-form-urlencoded",
|
||||
)
|
||||
|
||||
return self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id)
|
||||
return await self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id)
|
||||
|
||||
|
||||
class PikaScenesV2_2(PikaNodeBase):
|
||||
@ -361,7 +361,7 @@ class PikaScenesV2_2(PikaNodeBase):
|
||||
|
||||
DESCRIPTION = "Combine your images to create a video with the objects in them. Upload multiple images as ingredients and generate a high-quality video that incorporates all of them."
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
prompt_text: str,
|
||||
negative_prompt: str,
|
||||
@ -420,7 +420,7 @@ class PikaScenesV2_2(PikaNodeBase):
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
|
||||
return self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id)
|
||||
return await self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id)
|
||||
|
||||
|
||||
class PikAdditionsNode(PikaNodeBase):
|
||||
@ -462,7 +462,7 @@ class PikAdditionsNode(PikaNodeBase):
|
||||
|
||||
DESCRIPTION = "Add any object or image into your video. Upload a video and specify what you'd like to add to create a seamlessly integrated result."
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
video: VideoInput,
|
||||
image: torch.Tensor,
|
||||
@ -481,10 +481,10 @@ class PikAdditionsNode(PikaNodeBase):
|
||||
image_bytes_io = tensor_to_bytesio(image)
|
||||
image_bytes_io.seek(0)
|
||||
|
||||
pika_files = [
|
||||
("video", ("video.mp4", video_bytes_io, "video/mp4")),
|
||||
("image", ("image.png", image_bytes_io, "image/png")),
|
||||
]
|
||||
pika_files = {
|
||||
"video": ("video.mp4", video_bytes_io, "video/mp4"),
|
||||
"image": ("image.png", image_bytes_io, "image/png"),
|
||||
}
|
||||
|
||||
# Prepare non-file data
|
||||
pika_request_data = PikaBodyGeneratePikadditionsGeneratePikadditionsPost(
|
||||
@ -506,7 +506,7 @@ class PikAdditionsNode(PikaNodeBase):
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
|
||||
return self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id)
|
||||
return await self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id)
|
||||
|
||||
|
||||
class PikaSwapsNode(PikaNodeBase):
|
||||
@ -558,7 +558,7 @@ class PikaSwapsNode(PikaNodeBase):
|
||||
DESCRIPTION = "Swap out any object or region of your video with a new image or object. Define areas to replace either with a mask or coordinates."
|
||||
RETURN_TYPES = ("VIDEO",)
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
video: VideoInput,
|
||||
image: torch.Tensor,
|
||||
@ -587,11 +587,11 @@ class PikaSwapsNode(PikaNodeBase):
|
||||
image_bytes_io = tensor_to_bytesio(image)
|
||||
image_bytes_io.seek(0)
|
||||
|
||||
pika_files = [
|
||||
("video", ("video.mp4", video_bytes_io, "video/mp4")),
|
||||
("image", ("image.png", image_bytes_io, "image/png")),
|
||||
("modifyRegionMask", ("mask.png", mask_bytes_io, "image/png")),
|
||||
]
|
||||
pika_files = {
|
||||
"video": ("video.mp4", video_bytes_io, "video/mp4"),
|
||||
"image": ("image.png", image_bytes_io, "image/png"),
|
||||
"modifyRegionMask": ("mask.png", mask_bytes_io, "image/png"),
|
||||
}
|
||||
|
||||
# Prepare non-file data
|
||||
pika_request_data = PikaBodyGeneratePikaswapsGeneratePikaswapsPost(
|
||||
@ -613,7 +613,7 @@ class PikaSwapsNode(PikaNodeBase):
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
|
||||
return self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id)
|
||||
return await self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id)
|
||||
|
||||
|
||||
class PikaffectsNode(PikaNodeBase):
|
||||
@ -664,7 +664,7 @@ class PikaffectsNode(PikaNodeBase):
|
||||
|
||||
DESCRIPTION = "Generate a video with a specific Pikaffect. Supported Pikaffects: Cake-ify, Crumble, Crush, Decapitate, Deflate, Dissolve, Explode, Eye-pop, Inflate, Levitate, Melt, Peel, Poke, Squish, Ta-da, Tear"
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
image: torch.Tensor,
|
||||
pikaffect: str,
|
||||
@ -693,7 +693,7 @@ class PikaffectsNode(PikaNodeBase):
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
|
||||
return self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id)
|
||||
return await self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id)
|
||||
|
||||
|
||||
class PikaStartEndFrameNode2_2(PikaNodeBase):
|
||||
@ -718,7 +718,7 @@ class PikaStartEndFrameNode2_2(PikaNodeBase):
|
||||
|
||||
DESCRIPTION = "Generate a video by combining your first and last frame. Upload two images to define the start and end points, and let the AI create a smooth transition between them."
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
image_start: torch.Tensor,
|
||||
image_end: torch.Tensor,
|
||||
@ -732,10 +732,7 @@ class PikaStartEndFrameNode2_2(PikaNodeBase):
|
||||
) -> tuple[VideoFromFile]:
|
||||
|
||||
pika_files = [
|
||||
(
|
||||
"keyFrames",
|
||||
("image_start.png", tensor_to_bytesio(image_start), "image/png"),
|
||||
),
|
||||
("keyFrames", ("image_start.png", tensor_to_bytesio(image_start), "image/png")),
|
||||
("keyFrames", ("image_end.png", tensor_to_bytesio(image_end), "image/png")),
|
||||
]
|
||||
|
||||
@ -758,7 +755,7 @@ class PikaStartEndFrameNode2_2(PikaNodeBase):
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
|
||||
return self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id)
|
||||
return await self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id)
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
|
@ -30,7 +30,7 @@ from comfy.comfy_types.node_typing import IO, ComfyNodeABC
|
||||
from comfy_api.input_impl import VideoFromFile
|
||||
|
||||
import torch
|
||||
import requests
|
||||
import aiohttp
|
||||
from io import BytesIO
|
||||
|
||||
|
||||
@ -47,7 +47,7 @@ def get_video_url_from_response(
|
||||
return str(response.Resp.url)
|
||||
|
||||
|
||||
def upload_image_to_pixverse(image: torch.Tensor, auth_kwargs=None):
|
||||
async def upload_image_to_pixverse(image: torch.Tensor, auth_kwargs=None):
|
||||
# first, upload image to Pixverse and get image id to use in actual generation call
|
||||
files = {"image": tensor_to_bytesio(image)}
|
||||
operation = SynchronousOperation(
|
||||
@ -62,7 +62,7 @@ def upload_image_to_pixverse(image: torch.Tensor, auth_kwargs=None):
|
||||
content_type="multipart/form-data",
|
||||
auth_kwargs=auth_kwargs,
|
||||
)
|
||||
response_upload: PixverseImageUploadResponse = operation.execute()
|
||||
response_upload: PixverseImageUploadResponse = await operation.execute()
|
||||
|
||||
if response_upload.Resp is None:
|
||||
raise Exception(
|
||||
@ -164,7 +164,7 @@ class PixverseTextToVideoNode(ComfyNodeABC):
|
||||
},
|
||||
}
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
prompt: str,
|
||||
aspect_ratio: str,
|
||||
@ -205,7 +205,7 @@ class PixverseTextToVideoNode(ComfyNodeABC):
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
response_api = operation.execute()
|
||||
response_api = await operation.execute()
|
||||
|
||||
if response_api.Resp is None:
|
||||
raise Exception(f"PixVerse request failed: '{response_api.ErrMsg}'")
|
||||
@ -229,11 +229,11 @@ class PixverseTextToVideoNode(ComfyNodeABC):
|
||||
result_url_extractor=get_video_url_from_response,
|
||||
estimated_duration=AVERAGE_DURATION_T2V,
|
||||
)
|
||||
response_poll = operation.execute()
|
||||
response_poll = await operation.execute()
|
||||
|
||||
vid_response = requests.get(response_poll.Resp.url)
|
||||
|
||||
return (VideoFromFile(BytesIO(vid_response.content)),)
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(response_poll.Resp.url) as vid_response:
|
||||
return (VideoFromFile(BytesIO(await vid_response.content.read())),)
|
||||
|
||||
|
||||
class PixverseImageToVideoNode(ComfyNodeABC):
|
||||
@ -302,7 +302,7 @@ class PixverseImageToVideoNode(ComfyNodeABC):
|
||||
},
|
||||
}
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
image: torch.Tensor,
|
||||
prompt: str,
|
||||
@ -316,7 +316,7 @@ class PixverseImageToVideoNode(ComfyNodeABC):
|
||||
**kwargs,
|
||||
):
|
||||
validate_string(prompt, strip_whitespace=False)
|
||||
img_id = upload_image_to_pixverse(image, auth_kwargs=kwargs)
|
||||
img_id = await upload_image_to_pixverse(image, auth_kwargs=kwargs)
|
||||
|
||||
# 1080p is limited to 5 seconds duration
|
||||
# only normal motion_mode supported for 1080p or for non-5 second duration
|
||||
@ -345,7 +345,7 @@ class PixverseImageToVideoNode(ComfyNodeABC):
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
response_api = operation.execute()
|
||||
response_api = await operation.execute()
|
||||
|
||||
if response_api.Resp is None:
|
||||
raise Exception(f"PixVerse request failed: '{response_api.ErrMsg}'")
|
||||
@ -369,10 +369,11 @@ class PixverseImageToVideoNode(ComfyNodeABC):
|
||||
result_url_extractor=get_video_url_from_response,
|
||||
estimated_duration=AVERAGE_DURATION_I2V,
|
||||
)
|
||||
response_poll = operation.execute()
|
||||
response_poll = await operation.execute()
|
||||
|
||||
vid_response = requests.get(response_poll.Resp.url)
|
||||
return (VideoFromFile(BytesIO(vid_response.content)),)
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(response_poll.Resp.url) as vid_response:
|
||||
return (VideoFromFile(BytesIO(await vid_response.content.read())),)
|
||||
|
||||
|
||||
class PixverseTransitionVideoNode(ComfyNodeABC):
|
||||
@ -436,7 +437,7 @@ class PixverseTransitionVideoNode(ComfyNodeABC):
|
||||
},
|
||||
}
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
first_frame: torch.Tensor,
|
||||
last_frame: torch.Tensor,
|
||||
@ -450,8 +451,8 @@ class PixverseTransitionVideoNode(ComfyNodeABC):
|
||||
**kwargs,
|
||||
):
|
||||
validate_string(prompt, strip_whitespace=False)
|
||||
first_frame_id = upload_image_to_pixverse(first_frame, auth_kwargs=kwargs)
|
||||
last_frame_id = upload_image_to_pixverse(last_frame, auth_kwargs=kwargs)
|
||||
first_frame_id = await upload_image_to_pixverse(first_frame, auth_kwargs=kwargs)
|
||||
last_frame_id = await upload_image_to_pixverse(last_frame, auth_kwargs=kwargs)
|
||||
|
||||
# 1080p is limited to 5 seconds duration
|
||||
# only normal motion_mode supported for 1080p or for non-5 second duration
|
||||
@ -480,7 +481,7 @@ class PixverseTransitionVideoNode(ComfyNodeABC):
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
response_api = operation.execute()
|
||||
response_api = await operation.execute()
|
||||
|
||||
if response_api.Resp is None:
|
||||
raise Exception(f"PixVerse request failed: '{response_api.ErrMsg}'")
|
||||
@ -504,10 +505,11 @@ class PixverseTransitionVideoNode(ComfyNodeABC):
|
||||
result_url_extractor=get_video_url_from_response,
|
||||
estimated_duration=AVERAGE_DURATION_T2V,
|
||||
)
|
||||
response_poll = operation.execute()
|
||||
response_poll = await operation.execute()
|
||||
|
||||
vid_response = requests.get(response_poll.Resp.url)
|
||||
return (VideoFromFile(BytesIO(vid_response.content)),)
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(response_poll.Resp.url) as vid_response:
|
||||
return (VideoFromFile(BytesIO(await vid_response.content.read())),)
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
|
@ -37,7 +37,7 @@ from io import BytesIO
|
||||
from PIL import UnidentifiedImageError
|
||||
|
||||
|
||||
def handle_recraft_file_request(
|
||||
async def handle_recraft_file_request(
|
||||
image: torch.Tensor,
|
||||
path: str,
|
||||
mask: torch.Tensor=None,
|
||||
@ -71,13 +71,13 @@ def handle_recraft_file_request(
|
||||
auth_kwargs=auth_kwargs,
|
||||
multipart_parser=recraft_multipart_parser,
|
||||
)
|
||||
response: RecraftImageGenerationResponse = operation.execute()
|
||||
response: RecraftImageGenerationResponse = await operation.execute()
|
||||
all_bytesio = []
|
||||
if response.image is not None:
|
||||
all_bytesio.append(download_url_to_bytesio(response.image.url, timeout=timeout))
|
||||
all_bytesio.append(await download_url_to_bytesio(response.image.url, timeout=timeout))
|
||||
else:
|
||||
for data in response.data:
|
||||
all_bytesio.append(download_url_to_bytesio(data.url, timeout=timeout))
|
||||
all_bytesio.append(await download_url_to_bytesio(data.url, timeout=timeout))
|
||||
|
||||
return all_bytesio
|
||||
|
||||
@ -395,7 +395,7 @@ class RecraftTextToImageNode:
|
||||
},
|
||||
}
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
prompt: str,
|
||||
size: str,
|
||||
@ -439,7 +439,7 @@ class RecraftTextToImageNode:
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
response: RecraftImageGenerationResponse = operation.execute()
|
||||
response: RecraftImageGenerationResponse = await operation.execute()
|
||||
images = []
|
||||
urls = []
|
||||
for data in response.data:
|
||||
@ -451,7 +451,7 @@ class RecraftTextToImageNode:
|
||||
f"Result URL: {urls_string}", unique_id
|
||||
)
|
||||
image = bytesio_to_image_tensor(
|
||||
download_url_to_bytesio(data.url, timeout=1024)
|
||||
await download_url_to_bytesio(data.url, timeout=1024)
|
||||
)
|
||||
if len(image.shape) < 4:
|
||||
image = image.unsqueeze(0)
|
||||
@ -538,7 +538,7 @@ class RecraftImageToImageNode:
|
||||
},
|
||||
}
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
image: torch.Tensor,
|
||||
prompt: str,
|
||||
@ -578,7 +578,7 @@ class RecraftImageToImageNode:
|
||||
total = image.shape[0]
|
||||
pbar = ProgressBar(total)
|
||||
for i in range(total):
|
||||
sub_bytes = handle_recraft_file_request(
|
||||
sub_bytes = await handle_recraft_file_request(
|
||||
image=image[i],
|
||||
path="/proxy/recraft/images/imageToImage",
|
||||
request=request,
|
||||
@ -654,7 +654,7 @@ class RecraftImageInpaintingNode:
|
||||
},
|
||||
}
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
image: torch.Tensor,
|
||||
mask: torch.Tensor,
|
||||
@ -690,7 +690,7 @@ class RecraftImageInpaintingNode:
|
||||
total = image.shape[0]
|
||||
pbar = ProgressBar(total)
|
||||
for i in range(total):
|
||||
sub_bytes = handle_recraft_file_request(
|
||||
sub_bytes = await handle_recraft_file_request(
|
||||
image=image[i],
|
||||
mask=mask[i:i+1],
|
||||
path="/proxy/recraft/images/inpaint",
|
||||
@ -779,7 +779,7 @@ class RecraftTextToVectorNode:
|
||||
},
|
||||
}
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
prompt: str,
|
||||
substyle: str,
|
||||
@ -821,7 +821,7 @@ class RecraftTextToVectorNode:
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
response: RecraftImageGenerationResponse = operation.execute()
|
||||
response: RecraftImageGenerationResponse = await operation.execute()
|
||||
svg_data = []
|
||||
urls = []
|
||||
for data in response.data:
|
||||
@ -831,7 +831,7 @@ class RecraftTextToVectorNode:
|
||||
PromptServer.instance.send_progress_text(
|
||||
f"Result URL: {' '.join(urls)}", unique_id
|
||||
)
|
||||
svg_data.append(download_url_to_bytesio(data.url, timeout=1024))
|
||||
svg_data.append(await download_url_to_bytesio(data.url, timeout=1024))
|
||||
|
||||
return (SVG(svg_data),)
|
||||
|
||||
@ -861,7 +861,7 @@ class RecraftVectorizeImageNode:
|
||||
},
|
||||
}
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
image: torch.Tensor,
|
||||
**kwargs,
|
||||
@ -870,7 +870,7 @@ class RecraftVectorizeImageNode:
|
||||
total = image.shape[0]
|
||||
pbar = ProgressBar(total)
|
||||
for i in range(total):
|
||||
sub_bytes = handle_recraft_file_request(
|
||||
sub_bytes = await handle_recraft_file_request(
|
||||
image=image[i],
|
||||
path="/proxy/recraft/images/vectorize",
|
||||
auth_kwargs=kwargs,
|
||||
@ -942,7 +942,7 @@ class RecraftReplaceBackgroundNode:
|
||||
},
|
||||
}
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
image: torch.Tensor,
|
||||
prompt: str,
|
||||
@ -973,7 +973,7 @@ class RecraftReplaceBackgroundNode:
|
||||
total = image.shape[0]
|
||||
pbar = ProgressBar(total)
|
||||
for i in range(total):
|
||||
sub_bytes = handle_recraft_file_request(
|
||||
sub_bytes = await handle_recraft_file_request(
|
||||
image=image[i],
|
||||
path="/proxy/recraft/images/replaceBackground",
|
||||
request=request,
|
||||
@ -1011,7 +1011,7 @@ class RecraftRemoveBackgroundNode:
|
||||
},
|
||||
}
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
image: torch.Tensor,
|
||||
**kwargs,
|
||||
@ -1020,7 +1020,7 @@ class RecraftRemoveBackgroundNode:
|
||||
total = image.shape[0]
|
||||
pbar = ProgressBar(total)
|
||||
for i in range(total):
|
||||
sub_bytes = handle_recraft_file_request(
|
||||
sub_bytes = await handle_recraft_file_request(
|
||||
image=image[i],
|
||||
path="/proxy/recraft/images/removeBackground",
|
||||
auth_kwargs=kwargs,
|
||||
@ -1062,7 +1062,7 @@ class RecraftCrispUpscaleNode:
|
||||
},
|
||||
}
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
image: torch.Tensor,
|
||||
**kwargs,
|
||||
@ -1071,7 +1071,7 @@ class RecraftCrispUpscaleNode:
|
||||
total = image.shape[0]
|
||||
pbar = ProgressBar(total)
|
||||
for i in range(total):
|
||||
sub_bytes = handle_recraft_file_request(
|
||||
sub_bytes = await handle_recraft_file_request(
|
||||
image=image[i],
|
||||
path=self.RECRAFT_PATH,
|
||||
auth_kwargs=kwargs,
|
||||
|
@ -9,11 +9,10 @@ from __future__ import annotations
|
||||
from inspect import cleandoc
|
||||
from comfy.comfy_types.node_typing import IO
|
||||
import folder_paths as comfy_paths
|
||||
import requests
|
||||
import aiohttp
|
||||
import os
|
||||
import datetime
|
||||
import shutil
|
||||
import time
|
||||
import asyncio
|
||||
import io
|
||||
import logging
|
||||
import math
|
||||
@ -66,7 +65,6 @@ def create_task_error(response: Rodin3DGenerateResponse):
|
||||
return hasattr(response, "error")
|
||||
|
||||
|
||||
|
||||
class Rodin3DAPI:
|
||||
"""
|
||||
Generate 3D Assets using Rodin API
|
||||
@ -123,8 +121,8 @@ class Rodin3DAPI:
|
||||
else:
|
||||
return "Generating"
|
||||
|
||||
def CreateGenerateTask(self, images=None, seed=1, material="PBR", quality="medium", tier="Regular", mesh_mode="Quad", **kwargs):
|
||||
if images == None:
|
||||
async def create_generate_task(self, images=None, seed=1, material="PBR", quality="medium", tier="Regular", mesh_mode="Quad", **kwargs):
|
||||
if images is None:
|
||||
raise Exception("Rodin 3D generate requires at least 1 image.")
|
||||
if len(images) >= 5:
|
||||
raise Exception("Rodin 3D generate requires up to 5 image.")
|
||||
@ -155,7 +153,7 @@ class Rodin3DAPI:
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
|
||||
response = operation.execute()
|
||||
response = await operation.execute()
|
||||
|
||||
if create_task_error(response):
|
||||
error_message = f"Rodin3D Create 3D generate Task Failed. Message: {response.message}, error: {response.error}"
|
||||
@ -168,7 +166,7 @@ class Rodin3DAPI:
|
||||
logging.info(f"[ Rodin3D API - Submit Jobs ] UUID: {task_uuid}")
|
||||
return task_uuid, subscription_key
|
||||
|
||||
def poll_for_task_status(self, subscription_key, **kwargs) -> Rodin3DCheckStatusResponse:
|
||||
async def poll_for_task_status(self, subscription_key, **kwargs) -> Rodin3DCheckStatusResponse:
|
||||
|
||||
path = "/proxy/rodin/api/v2/status"
|
||||
|
||||
@ -191,11 +189,9 @@ class Rodin3DAPI:
|
||||
|
||||
logging.info("[ Rodin3D API - CheckStatus ] Generate Start!")
|
||||
|
||||
return poll_operation.execute()
|
||||
return await poll_operation.execute()
|
||||
|
||||
|
||||
|
||||
def GetRodinDownloadList(self, uuid, **kwargs) -> Rodin3DDownloadResponse:
|
||||
async def get_rodin_download_list(self, uuid, **kwargs) -> Rodin3DDownloadResponse:
|
||||
logging.info("[ Rodin3D API - Downloading ] Generate Successfully!")
|
||||
|
||||
path = "/proxy/rodin/api/v2/download"
|
||||
@ -212,53 +208,59 @@ class Rodin3DAPI:
|
||||
auth_kwargs=kwargs
|
||||
)
|
||||
|
||||
return operation.execute()
|
||||
return await operation.execute()
|
||||
|
||||
def GetQualityAndMode(self, PolyCount):
|
||||
if PolyCount == "200K-Triangle":
|
||||
def get_quality_mode(self, poly_count):
|
||||
if poly_count == "200K-Triangle":
|
||||
mesh_mode = "Raw"
|
||||
quality = "medium"
|
||||
else:
|
||||
mesh_mode = "Quad"
|
||||
if PolyCount == "4K-Quad":
|
||||
if poly_count == "4K-Quad":
|
||||
quality = "extra-low"
|
||||
elif PolyCount == "8K-Quad":
|
||||
elif poly_count == "8K-Quad":
|
||||
quality = "low"
|
||||
elif PolyCount == "18K-Quad":
|
||||
elif poly_count == "18K-Quad":
|
||||
quality = "medium"
|
||||
elif PolyCount == "50K-Quad":
|
||||
elif poly_count == "50K-Quad":
|
||||
quality = "high"
|
||||
else:
|
||||
quality = "medium"
|
||||
|
||||
return mesh_mode, quality
|
||||
|
||||
def DownLoadFiles(self, Url_List):
|
||||
Save_path = os.path.join(comfy_paths.get_output_directory(), "Rodin3D", datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"))
|
||||
os.makedirs(Save_path, exist_ok=True)
|
||||
async def download_files(self, url_list):
|
||||
save_path = os.path.join(comfy_paths.get_output_directory(), "Rodin3D", datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"))
|
||||
os.makedirs(save_path, exist_ok=True)
|
||||
model_file_path = None
|
||||
for Item in Url_List.list:
|
||||
url = Item.url
|
||||
file_name = Item.name
|
||||
file_path = os.path.join(Save_path, file_name)
|
||||
if file_path.endswith(".glb"):
|
||||
model_file_path = file_path
|
||||
logging.info(f"[ Rodin3D API - download_files ] Downloading file: {file_path}")
|
||||
max_retries = 5
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
with requests.get(url, stream=True) as r:
|
||||
r.raise_for_status()
|
||||
with open(file_path, "wb") as f:
|
||||
shutil.copyfileobj(r.raw, f)
|
||||
break
|
||||
except Exception as e:
|
||||
logging.info(f"[ Rodin3D API - download_files ] Error downloading {file_path}:{e}")
|
||||
if attempt < max_retries - 1:
|
||||
logging.info("Retrying...")
|
||||
time.sleep(2)
|
||||
else:
|
||||
logging.info(f"[ Rodin3D API - download_files ] Failed to download {file_path} after {max_retries} attempts.")
|
||||
async with aiohttp.ClientSession() as session:
|
||||
for i in url_list.list:
|
||||
url = i.url
|
||||
file_name = i.name
|
||||
file_path = os.path.join(save_path, file_name)
|
||||
if file_path.endswith(".glb"):
|
||||
model_file_path = file_path
|
||||
logging.info(f"[ Rodin3D API - download_files ] Downloading file: {file_path}")
|
||||
max_retries = 5
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
async with session.get(url) as resp:
|
||||
resp.raise_for_status()
|
||||
with open(file_path, "wb") as f:
|
||||
async for chunk in resp.content.iter_chunked(32 * 1024):
|
||||
f.write(chunk)
|
||||
break
|
||||
except Exception as e:
|
||||
logging.info(f"[ Rodin3D API - download_files ] Error downloading {file_path}:{e}")
|
||||
if attempt < max_retries - 1:
|
||||
logging.info("Retrying...")
|
||||
await asyncio.sleep(2)
|
||||
else:
|
||||
logging.info(
|
||||
"[ Rodin3D API - download_files ] Failed to download %s after %s attempts.",
|
||||
file_path,
|
||||
max_retries,
|
||||
)
|
||||
|
||||
return model_file_path
|
||||
|
||||
@ -285,7 +287,7 @@ class Rodin3D_Regular(Rodin3DAPI):
|
||||
},
|
||||
}
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
Images,
|
||||
Seed,
|
||||
@ -298,14 +300,17 @@ class Rodin3D_Regular(Rodin3DAPI):
|
||||
m_images = []
|
||||
for i in range(num_images):
|
||||
m_images.append(Images[i])
|
||||
mesh_mode, quality = self.GetQualityAndMode(Polygon_count)
|
||||
task_uuid, subscription_key = self.CreateGenerateTask(images=m_images, seed=Seed, material=Material_Type, quality=quality, tier=tier, mesh_mode=mesh_mode, **kwargs)
|
||||
self.poll_for_task_status(subscription_key, **kwargs)
|
||||
Download_List = self.GetRodinDownloadList(task_uuid, **kwargs)
|
||||
model = self.DownLoadFiles(Download_List)
|
||||
mesh_mode, quality = self.get_quality_mode(Polygon_count)
|
||||
task_uuid, subscription_key = await self.create_generate_task(images=m_images, seed=Seed, material=Material_Type,
|
||||
quality=quality, tier=tier, mesh_mode=mesh_mode,
|
||||
**kwargs)
|
||||
await self.poll_for_task_status(subscription_key, **kwargs)
|
||||
download_list = await self.get_rodin_download_list(task_uuid, **kwargs)
|
||||
model = await self.download_files(download_list)
|
||||
|
||||
return (model,)
|
||||
|
||||
|
||||
class Rodin3D_Detail(Rodin3DAPI):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
@ -328,7 +333,7 @@ class Rodin3D_Detail(Rodin3DAPI):
|
||||
},
|
||||
}
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
Images,
|
||||
Seed,
|
||||
@ -341,14 +346,17 @@ class Rodin3D_Detail(Rodin3DAPI):
|
||||
m_images = []
|
||||
for i in range(num_images):
|
||||
m_images.append(Images[i])
|
||||
mesh_mode, quality = self.GetQualityAndMode(Polygon_count)
|
||||
task_uuid, subscription_key = self.CreateGenerateTask(images=m_images, seed=Seed, material=Material_Type, quality=quality, tier=tier, mesh_mode=mesh_mode, **kwargs)
|
||||
self.poll_for_task_status(subscription_key, **kwargs)
|
||||
Download_List = self.GetRodinDownloadList(task_uuid, **kwargs)
|
||||
model = self.DownLoadFiles(Download_List)
|
||||
mesh_mode, quality = self.get_quality_mode(Polygon_count)
|
||||
task_uuid, subscription_key = await self.create_generate_task(images=m_images, seed=Seed, material=Material_Type,
|
||||
quality=quality, tier=tier, mesh_mode=mesh_mode,
|
||||
**kwargs)
|
||||
await self.poll_for_task_status(subscription_key, **kwargs)
|
||||
download_list = await self.get_rodin_download_list(task_uuid, **kwargs)
|
||||
model = await self.download_files(download_list)
|
||||
|
||||
return (model,)
|
||||
|
||||
|
||||
class Rodin3D_Smooth(Rodin3DAPI):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
@ -371,7 +379,7 @@ class Rodin3D_Smooth(Rodin3DAPI):
|
||||
},
|
||||
}
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
Images,
|
||||
Seed,
|
||||
@ -384,14 +392,17 @@ class Rodin3D_Smooth(Rodin3DAPI):
|
||||
m_images = []
|
||||
for i in range(num_images):
|
||||
m_images.append(Images[i])
|
||||
mesh_mode, quality = self.GetQualityAndMode(Polygon_count)
|
||||
task_uuid, subscription_key = self.CreateGenerateTask(images=m_images, seed=Seed, material=Material_Type, quality=quality, tier=tier, mesh_mode=mesh_mode, **kwargs)
|
||||
self.poll_for_task_status(subscription_key, **kwargs)
|
||||
Download_List = self.GetRodinDownloadList(task_uuid, **kwargs)
|
||||
model = self.DownLoadFiles(Download_List)
|
||||
mesh_mode, quality = self.get_quality_mode(Polygon_count)
|
||||
task_uuid, subscription_key = await self.create_generate_task(images=m_images, seed=Seed, material=Material_Type,
|
||||
quality=quality, tier=tier, mesh_mode=mesh_mode,
|
||||
**kwargs)
|
||||
await self.poll_for_task_status(subscription_key, **kwargs)
|
||||
download_list = await self.get_rodin_download_list(task_uuid, **kwargs)
|
||||
model = await self.download_files(download_list)
|
||||
|
||||
return (model,)
|
||||
|
||||
|
||||
class Rodin3D_Sketch(Rodin3DAPI):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
@ -423,7 +434,7 @@ class Rodin3D_Sketch(Rodin3DAPI):
|
||||
},
|
||||
}
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
Images,
|
||||
Seed,
|
||||
@ -437,10 +448,12 @@ class Rodin3D_Sketch(Rodin3DAPI):
|
||||
material_type = "PBR"
|
||||
quality = "medium"
|
||||
mesh_mode = "Quad"
|
||||
task_uuid, subscription_key = self.CreateGenerateTask(images=m_images, seed=Seed, material=material_type, quality=quality, tier=tier, mesh_mode=mesh_mode, **kwargs)
|
||||
self.poll_for_task_status(subscription_key, **kwargs)
|
||||
Download_List = self.GetRodinDownloadList(task_uuid, **kwargs)
|
||||
model = self.DownLoadFiles(Download_List)
|
||||
task_uuid, subscription_key = await self.create_generate_task(
|
||||
images=m_images, seed=Seed, material=material_type, quality=quality, tier=tier, mesh_mode=mesh_mode, **kwargs
|
||||
)
|
||||
await self.poll_for_task_status(subscription_key, **kwargs)
|
||||
download_list = await self.get_rodin_download_list(task_uuid, **kwargs)
|
||||
model = await self.download_files(download_list)
|
||||
|
||||
return (model,)
|
||||
|
||||
|
@ -99,14 +99,14 @@ def validate_input_image(image: torch.Tensor) -> bool:
|
||||
return image.shape[2] < 8000 and image.shape[1] < 8000
|
||||
|
||||
|
||||
def poll_until_finished(
|
||||
async def poll_until_finished(
|
||||
auth_kwargs: dict[str, str],
|
||||
api_endpoint: ApiEndpoint[Any, TaskStatusResponse],
|
||||
estimated_duration: Optional[int] = None,
|
||||
node_id: Optional[str] = None,
|
||||
) -> TaskStatusResponse:
|
||||
"""Polls the Runway API endpoint until the task reaches a terminal state, then returns the response."""
|
||||
return PollingOperation(
|
||||
return await PollingOperation(
|
||||
poll_endpoint=api_endpoint,
|
||||
completed_statuses=[
|
||||
TaskStatus.SUCCEEDED.value,
|
||||
@ -115,7 +115,7 @@ def poll_until_finished(
|
||||
TaskStatus.FAILED.value,
|
||||
TaskStatus.CANCELLED.value,
|
||||
],
|
||||
status_extractor=lambda response: (response.status.value),
|
||||
status_extractor=lambda response: response.status.value,
|
||||
auth_kwargs=auth_kwargs,
|
||||
result_url_extractor=get_video_url_from_task_status,
|
||||
estimated_duration=estimated_duration,
|
||||
@ -167,11 +167,11 @@ class RunwayVideoGenNode(ComfyNodeABC):
|
||||
)
|
||||
return True
|
||||
|
||||
def get_response(
|
||||
async def get_response(
|
||||
self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
|
||||
) -> RunwayImageToVideoResponse:
|
||||
"""Poll the task status until it is finished then get the response."""
|
||||
return poll_until_finished(
|
||||
return await poll_until_finished(
|
||||
auth_kwargs,
|
||||
ApiEndpoint(
|
||||
path=f"{PATH_GET_TASK_STATUS}/{task_id}",
|
||||
@ -183,7 +183,7 @@ class RunwayVideoGenNode(ComfyNodeABC):
|
||||
node_id=node_id,
|
||||
)
|
||||
|
||||
def generate_video(
|
||||
async def generate_video(
|
||||
self,
|
||||
request: RunwayImageToVideoRequest,
|
||||
auth_kwargs: dict[str, str],
|
||||
@ -200,15 +200,15 @@ class RunwayVideoGenNode(ComfyNodeABC):
|
||||
auth_kwargs=auth_kwargs,
|
||||
)
|
||||
|
||||
initial_response = initial_operation.execute()
|
||||
initial_response = await initial_operation.execute()
|
||||
self.validate_task_created(initial_response)
|
||||
task_id = initial_response.id
|
||||
|
||||
final_response = self.get_response(task_id, auth_kwargs, node_id)
|
||||
final_response = await self.get_response(task_id, auth_kwargs, node_id)
|
||||
self.validate_response(final_response)
|
||||
|
||||
video_url = get_video_url_from_task_status(final_response)
|
||||
return (download_url_to_video_output(video_url),)
|
||||
return (await download_url_to_video_output(video_url),)
|
||||
|
||||
|
||||
class RunwayImageToVideoNodeGen3a(RunwayVideoGenNode):
|
||||
@ -250,7 +250,7 @@ class RunwayImageToVideoNodeGen3a(RunwayVideoGenNode):
|
||||
},
|
||||
}
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
prompt: str,
|
||||
start_frame: torch.Tensor,
|
||||
@ -265,7 +265,7 @@ class RunwayImageToVideoNodeGen3a(RunwayVideoGenNode):
|
||||
validate_input_image(start_frame)
|
||||
|
||||
# Upload image
|
||||
download_urls = upload_images_to_comfyapi(
|
||||
download_urls = await upload_images_to_comfyapi(
|
||||
start_frame,
|
||||
max_images=1,
|
||||
mime_type="image/png",
|
||||
@ -274,7 +274,7 @@ class RunwayImageToVideoNodeGen3a(RunwayVideoGenNode):
|
||||
if len(download_urls) != 1:
|
||||
raise RunwayApiError("Failed to upload one or more images to comfy api.")
|
||||
|
||||
return self.generate_video(
|
||||
return await self.generate_video(
|
||||
RunwayImageToVideoRequest(
|
||||
promptText=prompt,
|
||||
seed=seed,
|
||||
@ -333,7 +333,7 @@ class RunwayImageToVideoNodeGen4(RunwayVideoGenNode):
|
||||
},
|
||||
}
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
prompt: str,
|
||||
start_frame: torch.Tensor,
|
||||
@ -348,7 +348,7 @@ class RunwayImageToVideoNodeGen4(RunwayVideoGenNode):
|
||||
validate_input_image(start_frame)
|
||||
|
||||
# Upload image
|
||||
download_urls = upload_images_to_comfyapi(
|
||||
download_urls = await upload_images_to_comfyapi(
|
||||
start_frame,
|
||||
max_images=1,
|
||||
mime_type="image/png",
|
||||
@ -357,7 +357,7 @@ class RunwayImageToVideoNodeGen4(RunwayVideoGenNode):
|
||||
if len(download_urls) != 1:
|
||||
raise RunwayApiError("Failed to upload one or more images to comfy api.")
|
||||
|
||||
return self.generate_video(
|
||||
return await self.generate_video(
|
||||
RunwayImageToVideoRequest(
|
||||
promptText=prompt,
|
||||
seed=seed,
|
||||
@ -382,10 +382,10 @@ class RunwayFirstLastFrameNode(RunwayVideoGenNode):
|
||||
|
||||
DESCRIPTION = "Upload first and last keyframes, draft a prompt, and generate a video. More complex transitions, such as cases where the Last frame is completely different from the First frame, may benefit from the longer 10s duration. This would give the generation more time to smoothly transition between the two inputs. Before diving in, review these best practices to ensure that your input selections will set your generation up for success: https://help.runwayml.com/hc/en-us/articles/34170748696595-Creating-with-Keyframes-on-Gen-3."
|
||||
|
||||
def get_response(
|
||||
async def get_response(
|
||||
self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
|
||||
) -> RunwayImageToVideoResponse:
|
||||
return poll_until_finished(
|
||||
return await poll_until_finished(
|
||||
auth_kwargs,
|
||||
ApiEndpoint(
|
||||
path=f"{PATH_GET_TASK_STATUS}/{task_id}",
|
||||
@ -437,7 +437,7 @@ class RunwayFirstLastFrameNode(RunwayVideoGenNode):
|
||||
},
|
||||
}
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
prompt: str,
|
||||
start_frame: torch.Tensor,
|
||||
@ -455,7 +455,7 @@ class RunwayFirstLastFrameNode(RunwayVideoGenNode):
|
||||
|
||||
# Upload images
|
||||
stacked_input_images = image_tensor_pair_to_batch(start_frame, end_frame)
|
||||
download_urls = upload_images_to_comfyapi(
|
||||
download_urls = await upload_images_to_comfyapi(
|
||||
stacked_input_images,
|
||||
max_images=2,
|
||||
mime_type="image/png",
|
||||
@ -464,7 +464,7 @@ class RunwayFirstLastFrameNode(RunwayVideoGenNode):
|
||||
if len(download_urls) != 2:
|
||||
raise RunwayApiError("Failed to upload one or more images to comfy api.")
|
||||
|
||||
return self.generate_video(
|
||||
return await self.generate_video(
|
||||
RunwayImageToVideoRequest(
|
||||
promptText=prompt,
|
||||
seed=seed,
|
||||
@ -543,11 +543,11 @@ class RunwayTextToImageNode(ComfyNodeABC):
|
||||
)
|
||||
return True
|
||||
|
||||
def get_response(
|
||||
async def get_response(
|
||||
self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
|
||||
) -> TaskStatusResponse:
|
||||
"""Poll the task status until it is finished then get the response."""
|
||||
return poll_until_finished(
|
||||
return await poll_until_finished(
|
||||
auth_kwargs,
|
||||
ApiEndpoint(
|
||||
path=f"{PATH_GET_TASK_STATUS}/{task_id}",
|
||||
@ -559,7 +559,7 @@ class RunwayTextToImageNode(ComfyNodeABC):
|
||||
node_id=node_id,
|
||||
)
|
||||
|
||||
def api_call(
|
||||
async def api_call(
|
||||
self,
|
||||
prompt: str,
|
||||
ratio: str,
|
||||
@ -574,7 +574,7 @@ class RunwayTextToImageNode(ComfyNodeABC):
|
||||
reference_images = None
|
||||
if reference_image is not None:
|
||||
validate_input_image(reference_image)
|
||||
download_urls = upload_images_to_comfyapi(
|
||||
download_urls = await upload_images_to_comfyapi(
|
||||
reference_image,
|
||||
max_images=1,
|
||||
mime_type="image/png",
|
||||
@ -605,19 +605,19 @@ class RunwayTextToImageNode(ComfyNodeABC):
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
|
||||
initial_response = initial_operation.execute()
|
||||
initial_response = await initial_operation.execute()
|
||||
self.validate_task_created(initial_response)
|
||||
task_id = initial_response.id
|
||||
|
||||
# Poll for completion
|
||||
final_response = self.get_response(
|
||||
final_response = await self.get_response(
|
||||
task_id, auth_kwargs=kwargs, node_id=unique_id
|
||||
)
|
||||
self.validate_response(final_response)
|
||||
|
||||
# Download and return image
|
||||
image_url = get_image_url_from_task_status(final_response)
|
||||
return (download_url_to_image_tensor(image_url),)
|
||||
return (await download_url_to_image_tensor(image_url),)
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
|
@ -124,7 +124,7 @@ class StabilityStableImageUltraNode:
|
||||
},
|
||||
}
|
||||
|
||||
def api_call(self, prompt: str, aspect_ratio: str, style_preset: str, seed: int,
|
||||
async def api_call(self, prompt: str, aspect_ratio: str, style_preset: str, seed: int,
|
||||
negative_prompt: str=None, image: torch.Tensor = None, image_denoise: float=None,
|
||||
**kwargs):
|
||||
validate_string(prompt, strip_whitespace=False)
|
||||
@ -163,7 +163,7 @@ class StabilityStableImageUltraNode:
|
||||
content_type="multipart/form-data",
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
response_api = operation.execute()
|
||||
response_api = await operation.execute()
|
||||
|
||||
if response_api.finish_reason != "SUCCESS":
|
||||
raise Exception(f"Stable Image Ultra generation failed: {response_api.finish_reason}.")
|
||||
@ -257,7 +257,7 @@ class StabilityStableImageSD_3_5Node:
|
||||
},
|
||||
}
|
||||
|
||||
def api_call(self, model: str, prompt: str, aspect_ratio: str, style_preset: str, seed: int, cfg_scale: float,
|
||||
async def api_call(self, model: str, prompt: str, aspect_ratio: str, style_preset: str, seed: int, cfg_scale: float,
|
||||
negative_prompt: str=None, image: torch.Tensor = None, image_denoise: float=None,
|
||||
**kwargs):
|
||||
validate_string(prompt, strip_whitespace=False)
|
||||
@ -302,7 +302,7 @@ class StabilityStableImageSD_3_5Node:
|
||||
content_type="multipart/form-data",
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
response_api = operation.execute()
|
||||
response_api = await operation.execute()
|
||||
|
||||
if response_api.finish_reason != "SUCCESS":
|
||||
raise Exception(f"Stable Diffusion 3.5 Image generation failed: {response_api.finish_reason}.")
|
||||
@ -374,7 +374,7 @@ class StabilityUpscaleConservativeNode:
|
||||
},
|
||||
}
|
||||
|
||||
def api_call(self, image: torch.Tensor, prompt: str, creativity: float, seed: int, negative_prompt: str=None,
|
||||
async def api_call(self, image: torch.Tensor, prompt: str, creativity: float, seed: int, negative_prompt: str=None,
|
||||
**kwargs):
|
||||
validate_string(prompt, strip_whitespace=False)
|
||||
image_binary = tensor_to_bytesio(image, total_pixels=1024*1024).read()
|
||||
@ -403,7 +403,7 @@ class StabilityUpscaleConservativeNode:
|
||||
content_type="multipart/form-data",
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
response_api = operation.execute()
|
||||
response_api = await operation.execute()
|
||||
|
||||
if response_api.finish_reason != "SUCCESS":
|
||||
raise Exception(f"Stability Upscale Conservative generation failed: {response_api.finish_reason}.")
|
||||
@ -480,7 +480,7 @@ class StabilityUpscaleCreativeNode:
|
||||
},
|
||||
}
|
||||
|
||||
def api_call(self, image: torch.Tensor, prompt: str, creativity: float, style_preset: str, seed: int, negative_prompt: str=None,
|
||||
async def api_call(self, image: torch.Tensor, prompt: str, creativity: float, style_preset: str, seed: int, negative_prompt: str=None,
|
||||
**kwargs):
|
||||
validate_string(prompt, strip_whitespace=False)
|
||||
image_binary = tensor_to_bytesio(image, total_pixels=1024*1024).read()
|
||||
@ -512,7 +512,7 @@ class StabilityUpscaleCreativeNode:
|
||||
content_type="multipart/form-data",
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
response_api = operation.execute()
|
||||
response_api = await operation.execute()
|
||||
|
||||
operation = PollingOperation(
|
||||
poll_endpoint=ApiEndpoint(
|
||||
@ -527,7 +527,7 @@ class StabilityUpscaleCreativeNode:
|
||||
status_extractor=lambda x: get_async_dummy_status(x),
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
response_poll: StabilityResultsGetResponse = operation.execute()
|
||||
response_poll: StabilityResultsGetResponse = await operation.execute()
|
||||
|
||||
if response_poll.finish_reason != "SUCCESS":
|
||||
raise Exception(f"Stability Upscale Creative generation failed: {response_poll.finish_reason}.")
|
||||
@ -563,8 +563,7 @@ class StabilityUpscaleFastNode:
|
||||
},
|
||||
}
|
||||
|
||||
def api_call(self, image: torch.Tensor,
|
||||
**kwargs):
|
||||
async def api_call(self, image: torch.Tensor, **kwargs):
|
||||
image_binary = tensor_to_bytesio(image, total_pixels=4096*4096).read()
|
||||
|
||||
files = {
|
||||
@ -583,7 +582,7 @@ class StabilityUpscaleFastNode:
|
||||
content_type="multipart/form-data",
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
response_api = operation.execute()
|
||||
response_api = await operation.execute()
|
||||
|
||||
if response_api.finish_reason != "SUCCESS":
|
||||
raise Exception(f"Stability Upscale Fast failed: {response_api.finish_reason}.")
|
||||
|
@ -37,8 +37,8 @@ from comfy_api_nodes.apinode_utils import (
|
||||
)
|
||||
|
||||
|
||||
def upload_image_to_tripo(image, **kwargs):
|
||||
urls = upload_images_to_comfyapi(image, max_images=1, auth_kwargs=kwargs)
|
||||
async def upload_image_to_tripo(image, **kwargs):
|
||||
urls = await upload_images_to_comfyapi(image, max_images=1, auth_kwargs=kwargs)
|
||||
return TripoFileReference(TripoUrlReference(url=urls[0], type="jpeg"))
|
||||
|
||||
def get_model_url_from_response(response: TripoTaskResponse) -> str:
|
||||
@ -49,7 +49,7 @@ def get_model_url_from_response(response: TripoTaskResponse) -> str:
|
||||
raise RuntimeError(f"Failed to get model url from response: {response}")
|
||||
|
||||
|
||||
def poll_until_finished(
|
||||
async def poll_until_finished(
|
||||
kwargs: dict[str, str],
|
||||
response: TripoTaskResponse,
|
||||
) -> tuple[str, str]:
|
||||
@ -57,7 +57,7 @@ def poll_until_finished(
|
||||
if response.code != 0:
|
||||
raise RuntimeError(f"Failed to generate mesh: {response.error}")
|
||||
task_id = response.data.task_id
|
||||
response_poll = PollingOperation(
|
||||
response_poll = await PollingOperation(
|
||||
poll_endpoint=ApiEndpoint(
|
||||
path=f"/proxy/tripo/v2/openapi/task/{task_id}",
|
||||
method=HttpMethod.GET,
|
||||
@ -80,7 +80,7 @@ def poll_until_finished(
|
||||
).execute()
|
||||
if response_poll.data.status == TripoTaskStatus.SUCCESS:
|
||||
url = get_model_url_from_response(response_poll)
|
||||
bytesio = download_url_to_bytesio(url)
|
||||
bytesio = await download_url_to_bytesio(url)
|
||||
# Save the downloaded model file
|
||||
model_file = f"tripo_model_{task_id}.glb"
|
||||
with open(os.path.join(get_output_directory(), model_file), "wb") as f:
|
||||
@ -88,6 +88,7 @@ def poll_until_finished(
|
||||
return model_file, task_id
|
||||
raise RuntimeError(f"Failed to generate mesh: {response_poll}")
|
||||
|
||||
|
||||
class TripoTextToModelNode:
|
||||
"""
|
||||
Generates 3D models synchronously based on a text prompt using Tripo's API.
|
||||
@ -126,11 +127,11 @@ class TripoTextToModelNode:
|
||||
API_NODE = True
|
||||
OUTPUT_NODE = True
|
||||
|
||||
def generate_mesh(self, prompt, negative_prompt=None, model_version=None, style=None, texture=None, pbr=None, image_seed=None, model_seed=None, texture_seed=None, texture_quality=None, face_limit=None, quad=None, **kwargs):
|
||||
async def generate_mesh(self, prompt, negative_prompt=None, model_version=None, style=None, texture=None, pbr=None, image_seed=None, model_seed=None, texture_seed=None, texture_quality=None, face_limit=None, quad=None, **kwargs):
|
||||
style_enum = None if style == "None" else style
|
||||
if not prompt:
|
||||
raise RuntimeError("Prompt is required")
|
||||
response = SynchronousOperation(
|
||||
response = await SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/tripo/v2/openapi/task",
|
||||
method=HttpMethod.POST,
|
||||
@ -155,7 +156,8 @@ class TripoTextToModelNode:
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
).execute()
|
||||
return poll_until_finished(kwargs, response)
|
||||
return await poll_until_finished(kwargs, response)
|
||||
|
||||
|
||||
class TripoImageToModelNode:
|
||||
"""
|
||||
@ -195,12 +197,12 @@ class TripoImageToModelNode:
|
||||
API_NODE = True
|
||||
OUTPUT_NODE = True
|
||||
|
||||
def generate_mesh(self, image, model_version=None, style=None, texture=None, pbr=None, model_seed=None, orientation=None, texture_alignment=None, texture_seed=None, texture_quality=None, face_limit=None, quad=None, **kwargs):
|
||||
async def generate_mesh(self, image, model_version=None, style=None, texture=None, pbr=None, model_seed=None, orientation=None, texture_alignment=None, texture_seed=None, texture_quality=None, face_limit=None, quad=None, **kwargs):
|
||||
style_enum = None if style == "None" else style
|
||||
if image is None:
|
||||
raise RuntimeError("Image is required")
|
||||
tripo_file = upload_image_to_tripo(image, **kwargs)
|
||||
response = SynchronousOperation(
|
||||
tripo_file = await upload_image_to_tripo(image, **kwargs)
|
||||
response = await SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/tripo/v2/openapi/task",
|
||||
method=HttpMethod.POST,
|
||||
@ -225,7 +227,8 @@ class TripoImageToModelNode:
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
).execute()
|
||||
return poll_until_finished(kwargs, response)
|
||||
return await poll_until_finished(kwargs, response)
|
||||
|
||||
|
||||
class TripoMultiviewToModelNode:
|
||||
"""
|
||||
@ -267,7 +270,7 @@ class TripoMultiviewToModelNode:
|
||||
API_NODE = True
|
||||
OUTPUT_NODE = True
|
||||
|
||||
def generate_mesh(self, image, image_left=None, image_back=None, image_right=None, model_version=None, orientation=None, texture=None, pbr=None, model_seed=None, texture_seed=None, texture_quality=None, texture_alignment=None, face_limit=None, quad=None, **kwargs):
|
||||
async def generate_mesh(self, image, image_left=None, image_back=None, image_right=None, model_version=None, orientation=None, texture=None, pbr=None, model_seed=None, texture_seed=None, texture_quality=None, texture_alignment=None, face_limit=None, quad=None, **kwargs):
|
||||
if image is None:
|
||||
raise RuntimeError("front image for multiview is required")
|
||||
images = []
|
||||
@ -282,11 +285,11 @@ class TripoMultiviewToModelNode:
|
||||
for image_name in ["image", "image_left", "image_back", "image_right"]:
|
||||
image_ = image_dict[image_name]
|
||||
if image_ is not None:
|
||||
tripo_file = upload_image_to_tripo(image_, **kwargs)
|
||||
tripo_file = await upload_image_to_tripo(image_, **kwargs)
|
||||
images.append(tripo_file)
|
||||
else:
|
||||
images.append(TripoFileEmptyReference())
|
||||
response = SynchronousOperation(
|
||||
response = await SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/tripo/v2/openapi/task",
|
||||
method=HttpMethod.POST,
|
||||
@ -309,7 +312,8 @@ class TripoMultiviewToModelNode:
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
).execute()
|
||||
return poll_until_finished(kwargs, response)
|
||||
return await poll_until_finished(kwargs, response)
|
||||
|
||||
|
||||
class TripoTextureNode:
|
||||
@classmethod
|
||||
@ -340,8 +344,8 @@ class TripoTextureNode:
|
||||
OUTPUT_NODE = True
|
||||
AVERAGE_DURATION = 80
|
||||
|
||||
def generate_mesh(self, model_task_id, texture=None, pbr=None, texture_seed=None, texture_quality=None, texture_alignment=None, **kwargs):
|
||||
response = SynchronousOperation(
|
||||
async def generate_mesh(self, model_task_id, texture=None, pbr=None, texture_seed=None, texture_quality=None, texture_alignment=None, **kwargs):
|
||||
response = await SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/tripo/v2/openapi/task",
|
||||
method=HttpMethod.POST,
|
||||
@ -358,7 +362,7 @@ class TripoTextureNode:
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
).execute()
|
||||
return poll_until_finished(kwargs, response)
|
||||
return await poll_until_finished(kwargs, response)
|
||||
|
||||
|
||||
class TripoRefineNode:
|
||||
@ -387,8 +391,8 @@ class TripoRefineNode:
|
||||
OUTPUT_NODE = True
|
||||
AVERAGE_DURATION = 240
|
||||
|
||||
def generate_mesh(self, model_task_id, **kwargs):
|
||||
response = SynchronousOperation(
|
||||
async def generate_mesh(self, model_task_id, **kwargs):
|
||||
response = await SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/tripo/v2/openapi/task",
|
||||
method=HttpMethod.POST,
|
||||
@ -400,7 +404,7 @@ class TripoRefineNode:
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
).execute()
|
||||
return poll_until_finished(kwargs, response)
|
||||
return await poll_until_finished(kwargs, response)
|
||||
|
||||
|
||||
class TripoRigNode:
|
||||
@ -425,8 +429,8 @@ class TripoRigNode:
|
||||
OUTPUT_NODE = True
|
||||
AVERAGE_DURATION = 180
|
||||
|
||||
def generate_mesh(self, original_model_task_id, **kwargs):
|
||||
response = SynchronousOperation(
|
||||
async def generate_mesh(self, original_model_task_id, **kwargs):
|
||||
response = await SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/tripo/v2/openapi/task",
|
||||
method=HttpMethod.POST,
|
||||
@ -440,7 +444,8 @@ class TripoRigNode:
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
).execute()
|
||||
return poll_until_finished(kwargs, response)
|
||||
return await poll_until_finished(kwargs, response)
|
||||
|
||||
|
||||
class TripoRetargetNode:
|
||||
@classmethod
|
||||
@ -475,8 +480,8 @@ class TripoRetargetNode:
|
||||
OUTPUT_NODE = True
|
||||
AVERAGE_DURATION = 30
|
||||
|
||||
def generate_mesh(self, animation, original_model_task_id, **kwargs):
|
||||
response = SynchronousOperation(
|
||||
async def generate_mesh(self, animation, original_model_task_id, **kwargs):
|
||||
response = await SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/tripo/v2/openapi/task",
|
||||
method=HttpMethod.POST,
|
||||
@ -491,7 +496,8 @@ class TripoRetargetNode:
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
).execute()
|
||||
return poll_until_finished(kwargs, response)
|
||||
return await poll_until_finished(kwargs, response)
|
||||
|
||||
|
||||
class TripoConversionNode:
|
||||
@classmethod
|
||||
@ -529,10 +535,10 @@ class TripoConversionNode:
|
||||
OUTPUT_NODE = True
|
||||
AVERAGE_DURATION = 30
|
||||
|
||||
def generate_mesh(self, original_model_task_id, format, quad, face_limit, texture_size, texture_format, **kwargs):
|
||||
async def generate_mesh(self, original_model_task_id, format, quad, face_limit, texture_size, texture_format, **kwargs):
|
||||
if not original_model_task_id:
|
||||
raise RuntimeError("original_model_task_id is required")
|
||||
response = SynchronousOperation(
|
||||
response = await SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/tripo/v2/openapi/task",
|
||||
method=HttpMethod.POST,
|
||||
@ -549,7 +555,8 @@ class TripoConversionNode:
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
).execute()
|
||||
return poll_until_finished(kwargs, response)
|
||||
return await poll_until_finished(kwargs, response)
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"TripoTextToModelNode": TripoTextToModelNode,
|
||||
|
@ -1,17 +1,17 @@
|
||||
import io
|
||||
import logging
|
||||
import base64
|
||||
import requests
|
||||
import aiohttp
|
||||
import torch
|
||||
from typing import Optional
|
||||
|
||||
from comfy.comfy_types.node_typing import IO, ComfyNodeABC
|
||||
from comfy_api.input_impl.video_types import VideoFromFile
|
||||
from comfy_api_nodes.apis import (
|
||||
Veo2GenVidRequest,
|
||||
Veo2GenVidResponse,
|
||||
Veo2GenVidPollRequest,
|
||||
Veo2GenVidPollResponse
|
||||
VeoGenVidRequest,
|
||||
VeoGenVidResponse,
|
||||
VeoGenVidPollRequest,
|
||||
VeoGenVidPollResponse
|
||||
)
|
||||
from comfy_api_nodes.apis.client import (
|
||||
ApiEndpoint,
|
||||
@ -35,7 +35,7 @@ def convert_image_to_base64(image: torch.Tensor):
|
||||
return tensor_to_base64_string(scaled_image)
|
||||
|
||||
|
||||
def get_video_url_from_response(poll_response: Veo2GenVidPollResponse) -> Optional[str]:
|
||||
def get_video_url_from_response(poll_response: VeoGenVidPollResponse) -> Optional[str]:
|
||||
if (
|
||||
poll_response.response
|
||||
and hasattr(poll_response.response, "videos")
|
||||
@ -130,6 +130,14 @@ class VeoVideoGenerationNode(ComfyNodeABC):
|
||||
"default": None,
|
||||
"tooltip": "Optional reference image to guide video generation",
|
||||
}),
|
||||
"model": (
|
||||
IO.COMBO,
|
||||
{
|
||||
"options": ["veo-2.0-generate-001"],
|
||||
"default": "veo-2.0-generate-001",
|
||||
"tooltip": "Veo 2 model to use for video generation",
|
||||
},
|
||||
),
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
@ -141,10 +149,10 @@ class VeoVideoGenerationNode(ComfyNodeABC):
|
||||
RETURN_TYPES = (IO.VIDEO,)
|
||||
FUNCTION = "generate_video"
|
||||
CATEGORY = "api node/video/Veo"
|
||||
DESCRIPTION = "Generates videos from text prompts using Google's Veo API"
|
||||
DESCRIPTION = "Generates videos from text prompts using Google's Veo 2 API"
|
||||
API_NODE = True
|
||||
|
||||
def generate_video(
|
||||
async def generate_video(
|
||||
self,
|
||||
prompt,
|
||||
aspect_ratio="16:9",
|
||||
@ -154,6 +162,8 @@ class VeoVideoGenerationNode(ComfyNodeABC):
|
||||
person_generation="ALLOW",
|
||||
seed=0,
|
||||
image=None,
|
||||
model="veo-2.0-generate-001",
|
||||
generate_audio=False,
|
||||
unique_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
@ -188,23 +198,26 @@ class VeoVideoGenerationNode(ComfyNodeABC):
|
||||
parameters["negativePrompt"] = negative_prompt
|
||||
if seed > 0:
|
||||
parameters["seed"] = seed
|
||||
# Only add generateAudio for Veo 3 models
|
||||
if "veo-3.0" in model:
|
||||
parameters["generateAudio"] = generate_audio
|
||||
|
||||
# Initial request to start video generation
|
||||
initial_operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/veo/generate",
|
||||
path=f"/proxy/veo/{model}/generate",
|
||||
method=HttpMethod.POST,
|
||||
request_model=Veo2GenVidRequest,
|
||||
response_model=Veo2GenVidResponse
|
||||
request_model=VeoGenVidRequest,
|
||||
response_model=VeoGenVidResponse
|
||||
),
|
||||
request=Veo2GenVidRequest(
|
||||
request=VeoGenVidRequest(
|
||||
instances=instances,
|
||||
parameters=parameters
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
|
||||
initial_response = initial_operation.execute()
|
||||
initial_response = await initial_operation.execute()
|
||||
operation_name = initial_response.name
|
||||
|
||||
logging.info(f"Veo generation started with operation name: {operation_name}")
|
||||
@ -223,16 +236,16 @@ class VeoVideoGenerationNode(ComfyNodeABC):
|
||||
# Define the polling operation
|
||||
poll_operation = PollingOperation(
|
||||
poll_endpoint=ApiEndpoint(
|
||||
path="/proxy/veo/poll",
|
||||
path=f"/proxy/veo/{model}/poll",
|
||||
method=HttpMethod.POST,
|
||||
request_model=Veo2GenVidPollRequest,
|
||||
response_model=Veo2GenVidPollResponse
|
||||
request_model=VeoGenVidPollRequest,
|
||||
response_model=VeoGenVidPollResponse
|
||||
),
|
||||
completed_statuses=["completed"],
|
||||
failed_statuses=[], # No failed statuses, we'll handle errors after polling
|
||||
status_extractor=status_extractor,
|
||||
progress_extractor=progress_extractor,
|
||||
request=Veo2GenVidPollRequest(
|
||||
request=VeoGenVidPollRequest(
|
||||
operationName=operation_name
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
@ -243,7 +256,7 @@ class VeoVideoGenerationNode(ComfyNodeABC):
|
||||
)
|
||||
|
||||
# Execute the polling operation
|
||||
poll_response = poll_operation.execute()
|
||||
poll_response = await poll_operation.execute()
|
||||
|
||||
# Now check for errors in the final response
|
||||
# Check for error in poll response
|
||||
@ -268,7 +281,6 @@ class VeoVideoGenerationNode(ComfyNodeABC):
|
||||
raise Exception(error_message)
|
||||
|
||||
# Extract video data
|
||||
video_data = None
|
||||
if poll_response.response and hasattr(poll_response.response, 'videos') and poll_response.response.videos and len(poll_response.response.videos) > 0:
|
||||
video = poll_response.response.videos[0]
|
||||
|
||||
@ -278,9 +290,9 @@ class VeoVideoGenerationNode(ComfyNodeABC):
|
||||
video_data = base64.b64decode(video.bytesBase64Encoded)
|
||||
elif hasattr(video, 'gcsUri') and video.gcsUri:
|
||||
# Download from URL
|
||||
video_url = video.gcsUri
|
||||
video_response = requests.get(video_url)
|
||||
video_data = video_response.content
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(video.gcsUri) as video_response:
|
||||
video_data = await video_response.content.read()
|
||||
else:
|
||||
raise Exception("Video returned but no data or URL was provided")
|
||||
else:
|
||||
@ -298,11 +310,64 @@ class VeoVideoGenerationNode(ComfyNodeABC):
|
||||
return (VideoFromFile(video_io),)
|
||||
|
||||
|
||||
# Register the node
|
||||
class Veo3VideoGenerationNode(VeoVideoGenerationNode):
|
||||
"""
|
||||
Generates videos from text prompts using Google's Veo 3 API.
|
||||
|
||||
Supported models:
|
||||
- veo-3.0-generate-001
|
||||
- veo-3.0-fast-generate-001
|
||||
|
||||
This node extends the base Veo node with Veo 3 specific features including
|
||||
audio generation and fixed 8-second duration.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
parent_input = super().INPUT_TYPES()
|
||||
|
||||
# Update model options for Veo 3
|
||||
parent_input["optional"]["model"] = (
|
||||
IO.COMBO,
|
||||
{
|
||||
"options": ["veo-3.0-generate-001", "veo-3.0-fast-generate-001"],
|
||||
"default": "veo-3.0-generate-001",
|
||||
"tooltip": "Veo 3 model to use for video generation",
|
||||
},
|
||||
)
|
||||
|
||||
# Add generateAudio parameter
|
||||
parent_input["optional"]["generate_audio"] = (
|
||||
IO.BOOLEAN,
|
||||
{
|
||||
"default": False,
|
||||
"tooltip": "Generate audio for the video. Supported by all Veo 3 models.",
|
||||
}
|
||||
)
|
||||
|
||||
# Update duration constraints for Veo 3 (only 8 seconds supported)
|
||||
parent_input["optional"]["duration_seconds"] = (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 8,
|
||||
"min": 8,
|
||||
"max": 8,
|
||||
"step": 1,
|
||||
"display": "number",
|
||||
"tooltip": "Duration of the output video in seconds (Veo 3 only supports 8 seconds)",
|
||||
},
|
||||
)
|
||||
|
||||
return parent_input
|
||||
|
||||
|
||||
# Register the nodes
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"VeoVideoGenerationNode": VeoVideoGenerationNode,
|
||||
"Veo3VideoGenerationNode": Veo3VideoGenerationNode,
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"VeoVideoGenerationNode": "Google Veo2 Video Generation",
|
||||
"VeoVideoGenerationNode": "Google Veo 2 Video Generation",
|
||||
"Veo3VideoGenerationNode": "Google Veo 3 Video Generation",
|
||||
}
|
||||
|
@ -314,6 +314,29 @@ class ModelMergeCosmosPredict2_14B(comfy_extras.nodes_model_merging.ModelMergeBl
|
||||
|
||||
return {"required": arg_dict}
|
||||
|
||||
class ModelMergeQwenImage(comfy_extras.nodes_model_merging.ModelMergeBlocks):
|
||||
CATEGORY = "advanced/model_merging/model_specific"
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
arg_dict = { "model1": ("MODEL",),
|
||||
"model2": ("MODEL",)}
|
||||
|
||||
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
|
||||
|
||||
arg_dict["pos_embeds."] = argument
|
||||
arg_dict["img_in."] = argument
|
||||
arg_dict["txt_norm."] = argument
|
||||
arg_dict["txt_in."] = argument
|
||||
arg_dict["time_text_embed."] = argument
|
||||
|
||||
for i in range(60):
|
||||
arg_dict["transformer_blocks.{}.".format(i)] = argument
|
||||
|
||||
arg_dict["proj_out."] = argument
|
||||
|
||||
return {"required": arg_dict}
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"ModelMergeSD1": ModelMergeSD1,
|
||||
"ModelMergeSD2": ModelMergeSD1, #SD1 and SD2 have the same blocks
|
||||
@ -329,4 +352,5 @@ NODE_CLASS_MAPPINGS = {
|
||||
"ModelMergeWAN2_1": ModelMergeWAN2_1,
|
||||
"ModelMergeCosmosPredict2_2B": ModelMergeCosmosPredict2_2B,
|
||||
"ModelMergeCosmosPredict2_14B": ModelMergeCosmosPredict2_14B,
|
||||
"ModelMergeQwenImage": ModelMergeQwenImage,
|
||||
}
|
||||
|
@ -1,3 +1,3 @@
|
||||
# This file is automatically generated by the build process when version is
|
||||
# updated in pyproject.toml.
|
||||
__version__ = "0.3.48"
|
||||
__version__ = "0.3.49"
|
||||
|
6
nodes.py
6
nodes.py
@ -1229,12 +1229,12 @@ class RepeatLatentBatch:
|
||||
s = samples.copy()
|
||||
s_in = samples["samples"]
|
||||
|
||||
s["samples"] = s_in.repeat((amount, 1,1,1))
|
||||
s["samples"] = s_in.repeat((amount,) + ((1,) * (s_in.ndim - 1)))
|
||||
if "noise_mask" in samples and samples["noise_mask"].shape[0] > 1:
|
||||
masks = samples["noise_mask"]
|
||||
if masks.shape[0] < s_in.shape[0]:
|
||||
masks = masks.repeat(math.ceil(s_in.shape[0] / masks.shape[0]), 1, 1, 1)[:s_in.shape[0]]
|
||||
s["noise_mask"] = samples["noise_mask"].repeat((amount, 1,1,1))
|
||||
masks = masks.repeat((math.ceil(s_in.shape[0] / masks.shape[0]),) + ((1,) * (masks.ndim - 1)))[:s_in.shape[0]]
|
||||
s["noise_mask"] = samples["noise_mask"].repeat((amount,) + ((1,) * (samples["noise_mask"].ndim - 1)))
|
||||
if "batch_index" in s:
|
||||
offset = max(s["batch_index"]) - min(s["batch_index"]) + 1
|
||||
s["batch_index"] = s["batch_index"] + [x + (i * offset) for i in range(1, amount) for x in s["batch_index"]]
|
||||
|
@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "ComfyUI"
|
||||
version = "0.3.48"
|
||||
version = "0.3.49"
|
||||
readme = "README.md"
|
||||
license = { file = "LICENSE" }
|
||||
requires-python = ">=3.9"
|
||||
|
@ -1,5 +1,5 @@
|
||||
comfyui-frontend-package==1.23.4
|
||||
comfyui-workflow-templates==0.1.47
|
||||
comfyui-frontend-package==1.24.4
|
||||
comfyui-workflow-templates==0.1.52
|
||||
comfyui-embedded-docs==0.2.4
|
||||
torch
|
||||
torchsde
|
||||
|
Loading…
x
Reference in New Issue
Block a user