Merge branch 'master' into v3-definition - async v3 nodes do not currently work, but I will fix that in the next v3 PR

This commit is contained in:
Jedrzej Kosinski 2025-07-18 14:14:02 -07:00
commit fd9c34a3eb
37 changed files with 2437 additions and 480 deletions

View File

@ -0,0 +1,40 @@
name: Check for Windows Line Endings
on:
pull_request:
branches: ['*'] # Trigger on all pull requests to any branch
jobs:
check-line-endings:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
with:
fetch-depth: 0 # Fetch all history to compare changes
- name: Check for Windows line endings (CRLF)
run: |
# Get the list of changed files in the PR
CHANGED_FILES=$(git diff --name-only origin/${{ github.base_ref }}..HEAD)
# Flag to track if CRLF is found
CRLF_FOUND=false
# Loop through each changed file
for FILE in $CHANGED_FILES; do
# Check if the file exists and is a text file
if [ -f "$FILE" ] && file "$FILE" | grep -q "text"; then
# Check for CRLF line endings
if grep -UP '\r$' "$FILE"; then
echo "Error: Windows line endings (CRLF) detected in $FILE"
CRLF_FOUND=true
fi
fi
done
# Exit with error if CRLF was found
if [ "$CRLF_FOUND" = true ]; then
exit 1
fi

View File

@ -144,6 +144,7 @@ class PerformanceFeature(enum.Enum):
parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: fp16_accumulation fp8_matrix_mult cublas_ops") parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: fp16_accumulation fp8_matrix_mult cublas_ops")
parser.add_argument("--mmap-torch-files", action="store_true", help="Use mmap when loading ckpt/pt files.") parser.add_argument("--mmap-torch-files", action="store_true", help="Use mmap when loading ckpt/pt files.")
parser.add_argument("--disable-mmap", action="store_true", help="Don't use mmap when loading safetensors.")
parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.") parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.")
parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.") parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.")

View File

@ -973,7 +973,7 @@ class VideoVAE(nn.Module):
norm_layer=config.get("norm_layer", "group_norm"), norm_layer=config.get("norm_layer", "group_norm"),
causal=config.get("causal_decoder", False), causal=config.get("causal_decoder", False),
timestep_conditioning=self.timestep_conditioning, timestep_conditioning=self.timestep_conditioning,
spatial_padding_mode=config.get("spatial_padding_mode", "zeros"), spatial_padding_mode=config.get("spatial_padding_mode", "reflect"),
) )
self.per_channel_statistics = processor() self.per_channel_statistics = processor()

View File

@ -18,6 +18,7 @@ import comfy.ldm.hunyuan3d.vae
import comfy.ldm.ace.vae.music_dcae_pipeline import comfy.ldm.ace.vae.music_dcae_pipeline
import yaml import yaml
import math import math
import os
import comfy.utils import comfy.utils
@ -977,6 +978,12 @@ def load_gligen(ckpt_path):
model = model.half() model = model.half()
return comfy.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device()) return comfy.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device())
def model_detection_error_hint(path, state_dict):
filename = os.path.basename(path)
if 'lora' in filename.lower():
return "\nHINT: This seems to be a Lora file and Lora files should be put in the lora folder and loaded with a lora loader node.."
return ""
def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_clip=True, embedding_directory=None, state_dict=None, config=None): def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_clip=True, embedding_directory=None, state_dict=None, config=None):
logging.warning("Warning: The load checkpoint with config function is deprecated and will eventually be removed, please use the other one.") logging.warning("Warning: The load checkpoint with config function is deprecated and will eventually be removed, please use the other one.")
model, clip, vae, _ = load_checkpoint_guess_config(ckpt_path, output_vae=output_vae, output_clip=output_clip, output_clipvision=False, embedding_directory=embedding_directory, output_model=True) model, clip, vae, _ = load_checkpoint_guess_config(ckpt_path, output_vae=output_vae, output_clip=output_clip, output_clipvision=False, embedding_directory=embedding_directory, output_model=True)
@ -1005,7 +1012,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
sd, metadata = comfy.utils.load_torch_file(ckpt_path, return_metadata=True) sd, metadata = comfy.utils.load_torch_file(ckpt_path, return_metadata=True)
out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, te_model_options=te_model_options, metadata=metadata) out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, te_model_options=te_model_options, metadata=metadata)
if out is None: if out is None:
raise RuntimeError("ERROR: Could not detect model type of: {}".format(ckpt_path)) raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(ckpt_path, model_detection_error_hint(ckpt_path, sd)))
return out return out
def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}, metadata=None): def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}, metadata=None):
@ -1177,7 +1184,7 @@ def load_diffusion_model(unet_path, model_options={}):
model = load_diffusion_model_state_dict(sd, model_options=model_options) model = load_diffusion_model_state_dict(sd, model_options=model_options)
if model is None: if model is None:
logging.error("ERROR UNSUPPORTED DIFFUSION MODEL {}".format(unet_path)) logging.error("ERROR UNSUPPORTED DIFFUSION MODEL {}".format(unet_path))
raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path)) raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(unet_path, model_detection_error_hint(unet_path, sd)))
return model return model
def load_unet(unet_path, dtype=None): def load_unet(unet_path, dtype=None):

View File

@ -18,7 +18,7 @@
"single_word": false "single_word": false
}, },
"errors": "replace", "errors": "replace",
"model_max_length": 77, "model_max_length": 8192,
"name_or_path": "openai/clip-vit-large-patch14", "name_or_path": "openai/clip-vit-large-patch14",
"pad_token": "<|endoftext|>", "pad_token": "<|endoftext|>",
"special_tokens_map_file": "./special_tokens_map.json", "special_tokens_map_file": "./special_tokens_map.json",

View File

@ -1214,7 +1214,7 @@ class Omnigen2(supported_models_base.BASE):
def clip_target(self, state_dict={}): def clip_target(self, state_dict={}):
pref = self.text_encoder_key_prefix[0] pref = self.text_encoder_key_prefix[0]
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_3b.transformer.".format(pref)) hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_3b.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.omnigen2.LuminaTokenizer, comfy.text_encoders.omnigen2.te(**hunyuan_detect)) return supported_models_base.ClipTarget(comfy.text_encoders.omnigen2.Omnigen2Tokenizer, comfy.text_encoders.omnigen2.te(**hunyuan_detect))
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep, Omnigen2] models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep, Omnigen2]

View File

@ -31,6 +31,7 @@ from einops import rearrange
from comfy.cli_args import args from comfy.cli_args import args
MMAP_TORCH_FILES = args.mmap_torch_files MMAP_TORCH_FILES = args.mmap_torch_files
DISABLE_MMAP = args.disable_mmap
ALWAYS_SAFE_LOAD = False ALWAYS_SAFE_LOAD = False
if hasattr(torch.serialization, "add_safe_globals"): # TODO: this was added in pytorch 2.4, the unsafe path should be removed once earlier versions are deprecated if hasattr(torch.serialization, "add_safe_globals"): # TODO: this was added in pytorch 2.4, the unsafe path should be removed once earlier versions are deprecated
@ -58,7 +59,10 @@ def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
with safetensors.safe_open(ckpt, framework="pt", device=device.type) as f: with safetensors.safe_open(ckpt, framework="pt", device=device.type) as f:
sd = {} sd = {}
for k in f.keys(): for k in f.keys():
sd[k] = f.get_tensor(k) tensor = f.get_tensor(k)
if DISABLE_MMAP: # TODO: Not sure if this is the best way to bypass the mmap issues
tensor = tensor.to(device=device, copy=True)
sd[k] = tensor
if return_metadata: if return_metadata:
metadata = f.metadata() metadata = f.metadata()
except Exception as e: except Exception as e:
@ -998,11 +1002,12 @@ def set_progress_bar_global_hook(function):
PROGRESS_BAR_HOOK = function PROGRESS_BAR_HOOK = function
class ProgressBar: class ProgressBar:
def __init__(self, total): def __init__(self, total, node_id=None):
global PROGRESS_BAR_HOOK global PROGRESS_BAR_HOOK
self.total = total self.total = total
self.current = 0 self.current = 0
self.hook = PROGRESS_BAR_HOOK self.hook = PROGRESS_BAR_HOOK
self.node_id = node_id
def update_absolute(self, value, total=None, preview=None): def update_absolute(self, value, total=None, preview=None):
if total is not None: if total is not None:
@ -1011,7 +1016,7 @@ class ProgressBar:
value = self.total value = self.total
self.current = value self.current = value
if self.hook is not None: if self.hook is not None:
self.hook(self.current, self.total, preview) self.hook(self.current, self.total, preview, node_id=self.node_id)
def update(self, value): def update(self, value):
self.update_absolute(self.current + value) self.update_absolute(self.current + value)

View File

@ -0,0 +1,69 @@
"""
Feature flags module for ComfyUI WebSocket protocol negotiation.
This module handles capability negotiation between frontend and backend,
allowing graceful protocol evolution while maintaining backward compatibility.
"""
from typing import Any, Dict
from comfy.cli_args import args
# Default server capabilities
SERVER_FEATURE_FLAGS: Dict[str, Any] = {
"supports_preview_metadata": True,
"max_upload_size": args.max_upload_size * 1024 * 1024, # Convert MB to bytes
}
def get_connection_feature(
sockets_metadata: Dict[str, Dict[str, Any]],
sid: str,
feature_name: str,
default: Any = False
) -> Any:
"""
Get a feature flag value for a specific connection.
Args:
sockets_metadata: Dictionary of socket metadata
sid: Session ID of the connection
feature_name: Name of the feature to check
default: Default value if feature not found
Returns:
Feature value or default if not found
"""
if sid not in sockets_metadata:
return default
return sockets_metadata[sid].get("feature_flags", {}).get(feature_name, default)
def supports_feature(
sockets_metadata: Dict[str, Dict[str, Any]],
sid: str,
feature_name: str
) -> bool:
"""
Check if a connection supports a specific feature.
Args:
sockets_metadata: Dictionary of socket metadata
sid: Session ID of the connection
feature_name: Name of the feature to check
Returns:
Boolean indicating if feature is supported
"""
return get_connection_feature(sockets_metadata, sid, feature_name, False) is True
def get_server_features() -> Dict[str, Any]:
"""
Get the server's feature flags.
Returns:
Dictionary of server feature flags
"""
return SERVER_FEATURE_FLAGS.copy()

View File

@ -112,3 +112,15 @@ def lock_class(cls):
locked_dict['__setattr__'] = locked_instance_setattr locked_dict['__setattr__'] = locked_instance_setattr
return LockedMeta(cls.__name__, cls.__bases__, locked_dict) return LockedMeta(cls.__name__, cls.__bases__, locked_dict)
def make_locked_method_func(type_obj, func, class_clone):
"""
Returns a function that, when called with **inputs, will execute:
getattr(type_obj, func).__func__(lock_class(class_clone), **inputs)
"""
locked_class = lock_class(class_clone)
method = getattr(type_obj, func).__func__
def wrapped_func(**inputs):
return method(locked_class, **inputs)
return wrapped_func

View File

@ -406,7 +406,7 @@ class GeminiInputFiles(ComfyNodeABC):
def create_file_part(self, file_path: str) -> GeminiPart: def create_file_part(self, file_path: str) -> GeminiPart:
mime_type = ( mime_type = (
GeminiMimeType.pdf GeminiMimeType.application_pdf
if file_path.endswith(".pdf") if file_path.endswith(".pdf")
else GeminiMimeType.text_plain else GeminiMimeType.text_plain
) )

View File

@ -132,6 +132,8 @@ def poll_until_finished(
result_url_extractor=result_url_extractor, result_url_extractor=result_url_extractor,
estimated_duration=estimated_duration, estimated_duration=estimated_duration,
node_id=node_id, node_id=node_id,
poll_interval=16.0,
max_poll_attempts=256,
).execute() ).execute()

View File

