diff --git a/README.md b/README.md index a99aca0e7..cf6df7e55 100644 --- a/README.md +++ b/README.md @@ -62,6 +62,7 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith - [HunyuanDiT](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_dit/) - [Flux](https://comfyanonymous.github.io/ComfyUI_examples/flux/) - [Lumina Image 2.0](https://comfyanonymous.github.io/ComfyUI_examples/lumina2/) + - [HiDream](https://comfyanonymous.github.io/ComfyUI_examples/hidream/) - Video Models - [Stable Video Diffusion](https://comfyanonymous.github.io/ComfyUI_examples/video/) - [Mochi](https://comfyanonymous.github.io/ComfyUI_examples/mochi/) diff --git a/app/frontend_management.py b/app/frontend_management.py index c56ea86e0..7b7923b79 100644 --- a/app/frontend_management.py +++ b/app/frontend_management.py @@ -184,6 +184,27 @@ comfyui-frontend-package is not installed. ) sys.exit(-1) + @classmethod + def templates_path(cls) -> str: + try: + import comfyui_workflow_templates + + return str( + importlib.resources.files(comfyui_workflow_templates) / "templates" + ) + except ImportError: + logging.error( + f""" +********** ERROR *********** + +comfyui-workflow-templates is not installed. + +{frontend_install_warning_message()} + +********** ERROR *********** +""".strip() + ) + @classmethod def parse_version_string(cls, value: str) -> tuple[str, str, str]: """ diff --git a/comfy/ldm/wan/model.py b/comfy/ldm/wan/model.py index 9b5e5332c..2a30497c5 100644 --- a/comfy/ldm/wan/model.py +++ b/comfy/ldm/wan/model.py @@ -83,7 +83,7 @@ class WanSelfAttention(nn.Module): class WanT2VCrossAttention(WanSelfAttention): - def forward(self, x, context): + def forward(self, x, context, **kwargs): r""" Args: x(Tensor): Shape [B, L1, C] @@ -116,14 +116,14 @@ class WanI2VCrossAttention(WanSelfAttention): # self.alpha = nn.Parameter(torch.zeros((1, ))) self.norm_k_img = RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity() - def forward(self, x, context): + def forward(self, x, context, context_img_len): r""" Args: x(Tensor): Shape [B, L1, C] context(Tensor): Shape [B, L2, C] """ - context_img = context[:, :257] - context = context[:, 257:] + context_img = context[:, :context_img_len] + context = context[:, context_img_len:] # compute query, key, value q = self.norm_q(self.q(x)) @@ -193,6 +193,7 @@ class WanAttentionBlock(nn.Module): e, freqs, context, + context_img_len=257, ): r""" Args: @@ -213,7 +214,7 @@ class WanAttentionBlock(nn.Module): x = x + y * e[2] # cross-attention & ffn - x = x + self.cross_attn(self.norm3(x), context) + x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len) y = self.ffn(self.norm2(x) * (1 + e[4]) + e[3]) x = x + y * e[5] return x @@ -250,7 +251,7 @@ class Head(nn.Module): class MLPProj(torch.nn.Module): - def __init__(self, in_dim, out_dim, operation_settings={}): + def __init__(self, in_dim, out_dim, flf_pos_embed_token_number=None, operation_settings={}): super().__init__() self.proj = torch.nn.Sequential( @@ -258,7 +259,15 @@ class MLPProj(torch.nn.Module): torch.nn.GELU(), operation_settings.get("operations").Linear(in_dim, out_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")), operation_settings.get("operations").LayerNorm(out_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))) + if flf_pos_embed_token_number is not None: + self.emb_pos = nn.Parameter(torch.empty((1, flf_pos_embed_token_number, in_dim), device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))) + else: + self.emb_pos = None + def forward(self, image_embeds): + if self.emb_pos is not None: + image_embeds = image_embeds[:, :self.emb_pos.shape[1]] + comfy.model_management.cast_to(self.emb_pos[:, :image_embeds.shape[1]], dtype=image_embeds.dtype, device=image_embeds.device) + clip_extra_context_tokens = self.proj(image_embeds) return clip_extra_context_tokens @@ -284,6 +293,7 @@ class WanModel(torch.nn.Module): qk_norm=True, cross_attn_norm=True, eps=1e-6, + flf_pos_embed_token_number=None, image_model=None, device=None, dtype=None, @@ -373,7 +383,7 @@ class WanModel(torch.nn.Module): self.rope_embedder = EmbedND(dim=d, theta=10000.0, axes_dim=[d - 4 * (d // 6), 2 * (d // 6), 2 * (d // 6)]) if model_type == 'i2v': - self.img_emb = MLPProj(1280, dim, operation_settings=operation_settings) + self.img_emb = MLPProj(1280, dim, flf_pos_embed_token_number=flf_pos_embed_token_number, operation_settings=operation_settings) else: self.img_emb = None @@ -420,9 +430,12 @@ class WanModel(torch.nn.Module): # context context = self.text_embedding(context) - if clip_fea is not None and self.img_emb is not None: - context_clip = self.img_emb(clip_fea) # bs x 257 x dim - context = torch.concat([context_clip, context], dim=1) + context_img_len = None + if clip_fea is not None: + if self.img_emb is not None: + context_clip = self.img_emb(clip_fea) # bs x 257 x dim + context = torch.concat([context_clip, context], dim=1) + context_img_len = clip_fea.shape[-2] patches_replace = transformer_options.get("patches_replace", {}) blocks_replace = patches_replace.get("dit", {}) @@ -430,12 +443,12 @@ class WanModel(torch.nn.Module): if ("double_block", i) in blocks_replace: def block_wrap(args): out = {} - out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"]) + out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len) return out out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap}) x = out["img"] else: - x = block(x, e=e0, freqs=freqs, context=context) + x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len) # head x = self.head(x, e) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index a4da1afcd..6499bf238 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -321,6 +321,9 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["model_type"] = "i2v" else: dit_config["model_type"] = "t2v" + flf_weight = state_dict.get('{}img_emb.emb_pos'.format(key_prefix)) + if flf_weight is not None: + dit_config["flf_pos_embed_token_number"] = flf_weight.shape[1] return dit_config if '{}latent_in.weight'.format(key_prefix) in state_dict_keys: # Hunyuan 3D diff --git a/comfy_execution/graph.py b/comfy_execution/graph.py index 59b42b746..a2799b52e 100644 --- a/comfy_execution/graph.py +++ b/comfy_execution/graph.py @@ -1,6 +1,9 @@ -import nodes +from __future__ import annotations +from typing import Type, Literal +import nodes from comfy_execution.graph_utils import is_link +from comfy.comfy_types.node_typing import ComfyNodeABC, InputTypeDict, InputTypeOptions class DependencyCycleError(Exception): pass @@ -54,7 +57,22 @@ class DynamicPrompt: def get_original_prompt(self): return self.original_prompt -def get_input_info(class_def, input_name, valid_inputs=None): +def get_input_info( + class_def: Type[ComfyNodeABC], + input_name: str, + valid_inputs: InputTypeDict | None = None +) -> tuple[str, Literal["required", "optional", "hidden"], InputTypeOptions] | tuple[None, None, None]: + """Get the input type, category, and extra info for a given input name. + + Arguments: + class_def: The class definition of the node. + input_name: The name of the input to get info for. + valid_inputs: The valid inputs for the node, or None to use the class_def.INPUT_TYPES(). + + Returns: + tuple[str, str, dict] | tuple[None, None, None]: The input type, category, and extra info for the input name. + """ + valid_inputs = valid_inputs or class_def.INPUT_TYPES() input_info = None input_category = None @@ -126,7 +144,7 @@ class TopologicalSort: from_node_id, from_socket = value if subgraph_nodes is not None and from_node_id not in subgraph_nodes: continue - input_type, input_category, input_info = self.get_input_info(unique_id, input_name) + _, _, input_info = self.get_input_info(unique_id, input_name) is_lazy = input_info is not None and "lazy" in input_info and input_info["lazy"] if (include_lazy or not is_lazy) and not self.is_cached(from_node_id): node_ids.append(from_node_id) diff --git a/comfy_extras/nodes_fresca.py b/comfy_extras/nodes_fresca.py new file mode 100644 index 000000000..ee310c874 --- /dev/null +++ b/comfy_extras/nodes_fresca.py @@ -0,0 +1,100 @@ +# Code based on https://github.com/WikiChao/FreSca (MIT License) +import torch +import torch.fft as fft + + +def Fourier_filter(x, scale_low=1.0, scale_high=1.5, freq_cutoff=20): + """ + Apply frequency-dependent scaling to an image tensor using Fourier transforms. + + Parameters: + x: Input tensor of shape (B, C, H, W) + scale_low: Scaling factor for low-frequency components (default: 1.0) + scale_high: Scaling factor for high-frequency components (default: 1.5) + freq_cutoff: Number of frequency indices around center to consider as low-frequency (default: 20) + + Returns: + x_filtered: Filtered version of x in spatial domain with frequency-specific scaling applied. + """ + # Preserve input dtype and device + dtype, device = x.dtype, x.device + + # Convert to float32 for FFT computations + x = x.to(torch.float32) + + # 1) Apply FFT and shift low frequencies to center + x_freq = fft.fftn(x, dim=(-2, -1)) + x_freq = fft.fftshift(x_freq, dim=(-2, -1)) + + # Initialize mask with high-frequency scaling factor + mask = torch.ones(x_freq.shape, device=device) * scale_high + m = mask + for d in range(len(x_freq.shape) - 2): + dim = d + 2 + cc = x_freq.shape[dim] // 2 + f_c = min(freq_cutoff, cc) + m = m.narrow(dim, cc - f_c, f_c * 2) + + # Apply low-frequency scaling factor to center region + m[:] = scale_low + + # 3) Apply frequency-specific scaling + x_freq = x_freq * mask + + # 4) Convert back to spatial domain + x_freq = fft.ifftshift(x_freq, dim=(-2, -1)) + x_filtered = fft.ifftn(x_freq, dim=(-2, -1)).real + + # 5) Restore original dtype + x_filtered = x_filtered.to(dtype) + + return x_filtered + + +class FreSca: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "model": ("MODEL",), + "scale_low": ("FLOAT", {"default": 1.0, "min": 0, "max": 10, "step": 0.01, + "tooltip": "Scaling factor for low-frequency components"}), + "scale_high": ("FLOAT", {"default": 1.25, "min": 0, "max": 10, "step": 0.01, + "tooltip": "Scaling factor for high-frequency components"}), + "freq_cutoff": ("INT", {"default": 20, "min": 1, "max": 10000, "step": 1, + "tooltip": "Number of frequency indices around center to consider as low-frequency"}), + } + } + RETURN_TYPES = ("MODEL",) + FUNCTION = "patch" + CATEGORY = "_for_testing" + DESCRIPTION = "Applies frequency-dependent scaling to the guidance" + def patch(self, model, scale_low, scale_high, freq_cutoff): + def custom_cfg_function(args): + cond = args["conds_out"][0] + uncond = args["conds_out"][1] + + guidance = cond - uncond + filtered_guidance = Fourier_filter( + guidance, + scale_low=scale_low, + scale_high=scale_high, + freq_cutoff=freq_cutoff, + ) + filtered_cond = filtered_guidance + uncond + + return [filtered_cond, uncond] + + m = model.clone() + m.set_model_sampler_pre_cfg_function(custom_cfg_function) + + return (m,) + + +NODE_CLASS_MAPPINGS = { + "FreSca": FreSca, +} + +NODE_DISPLAY_NAME_MAPPINGS = { + "FreSca": "FreSca", +} diff --git a/comfy_extras/nodes_load_3d.py b/comfy_extras/nodes_load_3d.py index db30030fb..53d892bc4 100644 --- a/comfy_extras/nodes_load_3d.py +++ b/comfy_extras/nodes_load_3d.py @@ -21,8 +21,8 @@ class Load3D(): "height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}), }} - RETURN_TYPES = ("IMAGE", "MASK", "STRING", "IMAGE", "IMAGE") - RETURN_NAMES = ("image", "mask", "mesh_path", "normal", "lineart") + RETURN_TYPES = ("IMAGE", "MASK", "STRING", "IMAGE", "IMAGE", "LOAD3D_CAMERA") + RETURN_NAMES = ("image", "mask", "mesh_path", "normal", "lineart", "camera_info") FUNCTION = "process" EXPERIMENTAL = True @@ -41,7 +41,7 @@ class Load3D(): normal_image, ignore_mask2 = load_image_node.load_image(image=normal_path) lineart_image, ignore_mask3 = load_image_node.load_image(image=lineart_path) - return output_image, output_mask, model_file, normal_image, lineart_image + return output_image, output_mask, model_file, normal_image, lineart_image, image['camera_info'] class Load3DAnimation(): @classmethod @@ -59,8 +59,8 @@ class Load3DAnimation(): "height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}), }} - RETURN_TYPES = ("IMAGE", "MASK", "STRING", "IMAGE") - RETURN_NAMES = ("image", "mask", "mesh_path", "normal") + RETURN_TYPES = ("IMAGE", "MASK", "STRING", "IMAGE", "LOAD3D_CAMERA") + RETURN_NAMES = ("image", "mask", "mesh_path", "normal", "camera_info") FUNCTION = "process" EXPERIMENTAL = True @@ -77,13 +77,16 @@ class Load3DAnimation(): ignore_image, output_mask = load_image_node.load_image(image=mask_path) normal_image, ignore_mask2 = load_image_node.load_image(image=normal_path) - return output_image, output_mask, model_file, normal_image + return output_image, output_mask, model_file, normal_image, image['camera_info'] class Preview3D(): @classmethod def INPUT_TYPES(s): return {"required": { "model_file": ("STRING", {"default": "", "multiline": False}), + }, + "optional": { + "camera_info": ("LOAD3D_CAMERA", {}) }} OUTPUT_NODE = True @@ -95,13 +98,22 @@ class Preview3D(): EXPERIMENTAL = True def process(self, model_file, **kwargs): - return {"ui": {"model_file": [model_file]}, "result": ()} + camera_info = kwargs.get("camera_info", None) + + return { + "ui": { + "result": [model_file, camera_info] + } + } class Preview3DAnimation(): @classmethod def INPUT_TYPES(s): return {"required": { "model_file": ("STRING", {"default": "", "multiline": False}), + }, + "optional": { + "camera_info": ("LOAD3D_CAMERA", {}) }} OUTPUT_NODE = True @@ -113,7 +125,13 @@ class Preview3DAnimation(): EXPERIMENTAL = True def process(self, model_file, **kwargs): - return {"ui": {"model_file": [model_file]}, "result": ()} + camera_info = kwargs.get("camera_info", None) + + return { + "ui": { + "result": [model_file, camera_info] + } + } NODE_CLASS_MAPPINGS = { "Load3D": Load3D, diff --git a/comfy_extras/nodes_wan.py b/comfy_extras/nodes_wan.py index 2d0f31ac8..8ad358ce8 100644 --- a/comfy_extras/nodes_wan.py +++ b/comfy_extras/nodes_wan.py @@ -4,6 +4,7 @@ import torch import comfy.model_management import comfy.utils import comfy.latent_formats +import comfy.clip_vision class WanImageToVideo: @@ -99,6 +100,72 @@ class WanFunControlToVideo: out_latent["samples"] = latent return (positive, negative, out_latent) +class WanFirstLastFrameToVideo: + @classmethod + def INPUT_TYPES(s): + return {"required": {"positive": ("CONDITIONING", ), + "negative": ("CONDITIONING", ), + "vae": ("VAE", ), + "width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), + "height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), + "length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}), + "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), + }, + "optional": {"clip_vision_start_image": ("CLIP_VISION_OUTPUT", ), + "clip_vision_end_image": ("CLIP_VISION_OUTPUT", ), + "start_image": ("IMAGE", ), + "end_image": ("IMAGE", ), + }} + + RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") + RETURN_NAMES = ("positive", "negative", "latent") + FUNCTION = "encode" + + CATEGORY = "conditioning/video_models" + + def encode(self, positive, negative, vae, width, height, length, batch_size, start_image=None, end_image=None, clip_vision_start_image=None, clip_vision_end_image=None): + latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) + if start_image is not None: + start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) + if end_image is not None: + end_image = comfy.utils.common_upscale(end_image[-length:].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) + + image = torch.ones((length, height, width, 3)) * 0.5 + mask = torch.ones((1, 1, latent.shape[2] * 4, latent.shape[-2], latent.shape[-1])) + + if start_image is not None: + image[:start_image.shape[0]] = start_image + mask[:, :, :start_image.shape[0] + 3] = 0.0 + + if end_image is not None: + image[-end_image.shape[0]:] = end_image + mask[:, :, -end_image.shape[0]:] = 0.0 + + concat_latent_image = vae.encode(image[:, :, :, :3]) + mask = mask.view(1, mask.shape[2] // 4, 4, mask.shape[3], mask.shape[4]).transpose(1, 2) + positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent_image, "concat_mask": mask}) + negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent_image, "concat_mask": mask}) + + if clip_vision_start_image is not None: + clip_vision_output = clip_vision_start_image + + if clip_vision_end_image is not None: + if clip_vision_output is not None: + states = torch.cat([clip_vision_output.penultimate_hidden_states, clip_vision_end_image.penultimate_hidden_states], dim=-2) + clip_vision_output = comfy.clip_vision.Output() + clip_vision_output.penultimate_hidden_states = states + else: + clip_vision_output = clip_vision_end_image + + if clip_vision_output is not None: + positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output}) + negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output}) + + out_latent = {} + out_latent["samples"] = latent + return (positive, negative, out_latent) + + class WanFunInpaintToVideo: @classmethod def INPUT_TYPES(s): @@ -122,38 +189,13 @@ class WanFunInpaintToVideo: CATEGORY = "conditioning/video_models" def encode(self, positive, negative, vae, width, height, length, batch_size, start_image=None, end_image=None, clip_vision_output=None): - latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) - if start_image is not None: - start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) - if end_image is not None: - end_image = comfy.utils.common_upscale(end_image[-length:].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) + flfv = WanFirstLastFrameToVideo() + return flfv.encode(positive, negative, vae, width, height, length, batch_size, start_image=start_image, end_image=end_image, clip_vision_start_image=clip_vision_output) - image = torch.ones((length, height, width, 3)) * 0.5 - mask = torch.ones((1, 1, latent.shape[2] * 4, latent.shape[-2], latent.shape[-1])) - - if start_image is not None: - image[:start_image.shape[0]] = start_image - mask[:, :, :start_image.shape[0] + 3] = 0.0 - - if end_image is not None: - image[-end_image.shape[0]:] = end_image - mask[:, :, -end_image.shape[0]:] = 0.0 - - concat_latent_image = vae.encode(image[:, :, :, :3]) - mask = mask.view(1, mask.shape[2] // 4, 4, mask.shape[3], mask.shape[4]).transpose(1, 2) - positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent_image, "concat_mask": mask}) - negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent_image, "concat_mask": mask}) - - if clip_vision_output is not None: - positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output}) - negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output}) - - out_latent = {} - out_latent["samples"] = latent - return (positive, negative, out_latent) NODE_CLASS_MAPPINGS = { "WanImageToVideo": WanImageToVideo, "WanFunControlToVideo": WanFunControlToVideo, "WanFunInpaintToVideo": WanFunInpaintToVideo, + "WanFirstLastFrameToVideo": WanFirstLastFrameToVideo, } diff --git a/comfyui_version.py b/comfyui_version.py index a44538d1a..f9161b37e 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.3.28" +__version__ = "0.3.29" diff --git a/execution.py b/execution.py index 9a5e27771..d09102f55 100644 --- a/execution.py +++ b/execution.py @@ -111,7 +111,7 @@ def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, e missing_keys = {} for x in inputs: input_data = inputs[x] - input_type, input_category, input_info = get_input_info(class_def, x, valid_inputs) + _, input_category, input_info = get_input_info(class_def, x, valid_inputs) def mark_missing(): missing_keys[x] = True input_data_all[x] = (None,) @@ -574,7 +574,7 @@ def validate_inputs(prompt, item, validated): received_types = {} for x in valid_inputs: - type_input, input_category, extra_info = get_input_info(obj_class, x, class_inputs) + input_type, input_category, extra_info = get_input_info(obj_class, x, class_inputs) assert extra_info is not None if x not in inputs: if input_category == "required": @@ -590,7 +590,7 @@ def validate_inputs(prompt, item, validated): continue val = inputs[x] - info = (type_input, extra_info) + info = (input_type, extra_info) if isinstance(val, list): if len(val) != 2: error = { @@ -611,8 +611,8 @@ def validate_inputs(prompt, item, validated): r = nodes.NODE_CLASS_MAPPINGS[o_class_type].RETURN_TYPES received_type = r[val[1]] received_types[x] = received_type - if 'input_types' not in validate_function_inputs and not validate_node_input(received_type, type_input): - details = f"{x}, received_type({received_type}) mismatch input_type({type_input})" + if 'input_types' not in validate_function_inputs and not validate_node_input(received_type, input_type): + details = f"{x}, received_type({received_type}) mismatch input_type({input_type})" error = { "type": "return_type_mismatch", "message": "Return type mismatch between linked nodes", @@ -660,22 +660,22 @@ def validate_inputs(prompt, item, validated): val = val["__value__"] inputs[x] = val - if type_input == "INT": + if input_type == "INT": val = int(val) inputs[x] = val - if type_input == "FLOAT": + if input_type == "FLOAT": val = float(val) inputs[x] = val - if type_input == "STRING": + if input_type == "STRING": val = str(val) inputs[x] = val - if type_input == "BOOLEAN": + if input_type == "BOOLEAN": val = bool(val) inputs[x] = val except Exception as ex: error = { "type": "invalid_input_type", - "message": f"Failed to convert an input value to a {type_input} value", + "message": f"Failed to convert an input value to a {input_type} value", "details": f"{x}, {val}, {ex}", "extra_info": { "input_name": x, @@ -715,18 +715,19 @@ def validate_inputs(prompt, item, validated): errors.append(error) continue - if isinstance(type_input, list): - if val not in type_input: + if isinstance(input_type, list): + combo_options = input_type + if val not in combo_options: input_config = info list_info = "" # Don't send back gigantic lists like if they're lots of # scanned model filepaths - if len(type_input) > 20: - list_info = f"(list of length {len(type_input)})" + if len(combo_options) > 20: + list_info = f"(list of length {len(combo_options)})" input_config = None else: - list_info = str(type_input) + list_info = str(combo_options) error = { "type": "value_not_in_list", diff --git a/nodes.py b/nodes.py index f71802454..443d1efbb 100644 --- a/nodes.py +++ b/nodes.py @@ -930,26 +930,7 @@ class CLIPLoader: DESCRIPTION = "[Recipes]\n\nstable_diffusion: clip-l\nstable_cascade: clip-g\nsd3: t5 xxl/ clip-g / clip-l\nstable_audio: t5 base\nmochi: t5 xxl\ncosmos: old t5 xxl\nlumina2: gemma 2 2B\nwan: umt5 xxl" def load_clip(self, clip_name, type="stable_diffusion", device="default"): - if type == "stable_cascade": - clip_type = comfy.sd.CLIPType.STABLE_CASCADE - elif type == "sd3": - clip_type = comfy.sd.CLIPType.SD3 - elif type == "stable_audio": - clip_type = comfy.sd.CLIPType.STABLE_AUDIO - elif type == "mochi": - clip_type = comfy.sd.CLIPType.MOCHI - elif type == "ltxv": - clip_type = comfy.sd.CLIPType.LTXV - elif type == "pixart": - clip_type = comfy.sd.CLIPType.PIXART - elif type == "cosmos": - clip_type = comfy.sd.CLIPType.COSMOS - elif type == "lumina2": - clip_type = comfy.sd.CLIPType.LUMINA2 - elif type == "wan": - clip_type = comfy.sd.CLIPType.WAN - else: - clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION + clip_type = getattr(comfy.sd.CLIPType, type.upper(), comfy.sd.CLIPType.STABLE_DIFFUSION) model_options = {} if device == "cpu": @@ -977,16 +958,10 @@ class DualCLIPLoader: DESCRIPTION = "[Recipes]\n\nsdxl: clip-l, clip-g\nsd3: clip-l, clip-g / clip-l, t5 / clip-g, t5\nflux: clip-l, t5" def load_clip(self, clip_name1, clip_name2, type, device="default"): + clip_type = getattr(comfy.sd.CLIPType, type.upper(), comfy.sd.CLIPType.STABLE_DIFFUSION) + clip_path1 = folder_paths.get_full_path_or_raise("text_encoders", clip_name1) clip_path2 = folder_paths.get_full_path_or_raise("text_encoders", clip_name2) - if type == "sdxl": - clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION - elif type == "sd3": - clip_type = comfy.sd.CLIPType.SD3 - elif type == "flux": - clip_type = comfy.sd.CLIPType.FLUX - elif type == "hunyuan_video": - clip_type = comfy.sd.CLIPType.HUNYUAN_VIDEO model_options = {} if device == "cpu": @@ -2282,7 +2257,8 @@ def init_builtin_extra_nodes(): "nodes_primitive.py", "nodes_cfg.py", "nodes_optimalsteps.py", - "nodes_hidream.py" + "nodes_hidream.py", + "nodes_fresca.py", ] import_failed = [] diff --git a/pyproject.toml b/pyproject.toml index 6eb1704db..e8fc9555d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.3.28" +version = "0.3.29" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.9" diff --git a/requirements.txt b/requirements.txt index 851db23bd..ff9f66a77 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,5 @@ -comfyui-frontend-package==1.15.13 +comfyui-frontend-package==1.16.9 +comfyui-workflow-templates==0.1.1 torch torchsde torchvision diff --git a/server.py b/server.py index 62667ce18..0cc97b248 100644 --- a/server.py +++ b/server.py @@ -736,6 +736,12 @@ class PromptServer(): for name, dir in nodes.EXTENSION_WEB_DIRS.items(): self.app.add_routes([web.static('/extensions/' + name, dir)]) + workflow_templates_path = FrontendManager.templates_path() + if workflow_templates_path: + self.app.add_routes([ + web.static('/templates', workflow_templates_path) + ]) + self.app.add_routes([ web.static('/', self.web_root), ])