From 4136502b7a530df28e489f81c0c58b30612f9ece Mon Sep 17 00:00:00 2001 From: thot experiment <94414189+thot-experiment@users.noreply.github.com> Date: Mon, 12 May 2025 18:10:24 -0700 Subject: [PATCH 01/31] implement APG guidance (#8081) * first pass at impementing AGP * rename, cleanup code * fix ruff * fix modified cond to match ref impl better, support different cond arity --- comfy_extras/nodes_apg.py | 76 +++++++++++++++++++++++++++++++++++++++ nodes.py | 1 + 2 files changed, 77 insertions(+) create mode 100644 comfy_extras/nodes_apg.py diff --git a/comfy_extras/nodes_apg.py b/comfy_extras/nodes_apg.py new file mode 100644 index 00000000..1325985b --- /dev/null +++ b/comfy_extras/nodes_apg.py @@ -0,0 +1,76 @@ +import torch + +def project(v0, v1): + v1 = torch.nn.functional.normalize(v1, dim=[-1, -2, -3]) + v0_parallel = (v0 * v1).sum(dim=[-1, -2, -3], keepdim=True) * v1 + v0_orthogonal = v0 - v0_parallel + return v0_parallel, v0_orthogonal + +class APG: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "model": ("MODEL",), + "eta": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01, "tooltip": "Controls the scale of the parallel guidance vector. Default CFG behavior at a setting of 1."}), + "norm_threshold": ("FLOAT", {"default": 5.0, "min": 0.0, "max": 50.0, "step": 0.1, "tooltip": "Normalize guidance vector to this value, normalization disable at a setting of 0."}), + "momentum": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip":"Controls a running average of guidance during diffusion, disabled at a setting of 0."}), + } + } + RETURN_TYPES = ("MODEL",) + FUNCTION = "patch" + CATEGORY = "sampling/custom_sampling" + + def patch(self, model, eta, norm_threshold, momentum): + running_avg = 0 + prev_sigma = None + + def pre_cfg_function(args): + nonlocal running_avg, prev_sigma + + if len(args["conds_out"]) == 1: return args["conds_out"] + + cond = args["conds_out"][0] + uncond = args["conds_out"][1] + sigma = args["sigma"][0] + cond_scale = args["cond_scale"] + + if prev_sigma is not None and sigma > prev_sigma: + running_avg = 0 + prev_sigma = sigma + + guidance = cond - uncond + + if momentum > 0: + if not torch.is_tensor(running_avg): + running_avg = guidance + else: + running_avg = momentum * running_avg + guidance + guidance = running_avg + + if norm_threshold > 0: + guidance_norm = guidance.norm(p=2, dim=[-1, -2, -3], keepdim=True) + scale = torch.minimum( + torch.ones_like(guidance_norm), + norm_threshold / guidance_norm + ) + guidance = guidance * scale + + guidance_parallel, guidance_orthogonal = project(guidance, cond) + modified_guidance = guidance_orthogonal + eta * guidance_parallel + + modified_cond = (uncond + modified_guidance) + (cond - uncond) / cond_scale + + return [modified_cond, uncond] + args["conds_out"][2:] + + m = model.clone() + m.set_model_sampler_pre_cfg_function(pre_cfg_function) + return (m,) + +NODE_CLASS_MAPPINGS = { + "APG": APG, +} + +NODE_DISPLAY_NAME_MAPPINGS = { + "APG": "Adaptive Projected Guidance", +} diff --git a/nodes.py b/nodes.py index a26a138f..54e3886a 100644 --- a/nodes.py +++ b/nodes.py @@ -2261,6 +2261,7 @@ def init_builtin_extra_nodes(): "nodes_optimalsteps.py", "nodes_hidream.py", "nodes_fresca.py", + "nodes_apg.py", "nodes_preview_any.py", "nodes_ace.py", "nodes_string.py", From 2156ce9453a8508efec4c36c26e810a14a3ce473 Mon Sep 17 00:00:00 2001 From: Christian Byrne Date: Mon, 12 May 2025 20:06:44 -0700 Subject: [PATCH 02/31] add comment about using api key in headless (#8082) --- script_examples/basic_api_example.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/script_examples/basic_api_example.py b/script_examples/basic_api_example.py index c916e6cb..9128420c 100644 --- a/script_examples/basic_api_example.py +++ b/script_examples/basic_api_example.py @@ -101,6 +101,14 @@ prompt_text = """ def queue_prompt(prompt): p = {"prompt": prompt} + + # If the workflow contains API nodes, you can add a Comfy API key to the `extra_data`` field of the payload. + # p["extra_data"] = { + # "api_key_comfy_org": "comfyui-87d01e28d*******************************************************" # replace with real key + # } + # See: https://docs.comfy.org/tutorials/api-nodes/overview + # Generate a key here: https://platform.comfy.org/login + data = json.dumps(p).encode('utf-8') req = request.Request("http://127.0.0.1:8188/prompt", data=data) request.urlopen(req) From 481732a0ed609ccd86cb3d2df00e3d14c9133280 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 13 May 2025 04:32:16 -0700 Subject: [PATCH 03/31] Support official ACE Step loras. (#8094) --- comfy/lora.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/comfy/lora.py b/comfy/lora.py index fff524be..ef110c16 100644 --- a/comfy/lora.py +++ b/comfy/lora.py @@ -286,6 +286,12 @@ def model_lora_keys_unet(model, key_map={}): key_lora = k[len("diffusion_model."):-len(".weight")].replace(".", "_") key_map["lycoris_{}".format(key_lora)] = k #SimpleTuner lycoris format + if isinstance(model, comfy.model_base.ACEStep): + for k in sdk: + if k.startswith("diffusion_model.") and k.endswith(".weight"): #Official ACE step lora format + key_lora = k[len("diffusion_model."):-len(".weight")] + key_map["{}".format(key_lora)] = k + return key_map From a814f2e8ccf171c967ee2dc095dd02f2a93dccde Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 13 May 2025 04:54:28 -0700 Subject: [PATCH 04/31] Fix issue with old pytorch RMSNorm. (#8095) --- comfy/rmsnorm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/rmsnorm.py b/comfy/rmsnorm.py index 9d82bee1..66ae8321 100644 --- a/comfy/rmsnorm.py +++ b/comfy/rmsnorm.py @@ -30,7 +30,7 @@ if RMSNorm is None: def __init__( self, normalized_shape, - eps=None, + eps=1e-6, elementwise_affine=True, device=None, dtype=None, From 8a7c894d5415c499817c1fcdcba26940755c03a4 Mon Sep 17 00:00:00 2001 From: thot experiment <94414189+thot-experiment@users.noreply.github.com> Date: Tue, 13 May 2025 10:50:32 -0700 Subject: [PATCH 05/31] fix negative momentum (#8100) --- comfy_extras/nodes_apg.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/comfy_extras/nodes_apg.py b/comfy_extras/nodes_apg.py index 1325985b..25b21b1b 100644 --- a/comfy_extras/nodes_apg.py +++ b/comfy_extras/nodes_apg.py @@ -14,7 +14,7 @@ class APG: "model": ("MODEL",), "eta": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01, "tooltip": "Controls the scale of the parallel guidance vector. Default CFG behavior at a setting of 1."}), "norm_threshold": ("FLOAT", {"default": 5.0, "min": 0.0, "max": 50.0, "step": 0.1, "tooltip": "Normalize guidance vector to this value, normalization disable at a setting of 0."}), - "momentum": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip":"Controls a running average of guidance during diffusion, disabled at a setting of 0."}), + "momentum": ("FLOAT", {"default": 0.0, "min": -5.0, "max": 1.0, "step": 0.01, "tooltip":"Controls a running average of guidance during diffusion, disabled at a setting of 0."}), } } RETURN_TYPES = ("MODEL",) @@ -41,7 +41,7 @@ class APG: guidance = cond - uncond - if momentum > 0: + if momentum != 0: if not torch.is_tensor(running_avg): running_avg = guidance else: From 4a9014e201b71dfcf97d59d1086cb451cd40873f Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 13 May 2025 12:53:47 -0700 Subject: [PATCH 06/31] Hunyuan Custom initial untested implementation. (#8101) --- comfy/ldm/hunyuan_video/model.py | 21 ++++++++++++++++++--- comfy/model_base.py | 4 ++++ comfy_extras/nodes_hunyuan.py | 6 ++++-- 3 files changed, 26 insertions(+), 5 deletions(-) diff --git a/comfy/ldm/hunyuan_video/model.py b/comfy/ldm/hunyuan_video/model.py index 72af3d5b..fbd8d419 100644 --- a/comfy/ldm/hunyuan_video/model.py +++ b/comfy/ldm/hunyuan_video/model.py @@ -228,6 +228,7 @@ class HunyuanVideo(nn.Module): y: Tensor, guidance: Tensor = None, guiding_frame_index=None, + ref_latent=None, control=None, transformer_options={}, ) -> Tensor: @@ -238,6 +239,14 @@ class HunyuanVideo(nn.Module): img = self.img_in(img) vec = self.time_in(timestep_embedding(timesteps, 256, time_factor=1.0).to(img.dtype)) + if ref_latent is not None: + ref_latent_ids = self.img_ids(ref_latent) + ref_latent = self.img_in(ref_latent) + img = torch.cat([ref_latent, img], dim=-2) + ref_latent_ids[..., 0] = -1 + ref_latent_ids[..., 2] += (initial_shape[-1] // self.patch_size[-1]) + img_ids = torch.cat([ref_latent_ids, img_ids], dim=-2) + if guiding_frame_index is not None: token_replace_vec = self.time_in(timestep_embedding(guiding_frame_index, 256, time_factor=1.0)) vec_ = self.vector_in(y[:, :self.params.vec_in_dim]) @@ -313,6 +322,8 @@ class HunyuanVideo(nn.Module): img[:, : img_len] += add img = img[:, : img_len] + if ref_latent is not None: + img = img[:, ref_latent.shape[1]:] img = self.final_layer(img, vec, modulation_dims=modulation_dims) # (N, T, patch_size ** 2 * out_channels) @@ -324,7 +335,7 @@ class HunyuanVideo(nn.Module): img = img.reshape(initial_shape[0], self.out_channels, initial_shape[2], initial_shape[3], initial_shape[4]) return img - def forward(self, x, timestep, context, y, guidance=None, attention_mask=None, guiding_frame_index=None, control=None, transformer_options={}, **kwargs): + def img_ids(self, x): bs, c, t, h, w = x.shape patch_size = self.patch_size t_len = ((t + (patch_size[0] // 2)) // patch_size[0]) @@ -334,7 +345,11 @@ class HunyuanVideo(nn.Module): img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(0, t_len - 1, steps=t_len, device=x.device, dtype=x.dtype).reshape(-1, 1, 1) img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).reshape(1, -1, 1) img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).reshape(1, 1, -1) - img_ids = repeat(img_ids, "t h w c -> b (t h w) c", b=bs) + return repeat(img_ids, "t h w c -> b (t h w) c", b=bs) + + def forward(self, x, timestep, context, y, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, control=None, transformer_options={}, **kwargs): + bs, c, t, h, w = x.shape + img_ids = self.img_ids(x) txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype) - out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, guidance, guiding_frame_index, control, transformer_options) + out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, guidance, guiding_frame_index, ref_latent, control=control, transformer_options=transformer_options) return out diff --git a/comfy/model_base.py b/comfy/model_base.py index 6d27930d..04786159 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -924,6 +924,10 @@ class HunyuanVideo(BaseModel): if guiding_frame_index is not None: out['guiding_frame_index'] = comfy.conds.CONDRegular(torch.FloatTensor([guiding_frame_index])) + ref_latent = kwargs.get("ref_latent", None) + if ref_latent is not None: + out['ref_latent'] = comfy.conds.CONDRegular(self.process_latent_in(ref_latent)) + return out def scale_latent_inpaint(self, latent_image, **kwargs): diff --git a/comfy_extras/nodes_hunyuan.py b/comfy_extras/nodes_hunyuan.py index 504010ad..d7278e7a 100644 --- a/comfy_extras/nodes_hunyuan.py +++ b/comfy_extras/nodes_hunyuan.py @@ -77,7 +77,7 @@ class HunyuanImageToVideo: "height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), "length": ("INT", {"default": 53, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}), "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), - "guidance_type": (["v1 (concat)", "v2 (replace)"], ) + "guidance_type": (["v1 (concat)", "v2 (replace)", "custom"], ) }, "optional": {"start_image": ("IMAGE", ), }} @@ -101,10 +101,12 @@ class HunyuanImageToVideo: if guidance_type == "v1 (concat)": cond = {"concat_latent_image": concat_latent_image, "concat_mask": mask} - else: + elif guidance_type == "v2 (replace)": cond = {'guiding_frame_index': 0} latent[:, :, :concat_latent_image.shape[2]] = concat_latent_image out_latent["noise_mask"] = mask + elif guidance_type == "custom": + cond = {"ref_latent": concat_latent_image} positive = node_helpers.conditioning_set_values(positive, cond) From bab836d88d62ebbe3b3040f3a9fff05f79d5049d Mon Sep 17 00:00:00 2001 From: thot experiment <94414189+thot-experiment@users.noreply.github.com> Date: Tue, 13 May 2025 17:42:29 -0700 Subject: [PATCH 07/31] rework client.py to be more robust, add logging of api requests (#7988) * rework how errors are handled on the client side * add logging to /temp * fix ruff * fix rebase, stupid vscode gui --- comfy_api_nodes/apis/client.py | 543 ++++++++++++++++++++++--- comfy_api_nodes/apis/request_logger.py | 125 ++++++ 2 files changed, 622 insertions(+), 46 deletions(-) create mode 100644 comfy_api_nodes/apis/request_logger.py diff --git a/comfy_api_nodes/apis/client.py b/comfy_api_nodes/apis/client.py index cff52714..158d20de 100644 --- a/comfy_api_nodes/apis/client.py +++ b/comfy_api_nodes/apis/client.py @@ -94,15 +94,18 @@ from __future__ import annotations import logging import time import io -from typing import Dict, Type, Optional, Any, TypeVar, Generic, Callable +import socket +from typing import Dict, Type, Optional, Any, TypeVar, Generic, Callable, Tuple from enum import Enum import json import requests -from urllib.parse import urljoin +from urllib.parse import urljoin, urlparse from pydantic import BaseModel, Field +import uuid # For generating unique operation IDs from comfy.cli_args import args from comfy import utils +from . import request_logger T = TypeVar("T", bound=BaseModel) R = TypeVar("R", bound=BaseModel) @@ -111,6 +114,21 @@ P = TypeVar("P", bound=BaseModel) # For poll response PROGRESS_BAR_MAX = 100 +class NetworkError(Exception): + """Base exception for network-related errors with diagnostic information.""" + pass + + +class LocalNetworkError(NetworkError): + """Exception raised when local network connectivity issues are detected.""" + pass + + +class ApiServerError(NetworkError): + """Exception raised when the API server is unreachable but internet is working.""" + pass + + class EmptyRequest(BaseModel): """Base class for empty request bodies. For GET requests, fields will be sent as query parameters.""" @@ -141,7 +159,7 @@ class HttpMethod(str, Enum): class ApiClient: """ - Client for making HTTP requests to an API with authentication and error handling. + Client for making HTTP requests to an API with authentication, error handling, and retry logic. """ def __init__( @@ -151,12 +169,26 @@ class ApiClient: comfy_api_key: Optional[str] = None, timeout: float = 3600.0, verify_ssl: bool = True, + max_retries: int = 3, + retry_delay: float = 1.0, + retry_backoff_factor: float = 2.0, + retry_status_codes: Optional[Tuple[int, ...]] = None, ): self.base_url = base_url self.auth_token = auth_token self.comfy_api_key = comfy_api_key self.timeout = timeout self.verify_ssl = verify_ssl + self.max_retries = max_retries + self.retry_delay = retry_delay + self.retry_backoff_factor = retry_backoff_factor + # Default retry status codes: 408 (Request Timeout), 429 (Too Many Requests), + # 500, 502, 503, 504 (Server Errors) + self.retry_status_codes = retry_status_codes or (408, 429, 500, 502, 503, 504) + + def _generate_operation_id(self, path: str) -> str: + """Generates a unique operation ID for logging.""" + return f"{path.strip('/').replace('/', '_')}_{uuid.uuid4().hex[:8]}" def _create_json_payload_args( self, @@ -211,6 +243,56 @@ class ApiClient: return headers + def _check_connectivity(self, target_url: str) -> Dict[str, bool]: + """ + Check connectivity to determine if network issues are local or server-related. + + Args: + target_url: URL to check connectivity to + + Returns: + Dictionary with connectivity status details + """ + results = { + "internet_accessible": False, + "api_accessible": False, + "is_local_issue": False, + "is_api_issue": False + } + + # First check basic internet connectivity using a reliable external site + try: + # Use a reliable external domain for checking basic connectivity + check_response = requests.get("https://www.google.com", + timeout=5.0, + verify=self.verify_ssl) + if check_response.status_code < 500: + results["internet_accessible"] = True + except (requests.RequestException, socket.error): + results["internet_accessible"] = False + results["is_local_issue"] = True + return results + + # Now check API server connectivity + try: + # Extract domain from the target URL to do a simpler health check + parsed_url = urlparse(target_url) + api_base = f"{parsed_url.scheme}://{parsed_url.netloc}" + + # Try to reach the API domain + api_response = requests.get(f"{api_base}/health", timeout=5.0, verify=self.verify_ssl) + if api_response.status_code < 500: + results["api_accessible"] = True + else: + results["api_accessible"] = False + results["is_api_issue"] = True + except requests.RequestException: + results["api_accessible"] = False + # If we can reach the internet but not the API, it's an API issue + results["is_api_issue"] = True + + return results + def request( self, method: str, @@ -221,9 +303,10 @@ class ApiClient: headers: Optional[Dict[str, str]] = None, content_type: str = "application/json", multipart_parser: Callable = None, + retry_count: int = 0, # Used internally for tracking retries ) -> Dict[str, Any]: """ - Make an HTTP request to the API + Make an HTTP request to the API with automatic retries for transient errors. Args: method: HTTP method (GET, POST, etc.) @@ -233,12 +316,15 @@ class ApiClient: files: Files to upload headers: Additional headers content_type: Content type of the request. Defaults to application/json. + retry_count: Internal parameter for tracking retries, do not set manually Returns: Parsed JSON response Raises: - requests.RequestException: If the request fails + LocalNetworkError: If local network connectivity issues are detected + ApiServerError: If the API server is unreachable but internet is working + Exception: For other request failures """ url = urljoin(self.base_url, path) self.check_auth(self.auth_token, self.comfy_api_key) @@ -265,6 +351,16 @@ class ApiClient: else: payload_args = self._create_json_payload_args(data, request_headers) + operation_id = self._generate_operation_id(path) + request_logger.log_request_response( + operation_id=operation_id, + request_method=method, + request_url=url, + request_headers=request_headers, + request_params=params, + request_data=data if content_type == "application/json" else "[form-data or other]" + ) + try: response = requests.request( method=method, @@ -275,50 +371,228 @@ class ApiClient: **payload_args, ) + # Check if we should retry based on status code + if (response.status_code in self.retry_status_codes and + retry_count < self.max_retries): + + # Calculate delay with exponential backoff + delay = self.retry_delay * (self.retry_backoff_factor ** retry_count) + + logging.warning( + f"Request failed with status {response.status_code}. " + f"Retrying in {delay:.2f}s ({retry_count + 1}/{self.max_retries})" + ) + + time.sleep(delay) + return self.request( + method=method, + path=path, + params=params, + data=data, + files=files, + headers=headers, + content_type=content_type, + multipart_parser=multipart_parser, + retry_count=retry_count + 1, + ) + # Raise exception for error status codes response.raise_for_status() - except requests.ConnectionError: - raise Exception( - f"Unable to connect to the API server at {self.base_url}. Please check your internet connection or verify the service is available." + + # Log successful response + response_content_to_log = response.content + try: + # Attempt to parse JSON for prettier logging, fallback to raw content + response_content_to_log = response.json() + except json.JSONDecodeError: + pass # Keep as bytes/str if not JSON + + request_logger.log_request_response( + operation_id=operation_id, + request_method=method, # Pass request details again for context in log + request_url=url, + response_status_code=response.status_code, + response_headers=dict(response.headers), + response_content=response_content_to_log ) - except requests.Timeout: - raise Exception( - f"Request timed out after {self.timeout} seconds. The server might be experiencing high load or the operation is taking longer than expected." + except requests.ConnectionError as e: + error_message = f"ConnectionError: {str(e)}" + request_logger.log_request_response( + operation_id=operation_id, + request_method=method, + request_url=url, + error_message=error_message ) + # Only perform connectivity check if we've exhausted all retries + if retry_count >= self.max_retries: + # Check connectivity to determine if it's a local or API issue + connectivity = self._check_connectivity(self.base_url) + + if connectivity["is_local_issue"]: + raise LocalNetworkError( + "Unable to connect to the API server due to local network issues. " + "Please check your internet connection and try again." + ) from e + elif connectivity["is_api_issue"]: + raise ApiServerError( + f"The API server at {self.base_url} is currently unreachable. " + f"The service may be experiencing issues. Please try again later." + ) from e + + # If we haven't exhausted retries yet, retry the request + if retry_count < self.max_retries: + delay = self.retry_delay * (self.retry_backoff_factor ** retry_count) + logging.warning( + f"Connection error: {str(e)}. " + f"Retrying in {delay:.2f}s ({retry_count + 1}/{self.max_retries})" + ) + time.sleep(delay) + return self.request( + method=method, + path=path, + params=params, + data=data, + files=files, + headers=headers, + content_type=content_type, + multipart_parser=multipart_parser, + retry_count=retry_count + 1, + ) + + # If we've exhausted retries and didn't identify the specific issue, + # raise a generic exception + final_error_message = ( + f"Unable to connect to the API server after {self.max_retries} attempts. " + f"Please check your internet connection or try again later." + ) + request_logger.log_request_response( # Log final failure + operation_id=operation_id, + request_method=method, request_url=url, + error_message=final_error_message + ) + raise Exception(final_error_message) from e + + except requests.Timeout as e: + error_message = f"Timeout: {str(e)}" + request_logger.log_request_response( + operation_id=operation_id, + request_method=method, request_url=url, + error_message=error_message + ) + # Retry timeouts if we haven't exhausted retries + if retry_count < self.max_retries: + delay = self.retry_delay * (self.retry_backoff_factor ** retry_count) + logging.warning( + f"Request timed out. " + f"Retrying in {delay:.2f}s ({retry_count + 1}/{self.max_retries})" + ) + time.sleep(delay) + return self.request( + method=method, + path=path, + params=params, + data=data, + files=files, + headers=headers, + content_type=content_type, + multipart_parser=multipart_parser, + retry_count=retry_count + 1, + ) + final_error_message = ( + f"Request timed out after {self.timeout} seconds and {self.max_retries} retry attempts. " + f"The server might be experiencing high load or the operation is taking longer than expected." + ) + request_logger.log_request_response( # Log final failure + operation_id=operation_id, + request_method=method, request_url=url, + error_message=final_error_message + ) + raise Exception(final_error_message) from e except requests.HTTPError as e: status_code = e.response.status_code if hasattr(e, "response") else None - error_message = f"HTTP Error: {str(e)}" + original_error_message = f"HTTP Error: {str(e)}" + error_content_for_log = None + if hasattr(e, "response") and e.response is not None: + error_content_for_log = e.response.content + try: + error_content_for_log = e.response.json() + except json.JSONDecodeError: + pass + + + # Try to extract detailed error message from JSON response for user display + # but log the full error content. + user_display_error_message = original_error_message - # Try to extract detailed error message from JSON response try: - if hasattr(e, "response") and e.response.content: + if hasattr(e, "response") and e.response is not None and e.response.content: error_json = e.response.json() if "error" in error_json and "message" in error_json["error"]: - error_message = f"API Error: {error_json['error']['message']}" + user_display_error_message = f"API Error: {error_json['error']['message']}" if "type" in error_json["error"]: - error_message += f" (Type: {error_json['error']['type']})" + user_display_error_message += f" (Type: {error_json['error']['type']})" + elif isinstance(error_json, dict): # Handle cases where error is just a JSON dict + user_display_error_message = f"API Error: {json.dumps(error_json)}" + else: # Non-dict JSON error + user_display_error_message = f"API Error: {str(error_json)}" + except json.JSONDecodeError: + # If not JSON, use the raw content if it's not too long, or a summary + if hasattr(e, "response") and e.response is not None and e.response.content: + raw_content = e.response.content.decode(errors='ignore') + if len(raw_content) < 200: # Arbitrary limit for display + user_display_error_message = f"API Error (raw): {raw_content}" else: - error_message = f"API Error: {error_json}" - except Exception as json_error: - # If we can't parse the JSON, fall back to the original error message - logging.debug( - f"[DEBUG] Failed to parse error response: {str(json_error)}" + user_display_error_message = f"API Error (raw, status {status_code})" + + request_logger.log_request_response( + operation_id=operation_id, + request_method=method, request_url=url, + response_status_code=status_code, + response_headers=dict(e.response.headers) if hasattr(e, "response") and e.response is not None else None, + response_content=error_content_for_log, + error_message=original_error_message # Log the original exception string as error + ) + + logging.debug(f"[DEBUG] API Error: {user_display_error_message} (Status: {status_code})") + if hasattr(e, "response") and e.response is not None and e.response.content: + logging.debug(f"[DEBUG] Response content: {e.response.content}") + + # Retry if the status code is in our retry list and we haven't exhausted retries + if (status_code in self.retry_status_codes and + retry_count < self.max_retries): + + delay = self.retry_delay * (self.retry_backoff_factor ** retry_count) + logging.warning( + f"HTTP error {status_code}. " + f"Retrying in {delay:.2f}s ({retry_count + 1}/{self.max_retries})" + ) + time.sleep(delay) + return self.request( + method=method, + path=path, + params=params, + data=data, + files=files, + headers=headers, + content_type=content_type, + multipart_parser=multipart_parser, + retry_count=retry_count + 1, ) - logging.debug(f"[DEBUG] API Error: {error_message} (Status: {status_code})") - if hasattr(e, "response") and e.response.content: - logging.debug(f"[DEBUG] Response content: {e.response.content}") + # Specific error messages for common status codes for user display if status_code == 401: - error_message = "Unauthorized: Please login first to use this node." - if status_code == 402: - error_message = "Payment Required: Please add credits to your account to use this node." - if status_code == 409: - error_message = "There is a problem with your account. Please contact support@comfy.org. " - if status_code == 429: - error_message = "Rate Limit Exceeded: Please try again later." - raise Exception(error_message) + user_display_error_message = "Unauthorized: Please login first to use this node." + elif status_code == 402: + user_display_error_message = "Payment Required: Please add credits to your account to use this node." + elif status_code == 409: + user_display_error_message = "There is a problem with your account. Please contact support@comfy.org." + elif status_code == 429: + user_display_error_message = "Rate Limit Exceeded: Please try again later." + # else, user_display_error_message remains as parsed from response or original HTTPError string + + raise Exception(user_display_error_message) # Raise with the user-friendly message # Parse and return JSON response if response.content: @@ -336,26 +610,126 @@ class ApiClient: upload_url: str, file: io.BytesIO | str, content_type: str | None = None, + max_retries: int = 3, + retry_delay: float = 1.0, + retry_backoff_factor: float = 2.0, ): - """Upload a file to the API. Make sure the file has a filename equal to what the url expects. + """Upload a file to the API with retry logic. Args: upload_url: The URL to upload to file: Either a file path string, BytesIO object, or tuple of (file_path, filename) - mime_type: Optional mime type to set for the upload + content_type: Optional mime type to set for the upload + max_retries: Maximum number of retry attempts + retry_delay: Initial delay between retries in seconds + retry_backoff_factor: Multiplier for the delay after each retry """ headers = {} if content_type: headers["Content-Type"] = content_type + # Prepare the file data if isinstance(file, io.BytesIO): file.seek(0) # Ensure we're at the start of the file data = file.read() - return requests.put(upload_url, data=data, headers=headers) elif isinstance(file, str): with open(file, "rb") as f: data = f.read() - return requests.put(upload_url, data=data, headers=headers) + else: + raise ValueError("File must be either a BytesIO object or a file path string") + + # Try the upload with retries + last_exception = None + operation_id = f"upload_{upload_url.split('/')[-1]}_{uuid.uuid4().hex[:8]}" # Simplified ID for uploads + + # Log initial attempt (without full file data for brevity) + request_logger.log_request_response( + operation_id=operation_id, + request_method="PUT", + request_url=upload_url, + request_headers=headers, + request_data=f"[File data of type {content_type or 'unknown'}, size {len(data)} bytes]" + ) + + for retry_attempt in range(max_retries + 1): + try: + response = requests.put(upload_url, data=data, headers=headers) + response.raise_for_status() + request_logger.log_request_response( + operation_id=operation_id, + request_method="PUT", request_url=upload_url, # For context + response_status_code=response.status_code, + response_headers=dict(response.headers), + response_content="File uploaded successfully." # Or response.text if available + ) + return response + + except (requests.ConnectionError, requests.Timeout, requests.HTTPError) as e: + last_exception = e + error_message_for_log = f"{type(e).__name__}: {str(e)}" + response_content_for_log = None + status_code_for_log = None + headers_for_log = None + + if hasattr(e, 'response') and e.response is not None: + status_code_for_log = e.response.status_code + headers_for_log = dict(e.response.headers) + try: + response_content_for_log = e.response.json() + except json.JSONDecodeError: + response_content_for_log = e.response.content + + + request_logger.log_request_response( + operation_id=operation_id, + request_method="PUT", request_url=upload_url, + response_status_code=status_code_for_log, + response_headers=headers_for_log, + response_content=response_content_for_log, + error_message=error_message_for_log + ) + + if retry_attempt < max_retries: + delay = retry_delay * (retry_backoff_factor ** retry_attempt) + logging.warning( + f"File upload failed: {str(e)}. " + f"Retrying in {delay:.2f}s ({retry_attempt + 1}/{max_retries})" + ) + time.sleep(delay) + else: + break # Max retries reached + + # If we've exhausted all retries, determine the final error type and raise + final_error_message = f"Failed to upload file after {max_retries + 1} attempts. Error: {str(last_exception)}" + try: + # Check basic internet connectivity + check_response = requests.get("https://www.google.com", timeout=5.0, verify=True) # Assuming verify=True is desired + if check_response.status_code >= 500: # Google itself has an issue (rare) + final_error_message = (f"Failed to upload file. Internet connectivity check to Google failed " + f"(status {check_response.status_code}). Original error: {str(last_exception)}") + # Not raising LocalNetworkError here as Google itself might be down. + # If Google is reachable, the issue is likely with the upload server or a more specific local problem + # not caught by a simple Google ping (e.g., DNS for the specific upload URL, firewall). + # The original last_exception is probably most relevant. + + except (requests.RequestException, socket.error) as conn_check_exc: + # Could not reach Google, likely a local network issue + final_error_message = (f"Failed to upload file due to network connectivity issues " + f"(cannot reach Google: {str(conn_check_exc)}). " + f"Original upload error: {str(last_exception)}") + request_logger.log_request_response( # Log final failure reason + operation_id=operation_id, + request_method="PUT", request_url=upload_url, + error_message=final_error_message + ) + raise LocalNetworkError(final_error_message) from last_exception + + request_logger.log_request_response( # Log final failure reason if not LocalNetworkError + operation_id=operation_id, + request_method="PUT", request_url=upload_url, + error_message=final_error_message + ) + raise Exception(final_error_message) from last_exception class ApiEndpoint(Generic[T, R]): @@ -403,6 +777,9 @@ class SynchronousOperation(Generic[T, R]): verify_ssl: bool = True, content_type: str = "application/json", multipart_parser: Callable = None, + max_retries: int = 3, + retry_delay: float = 1.0, + retry_backoff_factor: float = 2.0, ): self.endpoint = endpoint self.request = request @@ -419,8 +796,12 @@ class SynchronousOperation(Generic[T, R]): self.files = files self.content_type = content_type self.multipart_parser = multipart_parser + self.max_retries = max_retries + self.retry_delay = retry_delay + self.retry_backoff_factor = retry_backoff_factor + def execute(self, client: Optional[ApiClient] = None) -> R: - """Execute the API operation using the provided client or create one""" + """Execute the API operation using the provided client or create one with retry support""" try: # Create client if not provided if client is None: @@ -430,6 +811,9 @@ class SynchronousOperation(Generic[T, R]): comfy_api_key=self.comfy_api_key, timeout=self.timeout, verify_ssl=self.verify_ssl, + max_retries=self.max_retries, + retry_delay=self.retry_delay, + retry_backoff_factor=self.retry_backoff_factor, ) # Convert request model to dict, but use None for EmptyRequest @@ -443,11 +827,6 @@ class SynchronousOperation(Generic[T, R]): if isinstance(value, Enum): request_dict[key] = value.value - if request_dict: - for key, value in request_dict.items(): - if isinstance(value, Enum): - request_dict[key] = value.value - # Debug log for request logging.debug( f"[DEBUG] API Request: {self.endpoint.method.value} {self.endpoint.path}" @@ -455,7 +834,7 @@ class SynchronousOperation(Generic[T, R]): logging.debug(f"[DEBUG] Request Data: {json.dumps(request_dict, indent=2)}") logging.debug(f"[DEBUG] Query Params: {self.endpoint.query_params}") - # Make the request + # Make the request with built-in retry resp = client.request( method=self.endpoint.method.value, path=self.endpoint.path, @@ -476,8 +855,18 @@ class SynchronousOperation(Generic[T, R]): # Parse and return the response return self._parse_response(resp) + except LocalNetworkError as e: + # Propagate specific network error types + logging.error(f"[ERROR] Local network error: {str(e)}") + raise + + except ApiServerError as e: + # Propagate API server errors + logging.error(f"[ERROR] API server error: {str(e)}") + raise + except Exception as e: - logging.error(f"[DEBUG] API Exception: {str(e)}") + logging.error(f"[ERROR] API Exception: {str(e)}") raise Exception(str(e)) def _parse_response(self, resp): @@ -517,6 +906,10 @@ class PollingOperation(Generic[T, R]): comfy_api_key: Optional[str] = None, auth_kwargs: Optional[Dict[str,str]] = None, poll_interval: float = 5.0, + max_poll_attempts: int = 120, # Default max polling attempts (10 minutes with 5s interval) + max_retries: int = 3, # Max retries per individual API call + retry_delay: float = 1.0, + retry_backoff_factor: float = 2.0, ): self.poll_endpoint = poll_endpoint self.request = request @@ -527,6 +920,10 @@ class PollingOperation(Generic[T, R]): self.auth_token = auth_kwargs.get("auth_token", self.auth_token) self.comfy_api_key = auth_kwargs.get("comfy_api_key", self.comfy_api_key) self.poll_interval = poll_interval + self.max_poll_attempts = max_poll_attempts + self.max_retries = max_retries + self.retry_delay = retry_delay + self.retry_backoff_factor = retry_backoff_factor # Polling configuration self.status_extractor = status_extractor or ( @@ -548,8 +945,23 @@ class PollingOperation(Generic[T, R]): base_url=self.api_base, auth_token=self.auth_token, comfy_api_key=self.comfy_api_key, + max_retries=self.max_retries, + retry_delay=self.retry_delay, + retry_backoff_factor=self.retry_backoff_factor, ) return self._poll_until_complete(client) + except LocalNetworkError as e: + # Provide clear message for local network issues + raise Exception( + f"Polling failed due to local network issues. Please check your internet connection. " + f"Details: {str(e)}" + ) from e + except ApiServerError as e: + # Provide clear message for API server issues + raise Exception( + f"Polling failed due to API server issues. The service may be experiencing problems. " + f"Please try again later. Details: {str(e)}" + ) from e except Exception as e: raise Exception(f"Error during polling: {str(e)}") @@ -569,10 +981,13 @@ class PollingOperation(Generic[T, R]): def _poll_until_complete(self, client: ApiClient) -> R: """Poll until the task is complete""" poll_count = 0 + consecutive_errors = 0 + max_consecutive_errors = min(5, self.max_retries * 2) # Limit consecutive errors + if self.progress_extractor: progress = utils.ProgressBar(PROGRESS_BAR_MAX) - while True: + while poll_count < self.max_poll_attempts: try: poll_count += 1 logging.debug(f"[DEBUG] Polling attempt #{poll_count}") @@ -599,8 +1014,12 @@ class PollingOperation(Generic[T, R]): data=request_dict, ) + # Successfully got a response, reset consecutive error count + consecutive_errors = 0 + # Parse response response_obj = self.poll_endpoint.response_model.model_validate(resp) + # Check if task is complete status = self._check_task_status(response_obj) logging.debug(f"[DEBUG] Task Status: {status}") @@ -630,6 +1049,38 @@ class PollingOperation(Generic[T, R]): ) time.sleep(self.poll_interval) + except (LocalNetworkError, ApiServerError) as e: + # For network-related errors, increment error count and potentially abort + consecutive_errors += 1 + if consecutive_errors >= max_consecutive_errors: + raise Exception( + f"Polling aborted after {consecutive_errors} consecutive network errors: {str(e)}" + ) from e + + # Log the error but continue polling + logging.warning( + f"Network error during polling (attempt {poll_count}/{self.max_poll_attempts}): {str(e)}. " + f"Will retry in {self.poll_interval} seconds." + ) + time.sleep(self.poll_interval) + except Exception as e: + # For other errors, increment count and potentially abort + consecutive_errors += 1 + if consecutive_errors >= max_consecutive_errors: + raise Exception( + f"Polling aborted after {consecutive_errors} consecutive errors: {str(e)}" + ) from e + logging.error(f"[DEBUG] Polling error: {str(e)}") - raise Exception(f"Error while polling: {str(e)}") + logging.warning( + f"Error during polling (attempt {poll_count}/{self.max_poll_attempts}): {str(e)}. " + f"Will retry in {self.poll_interval} seconds." + ) + time.sleep(self.poll_interval) + + # If we've exhausted all polling attempts + raise Exception( + f"Polling timed out after {poll_count} attempts ({poll_count * self.poll_interval} seconds). " + f"The operation may still be running on the server but is taking longer than expected." + ) diff --git a/comfy_api_nodes/apis/request_logger.py b/comfy_api_nodes/apis/request_logger.py new file mode 100644 index 00000000..93517ede --- /dev/null +++ b/comfy_api_nodes/apis/request_logger.py @@ -0,0 +1,125 @@ +import os +import datetime +import json +import logging +import folder_paths + +# Get the logger instance +logger = logging.getLogger(__name__) + +def get_log_directory(): + """ + Ensures the API log directory exists within ComfyUI's temp directory + and returns its path. + """ + base_temp_dir = folder_paths.get_temp_directory() + log_dir = os.path.join(base_temp_dir, "api_logs") + try: + os.makedirs(log_dir, exist_ok=True) + except Exception as e: + logger.error(f"Error creating API log directory {log_dir}: {e}") + # Fallback to base temp directory if sub-directory creation fails + return base_temp_dir + return log_dir + +def _format_data_for_logging(data): + """Helper to format data (dict, str, bytes) for logging.""" + if isinstance(data, bytes): + try: + return data.decode('utf-8') # Try to decode as text + except UnicodeDecodeError: + return f"[Binary data of length {len(data)} bytes]" + elif isinstance(data, (dict, list)): + try: + return json.dumps(data, indent=2, ensure_ascii=False) + except TypeError: + return str(data) # Fallback for non-serializable objects + return str(data) + +def log_request_response( + operation_id: str, + request_method: str, + request_url: str, + request_headers: dict | None = None, + request_params: dict | None = None, + request_data: any = None, + response_status_code: int | None = None, + response_headers: dict | None = None, + response_content: any = None, + error_message: str | None = None +): + """ + Logs API request and response details to a file in the temp/api_logs directory. + """ + log_dir = get_log_directory() + timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S_%f") + filename = f"{timestamp}_{operation_id.replace('/', '_').replace(':', '_')}.log" + filepath = os.path.join(log_dir, filename) + + log_content = [] + + log_content.append(f"Timestamp: {datetime.datetime.now().isoformat()}") + log_content.append(f"Operation ID: {operation_id}") + log_content.append("-" * 30 + " REQUEST " + "-" * 30) + log_content.append(f"Method: {request_method}") + log_content.append(f"URL: {request_url}") + if request_headers: + log_content.append(f"Headers:\n{_format_data_for_logging(request_headers)}") + if request_params: + log_content.append(f"Params:\n{_format_data_for_logging(request_params)}") + if request_data: + log_content.append(f"Data/Body:\n{_format_data_for_logging(request_data)}") + + log_content.append("\n" + "-" * 30 + " RESPONSE " + "-" * 30) + if response_status_code is not None: + log_content.append(f"Status Code: {response_status_code}") + if response_headers: + log_content.append(f"Headers:\n{_format_data_for_logging(response_headers)}") + if response_content: + log_content.append(f"Content:\n{_format_data_for_logging(response_content)}") + if error_message: + log_content.append(f"Error:\n{error_message}") + + try: + with open(filepath, "w", encoding="utf-8") as f: + f.write("\n".join(log_content)) + logger.debug(f"API log saved to: {filepath}") + except Exception as e: + logger.error(f"Error writing API log to {filepath}: {e}") + +if __name__ == '__main__': + # Example usage (for testing the logger directly) + logger.setLevel(logging.DEBUG) + # Mock folder_paths for direct execution if not running within ComfyUI full context + if not hasattr(folder_paths, 'get_temp_directory'): + class MockFolderPaths: + def get_temp_directory(self): + # Create a local temp dir for testing if needed + p = os.path.join(os.path.dirname(__file__), 'temp_test_logs') + os.makedirs(p, exist_ok=True) + return p + folder_paths = MockFolderPaths() + + log_request_response( + operation_id="test_operation_get", + request_method="GET", + request_url="https://api.example.com/test", + request_headers={"Authorization": "Bearer testtoken"}, + request_params={"param1": "value1"}, + response_status_code=200, + response_content={"message": "Success!"} + ) + log_request_response( + operation_id="test_operation_post_error", + request_method="POST", + request_url="https://api.example.com/submit", + request_data={"key": "value", "nested": {"num": 123}}, + error_message="Connection timed out" + ) + log_request_response( + operation_id="test_binary_response", + request_method="GET", + request_url="https://api.example.com/image.png", + response_status_code=200, + response_content=b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR...' # Sample binary data + ) From 98ff01e1486353a3b0ddd8fa82fcbb25401f8364 Mon Sep 17 00:00:00 2001 From: Christian Byrne Date: Tue, 13 May 2025 21:33:18 -0700 Subject: [PATCH 08/31] Display progress and result URL directly on API nodes (#8102) * [Luma] Print download URL of successful task result directly on nodes (#177) [Veo] Print download URL of successful task result directly on nodes (#184) [Recraft] Print download URL of successful task result directly on nodes (#183) [Pixverse] Print download URL of successful task result directly on nodes (#182) [Kling] Print download URL of successful task result directly on nodes (#181) [MiniMax] Print progress text and download URL of successful task result directly on nodes (#179) [Docs] Link to docs in `API_NODE` class property type annotation comment (#178) [Ideogram] Print download URL of successful task result directly on nodes (#176) [Kling] Print download URL of successful task result directly on nodes (#181) [Veo] Print download URL of successful task result directly on nodes (#184) [Recraft] Print download URL of successful task result directly on nodes (#183) [Pixverse] Print download URL of successful task result directly on nodes (#182) [MiniMax] Print progress text and download URL of successful task result directly on nodes (#179) [Docs] Link to docs in `API_NODE` class property type annotation comment (#178) [Luma] Print download URL of successful task result directly on nodes (#177) [Ideogram] Print download URL of successful task result directly on nodes (#176) Show output URL and progress text on Pika nodes (#168) [BFL] Print download URL of successful task result directly on nodes (#175) [OpenAI ] Print download URL of successful task result directly on nodes (#174) * fix ruff errors * fix 3.10 syntax error --- comfy/comfy_types/node_typing.py | 2 +- comfy_api_nodes/apinode_utils.py | 11 +- comfy_api_nodes/apis/client.py | 42 +++++++- comfy_api_nodes/nodes_bfl.py | 63 ++++++++---- comfy_api_nodes/nodes_ideogram.py | 24 ++++- comfy_api_nodes/nodes_kling.py | 164 ++++++++++++++++++++++++++---- comfy_api_nodes/nodes_luma.py | 33 ++++++ comfy_api_nodes/nodes_minimax.py | 27 ++++- comfy_api_nodes/nodes_openai.py | 12 ++- comfy_api_nodes/nodes_pika.py | 42 ++++++-- comfy_api_nodes/nodes_pixverse.py | 99 ++++++++++++------ comfy_api_nodes/nodes_recraft.py | 21 ++++ comfy_api_nodes/nodes_veo2.py | 26 ++++- 13 files changed, 474 insertions(+), 92 deletions(-) diff --git a/comfy/comfy_types/node_typing.py b/comfy/comfy_types/node_typing.py index 2ffc9c02..470eb9fd 100644 --- a/comfy/comfy_types/node_typing.py +++ b/comfy/comfy_types/node_typing.py @@ -235,7 +235,7 @@ class ComfyNodeABC(ABC): DEPRECATED: bool """Flags a node as deprecated, indicating to users that they should find alternatives to this node.""" API_NODE: Optional[bool] - """Flags a node as an API node.""" + """Flags a node as an API node. See: https://docs.comfy.org/tutorials/api-nodes/overview.""" @classmethod @abstractmethod diff --git a/comfy_api_nodes/apinode_utils.py b/comfy_api_nodes/apinode_utils.py index e28d7d60..87d8c3e1 100644 --- a/comfy_api_nodes/apinode_utils.py +++ b/comfy_api_nodes/apinode_utils.py @@ -1,7 +1,7 @@ from __future__ import annotations import io import logging -from typing import Optional +from typing import Optional, Union from comfy.utils import common_upscale from comfy_api.input_impl import VideoFromFile from comfy_api.util import VideoContainer, VideoCodec @@ -15,6 +15,7 @@ from comfy_api_nodes.apis.client import ( UploadRequest, UploadResponse, ) +from server import PromptServer import numpy as np @@ -60,7 +61,9 @@ def downscale_image_tensor(image, total_pixels=1536 * 1024) -> torch.Tensor: return s -def validate_and_cast_response(response, timeout: int = None) -> torch.Tensor: +def validate_and_cast_response( + response, timeout: int = None, node_id: Union[str, None] = None +) -> torch.Tensor: """Validates and casts a response to a torch.Tensor. Args: @@ -94,6 +97,10 @@ def validate_and_cast_response(response, timeout: int = None) -> torch.Tensor: img = Image.open(io.BytesIO(img_data)) elif image_url: + if node_id: + PromptServer.instance.send_progress_text( + f"Result URL: {image_url}", node_id + ) img_response = requests.get(image_url, timeout=timeout) if img_response.status_code != 200: raise ValueError("Failed to download the image") diff --git a/comfy_api_nodes/apis/client.py b/comfy_api_nodes/apis/client.py index 158d20de..838ff1e8 100644 --- a/comfy_api_nodes/apis/client.py +++ b/comfy_api_nodes/apis/client.py @@ -103,6 +103,7 @@ from urllib.parse import urljoin, urlparse from pydantic import BaseModel, Field import uuid # For generating unique operation IDs +from server import PromptServer from comfy.cli_args import args from comfy import utils from . import request_logger @@ -900,6 +901,7 @@ class PollingOperation(Generic[T, R]): failed_statuses: list, status_extractor: Callable[[R], str], progress_extractor: Callable[[R], float] = None, + result_url_extractor: Callable[[R], str] = None, request: Optional[T] = None, api_base: str | None = None, auth_token: Optional[str] = None, @@ -910,6 +912,8 @@ class PollingOperation(Generic[T, R]): max_retries: int = 3, # Max retries per individual API call retry_delay: float = 1.0, retry_backoff_factor: float = 2.0, + estimated_duration: Optional[float] = None, + node_id: Optional[str] = None, ): self.poll_endpoint = poll_endpoint self.request = request @@ -924,12 +928,15 @@ class PollingOperation(Generic[T, R]): self.max_retries = max_retries self.retry_delay = retry_delay self.retry_backoff_factor = retry_backoff_factor + self.estimated_duration = estimated_duration # Polling configuration self.status_extractor = status_extractor or ( lambda x: getattr(x, "status", None) ) self.progress_extractor = progress_extractor + self.result_url_extractor = result_url_extractor + self.node_id = node_id self.completed_statuses = completed_statuses self.failed_statuses = failed_statuses @@ -965,6 +972,26 @@ class PollingOperation(Generic[T, R]): except Exception as e: raise Exception(f"Error during polling: {str(e)}") + def _display_text_on_node(self, text: str): + """Sends text to the client which will be displayed on the node in the UI""" + if not self.node_id: + return + + PromptServer.instance.send_progress_text(text, self.node_id) + + def _display_time_progress_on_node(self, time_completed: int): + if not self.node_id: + return + + if self.estimated_duration is not None: + estimated_time_remaining = max( + 0, int(self.estimated_duration) - int(time_completed) + ) + message = f"Task in progress: {time_completed:.0f}s (~{estimated_time_remaining:.0f}s remaining)" + else: + message = f"Task in progress: {time_completed:.0f}s" + self._display_text_on_node(message) + def _check_task_status(self, response: R) -> TaskStatus: """Check task status using the status extractor function""" try: @@ -1031,7 +1058,15 @@ class PollingOperation(Generic[T, R]): progress.update_absolute(new_progress, total=PROGRESS_BAR_MAX) if status == TaskStatus.COMPLETED: - logging.debug("[DEBUG] Task completed successfully") + message = "Task completed successfully" + if self.result_url_extractor: + result_url = self.result_url_extractor(response_obj) + if result_url: + message = f"Result URL: {result_url}" + else: + message = "Task completed successfully!" + logging.debug(f"[DEBUG] {message}") + self._display_text_on_node(message) self.final_response = response_obj if self.progress_extractor: progress.update(100) @@ -1047,7 +1082,10 @@ class PollingOperation(Generic[T, R]): logging.debug( f"[DEBUG] Waiting {self.poll_interval} seconds before next poll" ) - time.sleep(self.poll_interval) + for i in range(int(self.poll_interval)): + time_completed = (poll_count * self.poll_interval) + i + self._display_time_progress_on_node(time_completed) + time.sleep(1) except (LocalNetworkError, ApiServerError) as e: # For network-related errors, increment error count and potentially abort diff --git a/comfy_api_nodes/nodes_bfl.py b/comfy_api_nodes/nodes_bfl.py index 66ef1b39..509170b3 100644 --- a/comfy_api_nodes/nodes_bfl.py +++ b/comfy_api_nodes/nodes_bfl.py @@ -1,5 +1,6 @@ import io from inspect import cleandoc +from typing import Union from comfy.comfy_types.node_typing import IO, ComfyNodeABC from comfy_api_nodes.apis.bfl_api import ( BFLStatus, @@ -30,6 +31,7 @@ import requests import torch import base64 import time +from server import PromptServer def convert_mask_to_image(mask: torch.Tensor): @@ -42,14 +44,19 @@ def convert_mask_to_image(mask: torch.Tensor): def handle_bfl_synchronous_operation( - operation: SynchronousOperation, timeout_bfl_calls=360 + operation: SynchronousOperation, + timeout_bfl_calls=360, + node_id: Union[str, None] = None, ): response_api: BFLFluxProGenerateResponse = operation.execute() return _poll_until_generated( - response_api.polling_url, timeout=timeout_bfl_calls + response_api.polling_url, timeout=timeout_bfl_calls, node_id=node_id ) -def _poll_until_generated(polling_url: str, timeout=360): + +def _poll_until_generated( + polling_url: str, timeout=360, node_id: Union[str, None] = None +): # used bfl-comfy-nodes to verify code implementation: # https://github.com/black-forest-labs/bfl-comfy-nodes/tree/main start_time = time.time() @@ -61,11 +68,21 @@ def _poll_until_generated(polling_url: str, timeout=360): request = requests.Request(method=HttpMethod.GET, url=polling_url) # NOTE: should True loop be replaced with checking if workflow has been interrupted? while True: + if node_id: + time_elapsed = time.time() - start_time + PromptServer.instance.send_progress_text( + f"Generating ({time_elapsed:.0f}s)", node_id + ) + response = requests.Session().send(request.prepare()) if response.status_code == 200: result = response.json() if result["status"] == BFLStatus.ready: img_url = result["result"]["sample"] + if node_id: + PromptServer.instance.send_progress_text( + f"Result URL: {img_url}", node_id + ) img_response = requests.get(img_url) return process_image_response(img_response) elif result["status"] in [ @@ -180,6 +197,7 @@ class FluxProUltraImageNode(ComfyNodeABC): "hidden": { "auth_token": "AUTH_TOKEN_COMFY_ORG", "comfy_api_key": "API_KEY_COMFY_ORG", + "unique_id": "UNIQUE_ID", }, } @@ -212,6 +230,7 @@ class FluxProUltraImageNode(ComfyNodeABC): seed=0, image_prompt=None, image_prompt_strength=0.1, + unique_id: Union[str, None] = None, **kwargs, ): if image_prompt is None: @@ -246,7 +265,7 @@ class FluxProUltraImageNode(ComfyNodeABC): ), auth_kwargs=kwargs, ) - output_image = handle_bfl_synchronous_operation(operation) + output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id) return (output_image,) @@ -320,6 +339,7 @@ class FluxProImageNode(ComfyNodeABC): "hidden": { "auth_token": "AUTH_TOKEN_COMFY_ORG", "comfy_api_key": "API_KEY_COMFY_ORG", + "unique_id": "UNIQUE_ID", }, } @@ -338,6 +358,7 @@ class FluxProImageNode(ComfyNodeABC): seed=0, image_prompt=None, # image_prompt_strength=0.1, + unique_id: Union[str, None] = None, **kwargs, ): image_prompt = ( @@ -363,7 +384,7 @@ class FluxProImageNode(ComfyNodeABC): ), auth_kwargs=kwargs, ) - output_image = handle_bfl_synchronous_operation(operation) + output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id) return (output_image,) @@ -457,11 +478,11 @@ class FluxProExpandNode(ComfyNodeABC): }, ), }, - "optional": { - }, + "optional": {}, "hidden": { "auth_token": "AUTH_TOKEN_COMFY_ORG", "comfy_api_key": "API_KEY_COMFY_ORG", + "unique_id": "UNIQUE_ID", }, } @@ -483,6 +504,7 @@ class FluxProExpandNode(ComfyNodeABC): steps: int, guidance: float, seed=0, + unique_id: Union[str, None] = None, **kwargs, ): image = convert_image_to_base64(image) @@ -508,7 +530,7 @@ class FluxProExpandNode(ComfyNodeABC): ), auth_kwargs=kwargs, ) - output_image = handle_bfl_synchronous_operation(operation) + output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id) return (output_image,) @@ -568,11 +590,11 @@ class FluxProFillNode(ComfyNodeABC): }, ), }, - "optional": { - }, + "optional": {}, "hidden": { "auth_token": "AUTH_TOKEN_COMFY_ORG", "comfy_api_key": "API_KEY_COMFY_ORG", + "unique_id": "UNIQUE_ID", }, } @@ -591,13 +613,14 @@ class FluxProFillNode(ComfyNodeABC): steps: int, guidance: float, seed=0, + unique_id: Union[str, None] = None, **kwargs, ): # prepare mask mask = resize_mask_to_image(mask, image) mask = convert_image_to_base64(convert_mask_to_image(mask)) # make sure image will have alpha channel removed - image = convert_image_to_base64(image[:,:,:,:3]) + image = convert_image_to_base64(image[:, :, :, :3]) operation = SynchronousOperation( endpoint=ApiEndpoint( @@ -617,7 +640,7 @@ class FluxProFillNode(ComfyNodeABC): ), auth_kwargs=kwargs, ) - output_image = handle_bfl_synchronous_operation(operation) + output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id) return (output_image,) @@ -702,11 +725,11 @@ class FluxProCannyNode(ComfyNodeABC): }, ), }, - "optional": { - }, + "optional": {}, "hidden": { "auth_token": "AUTH_TOKEN_COMFY_ORG", "comfy_api_key": "API_KEY_COMFY_ORG", + "unique_id": "UNIQUE_ID", }, } @@ -727,9 +750,10 @@ class FluxProCannyNode(ComfyNodeABC): steps: int, guidance: float, seed=0, + unique_id: Union[str, None] = None, **kwargs, ): - control_image = convert_image_to_base64(control_image[:,:,:,:3]) + control_image = convert_image_to_base64(control_image[:, :, :, :3]) preprocessed_image = None # scale canny threshold between 0-500, to match BFL's API @@ -765,7 +789,7 @@ class FluxProCannyNode(ComfyNodeABC): ), auth_kwargs=kwargs, ) - output_image = handle_bfl_synchronous_operation(operation) + output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id) return (output_image,) @@ -830,11 +854,11 @@ class FluxProDepthNode(ComfyNodeABC): }, ), }, - "optional": { - }, + "optional": {}, "hidden": { "auth_token": "AUTH_TOKEN_COMFY_ORG", "comfy_api_key": "API_KEY_COMFY_ORG", + "unique_id": "UNIQUE_ID", }, } @@ -853,6 +877,7 @@ class FluxProDepthNode(ComfyNodeABC): steps: int, guidance: float, seed=0, + unique_id: Union[str, None] = None, **kwargs, ): control_image = convert_image_to_base64(control_image[:,:,:,:3]) @@ -880,7 +905,7 @@ class FluxProDepthNode(ComfyNodeABC): ), auth_kwargs=kwargs, ) - output_image = handle_bfl_synchronous_operation(operation) + output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id) return (output_image,) diff --git a/comfy_api_nodes/nodes_ideogram.py b/comfy_api_nodes/nodes_ideogram.py index d25468b1..b1cbf511 100644 --- a/comfy_api_nodes/nodes_ideogram.py +++ b/comfy_api_nodes/nodes_ideogram.py @@ -23,6 +23,7 @@ from comfy_api_nodes.apinode_utils import ( bytesio_to_image_tensor, resize_mask_to_image, ) +from server import PromptServer V1_V1_RES_MAP = { "Auto":"AUTO", @@ -232,6 +233,19 @@ def download_and_process_images(image_urls): return stacked_tensors +def display_image_urls_on_node(image_urls, node_id): + if node_id and image_urls: + if len(image_urls) == 1: + PromptServer.instance.send_progress_text( + f"Generated Image URL:\n{image_urls[0]}", node_id + ) + else: + urls_text = "Generated Image URLs:\n" + "\n".join( + f"{i+1}. {url}" for i, url in enumerate(image_urls) + ) + PromptServer.instance.send_progress_text(urls_text, node_id) + + class IdeogramV1(ComfyNodeABC): """ Generates images using the Ideogram V1 model. @@ -304,6 +318,7 @@ class IdeogramV1(ComfyNodeABC): "hidden": { "auth_token": "AUTH_TOKEN_COMFY_ORG", "comfy_api_key": "API_KEY_COMFY_ORG", + "unique_id": "UNIQUE_ID", }, } @@ -322,6 +337,7 @@ class IdeogramV1(ComfyNodeABC): seed=0, negative_prompt="", num_images=1, + unique_id=None, **kwargs, ): # Determine the model based on turbo setting @@ -361,6 +377,7 @@ class IdeogramV1(ComfyNodeABC): if not image_urls: raise Exception("No image URLs were generated in the response") + display_image_urls_on_node(image_urls, unique_id) return (download_and_process_images(image_urls),) @@ -460,6 +477,7 @@ class IdeogramV2(ComfyNodeABC): "hidden": { "auth_token": "AUTH_TOKEN_COMFY_ORG", "comfy_api_key": "API_KEY_COMFY_ORG", + "unique_id": "UNIQUE_ID", }, } @@ -481,6 +499,7 @@ class IdeogramV2(ComfyNodeABC): negative_prompt="", num_images=1, color_palette="", + unique_id=None, **kwargs, ): aspect_ratio = V1_V2_RATIO_MAP.get(aspect_ratio, None) @@ -534,6 +553,7 @@ class IdeogramV2(ComfyNodeABC): if not image_urls: raise Exception("No image URLs were generated in the response") + display_image_urls_on_node(image_urls, unique_id) return (download_and_process_images(image_urls),) class IdeogramV3(ComfyNodeABC): @@ -623,6 +643,7 @@ class IdeogramV3(ComfyNodeABC): "hidden": { "auth_token": "AUTH_TOKEN_COMFY_ORG", "comfy_api_key": "API_KEY_COMFY_ORG", + "unique_id": "UNIQUE_ID", }, } @@ -643,6 +664,7 @@ class IdeogramV3(ComfyNodeABC): seed=0, num_images=1, rendering_speed="BALANCED", + unique_id=None, **kwargs, ): # Check if both image and mask are provided for editing mode @@ -762,6 +784,7 @@ class IdeogramV3(ComfyNodeABC): if not image_urls: raise Exception("No image URLs were generated in the response") + display_image_urls_on_node(image_urls, unique_id) return (download_and_process_images(image_urls),) @@ -776,4 +799,3 @@ NODE_DISPLAY_NAME_MAPPINGS = { "IdeogramV2": "Ideogram V2", "IdeogramV3": "Ideogram V3", } - diff --git a/comfy_api_nodes/nodes_kling.py b/comfy_api_nodes/nodes_kling.py index 2d0fd888..456a8690 100644 --- a/comfy_api_nodes/nodes_kling.py +++ b/comfy_api_nodes/nodes_kling.py @@ -6,6 +6,7 @@ For source of truth on the allowed permutations of request fields, please refere from __future__ import annotations from typing import Optional, TypeVar, Any +from collections.abc import Callable import math import logging @@ -86,6 +87,15 @@ MAX_PROMPT_LENGTH_IMAGE_GEN = 500 MAX_NEGATIVE_PROMPT_LENGTH_IMAGE_GEN = 200 MAX_PROMPT_LENGTH_LIP_SYNC = 120 +# TODO: adjust based on tests +AVERAGE_DURATION_T2V = 319 # 319, +AVERAGE_DURATION_I2V = 164 # 164, +AVERAGE_DURATION_LIP_SYNC = 120 +AVERAGE_DURATION_VIRTUAL_TRY_ON = 19 # 19, +AVERAGE_DURATION_IMAGE_GEN = 32 +AVERAGE_DURATION_VIDEO_EFFECTS = 320 +AVERAGE_DURATION_VIDEO_EXTEND = 320 + R = TypeVar("R") @@ -95,7 +105,13 @@ class KlingApiError(Exception): pass -def poll_until_finished(auth_kwargs: dict[str,str], api_endpoint: ApiEndpoint[Any, R]) -> R: +def poll_until_finished( + auth_kwargs: dict[str, str], + api_endpoint: ApiEndpoint[Any, R], + result_url_extractor: Optional[Callable[[R], str]] = None, + estimated_duration: Optional[int] = None, + node_id: Optional[str] = None, +) -> R: """Polls the Kling API endpoint until the task reaches a terminal state, then returns the response.""" return PollingOperation( poll_endpoint=api_endpoint, @@ -109,6 +125,9 @@ def poll_until_finished(auth_kwargs: dict[str,str], api_endpoint: ApiEndpoint[An else None ), auth_kwargs=auth_kwargs, + result_url_extractor=result_url_extractor, + estimated_duration=estimated_duration, + node_id=node_id, ).execute() @@ -227,7 +246,9 @@ def get_camera_control_input_config( def get_video_from_response(response) -> KlingVideoResult: - """Returns the first video object from the Kling video generation task result.""" + """Returns the first video object from the Kling video generation task result. + Will raise an error if the response is not valid. + """ video = response.data.task_result.videos[0] logging.info( "Kling task %s succeeded. Video URL: %s", response.data.task_id, video.url @@ -235,12 +256,37 @@ def get_video_from_response(response) -> KlingVideoResult: return video +def get_video_url_from_response(response) -> Optional[str]: + """Returns the first video url from the Kling video generation task result. + Will not raise an error if the response is not valid. + """ + if response and is_valid_video_response(response): + return str(get_video_from_response(response).url) + else: + return None + + def get_images_from_response(response) -> list[KlingImageResult]: + """Returns the list of image objects from the Kling image generation task result. + Will raise an error if the response is not valid. + """ images = response.data.task_result.images logging.info("Kling task %s succeeded. Images: %s", response.data.task_id, images) return images +def get_images_urls_from_response(response) -> Optional[str]: + """Returns the list of image urls from the Kling image generation task result. + Will not raise an error if the response is not valid. If there is only one image, returns the url as a string. If there are multiple images, returns a list of urls. + """ + if response and is_valid_image_response(response): + images = get_images_from_response(response) + image_urls = [str(image.url) for image in images] + return "\n".join(image_urls) + else: + return None + + def video_result_to_node_output( video: KlingVideoResult, ) -> tuple[VideoFromFile, str, str]: @@ -312,6 +358,7 @@ class KlingCameraControls(KlingNodeBase): RETURN_TYPES = ("CAMERA_CONTROL",) RETURN_NAMES = ("camera_control",) FUNCTION = "main" + API_NODE = False # This is just a helper node, it doesn't make an API call @classmethod def VALIDATE_INPUTS( @@ -421,6 +468,7 @@ class KlingTextToVideoNode(KlingNodeBase): "hidden": { "auth_token": "AUTH_TOKEN_COMFY_ORG", "comfy_api_key": "API_KEY_COMFY_ORG", + "unique_id": "UNIQUE_ID", }, } @@ -428,7 +476,9 @@ class KlingTextToVideoNode(KlingNodeBase): RETURN_NAMES = ("VIDEO", "video_id", "duration") DESCRIPTION = "Kling Text to Video Node" - def get_response(self, task_id: str, auth_kwargs: dict[str,str]) -> KlingText2VideoResponse: + def get_response( + self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None + ) -> KlingText2VideoResponse: return poll_until_finished( auth_kwargs, ApiEndpoint( @@ -437,6 +487,9 @@ class KlingTextToVideoNode(KlingNodeBase): request_model=EmptyRequest, response_model=KlingText2VideoResponse, ), + result_url_extractor=get_video_url_from_response, + estimated_duration=AVERAGE_DURATION_T2V, + node_id=node_id, ) def api_call( @@ -449,6 +502,7 @@ class KlingTextToVideoNode(KlingNodeBase): camera_control: Optional[KlingCameraControl] = None, model_name: Optional[str] = None, duration: Optional[str] = None, + unique_id: Optional[str] = None, **kwargs, ) -> tuple[VideoFromFile, str, str]: validate_prompts(prompt, negative_prompt, MAX_PROMPT_LENGTH_T2V) @@ -478,7 +532,9 @@ class KlingTextToVideoNode(KlingNodeBase): validate_task_creation_response(task_creation_response) task_id = task_creation_response.data.task_id - final_response = self.get_response(task_id, auth_kwargs=kwargs) + final_response = self.get_response( + task_id, auth_kwargs=kwargs, node_id=unique_id + ) validate_video_result_response(final_response) video = get_video_from_response(final_response) @@ -528,6 +584,7 @@ class KlingCameraControlT2VNode(KlingTextToVideoNode): "hidden": { "auth_token": "AUTH_TOKEN_COMFY_ORG", "comfy_api_key": "API_KEY_COMFY_ORG", + "unique_id": "UNIQUE_ID", }, } @@ -540,6 +597,7 @@ class KlingCameraControlT2VNode(KlingTextToVideoNode): cfg_scale: float, aspect_ratio: str, camera_control: Optional[KlingCameraControl] = None, + unique_id: Optional[str] = None, **kwargs, ): return super().api_call( @@ -613,6 +671,7 @@ class KlingImage2VideoNode(KlingNodeBase): "hidden": { "auth_token": "AUTH_TOKEN_COMFY_ORG", "comfy_api_key": "API_KEY_COMFY_ORG", + "unique_id": "UNIQUE_ID", }, } @@ -620,7 +679,9 @@ class KlingImage2VideoNode(KlingNodeBase): RETURN_NAMES = ("VIDEO", "video_id", "duration") DESCRIPTION = "Kling Image to Video Node" - def get_response(self, task_id: str, auth_kwargs: dict[str,str]) -> KlingImage2VideoResponse: + def get_response( + self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None + ) -> KlingImage2VideoResponse: return poll_until_finished( auth_kwargs, ApiEndpoint( @@ -629,6 +690,9 @@ class KlingImage2VideoNode(KlingNodeBase): request_model=KlingImage2VideoRequest, response_model=KlingImage2VideoResponse, ), + result_url_extractor=get_video_url_from_response, + estimated_duration=AVERAGE_DURATION_I2V, + node_id=node_id, ) def api_call( @@ -643,6 +707,7 @@ class KlingImage2VideoNode(KlingNodeBase): duration: str, camera_control: Optional[KlingCameraControl] = None, end_frame: Optional[torch.Tensor] = None, + unique_id: Optional[str] = None, **kwargs, ) -> tuple[VideoFromFile]: validate_prompts(prompt, negative_prompt, MAX_PROMPT_LENGTH_I2V) @@ -681,7 +746,9 @@ class KlingImage2VideoNode(KlingNodeBase): validate_task_creation_response(task_creation_response) task_id = task_creation_response.data.task_id - final_response = self.get_response(task_id, auth_kwargs=kwargs) + final_response = self.get_response( + task_id, auth_kwargs=kwargs, node_id=unique_id + ) validate_video_result_response(final_response) video = get_video_from_response(final_response) @@ -734,6 +801,7 @@ class KlingCameraControlI2VNode(KlingImage2VideoNode): "hidden": { "auth_token": "AUTH_TOKEN_COMFY_ORG", "comfy_api_key": "API_KEY_COMFY_ORG", + "unique_id": "UNIQUE_ID", }, } @@ -747,6 +815,7 @@ class KlingCameraControlI2VNode(KlingImage2VideoNode): cfg_scale: float, aspect_ratio: str, camera_control: KlingCameraControl, + unique_id: Optional[str] = None, **kwargs, ): return super().api_call( @@ -759,6 +828,7 @@ class KlingCameraControlI2VNode(KlingImage2VideoNode): prompt=prompt, negative_prompt=negative_prompt, camera_control=camera_control, + unique_id=unique_id, **kwargs, ) @@ -830,6 +900,7 @@ class KlingStartEndFrameNode(KlingImage2VideoNode): "hidden": { "auth_token": "AUTH_TOKEN_COMFY_ORG", "comfy_api_key": "API_KEY_COMFY_ORG", + "unique_id": "UNIQUE_ID", }, } @@ -844,6 +915,7 @@ class KlingStartEndFrameNode(KlingImage2VideoNode): cfg_scale: float, aspect_ratio: str, mode: str, + unique_id: Optional[str] = None, **kwargs, ): mode, duration, model_name = KlingStartEndFrameNode.get_mode_string_mapping()[ @@ -859,6 +931,7 @@ class KlingStartEndFrameNode(KlingImage2VideoNode): aspect_ratio=aspect_ratio, duration=duration, end_frame=end_frame, + unique_id=unique_id, **kwargs, ) @@ -892,6 +965,7 @@ class KlingVideoExtendNode(KlingNodeBase): "hidden": { "auth_token": "AUTH_TOKEN_COMFY_ORG", "comfy_api_key": "API_KEY_COMFY_ORG", + "unique_id": "UNIQUE_ID", }, } @@ -899,7 +973,9 @@ class KlingVideoExtendNode(KlingNodeBase): RETURN_NAMES = ("VIDEO", "video_id", "duration") DESCRIPTION = "Kling Video Extend Node. Extend videos made by other Kling nodes. The video_id is created by using other Kling Nodes." - def get_response(self, task_id: str, auth_kwargs: dict[str,str]) -> KlingVideoExtendResponse: + def get_response( + self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None + ) -> KlingVideoExtendResponse: return poll_until_finished( auth_kwargs, ApiEndpoint( @@ -908,6 +984,9 @@ class KlingVideoExtendNode(KlingNodeBase): request_model=EmptyRequest, response_model=KlingVideoExtendResponse, ), + result_url_extractor=get_video_url_from_response, + estimated_duration=AVERAGE_DURATION_VIDEO_EXTEND, + node_id=node_id, ) def api_call( @@ -916,6 +995,7 @@ class KlingVideoExtendNode(KlingNodeBase): negative_prompt: str, cfg_scale: float, video_id: str, + unique_id: Optional[str] = None, **kwargs, ) -> tuple[VideoFromFile, str, str]: validate_prompts(prompt, negative_prompt, MAX_PROMPT_LENGTH_T2V) @@ -939,7 +1019,9 @@ class KlingVideoExtendNode(KlingNodeBase): validate_task_creation_response(task_creation_response) task_id = task_creation_response.data.task_id - final_response = self.get_response(task_id, auth_kwargs=kwargs) + final_response = self.get_response( + task_id, auth_kwargs=kwargs, node_id=unique_id + ) validate_video_result_response(final_response) video = get_video_from_response(final_response) @@ -952,7 +1034,9 @@ class KlingVideoEffectsBase(KlingNodeBase): RETURN_TYPES = ("VIDEO", "STRING", "STRING") RETURN_NAMES = ("VIDEO", "video_id", "duration") - def get_response(self, task_id: str, auth_kwargs: dict[str,str]) -> KlingVideoEffectsResponse: + def get_response( + self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None + ) -> KlingVideoEffectsResponse: return poll_until_finished( auth_kwargs, ApiEndpoint( @@ -961,6 +1045,9 @@ class KlingVideoEffectsBase(KlingNodeBase): request_model=EmptyRequest, response_model=KlingVideoEffectsResponse, ), + result_url_extractor=get_video_url_from_response, + estimated_duration=AVERAGE_DURATION_VIDEO_EFFECTS, + node_id=node_id, ) def api_call( @@ -972,6 +1059,7 @@ class KlingVideoEffectsBase(KlingNodeBase): image_1: torch.Tensor, image_2: Optional[torch.Tensor] = None, mode: Optional[KlingVideoGenMode] = None, + unique_id: Optional[str] = None, **kwargs, ): if dual_character: @@ -1009,7 +1097,9 @@ class KlingVideoEffectsBase(KlingNodeBase): validate_task_creation_response(task_creation_response) task_id = task_creation_response.data.task_id - final_response = self.get_response(task_id, auth_kwargs=kwargs) + final_response = self.get_response( + task_id, auth_kwargs=kwargs, node_id=unique_id + ) validate_video_result_response(final_response) video = get_video_from_response(final_response) @@ -1053,6 +1143,7 @@ class KlingDualCharacterVideoEffectNode(KlingVideoEffectsBase): "hidden": { "auth_token": "AUTH_TOKEN_COMFY_ORG", "comfy_api_key": "API_KEY_COMFY_ORG", + "unique_id": "UNIQUE_ID", }, } @@ -1068,6 +1159,7 @@ class KlingDualCharacterVideoEffectNode(KlingVideoEffectsBase): model_name: KlingCharacterEffectModelName, mode: KlingVideoGenMode, duration: KlingVideoGenDuration, + unique_id: Optional[str] = None, **kwargs, ): video, _, duration = super().api_call( @@ -1078,10 +1170,12 @@ class KlingDualCharacterVideoEffectNode(KlingVideoEffectsBase): duration=duration, image_1=image_left, image_2=image_right, + unique_id=unique_id, **kwargs, ) return video, duration + class KlingSingleImageVideoEffectNode(KlingVideoEffectsBase): """Kling Single Image Video Effect Node""" @@ -1117,6 +1211,7 @@ class KlingSingleImageVideoEffectNode(KlingVideoEffectsBase): "hidden": { "auth_token": "AUTH_TOKEN_COMFY_ORG", "comfy_api_key": "API_KEY_COMFY_ORG", + "unique_id": "UNIQUE_ID", }, } @@ -1128,6 +1223,7 @@ class KlingSingleImageVideoEffectNode(KlingVideoEffectsBase): effect_scene: KlingSingleImageEffectsScene, model_name: KlingSingleImageEffectModelName, duration: KlingVideoGenDuration, + unique_id: Optional[str] = None, **kwargs, ): return super().api_call( @@ -1136,6 +1232,7 @@ class KlingSingleImageVideoEffectNode(KlingVideoEffectsBase): model_name=model_name, duration=duration, image_1=image, + unique_id=unique_id, **kwargs, ) @@ -1154,7 +1251,9 @@ class KlingLipSyncBase(KlingNodeBase): f"Text is too long. Maximum length is {MAX_PROMPT_LENGTH_LIP_SYNC} characters." ) - def get_response(self, task_id: str, auth_kwargs: dict[str,str]) -> KlingLipSyncResponse: + def get_response( + self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None + ) -> KlingLipSyncResponse: """Polls the Kling API endpoint until the task reaches a terminal state.""" return poll_until_finished( auth_kwargs, @@ -1164,6 +1263,9 @@ class KlingLipSyncBase(KlingNodeBase): request_model=EmptyRequest, response_model=KlingLipSyncResponse, ), + result_url_extractor=get_video_url_from_response, + estimated_duration=AVERAGE_DURATION_LIP_SYNC, + node_id=node_id, ) def api_call( @@ -1175,7 +1277,8 @@ class KlingLipSyncBase(KlingNodeBase): text: Optional[str] = None, voice_speed: Optional[float] = None, voice_id: Optional[str] = None, - **kwargs + unique_id: Optional[str] = None, + **kwargs, ) -> tuple[VideoFromFile, str, str]: if text: self.validate_text(text) @@ -1217,7 +1320,9 @@ class KlingLipSyncBase(KlingNodeBase): validate_task_creation_response(task_creation_response) task_id = task_creation_response.data.task_id - final_response = self.get_response(task_id, auth_kwargs=kwargs) + final_response = self.get_response( + task_id, auth_kwargs=kwargs, node_id=unique_id + ) validate_video_result_response(final_response) video = get_video_from_response(final_response) @@ -1243,6 +1348,7 @@ class KlingLipSyncAudioToVideoNode(KlingLipSyncBase): "hidden": { "auth_token": "AUTH_TOKEN_COMFY_ORG", "comfy_api_key": "API_KEY_COMFY_ORG", + "unique_id": "UNIQUE_ID", }, } @@ -1253,6 +1359,7 @@ class KlingLipSyncAudioToVideoNode(KlingLipSyncBase): video: VideoInput, audio: AudioInput, voice_language: str, + unique_id: Optional[str] = None, **kwargs, ): return super().api_call( @@ -1260,6 +1367,7 @@ class KlingLipSyncAudioToVideoNode(KlingLipSyncBase): audio=audio, voice_language=voice_language, mode="audio2video", + unique_id=unique_id, **kwargs, ) @@ -1352,6 +1460,7 @@ class KlingLipSyncTextToVideoNode(KlingLipSyncBase): "hidden": { "auth_token": "AUTH_TOKEN_COMFY_ORG", "comfy_api_key": "API_KEY_COMFY_ORG", + "unique_id": "UNIQUE_ID", }, } @@ -1363,6 +1472,7 @@ class KlingLipSyncTextToVideoNode(KlingLipSyncBase): text: str, voice: str, voice_speed: float, + unique_id: Optional[str] = None, **kwargs, ): voice_id, voice_language = KlingLipSyncTextToVideoNode.get_voice_config()[voice] @@ -1373,6 +1483,7 @@ class KlingLipSyncTextToVideoNode(KlingLipSyncBase): voice_id=voice_id, voice_speed=voice_speed, mode="text2video", + unique_id=unique_id, **kwargs, ) @@ -1413,13 +1524,14 @@ class KlingVirtualTryOnNode(KlingImageGenerationBase): "hidden": { "auth_token": "AUTH_TOKEN_COMFY_ORG", "comfy_api_key": "API_KEY_COMFY_ORG", + "unique_id": "UNIQUE_ID", }, } - DESCRIPTION = "Kling Virtual Try On Node. Input a human image and a cloth image to try on the cloth on the human." + DESCRIPTION = "Kling Virtual Try On Node. Input a human image and a cloth image to try on the cloth on the human. You can merge multiple clothing item pictures into one image with a white background." def get_response( - self, task_id: str, auth_kwargs: dict[str,str] = None + self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None ) -> KlingVirtualTryOnResponse: return poll_until_finished( auth_kwargs, @@ -1429,6 +1541,9 @@ class KlingVirtualTryOnNode(KlingImageGenerationBase): request_model=EmptyRequest, response_model=KlingVirtualTryOnResponse, ), + result_url_extractor=get_images_urls_from_response, + estimated_duration=AVERAGE_DURATION_VIRTUAL_TRY_ON, + node_id=node_id, ) def api_call( @@ -1436,6 +1551,7 @@ class KlingVirtualTryOnNode(KlingImageGenerationBase): human_image: torch.Tensor, cloth_image: torch.Tensor, model_name: KlingVirtualTryOnModelName, + unique_id: Optional[str] = None, **kwargs, ): initial_operation = SynchronousOperation( @@ -1457,7 +1573,9 @@ class KlingVirtualTryOnNode(KlingImageGenerationBase): validate_task_creation_response(task_creation_response) task_id = task_creation_response.data.task_id - final_response = self.get_response(task_id, auth_kwargs=kwargs) + final_response = self.get_response( + task_id, auth_kwargs=kwargs, node_id=unique_id + ) validate_image_result_response(final_response) images = get_images_from_response(final_response) @@ -1528,13 +1646,17 @@ class KlingImageGenerationNode(KlingImageGenerationBase): "hidden": { "auth_token": "AUTH_TOKEN_COMFY_ORG", "comfy_api_key": "API_KEY_COMFY_ORG", + "unique_id": "UNIQUE_ID", }, } DESCRIPTION = "Kling Image Generation Node. Generate an image from a text prompt with an optional reference image." def get_response( - self, task_id: str, auth_kwargs: Optional[dict[str,str]] = None + self, + task_id: str, + auth_kwargs: Optional[dict[str, str]], + node_id: Optional[str] = None, ) -> KlingImageGenerationsResponse: return poll_until_finished( auth_kwargs, @@ -1544,6 +1666,9 @@ class KlingImageGenerationNode(KlingImageGenerationBase): request_model=EmptyRequest, response_model=KlingImageGenerationsResponse, ), + result_url_extractor=get_images_urls_from_response, + estimated_duration=AVERAGE_DURATION_IMAGE_GEN, + node_id=node_id, ) def api_call( @@ -1557,6 +1682,7 @@ class KlingImageGenerationNode(KlingImageGenerationBase): n: int, aspect_ratio: KlingImageGenAspectRatio, image: Optional[torch.Tensor] = None, + unique_id: Optional[str] = None, **kwargs, ): self.validate_prompt(prompt, negative_prompt) @@ -1589,7 +1715,9 @@ class KlingImageGenerationNode(KlingImageGenerationBase): validate_task_creation_response(task_creation_response) task_id = task_creation_response.data.task_id - final_response = self.get_response(task_id, auth_kwargs=kwargs) + final_response = self.get_response( + task_id, auth_kwargs=kwargs, node_id=unique_id + ) validate_image_result_response(final_response) images = get_images_from_response(final_response) diff --git a/comfy_api_nodes/nodes_luma.py b/comfy_api_nodes/nodes_luma.py index bd33a53e..525dc38e 100644 --- a/comfy_api_nodes/nodes_luma.py +++ b/comfy_api_nodes/nodes_luma.py @@ -36,11 +36,20 @@ from comfy_api_nodes.apinode_utils import ( process_image_response, validate_string, ) +from server import PromptServer import requests import torch from io import BytesIO +LUMA_T2V_AVERAGE_DURATION = 105 +LUMA_I2V_AVERAGE_DURATION = 100 + +def image_result_url_extractor(response: LumaGeneration): + return response.assets.image if hasattr(response, "assets") and hasattr(response.assets, "image") else None + +def video_result_url_extractor(response: LumaGeneration): + return response.assets.video if hasattr(response, "assets") and hasattr(response.assets, "video") else None class LumaReferenceNode(ComfyNodeABC): """ @@ -204,6 +213,7 @@ class LumaImageGenerationNode(ComfyNodeABC): "hidden": { "auth_token": "AUTH_TOKEN_COMFY_ORG", "comfy_api_key": "API_KEY_COMFY_ORG", + "unique_id": "UNIQUE_ID", }, } @@ -217,6 +227,7 @@ class LumaImageGenerationNode(ComfyNodeABC): image_luma_ref: LumaReferenceChain = None, style_image: torch.Tensor = None, character_image: torch.Tensor = None, + unique_id: str = None, **kwargs, ): validate_string(prompt, strip_whitespace=True, min_length=3) @@ -271,6 +282,8 @@ class LumaImageGenerationNode(ComfyNodeABC): completed_statuses=[LumaState.completed], failed_statuses=[LumaState.failed], status_extractor=lambda x: x.state, + result_url_extractor=image_result_url_extractor, + node_id=unique_id, auth_kwargs=kwargs, ) response_poll = operation.execute() @@ -353,6 +366,7 @@ class LumaImageModifyNode(ComfyNodeABC): "hidden": { "auth_token": "AUTH_TOKEN_COMFY_ORG", "comfy_api_key": "API_KEY_COMFY_ORG", + "unique_id": "UNIQUE_ID", }, } @@ -363,6 +377,7 @@ class LumaImageModifyNode(ComfyNodeABC): image: torch.Tensor, image_weight: float, seed, + unique_id: str = None, **kwargs, ): # first, upload image @@ -399,6 +414,8 @@ class LumaImageModifyNode(ComfyNodeABC): completed_statuses=[LumaState.completed], failed_statuses=[LumaState.failed], status_extractor=lambda x: x.state, + result_url_extractor=image_result_url_extractor, + node_id=unique_id, auth_kwargs=kwargs, ) response_poll = operation.execute() @@ -473,6 +490,7 @@ class LumaTextToVideoGenerationNode(ComfyNodeABC): "hidden": { "auth_token": "AUTH_TOKEN_COMFY_ORG", "comfy_api_key": "API_KEY_COMFY_ORG", + "unique_id": "UNIQUE_ID", }, } @@ -486,6 +504,7 @@ class LumaTextToVideoGenerationNode(ComfyNodeABC): loop: bool, seed, luma_concepts: LumaConceptChain = None, + unique_id: str = None, **kwargs, ): validate_string(prompt, strip_whitespace=False, min_length=3) @@ -512,6 +531,9 @@ class LumaTextToVideoGenerationNode(ComfyNodeABC): ) response_api: LumaGeneration = operation.execute() + if unique_id: + PromptServer.instance.send_progress_text(f"Luma video generation started: {response_api.id}", unique_id) + operation = PollingOperation( poll_endpoint=ApiEndpoint( path=f"/proxy/luma/generations/{response_api.id}", @@ -522,6 +544,9 @@ class LumaTextToVideoGenerationNode(ComfyNodeABC): completed_statuses=[LumaState.completed], failed_statuses=[LumaState.failed], status_extractor=lambda x: x.state, + result_url_extractor=video_result_url_extractor, + node_id=unique_id, + estimated_duration=LUMA_T2V_AVERAGE_DURATION, auth_kwargs=kwargs, ) response_poll = operation.execute() @@ -597,6 +622,7 @@ class LumaImageToVideoGenerationNode(ComfyNodeABC): "hidden": { "auth_token": "AUTH_TOKEN_COMFY_ORG", "comfy_api_key": "API_KEY_COMFY_ORG", + "unique_id": "UNIQUE_ID", }, } @@ -611,6 +637,7 @@ class LumaImageToVideoGenerationNode(ComfyNodeABC): first_image: torch.Tensor = None, last_image: torch.Tensor = None, luma_concepts: LumaConceptChain = None, + unique_id: str = None, **kwargs, ): if first_image is None and last_image is None: @@ -642,6 +669,9 @@ class LumaImageToVideoGenerationNode(ComfyNodeABC): ) response_api: LumaGeneration = operation.execute() + if unique_id: + PromptServer.instance.send_progress_text(f"Luma video generation started: {response_api.id}", unique_id) + operation = PollingOperation( poll_endpoint=ApiEndpoint( path=f"/proxy/luma/generations/{response_api.id}", @@ -652,6 +682,9 @@ class LumaImageToVideoGenerationNode(ComfyNodeABC): completed_statuses=[LumaState.completed], failed_statuses=[LumaState.failed], status_extractor=lambda x: x.state, + result_url_extractor=video_result_url_extractor, + node_id=unique_id, + estimated_duration=LUMA_I2V_AVERAGE_DURATION, auth_kwargs=kwargs, ) response_poll = operation.execute() diff --git a/comfy_api_nodes/nodes_minimax.py b/comfy_api_nodes/nodes_minimax.py index fd64aeb0..9b46636d 100644 --- a/comfy_api_nodes/nodes_minimax.py +++ b/comfy_api_nodes/nodes_minimax.py @@ -1,3 +1,7 @@ +from typing import Union +import logging +import torch + from comfy.comfy_types.node_typing import IO from comfy_api.input_impl.video_types import VideoFromFile from comfy_api_nodes.apis import ( @@ -20,16 +24,19 @@ from comfy_api_nodes.apinode_utils import ( upload_images_to_comfyapi, validate_string, ) +from server import PromptServer -import torch -import logging +I2V_AVERAGE_DURATION = 114 +T2V_AVERAGE_DURATION = 234 class MinimaxTextToVideoNode: """ Generates videos synchronously based on a prompt, and optional parameters using MiniMax's API. """ + AVERAGE_DURATION = T2V_AVERAGE_DURATION + @classmethod def INPUT_TYPES(s): return { @@ -68,6 +75,7 @@ class MinimaxTextToVideoNode: "hidden": { "auth_token": "AUTH_TOKEN_COMFY_ORG", "comfy_api_key": "API_KEY_COMFY_ORG", + "unique_id": "UNIQUE_ID", }, } @@ -85,6 +93,7 @@ class MinimaxTextToVideoNode: model="T2V-01", image: torch.Tensor=None, # used for ImageToVideo subject: torch.Tensor=None, # used for SubjectToVideo + unique_id: Union[str, None]=None, **kwargs, ): ''' @@ -138,6 +147,8 @@ class MinimaxTextToVideoNode: completed_statuses=["Success"], failed_statuses=["Fail"], status_extractor=lambda x: x.status.value, + estimated_duration=self.AVERAGE_DURATION, + node_id=unique_id, auth_kwargs=kwargs, ) task_result = video_generate_operation.execute() @@ -164,6 +175,12 @@ class MinimaxTextToVideoNode: f"No video was found in the response. Full response: {file_result.model_dump()}" ) logging.info(f"Generated video URL: {file_url}") + if unique_id: + if hasattr(file_result.file, "backup_download_url"): + message = f"Result URL: {file_url}\nBackup URL: {file_result.file.backup_download_url}" + else: + message = f"Result URL: {file_url}" + PromptServer.instance.send_progress_text(message, unique_id) video_io = download_url_to_bytesio(file_url) if video_io is None: @@ -178,6 +195,8 @@ class MinimaxImageToVideoNode(MinimaxTextToVideoNode): Generates videos synchronously based on an image and prompt, and optional parameters using MiniMax's API. """ + AVERAGE_DURATION = I2V_AVERAGE_DURATION + @classmethod def INPUT_TYPES(s): return { @@ -223,6 +242,7 @@ class MinimaxImageToVideoNode(MinimaxTextToVideoNode): "hidden": { "auth_token": "AUTH_TOKEN_COMFY_ORG", "comfy_api_key": "API_KEY_COMFY_ORG", + "unique_id": "UNIQUE_ID", }, } @@ -239,6 +259,8 @@ class MinimaxSubjectToVideoNode(MinimaxTextToVideoNode): Generates videos synchronously based on an image and prompt, and optional parameters using MiniMax's API. """ + AVERAGE_DURATION = T2V_AVERAGE_DURATION + @classmethod def INPUT_TYPES(s): return { @@ -282,6 +304,7 @@ class MinimaxSubjectToVideoNode(MinimaxTextToVideoNode): "hidden": { "auth_token": "AUTH_TOKEN_COMFY_ORG", "comfy_api_key": "API_KEY_COMFY_ORG", + "unique_id": "UNIQUE_ID", }, } diff --git a/comfy_api_nodes/nodes_openai.py b/comfy_api_nodes/nodes_openai.py index c63908be..ce8054af 100644 --- a/comfy_api_nodes/nodes_openai.py +++ b/comfy_api_nodes/nodes_openai.py @@ -96,6 +96,7 @@ class OpenAIDalle2(ComfyNodeABC): "hidden": { "auth_token": "AUTH_TOKEN_COMFY_ORG", "comfy_api_key": "API_KEY_COMFY_ORG", + "unique_id": "UNIQUE_ID", }, } @@ -113,6 +114,7 @@ class OpenAIDalle2(ComfyNodeABC): mask=None, n=1, size="1024x1024", + unique_id=None, **kwargs ): validate_string(prompt, strip_whitespace=False) @@ -176,7 +178,7 @@ class OpenAIDalle2(ComfyNodeABC): response = operation.execute() - img_tensor = validate_and_cast_response(response) + img_tensor = validate_and_cast_response(response, node_id=unique_id) return (img_tensor,) @@ -242,6 +244,7 @@ class OpenAIDalle3(ComfyNodeABC): "hidden": { "auth_token": "AUTH_TOKEN_COMFY_ORG", "comfy_api_key": "API_KEY_COMFY_ORG", + "unique_id": "UNIQUE_ID", }, } @@ -258,6 +261,7 @@ class OpenAIDalle3(ComfyNodeABC): style="natural", quality="standard", size="1024x1024", + unique_id=None, **kwargs ): validate_string(prompt, strip_whitespace=False) @@ -284,7 +288,7 @@ class OpenAIDalle3(ComfyNodeABC): response = operation.execute() - img_tensor = validate_and_cast_response(response) + img_tensor = validate_and_cast_response(response, node_id=unique_id) return (img_tensor,) @@ -375,6 +379,7 @@ class OpenAIGPTImage1(ComfyNodeABC): "hidden": { "auth_token": "AUTH_TOKEN_COMFY_ORG", "comfy_api_key": "API_KEY_COMFY_ORG", + "unique_id": "UNIQUE_ID", }, } @@ -394,6 +399,7 @@ class OpenAIGPTImage1(ComfyNodeABC): mask=None, n=1, size="1024x1024", + unique_id=None, **kwargs ): validate_string(prompt, strip_whitespace=False) @@ -476,7 +482,7 @@ class OpenAIGPTImage1(ComfyNodeABC): response = operation.execute() - img_tensor = validate_and_cast_response(response) + img_tensor = validate_and_cast_response(response, node_id=unique_id) return (img_tensor,) diff --git a/comfy_api_nodes/nodes_pika.py b/comfy_api_nodes/nodes_pika.py index 08ec9cf0..30562790 100644 --- a/comfy_api_nodes/nodes_pika.py +++ b/comfy_api_nodes/nodes_pika.py @@ -121,7 +121,10 @@ class PikaNodeBase(ComfyNodeABC): RETURN_TYPES = ("VIDEO",) def poll_for_task_status( - self, task_id: str, auth_kwargs: Optional[dict[str,str]] = None + self, + task_id: str, + auth_kwargs: Optional[dict[str, str]] = None, + node_id: Optional[str] = None, ) -> PikaGenerateResponse: polling_operation = PollingOperation( poll_endpoint=ApiEndpoint( @@ -141,13 +144,19 @@ class PikaNodeBase(ComfyNodeABC): response.progress if hasattr(response, "progress") else None ), auth_kwargs=auth_kwargs, + result_url_extractor=lambda response: ( + response.url if hasattr(response, "url") else None + ), + node_id=node_id, + estimated_duration=60 ) return polling_operation.execute() def execute_task( self, initial_operation: SynchronousOperation[R, PikaGenerateResponse], - auth_kwargs: Optional[dict[str,str]] = None, + auth_kwargs: Optional[dict[str, str]] = None, + node_id: Optional[str] = None, ) -> tuple[VideoFromFile]: """Executes the initial operation then polls for the task status until it is completed. @@ -208,7 +217,8 @@ class PikaImageToVideoV2_2(PikaNodeBase): seed: int, resolution: str, duration: int, - **kwargs + unique_id: str, + **kwargs, ) -> tuple[VideoFromFile]: # Convert image to BytesIO image_bytes_io = tensor_to_bytesio(image) @@ -238,7 +248,7 @@ class PikaImageToVideoV2_2(PikaNodeBase): auth_kwargs=kwargs, ) - return self.execute_task(initial_operation, auth_kwargs=kwargs) + return self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id) class PikaTextToVideoNodeV2_2(PikaNodeBase): @@ -262,6 +272,7 @@ class PikaTextToVideoNodeV2_2(PikaNodeBase): "hidden": { "auth_token": "AUTH_TOKEN_COMFY_ORG", "comfy_api_key": "API_KEY_COMFY_ORG", + "unique_id": "UNIQUE_ID", }, } @@ -275,6 +286,7 @@ class PikaTextToVideoNodeV2_2(PikaNodeBase): resolution: str, duration: int, aspect_ratio: float, + unique_id: str, **kwargs, ) -> tuple[VideoFromFile]: initial_operation = SynchronousOperation( @@ -296,7 +308,7 @@ class PikaTextToVideoNodeV2_2(PikaNodeBase): content_type="application/x-www-form-urlencoded", ) - return self.execute_task(initial_operation, auth_kwargs=kwargs) + return self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id) class PikaScenesV2_2(PikaNodeBase): @@ -340,6 +352,7 @@ class PikaScenesV2_2(PikaNodeBase): "hidden": { "auth_token": "AUTH_TOKEN_COMFY_ORG", "comfy_api_key": "API_KEY_COMFY_ORG", + "unique_id": "UNIQUE_ID", }, } @@ -354,6 +367,7 @@ class PikaScenesV2_2(PikaNodeBase): duration: int, ingredients_mode: str, aspect_ratio: float, + unique_id: str, image_ingredient_1: Optional[torch.Tensor] = None, image_ingredient_2: Optional[torch.Tensor] = None, image_ingredient_3: Optional[torch.Tensor] = None, @@ -403,7 +417,7 @@ class PikaScenesV2_2(PikaNodeBase): auth_kwargs=kwargs, ) - return self.execute_task(initial_operation, auth_kwargs=kwargs) + return self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id) class PikAdditionsNode(PikaNodeBase): @@ -439,6 +453,7 @@ class PikAdditionsNode(PikaNodeBase): "hidden": { "auth_token": "AUTH_TOKEN_COMFY_ORG", "comfy_api_key": "API_KEY_COMFY_ORG", + "unique_id": "UNIQUE_ID", }, } @@ -451,6 +466,7 @@ class PikAdditionsNode(PikaNodeBase): prompt_text: str, negative_prompt: str, seed: int, + unique_id: str, **kwargs, ) -> tuple[VideoFromFile]: # Convert video to BytesIO @@ -487,7 +503,7 @@ class PikAdditionsNode(PikaNodeBase): auth_kwargs=kwargs, ) - return self.execute_task(initial_operation, auth_kwargs=kwargs) + return self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id) class PikaSwapsNode(PikaNodeBase): @@ -532,6 +548,7 @@ class PikaSwapsNode(PikaNodeBase): "hidden": { "auth_token": "AUTH_TOKEN_COMFY_ORG", "comfy_api_key": "API_KEY_COMFY_ORG", + "unique_id": "UNIQUE_ID", }, } @@ -546,6 +563,7 @@ class PikaSwapsNode(PikaNodeBase): prompt_text: str, negative_prompt: str, seed: int, + unique_id: str, **kwargs, ) -> tuple[VideoFromFile]: # Convert video to BytesIO @@ -592,7 +610,7 @@ class PikaSwapsNode(PikaNodeBase): auth_kwargs=kwargs, ) - return self.execute_task(initial_operation, auth_kwargs=kwargs) + return self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id) class PikaffectsNode(PikaNodeBase): @@ -637,6 +655,7 @@ class PikaffectsNode(PikaNodeBase): "hidden": { "auth_token": "AUTH_TOKEN_COMFY_ORG", "comfy_api_key": "API_KEY_COMFY_ORG", + "unique_id": "UNIQUE_ID", }, } @@ -649,6 +668,7 @@ class PikaffectsNode(PikaNodeBase): prompt_text: str, negative_prompt: str, seed: int, + unique_id: str, **kwargs, ) -> tuple[VideoFromFile]: @@ -670,7 +690,7 @@ class PikaffectsNode(PikaNodeBase): auth_kwargs=kwargs, ) - return self.execute_task(initial_operation, auth_kwargs=kwargs) + return self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id) class PikaStartEndFrameNode2_2(PikaNodeBase): @@ -689,6 +709,7 @@ class PikaStartEndFrameNode2_2(PikaNodeBase): "hidden": { "auth_token": "AUTH_TOKEN_COMFY_ORG", "comfy_api_key": "API_KEY_COMFY_ORG", + "unique_id": "UNIQUE_ID", }, } @@ -703,6 +724,7 @@ class PikaStartEndFrameNode2_2(PikaNodeBase): seed: int, resolution: str, duration: int, + unique_id: str, **kwargs, ) -> tuple[VideoFromFile]: @@ -733,7 +755,7 @@ class PikaStartEndFrameNode2_2(PikaNodeBase): auth_kwargs=kwargs, ) - return self.execute_task(initial_operation, auth_kwargs=kwargs) + return self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id) NODE_CLASS_MAPPINGS = { diff --git a/comfy_api_nodes/nodes_pixverse.py b/comfy_api_nodes/nodes_pixverse.py index 0c29e77c..ef4a9a80 100644 --- a/comfy_api_nodes/nodes_pixverse.py +++ b/comfy_api_nodes/nodes_pixverse.py @@ -1,5 +1,5 @@ from inspect import cleandoc - +from typing import Optional from comfy_api_nodes.apis.pixverse_api import ( PixverseTextVideoRequest, PixverseImageVideoRequest, @@ -34,11 +34,22 @@ import requests from io import BytesIO +AVERAGE_DURATION_T2V = 32 +AVERAGE_DURATION_I2V = 30 +AVERAGE_DURATION_T2T = 52 + + +def get_video_url_from_response( + response: PixverseGenerationStatusResponse, +) -> Optional[str]: + if response.Resp is None or response.Resp.url is None: + return None + return str(response.Resp.url) + + def upload_image_to_pixverse(image: torch.Tensor, auth_kwargs=None): # first, upload image to Pixverse and get image id to use in actual generation call - files = { - "image": tensor_to_bytesio(image) - } + files = {"image": tensor_to_bytesio(image)} operation = SynchronousOperation( endpoint=ApiEndpoint( path="/proxy/pixverse/image/upload", @@ -54,7 +65,9 @@ def upload_image_to_pixverse(image: torch.Tensor, auth_kwargs=None): response_upload: PixverseImageUploadResponse = operation.execute() if response_upload.Resp is None: - raise Exception(f"PixVerse image upload request failed: '{response_upload.ErrMsg}'") + raise Exception( + f"PixVerse image upload request failed: '{response_upload.ErrMsg}'" + ) return response_upload.Resp.img_id @@ -73,7 +86,7 @@ class PixverseTemplateNode: def INPUT_TYPES(s): return { "required": { - "template": (list(pixverse_templates.keys()), ), + "template": (list(pixverse_templates.keys()),), } } @@ -87,7 +100,7 @@ class PixverseTemplateNode: class PixverseTextToVideoNode(ComfyNodeABC): """ - Generates videos synchronously based on prompt and output_size. + Generates videos based on prompt and output_size. """ RETURN_TYPES = (IO.VIDEO,) @@ -108,9 +121,7 @@ class PixverseTextToVideoNode(ComfyNodeABC): "tooltip": "Prompt for the video generation", }, ), - "aspect_ratio": ( - [ratio.value for ratio in PixverseAspectRatio], - ), + "aspect_ratio": ([ratio.value for ratio in PixverseAspectRatio],), "quality": ( [resolution.value for resolution in PixverseQuality], { @@ -143,12 +154,13 @@ class PixverseTextToVideoNode(ComfyNodeABC): PixverseIO.TEMPLATE, { "tooltip": "An optional template to influence style of generation, created by the PixVerse Template node." - } - ) + }, + ), }, "hidden": { "auth_token": "AUTH_TOKEN_COMFY_ORG", "comfy_api_key": "API_KEY_COMFY_ORG", + "unique_id": "UNIQUE_ID", }, } @@ -160,8 +172,9 @@ class PixverseTextToVideoNode(ComfyNodeABC): duration_seconds: int, motion_mode: str, seed, - negative_prompt: str=None, - pixverse_template: int=None, + negative_prompt: str = None, + pixverse_template: int = None, + unique_id: Optional[str] = None, **kwargs, ): validate_string(prompt, strip_whitespace=False) @@ -205,19 +218,27 @@ class PixverseTextToVideoNode(ComfyNodeABC): response_model=PixverseGenerationStatusResponse, ), completed_statuses=[PixverseStatus.successful], - failed_statuses=[PixverseStatus.contents_moderation, PixverseStatus.failed, PixverseStatus.deleted], + failed_statuses=[ + PixverseStatus.contents_moderation, + PixverseStatus.failed, + PixverseStatus.deleted, + ], status_extractor=lambda x: x.Resp.status, auth_kwargs=kwargs, + node_id=unique_id, + result_url_extractor=get_video_url_from_response, + estimated_duration=AVERAGE_DURATION_T2V, ) response_poll = operation.execute() vid_response = requests.get(response_poll.Resp.url) + return (VideoFromFile(BytesIO(vid_response.content)),) class PixverseImageToVideoNode(ComfyNodeABC): """ - Generates videos synchronously based on prompt and output_size. + Generates videos based on prompt and output_size. """ RETURN_TYPES = (IO.VIDEO,) @@ -230,9 +251,7 @@ class PixverseImageToVideoNode(ComfyNodeABC): def INPUT_TYPES(s): return { "required": { - "image": ( - IO.IMAGE, - ), + "image": (IO.IMAGE,), "prompt": ( IO.STRING, { @@ -273,12 +292,13 @@ class PixverseImageToVideoNode(ComfyNodeABC): PixverseIO.TEMPLATE, { "tooltip": "An optional template to influence style of generation, created by the PixVerse Template node." - } - ) + }, + ), }, "hidden": { "auth_token": "AUTH_TOKEN_COMFY_ORG", "comfy_api_key": "API_KEY_COMFY_ORG", + "unique_id": "UNIQUE_ID", }, } @@ -290,8 +310,9 @@ class PixverseImageToVideoNode(ComfyNodeABC): duration_seconds: int, motion_mode: str, seed, - negative_prompt: str=None, - pixverse_template: int=None, + negative_prompt: str = None, + pixverse_template: int = None, + unique_id: Optional[str] = None, **kwargs, ): validate_string(prompt, strip_whitespace=False) @@ -337,9 +358,16 @@ class PixverseImageToVideoNode(ComfyNodeABC): response_model=PixverseGenerationStatusResponse, ), completed_statuses=[PixverseStatus.successful], - failed_statuses=[PixverseStatus.contents_moderation, PixverseStatus.failed, PixverseStatus.deleted], + failed_statuses=[ + PixverseStatus.contents_moderation, + PixverseStatus.failed, + PixverseStatus.deleted, + ], status_extractor=lambda x: x.Resp.status, auth_kwargs=kwargs, + node_id=unique_id, + result_url_extractor=get_video_url_from_response, + estimated_duration=AVERAGE_DURATION_I2V, ) response_poll = operation.execute() @@ -349,7 +377,7 @@ class PixverseImageToVideoNode(ComfyNodeABC): class PixverseTransitionVideoNode(ComfyNodeABC): """ - Generates videos synchronously based on prompt and output_size. + Generates videos based on prompt and output_size. """ RETURN_TYPES = (IO.VIDEO,) @@ -362,12 +390,8 @@ class PixverseTransitionVideoNode(ComfyNodeABC): def INPUT_TYPES(s): return { "required": { - "first_frame": ( - IO.IMAGE, - ), - "last_frame": ( - IO.IMAGE, - ), + "first_frame": (IO.IMAGE,), + "last_frame": (IO.IMAGE,), "prompt": ( IO.STRING, { @@ -408,6 +432,7 @@ class PixverseTransitionVideoNode(ComfyNodeABC): "hidden": { "auth_token": "AUTH_TOKEN_COMFY_ORG", "comfy_api_key": "API_KEY_COMFY_ORG", + "unique_id": "UNIQUE_ID", }, } @@ -420,7 +445,8 @@ class PixverseTransitionVideoNode(ComfyNodeABC): duration_seconds: int, motion_mode: str, seed, - negative_prompt: str=None, + negative_prompt: str = None, + unique_id: Optional[str] = None, **kwargs, ): validate_string(prompt, strip_whitespace=False) @@ -467,9 +493,16 @@ class PixverseTransitionVideoNode(ComfyNodeABC): response_model=PixverseGenerationStatusResponse, ), completed_statuses=[PixverseStatus.successful], - failed_statuses=[PixverseStatus.contents_moderation, PixverseStatus.failed, PixverseStatus.deleted], + failed_statuses=[ + PixverseStatus.contents_moderation, + PixverseStatus.failed, + PixverseStatus.deleted, + ], status_extractor=lambda x: x.Resp.status, auth_kwargs=kwargs, + node_id=unique_id, + result_url_extractor=get_video_url_from_response, + estimated_duration=AVERAGE_DURATION_T2V, ) response_poll = operation.execute() diff --git a/comfy_api_nodes/nodes_recraft.py b/comfy_api_nodes/nodes_recraft.py index 767d93e3..e369c4b7 100644 --- a/comfy_api_nodes/nodes_recraft.py +++ b/comfy_api_nodes/nodes_recraft.py @@ -1,5 +1,6 @@ from __future__ import annotations from inspect import cleandoc +from typing import Optional from comfy.utils import ProgressBar from comfy_extras.nodes_images import SVG # Added from comfy.comfy_types.node_typing import IO @@ -29,6 +30,8 @@ from comfy_api_nodes.apinode_utils import ( resize_mask_to_image, validate_string, ) +from server import PromptServer + import torch from io import BytesIO from PIL import UnidentifiedImageError @@ -388,6 +391,7 @@ class RecraftTextToImageNode: "hidden": { "auth_token": "AUTH_TOKEN_COMFY_ORG", "comfy_api_key": "API_KEY_COMFY_ORG", + "unique_id": "UNIQUE_ID", }, } @@ -400,6 +404,7 @@ class RecraftTextToImageNode: recraft_style: RecraftStyle = None, negative_prompt: str = None, recraft_controls: RecraftControls = None, + unique_id: Optional[str] = None, **kwargs, ): validate_string(prompt, strip_whitespace=False, max_length=1000) @@ -436,8 +441,15 @@ class RecraftTextToImageNode: ) response: RecraftImageGenerationResponse = operation.execute() images = [] + urls = [] for data in response.data: with handle_recraft_image_output(): + if unique_id and data.url: + urls.append(data.url) + urls_string = '\n'.join(urls) + PromptServer.instance.send_progress_text( + f"Result URL: {urls_string}", unique_id + ) image = bytesio_to_image_tensor( download_url_to_bytesio(data.url, timeout=1024) ) @@ -763,6 +775,7 @@ class RecraftTextToVectorNode: "hidden": { "auth_token": "AUTH_TOKEN_COMFY_ORG", "comfy_api_key": "API_KEY_COMFY_ORG", + "unique_id": "UNIQUE_ID", }, } @@ -775,6 +788,7 @@ class RecraftTextToVectorNode: seed, negative_prompt: str = None, recraft_controls: RecraftControls = None, + unique_id: Optional[str] = None, **kwargs, ): validate_string(prompt, strip_whitespace=False, max_length=1000) @@ -809,7 +823,14 @@ class RecraftTextToVectorNode: ) response: RecraftImageGenerationResponse = operation.execute() svg_data = [] + urls = [] for data in response.data: + if unique_id and data.url: + urls.append(data.url) + # Print result on each iteration in case of error + PromptServer.instance.send_progress_text( + f"Result URL: {' '.join(urls)}", unique_id + ) svg_data.append(download_url_to_bytesio(data.url, timeout=1024)) return (SVG(svg_data),) diff --git a/comfy_api_nodes/nodes_veo2.py b/comfy_api_nodes/nodes_veo2.py index 2740179c..df846d5d 100644 --- a/comfy_api_nodes/nodes_veo2.py +++ b/comfy_api_nodes/nodes_veo2.py @@ -3,6 +3,7 @@ import logging import base64 import requests import torch +from typing import Optional from comfy.comfy_types.node_typing import IO, ComfyNodeABC from comfy_api.input_impl.video_types import VideoFromFile @@ -24,6 +25,8 @@ from comfy_api_nodes.apinode_utils import ( tensor_to_base64_string ) +AVERAGE_DURATION_VIDEO_GEN = 32 + def convert_image_to_base64(image: torch.Tensor): if image is None: return None @@ -31,6 +34,22 @@ def convert_image_to_base64(image: torch.Tensor): scaled_image = downscale_image_tensor(image, total_pixels=2048*2048) return tensor_to_base64_string(scaled_image) + +def get_video_url_from_response(poll_response: Veo2GenVidPollResponse) -> Optional[str]: + if ( + poll_response.response + and hasattr(poll_response.response, "videos") + and poll_response.response.videos + and len(poll_response.response.videos) > 0 + ): + video = poll_response.response.videos[0] + else: + return None + if hasattr(video, "gcsUri") and video.gcsUri: + return str(video.gcsUri) + return None + + class VeoVideoGenerationNode(ComfyNodeABC): """ Generates videos from text prompts using Google's Veo API. @@ -115,6 +134,7 @@ class VeoVideoGenerationNode(ComfyNodeABC): "hidden": { "auth_token": "AUTH_TOKEN_COMFY_ORG", "comfy_api_key": "API_KEY_COMFY_ORG", + "unique_id": "UNIQUE_ID", }, } @@ -134,6 +154,7 @@ class VeoVideoGenerationNode(ComfyNodeABC): person_generation="ALLOW", seed=0, image=None, + unique_id: Optional[str] = None, **kwargs, ): # Prepare the instances for the request @@ -215,7 +236,10 @@ class VeoVideoGenerationNode(ComfyNodeABC): operationName=operation_name ), auth_kwargs=kwargs, - poll_interval=5.0 + poll_interval=5.0, + result_url_extractor=get_video_url_from_response, + node_id=unique_id, + estimated_duration=AVERAGE_DURATION_VIDEO_GEN, ) # Execute the polling operation From f3ff5c40db3b68449b47112591eef31094cffe70 Mon Sep 17 00:00:00 2001 From: Christian Byrne Date: Tue, 13 May 2025 22:28:30 -0700 Subject: [PATCH 09/31] don't retry if API returns task failure (#8111) --- comfy_api_nodes/apis/client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy_api_nodes/apis/client.py b/comfy_api_nodes/apis/client.py index 838ff1e8..62866216 100644 --- a/comfy_api_nodes/apis/client.py +++ b/comfy_api_nodes/apis/client.py @@ -1105,7 +1105,7 @@ class PollingOperation(Generic[T, R]): except Exception as e: # For other errors, increment count and potentially abort consecutive_errors += 1 - if consecutive_errors >= max_consecutive_errors: + if consecutive_errors >= max_consecutive_errors or status == TaskStatus.FAILED: raise Exception( f"Polling aborted after {consecutive_errors} consecutive errors: {str(e)}" ) from e From 08368f8e00c07d6888907557c5ef1da7a64b5e42 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 14 May 2025 14:54:50 -0700 Subject: [PATCH 10/31] Update comment on ROCm pytorch attention in README. (#8123) --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index deee70c6..dcdaa353 100644 --- a/README.md +++ b/README.md @@ -302,7 +302,7 @@ For AMD 7600 and maybe other RDNA3 cards: ```HSA_OVERRIDE_GFX_VERSION=11.0.0 pyt ### AMD ROCm Tips -You can enable experimental memory efficient attention on pytorch 2.5 in ComfyUI on RDNA3 and potentially other AMD GPUs using this command: +You can enable experimental memory efficient attention on recent pytorch in ComfyUI on some AMD GPUs using this command, it should already be enabled by default on RDNA3. If this improves speed for you on latest pytorch on your GPU please report it so that I can enable it by default. ```TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1 python main.py --use-pytorch-cross-attention``` From f1f9763b4c0e7c361ec2808b205ad32297c4777b Mon Sep 17 00:00:00 2001 From: Christian Byrne Date: Wed, 14 May 2025 21:11:41 -0700 Subject: [PATCH 11/31] Add `get_duration` method to Comfy VIDEO type (#8122) * get duration from VIDEO type * video get_duration unit test * fix Windows unit test: can't delete opened temp file --- comfy_api/input/video_types.py | 10 + comfy_api/input_impl/video_types.py | 32 +++ tests-unit/comfy_api_test/video_types_test.py | 239 ++++++++++++++++++ 3 files changed, 281 insertions(+) create mode 100644 tests-unit/comfy_api_test/video_types_test.py diff --git a/comfy_api/input/video_types.py b/comfy_api/input/video_types.py index 0676e0e6..dc22d34f 100644 --- a/comfy_api/input/video_types.py +++ b/comfy_api/input/video_types.py @@ -43,3 +43,13 @@ class VideoInput(ABC): components = self.get_components() return components.images.shape[2], components.images.shape[1] + def get_duration(self) -> float: + """ + Returns the duration of the video in seconds. + + Returns: + Duration in seconds + """ + components = self.get_components() + frame_count = components.images.shape[0] + return float(frame_count / components.frame_rate) diff --git a/comfy_api/input_impl/video_types.py b/comfy_api/input_impl/video_types.py index ae48dbaa..197f6558 100644 --- a/comfy_api/input_impl/video_types.py +++ b/comfy_api/input_impl/video_types.py @@ -80,6 +80,38 @@ class VideoFromFile(VideoInput): return stream.width, stream.height raise ValueError(f"No video stream found in file '{self.__file}'") + def get_duration(self) -> float: + """ + Returns the duration of the video in seconds. + + Returns: + Duration in seconds + """ + if isinstance(self.__file, io.BytesIO): + self.__file.seek(0) + with av.open(self.__file, mode="r") as container: + if container.duration is not None: + return float(container.duration / av.time_base) + + # Fallback: calculate from frame count and frame rate + video_stream = next( + (s for s in container.streams if s.type == "video"), None + ) + if video_stream and video_stream.frames and video_stream.average_rate: + return float(video_stream.frames / video_stream.average_rate) + + # Last resort: decode frames to count them + if video_stream and video_stream.average_rate: + frame_count = 0 + container.seek(0) + for packet in container.demux(video_stream): + for _ in packet.decode(): + frame_count += 1 + if frame_count > 0: + return float(frame_count / video_stream.average_rate) + + raise ValueError(f"Could not determine duration for file '{self.__file}'") + def get_components_internal(self, container: InputContainer) -> VideoComponents: # Get video frames frames = [] diff --git a/tests-unit/comfy_api_test/video_types_test.py b/tests-unit/comfy_api_test/video_types_test.py new file mode 100644 index 00000000..b25fcb1c --- /dev/null +++ b/tests-unit/comfy_api_test/video_types_test.py @@ -0,0 +1,239 @@ +import pytest +import torch +import tempfile +import os +import av +import io +from fractions import Fraction +from comfy_api.input_impl.video_types import VideoFromFile, VideoFromComponents +from comfy_api.util.video_types import VideoComponents +from comfy_api.input.basic_types import AudioInput +from av.error import InvalidDataError + +EPSILON = 0.0001 + + +@pytest.fixture +def sample_images(): + """3-frame 2x2 RGB video tensor""" + return torch.rand(3, 2, 2, 3) + + +@pytest.fixture +def sample_audio(): + """Stereo audio with 44.1kHz sample rate""" + return AudioInput( + { + "waveform": torch.rand(1, 2, 1000), + "sample_rate": 44100, + } + ) + + +@pytest.fixture +def video_components(sample_images, sample_audio): + """VideoComponents with images, audio, and metadata""" + return VideoComponents( + images=sample_images, + audio=sample_audio, + frame_rate=Fraction(30), + metadata={"test": "metadata"}, + ) + + +def create_test_video(width=4, height=4, frames=3, fps=30): + """Helper to create a temporary video file""" + tmp = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) + with av.open(tmp.name, mode="w") as container: + stream = container.add_stream("h264", rate=fps) + stream.width = width + stream.height = height + stream.pix_fmt = "yuv420p" + + for i in range(frames): + frame = av.VideoFrame.from_ndarray( + torch.ones(height, width, 3, dtype=torch.uint8).numpy() * (i * 85), + format="rgb24", + ) + frame = frame.reformat(format="yuv420p") + packet = stream.encode(frame) + container.mux(packet) + + # Flush + packet = stream.encode(None) + container.mux(packet) + + return tmp.name + + +@pytest.fixture +def simple_video_file(): + """4x4 video with 3 frames at 30fps""" + file_path = create_test_video() + yield file_path + os.unlink(file_path) + + +def test_video_from_components_get_duration(video_components): + """Duration calculated correctly from frame count and frame rate""" + video = VideoFromComponents(video_components) + duration = video.get_duration() + + expected_duration = 3.0 / 30.0 + assert duration == pytest.approx(expected_duration) + + +def test_video_from_components_get_duration_different_frame_rates(sample_images): + """Duration correct for different frame rates including fractional""" + # Test with 60 fps + components_60fps = VideoComponents(images=sample_images, frame_rate=Fraction(60)) + video_60fps = VideoFromComponents(components_60fps) + assert video_60fps.get_duration() == pytest.approx(3.0 / 60.0) + + # Test with fractional frame rate (23.976fps) + components_frac = VideoComponents( + images=sample_images, frame_rate=Fraction(24000, 1001) + ) + video_frac = VideoFromComponents(components_frac) + expected_frac = 3.0 / (24000.0 / 1001.0) + assert video_frac.get_duration() == pytest.approx(expected_frac) + + +def test_video_from_components_get_duration_empty_video(): + """Duration is zero for empty video""" + empty_components = VideoComponents( + images=torch.zeros(0, 2, 2, 3), frame_rate=Fraction(30) + ) + video = VideoFromComponents(empty_components) + assert video.get_duration() == 0.0 + + +def test_video_from_components_get_dimensions(video_components): + """Dimensions returned correctly from image tensor shape""" + video = VideoFromComponents(video_components) + width, height = video.get_dimensions() + assert width == 2 + assert height == 2 + + +def test_video_from_file_get_duration(simple_video_file): + """Duration extracted from file metadata""" + video = VideoFromFile(simple_video_file) + duration = video.get_duration() + assert duration == pytest.approx(0.1, abs=0.01) + + +def test_video_from_file_get_dimensions(simple_video_file): + """Dimensions read from stream without decoding frames""" + video = VideoFromFile(simple_video_file) + width, height = video.get_dimensions() + assert width == 4 + assert height == 4 + + +def test_video_from_file_bytesio_input(): + """VideoFromFile works with BytesIO input""" + buffer = io.BytesIO() + with av.open(buffer, mode="w", format="mp4") as container: + stream = container.add_stream("h264", rate=30) + stream.width = 2 + stream.height = 2 + stream.pix_fmt = "yuv420p" + + frame = av.VideoFrame.from_ndarray( + torch.zeros(2, 2, 3, dtype=torch.uint8).numpy(), format="rgb24" + ) + frame = frame.reformat(format="yuv420p") + packet = stream.encode(frame) + container.mux(packet) + packet = stream.encode(None) + container.mux(packet) + + buffer.seek(0) + video = VideoFromFile(buffer) + + assert video.get_dimensions() == (2, 2) + assert video.get_duration() == pytest.approx(1 / 30, abs=0.01) + + +def test_video_from_file_invalid_file_error(): + """InvalidDataError raised for non-video files""" + with tempfile.NamedTemporaryFile(suffix=".txt", delete=False) as tmp: + tmp.write(b"not a video file") + tmp.flush() + tmp_name = tmp.name + + try: + with pytest.raises(InvalidDataError): + video = VideoFromFile(tmp_name) + video.get_dimensions() + finally: + os.unlink(tmp_name) + + +def test_video_from_file_audio_only_error(): + """ValueError raised for audio-only files""" + with tempfile.NamedTemporaryFile(suffix=".m4a", delete=False) as tmp: + tmp_name = tmp.name + + try: + with av.open(tmp_name, mode="w") as container: + stream = container.add_stream("aac", rate=44100) + stream.sample_rate = 44100 + stream.format = "fltp" + + audio_data = torch.zeros(1, 1024).numpy() + audio_frame = av.AudioFrame.from_ndarray( + audio_data, format="fltp", layout="mono" + ) + audio_frame.sample_rate = 44100 + audio_frame.pts = 0 + packet = stream.encode(audio_frame) + container.mux(packet) + + for packet in stream.encode(None): + container.mux(packet) + + with pytest.raises(ValueError, match="No video stream found"): + video = VideoFromFile(tmp_name) + video.get_dimensions() + finally: + os.unlink(tmp_name) + + +def test_single_frame_video(): + """Single frame video has correct duration""" + components = VideoComponents( + images=torch.rand(1, 10, 10, 3), frame_rate=Fraction(1) + ) + video = VideoFromComponents(components) + assert video.get_duration() == 1.0 + + +@pytest.mark.parametrize( + "frame_rate,expected_fps", + [ + (Fraction(24000, 1001), 24000 / 1001), + (Fraction(30000, 1001), 30000 / 1001), + (Fraction(25, 1), 25.0), + (Fraction(50, 2), 25.0), + ], +) +def test_fractional_frame_rates(frame_rate, expected_fps): + """Duration calculated correctly for various fractional frame rates""" + components = VideoComponents(images=torch.rand(100, 4, 4, 3), frame_rate=frame_rate) + video = VideoFromComponents(components) + duration = video.get_duration() + expected_duration = 100.0 / expected_fps + assert duration == pytest.approx(expected_duration) + + +def test_duration_consistency(video_components): + """get_duration() consistent with manual calculation from components""" + video = VideoFromComponents(video_components) + + duration = video.get_duration() + components = video.get_components() + manual_duration = float(components.images.shape[0] / components.frame_rate) + + assert duration == pytest.approx(manual_duration) From 6a2e4bb9e00c1ef467000d880abbbaf284c26d4a Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Thu, 15 May 2025 05:21:47 -0700 Subject: [PATCH 12/31] Remove old hack used to fix windows pytorch 2.4 on the portable. (#8139) Not necessary anymore. --- fix_torch.py | 28 ---------------------------- main.py | 7 ------- 2 files changed, 35 deletions(-) delete mode 100644 fix_torch.py diff --git a/fix_torch.py b/fix_torch.py deleted file mode 100644 index ce117b63..00000000 --- a/fix_torch.py +++ /dev/null @@ -1,28 +0,0 @@ -import importlib.util -import shutil -import os -import ctypes -import logging - - -def fix_pytorch_libomp(): - """ - Fix PyTorch libomp DLL issue on Windows by copying the correct DLL file if needed. - """ - torch_spec = importlib.util.find_spec("torch") - for folder in torch_spec.submodule_search_locations: - lib_folder = os.path.join(folder, "lib") - test_file = os.path.join(lib_folder, "fbgemm.dll") - dest = os.path.join(lib_folder, "libomp140.x86_64.dll") - if os.path.exists(dest): - break - - with open(test_file, "rb") as f: - contents = f.read() - if b"libomp140.x86_64.dll" not in contents: - break - try: - ctypes.cdll.LoadLibrary(test_file) - except FileNotFoundError: - logging.warning("Detected pytorch version with libomp issue, patching.") - shutil.copyfile(os.path.join(lib_folder, "libiomp5md.dll"), dest) diff --git a/main.py b/main.py index 221e48e4..0fde6d22 100644 --- a/main.py +++ b/main.py @@ -125,13 +125,6 @@ if __name__ == "__main__": import cuda_malloc -if args.windows_standalone_build: - try: - from fix_torch import fix_pytorch_libomp - fix_pytorch_libomp() - except: - pass - import comfy.utils import execution From c820ef950d10a6b4e4fa8ab28bc09274d563b13c Mon Sep 17 00:00:00 2001 From: George0726 <38740075+George0726@users.noreply.github.com> Date: Thu, 15 May 2025 19:00:43 -0400 Subject: [PATCH 13/31] Add Wan-FUN Camera Control models and Add WanCameraImageToVideo node (#8013) * support wan camera models * fix by ruff check * change camera_condition type; make camera_condition optional * support camera trajectory nodes * fix camera direction --------- Co-authored-by: Qirui Sun --- comfy/ldm/wan/model.py | 143 ++++++++++++++++ comfy/model_base.py | 11 ++ comfy/supported_models.py | 12 +- comfy_extras/nodes_camera_trajectory.py | 218 ++++++++++++++++++++++++ comfy_extras/nodes_wan.py | 47 +++++ nodes.py | 1 + 6 files changed, 431 insertions(+), 1 deletion(-) create mode 100644 comfy_extras/nodes_camera_trajectory.py diff --git a/comfy/ldm/wan/model.py b/comfy/ldm/wan/model.py index fc5ff40c..a996dedf 100644 --- a/comfy/ldm/wan/model.py +++ b/comfy/ldm/wan/model.py @@ -247,6 +247,60 @@ class VaceWanAttentionBlock(WanAttentionBlock): return c_skip, c +class WanCamAdapter(nn.Module): + def __init__(self, in_dim, out_dim, kernel_size, stride, num_residual_blocks=1, operation_settings={}): + super(WanCamAdapter, self).__init__() + + # Pixel Unshuffle: reduce spatial dimensions by a factor of 8 + self.pixel_unshuffle = nn.PixelUnshuffle(downscale_factor=8) + + # Convolution: reduce spatial dimensions by a factor + # of 2 (without overlap) + self.conv = operation_settings.get("operations").Conv2d(in_dim * 64, out_dim, kernel_size=kernel_size, stride=stride, padding=0, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + + # Residual blocks for feature extraction + self.residual_blocks = nn.Sequential( + *[WanCamResidualBlock(out_dim, operation_settings = operation_settings) for _ in range(num_residual_blocks)] + ) + + def forward(self, x): + # Reshape to merge the frame dimension into batch + bs, c, f, h, w = x.size() + x = x.permute(0, 2, 1, 3, 4).contiguous().view(bs * f, c, h, w) + + # Pixel Unshuffle operation + x_unshuffled = self.pixel_unshuffle(x) + + # Convolution operation + x_conv = self.conv(x_unshuffled) + + # Feature extraction with residual blocks + out = self.residual_blocks(x_conv) + + # Reshape to restore original bf dimension + out = out.view(bs, f, out.size(1), out.size(2), out.size(3)) + + # Permute dimensions to reorder (if needed), e.g., swap channels and feature frames + out = out.permute(0, 2, 1, 3, 4) + + return out + + +class WanCamResidualBlock(nn.Module): + def __init__(self, dim, operation_settings={}): + super(WanCamResidualBlock, self).__init__() + self.conv1 = operation_settings.get("operations").Conv2d(dim, dim, kernel_size=3, padding=1, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + self.relu = nn.ReLU(inplace=True) + self.conv2 = operation_settings.get("operations").Conv2d(dim, dim, kernel_size=3, padding=1, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + + def forward(self, x): + residual = x + out = self.relu(self.conv1(x)) + out = self.conv2(out) + out += residual + return out + + class Head(nn.Module): def __init__(self, dim, out_dim, patch_size, eps=1e-6, operation_settings={}): @@ -637,3 +691,92 @@ class VaceWanModel(WanModel): # unpatchify x = self.unpatchify(x, grid_sizes) return x + +class CameraWanModel(WanModel): + r""" + Wan diffusion backbone supporting both text-to-video and image-to-video. + """ + + def __init__(self, + model_type='camera', + patch_size=(1, 2, 2), + text_len=512, + in_dim=16, + dim=2048, + ffn_dim=8192, + freq_dim=256, + text_dim=4096, + out_dim=16, + num_heads=16, + num_layers=32, + window_size=(-1, -1), + qk_norm=True, + cross_attn_norm=True, + eps=1e-6, + flf_pos_embed_token_number=None, + image_model=None, + in_dim_control_adapter=24, + device=None, + dtype=None, + operations=None, + ): + + super().__init__(model_type='i2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, image_model=image_model, device=device, dtype=dtype, operations=operations) + operation_settings = {"operations": operations, "device": device, "dtype": dtype} + + self.control_adapter = WanCamAdapter(in_dim_control_adapter, dim, kernel_size=patch_size[1:], stride=patch_size[1:], operation_settings=operation_settings) + + + def forward_orig( + self, + x, + t, + context, + clip_fea=None, + freqs=None, + camera_conditions = None, + transformer_options={}, + **kwargs, + ): + # embeddings + x = self.patch_embedding(x.float()).to(x.dtype) + if self.control_adapter is not None and camera_conditions is not None: + x_camera = self.control_adapter(camera_conditions).to(x.dtype) + x = x + x_camera + grid_sizes = x.shape[2:] + x = x.flatten(2).transpose(1, 2) + + # time embeddings + e = self.time_embedding( + sinusoidal_embedding_1d(self.freq_dim, t).to(dtype=x[0].dtype)) + e0 = self.time_projection(e).unflatten(1, (6, self.dim)) + + # context + context = self.text_embedding(context) + + context_img_len = None + if clip_fea is not None: + if self.img_emb is not None: + context_clip = self.img_emb(clip_fea) # bs x 257 x dim + context = torch.concat([context_clip, context], dim=1) + context_img_len = clip_fea.shape[-2] + + patches_replace = transformer_options.get("patches_replace", {}) + blocks_replace = patches_replace.get("dit", {}) + for i, block in enumerate(self.blocks): + if ("double_block", i) in blocks_replace: + def block_wrap(args): + out = {} + out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len) + return out + out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap}) + x = out["img"] + else: + x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len) + + # head + x = self.head(x, e) + + # unpatchify + x = self.unpatchify(x, grid_sizes) + return x diff --git a/comfy/model_base.py b/comfy/model_base.py index 04786159..f475e837 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1079,6 +1079,17 @@ class WAN21_Vace(WAN21): out['vace_strength'] = comfy.conds.CONDConstant(vace_strength) return out +class WAN21_Camera(WAN21): + def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None): + super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.CameraWanModel) + self.image_to_video = image_to_video + + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + camera_conditions = kwargs.get("camera_conditions", None) + if camera_conditions is not None: + out['camera_conditions'] = comfy.conds.CONDRegular(camera_conditions) + return out class Hunyuan3Dv2(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): diff --git a/comfy/supported_models.py b/comfy/supported_models.py index fef25eb2..667393ac 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -992,6 +992,16 @@ class WAN21_FunControl2V(WAN21_T2V): out = model_base.WAN21(self, image_to_video=False, device=device) return out +class WAN21_Camera(WAN21_T2V): + unet_config = { + "image_model": "wan2.1", + "model_type": "i2v", + "in_dim": 32, + } + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.WAN21_Camera(self, image_to_video=False, device=device) + return out class WAN21_Vace(WAN21_T2V): unet_config = { "image_model": "wan2.1", @@ -1129,6 +1139,6 @@ class ACEStep(supported_models_base.BASE): def clip_target(self, state_dict={}): return supported_models_base.ClipTarget(comfy.text_encoders.ace.AceT5Tokenizer, comfy.text_encoders.ace.AceT5Model) -models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep] +models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep] models += [SVD_img2vid] diff --git a/comfy_extras/nodes_camera_trajectory.py b/comfy_extras/nodes_camera_trajectory.py new file mode 100644 index 00000000..b84b4785 --- /dev/null +++ b/comfy_extras/nodes_camera_trajectory.py @@ -0,0 +1,218 @@ +import nodes +import torch +import numpy as np +from einops import rearrange +import comfy.model_management + + + +MAX_RESOLUTION = nodes.MAX_RESOLUTION + +CAMERA_DICT = { + "base_T_norm": 1.5, + "base_angle": np.pi/3, + "Static": { "angle":[0., 0., 0.], "T":[0., 0., 0.]}, + "Pan Up": { "angle":[0., 0., 0.], "T":[0., -1., 0.]}, + "Pan Down": { "angle":[0., 0., 0.], "T":[0.,1.,0.]}, + "Pan Left": { "angle":[0., 0., 0.], "T":[-1.,0.,0.]}, + "Pan Right": { "angle":[0., 0., 0.], "T": [1.,0.,0.]}, + "Zoom In": { "angle":[0., 0., 0.], "T": [0.,0.,2.]}, + "Zoom Out": { "angle":[0., 0., 0.], "T": [0.,0.,-2.]}, + "Anti Clockwise (ACW)": { "angle": [0., 0., -1.], "T":[0., 0., 0.]}, + "ClockWise (CW)": { "angle": [0., 0., 1.], "T":[0., 0., 0.]}, +} + + +def process_pose_params(cam_params, width=672, height=384, original_pose_width=1280, original_pose_height=720, device='cpu'): + + def get_relative_pose(cam_params): + """Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py + """ + abs_w2cs = [cam_param.w2c_mat for cam_param in cam_params] + abs_c2ws = [cam_param.c2w_mat for cam_param in cam_params] + cam_to_origin = 0 + target_cam_c2w = np.array([ + [1, 0, 0, 0], + [0, 1, 0, -cam_to_origin], + [0, 0, 1, 0], + [0, 0, 0, 1] + ]) + abs2rel = target_cam_c2w @ abs_w2cs[0] + ret_poses = [target_cam_c2w, ] + [abs2rel @ abs_c2w for abs_c2w in abs_c2ws[1:]] + ret_poses = np.array(ret_poses, dtype=np.float32) + return ret_poses + + """Modified from https://github.com/hehao13/CameraCtrl/blob/main/inference.py + """ + cam_params = [Camera(cam_param) for cam_param in cam_params] + + sample_wh_ratio = width / height + pose_wh_ratio = original_pose_width / original_pose_height # Assuming placeholder ratios, change as needed + + if pose_wh_ratio > sample_wh_ratio: + resized_ori_w = height * pose_wh_ratio + for cam_param in cam_params: + cam_param.fx = resized_ori_w * cam_param.fx / width + else: + resized_ori_h = width / pose_wh_ratio + for cam_param in cam_params: + cam_param.fy = resized_ori_h * cam_param.fy / height + + intrinsic = np.asarray([[cam_param.fx * width, + cam_param.fy * height, + cam_param.cx * width, + cam_param.cy * height] + for cam_param in cam_params], dtype=np.float32) + + K = torch.as_tensor(intrinsic)[None] # [1, 1, 4] + c2ws = get_relative_pose(cam_params) # Assuming this function is defined elsewhere + c2ws = torch.as_tensor(c2ws)[None] # [1, n_frame, 4, 4] + plucker_embedding = ray_condition(K, c2ws, height, width, device=device)[0].permute(0, 3, 1, 2).contiguous() # V, 6, H, W + plucker_embedding = plucker_embedding[None] + plucker_embedding = rearrange(plucker_embedding, "b f c h w -> b f h w c")[0] + return plucker_embedding + +class Camera(object): + """Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py + """ + def __init__(self, entry): + fx, fy, cx, cy = entry[1:5] + self.fx = fx + self.fy = fy + self.cx = cx + self.cy = cy + c2w_mat = np.array(entry[7:]).reshape(4, 4) + self.c2w_mat = c2w_mat + self.w2c_mat = np.linalg.inv(c2w_mat) + +def ray_condition(K, c2w, H, W, device): + """Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py + """ + # c2w: B, V, 4, 4 + # K: B, V, 4 + + B = K.shape[0] + + j, i = torch.meshgrid( + torch.linspace(0, H - 1, H, device=device, dtype=c2w.dtype), + torch.linspace(0, W - 1, W, device=device, dtype=c2w.dtype), + indexing='ij' + ) + i = i.reshape([1, 1, H * W]).expand([B, 1, H * W]) + 0.5 # [B, HxW] + j = j.reshape([1, 1, H * W]).expand([B, 1, H * W]) + 0.5 # [B, HxW] + + fx, fy, cx, cy = K.chunk(4, dim=-1) # B,V, 1 + + zs = torch.ones_like(i) # [B, HxW] + xs = (i - cx) / fx * zs + ys = (j - cy) / fy * zs + zs = zs.expand_as(ys) + + directions = torch.stack((xs, ys, zs), dim=-1) # B, V, HW, 3 + directions = directions / directions.norm(dim=-1, keepdim=True) # B, V, HW, 3 + + rays_d = directions @ c2w[..., :3, :3].transpose(-1, -2) # B, V, 3, HW + rays_o = c2w[..., :3, 3] # B, V, 3 + rays_o = rays_o[:, :, None].expand_as(rays_d) # B, V, 3, HW + # c2w @ dirctions + rays_dxo = torch.cross(rays_o, rays_d) + plucker = torch.cat([rays_dxo, rays_d], dim=-1) + plucker = plucker.reshape(B, c2w.shape[1], H, W, 6) # B, V, H, W, 6 + # plucker = plucker.permute(0, 1, 4, 2, 3) + return plucker + +def get_camera_motion(angle, T, speed, n=81): + def compute_R_form_rad_angle(angles): + theta_x, theta_y, theta_z = angles + Rx = np.array([[1, 0, 0], + [0, np.cos(theta_x), -np.sin(theta_x)], + [0, np.sin(theta_x), np.cos(theta_x)]]) + + Ry = np.array([[np.cos(theta_y), 0, np.sin(theta_y)], + [0, 1, 0], + [-np.sin(theta_y), 0, np.cos(theta_y)]]) + + Rz = np.array([[np.cos(theta_z), -np.sin(theta_z), 0], + [np.sin(theta_z), np.cos(theta_z), 0], + [0, 0, 1]]) + + R = np.dot(Rz, np.dot(Ry, Rx)) + return R + RT = [] + for i in range(n): + _angle = (i/n)*speed*(CAMERA_DICT["base_angle"])*angle + R = compute_R_form_rad_angle(_angle) + _T=(i/n)*speed*(CAMERA_DICT["base_T_norm"])*(T.reshape(3,1)) + _RT = np.concatenate([R,_T], axis=1) + RT.append(_RT) + RT = np.stack(RT) + return RT + +class WanCameraEmbeding: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "camera_pose":(["Static","Pan Up","Pan Down","Pan Left","Pan Right","Zoom In","Zoom Out","Anti Clockwise (ACW)", "ClockWise (CW)"],{"default":"Static"}), + "width": ("INT", {"default": 832, "min": 16, "max": MAX_RESOLUTION, "step": 16}), + "height": ("INT", {"default": 480, "min": 16, "max": MAX_RESOLUTION, "step": 16}), + "length": ("INT", {"default": 81, "min": 1, "max": MAX_RESOLUTION, "step": 4}), + }, + "optional":{ + "speed":("FLOAT",{"default":1.0, "min": 0, "max": 10.0, "step": 0.1}), + "fx":("FLOAT",{"default":0.5, "min": 0, "max": 1, "step": 0.000000001}), + "fy":("FLOAT",{"default":0.5, "min": 0, "max": 1, "step": 0.000000001}), + "cx":("FLOAT",{"default":0.5, "min": 0, "max": 1, "step": 0.01}), + "cy":("FLOAT",{"default":0.5, "min": 0, "max": 1, "step": 0.01}), + } + + } + + RETURN_TYPES = ("WAN_CAMERA_EMBEDDING","INT","INT","INT") + RETURN_NAMES = ("camera_embedding","width","height","length") + FUNCTION = "run" + CATEGORY = "camera" + + def run(self, camera_pose, width, height, length, speed=1.0, fx=0.5, fy=0.5, cx=0.5, cy=0.5): + """ + Use Camera trajectory as extrinsic parameters to calculate Plücker embeddings (Sitzmannet al., 2021) + Adapted from https://github.com/aigc-apps/VideoX-Fun/blob/main/comfyui/comfyui_nodes.py + """ + motion_list = [camera_pose] + speed = speed + angle = np.array(CAMERA_DICT[motion_list[0]]["angle"]) + T = np.array(CAMERA_DICT[motion_list[0]]["T"]) + RT = get_camera_motion(angle, T, speed, length) + + trajs=[] + for cp in RT.tolist(): + traj=[fx,fy,cx,cy,0,0] + traj.extend(cp[0]) + traj.extend(cp[1]) + traj.extend(cp[2]) + traj.extend([0,0,0,1]) + trajs.append(traj) + + cam_params = np.array([[float(x) for x in pose] for pose in trajs]) + cam_params = np.concatenate([np.zeros_like(cam_params[:, :1]), cam_params], 1) + control_camera_video = process_pose_params(cam_params, width=width, height=height) + control_camera_video = control_camera_video.permute([3, 0, 1, 2]).unsqueeze(0).to(device=comfy.model_management.intermediate_device()) + + control_camera_video = torch.concat( + [ + torch.repeat_interleave(control_camera_video[:, :, 0:1], repeats=4, dim=2), + control_camera_video[:, :, 1:] + ], dim=2 + ).transpose(1, 2) + + # Reshape, transpose, and view into desired shape + b, f, c, h, w = control_camera_video.shape + control_camera_video = control_camera_video.contiguous().view(b, f // 4, 4, c, h, w).transpose(2, 3) + control_camera_video = control_camera_video.contiguous().view(b, f // 4, c * 4, h, w).transpose(1, 2) + + return (control_camera_video, width, height, length) + + +NODE_CLASS_MAPPINGS = { + "WanCameraEmbeding": WanCameraEmbeding, +} diff --git a/comfy_extras/nodes_wan.py b/comfy_extras/nodes_wan.py index 9dda6459..a91b4aba 100644 --- a/comfy_extras/nodes_wan.py +++ b/comfy_extras/nodes_wan.py @@ -297,6 +297,52 @@ class TrimVideoLatent: samples_out["samples"] = s1[:, :, trim_amount:] return (samples_out,) +class WanCameraImageToVideo: + @classmethod + def INPUT_TYPES(s): + return {"required": {"positive": ("CONDITIONING", ), + "negative": ("CONDITIONING", ), + "vae": ("VAE", ), + "width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), + "height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), + "length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}), + "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), + }, + "optional": {"clip_vision_output": ("CLIP_VISION_OUTPUT", ), + "start_image": ("IMAGE", ), + "camera_conditions": ("WAN_CAMERA_EMBEDDING", ), + }} + + RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") + RETURN_NAMES = ("positive", "negative", "latent") + FUNCTION = "encode" + + CATEGORY = "conditioning/video_models" + + def encode(self, positive, negative, vae, width, height, length, batch_size, start_image=None, clip_vision_output=None, camera_conditions=None): + latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) + concat_latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) + concat_latent = comfy.latent_formats.Wan21().process_out(concat_latent) + + if start_image is not None: + start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) + concat_latent_image = vae.encode(start_image[:, :, :, :3]) + concat_latent[:,:,:concat_latent_image.shape[2]] = concat_latent_image[:,:,:concat_latent.shape[2]] + + positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent}) + negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent}) + + if camera_conditions is not None: + positive = node_helpers.conditioning_set_values(positive, {'camera_conditions': camera_conditions}) + negative = node_helpers.conditioning_set_values(negative, {'camera_conditions': camera_conditions}) + + if clip_vision_output is not None: + positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output}) + negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output}) + + out_latent = {} + out_latent["samples"] = latent + return (positive, negative, out_latent) NODE_CLASS_MAPPINGS = { "WanImageToVideo": WanImageToVideo, @@ -305,4 +351,5 @@ NODE_CLASS_MAPPINGS = { "WanFirstLastFrameToVideo": WanFirstLastFrameToVideo, "WanVaceToVideo": WanVaceToVideo, "TrimVideoLatent": TrimVideoLatent, + "WanCameraImageToVideo": WanCameraImageToVideo, } diff --git a/nodes.py b/nodes.py index 54e3886a..0a9db139 100644 --- a/nodes.py +++ b/nodes.py @@ -2265,6 +2265,7 @@ def init_builtin_extra_nodes(): "nodes_preview_any.py", "nodes_ace.py", "nodes_string.py", + "nodes_camera_trajectory.py", ] import_failed = [] From 1c2d45d2b575c40efcb303a61318d4a415242e52 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Thu, 15 May 2025 16:02:19 -0700 Subject: [PATCH 14/31] Fix typo in last PR. (#8144) More robust model detection for future proofing. --- comfy/model_detection.py | 2 ++ comfy/supported_models.py | 2 +- comfy_extras/nodes_camera_trajectory.py | 4 ++-- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 28c58638..20f287df 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -361,6 +361,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["model_type"] = "vace" dit_config["vace_in_dim"] = state_dict['{}vace_patch_embedding.weight'.format(key_prefix)].shape[1] dit_config["vace_layers"] = count_blocks(state_dict_keys, '{}vace_blocks.'.format(key_prefix) + '{}.') + elif '{}control_adapter.conv.weight'.format(key_prefix) in state_dict_keys: + dit_config["model_type"] = "camera" else: if '{}img_emb.proj.0.bias'.format(key_prefix) in state_dict_keys: dit_config["model_type"] = "i2v" diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 667393ac..efe2e6b8 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -995,7 +995,7 @@ class WAN21_FunControl2V(WAN21_T2V): class WAN21_Camera(WAN21_T2V): unet_config = { "image_model": "wan2.1", - "model_type": "i2v", + "model_type": "camera", "in_dim": 32, } diff --git a/comfy_extras/nodes_camera_trajectory.py b/comfy_extras/nodes_camera_trajectory.py index b84b4785..5e0e39f9 100644 --- a/comfy_extras/nodes_camera_trajectory.py +++ b/comfy_extras/nodes_camera_trajectory.py @@ -148,7 +148,7 @@ def get_camera_motion(angle, T, speed, n=81): RT = np.stack(RT) return RT -class WanCameraEmbeding: +class WanCameraEmbedding: @classmethod def INPUT_TYPES(cls): return { @@ -214,5 +214,5 @@ class WanCameraEmbeding: NODE_CLASS_MAPPINGS = { - "WanCameraEmbeding": WanCameraEmbeding, + "WanCameraEmbedding": WanCameraEmbedding, } From 7046983d9538d49d1dd286a513aa6db42b9a74fd Mon Sep 17 00:00:00 2001 From: filtered <176114999+webfiltered@users.noreply.github.com> Date: Sat, 17 May 2025 03:45:36 +1000 Subject: [PATCH 15/31] Remove Desktop versioning claim from README (#8155) --- README.md | 1 - 1 file changed, 1 deletion(-) diff --git a/README.md b/README.md index dcdaa353..9b5f301c 100644 --- a/README.md +++ b/README.md @@ -110,7 +110,6 @@ ComfyUI follows a weekly release cycle every Friday, with three interconnected r 2. **[ComfyUI Desktop](https://github.com/Comfy-Org/desktop)** - Builds a new release using the latest stable core version - - Version numbers match the core release (e.g., Desktop v1.7.0 uses Core v1.7.0) 3. **[ComfyUI Frontend](https://github.com/Comfy-Org/ComfyUI_frontend)** - Weekly frontend updates are merged into the core repository From dc46db7aa48acd035dfc10730a064a9c994fb76b Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Fri, 16 May 2025 12:15:55 -0700 Subject: [PATCH 16/31] Make ImagePadForOutpaint return a 3 channel mask. (#8157) --- nodes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nodes.py b/nodes.py index 0a9db139..95e831b8 100644 --- a/nodes.py +++ b/nodes.py @@ -1940,7 +1940,7 @@ class ImagePadForOutpaint: mask[top:top + d2, left:left + d3] = t - return (new_image, mask) + return (new_image, mask.unsqueeze(0)) NODE_CLASS_MAPPINGS = { From aee2908d0395577a6e2e13d1307aaf271424108b Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sat, 17 May 2025 03:27:34 -0700 Subject: [PATCH 17/31] Remove useless log. (#8166) --- comfy/utils.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/comfy/utils.py b/comfy/utils.py index 561e1b85..1f8d7129 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -78,8 +78,6 @@ def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False): pl_sd = torch.load(ckpt, map_location=device, weights_only=True, **torch_args) else: pl_sd = torch.load(ckpt, map_location=device, pickle_module=comfy.checkpoint_pickle) - if "global_step" in pl_sd: - logging.debug(f"Global Step: {pl_sd['global_step']}") if "state_dict" in pl_sd: sd = pl_sd["state_dict"] else: From f5e4e976f43c9ed79e75b27f73719b2708f9ded9 Mon Sep 17 00:00:00 2001 From: Silver <65376327+silveroxides@users.noreply.github.com> Date: Sun, 18 May 2025 08:59:06 +0200 Subject: [PATCH 18/31] Add missing category for T5TokenizerOption (#8177) Change it if you need to but it should at least have a category. --- comfy_extras/nodes_cond.py | 1 + 1 file changed, 1 insertion(+) diff --git a/comfy_extras/nodes_cond.py b/comfy_extras/nodes_cond.py index 57426217..58c16f62 100644 --- a/comfy_extras/nodes_cond.py +++ b/comfy_extras/nodes_cond.py @@ -31,6 +31,7 @@ class T5TokenizerOptions: } } + CATEGORY = "_for_testing/conditioning" RETURN_TYPES = ("CLIP",) FUNCTION = "set_options" From 05eb10b43a42929033f449def7cd5a8feeb84673 Mon Sep 17 00:00:00 2001 From: Christian Byrne Date: Sun, 18 May 2025 01:08:47 -0700 Subject: [PATCH 19/31] Validate video inputs (#8133) * validate kling lip sync input video * add tooltips * update duration estimates * decrease epsilon * fix rebase error --- comfy_api_nodes/nodes_kling.py | 51 ++++++------ comfy_api_nodes/util/__init__.py | 0 comfy_api_nodes/util/validation_utils.py | 100 +++++++++++++++++++++++ 3 files changed, 126 insertions(+), 25 deletions(-) create mode 100644 comfy_api_nodes/util/__init__.py create mode 100644 comfy_api_nodes/util/validation_utils.py diff --git a/comfy_api_nodes/nodes_kling.py b/comfy_api_nodes/nodes_kling.py index 456a8690..641cd635 100644 --- a/comfy_api_nodes/nodes_kling.py +++ b/comfy_api_nodes/nodes_kling.py @@ -65,6 +65,12 @@ from comfy_api_nodes.apinode_utils import ( download_url_to_image_tensor, ) from comfy_api_nodes.mapper_utils import model_field_to_node_input +from comfy_api_nodes.util.validation_utils import ( + validate_image_dimensions, + validate_image_aspect_ratio, + validate_video_dimensions, + validate_video_duration, +) from comfy_api.input.basic_types import AudioInput from comfy_api.input.video_types import VideoInput from comfy_api.input_impl import VideoFromFile @@ -80,18 +86,16 @@ PATH_CHARACTER_IMAGE = f"/proxy/kling/{KLING_API_VERSION}/images/generations" PATH_VIRTUAL_TRY_ON = f"/proxy/kling/{KLING_API_VERSION}/images/kolors-virtual-try-on" PATH_IMAGE_GENERATIONS = f"/proxy/kling/{KLING_API_VERSION}/images/generations" - MAX_PROMPT_LENGTH_T2V = 2500 MAX_PROMPT_LENGTH_I2V = 500 MAX_PROMPT_LENGTH_IMAGE_GEN = 500 MAX_NEGATIVE_PROMPT_LENGTH_IMAGE_GEN = 200 MAX_PROMPT_LENGTH_LIP_SYNC = 120 -# TODO: adjust based on tests -AVERAGE_DURATION_T2V = 319 # 319, -AVERAGE_DURATION_I2V = 164 # 164, -AVERAGE_DURATION_LIP_SYNC = 120 -AVERAGE_DURATION_VIRTUAL_TRY_ON = 19 # 19, +AVERAGE_DURATION_T2V = 319 +AVERAGE_DURATION_I2V = 164 +AVERAGE_DURATION_LIP_SYNC = 455 +AVERAGE_DURATION_VIRTUAL_TRY_ON = 19 AVERAGE_DURATION_IMAGE_GEN = 32 AVERAGE_DURATION_VIDEO_EFFECTS = 320 AVERAGE_DURATION_VIDEO_EXTEND = 320 @@ -211,23 +215,8 @@ def validate_input_image(image: torch.Tensor) -> None: See: https://app.klingai.com/global/dev/document-api/apiReference/model/imageToVideo """ - if len(image.shape) == 4: - height, width = image.shape[1], image.shape[2] - elif len(image.shape) == 3: - height, width = image.shape[0], image.shape[1] - else: - raise ValueError("Invalid image tensor shape.") - - # Ensure minimum resolution is met - if height < 300: - raise ValueError("Image height must be at least 300px") - if width < 300: - raise ValueError("Image width must be at least 300px") - - # Ensure aspect ratio is within acceptable range - aspect_ratio = width / height - if aspect_ratio < 1 / 2.5 or aspect_ratio > 2.5: - raise ValueError("Image aspect ratio must be between 1:2.5 and 2.5:1") + validate_image_dimensions(image, min_width=300, min_height=300) + validate_image_aspect_ratio(image, min_aspect_ratio=1 / 2.5, max_aspect_ratio=2.5) def get_camera_control_input_config( @@ -1243,6 +1232,17 @@ class KlingLipSyncBase(KlingNodeBase): RETURN_TYPES = ("VIDEO", "STRING", "STRING") RETURN_NAMES = ("VIDEO", "video_id", "duration") + def validate_lip_sync_video(self, video: VideoInput): + """ + Validates the input video adheres to the expectations of the Kling Lip Sync API: + - Video length does not exceed 10s and is not shorter than 2s + - Length and width dimensions should both be between 720px and 1920px + + See: https://app.klingai.com/global/dev/document-api/apiReference/model/videoTolip + """ + validate_video_dimensions(video, 720, 1920) + validate_video_duration(video, 2, 10) + def validate_text(self, text: str): if not text: raise ValueError("Text is required") @@ -1282,6 +1282,7 @@ class KlingLipSyncBase(KlingNodeBase): ) -> tuple[VideoFromFile, str, str]: if text: self.validate_text(text) + self.validate_lip_sync_video(video) # Upload video to Comfy API and get download URL video_url = upload_video_to_comfyapi(video, auth_kwargs=kwargs) @@ -1352,7 +1353,7 @@ class KlingLipSyncAudioToVideoNode(KlingLipSyncBase): }, } - DESCRIPTION = "Kling Lip Sync Audio to Video Node. Syncs mouth movements in a video file to the audio content of an audio file." + DESCRIPTION = "Kling Lip Sync Audio to Video Node. Syncs mouth movements in a video file to the audio content of an audio file. When using, ensure that the audio contains clearly distinguishable vocals and that the video contains a distinct face. The audio file should not be larger than 5MB. The video file should not be larger than 100MB, should have height/width between 720px and 1920px, and should be between 2s and 10s in length." def api_call( self, @@ -1464,7 +1465,7 @@ class KlingLipSyncTextToVideoNode(KlingLipSyncBase): }, } - DESCRIPTION = "Kling Lip Sync Text to Video Node. Syncs mouth movements in a video file to a text prompt." + DESCRIPTION = "Kling Lip Sync Text to Video Node. Syncs mouth movements in a video file to a text prompt. The video file should not be larger than 100MB, should have height/width between 720px and 1920px, and should be between 2s and 10s in length." def api_call( self, diff --git a/comfy_api_nodes/util/__init__.py b/comfy_api_nodes/util/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/comfy_api_nodes/util/validation_utils.py b/comfy_api_nodes/util/validation_utils.py new file mode 100644 index 00000000..031b9fbd --- /dev/null +++ b/comfy_api_nodes/util/validation_utils.py @@ -0,0 +1,100 @@ +import logging +from typing import Optional + +import torch +from comfy_api.input.video_types import VideoInput + + +def get_image_dimensions(image: torch.Tensor) -> tuple[int, int]: + if len(image.shape) == 4: + return image.shape[1], image.shape[2] + elif len(image.shape) == 3: + return image.shape[0], image.shape[1] + else: + raise ValueError("Invalid image tensor shape.") + + +def validate_image_dimensions( + image: torch.Tensor, + min_width: Optional[int] = None, + max_width: Optional[int] = None, + min_height: Optional[int] = None, + max_height: Optional[int] = None, +): + height, width = get_image_dimensions(image) + + if min_width is not None and width < min_width: + raise ValueError(f"Image width must be at least {min_width}px, got {width}px") + if max_width is not None and width > max_width: + raise ValueError(f"Image width must be at most {max_width}px, got {width}px") + if min_height is not None and height < min_height: + raise ValueError( + f"Image height must be at least {min_height}px, got {height}px" + ) + if max_height is not None and height > max_height: + raise ValueError(f"Image height must be at most {max_height}px, got {height}px") + + +def validate_image_aspect_ratio( + image: torch.Tensor, + min_aspect_ratio: Optional[float] = None, + max_aspect_ratio: Optional[float] = None, +): + width, height = get_image_dimensions(image) + aspect_ratio = width / height + + if min_aspect_ratio is not None and aspect_ratio < min_aspect_ratio: + raise ValueError( + f"Image aspect ratio must be at least {min_aspect_ratio}, got {aspect_ratio}" + ) + if max_aspect_ratio is not None and aspect_ratio > max_aspect_ratio: + raise ValueError( + f"Image aspect ratio must be at most {max_aspect_ratio}, got {aspect_ratio}" + ) + + +def validate_video_dimensions( + video: VideoInput, + min_width: Optional[int] = None, + max_width: Optional[int] = None, + min_height: Optional[int] = None, + max_height: Optional[int] = None, +): + try: + width, height = video.get_dimensions() + except Exception as e: + logging.error("Error getting dimensions of video: %s", e) + return + + if min_width is not None and width < min_width: + raise ValueError(f"Video width must be at least {min_width}px, got {width}px") + if max_width is not None and width > max_width: + raise ValueError(f"Video width must be at most {max_width}px, got {width}px") + if min_height is not None and height < min_height: + raise ValueError( + f"Video height must be at least {min_height}px, got {height}px" + ) + if max_height is not None and height > max_height: + raise ValueError(f"Video height must be at most {max_height}px, got {height}px") + + +def validate_video_duration( + video: VideoInput, + min_duration: Optional[float] = None, + max_duration: Optional[float] = None, +): + try: + duration = video.get_duration() + except Exception as e: + logging.error("Error getting duration of video: %s", e) + return + + epsilon = 0.0001 + if min_duration is not None and min_duration - epsilon > duration: + raise ValueError( + f"Video duration must be at least {min_duration}s, got {duration}s" + ) + if max_duration is not None and duration > max_duration + epsilon: + raise ValueError( + f"Video duration must be at most {max_duration}s, got {duration}s" + ) From 62690eddec9b7d715b4c37246f71abf5ca1c5844 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sun, 18 May 2025 01:09:56 -0700 Subject: [PATCH 20/31] Node to add pixel space noise to an image. (#8182) --- comfy_extras/nodes_images.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/comfy_extras/nodes_images.py b/comfy_extras/nodes_images.py index 77c30561..29a5d5b6 100644 --- a/comfy_extras/nodes_images.py +++ b/comfy_extras/nodes_images.py @@ -13,6 +13,7 @@ import os import re from io import BytesIO from inspect import cleandoc +import torch from comfy.comfy_types import FileLocator @@ -74,6 +75,24 @@ class ImageFromBatch: s = s_in[batch_index:batch_index + length].clone() return (s,) + +class ImageAddNoise: + @classmethod + def INPUT_TYPES(s): + return {"required": { "image": ("IMAGE",), + "seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff, "control_after_generate": True, "tooltip": "The random seed used for creating the noise."}), + "strength": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}), + }} + RETURN_TYPES = ("IMAGE",) + FUNCTION = "repeat" + + CATEGORY = "image" + + def repeat(self, image, seed, strength): + generator = torch.manual_seed(seed) + s = torch.clip((image + strength * torch.randn(image.size(), generator=generator, device="cpu").to(image)), min=0.0, max=1.0) + return (s,) + class SaveAnimatedWEBP: def __init__(self): self.output_dir = folder_paths.get_output_directory() @@ -295,6 +314,7 @@ NODE_CLASS_MAPPINGS = { "ImageCrop": ImageCrop, "RepeatImageBatch": RepeatImageBatch, "ImageFromBatch": ImageFromBatch, + "ImageAddNoise": ImageAddNoise, "SaveAnimatedWEBP": SaveAnimatedWEBP, "SaveAnimatedPNG": SaveAnimatedPNG, "SaveSVGNode": SaveSVGNode, From 3d44a09812c4f0880c30fcd1876125b7319300b4 Mon Sep 17 00:00:00 2001 From: LaVie024 <62406970+LaVie024@users.noreply.github.com> Date: Sun, 18 May 2025 08:11:11 +0000 Subject: [PATCH 21/31] Update nodes_string.py (#8173) --- comfy_extras/nodes_string.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/comfy_extras/nodes_string.py b/comfy_extras/nodes_string.py index a852326e..9eaa7123 100644 --- a/comfy_extras/nodes_string.py +++ b/comfy_extras/nodes_string.py @@ -8,7 +8,8 @@ class StringConcatenate(): return { "required": { "string_a": (IO.STRING, {"multiline": True}), - "string_b": (IO.STRING, {"multiline": True}) + "string_b": (IO.STRING, {"multiline": True}), + "delimiter": (IO.STRING, {"multiline": False, "default": ", "}) } } @@ -16,8 +17,8 @@ class StringConcatenate(): FUNCTION = "execute" CATEGORY = "utils/string" - def execute(self, string_a, string_b, **kwargs): - return string_a + string_b, + def execute(self, string_a, string_b, delimiter, **kwargs): + return delimiter.join((string_a, string_b)), class StringSubstring(): @classmethod From d8e5662822168101afb5e08a8ba75b6eefff6e02 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sun, 18 May 2025 01:12:12 -0700 Subject: [PATCH 22/31] Remove default delimiter. (#8183) --- comfy_extras/nodes_string.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy_extras/nodes_string.py b/comfy_extras/nodes_string.py index 9eaa7123..b24222ce 100644 --- a/comfy_extras/nodes_string.py +++ b/comfy_extras/nodes_string.py @@ -9,7 +9,7 @@ class StringConcatenate(): "required": { "string_a": (IO.STRING, {"multiline": True}), "string_b": (IO.STRING, {"multiline": True}), - "delimiter": (IO.STRING, {"multiline": False, "default": ", "}) + "delimiter": (IO.STRING, {"multiline": False, "default": ""}) } } From e930a387d62cc819117502993b0b821b1e3f2687 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 19 May 2025 01:58:41 -0700 Subject: [PATCH 23/31] Update AMD instructions in README. (#8198) --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 9b5f301c..15157f52 100644 --- a/README.md +++ b/README.md @@ -197,11 +197,11 @@ Put your VAE in: models/vae ### AMD GPUs (Linux only) AMD users can install rocm and pytorch with pip if you don't have it already installed, this is the command to install the stable version: -```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.2.4``` +```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.3``` This is the command to install the nightly with ROCm 6.3 which might have some performance improvements: -```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm6.3``` +```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm6.4``` ### Intel GPUs (Windows and Linux) From 4f3b50ba510e02fa3fdd8c755ef9ad319b36bd61 Mon Sep 17 00:00:00 2001 From: filtered <176114999+webfiltered@users.noreply.github.com> Date: Tue, 20 May 2025 06:40:55 +1000 Subject: [PATCH 24/31] Update README ROCm text to match link (#8199) - Follow-up on #8198 --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 15157f52..47514d1b 100644 --- a/README.md +++ b/README.md @@ -199,7 +199,7 @@ AMD users can install rocm and pytorch with pip if you don't have it already ins ```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.3``` -This is the command to install the nightly with ROCm 6.3 which might have some performance improvements: +This is the command to install the nightly with ROCm 6.4 which might have some performance improvements: ```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm6.4``` From 7e84bf53737879ace37a68dc93e0df7704a53514 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 20 May 2025 02:29:23 -0700 Subject: [PATCH 25/31] This doesn't seem to be needed on chroma. (#8209) --- comfy/ldm/chroma/layers.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/comfy/ldm/chroma/layers.py b/comfy/ldm/chroma/layers.py index 35da91ee..18a4a9cf 100644 --- a/comfy/ldm/chroma/layers.py +++ b/comfy/ldm/chroma/layers.py @@ -109,9 +109,6 @@ class DoubleStreamBlock(nn.Module): txt += txt_mod1.gate * self.txt_attn.proj(txt_attn) txt += txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift) - if txt.dtype == torch.float16: - txt = torch.nan_to_num(txt, nan=0.0, posinf=65504, neginf=-65504) - return img, txt @@ -163,8 +160,6 @@ class SingleStreamBlock(nn.Module): # compute activation in mlp stream, cat again and run second linear layer output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2)) x += mod.gate * output - if x.dtype == torch.float16: - x = torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504) return x From 87f91307782ce0b401786d8edddd8f618b955141 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 20 May 2025 02:39:55 -0700 Subject: [PATCH 26/31] Revert "This doesn't seem to be needed on chroma. (#8209)" (#8210) This reverts commit 7e84bf53737879ace37a68dc93e0df7704a53514. --- comfy/ldm/chroma/layers.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/comfy/ldm/chroma/layers.py b/comfy/ldm/chroma/layers.py index 18a4a9cf..35da91ee 100644 --- a/comfy/ldm/chroma/layers.py +++ b/comfy/ldm/chroma/layers.py @@ -109,6 +109,9 @@ class DoubleStreamBlock(nn.Module): txt += txt_mod1.gate * self.txt_attn.proj(txt_attn) txt += txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift) + if txt.dtype == torch.float16: + txt = torch.nan_to_num(txt, nan=0.0, posinf=65504, neginf=-65504) + return img, txt @@ -160,6 +163,8 @@ class SingleStreamBlock(nn.Module): # compute activation in mlp stream, cat again and run second linear layer output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2)) x += mod.gate * output + if x.dtype == torch.float16: + x = torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504) return x From 10024a38ea8d7e8950b26500a540cd0323d0e611 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 21 May 2025 04:50:37 -0400 Subject: [PATCH 27/31] ComfyUI version v0.3.35 --- comfyui_version.py | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/comfyui_version.py b/comfyui_version.py index b740b378..8db3bc80 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.3.34" +__version__ = "0.3.35" diff --git a/pyproject.toml b/pyproject.toml index 80061b39..a33fc437 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.3.34" +version = "0.3.35" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.9" From 65da29aaa965afcb0811a9c8dac1cc0facb006d4 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Wed, 21 May 2025 01:56:56 -0700 Subject: [PATCH 28/31] Make torch.compile LoRA/key-compatible (#8213) * Make torch compile node use wrapper instead of object_patch for the entire diffusion_models object, allowing key assotiations on diffusion_models to not break (loras, getting attributes, etc.) * Moved torch compile code into comfy_api so it can be used by custom nodes with a degree of confidence * Refactor set_torch_compile_wrapper to support a list of keys instead of just diffusion_model, as well as additional torch.compile args * remove unused import * Moved torch compile kwargs to be stored in model_options instead of attachments; attachments are more intended for things to be 'persisted', AKA not deepcopied * Add some comments * Remove random line of code, not sure how it got there --- comfy_api/torch_helpers/__init__.py | 5 ++ comfy_api/torch_helpers/torch_compile.py | 69 ++++++++++++++++++++++++ comfy_extras/nodes_torch_compile.py | 5 +- 3 files changed, 77 insertions(+), 2 deletions(-) create mode 100644 comfy_api/torch_helpers/__init__.py create mode 100644 comfy_api/torch_helpers/torch_compile.py diff --git a/comfy_api/torch_helpers/__init__.py b/comfy_api/torch_helpers/__init__.py new file mode 100644 index 00000000..be7ae7a6 --- /dev/null +++ b/comfy_api/torch_helpers/__init__.py @@ -0,0 +1,5 @@ +from .torch_compile import set_torch_compile_wrapper + +__all__ = [ + "set_torch_compile_wrapper", +] diff --git a/comfy_api/torch_helpers/torch_compile.py b/comfy_api/torch_helpers/torch_compile.py new file mode 100644 index 00000000..9223f58d --- /dev/null +++ b/comfy_api/torch_helpers/torch_compile.py @@ -0,0 +1,69 @@ +from __future__ import annotations +import torch + +import comfy.utils +from comfy.patcher_extension import WrappersMP +from typing import TYPE_CHECKING, Callable, Optional +if TYPE_CHECKING: + from comfy.model_patcher import ModelPatcher + from comfy.patcher_extension import WrapperExecutor + + +COMPILE_KEY = "torch.compile" +TORCH_COMPILE_KWARGS = "torch_compile_kwargs" + + +def apply_torch_compile_factory(compiled_module_dict: dict[str, Callable]) -> Callable: + ''' + Create a wrapper that will refer to the compiled_diffusion_model. + ''' + def apply_torch_compile_wrapper(executor: WrapperExecutor, *args, **kwargs): + try: + orig_modules = {} + for key, value in compiled_module_dict.items(): + orig_modules[key] = comfy.utils.get_attr(executor.class_obj, key) + comfy.utils.set_attr(executor.class_obj, key, value) + return executor(*args, **kwargs) + finally: + for key, value in orig_modules.items(): + comfy.utils.set_attr(executor.class_obj, key, value) + return apply_torch_compile_wrapper + + +def set_torch_compile_wrapper(model: ModelPatcher, backend: str, options: Optional[dict[str,str]]=None, + mode: Optional[str]=None, fullgraph=False, dynamic: Optional[bool]=None, + keys: list[str]=["diffusion_model"], *args, **kwargs): + ''' + Perform torch.compile that will be applied at sample time for either the whole model or specific params of the BaseModel instance. + + When keys is None, it will default to using ["diffusion_model"], compiling the whole diffusion_model. + When a list of keys is provided, it will perform torch.compile on only the selected modules. + ''' + # clear out any other torch.compile wrappers + model.remove_wrappers_with_key(WrappersMP.APPLY_MODEL, COMPILE_KEY) + # if no keys, default to 'diffusion_model' + if not keys: + keys = ["diffusion_model"] + # create kwargs dict that can be referenced later + compile_kwargs = { + "backend": backend, + "options": options, + "mode": mode, + "fullgraph": fullgraph, + "dynamic": dynamic, + } + # get a dict of compiled keys + compiled_modules = {} + for key in keys: + compiled_modules[key] = torch.compile( + model=model.get_model_object(key), + **compile_kwargs, + ) + # add torch.compile wrapper + wrapper_func = apply_torch_compile_factory( + compiled_module_dict=compiled_modules, + ) + # store wrapper to run on BaseModel's apply_model function + model.add_wrapper_with_key(WrappersMP.APPLY_MODEL, COMPILE_KEY, wrapper_func) + # keep compile kwargs for reference + model.model_options[TORCH_COMPILE_KWARGS] = compile_kwargs diff --git a/comfy_extras/nodes_torch_compile.py b/comfy_extras/nodes_torch_compile.py index 1fe6f42c..60553667 100644 --- a/comfy_extras/nodes_torch_compile.py +++ b/comfy_extras/nodes_torch_compile.py @@ -1,4 +1,5 @@ -import torch +from comfy_api.torch_helpers import set_torch_compile_wrapper + class TorchCompileModel: @classmethod @@ -14,7 +15,7 @@ class TorchCompileModel: def patch(self, model, backend): m = model.clone() - m.add_object_patch("diffusion_model", torch.compile(model=m.get_model_object("diffusion_model"), backend=backend)) + set_torch_compile_wrapper(model=m, backend=backend) return (m, ) NODE_CLASS_MAPPINGS = { From 57893c843f44ea9e8a0be79292d19e5a5e16e9e2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=BC=96=E7=A8=8B=E7=95=8C=E7=9A=84=E5=B0=8F=E5=AD=A6?= =?UTF-8?q?=E7=94=9F?= <15620646321@163.com> Date: Wed, 21 May 2025 16:59:42 +0800 Subject: [PATCH 29/31] Code Optimization and Issues Fixes in ComfyUI server (#8196) * Update server.py * Update server.py --- server.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/server.py b/server.py index cb1c6a8f..16cd88d9 100644 --- a/server.py +++ b/server.py @@ -226,7 +226,7 @@ class PromptServer(): return response @routes.get("/embeddings") - def get_embeddings(self): + def get_embeddings(request): embeddings = folder_paths.get_filename_list("embeddings") return web.json_response(list(map(lambda a: os.path.splitext(a)[0], embeddings))) @@ -282,7 +282,6 @@ class PromptServer(): a.update(f.read()) b.update(image.file.read()) image.file.seek(0) - f.close() return a.hexdigest() == b.hexdigest() return False From 8bb858e4d39f7f6a6969c584aeeaa1d606a812d6 Mon Sep 17 00:00:00 2001 From: Michael Abrahams Date: Wed, 21 May 2025 05:14:17 -0400 Subject: [PATCH 30/31] Improve performance with large number of queued prompts (#8176) * get_current_queue_volatile * restore get_current_queue method * remove extra import --- execution.py | 9 ++++++++- main.py | 3 +-- server.py | 5 +++-- 3 files changed, 12 insertions(+), 5 deletions(-) diff --git a/execution.py b/execution.py index e5d1c69d..15ff7567 100644 --- a/execution.py +++ b/execution.py @@ -909,7 +909,6 @@ class PromptQueue: self.currently_running = {} self.history = {} self.flags = {} - server.prompt_queue = self def put(self, item): with self.mutex: @@ -954,6 +953,7 @@ class PromptQueue: self.history[prompt[1]].update(history_result) self.server.queue_updated() + # Note: slow def get_current_queue(self): with self.mutex: out = [] @@ -961,6 +961,13 @@ class PromptQueue: out += [x] return (out, copy.deepcopy(self.queue)) + # read-safe as long as queue items are immutable + def get_current_queue_volatile(self): + with self.mutex: + running = [x for x in self.currently_running.values()] + queued = copy.copy(self.queue) + return (running, queued) + def get_tasks_remaining(self): with self.mutex: return len(self.queue) + len(self.currently_running) diff --git a/main.py b/main.py index 0fde6d22..fb1f8d20 100644 --- a/main.py +++ b/main.py @@ -260,7 +260,6 @@ def start_comfyui(asyncio_loop=None): asyncio_loop = asyncio.new_event_loop() asyncio.set_event_loop(asyncio_loop) prompt_server = server.PromptServer(asyncio_loop) - q = execution.PromptQueue(prompt_server) hook_breaker_ac10a0.save_functions() nodes.init_extra_nodes(init_custom_nodes=not args.disable_all_custom_nodes, init_api_nodes=not args.disable_api_nodes) @@ -271,7 +270,7 @@ def start_comfyui(asyncio_loop=None): prompt_server.add_routes() hijack_progress(prompt_server) - threading.Thread(target=prompt_worker, daemon=True, args=(q, prompt_server,)).start() + threading.Thread(target=prompt_worker, daemon=True, args=(prompt_server.prompt_queue, prompt_server,)).start() if args.quick_test_for_ci: exit(0) diff --git a/server.py b/server.py index 16cd88d9..1b0a7360 100644 --- a/server.py +++ b/server.py @@ -29,6 +29,7 @@ import comfy.model_management import node_helpers from comfyui_version import __version__ from app.frontend_management import FrontendManager + from app.user_manager import UserManager from app.model_manager import ModelFileManager from app.custom_node_manager import CustomNodeManager @@ -159,7 +160,7 @@ class PromptServer(): self.custom_node_manager = CustomNodeManager() self.internal_routes = InternalRoutes(self) self.supports = ["custom_nodes_from_web"] - self.prompt_queue = None + self.prompt_queue = execution.PromptQueue(self) self.loop = loop self.messages = asyncio.Queue() self.client_session:Optional[aiohttp.ClientSession] = None @@ -620,7 +621,7 @@ class PromptServer(): @routes.get("/queue") async def get_queue(request): queue_info = {} - current_queue = self.prompt_queue.get_current_queue() + current_queue = self.prompt_queue.get_current_queue_volatile() queue_info['queue_running'] = current_queue[0] queue_info['queue_pending'] = current_queue[1] return web.json_response(queue_info) From ded60c33a0c0231c7109f30072c25e64e360e636 Mon Sep 17 00:00:00 2001 From: ComfyUI Wiki Date: Thu, 22 May 2025 02:40:08 +0800 Subject: [PATCH 31/31] Update templates to 0.1.18 (#8224) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 8f7a7898..858e6343 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.19.9 -comfyui-workflow-templates==0.1.14 +comfyui-workflow-templates==0.1.18 torch torchsde torchvision