@ -1,6 +1,7 @@
import itertools import itertools
from typing import Sequence, Mapping, Dict from typing import Sequence, Mapping, Dict
from comfy_execution.graph import DynamicPrompt from comfy_execution.graph import DynamicPrompt
from abc import ABC, abstractmethod
import nodes import nodes
@ -16,12 +17,13 @@ def include_unique_id_in_input(class_type: str) -> bool:
NODE_CLASS_CONTAINS_UNIQUE_ID[class_type] = "UNIQUE_ID" in class_def.INPUT_TYPES().get("hidden", {}).values() NODE_CLASS_CONTAINS_UNIQUE_ID[class_type] = "UNIQUE_ID" in class_def.INPUT_TYPES().get("hidden", {}).values()
return NODE_CLASS_CONTAINS_UNIQUE_ID[class_type] return NODE_CLASS_CONTAINS_UNIQUE_ID[class_type]
class CacheKeySet: class CacheKeySet(ABC):
def __init__(self, dynprompt, node_ids, is_changed_cache): def __init__(self, dynprompt, node_ids, is_changed_cache):
self.keys = {} self.keys = {}
self.subcache_keys = {} self.subcache_keys = {}
def add_keys(self, node_ids): @abstractmethod
async def add_keys(self, node_ids):
raise NotImplementedError() raise NotImplementedError()
def all_node_ids(self): def all_node_ids(self):
@ -60,9 +62,8 @@ class CacheKeySetID(CacheKeySet):
def __init__(self, dynprompt, node_ids, is_changed_cache): def __init__(self, dynprompt, node_ids, is_changed_cache):
super().__init__(dynprompt, node_ids, is_changed_cache) super().__init__(dynprompt, node_ids, is_changed_cache)
self.dynprompt = dynprompt self.dynprompt = dynprompt
self.add_keys(node_ids)
def add_keys(self, node_ids): async def add_keys(self, node_ids):
for node_id in node_ids: for node_id in node_ids:
if node_id in self.keys: if node_id in self.keys:
continue continue
@ -77,37 +78,36 @@ class CacheKeySetInputSignature(CacheKeySet):
super().__init__(dynprompt, node_ids, is_changed_cache) super().__init__(dynprompt, node_ids, is_changed_cache)
self.dynprompt = dynprompt self.dynprompt = dynprompt
self.is_changed_cache = is_changed_cache self.is_changed_cache = is_changed_cache
self.add_keys(node_ids)
def include_node_id_in_input(self) -> bool: def include_node_id_in_input(self) -> bool:
return False return False
def add_keys(self, node_ids): async def add_keys(self, node_ids):
for node_id in node_ids: for node_id in node_ids:
if node_id in self.keys: if node_id in self.keys:
continue continue
if not self.dynprompt.has_node(node_id): if not self.dynprompt.has_node(node_id):
continue continue
node = self.dynprompt.get_node(node_id) node = self.dynprompt.get_node(node_id)
self.keys[node_id] = self.get_node_signature(self.dynprompt, node_id) self.keys[node_id] = await self.get_node_signature(self.dynprompt, node_id)
self.subcache_keys[node_id] = (node_id, node["class_type"]) self.subcache_keys[node_id] = (node_id, node["class_type"])
def get_node_signature(self, dynprompt, node_id): async def get_node_signature(self, dynprompt, node_id):
signature = [] signature = []
ancestors, order_mapping = self.get_ordered_ancestry(dynprompt, node_id) ancestors, order_mapping = self.get_ordered_ancestry(dynprompt, node_id)
signature.append(self.get_immediate_node_signature(dynprompt, node_id, order_mapping)) signature.append(await self.get_immediate_node_signature(dynprompt, node_id, order_mapping))
for ancestor_id in ancestors: for ancestor_id in ancestors:
signature.append(self.get_immediate_node_signature(dynprompt, ancestor_id, order_mapping)) signature.append(await self.get_immediate_node_signature(dynprompt, ancestor_id, order_mapping))
return to_hashable(signature) return to_hashable(signature)
def get_immediate_node_signature(self, dynprompt, node_id, ancestor_order_mapping): async def get_immediate_node_signature(self, dynprompt, node_id, ancestor_order_mapping):
if not dynprompt.has_node(node_id): if not dynprompt.has_node(node_id):
# This node doesn't exist -- we can't cache it. # This node doesn't exist -- we can't cache it.
return [float("NaN")] return [float("NaN")]
node = dynprompt.get_node(node_id) node = dynprompt.get_node(node_id)
class_type = node["class_type"] class_type = node["class_type"]
class_def = nodes.NODE_CLASS_MAPPINGS[class_type] class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
signature = [class_type, self.is_changed_cache.get(node_id)] signature = [class_type, await self.is_changed_cache.get(node_id)]
if self.include_node_id_in_input() or (hasattr(class_def, "NOT_IDEMPOTENT") and class_def.NOT_IDEMPOTENT) or include_unique_id_in_input(class_type): if self.include_node_id_in_input() or (hasattr(class_def, "NOT_IDEMPOTENT") and class_def.NOT_IDEMPOTENT) or include_unique_id_in_input(class_type):
signature.append(node_id) signature.append(node_id)
inputs = node["inputs"] inputs = node["inputs"]
@ -150,9 +150,10 @@ class BasicCache:
self.cache = {} self.cache = {}
self.subcaches = {} self.subcaches = {}
def set_prompt(self, dynprompt, node_ids, is_changed_cache): async def set_prompt(self, dynprompt, node_ids, is_changed_cache):
self.dynprompt = dynprompt self.dynprompt = dynprompt
self.cache_key_set = self.key_class(dynprompt, node_ids, is_changed_cache) self.cache_key_set = self.key_class(dynprompt, node_ids, is_changed_cache)
await self.cache_key_set.add_keys(node_ids)
self.is_changed_cache = is_changed_cache self.is_changed_cache = is_changed_cache
self.initialized = True self.initialized = True
@ -201,13 +202,13 @@ class BasicCache:
else: else:
return None return None
def _ensure_subcache(self, node_id, children_ids): async def _ensure_subcache(self, node_id, children_ids):
subcache_key = self.cache_key_set.get_subcache_key(node_id) subcache_key = self.cache_key_set.get_subcache_key(node_id)
subcache = self.subcaches.get(subcache_key, None) subcache = self.subcaches.get(subcache_key, None)
if subcache is None: if subcache is None:
subcache = BasicCache(self.key_class) subcache = BasicCache(self.key_class)
self.subcaches[subcache_key] = subcache self.subcaches[subcache_key] = subcache
subcache.set_prompt(self.dynprompt, children_ids, self.is_changed_cache) await subcache.set_prompt(self.dynprompt, children_ids, self.is_changed_cache)
return subcache return subcache
def _get_subcache(self, node_id): def _get_subcache(self, node_id):
@ -259,10 +260,10 @@ class HierarchicalCache(BasicCache):
assert cache is not None assert cache is not None
cache._set_immediate(node_id, value) cache._set_immediate(node_id, value)
def ensure_subcache_for(self, node_id, children_ids): async def ensure_subcache_for(self, node_id, children_ids):
cache = self._get_cache_for(node_id) cache = self._get_cache_for(node_id)
assert cache is not None assert cache is not None
return cache._ensure_subcache(node_id, children_ids) return await cache._ensure_subcache(node_id, children_ids)
class LRUCache(BasicCache): class LRUCache(BasicCache):
def __init__(self, key_class, max_size=100): def __init__(self, key_class, max_size=100):
@ -273,8 +274,8 @@ class LRUCache(BasicCache):
self.used_generation = {} self.used_generation = {}
self.children = {} self.children = {}
def set_prompt(self, dynprompt, node_ids, is_changed_cache): async def set_prompt(self, dynprompt, node_ids, is_changed_cache):
super().set_prompt(dynprompt, node_ids, is_changed_cache) await super().set_prompt(dynprompt, node_ids, is_changed_cache)
self.generation += 1 self.generation += 1
for node_id in node_ids: for node_id in node_ids:
self._mark_used(node_id) self._mark_used(node_id)
@ -303,11 +304,11 @@ class LRUCache(BasicCache):
self._mark_used(node_id) self._mark_used(node_id)
return self._set_immediate(node_id, value) return self._set_immediate(node_id, value)
def ensure_subcache_for(self, node_id, children_ids): async def ensure_subcache_for(self, node_id, children_ids):
# Just uses subcaches for tracking 'live' nodes # Just uses subcaches for tracking 'live' nodes
super()._ensure_subcache(node_id, children_ids) await super()._ensure_subcache(node_id, children_ids)
self.cache_key_set.add_keys(children_ids) await self.cache_key_set.add_keys(children_ids)
self._mark_used(node_id) self._mark_used(node_id)
cache_key = self.cache_key_set.get_data_key(node_id) cache_key = self.cache_key_set.get_data_key(node_id)
self.children[cache_key] = [] self.children[cache_key] = []
@ -337,7 +338,7 @@ class DependencyAwareCache(BasicCache):
self.ancestors = {} # Maps node_id -> set of ancestor node_ids self.ancestors = {} # Maps node_id -> set of ancestor node_ids
self.executed_nodes = set() # Tracks nodes that have been executed self.executed_nodes = set() # Tracks nodes that have been executed
def set_prompt(self, dynprompt, node_ids, is_changed_cache): async def set_prompt(self, dynprompt, node_ids, is_changed_cache):
""" """
Clear the entire cache and rebuild the dependency graph. Clear the entire cache and rebuild the dependency graph.
@ -354,7 +355,7 @@ class DependencyAwareCache(BasicCache):
self.executed_nodes.clear() self.executed_nodes.clear()
# Call the parent method to initialize the cache with the new prompt # Call the parent method to initialize the cache with the new prompt
super().set_prompt(dynprompt, node_ids, is_changed_cache) await super().set_prompt(dynprompt, node_ids, is_changed_cache)
# Rebuild the dependency graph # Rebuild the dependency graph
self._build_dependency_graph(dynprompt, node_ids) self._build_dependency_graph(dynprompt, node_ids)
@ -405,7 +406,7 @@ class DependencyAwareCache(BasicCache):
""" """
return self._get_immediate(node_id) return self._get_immediate(node_id)
def ensure_subcache_for(self, node_id, children_ids): async def ensure_subcache_for(self, node_id, children_ids):
""" """
Ensure a subcache exists for a node and update dependencies. Ensure a subcache exists for a node and update dependencies.
@ -416,7 +417,7 @@ class DependencyAwareCache(BasicCache):
Returns: Returns:
The subcache object for the node. The subcache object for the node.
""" """
subcache = super()._ensure_subcache(node_id, children_ids) subcache = await super()._ensure_subcache(node_id, children_ids)
for child_id in children_ids: for child_id in children_ids:
self.descendants[node_id].add(child_id) self.descendants[node_id].add(child_id)
self.ancestors[child_id].add(node_id) self.ancestors[child_id].add(node_id)

View File

@ -2,6 +2,8 @@ from __future__ import annotations
from typing import Type, Literal from typing import Type, Literal
import nodes import nodes
import asyncio
import inspect
from comfy_execution.graph_utils import is_link from comfy_execution.graph_utils import is_link
from comfy.comfy_types.node_typing import ComfyNodeABC, InputTypeDict, InputTypeOptions from comfy.comfy_types.node_typing import ComfyNodeABC, InputTypeDict, InputTypeOptions
@ -100,6 +102,8 @@ class TopologicalSort:
self.pendingNodes = {} self.pendingNodes = {}
self.blockCount = {} # Number of nodes this node is directly blocked by self.blockCount = {} # Number of nodes this node is directly blocked by
self.blocking = {} # Which nodes are blocked by this node self.blocking = {} # Which nodes are blocked by this node
self.externalBlocks = 0
self.unblockedEvent = asyncio.Event()
def get_input_info(self, unique_id, input_name): def get_input_info(self, unique_id, input_name):
class_type = self.dynprompt.get_node(unique_id)["class_type"] class_type = self.dynprompt.get_node(unique_id)["class_type"]
@ -153,6 +157,16 @@ class TopologicalSort:
for link in links: for link in links:
self.add_strong_link(*link) self.add_strong_link(*link)
def add_external_block(self, node_id):
assert node_id in self.blockCount, "Can't add external block to a node that isn't pending"
self.externalBlocks += 1
self.blockCount[node_id] += 1
def unblock():
self.externalBlocks -= 1
self.blockCount[node_id] -= 1
self.unblockedEvent.set()
return unblock
def is_cached(self, node_id): def is_cached(self, node_id):
return False return False
@ -181,11 +195,16 @@ class ExecutionList(TopologicalSort):
def is_cached(self, node_id): def is_cached(self, node_id):
return self.output_cache.get(node_id) is not None return self.output_cache.get(node_id) is not None
def stage_node_execution(self): async def stage_node_execution(self):
assert self.staged_node_id is None assert self.staged_node_id is None
if self.is_empty(): if self.is_empty():
return None, None, None return None, None, None
available = self.get_ready_nodes() available = self.get_ready_nodes()
while len(available) == 0 and self.externalBlocks > 0:
# Wait for an external block to be released
await self.unblockedEvent.wait()
self.unblockedEvent.clear()
available = self.get_ready_nodes()
if len(available) == 0: if len(available) == 0:
cycled_nodes = self.get_nodes_in_cycle() cycled_nodes = self.get_nodes_in_cycle()
# Because cycles composed entirely of static nodes are caught during initial validation, # Because cycles composed entirely of static nodes are caught during initial validation,
@ -221,8 +240,15 @@ class ExecutionList(TopologicalSort):
return True return True
return False return False
# If an available node is async, do that first.
# This will execute the asynchronous function earlier, reducing the overall time.
def is_async(node_id):
class_type = self.dynprompt.get_node(node_id)["class_type"]
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
return inspect.iscoroutinefunction(getattr(class_def, class_def.FUNCTION))
for node_id in node_list: for node_id in node_list:
if is_output(node_id): if is_output(node_id) or is_async(node_id):
return node_id return node_id
#This should handle the VAEDecode -> preview case #This should handle the VAEDecode -> preview case

347
comfy_execution/progress.py Normal file
View File

@ -0,0 +1,347 @@
from typing import TypedDict, Dict, Optional
from typing_extensions import override
from PIL import Image
from enum import Enum
from abc import ABC
from tqdm import tqdm
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from comfy_execution.graph import DynamicPrompt
from protocol import BinaryEventTypes
from comfy_api import feature_flags
class NodeState(Enum):
Pending = "pending"
Running = "running"
Finished = "finished"
Error = "error"
class NodeProgressState(TypedDict):
"""
A class to represent the state of a node's progress.
"""
state: NodeState
value: float
max: float
class ProgressHandler(ABC):
"""
Abstract base class for progress handlers.
Progress handlers receive progress updates and display them in various ways.
"""
def __init__(self, name: str):
self.name = name
self.enabled = True
def set_registry(self, registry: "ProgressRegistry"):
pass
def start_handler(self, node_id: str, state: NodeProgressState, prompt_id: str):
"""Called when a node starts processing"""
pass
def update_handler(
self,
node_id: str,
value: float,
max_value: float,
state: NodeProgressState,
prompt_id: str,
image: Optional[Image.Image] = None,
):
"""Called when a node's progress is updated"""
pass
def finish_handler(self, node_id: str, state: NodeProgressState, prompt_id: str):
"""Called when a node finishes processing"""
pass
def reset(self):
"""Called when the progress registry is reset"""
pass
def enable(self):
"""Enable this handler"""
self.enabled = True
def disable(self):
"""Disable this handler"""
self.enabled = False
class CLIProgressHandler(ProgressHandler):
"""
Handler that displays progress using tqdm progress bars in the CLI.
"""
def __init__(self):
super().__init__("cli")
self.progress_bars: Dict[str, tqdm] = {}
@override
def start_handler(self, node_id: str, state: NodeProgressState, prompt_id: str):
# Create a new tqdm progress bar
if node_id not in self.progress_bars:
self.progress_bars[node_id] = tqdm(
total=state["max"],
desc=f"Node {node_id}",
unit="steps",
leave=True,
position=len(self.progress_bars),
)
@override
def update_handler(
self,
node_id: str,
value: float,
max_value: float,
state: NodeProgressState,
prompt_id: str,
image: Optional[Image.Image] = None,
):
# Handle case where start_handler wasn't called
if node_id not in self.progress_bars:
self.progress_bars[node_id] = tqdm(
total=max_value,
desc=f"Node {node_id}",
unit="steps",
leave=True,
position=len(self.progress_bars),
)
self.progress_bars[node_id].update(value)
else:
# Update existing progress bar
if max_value != self.progress_bars[node_id].total:
self.progress_bars[node_id].total = max_value
# Calculate the update amount (difference from current position)
current_position = self.progress_bars[node_id].n
update_amount = value - current_position
if update_amount > 0:
self.progress_bars[node_id].update(update_amount)
@override
def finish_handler(self, node_id: str, state: NodeProgressState, prompt_id: str):
# Complete and close the progress bar if it exists
if node_id in self.progress_bars:
# Ensure the bar shows 100% completion
remaining = state["max"] - self.progress_bars[node_id].n
if remaining > 0:
self.progress_bars[node_id].update(remaining)
self.progress_bars[node_id].close()
del self.progress_bars[node_id]
@override
def reset(self):
# Close all progress bars
for bar in self.progress_bars.values():
bar.close()
self.progress_bars.clear()
class WebUIProgressHandler(ProgressHandler):
"""
Handler that sends progress updates to the WebUI via WebSockets.
"""
def __init__(self, server_instance):
super().__init__("webui")
self.server_instance = server_instance
def set_registry(self, registry: "ProgressRegistry"):
self.registry = registry
def _send_progress_state(self, prompt_id: str, nodes: Dict[str, NodeProgressState]):
"""Send the current progress state to the client"""
if self.server_instance is None:
return
# Only send info for non-pending nodes
active_nodes = {
node_id: {
"value": state["value"],
"max": state["max"],
"state": state["state"].value,
"node_id": node_id,
"prompt_id": prompt_id,
"display_node_id": self.registry.dynprompt.get_display_node_id(node_id),
"parent_node_id": self.registry.dynprompt.get_parent_node_id(node_id),
"real_node_id": self.registry.dynprompt.get_real_node_id(node_id),
}
for node_id, state in nodes.items()
if state["state"] != NodeState.Pending
}
# Send a combined progress_state message with all node states
self.server_instance.send_sync(
"progress_state", {"prompt_id": prompt_id, "nodes": active_nodes}
)
@override
def start_handler(self, node_id: str, state: NodeProgressState, prompt_id: str):
# Send progress state of all nodes
if self.registry:
self._send_progress_state(prompt_id, self.registry.nodes)
@override
def update_handler(
self,
node_id: str,
value: float,
max_value: float,
state: NodeProgressState,
prompt_id: str,
image: Optional[Image.Image] = None,
):
# Send progress state of all nodes
if self.registry:
self._send_progress_state(prompt_id, self.registry.nodes)
if image:
# Only send new format if client supports it
if feature_flags.supports_feature(
self.server_instance.sockets_metadata,
self.server_instance.client_id,
"supports_preview_metadata",
):
metadata = {
"node_id": node_id,
"prompt_id": prompt_id,
"display_node_id": self.registry.dynprompt.get_display_node_id(
node_id
),
"parent_node_id": self.registry.dynprompt.get_parent_node_id(
node_id
),
"real_node_id": self.registry.dynprompt.get_real_node_id(node_id),
}
self.server_instance.send_sync(
BinaryEventTypes.PREVIEW_IMAGE_WITH_METADATA,
(image, metadata),
self.server_instance.client_id,
)
@override
def finish_handler(self, node_id: str, state: NodeProgressState, prompt_id: str):
# Send progress state of all nodes
if self.registry:
self._send_progress_state(prompt_id, self.registry.nodes)
class ProgressRegistry:
"""
Registry that maintains node progress state and notifies registered handlers.
"""
def __init__(self, prompt_id: str, dynprompt: "DynamicPrompt"):
self.prompt_id = prompt_id
self.dynprompt = dynprompt
self.nodes: Dict[str, NodeProgressState] = {}
self.handlers: Dict[str, ProgressHandler] = {}
def register_handler(self, handler: ProgressHandler) -> None:
"""Register a progress handler"""
self.handlers[handler.name] = handler
def unregister_handler(self, handler_name: str) -> None:
"""Unregister a progress handler"""
if handler_name in self.handlers:
# Allow handler to clean up resources
self.handlers[handler_name].reset()
del self.handlers[handler_name]
def enable_handler(self, handler_name: str) -> None:
"""Enable a progress handler"""
if handler_name in self.handlers:
self.handlers[handler_name].enable()
def disable_handler(self, handler_name: str) -> None:
"""Disable a progress handler"""
if handler_name in self.handlers:
self.handlers[handler_name].disable()
def ensure_entry(self, node_id: str) -> NodeProgressState:
"""Ensure a node entry exists"""
if node_id not in self.nodes:
self.nodes[node_id] = NodeProgressState(
state=NodeState.Pending, value=0, max=1
)
return self.nodes[node_id]
def start_progress(self, node_id: str) -> None:
"""Start progress tracking for a node"""
entry = self.ensure_entry(node_id)
entry["state"] = NodeState.Running
entry["value"] = 0.0
entry["max"] = 1.0
# Notify all enabled handlers
for handler in self.handlers.values():
if handler.enabled:
handler.start_handler(node_id, entry, self.prompt_id)
def update_progress(
self, node_id: str, value: float, max_value: float, image: Optional[Image.Image]
) -> None:
"""Update progress for a node"""
entry = self.ensure_entry(node_id)
entry["state"] = NodeState.Running
entry["value"] = value
entry["max"] = max_value
# Notify all enabled handlers
for handler in self.handlers.values():
if handler.enabled:
handler.update_handler(
node_id, value, max_value, entry, self.prompt_id, image
)
def finish_progress(self, node_id: str) -> None:
"""Finish progress tracking for a node"""
entry = self.ensure_entry(node_id)
entry["state"] = NodeState.Finished
entry["value"] = entry["max"]
# Notify all enabled handlers
for handler in self.handlers.values():
if handler.enabled:
handler.finish_handler(node_id, entry, self.prompt_id)
def reset_handlers(self) -> None:
"""Reset all handlers"""
for handler in self.handlers.values():
handler.reset()
# Global registry instance
global_progress_registry: ProgressRegistry = None
def reset_progress_state(prompt_id: str, dynprompt: "DynamicPrompt") -> None:
global global_progress_registry
# Reset existing handlers if registry exists
if global_progress_registry is not None:
global_progress_registry.reset_handlers()
# Create new registry
global_progress_registry = ProgressRegistry(prompt_id, dynprompt)
def add_progress_handler(handler: ProgressHandler) -> None:
registry = get_progress_state()
handler.set_registry(registry)
registry.register_handler(handler)
def get_progress_state() -> ProgressRegistry:
global global_progress_registry
if global_progress_registry is None:
from comfy_execution.graph import DynamicPrompt
global_progress_registry = ProgressRegistry(
prompt_id="", dynprompt=DynamicPrompt({})
)
return global_progress_registry

46
comfy_execution/utils.py Normal file
View File

@ -0,0 +1,46 @@
import contextvars
from typing import Optional, NamedTuple
class ExecutionContext(NamedTuple):
"""
Context information about the currently executing node.
Attributes:
node_id: The ID of the currently executing node
list_index: The index in a list being processed (for operations on batches/lists)
"""
prompt_id: str
node_id: str
list_index: Optional[int]
current_executing_context: contextvars.ContextVar[Optional[ExecutionContext]] = contextvars.ContextVar("current_executing_context", default=None)
def get_executing_context() -> Optional[ExecutionContext]:
return current_executing_context.get(None)
class CurrentNodeContext:
"""
Context manager for setting the current executing node context.
Sets the current_executing_context on enter and resets it on exit.
Example:
with CurrentNodeContext(node_id="123", list_index=0):
# Code that should run with the current node context set
process_image()
"""
def __init__(self, prompt_id: str, node_id: str, list_index: Optional[int] = None):
self.context = ExecutionContext(
prompt_id= prompt_id,
node_id= node_id,
list_index= list_index
)
self.token = None
def __enter__(self):
self.token = current_executing_context.set(self.context)
return self
def __exit__(self, exc_type, exc_val, exc_tb):
if self.token is not None:
current_executing_context.reset(self.token)

View File

@ -40,6 +40,33 @@ class CFGZeroStar:
m.set_model_sampler_post_cfg_function(cfg_zero_star) m.set_model_sampler_post_cfg_function(cfg_zero_star)
return (m, ) return (m, )
class CFGNorm:
@classmethod
def INPUT_TYPES(s):
return {"required": {"model": ("MODEL",),
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01}),
}}
RETURN_TYPES = ("MODEL",)
RETURN_NAMES = ("patched_model",)
FUNCTION = "patch"
CATEGORY = "advanced/guidance"
EXPERIMENTAL = True
def patch(self, model, strength):
m = model.clone()
def cfg_norm(args):
cond_p = args['cond_denoised']
pred_text_ = args["denoised"]
norm_full_cond = torch.norm(cond_p, dim=1, keepdim=True)
norm_pred_text = torch.norm(pred_text_, dim=1, keepdim=True)
scale = (norm_full_cond / (norm_pred_text + 1e-8)).clamp(min=0.0, max=1.0)
return pred_text_ * scale * strength
m.set_model_sampler_post_cfg_function(cfg_norm)
return (m, )
NODE_CLASS_MAPPINGS = { NODE_CLASS_MAPPINGS = {
"CFGZeroStar": CFGZeroStar "CFGZeroStar": CFGZeroStar,
"CFGNorm": CFGNorm,
} }

View File

@ -71,8 +71,11 @@ class FreSca:
DESCRIPTION = "Applies frequency-dependent scaling to the guidance" DESCRIPTION = "Applies frequency-dependent scaling to the guidance"
def patch(self, model, scale_low, scale_high, freq_cutoff): def patch(self, model, scale_low, scale_high, freq_cutoff):
def custom_cfg_function(args): def custom_cfg_function(args):
cond = args["conds_out"][0] conds_out = args["conds_out"]
uncond = args["conds_out"][1] if len(conds_out) <= 1 or None in args["conds"][:2]:
return conds_out
cond = conds_out[0]
uncond = conds_out[1]
guidance = cond - uncond guidance = cond - uncond
filtered_guidance = Fourier_filter( filtered_guidance = Fourier_filter(
@ -83,7 +86,7 @@ class FreSca:
) )
filtered_cond = filtered_guidance + uncond filtered_cond = filtered_guidance + uncond
return [filtered_cond, uncond] return [filtered_cond, uncond] + conds_out[2:]
m = model.clone() m = model.clone()
m.set_model_sampler_pre_cfg_function(custom_cfg_function) m.set_model_sampler_pre_cfg_function(custom_cfg_function)

View File

@ -247,7 +247,7 @@ class MaskComposite:
visible_width, visible_height = (right - left, bottom - top,) visible_width, visible_height = (right - left, bottom - top,)
source_portion = source[:, :visible_height, :visible_width] source_portion = source[:, :visible_height, :visible_width]
destination_portion = destination[:, top:bottom, left:right] destination_portion = output[:, top:bottom, left:right]
if operation == "multiply": if operation == "multiply":
output[:, top:bottom, left:right] = destination_portion * source_portion output[:, top:bottom, left:right] = destination_portion * source_portion

View File

@ -23,38 +23,78 @@ from comfy.comfy_types.node_typing import IO
from comfy.weight_adapter import adapters from comfy.weight_adapter import adapters
def make_batch_extra_option_dict(d, indicies, full_size=None):
new_dict = {}
for k, v in d.items():
newv = v
if isinstance(v, dict):
newv = make_batch_extra_option_dict(v, indicies, full_size=full_size)
elif isinstance(v, torch.Tensor):
if full_size is None or v.size(0) == full_size:
newv = v[indicies]
elif isinstance(v, (list, tuple)) and len(v) == full_size:
newv = [v[i] for i in indicies]
new_dict[k] = newv
return new_dict
class TrainSampler(comfy.samplers.Sampler): class TrainSampler(comfy.samplers.Sampler):
def __init__(self, loss_fn, optimizer, loss_callback=None): def __init__(self, loss_fn, optimizer, loss_callback=None, batch_size=1, total_steps=1, seed=0, training_dtype=torch.bfloat16):
self.loss_fn = loss_fn self.loss_fn = loss_fn
self.optimizer = optimizer self.optimizer = optimizer
self.loss_callback = loss_callback self.loss_callback = loss_callback
self.batch_size = batch_size
self.total_steps = total_steps
self.seed = seed
self.training_dtype = training_dtype
def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False): def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False):
self.optimizer.zero_grad() cond = model_wrap.conds["positive"]
noise = model_wrap.inner_model.model_sampling.noise_scaling(sigmas, noise, latent_image, False) dataset_size = sigmas.size(0)
latent = model_wrap.inner_model.model_sampling.noise_scaling( torch.cuda.empty_cache()
torch.zeros_like(sigmas), for i in (pbar:=tqdm.trange(self.total_steps, desc="Training LoRA", smoothing=0.01, disable=not comfy.utils.PROGRESS_BAR_ENABLED)):
torch.zeros_like(noise, requires_grad=True), noisegen = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(self.seed + i * 1000)
latent_image, indicies = torch.randperm(dataset_size)[:self.batch_size].tolist()
False
)
# Ensure model is in training mode and computing gradients batch_latent = torch.stack([latent_image[i] for i in indicies])
# x0 pred batch_noise = noisegen.generate_noise({"samples": batch_latent}).to(batch_latent.device)
denoised = model_wrap(noise, sigmas, **extra_args) batch_sigmas = [
try: model_wrap.inner_model.model_sampling.percent_to_sigma(
loss = self.loss_fn(denoised, latent.clone()) torch.rand((1,)).item()
except RuntimeError as e: ) for _ in range(min(self.batch_size, dataset_size))
if "does not require grad and does not have a grad_fn" in str(e): ]
logging.info("WARNING: This is likely due to the model is loaded in inference mode.") batch_sigmas = torch.tensor(batch_sigmas).to(batch_latent.device)
loss.backward()
if self.loss_callback:
self.loss_callback(loss.item())
self.optimizer.step() xt = model_wrap.inner_model.model_sampling.noise_scaling(
# torch.cuda.memory._dump_snapshot("trainn.pickle") batch_sigmas,
# torch.cuda.memory._record_memory_history(enabled=None) batch_noise,
batch_latent,
False
)
x0 = model_wrap.inner_model.model_sampling.noise_scaling(
torch.zeros_like(batch_sigmas),
torch.zeros_like(batch_noise),
batch_latent,
False
)
model_wrap.conds["positive"] = [
cond[i] for i in indicies
]
batch_extra_args = make_batch_extra_option_dict(extra_args, indicies, full_size=dataset_size)
with torch.autocast(xt.device.type, dtype=self.training_dtype):
x0_pred = model_wrap(xt, batch_sigmas, **batch_extra_args)
loss = self.loss_fn(x0_pred, x0)
loss.backward()
if self.loss_callback:
self.loss_callback(loss.item())
pbar.set_postfix({"loss": f"{loss.item():.4f}"})
self.optimizer.step()
self.optimizer.zero_grad()
torch.cuda.empty_cache()
return torch.zeros_like(latent_image) return torch.zeros_like(latent_image)
@ -584,36 +624,34 @@ class TrainLoraNode:
loss_map = {"loss": []} loss_map = {"loss": []}
def loss_callback(loss): def loss_callback(loss):
loss_map["loss"].append(loss) loss_map["loss"].append(loss)
pbar.set_postfix({"loss": f"{loss:.4f}"})
train_sampler = TrainSampler( train_sampler = TrainSampler(
criterion, optimizer, loss_callback=loss_callback criterion,
optimizer,
loss_callback=loss_callback,
batch_size=batch_size,
total_steps=steps,
seed=seed,
training_dtype=dtype
) )
guider = comfy_extras.nodes_custom_sampler.Guider_Basic(mp) guider = comfy_extras.nodes_custom_sampler.Guider_Basic(mp)
guider.set_conds(positive) # Set conditioning from input guider.set_conds(positive) # Set conditioning from input
# yoland: this currently resize to the first image in the dataset
# Training loop # Training loop
torch.cuda.empty_cache()
try: try:
for step in (pbar:=tqdm.trange(steps, desc="Training LoRA", smoothing=0.01, disable=not comfy.utils.PROGRESS_BAR_ENABLED)): # Generate dummy sigmas and noise
# Generate random sigma sigmas = torch.tensor(range(num_images))
sigmas = [mp.model.model_sampling.percent_to_sigma( noise = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(seed)
torch.rand((1,)).item() guider.sample(
) for _ in range(min(batch_size, num_images))] noise.generate_noise({"samples": latents}),
sigmas = torch.tensor(sigmas) latents,
train_sampler,
noise = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(step * 1000 + seed) sigmas,
seed=noise.seed
indices = torch.randperm(num_images)[:batch_size] )
batch_latent = latents[indices].clone()
guider.set_conds([positive[i] for i in indices]) # Set conditioning from input
guider.sample(noise.generate_noise({"samples": batch_latent}), batch_latent, train_sampler, sigmas, seed=noise.seed)
finally: finally:
for m in mp.model.modules(): for m in mp.model.modules():
unpatch(m) unpatch(m)
del train_sampler, optimizer del train_sampler, optimizer
torch.cuda.empty_cache()
for adapter in all_weight_adapters: for adapter in all_weight_adapters:
adapter.requires_grad_(False) adapter.requires_grad_(False)

View File

@ -8,12 +8,14 @@ import time
import traceback import traceback
from enum import Enum from enum import Enum
from typing import List, Literal, NamedTuple, Optional from typing import List, Literal, NamedTuple, Optional
import asyncio
import torch import torch
import comfy.model_management import comfy.model_management
import nodes import nodes
from comfy_execution.caching import ( from comfy_execution.caching import (
BasicCache,
CacheKeySetID, CacheKeySetID,
CacheKeySetInputSignature, CacheKeySetInputSignature,
DependencyAwareCache, DependencyAwareCache,
@ -28,7 +30,9 @@ from comfy_execution.graph import (
) )
from comfy_execution.graph_utils import GraphBuilder, is_link from comfy_execution.graph_utils import GraphBuilder, is_link
from comfy_execution.validation import validate_node_input from comfy_execution.validation import validate_node_input
from comfy_api.internal import ComfyNodeInternal, lock_class, first_real_override, is_class from comfy_execution.progress import get_progress_state, reset_progress_state, add_progress_handler, WebUIProgressHandler
from comfy_execution.utils import CurrentNodeContext
from comfy_api.internal import ComfyNodeInternal, first_real_override, is_class, make_locked_method_func
from comfy_api.v3 import io from comfy_api.v3 import io
@ -41,12 +45,13 @@ class DuplicateNodeError(Exception):
pass pass
class IsChangedCache: class IsChangedCache:
def __init__(self, dynprompt, outputs_cache): def __init__(self, prompt_id: str, dynprompt: DynamicPrompt, outputs_cache: BasicCache):
self.prompt_id = prompt_id
self.dynprompt = dynprompt self.dynprompt = dynprompt
self.outputs_cache = outputs_cache self.outputs_cache = outputs_cache
self.is_changed = {} self.is_changed = {}
def get(self, node_id): async def get(self, node_id):
if node_id in self.is_changed: if node_id in self.is_changed:
return self.is_changed[node_id] return self.is_changed[node_id]
@ -72,7 +77,8 @@ class IsChangedCache:
# Intentionally do not use cached outputs here. We only want constants in IS_CHANGED # Intentionally do not use cached outputs here. We only want constants in IS_CHANGED
input_data_all, _, hidden_inputs = get_input_data(node["inputs"], class_def, node_id, None) input_data_all, _, hidden_inputs = get_input_data(node["inputs"], class_def, node_id, None)
try: try:
is_changed = _map_node_over_list(class_def, input_data_all, is_changed_name) is_changed = await _async_map_node_over_list(self.prompt_id, node_id, class_def, input_data_all, is_changed_name)
is_changed = await resolve_map_node_over_list_results(is_changed)
node["is_changed"] = [None if isinstance(x, ExecutionBlocker) else x for x in is_changed] node["is_changed"] = [None if isinstance(x, ExecutionBlocker) else x for x in is_changed]
except Exception as e: except Exception as e:
logging.warning("WARNING: {}".format(e)) logging.warning("WARNING: {}".format(e))
@ -127,6 +133,8 @@ class CacheSet:
} }
return result return result
SENSITIVE_EXTRA_DATA_KEYS = ("auth_token_comfy_org", "api_key_comfy_org")
def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, extra_data={}): def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, extra_data={}):
is_v3 = issubclass(class_def, ComfyNodeInternal) is_v3 = issubclass(class_def, ComfyNodeInternal)
if is_v3: if is_v3:
@ -194,7 +202,19 @@ def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, e
map_node_over_list = None #Don't hook this please map_node_over_list = None #Don't hook this please
def _map_node_over_list(obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None, hidden_inputs=None): async def resolve_map_node_over_list_results(results):
remaining = [x for x in results if isinstance(x, asyncio.Task) and not x.done()]
if len(remaining) == 0:
return [x.result() if isinstance(x, asyncio.Task) else x for x in results]
else:
done, pending = await asyncio.wait(remaining)
for task in done:
exc = task.exception()
if exc is not None:
raise exc
return [x.result() if isinstance(x, asyncio.Task) else x for x in results]
async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None, hidden_inputs=None):
# check if node wants the lists # check if node wants the lists
input_is_list = getattr(obj, "INPUT_IS_LIST", False) input_is_list = getattr(obj, "INPUT_IS_LIST", False)
@ -208,7 +228,7 @@ def _map_node_over_list(obj, input_data_all, func, allow_interrupt=False, execut
return {k: v[i if len(v) > i else -1] for k, v in d.items()} return {k: v[i if len(v) > i else -1] for k, v in d.items()}
results = [] results = []
def process_inputs(inputs, index=None, input_is_list=False): async def process_inputs(inputs, index=None, input_is_list=False):
if allow_interrupt: if allow_interrupt:
nodes.before_node_execution() nodes.before_node_execution()
execution_block = None execution_block = None
@ -256,23 +276,41 @@ def _map_node_over_list(obj, input_data_all, func, allow_interrupt=False, execut
dynamic_list.append(real_inputs.pop(d.id, None)) dynamic_list.append(real_inputs.pop(d.id, None))
dynamic_list = [x for x in dynamic_list if x is not None] dynamic_list = [x for x in dynamic_list if x is not None]
inputs = {**real_inputs, add_key: dynamic_list} inputs = {**real_inputs, add_key: dynamic_list}
results.append(getattr(type_obj, func).__func__(lock_class(class_clone), **inputs)) # TODO: make checkign for async work, this will currently always return False for iscoroutinefunction
f = make_locked_method_func(type_obj, func, class_clone)
# V1 # V1
else: else:
results.append(getattr(obj, func)(**inputs)) f = getattr(obj, func)
if inspect.iscoroutinefunction(f):
async def async_wrapper(f, prompt_id, unique_id, list_index, args):
with CurrentNodeContext(prompt_id, unique_id, list_index):
return await f(**args)
task = asyncio.create_task(async_wrapper(f, prompt_id, unique_id, index, args=inputs))
# Give the task a chance to execute without yielding
await asyncio.sleep(0)
if task.done():
result = task.result()
results.append(result)
else:
results.append(task)
else:
with CurrentNodeContext(prompt_id, unique_id, index):
result = f(**inputs)
results.append(result)
else: else:
results.append(execution_block) results.append(execution_block)
if input_is_list: if input_is_list:
process_inputs(input_data_all, 0, input_is_list=input_is_list) await process_inputs(input_data_all, 0, input_is_list=input_is_list)
elif max_len_input == 0: elif max_len_input == 0:
process_inputs({}) await process_inputs({})
else: else:
for i in range(max_len_input): for i in range(max_len_input):
input_dict = slice_dict(input_data_all, i) input_dict = slice_dict(input_data_all, i)
process_inputs(input_dict, i) await process_inputs(input_dict, i)
return results return results
def merge_result_data(results, obj): def merge_result_data(results, obj):
# check which outputs need concatenating # check which outputs need concatenating
output = [] output = []
@ -294,11 +332,18 @@ def merge_result_data(results, obj):
output.append([o[i] for o in results]) output.append([o[i] for o in results])
return output return output
def get_output_data(obj, input_data_all, execution_block_cb=None, pre_execute_cb=None, hidden_inputs=None): async def get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=None, pre_execute_cb=None, hidden_inputs=None):
return_values = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, hidden_inputs=hidden_inputs)
has_pending_task = any(isinstance(r, asyncio.Task) and not r.done() for r in return_values)
if has_pending_task:
return return_values, {}, False, has_pending_task
output, ui, has_subgraph = get_output_from_returns(return_values, obj)
return output, ui, has_subgraph, False
def get_output_from_returns(return_values, obj):
results = [] results = []
uis = [] uis = []
subgraph_results = [] subgraph_results = []
return_values = _map_node_over_list(obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, hidden_inputs=hidden_inputs)
has_subgraph = False has_subgraph = False
for i in range(len(return_values)): for i in range(len(return_values)):
r = return_values[i] r = return_values[i]
@ -352,6 +397,10 @@ def get_output_data(obj, input_data_all, execution_block_cb=None, pre_execute_cb
else: else:
output = [] output = []
ui = dict() ui = dict()
# TODO: Think there's an existing bug here
# If we're performing a subgraph expansion, we probably shouldn't be returning UI values yet.
# They'll get cached without the completed subgraphs. It's an edge case and I'm not aware of
# any nodes that use both subgraph expansion and custom UI outputs, but might be a problem in the future.
if len(uis) > 0: if len(uis) > 0:
ui = {k: [y for x in uis for y in x[k]] for k in uis[0].keys()} ui = {k: [y for x in uis for y in x[k]] for k in uis[0].keys()}
return output, ui, has_subgraph return output, ui, has_subgraph
@ -364,7 +413,7 @@ def format_value(x):
else: else:
return str(x) return str(x)
def execute(server, dynprompt, caches, current_item, extra_data, executed, prompt_id, execution_list, pending_subgraph_results): async def execute(server, dynprompt, caches, current_item, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes):
unique_id = current_item unique_id = current_item
real_node_id = dynprompt.get_real_node_id(unique_id) real_node_id = dynprompt.get_real_node_id(unique_id)
display_node_id = dynprompt.get_display_node_id(unique_id) display_node_id = dynprompt.get_display_node_id(unique_id)
@ -376,11 +425,26 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp
if server.client_id is not None: if server.client_id is not None:
cached_output = caches.ui.get(unique_id) or {} cached_output = caches.ui.get(unique_id) or {}
server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": cached_output.get("output",None), "prompt_id": prompt_id }, server.client_id) server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": cached_output.get("output",None), "prompt_id": prompt_id }, server.client_id)
get_progress_state().finish_progress(unique_id)
return (ExecutionResult.SUCCESS, None, None) return (ExecutionResult.SUCCESS, None, None)
input_data_all = None input_data_all = None
try: try:
if unique_id in pending_subgraph_results: if unique_id in pending_async_nodes:
results = []
for r in pending_async_nodes[unique_id]:
if isinstance(r, asyncio.Task):
try:
results.append(r.result())
except Exception as ex:
# An async task failed - propagate the exception up
del pending_async_nodes[unique_id]
raise ex
else:
results.append(r)
del pending_async_nodes[unique_id]
output_data, output_ui, has_subgraph = get_output_from_returns(results, class_def)
elif unique_id in pending_subgraph_results:
cached_results = pending_subgraph_results[unique_id] cached_results = pending_subgraph_results[unique_id]
resolved_outputs = [] resolved_outputs = []
for is_subgraph, result in cached_results: for is_subgraph, result in cached_results:
@ -402,6 +466,7 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp
output_ui = [] output_ui = []
has_subgraph = False has_subgraph = False
else: else:
get_progress_state().start_progress(unique_id)
input_data_all, missing_keys, hidden_inputs = get_input_data(inputs, class_def, unique_id, caches.outputs, dynprompt, extra_data) input_data_all, missing_keys, hidden_inputs = get_input_data(inputs, class_def, unique_id, caches.outputs, dynprompt, extra_data)
if server.client_id is not None: if server.client_id is not None:
server.last_node_id = display_node_id server.last_node_id = display_node_id
@ -417,7 +482,8 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp
else: else:
lazy_status_present = getattr(obj, "check_lazy_status", None) is not None lazy_status_present = getattr(obj, "check_lazy_status", None) is not None
if lazy_status_present: if lazy_status_present:
required_inputs = _map_node_over_list(obj, input_data_all, "check_lazy_status", allow_interrupt=True, hidden_inputs=hidden_inputs) required_inputs = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, "check_lazy_status", allow_interrupt=True, hidden_inputs=hidden_inputs)
required_inputs = await resolve_map_node_over_list_results(required_inputs)
required_inputs = set(sum([r for r in required_inputs if isinstance(r,list)], [])) required_inputs = set(sum([r for r in required_inputs if isinstance(r,list)], []))
required_inputs = [x for x in required_inputs if isinstance(x,str) and ( required_inputs = [x for x in required_inputs if isinstance(x,str) and (
x not in input_data_all or x in missing_keys x not in input_data_all or x in missing_keys
@ -446,8 +512,18 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp
else: else:
return block return block
def pre_execute_cb(call_index): def pre_execute_cb(call_index):
# TODO - How to handle this with async functions without contextvars (which requires Python 3.12)?
GraphBuilder.set_default_prefix(unique_id, call_index, 0) GraphBuilder.set_default_prefix(unique_id, call_index, 0)
output_data, output_ui, has_subgraph = get_output_data(obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, hidden_inputs=hidden_inputs) output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, hidden_inputs=hidden_inputs)
if has_pending_tasks:
pending_async_nodes[unique_id] = output_data
unblock = execution_list.add_external_block(unique_id)
async def await_completion():
tasks = [x for x in output_data if isinstance(x, asyncio.Task)]
await asyncio.gather(*tasks, return_exceptions=True)
unblock()
asyncio.create_task(await_completion())
return (ExecutionResult.PENDING, None, None)
if len(output_ui) > 0: if len(output_ui) > 0:
caches.ui.set(unique_id, { caches.ui.set(unique_id, {
"meta": { "meta": {
@ -490,7 +566,8 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp
cached_outputs.append((True, node_outputs)) cached_outputs.append((True, node_outputs))
new_node_ids = set(new_node_ids) new_node_ids = set(new_node_ids)
for cache in caches.all: for cache in caches.all:
cache.ensure_subcache_for(unique_id, new_node_ids).clean_unused() subcache = await cache.ensure_subcache_for(unique_id, new_node_ids)
subcache.clean_unused()
for node_id in new_output_ids: for node_id in new_output_ids:
execution_list.add_node(node_id) execution_list.add_node(node_id)
for link in new_output_links: for link in new_output_links:
@ -535,6 +612,7 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp
return (ExecutionResult.FAILURE, error_details, ex) return (ExecutionResult.FAILURE, error_details, ex)
get_progress_state().finish_progress(unique_id)
executed.add(unique_id) executed.add(unique_id)
return (ExecutionResult.SUCCESS, None, None) return (ExecutionResult.SUCCESS, None, None)
@ -589,6 +667,11 @@ class PromptExecutor:
self.add_message("execution_error", mes, broadcast=False) self.add_message("execution_error", mes, broadcast=False)
def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]): def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]):
asyncio_loop = asyncio.new_event_loop()
asyncio.set_event_loop(asyncio_loop)
asyncio.run(self.execute_async(prompt, prompt_id, extra_data, execute_outputs))
async def execute_async(self, prompt, prompt_id, extra_data={}, execute_outputs=[]):
nodes.interrupt_processing(False) nodes.interrupt_processing(False)
if "client_id" in extra_data: if "client_id" in extra_data:
@ -601,9 +684,11 @@ class PromptExecutor:
with torch.inference_mode(): with torch.inference_mode():
dynamic_prompt = DynamicPrompt(prompt) dynamic_prompt = DynamicPrompt(prompt)
is_changed_cache = IsChangedCache(dynamic_prompt, self.caches.outputs) reset_progress_state(prompt_id, dynamic_prompt)
add_progress_handler(WebUIProgressHandler(self.server))
is_changed_cache = IsChangedCache(prompt_id, dynamic_prompt, self.caches.outputs)
for cache in self.caches.all: for cache in self.caches.all:
cache.set_prompt(dynamic_prompt, prompt.keys(), is_changed_cache) await cache.set_prompt(dynamic_prompt, prompt.keys(), is_changed_cache)
cache.clean_unused() cache.clean_unused()
cached_nodes = [] cached_nodes = []
@ -616,6 +701,7 @@ class PromptExecutor:
{ "nodes": cached_nodes, "prompt_id": prompt_id}, { "nodes": cached_nodes, "prompt_id": prompt_id},
broadcast=False) broadcast=False)
pending_subgraph_results = {} pending_subgraph_results = {}
pending_async_nodes = {} # TODO - Unify this with pending_subgraph_results
executed = set() executed = set()
execution_list = ExecutionList(dynamic_prompt, self.caches.outputs) execution_list = ExecutionList(dynamic_prompt, self.caches.outputs)
current_outputs = self.caches.outputs.all_node_ids() current_outputs = self.caches.outputs.all_node_ids()
@ -623,12 +709,13 @@ class PromptExecutor:
execution_list.add_node(node_id) execution_list.add_node(node_id)
while not execution_list.is_empty(): while not execution_list.is_empty():
node_id, error, ex = execution_list.stage_node_execution() node_id, error, ex = await execution_list.stage_node_execution()
if error is not None: if error is not None:
self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex) self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
break break
result, error, ex = execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results) assert node_id is not None, "Node ID should not be None at this point"
result, error, ex = await execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes)
self.success = result != ExecutionResult.FAILURE self.success = result != ExecutionResult.FAILURE
if result == ExecutionResult.FAILURE: if result == ExecutionResult.FAILURE:
self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex) self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
@ -658,7 +745,7 @@ class PromptExecutor:
comfy.model_management.unload_all_models() comfy.model_management.unload_all_models()
def validate_inputs(prompt, item, validated): async def validate_inputs(prompt_id, prompt, item, validated):
unique_id = item unique_id = item
if unique_id in validated: if unique_id in validated:
return validated[unique_id] return validated[unique_id]
@ -741,7 +828,7 @@ def validate_inputs(prompt, item, validated):
errors.append(error) errors.append(error)
continue continue
try: try:
r = validate_inputs(prompt, o_id, validated) r = await validate_inputs(prompt_id, prompt, o_id, validated)
if r[0] is False: if r[0] is False:
# `r` will be set in `validated[o_id]` already # `r` will be set in `validated[o_id]` already
valid = False valid = False
@ -865,7 +952,8 @@ def validate_inputs(prompt, item, validated):
if 'input_types' in validate_function_inputs: if 'input_types' in validate_function_inputs:
input_filtered['input_types'] = [received_types] input_filtered['input_types'] = [received_types]
ret = _map_node_over_list(obj_class, input_filtered, validate_function_name, hidden_inputs=hidden_inputs) ret = await _async_map_node_over_list(prompt_id, unique_id, obj_class, input_filtered, validate_function_name, hidden_inputs=hidden_inputs)
ret = await resolve_map_node_over_list_results(ret)
for x in input_filtered: for x in input_filtered:
for i, r in enumerate(ret): for i, r in enumerate(ret):
if r is not True and not isinstance(r, ExecutionBlocker): if r is not True and not isinstance(r, ExecutionBlocker):
@ -898,7 +986,7 @@ def full_type_name(klass):
return klass.__qualname__ return klass.__qualname__
return module + '.' + klass.__qualname__ return module + '.' + klass.__qualname__
def validate_prompt(prompt): async def validate_prompt(prompt_id, prompt):
outputs = set() outputs = set()
for x in prompt: for x in prompt:
if 'class_type' not in prompt[x]: if 'class_type' not in prompt[x]:
@ -941,7 +1029,7 @@ def validate_prompt(prompt):
valid = False valid = False
reasons = [] reasons = []
try: try:
m = validate_inputs(prompt, o, validated) m = await validate_inputs(prompt_id, prompt, o, validated)
valid = m[0] valid = m[0]
reasons = m[1] reasons = m[1]
except Exception as ex: except Exception as ex:
@ -1054,6 +1142,11 @@ class PromptQueue:
if status is not None: if status is not None:
status_dict = copy.deepcopy(status._asdict()) status_dict = copy.deepcopy(status._asdict())
# Remove sensitive data from extra_data before storing in history
for sensitive_val in SENSITIVE_EXTRA_DATA_KEYS:
if sensitive_val in prompt[3]:
prompt[3].pop(sensitive_val)
self.history[prompt[1]] = { self.history[prompt[1]] = {
"prompt": prompt, "prompt": prompt,
"outputs": {}, "outputs": {},

35
main.py
View File

@ -11,6 +11,9 @@ import itertools
import utils.extra_config import utils.extra_config
import logging import logging
import sys import sys
from comfy_execution.progress import get_progress_state
from comfy_execution.utils import get_executing_context
from comfy_api import feature_flags
if __name__ == "__main__": if __name__ == "__main__":
#NOTE: These do not do anything on core ComfyUI, they are for custom nodes. #NOTE: These do not do anything on core ComfyUI, they are for custom nodes.
@ -127,11 +130,14 @@ if __name__ == "__main__":
import cuda_malloc import cuda_malloc
if 'torch' in sys.modules:
logging.warning("WARNING: Potential Error in code: Torch already imported, torch should never be imported before this point.")
import comfy.utils import comfy.utils
import execution import execution
import server import server
from server import BinaryEventTypes from protocol import BinaryEventTypes
import nodes import nodes
import comfy.model_management import comfy.model_management
import comfyui_version import comfyui_version
@ -227,15 +233,34 @@ async def run(server_instance, address='', port=8188, verbose=True, call_on_star
server_instance.start_multi_address(addresses, call_on_start, verbose), server_instance.publish_loop() server_instance.start_multi_address(addresses, call_on_start, verbose), server_instance.publish_loop()
) )
def hijack_progress(server_instance): def hijack_progress(server_instance):
def hook(value, total, preview_image): def hook(value, total, preview_image, prompt_id=None, node_id=None):
executing_context = get_executing_context()
if prompt_id is None and executing_context is not None:
prompt_id = executing_context.prompt_id
if node_id is None and executing_context is not None:
node_id = executing_context.node_id
comfy.model_management.throw_exception_if_processing_interrupted() comfy.model_management.throw_exception_if_processing_interrupted()
progress = {"value": value, "max": total, "prompt_id": server_instance.last_prompt_id, "node": server_instance.last_node_id} if prompt_id is None:
prompt_id = server_instance.last_prompt_id
if node_id is None:
node_id = server_instance.last_node_id
progress = {"value": value, "max": total, "prompt_id": prompt_id, "node": node_id}
get_progress_state().update_progress(node_id, value, total, preview_image)
server_instance.send_sync("progress", progress, server_instance.client_id) server_instance.send_sync("progress", progress, server_instance.client_id)
if preview_image is not None: if preview_image is not None:
server_instance.send_sync(BinaryEventTypes.UNENCODED_PREVIEW_IMAGE, preview_image, server_instance.client_id) # Only send old method if client doesn't support preview metadata
if not feature_flags.supports_feature(
server_instance.sockets_metadata,
server_instance.client_id,
"supports_preview_metadata",
):
server_instance.send_sync(
BinaryEventTypes.UNENCODED_PREVIEW_IMAGE,
preview_image,
server_instance.client_id,
)
comfy.utils.set_progress_bar_global_hook(hook) comfy.utils.set_progress_bar_global_hook(hook)

7
protocol.py Normal file
View File

@ -0,0 +1,7 @@
class BinaryEventTypes:
PREVIEW_IMAGE = 1
UNENCODED_PREVIEW_IMAGE = 2
TEXT = 3
PREVIEW_IMAGE_WITH_METADATA = 4

View File

@ -1,5 +1,5 @@
comfyui-frontend-package==1.23.4 comfyui-frontend-package==1.23.4
comfyui-workflow-templates==0.1.35 comfyui-workflow-templates==0.1.36
comfyui-embedded-docs==0.2.4 comfyui-embedded-docs==0.2.4
torch torch
torchsde torchsde

View File

@ -10,11 +10,11 @@ import urllib.parse
server_address = "127.0.0.1:8188" server_address = "127.0.0.1:8188"
client_id = str(uuid.uuid4()) client_id = str(uuid.uuid4())
def queue_prompt(prompt): def queue_prompt(prompt, prompt_id):
p = {"prompt": prompt, "client_id": client_id} p = {"prompt": prompt, "client_id": client_id, "prompt_id": prompt_id}
data = json.dumps(p).encode('utf-8') data = json.dumps(p).encode('utf-8')
req = urllib.request.Request("http://{}/prompt".format(server_address), data=data) req = urllib.request.Request("http://{}/prompt".format(server_address), data=data)
return json.loads(urllib.request.urlopen(req).read()) urllib.request.urlopen(req).read()
def get_image(filename, subfolder, folder_type): def get_image(filename, subfolder, folder_type):
data = {"filename": filename, "subfolder": subfolder, "type": folder_type} data = {"filename": filename, "subfolder": subfolder, "type": folder_type}
@ -27,7 +27,8 @@ def get_history(prompt_id):
return json.loads(response.read()) return json.loads(response.read())
def get_images(ws, prompt): def get_images(ws, prompt):
prompt_id = queue_prompt(prompt)['prompt_id'] prompt_id = str(uuid.uuid4())
queue_prompt(prompt, prompt_id)
output_images = {} output_images = {}
while True: while True:
out = ws.recv() out = ws.recv()

View File

@ -26,6 +26,7 @@ import mimetypes
from comfy.cli_args import args from comfy.cli_args import args
import comfy.utils import comfy.utils
import comfy.model_management import comfy.model_management
from comfy_api import feature_flags
import node_helpers import node_helpers
from comfyui_version import __version__ from comfyui_version import __version__
from app.frontend_management import FrontendManager from app.frontend_management import FrontendManager
@ -36,11 +37,7 @@ from app.model_manager import ModelFileManager
from app.custom_node_manager import CustomNodeManager from app.custom_node_manager import CustomNodeManager
from typing import Optional, Union from typing import Optional, Union
from api_server.routes.internal.internal_routes import InternalRoutes from api_server.routes.internal.internal_routes import InternalRoutes
from protocol import BinaryEventTypes
class BinaryEventTypes:
PREVIEW_IMAGE = 1
UNENCODED_PREVIEW_IMAGE = 2
TEXT = 3
async def send_socket_catch_exception(function, message): async def send_socket_catch_exception(function, message):
try: try:
@ -179,6 +176,7 @@ class PromptServer():
max_upload_size = round(args.max_upload_size * 1024 * 1024) max_upload_size = round(args.max_upload_size * 1024 * 1024)
self.app = web.Application(client_max_size=max_upload_size, middlewares=middlewares) self.app = web.Application(client_max_size=max_upload_size, middlewares=middlewares)
self.sockets = dict() self.sockets = dict()
self.sockets_metadata = dict()
self.web_root = ( self.web_root = (
FrontendManager.init_frontend(args.front_end_version) FrontendManager.init_frontend(args.front_end_version)
if args.front_end_root is None if args.front_end_root is None
@ -203,20 +201,53 @@ class PromptServer():
else: else:
sid = uuid.uuid4().hex sid = uuid.uuid4().hex
# Store WebSocket for backward compatibility
self.sockets[sid] = ws self.sockets[sid] = ws
# Store metadata separately
self.sockets_metadata[sid] = {"feature_flags": {}}
try: try:
# Send initial state to the new client # Send initial state to the new client
await self.send("status", { "status": self.get_queue_info(), 'sid': sid }, sid) await self.send("status", {"status": self.get_queue_info(), "sid": sid}, sid)
# On reconnect if we are the currently executing client send the current node # On reconnect if we are the currently executing client send the current node
if self.client_id == sid and self.last_node_id is not None: if self.client_id == sid and self.last_node_id is not None:
await self.send("executing", { "node": self.last_node_id }, sid) await self.send("executing", { "node": self.last_node_id }, sid)
# Flag to track if we've received the first message
first_message = True
async for msg in ws: async for msg in ws:
if msg.type == aiohttp.WSMsgType.ERROR: if msg.type == aiohttp.WSMsgType.ERROR:
logging.warning('ws connection closed with exception %s' % ws.exception()) logging.warning('ws connection closed with exception %s' % ws.exception())
elif msg.type == aiohttp.WSMsgType.TEXT:
try:
data = json.loads(msg.data)
# Check if first message is feature flags
if first_message and data.get("type") == "feature_flags":
# Store client feature flags
client_flags = data.get("data", {})
self.sockets_metadata[sid]["feature_flags"] = client_flags
# Send server feature flags in response
await self.send(
"feature_flags",
feature_flags.get_server_features(),
sid,
)
logging.info(
f"Feature flags negotiated for client {sid}: {client_flags}"
)
first_message = False
except json.JSONDecodeError:
logging.warning(
f"Invalid JSON received from client {sid}: {msg.data}"
)
except Exception as e:
logging.error(f"Error processing WebSocket message: {e}")
finally: finally:
self.sockets.pop(sid, None) self.sockets.pop(sid, None)
self.sockets_metadata.pop(sid, None)
return ws return ws
@routes.get("/") @routes.get("/")
@ -549,6 +580,10 @@ class PromptServer():
} }
return web.json_response(system_stats) return web.json_response(system_stats)
@routes.get("/features")
async def get_features(request):
return web.json_response(feature_flags.get_server_features())
@routes.get("/prompt") @routes.get("/prompt")
async def get_prompt(request): async def get_prompt(request):
return web.json_response(self.get_queue_info()) return web.json_response(self.get_queue_info())
@ -646,7 +681,8 @@ class PromptServer():
if "prompt" in json_data: if "prompt" in json_data:
prompt = json_data["prompt"] prompt = json_data["prompt"]
valid = execution.validate_prompt(prompt) prompt_id = str(json_data.get("prompt_id", uuid.uuid4()))
valid = await execution.validate_prompt(prompt_id, prompt)
extra_data = {} extra_data = {}
if "extra_data" in json_data: if "extra_data" in json_data:
extra_data = json_data["extra_data"] extra_data = json_data["extra_data"]
@ -654,7 +690,6 @@ class PromptServer():
if "client_id" in json_data: if "client_id" in json_data:
extra_data["client_id"] = json_data["client_id"] extra_data["client_id"] = json_data["client_id"]
if valid[0]: if valid[0]:
prompt_id = str(uuid.uuid4())
outputs_to_execute = valid[2] outputs_to_execute = valid[2]
self.prompt_queue.put((number, prompt_id, prompt, extra_data, outputs_to_execute)) self.prompt_queue.put((number, prompt_id, prompt, extra_data, outputs_to_execute))
response = {"prompt_id": prompt_id, "number": number, "node_errors": valid[3]} response = {"prompt_id": prompt_id, "number": number, "node_errors": valid[3]}
@ -769,6 +804,10 @@ class PromptServer():
async def send(self, event, data, sid=None): async def send(self, event, data, sid=None):
if event == BinaryEventTypes.UNENCODED_PREVIEW_IMAGE: if event == BinaryEventTypes.UNENCODED_PREVIEW_IMAGE:
await self.send_image(data, sid=sid) await self.send_image(data, sid=sid)
elif event == BinaryEventTypes.PREVIEW_IMAGE_WITH_METADATA:
# data is (preview_image, metadata)
preview_image, metadata = data
await self.send_image_with_metadata(preview_image, metadata, sid=sid)
elif isinstance(data, (bytes, bytearray)): elif isinstance(data, (bytes, bytearray)):
await self.send_bytes(event, data, sid) await self.send_bytes(event, data, sid)
else: else:
@ -807,6 +846,43 @@ class PromptServer():
preview_bytes = bytesIO.getvalue() preview_bytes = bytesIO.getvalue()
await self.send_bytes(BinaryEventTypes.PREVIEW_IMAGE, preview_bytes, sid=sid) await self.send_bytes(BinaryEventTypes.PREVIEW_IMAGE, preview_bytes, sid=sid)
async def send_image_with_metadata(self, image_data, metadata=None, sid=None):
image_type = image_data[0]
image = image_data[1]
max_size = image_data[2]
if max_size is not None:
if hasattr(Image, 'Resampling'):
resampling = Image.Resampling.BILINEAR
else:
resampling = Image.Resampling.LANCZOS
image = ImageOps.contain(image, (max_size, max_size), resampling)
mimetype = "image/png" if image_type == "PNG" else "image/jpeg"
# Prepare metadata
if metadata is None:
metadata = {}
metadata["image_type"] = mimetype
# Serialize metadata as JSON
import json
metadata_json = json.dumps(metadata).encode('utf-8')
metadata_length = len(metadata_json)
# Prepare image data
bytesIO = BytesIO()
image.save(bytesIO, format=image_type, quality=95, compress_level=1)
image_bytes = bytesIO.getvalue()
# Combine metadata and image
combined_data = bytearray()
combined_data.extend(struct.pack(">I", metadata_length))
combined_data.extend(metadata_json)
combined_data.extend(image_bytes)
await self.send_bytes(BinaryEventTypes.PREVIEW_IMAGE_WITH_METADATA, combined_data, sid=sid)
async def send_bytes(self, event, data, sid=None): async def send_bytes(self, event, data, sid=None):
message = self.encode_bytes(event, data) message = self.encode_bytes(event, data)
@ -848,10 +924,10 @@ class PromptServer():
ssl_ctx = None ssl_ctx = None
scheme = "http" scheme = "http"
if args.tls_keyfile and args.tls_certfile: if args.tls_keyfile and args.tls_certfile:
ssl_ctx = ssl.SSLContext(protocol=ssl.PROTOCOL_TLS_SERVER, verify_mode=ssl.CERT_NONE) ssl_ctx = ssl.SSLContext(protocol=ssl.PROTOCOL_TLS_SERVER, verify_mode=ssl.CERT_NONE)
ssl_ctx.load_cert_chain(certfile=args.tls_certfile, ssl_ctx.load_cert_chain(certfile=args.tls_certfile,
keyfile=args.tls_keyfile) keyfile=args.tls_keyfile)
scheme = "https" scheme = "https"
if verbose: if verbose:
logging.info("Starting server\n") logging.info("Starting server\n")

View File

@ -0,0 +1,98 @@
"""Tests for feature flags functionality."""
from comfy_api.feature_flags import (
get_connection_feature,
supports_feature,
get_server_features,
SERVER_FEATURE_FLAGS,
)
class TestFeatureFlags:
"""Test suite for feature flags functions."""
def test_get_server_features_returns_copy(self):
"""Test that get_server_features returns a copy of the server flags."""
features = get_server_features()
# Verify it's a copy by modifying it
features["test_flag"] = True
# Original should be unchanged
assert "test_flag" not in SERVER_FEATURE_FLAGS
def test_get_server_features_contains_expected_flags(self):
"""Test that server features contain expected flags."""
features = get_server_features()
assert "supports_preview_metadata" in features
assert features["supports_preview_metadata"] is True
assert "max_upload_size" in features
assert isinstance(features["max_upload_size"], (int, float))
def test_get_connection_feature_with_missing_sid(self):
"""Test getting feature for non-existent session ID."""
sockets_metadata = {}
result = get_connection_feature(sockets_metadata, "missing_sid", "some_feature")
assert result is False # Default value
def test_get_connection_feature_with_custom_default(self):
"""Test getting feature with custom default value."""
sockets_metadata = {}
result = get_connection_feature(
sockets_metadata, "missing_sid", "some_feature", default="custom_default"
)
assert result == "custom_default"
def test_get_connection_feature_with_feature_flags(self):
"""Test getting feature from connection with feature flags."""
sockets_metadata = {
"sid1": {
"feature_flags": {
"supports_preview_metadata": True,
"custom_feature": "value",
},
}
}
result = get_connection_feature(sockets_metadata, "sid1", "supports_preview_metadata")
assert result is True
result = get_connection_feature(sockets_metadata, "sid1", "custom_feature")
assert result == "value"
def test_get_connection_feature_missing_feature(self):
"""Test getting non-existent feature from connection."""
sockets_metadata = {
"sid1": {"feature_flags": {"existing_feature": True}}
}
result = get_connection_feature(sockets_metadata, "sid1", "missing_feature")
assert result is False
def test_supports_feature_returns_boolean(self):
"""Test that supports_feature always returns boolean."""
sockets_metadata = {
"sid1": {
"feature_flags": {
"bool_feature": True,
"string_feature": "value",
"none_feature": None,
},
}
}
# True boolean feature
assert supports_feature(sockets_metadata, "sid1", "bool_feature") is True
# Non-boolean values should return False
assert supports_feature(sockets_metadata, "sid1", "string_feature") is False
assert supports_feature(sockets_metadata, "sid1", "none_feature") is False
assert supports_feature(sockets_metadata, "sid1", "missing_feature") is False
def test_supports_feature_with_missing_connection(self):
"""Test supports_feature with missing connection."""
sockets_metadata = {}
assert supports_feature(sockets_metadata, "missing_sid", "any_feature") is False
def test_empty_feature_flags_dict(self):
"""Test connection with empty feature flags dictionary."""
sockets_metadata = {"sid1": {"feature_flags": {}}}
result = get_connection_feature(sockets_metadata, "sid1", "any_feature")
assert result is False
assert supports_feature(sockets_metadata, "sid1", "any_feature") is False

View File

@ -1,3 +1,4 @@
pytest>=7.8.0 pytest>=7.8.0
pytest-aiohttp pytest-aiohttp
pytest-asyncio pytest-asyncio
websocket-client

View File

@ -0,0 +1,77 @@
"""Simplified tests for WebSocket feature flags functionality."""
from comfy_api import feature_flags
class TestWebSocketFeatureFlags:
"""Test suite for WebSocket feature flags integration."""
def test_server_feature_flags_response(self):
"""Test server feature flags are properly formatted."""
features = feature_flags.get_server_features()
# Check expected server features
assert "supports_preview_metadata" in features
assert features["supports_preview_metadata"] is True
assert "max_upload_size" in features
assert isinstance(features["max_upload_size"], (int, float))
def test_progress_py_checks_feature_flags(self):
"""Test that progress.py checks feature flags before sending metadata."""
# This simulates the check in progress.py
client_id = "test_client"
sockets_metadata = {"test_client": {"feature_flags": {}}}
# The actual check would be in progress.py
supports_metadata = feature_flags.supports_feature(
sockets_metadata, client_id, "supports_preview_metadata"
)
assert supports_metadata is False
def test_multiple_clients_different_features(self):
"""Test handling multiple clients with different feature support."""
sockets_metadata = {
"modern_client": {
"feature_flags": {"supports_preview_metadata": True}
},
"legacy_client": {
"feature_flags": {}
}
}
# Check modern client
assert feature_flags.supports_feature(
sockets_metadata, "modern_client", "supports_preview_metadata"
) is True
# Check legacy client
assert feature_flags.supports_feature(
sockets_metadata, "legacy_client", "supports_preview_metadata"
) is False
def test_feature_negotiation_message_format(self):
"""Test the format of feature negotiation messages."""
# Client message format
client_message = {
"type": "feature_flags",
"data": {
"supports_preview_metadata": True,
"api_version": "1.0.0"
}
}
# Verify structure
assert client_message["type"] == "feature_flags"
assert "supports_preview_metadata" in client_message["data"]
# Server response format (what would be sent)
server_features = feature_flags.get_server_features()
server_message = {
"type": "feature_flags",
"data": server_features
}
# Verify structure
assert server_message["type"] == "feature_flags"
assert "supports_preview_metadata" in server_message["data"]
assert server_message["data"]["supports_preview_metadata"] is True

View File

@ -1,4 +1,4 @@
# Config for testing nodes # Config for testing nodes
testing: testing:
custom_nodes: tests/inference/testing_nodes custom_nodes: testing_nodes

View File

@ -0,0 +1,410 @@
import pytest
import time
import torch
import urllib.error
import numpy as np
import subprocess
from pytest import fixture
from comfy_execution.graph_utils import GraphBuilder
from tests.inference.test_execution import ComfyClient
@pytest.mark.execution
class TestAsyncNodes:
@fixture(scope="class", autouse=True, params=[
(False, 0),
(True, 0),
(True, 100),
])
def _server(self, args_pytest, request):
pargs = [
'python','main.py',
'--output-directory', args_pytest["output_dir"],
'--listen', args_pytest["listen"],
'--port', str(args_pytest["port"]),
'--extra-model-paths-config', 'tests/inference/extra_model_paths.yaml',
]
use_lru, lru_size = request.param
if use_lru:
pargs += ['--cache-lru', str(lru_size)]
# Running server with args: pargs
p = subprocess.Popen(pargs)
yield
p.kill()
torch.cuda.empty_cache()
@fixture(scope="class", autouse=True)
def shared_client(self, args_pytest, _server):
client = ComfyClient()
n_tries = 5
for i in range(n_tries):
time.sleep(4)
try:
client.connect(listen=args_pytest["listen"], port=args_pytest["port"])
except ConnectionRefusedError:
# Retrying...
pass
else:
break
yield client
del client
torch.cuda.empty_cache()
@fixture
def client(self, shared_client, request):
shared_client.set_test_name(f"async_nodes[{request.node.name}]")
yield shared_client
@fixture
def builder(self, request):
yield GraphBuilder(prefix=request.node.name)
# Happy Path Tests
def test_basic_async_execution(self, client: ComfyClient, builder: GraphBuilder):
"""Test that a basic async node executes correctly."""
g = builder
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
sleep_node = g.node("TestSleep", value=image.out(0), seconds=0.1)
output = g.node("SaveImage", images=sleep_node.out(0))
result = client.run(g)
# Verify execution completed
assert result.did_run(sleep_node), "Async sleep node should have executed"
assert result.did_run(output), "Output node should have executed"
# Verify the image passed through correctly
result_images = result.get_images(output)
assert len(result_images) == 1, "Should have 1 image"
assert np.array(result_images[0]).min() == 0 and np.array(result_images[0]).max() == 0, "Image should be black"
def test_multiple_async_parallel_execution(self, client: ComfyClient, builder: GraphBuilder):
"""Test that multiple async nodes execute in parallel."""
g = builder
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
# Create multiple async sleep nodes with different durations
sleep1 = g.node("TestSleep", value=image.out(0), seconds=0.3)
sleep2 = g.node("TestSleep", value=image.out(0), seconds=0.4)
sleep3 = g.node("TestSleep", value=image.out(0), seconds=0.5)
# Add outputs for each
_output1 = g.node("PreviewImage", images=sleep1.out(0))
_output2 = g.node("PreviewImage", images=sleep2.out(0))
_output3 = g.node("PreviewImage", images=sleep3.out(0))
start_time = time.time()
result = client.run(g)
elapsed_time = time.time() - start_time
# Should take ~0.5s (max duration) not 1.2s (sum of durations)
assert elapsed_time < 0.8, f"Parallel execution took {elapsed_time}s, expected < 0.8s"
# Verify all nodes executed
assert result.did_run(sleep1) and result.did_run(sleep2) and result.did_run(sleep3)
def test_async_with_dependencies(self, client: ComfyClient, builder: GraphBuilder):
"""Test async nodes with proper dependency handling."""
g = builder
image1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
image2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
# Chain of async operations
sleep1 = g.node("TestSleep", value=image1.out(0), seconds=0.2)
sleep2 = g.node("TestSleep", value=image2.out(0), seconds=0.2)
# Average depends on both async results
average = g.node("TestVariadicAverage", input1=sleep1.out(0), input2=sleep2.out(0))
output = g.node("SaveImage", images=average.out(0))
result = client.run(g)
# Verify execution order
assert result.did_run(sleep1) and result.did_run(sleep2)
assert result.did_run(average) and result.did_run(output)
# Verify averaged result
result_images = result.get_images(output)
avg_value = np.array(result_images[0]).mean()
assert abs(avg_value - 127.5) < 1, f"Average value {avg_value} should be ~127.5"
def test_async_validate_inputs(self, client: ComfyClient, builder: GraphBuilder):
"""Test async VALIDATE_INPUTS function."""
g = builder
# Create a test node with async validation
validation_node = g.node("TestAsyncValidation", value=5.0, threshold=10.0)
g.node("SaveImage", images=validation_node.out(0))
# Should pass validation
result = client.run(g)
assert result.did_run(validation_node)
# Test validation failure
validation_node.inputs['threshold'] = 3.0 # Will fail since value > threshold
with pytest.raises(urllib.error.HTTPError):
client.run(g)
def test_async_lazy_evaluation(self, client: ComfyClient, builder: GraphBuilder):
"""Test async nodes with lazy evaluation."""
g = builder
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
mask = g.node("StubMask", value=0.0, height=512, width=512, batch_size=1)
# Create async nodes that will be evaluated lazily
sleep1 = g.node("TestSleep", value=input1.out(0), seconds=0.3)
sleep2 = g.node("TestSleep", value=input2.out(0), seconds=0.3)
# Use lazy mix that only needs sleep1 (mask=0.0)
lazy_mix = g.node("TestLazyMixImages", image1=sleep1.out(0), image2=sleep2.out(0), mask=mask.out(0))
g.node("SaveImage", images=lazy_mix.out(0))
start_time = time.time()
result = client.run(g)
elapsed_time = time.time() - start_time
# Should only execute sleep1, not sleep2
assert elapsed_time < 0.5, f"Should skip sleep2, took {elapsed_time}s"
assert result.did_run(sleep1), "Sleep1 should have executed"
assert not result.did_run(sleep2), "Sleep2 should have been skipped"
def test_async_check_lazy_status(self, client: ComfyClient, builder: GraphBuilder):
"""Test async check_lazy_status function."""
g = builder
# Create a node with async check_lazy_status
lazy_node = g.node("TestAsyncLazyCheck",
input1="value1",
input2="value2",
condition=True)
g.node("SaveImage", images=lazy_node.out(0))
result = client.run(g)
assert result.did_run(lazy_node)
# Error Handling Tests
def test_async_execution_error(self, client: ComfyClient, builder: GraphBuilder):
"""Test that async execution errors are properly handled."""
g = builder
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
# Create an async node that will error
error_node = g.node("TestAsyncError", value=image.out(0), error_after=0.1)
g.node("SaveImage", images=error_node.out(0))
try:
client.run(g)
assert False, "Should have raised an error"
except Exception as e:
assert 'prompt_id' in e.args[0], f"Did not get proper error message: {e}"
assert e.args[0]['node_id'] == error_node.id, "Error should be from async error node"
def test_async_validation_error(self, client: ComfyClient, builder: GraphBuilder):
"""Test async validation error handling."""
g = builder
# Node with async validation that will fail
validation_node = g.node("TestAsyncValidationError", value=15.0, max_value=10.0)
g.node("SaveImage", images=validation_node.out(0))
with pytest.raises(urllib.error.HTTPError) as exc_info:
client.run(g)
# Verify it's a validation error
assert exc_info.value.code == 400
def test_async_timeout_handling(self, client: ComfyClient, builder: GraphBuilder):
"""Test handling of async operations that timeout."""
g = builder
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
# Very long sleep that would timeout
timeout_node = g.node("TestAsyncTimeout", value=image.out(0), timeout=0.5, operation_time=2.0)
g.node("SaveImage", images=timeout_node.out(0))
try:
client.run(g)
assert False, "Should have raised a timeout error"
except Exception as e:
assert 'timeout' in str(e).lower(), f"Expected timeout error, got: {e}"
def test_concurrent_async_error_recovery(self, client: ComfyClient, builder: GraphBuilder):
"""Test that workflow can recover after async errors."""
g = builder
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
# First run with error
error_node = g.node("TestAsyncError", value=image.out(0), error_after=0.1)
g.node("SaveImage", images=error_node.out(0))
try:
client.run(g)
except Exception:
pass # Expected
# Second run should succeed
g2 = GraphBuilder(prefix="recovery_test")
image2 = g2.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
sleep_node = g2.node("TestSleep", value=image2.out(0), seconds=0.1)
g2.node("SaveImage", images=sleep_node.out(0))
result = client.run(g2)
assert result.did_run(sleep_node), "Should be able to run after error"
def test_sync_error_during_async_execution(self, client: ComfyClient, builder: GraphBuilder):
"""Test handling when sync node errors while async node is executing."""
g = builder
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
# Async node that takes time
sleep_node = g.node("TestSleep", value=image.out(0), seconds=0.5)
# Sync node that will error immediately
error_node = g.node("TestSyncError", value=image.out(0))
# Both feed into output
g.node("PreviewImage", images=sleep_node.out(0))
g.node("PreviewImage", images=error_node.out(0))
try:
client.run(g)
assert False, "Should have raised an error"
except Exception as e:
# Verify the sync error was caught even though async was running
assert 'prompt_id' in e.args[0]
# Edge Cases
def test_async_with_execution_blocker(self, client: ComfyClient, builder: GraphBuilder):
"""Test async nodes with execution blockers."""
g = builder
image1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
image2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
# Async sleep nodes
sleep1 = g.node("TestSleep", value=image1.out(0), seconds=0.2)
sleep2 = g.node("TestSleep", value=image2.out(0), seconds=0.2)
# Create list of images
image_list = g.node("TestMakeListNode", value1=sleep1.out(0), value2=sleep2.out(0))
# Create list of blocking conditions - [False, True] to block only the second item
int1 = g.node("StubInt", value=1)
int2 = g.node("StubInt", value=2)
block_list = g.node("TestMakeListNode", value1=int1.out(0), value2=int2.out(0))
# Compare each value against 2, so first is False (1 != 2) and second is True (2 == 2)
compare = g.node("TestIntConditions", a=block_list.out(0), b=2, operation="==")
# Block based on the comparison results
blocker = g.node("TestExecutionBlocker", input=image_list.out(0), block=compare.out(0), verbose=False)
output = g.node("PreviewImage", images=blocker.out(0))
result = client.run(g)
images = result.get_images(output)
assert len(images) == 1, "Should have blocked second image"
def test_async_caching_behavior(self, client: ComfyClient, builder: GraphBuilder):
"""Test that async nodes are properly cached."""
g = builder
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
sleep_node = g.node("TestSleep", value=image.out(0), seconds=0.2)
g.node("SaveImage", images=sleep_node.out(0))
# First run
result1 = client.run(g)
assert result1.did_run(sleep_node), "Should run first time"
# Second run - should be cached
start_time = time.time()
result2 = client.run(g)
elapsed_time = time.time() - start_time
assert not result2.did_run(sleep_node), "Should be cached"
assert elapsed_time < 0.1, f"Cached run took {elapsed_time}s, should be instant"
def test_async_with_dynamic_prompts(self, client: ComfyClient, builder: GraphBuilder):
"""Test async nodes within dynamically generated prompts."""
g = builder
image1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
image2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
# Node that generates async nodes dynamically
dynamic_async = g.node("TestDynamicAsyncGeneration",
image1=image1.out(0),
image2=image2.out(0),
num_async_nodes=3,
sleep_duration=0.2)
g.node("SaveImage", images=dynamic_async.out(0))
start_time = time.time()
result = client.run(g)
elapsed_time = time.time() - start_time
# Should execute async nodes in parallel within dynamic prompt
assert elapsed_time < 0.5, f"Dynamic async execution took {elapsed_time}s"
assert result.did_run(dynamic_async)
def test_async_resource_cleanup(self, client: ComfyClient, builder: GraphBuilder):
"""Test that async resources are properly cleaned up."""
g = builder
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
# Create multiple async nodes that use resources
resource_nodes = []
for i in range(5):
node = g.node("TestAsyncResourceUser",
value=image.out(0),
resource_id=f"resource_{i}",
duration=0.1)
resource_nodes.append(node)
g.node("PreviewImage", images=node.out(0))
result = client.run(g)
# Verify all nodes executed
for node in resource_nodes:
assert result.did_run(node)
# Run again to ensure resources were cleaned up
result2 = client.run(g)
# Should be cached but not error due to resource conflicts
for node in resource_nodes:
assert not result2.did_run(node), "Should be cached"
def test_async_cancellation(self, client: ComfyClient, builder: GraphBuilder):
"""Test cancellation of async operations."""
# This would require implementing cancellation in the client
# For now, we'll test that long-running async operations can be interrupted
pass # TODO: Implement when cancellation API is available
def test_mixed_sync_async_execution(self, client: ComfyClient, builder: GraphBuilder):
"""Test workflows with both sync and async nodes."""
g = builder
image1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
image2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
mask = g.node("StubMask", value=0.5, height=512, width=512, batch_size=1)
# Mix of sync and async operations
# Sync: lazy mix images
sync_op1 = g.node("TestLazyMixImages", image1=image1.out(0), image2=image2.out(0), mask=mask.out(0))
# Async: sleep
async_op1 = g.node("TestSleep", value=sync_op1.out(0), seconds=0.2)
# Sync: custom validation
sync_op2 = g.node("TestCustomValidation1", input1=async_op1.out(0), input2=0.5)
# Async: sleep again
async_op2 = g.node("TestSleep", value=sync_op2.out(0), seconds=0.2)
output = g.node("SaveImage", images=async_op2.out(0))
result = client.run(g)
# Verify all nodes executed in correct order
assert result.did_run(sync_op1)
assert result.did_run(async_op1)
assert result.did_run(sync_op2)
assert result.did_run(async_op2)
# Image should be a mix of black and white (gray)
result_images = result.get_images(output)
avg_value = np.array(result_images[0]).mean()
assert abs(avg_value - 63.75) < 5, f"Average value {avg_value} should be ~63.75"

View File

@ -252,7 +252,7 @@ class TestExecution:
@pytest.mark.parametrize("test_type, test_value", [ @pytest.mark.parametrize("test_type, test_value", [
("StubInt", 5), ("StubInt", 5),
("StubFloat", 5.0) ("StubMask", 5.0)
]) ])
def test_validation_error_edge1(self, test_type, test_value, client: ComfyClient, builder: GraphBuilder): def test_validation_error_edge1(self, test_type, test_value, client: ComfyClient, builder: GraphBuilder):
g = builder g = builder
@ -497,6 +497,69 @@ class TestExecution:
assert numpy.array(images[0]).min() == 63 and numpy.array(images[0]).max() == 63, "Image should have value 0.25" assert numpy.array(images[0]).min() == 63 and numpy.array(images[0]).max() == 63, "Image should have value 0.25"
assert not result.did_run(test_node), "The execution should have been cached" assert not result.did_run(test_node), "The execution should have been cached"
def test_parallel_sleep_nodes(self, client: ComfyClient, builder: GraphBuilder):
g = builder
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
# Create sleep nodes for each duration
sleep_node1 = g.node("TestSleep", value=image.out(0), seconds=2.8)
sleep_node2 = g.node("TestSleep", value=image.out(0), seconds=2.9)
sleep_node3 = g.node("TestSleep", value=image.out(0), seconds=3.0)
# Add outputs to verify the execution
_output1 = g.node("PreviewImage", images=sleep_node1.out(0))
_output2 = g.node("PreviewImage", images=sleep_node2.out(0))
_output3 = g.node("PreviewImage", images=sleep_node3.out(0))
start_time = time.time()
result = client.run(g)
elapsed_time = time.time() - start_time
# The test should take around 0.4 seconds (the longest sleep duration)
# plus some overhead, but definitely less than the sum of all sleeps (0.9s)
# We'll allow for up to 0.8s total to account for overhead
assert elapsed_time < 4.0, f"Parallel execution took {elapsed_time}s, expected less than 0.8s"
# Verify that all nodes executed
assert result.did_run(sleep_node1), "Sleep node 1 should have run"
assert result.did_run(sleep_node2), "Sleep node 2 should have run"
assert result.did_run(sleep_node3), "Sleep node 3 should have run"
def test_parallel_sleep_expansion(self, client: ComfyClient, builder: GraphBuilder):
g = builder
# Create input images with different values
image1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
image2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
image3 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
# Create a TestParallelSleep node that expands into multiple TestSleep nodes
parallel_sleep = g.node("TestParallelSleep",
image1=image1.out(0),
image2=image2.out(0),
image3=image3.out(0),
sleep1=0.4,
sleep2=0.5,
sleep3=0.6)
output = g.node("SaveImage", images=parallel_sleep.out(0))
start_time = time.time()
result = client.run(g)
elapsed_time = time.time() - start_time
# Similar to the previous test, expect parallel execution of the sleep nodes
# which should complete in less than the sum of all sleeps
assert elapsed_time < 0.8, f"Expansion execution took {elapsed_time}s, expected less than 0.8s"
# Verify the parallel sleep node executed
assert result.did_run(parallel_sleep), "ParallelSleep node should have run"
# Verify we get an image as output (blend of the three input images)
result_images = result.get_images(output)
assert len(result_images) == 1, "Should have 1 image"
# Average pixel value should be around 170 (255 * 2 // 3)
avg_value = numpy.array(result_images[0]).mean()
assert avg_value == 170, f"Image average value {avg_value} should be 170"
# This tests that nodes with OUTPUT_IS_LIST function correctly when they receive an ExecutionBlocker # This tests that nodes with OUTPUT_IS_LIST function correctly when they receive an ExecutionBlocker
# as input. We also test that when that list (containing an ExecutionBlocker) is passed to a node, # as input. We also test that when that list (containing an ExecutionBlocker) is passed to a node,
# only that one entry in the list is blocked. # only that one entry in the list is blocked.

View File

@ -3,6 +3,7 @@ from .flow_control import FLOW_CONTROL_NODE_CLASS_MAPPINGS, FLOW_CONTROL_NODE_DI
from .util import UTILITY_NODE_CLASS_MAPPINGS, UTILITY_NODE_DISPLAY_NAME_MAPPINGS from .util import UTILITY_NODE_CLASS_MAPPINGS, UTILITY_NODE_DISPLAY_NAME_MAPPINGS
from .conditions import CONDITION_NODE_CLASS_MAPPINGS, CONDITION_NODE_DISPLAY_NAME_MAPPINGS from .conditions import CONDITION_NODE_CLASS_MAPPINGS, CONDITION_NODE_DISPLAY_NAME_MAPPINGS
from .stubs import TEST_STUB_NODE_CLASS_MAPPINGS, TEST_STUB_NODE_DISPLAY_NAME_MAPPINGS from .stubs import TEST_STUB_NODE_CLASS_MAPPINGS, TEST_STUB_NODE_DISPLAY_NAME_MAPPINGS
from .async_test_nodes import ASYNC_TEST_NODE_CLASS_MAPPINGS, ASYNC_TEST_NODE_DISPLAY_NAME_MAPPINGS
# NODE_CLASS_MAPPINGS = GENERAL_NODE_CLASS_MAPPINGS.update(COMPONENT_NODE_CLASS_MAPPINGS) # NODE_CLASS_MAPPINGS = GENERAL_NODE_CLASS_MAPPINGS.update(COMPONENT_NODE_CLASS_MAPPINGS)
# NODE_DISPLAY_NAME_MAPPINGS = GENERAL_NODE_DISPLAY_NAME_MAPPINGS.update(COMPONENT_NODE_DISPLAY_NAME_MAPPINGS) # NODE_DISPLAY_NAME_MAPPINGS = GENERAL_NODE_DISPLAY_NAME_MAPPINGS.update(COMPONENT_NODE_DISPLAY_NAME_MAPPINGS)
@ -13,6 +14,7 @@ NODE_CLASS_MAPPINGS.update(FLOW_CONTROL_NODE_CLASS_MAPPINGS)
NODE_CLASS_MAPPINGS.update(UTILITY_NODE_CLASS_MAPPINGS) NODE_CLASS_MAPPINGS.update(UTILITY_NODE_CLASS_MAPPINGS)
NODE_CLASS_MAPPINGS.update(CONDITION_NODE_CLASS_MAPPINGS) NODE_CLASS_MAPPINGS.update(CONDITION_NODE_CLASS_MAPPINGS)
NODE_CLASS_MAPPINGS.update(TEST_STUB_NODE_CLASS_MAPPINGS) NODE_CLASS_MAPPINGS.update(TEST_STUB_NODE_CLASS_MAPPINGS)
NODE_CLASS_MAPPINGS.update(ASYNC_TEST_NODE_CLASS_MAPPINGS)
NODE_DISPLAY_NAME_MAPPINGS = {} NODE_DISPLAY_NAME_MAPPINGS = {}
NODE_DISPLAY_NAME_MAPPINGS.update(TEST_NODE_DISPLAY_NAME_MAPPINGS) NODE_DISPLAY_NAME_MAPPINGS.update(TEST_NODE_DISPLAY_NAME_MAPPINGS)
@ -20,4 +22,5 @@ NODE_DISPLAY_NAME_MAPPINGS.update(FLOW_CONTROL_NODE_DISPLAY_NAME_MAPPINGS)
NODE_DISPLAY_NAME_MAPPINGS.update(UTILITY_NODE_DISPLAY_NAME_MAPPINGS) NODE_DISPLAY_NAME_MAPPINGS.update(UTILITY_NODE_DISPLAY_NAME_MAPPINGS)
NODE_DISPLAY_NAME_MAPPINGS.update(CONDITION_NODE_DISPLAY_NAME_MAPPINGS) NODE_DISPLAY_NAME_MAPPINGS.update(CONDITION_NODE_DISPLAY_NAME_MAPPINGS)
NODE_DISPLAY_NAME_MAPPINGS.update(TEST_STUB_NODE_DISPLAY_NAME_MAPPINGS) NODE_DISPLAY_NAME_MAPPINGS.update(TEST_STUB_NODE_DISPLAY_NAME_MAPPINGS)
NODE_DISPLAY_NAME_MAPPINGS.update(ASYNC_TEST_NODE_DISPLAY_NAME_MAPPINGS)

View File

@ -0,0 +1,343 @@
import torch
import asyncio
from typing import Dict
from comfy.utils import ProgressBar
from comfy_execution.graph_utils import GraphBuilder
from comfy.comfy_types.node_typing import ComfyNodeABC
from comfy.comfy_types import IO
class TestAsyncValidation(ComfyNodeABC):
"""Test node with async VALIDATE_INPUTS."""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"value": ("FLOAT", {"default": 5.0}),
"threshold": ("FLOAT", {"default": 10.0}),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "process"
CATEGORY = "_for_testing/async"
@classmethod
async def VALIDATE_INPUTS(cls, value, threshold):
# Simulate async validation (e.g., checking remote service)
await asyncio.sleep(0.05)
if value > threshold:
return f"Value {value} exceeds threshold {threshold}"
return True
def process(self, value, threshold):
# Create image based on value
intensity = value / 10.0
image = torch.ones([1, 512, 512, 3]) * intensity
return (image,)
class TestAsyncError(ComfyNodeABC):
"""Test node that errors during async execution."""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"value": (IO.ANY, {}),
"error_after": ("FLOAT", {"default": 0.1, "min": 0.0, "max": 10.0}),
},
}
RETURN_TYPES = (IO.ANY,)
FUNCTION = "error_execution"
CATEGORY = "_for_testing/async"
async def error_execution(self, value, error_after):
await asyncio.sleep(error_after)
raise RuntimeError("Intentional async execution error for testing")
class TestAsyncValidationError(ComfyNodeABC):
"""Test node with async validation that always fails."""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"value": ("FLOAT", {"default": 5.0}),
"max_value": ("FLOAT", {"default": 10.0}),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "process"
CATEGORY = "_for_testing/async"
@classmethod
async def VALIDATE_INPUTS(cls, value, max_value):
await asyncio.sleep(0.05)
# Always fail validation for values > max_value
if value > max_value:
return f"Async validation failed: {value} > {max_value}"
return True
def process(self, value, max_value):
# This won't be reached if validation fails
image = torch.ones([1, 512, 512, 3]) * (value / max_value)
return (image,)
class TestAsyncTimeout(ComfyNodeABC):
"""Test node that simulates timeout scenarios."""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"value": (IO.ANY, {}),
"timeout": ("FLOAT", {"default": 1.0, "min": 0.1, "max": 10.0}),
"operation_time": ("FLOAT", {"default": 2.0, "min": 0.1, "max": 10.0}),
},
}
RETURN_TYPES = (IO.ANY,)
FUNCTION = "timeout_execution"
CATEGORY = "_for_testing/async"
async def timeout_execution(self, value, timeout, operation_time):
try:
# This will timeout if operation_time > timeout
await asyncio.wait_for(asyncio.sleep(operation_time), timeout=timeout)
return (value,)
except asyncio.TimeoutError:
raise RuntimeError(f"Operation timed out after {timeout} seconds")
class TestSyncError(ComfyNodeABC):
"""Test node that errors synchronously (for mixed sync/async testing)."""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"value": (IO.ANY, {}),
},
}
RETURN_TYPES = (IO.ANY,)
FUNCTION = "sync_error"
CATEGORY = "_for_testing/async"
def sync_error(self, value):
raise RuntimeError("Intentional sync execution error for testing")
class TestAsyncLazyCheck(ComfyNodeABC):
"""Test node with async check_lazy_status."""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"input1": (IO.ANY, {"lazy": True}),
"input2": (IO.ANY, {"lazy": True}),
"condition": ("BOOLEAN", {"default": True}),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "process"
CATEGORY = "_for_testing/async"
async def check_lazy_status(self, condition, input1, input2):
# Simulate async checking (e.g., querying remote service)
await asyncio.sleep(0.05)
needed = []
if condition and input1 is None:
needed.append("input1")
if not condition and input2 is None:
needed.append("input2")
return needed
def process(self, input1, input2, condition):
# Return a simple image
return (torch.ones([1, 512, 512, 3]),)
class TestDynamicAsyncGeneration(ComfyNodeABC):
"""Test node that dynamically generates async nodes."""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"image1": ("IMAGE",),
"image2": ("IMAGE",),
"num_async_nodes": ("INT", {"default": 3, "min": 1, "max": 10}),
"sleep_duration": ("FLOAT", {"default": 0.2, "min": 0.1, "max": 1.0}),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "generate_async_workflow"
CATEGORY = "_for_testing/async"
def generate_async_workflow(self, image1, image2, num_async_nodes, sleep_duration):
g = GraphBuilder()
# Create multiple async sleep nodes
sleep_nodes = []
for i in range(num_async_nodes):
image = image1 if i % 2 == 0 else image2
sleep_node = g.node("TestSleep", value=image, seconds=sleep_duration)
sleep_nodes.append(sleep_node)
# Average all results
if len(sleep_nodes) == 1:
final_node = sleep_nodes[0]
else:
avg_inputs = {"input1": sleep_nodes[0].out(0)}
for i, node in enumerate(sleep_nodes[1:], 2):
avg_inputs[f"input{i}"] = node.out(0)
final_node = g.node("TestVariadicAverage", **avg_inputs)
return {
"result": (final_node.out(0),),
"expand": g.finalize(),
}
class TestAsyncResourceUser(ComfyNodeABC):
"""Test node that uses resources during async execution."""
# Class-level resource tracking for testing
_active_resources: Dict[str, bool] = {}
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"value": (IO.ANY, {}),
"resource_id": ("STRING", {"default": "resource_0"}),
"duration": ("FLOAT", {"default": 0.1, "min": 0.0, "max": 1.0}),
},
}
RETURN_TYPES = (IO.ANY,)
FUNCTION = "use_resource"
CATEGORY = "_for_testing/async"
async def use_resource(self, value, resource_id, duration):
# Check if resource is already in use
if self._active_resources.get(resource_id, False):
raise RuntimeError(f"Resource {resource_id} is already in use!")
# Mark resource as in use
self._active_resources[resource_id] = True
try:
# Simulate resource usage
await asyncio.sleep(duration)
return (value,)
finally:
# Always clean up resource
self._active_resources[resource_id] = False
class TestAsyncBatchProcessing(ComfyNodeABC):
"""Test async processing of batched inputs."""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"images": ("IMAGE",),
"process_time_per_item": ("FLOAT", {"default": 0.1, "min": 0.01, "max": 1.0}),
},
"hidden": {
"unique_id": "UNIQUE_ID",
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "process_batch"
CATEGORY = "_for_testing/async"
async def process_batch(self, images, process_time_per_item, unique_id):
batch_size = images.shape[0]
pbar = ProgressBar(batch_size, node_id=unique_id)
# Process each image in the batch
processed = []
for i in range(batch_size):
# Simulate async processing
await asyncio.sleep(process_time_per_item)
# Simple processing: invert the image
processed_image = 1.0 - images[i:i+1]
processed.append(processed_image)
pbar.update(1)
# Stack processed images
result = torch.cat(processed, dim=0)
return (result,)
class TestAsyncConcurrentLimit(ComfyNodeABC):
"""Test concurrent execution limits for async nodes."""
_semaphore = asyncio.Semaphore(2) # Only allow 2 concurrent executions
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"value": (IO.ANY, {}),
"duration": ("FLOAT", {"default": 0.5, "min": 0.1, "max": 2.0}),
"node_id": ("INT", {"default": 0}),
},
}
RETURN_TYPES = (IO.ANY,)
FUNCTION = "limited_execution"
CATEGORY = "_for_testing/async"
async def limited_execution(self, value, duration, node_id):
async with self._semaphore:
# Node {node_id} acquired semaphore
await asyncio.sleep(duration)
# Node {node_id} releasing semaphore
return (value,)
# Add node mappings
ASYNC_TEST_NODE_CLASS_MAPPINGS = {
"TestAsyncValidation": TestAsyncValidation,
"TestAsyncError": TestAsyncError,
"TestAsyncValidationError": TestAsyncValidationError,
"TestAsyncTimeout": TestAsyncTimeout,
"TestSyncError": TestSyncError,
"TestAsyncLazyCheck": TestAsyncLazyCheck,
"TestDynamicAsyncGeneration": TestDynamicAsyncGeneration,
"TestAsyncResourceUser": TestAsyncResourceUser,
"TestAsyncBatchProcessing": TestAsyncBatchProcessing,
"TestAsyncConcurrentLimit": TestAsyncConcurrentLimit,
}
ASYNC_TEST_NODE_DISPLAY_NAME_MAPPINGS = {
"TestAsyncValidation": "Test Async Validation",
"TestAsyncError": "Test Async Error",
"TestAsyncValidationError": "Test Async Validation Error",
"TestAsyncTimeout": "Test Async Timeout",
"TestSyncError": "Test Sync Error",
"TestAsyncLazyCheck": "Test Async Lazy Check",
"TestDynamicAsyncGeneration": "Test Dynamic Async Generation",
"TestAsyncResourceUser": "Test Async Resource User",
"TestAsyncBatchProcessing": "Test Async Batch Processing",
"TestAsyncConcurrentLimit": "Test Async Concurrent Limit",
}

View File

@ -1,6 +1,11 @@
import torch import torch
import time
import asyncio
from comfy.utils import ProgressBar
from .tools import VariantSupport from .tools import VariantSupport
from comfy_execution.graph_utils import GraphBuilder from comfy_execution.graph_utils import GraphBuilder
from comfy.comfy_types.node_typing import ComfyNodeABC
from comfy.comfy_types import IO
class TestLazyMixImages: class TestLazyMixImages:
@classmethod @classmethod
@ -333,6 +338,131 @@ class TestMixedExpansionReturns:
"expand": g.finalize(), "expand": g.finalize(),
} }
class TestSamplingInExpansion:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"model": ("MODEL",),
"clip": ("CLIP",),
"vae": ("VAE",),
"seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
"steps": ("INT", {"default": 20, "min": 1, "max": 100}),
"cfg": ("FLOAT", {"default": 7.0, "min": 0.0, "max": 30.0}),
"prompt": ("STRING", {"multiline": True, "default": "a beautiful landscape with mountains and trees"}),
"negative_prompt": ("STRING", {"multiline": True, "default": "blurry, bad quality, worst quality"}),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "sampling_in_expansion"
CATEGORY = "Testing/Nodes"
def sampling_in_expansion(self, model, clip, vae, seed, steps, cfg, prompt, negative_prompt):
g = GraphBuilder()
# Create a basic image generation workflow using the input model, clip and vae
# 1. Setup text prompts using the provided CLIP model
positive_prompt = g.node("CLIPTextEncode",
text=prompt,
clip=clip)
negative_prompt = g.node("CLIPTextEncode",
text=negative_prompt,
clip=clip)
# 2. Create empty latent with specified size
empty_latent = g.node("EmptyLatentImage", width=512, height=512, batch_size=1)
# 3. Setup sampler and generate image latent
sampler = g.node("KSampler",
model=model,
positive=positive_prompt.out(0),
negative=negative_prompt.out(0),
latent_image=empty_latent.out(0),
seed=seed,
steps=steps,
cfg=cfg,
sampler_name="euler_ancestral",
scheduler="normal")
# 4. Decode latent to image using VAE
output = g.node("VAEDecode", samples=sampler.out(0), vae=vae)
return {
"result": (output.out(0),),
"expand": g.finalize(),
}
class TestSleep(ComfyNodeABC):
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"value": (IO.ANY, {}),
"seconds": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 9999.0, "step": 0.01, "tooltip": "The amount of seconds to sleep."}),
},
"hidden": {
"unique_id": "UNIQUE_ID",
},
}
RETURN_TYPES = (IO.ANY,)
FUNCTION = "sleep"
CATEGORY = "_for_testing"
async def sleep(self, value, seconds, unique_id):
pbar = ProgressBar(seconds, node_id=unique_id)
start = time.time()
expiration = start + seconds
now = start
while now < expiration:
now = time.time()
pbar.update_absolute(now - start)
await asyncio.sleep(0.01)
return (value,)
class TestParallelSleep(ComfyNodeABC):
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"image1": ("IMAGE", ),
"image2": ("IMAGE", ),
"image3": ("IMAGE", ),
"sleep1": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 10.0, "step": 0.01}),
"sleep2": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 10.0, "step": 0.01}),
"sleep3": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 10.0, "step": 0.01}),
},
"hidden": {
"unique_id": "UNIQUE_ID",
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "parallel_sleep"
CATEGORY = "_for_testing"
OUTPUT_NODE = True
def parallel_sleep(self, image1, image2, image3, sleep1, sleep2, sleep3, unique_id):
# Create a graph dynamically with three TestSleep nodes
g = GraphBuilder()
# Create sleep nodes for each duration and image
sleep_node1 = g.node("TestSleep", value=image1, seconds=sleep1)
sleep_node2 = g.node("TestSleep", value=image2, seconds=sleep2)
sleep_node3 = g.node("TestSleep", value=image3, seconds=sleep3)
# Blend the results using TestVariadicAverage
blend = g.node("TestVariadicAverage",
input1=sleep_node1.out(0),
input2=sleep_node2.out(0),
input3=sleep_node3.out(0))
return {
"result": (blend.out(0),),
"expand": g.finalize(),
}
TEST_NODE_CLASS_MAPPINGS = { TEST_NODE_CLASS_MAPPINGS = {
"TestLazyMixImages": TestLazyMixImages, "TestLazyMixImages": TestLazyMixImages,
"TestVariadicAverage": TestVariadicAverage, "TestVariadicAverage": TestVariadicAverage,
@ -345,6 +475,9 @@ TEST_NODE_CLASS_MAPPINGS = {
"TestCustomValidation5": TestCustomValidation5, "TestCustomValidation5": TestCustomValidation5,
"TestDynamicDependencyCycle": TestDynamicDependencyCycle, "TestDynamicDependencyCycle": TestDynamicDependencyCycle,
"TestMixedExpansionReturns": TestMixedExpansionReturns, "TestMixedExpansionReturns": TestMixedExpansionReturns,
"TestSamplingInExpansion": TestSamplingInExpansion,
"TestSleep": TestSleep,
"TestParallelSleep": TestParallelSleep,
} }
TEST_NODE_DISPLAY_NAME_MAPPINGS = { TEST_NODE_DISPLAY_NAME_MAPPINGS = {
@ -359,4 +492,7 @@ TEST_NODE_DISPLAY_NAME_MAPPINGS = {
"TestCustomValidation5": "Custom Validation 5", "TestCustomValidation5": "Custom Validation 5",
"TestDynamicDependencyCycle": "Dynamic Dependency Cycle", "TestDynamicDependencyCycle": "Dynamic Dependency Cycle",
"TestMixedExpansionReturns": "Mixed Expansion Returns", "TestMixedExpansionReturns": "Mixed Expansion Returns",
"TestSamplingInExpansion": "Sampling In Expansion",
"TestSleep": "Test Sleep",
"TestParallelSleep": "Test Parallel Sleep",
} }