Merge branch 'master' into worksplit-multigpu

This commit is contained in:
Jedrzej Kosinski 2025-04-18 15:39:34 -05:00
commit b5cccf1325
14 changed files with 298 additions and 98 deletions

View File

@ -62,6 +62,7 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith
- [HunyuanDiT](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_dit/)
- [Flux](https://comfyanonymous.github.io/ComfyUI_examples/flux/)
- [Lumina Image 2.0](https://comfyanonymous.github.io/ComfyUI_examples/lumina2/)
- [HiDream](https://comfyanonymous.github.io/ComfyUI_examples/hidream/)
- Video Models
- [Stable Video Diffusion](https://comfyanonymous.github.io/ComfyUI_examples/video/)
- [Mochi](https://comfyanonymous.github.io/ComfyUI_examples/mochi/)

View File

@ -184,6 +184,27 @@ comfyui-frontend-package is not installed.
)
sys.exit(-1)
@classmethod
def templates_path(cls) -> str:
try:
import comfyui_workflow_templates
return str(
importlib.resources.files(comfyui_workflow_templates) / "templates"
)
except ImportError:
logging.error(
f"""
********** ERROR ***********
comfyui-workflow-templates is not installed.
{frontend_install_warning_message()}
********** ERROR ***********
""".strip()
)
@classmethod
def parse_version_string(cls, value: str) -> tuple[str, str, str]:
"""

View File

@ -83,7 +83,7 @@ class WanSelfAttention(nn.Module):
class WanT2VCrossAttention(WanSelfAttention):
def forward(self, x, context):
def forward(self, x, context, **kwargs):
r"""
Args:
x(Tensor): Shape [B, L1, C]
@ -116,14 +116,14 @@ class WanI2VCrossAttention(WanSelfAttention):
# self.alpha = nn.Parameter(torch.zeros((1, )))
self.norm_k_img = RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
def forward(self, x, context):
def forward(self, x, context, context_img_len):
r"""
Args:
x(Tensor): Shape [B, L1, C]
context(Tensor): Shape [B, L2, C]
"""
context_img = context[:, :257]
context = context[:, 257:]
context_img = context[:, :context_img_len]
context = context[:, context_img_len:]
# compute query, key, value
q = self.norm_q(self.q(x))
@ -193,6 +193,7 @@ class WanAttentionBlock(nn.Module):
e,
freqs,
context,
context_img_len=257,
):
r"""
Args:
@ -213,7 +214,7 @@ class WanAttentionBlock(nn.Module):
x = x + y * e[2]
# cross-attention & ffn
x = x + self.cross_attn(self.norm3(x), context)
x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len)
y = self.ffn(self.norm2(x) * (1 + e[4]) + e[3])
x = x + y * e[5]
return x
@ -250,7 +251,7 @@ class Head(nn.Module):
class MLPProj(torch.nn.Module):
def __init__(self, in_dim, out_dim, operation_settings={}):
def __init__(self, in_dim, out_dim, flf_pos_embed_token_number=None, operation_settings={}):
super().__init__()
self.proj = torch.nn.Sequential(
@ -258,7 +259,15 @@ class MLPProj(torch.nn.Module):
torch.nn.GELU(), operation_settings.get("operations").Linear(in_dim, out_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")),
operation_settings.get("operations").LayerNorm(out_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")))
if flf_pos_embed_token_number is not None:
self.emb_pos = nn.Parameter(torch.empty((1, flf_pos_embed_token_number, in_dim), device=operation_settings.get("device"), dtype=operation_settings.get("dtype")))
else:
self.emb_pos = None
def forward(self, image_embeds):
if self.emb_pos is not None:
image_embeds = image_embeds[:, :self.emb_pos.shape[1]] + comfy.model_management.cast_to(self.emb_pos[:, :image_embeds.shape[1]], dtype=image_embeds.dtype, device=image_embeds.device)
clip_extra_context_tokens = self.proj(image_embeds)
return clip_extra_context_tokens
@ -284,6 +293,7 @@ class WanModel(torch.nn.Module):
qk_norm=True,
cross_attn_norm=True,
eps=1e-6,
flf_pos_embed_token_number=None,
image_model=None,
device=None,
dtype=None,
@ -373,7 +383,7 @@ class WanModel(torch.nn.Module):
self.rope_embedder = EmbedND(dim=d, theta=10000.0, axes_dim=[d - 4 * (d // 6), 2 * (d // 6), 2 * (d // 6)])
if model_type == 'i2v':
self.img_emb = MLPProj(1280, dim, operation_settings=operation_settings)
self.img_emb = MLPProj(1280, dim, flf_pos_embed_token_number=flf_pos_embed_token_number, operation_settings=operation_settings)
else:
self.img_emb = None
@ -420,9 +430,12 @@ class WanModel(torch.nn.Module):
# context
context = self.text_embedding(context)
if clip_fea is not None and self.img_emb is not None:
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
context = torch.concat([context_clip, context], dim=1)
context_img_len = None
if clip_fea is not None:
if self.img_emb is not None:
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
context = torch.concat([context_clip, context], dim=1)
context_img_len = clip_fea.shape[-2]
patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {})
@ -430,12 +443,12 @@ class WanModel(torch.nn.Module):
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"])
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len)
return out
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap})
x = out["img"]
else:
x = block(x, e=e0, freqs=freqs, context=context)
x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len)
# head
x = self.head(x, e)

View File

@ -321,6 +321,9 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["model_type"] = "i2v"
else:
dit_config["model_type"] = "t2v"
flf_weight = state_dict.get('{}img_emb.emb_pos'.format(key_prefix))
if flf_weight is not None:
dit_config["flf_pos_embed_token_number"] = flf_weight.shape[1]
return dit_config
if '{}latent_in.weight'.format(key_prefix) in state_dict_keys: # Hunyuan 3D

View File

@ -1,6 +1,9 @@
import nodes
from __future__ import annotations
from typing import Type, Literal
import nodes
from comfy_execution.graph_utils import is_link
from comfy.comfy_types.node_typing import ComfyNodeABC, InputTypeDict, InputTypeOptions
class DependencyCycleError(Exception):
pass
@ -54,7 +57,22 @@ class DynamicPrompt:
def get_original_prompt(self):
return self.original_prompt
def get_input_info(class_def, input_name, valid_inputs=None):
def get_input_info(
class_def: Type[ComfyNodeABC],
input_name: str,
valid_inputs: InputTypeDict | None = None
) -> tuple[str, Literal["required", "optional", "hidden"], InputTypeOptions] | tuple[None, None, None]:
"""Get the input type, category, and extra info for a given input name.
Arguments:
class_def: The class definition of the node.
input_name: The name of the input to get info for.
valid_inputs: The valid inputs for the node, or None to use the class_def.INPUT_TYPES().
Returns:
tuple[str, str, dict] | tuple[None, None, None]: The input type, category, and extra info for the input name.
"""
valid_inputs = valid_inputs or class_def.INPUT_TYPES()
input_info = None
input_category = None
@ -126,7 +144,7 @@ class TopologicalSort:
from_node_id, from_socket = value
if subgraph_nodes is not None and from_node_id not in subgraph_nodes:
continue
input_type, input_category, input_info = self.get_input_info(unique_id, input_name)
_, _, input_info = self.get_input_info(unique_id, input_name)
is_lazy = input_info is not None and "lazy" in input_info and input_info["lazy"]
if (include_lazy or not is_lazy) and not self.is_cached(from_node_id):
node_ids.append(from_node_id)

View File

@ -0,0 +1,100 @@
# Code based on https://github.com/WikiChao/FreSca (MIT License)
import torch
import torch.fft as fft
def Fourier_filter(x, scale_low=1.0, scale_high=1.5, freq_cutoff=20):
"""
Apply frequency-dependent scaling to an image tensor using Fourier transforms.
Parameters:
x: Input tensor of shape (B, C, H, W)
scale_low: Scaling factor for low-frequency components (default: 1.0)
scale_high: Scaling factor for high-frequency components (default: 1.5)
freq_cutoff: Number of frequency indices around center to consider as low-frequency (default: 20)
Returns:
x_filtered: Filtered version of x in spatial domain with frequency-specific scaling applied.
"""
# Preserve input dtype and device
dtype, device = x.dtype, x.device
# Convert to float32 for FFT computations
x = x.to(torch.float32)
# 1) Apply FFT and shift low frequencies to center
x_freq = fft.fftn(x, dim=(-2, -1))
x_freq = fft.fftshift(x_freq, dim=(-2, -1))
# Initialize mask with high-frequency scaling factor
mask = torch.ones(x_freq.shape, device=device) * scale_high
m = mask
for d in range(len(x_freq.shape) - 2):
dim = d + 2
cc = x_freq.shape[dim] // 2
f_c = min(freq_cutoff, cc)
m = m.narrow(dim, cc - f_c, f_c * 2)
# Apply low-frequency scaling factor to center region
m[:] = scale_low
# 3) Apply frequency-specific scaling
x_freq = x_freq * mask
# 4) Convert back to spatial domain
x_freq = fft.ifftshift(x_freq, dim=(-2, -1))
x_filtered = fft.ifftn(x_freq, dim=(-2, -1)).real
# 5) Restore original dtype
x_filtered = x_filtered.to(dtype)
return x_filtered
class FreSca:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model": ("MODEL",),
"scale_low": ("FLOAT", {"default": 1.0, "min": 0, "max": 10, "step": 0.01,
"tooltip": "Scaling factor for low-frequency components"}),
"scale_high": ("FLOAT", {"default": 1.25, "min": 0, "max": 10, "step": 0.01,
"tooltip": "Scaling factor for high-frequency components"}),
"freq_cutoff": ("INT", {"default": 20, "min": 1, "max": 10000, "step": 1,
"tooltip": "Number of frequency indices around center to consider as low-frequency"}),
}
}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "_for_testing"
DESCRIPTION = "Applies frequency-dependent scaling to the guidance"
def patch(self, model, scale_low, scale_high, freq_cutoff):
def custom_cfg_function(args):
cond = args["conds_out"][0]
uncond = args["conds_out"][1]
guidance = cond - uncond
filtered_guidance = Fourier_filter(
guidance,
scale_low=scale_low,
scale_high=scale_high,
freq_cutoff=freq_cutoff,
)
filtered_cond = filtered_guidance + uncond
return [filtered_cond, uncond]
m = model.clone()
m.set_model_sampler_pre_cfg_function(custom_cfg_function)
return (m,)
NODE_CLASS_MAPPINGS = {
"FreSca": FreSca,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"FreSca": "FreSca",
}

View File

@ -21,8 +21,8 @@ class Load3D():
"height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
}}
RETURN_TYPES = ("IMAGE", "MASK", "STRING", "IMAGE", "IMAGE")
RETURN_NAMES = ("image", "mask", "mesh_path", "normal", "lineart")
RETURN_TYPES = ("IMAGE", "MASK", "STRING", "IMAGE", "IMAGE", "LOAD3D_CAMERA")
RETURN_NAMES = ("image", "mask", "mesh_path", "normal", "lineart", "camera_info")
FUNCTION = "process"
EXPERIMENTAL = True
@ -41,7 +41,7 @@ class Load3D():
normal_image, ignore_mask2 = load_image_node.load_image(image=normal_path)
lineart_image, ignore_mask3 = load_image_node.load_image(image=lineart_path)
return output_image, output_mask, model_file, normal_image, lineart_image
return output_image, output_mask, model_file, normal_image, lineart_image, image['camera_info']
class Load3DAnimation():
@classmethod
@ -59,8 +59,8 @@ class Load3DAnimation():
"height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
}}
RETURN_TYPES = ("IMAGE", "MASK", "STRING", "IMAGE")
RETURN_NAMES = ("image", "mask", "mesh_path", "normal")
RETURN_TYPES = ("IMAGE", "MASK", "STRING", "IMAGE", "LOAD3D_CAMERA")
RETURN_NAMES = ("image", "mask", "mesh_path", "normal", "camera_info")
FUNCTION = "process"
EXPERIMENTAL = True
@ -77,13 +77,16 @@ class Load3DAnimation():
ignore_image, output_mask = load_image_node.load_image(image=mask_path)
normal_image, ignore_mask2 = load_image_node.load_image(image=normal_path)
return output_image, output_mask, model_file, normal_image
return output_image, output_mask, model_file, normal_image, image['camera_info']
class Preview3D():
@classmethod
def INPUT_TYPES(s):
return {"required": {
"model_file": ("STRING", {"default": "", "multiline": False}),
},
"optional": {
"camera_info": ("LOAD3D_CAMERA", {})
}}
OUTPUT_NODE = True
@ -95,13 +98,22 @@ class Preview3D():
EXPERIMENTAL = True
def process(self, model_file, **kwargs):
return {"ui": {"model_file": [model_file]}, "result": ()}
camera_info = kwargs.get("camera_info", None)
return {
"ui": {
"result": [model_file, camera_info]
}
}
class Preview3DAnimation():
@classmethod
def INPUT_TYPES(s):
return {"required": {
"model_file": ("STRING", {"default": "", "multiline": False}),
},
"optional": {
"camera_info": ("LOAD3D_CAMERA", {})
}}
OUTPUT_NODE = True
@ -113,7 +125,13 @@ class Preview3DAnimation():
EXPERIMENTAL = True
def process(self, model_file, **kwargs):
return {"ui": {"model_file": [model_file]}, "result": ()}
camera_info = kwargs.get("camera_info", None)
return {
"ui": {
"result": [model_file, camera_info]
}
}
NODE_CLASS_MAPPINGS = {
"Load3D": Load3D,

View File

@ -4,6 +4,7 @@ import torch
import comfy.model_management
import comfy.utils
import comfy.latent_formats
import comfy.clip_vision
class WanImageToVideo:
@ -99,6 +100,72 @@ class WanFunControlToVideo:
out_latent["samples"] = latent
return (positive, negative, out_latent)
class WanFirstLastFrameToVideo:
@classmethod
def INPUT_TYPES(s):
return {"required": {"positive": ("CONDITIONING", ),
"negative": ("CONDITIONING", ),
"vae": ("VAE", ),
"width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
},
"optional": {"clip_vision_start_image": ("CLIP_VISION_OUTPUT", ),
"clip_vision_end_image": ("CLIP_VISION_OUTPUT", ),
"start_image": ("IMAGE", ),
"end_image": ("IMAGE", ),
}}
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
RETURN_NAMES = ("positive", "negative", "latent")
FUNCTION = "encode"
CATEGORY = "conditioning/video_models"
def encode(self, positive, negative, vae, width, height, length, batch_size, start_image=None, end_image=None, clip_vision_start_image=None, clip_vision_end_image=None):
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
if start_image is not None:
start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
if end_image is not None:
end_image = comfy.utils.common_upscale(end_image[-length:].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
image = torch.ones((length, height, width, 3)) * 0.5
mask = torch.ones((1, 1, latent.shape[2] * 4, latent.shape[-2], latent.shape[-1]))
if start_image is not None:
image[:start_image.shape[0]] = start_image
mask[:, :, :start_image.shape[0] + 3] = 0.0
if end_image is not None:
image[-end_image.shape[0]:] = end_image
mask[:, :, -end_image.shape[0]:] = 0.0
concat_latent_image = vae.encode(image[:, :, :, :3])
mask = mask.view(1, mask.shape[2] // 4, 4, mask.shape[3], mask.shape[4]).transpose(1, 2)
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent_image, "concat_mask": mask})
negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent_image, "concat_mask": mask})
if clip_vision_start_image is not None:
clip_vision_output = clip_vision_start_image
if clip_vision_end_image is not None:
if clip_vision_output is not None:
states = torch.cat([clip_vision_output.penultimate_hidden_states, clip_vision_end_image.penultimate_hidden_states], dim=-2)
clip_vision_output = comfy.clip_vision.Output()
clip_vision_output.penultimate_hidden_states = states
else:
clip_vision_output = clip_vision_end_image
if clip_vision_output is not None:
positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output})
negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output})
out_latent = {}
out_latent["samples"] = latent
return (positive, negative, out_latent)
class WanFunInpaintToVideo:
@classmethod
def INPUT_TYPES(s):
@ -122,38 +189,13 @@ class WanFunInpaintToVideo:
CATEGORY = "conditioning/video_models"
def encode(self, positive, negative, vae, width, height, length, batch_size, start_image=None, end_image=None, clip_vision_output=None):
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
if start_image is not None:
start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
if end_image is not None:
end_image = comfy.utils.common_upscale(end_image[-length:].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
flfv = WanFirstLastFrameToVideo()
return flfv.encode(positive, negative, vae, width, height, length, batch_size, start_image=start_image, end_image=end_image, clip_vision_start_image=clip_vision_output)
image = torch.ones((length, height, width, 3)) * 0.5
mask = torch.ones((1, 1, latent.shape[2] * 4, latent.shape[-2], latent.shape[-1]))
if start_image is not None:
image[:start_image.shape[0]] = start_image
mask[:, :, :start_image.shape[0] + 3] = 0.0
if end_image is not None:
image[-end_image.shape[0]:] = end_image
mask[:, :, -end_image.shape[0]:] = 0.0
concat_latent_image = vae.encode(image[:, :, :, :3])
mask = mask.view(1, mask.shape[2] // 4, 4, mask.shape[3], mask.shape[4]).transpose(1, 2)
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent_image, "concat_mask": mask})
negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent_image, "concat_mask": mask})
if clip_vision_output is not None:
positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output})
negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output})
out_latent = {}
out_latent["samples"] = latent
return (positive, negative, out_latent)
NODE_CLASS_MAPPINGS = {
"WanImageToVideo": WanImageToVideo,
"WanFunControlToVideo": WanFunControlToVideo,
"WanFunInpaintToVideo": WanFunInpaintToVideo,
"WanFirstLastFrameToVideo": WanFirstLastFrameToVideo,
}

View File

@ -1,3 +1,3 @@
# This file is automatically generated by the build process when version is
# updated in pyproject.toml.
__version__ = "0.3.28"
__version__ = "0.3.29"

View File

@ -111,7 +111,7 @@ def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, e
missing_keys = {}
for x in inputs:
input_data = inputs[x]
input_type, input_category, input_info = get_input_info(class_def, x, valid_inputs)
_, input_category, input_info = get_input_info(class_def, x, valid_inputs)
def mark_missing():
missing_keys[x] = True
input_data_all[x] = (None,)
@ -574,7 +574,7 @@ def validate_inputs(prompt, item, validated):
received_types = {}
for x in valid_inputs:
type_input, input_category, extra_info = get_input_info(obj_class, x, class_inputs)
input_type, input_category, extra_info = get_input_info(obj_class, x, class_inputs)
assert extra_info is not None
if x not in inputs:
if input_category == "required":
@ -590,7 +590,7 @@ def validate_inputs(prompt, item, validated):
continue
val = inputs[x]
info = (type_input, extra_info)
info = (input_type, extra_info)
if isinstance(val, list):
if len(val) != 2:
error = {
@ -611,8 +611,8 @@ def validate_inputs(prompt, item, validated):
r = nodes.NODE_CLASS_MAPPINGS[o_class_type].RETURN_TYPES
received_type = r[val[1]]
received_types[x] = received_type
if 'input_types' not in validate_function_inputs and not validate_node_input(received_type, type_input):
details = f"{x}, received_type({received_type}) mismatch input_type({type_input})"
if 'input_types' not in validate_function_inputs and not validate_node_input(received_type, input_type):
details = f"{x}, received_type({received_type}) mismatch input_type({input_type})"
error = {
"type": "return_type_mismatch",
"message": "Return type mismatch between linked nodes",
@ -660,22 +660,22 @@ def validate_inputs(prompt, item, validated):
val = val["__value__"]
inputs[x] = val
if type_input == "INT":
if input_type == "INT":
val = int(val)
inputs[x] = val
if type_input == "FLOAT":
if input_type == "FLOAT":
val = float(val)
inputs[x] = val
if type_input == "STRING":
if input_type == "STRING":
val = str(val)
inputs[x] = val
if type_input == "BOOLEAN":
if input_type == "BOOLEAN":
val = bool(val)
inputs[x] = val
except Exception as ex:
error = {
"type": "invalid_input_type",
"message": f"Failed to convert an input value to a {type_input} value",
"message": f"Failed to convert an input value to a {input_type} value",
"details": f"{x}, {val}, {ex}",
"extra_info": {
"input_name": x,
@ -715,18 +715,19 @@ def validate_inputs(prompt, item, validated):
errors.append(error)
continue
if isinstance(type_input, list):
if val not in type_input:
if isinstance(input_type, list):
combo_options = input_type
if val not in combo_options:
input_config = info
list_info = ""
# Don't send back gigantic lists like if they're lots of
# scanned model filepaths
if len(type_input) > 20:
list_info = f"(list of length {len(type_input)})"
if len(combo_options) > 20:
list_info = f"(list of length {len(combo_options)})"
input_config = None
else:
list_info = str(type_input)
list_info = str(combo_options)
error = {
"type": "value_not_in_list",

View File

@ -930,26 +930,7 @@ class CLIPLoader:
DESCRIPTION = "[Recipes]\n\nstable_diffusion: clip-l\nstable_cascade: clip-g\nsd3: t5 xxl/ clip-g / clip-l\nstable_audio: t5 base\nmochi: t5 xxl\ncosmos: old t5 xxl\nlumina2: gemma 2 2B\nwan: umt5 xxl"
def load_clip(self, clip_name, type="stable_diffusion", device="default"):
if type == "stable_cascade":
clip_type = comfy.sd.CLIPType.STABLE_CASCADE
elif type == "sd3":
clip_type = comfy.sd.CLIPType.SD3
elif type == "stable_audio":
clip_type = comfy.sd.CLIPType.STABLE_AUDIO
elif type == "mochi":
clip_type = comfy.sd.CLIPType.MOCHI
elif type == "ltxv":
clip_type = comfy.sd.CLIPType.LTXV
elif type == "pixart":
clip_type = comfy.sd.CLIPType.PIXART
elif type == "cosmos":
clip_type = comfy.sd.CLIPType.COSMOS
elif type == "lumina2":
clip_type = comfy.sd.CLIPType.LUMINA2
elif type == "wan":
clip_type = comfy.sd.CLIPType.WAN
else:
clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION
clip_type = getattr(comfy.sd.CLIPType, type.upper(), comfy.sd.CLIPType.STABLE_DIFFUSION)
model_options = {}
if device == "cpu":
@ -977,16 +958,10 @@ class DualCLIPLoader:
DESCRIPTION = "[Recipes]\n\nsdxl: clip-l, clip-g\nsd3: clip-l, clip-g / clip-l, t5 / clip-g, t5\nflux: clip-l, t5"
def load_clip(self, clip_name1, clip_name2, type, device="default"):
clip_type = getattr(comfy.sd.CLIPType, type.upper(), comfy.sd.CLIPType.STABLE_DIFFUSION)
clip_path1 = folder_paths.get_full_path_or_raise("text_encoders", clip_name1)
clip_path2 = folder_paths.get_full_path_or_raise("text_encoders", clip_name2)
if type == "sdxl":
clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION
elif type == "sd3":
clip_type = comfy.sd.CLIPType.SD3
elif type == "flux":
clip_type = comfy.sd.CLIPType.FLUX
elif type == "hunyuan_video":
clip_type = comfy.sd.CLIPType.HUNYUAN_VIDEO
model_options = {}
if device == "cpu":
@ -2282,7 +2257,8 @@ def init_builtin_extra_nodes():
"nodes_primitive.py",
"nodes_cfg.py",
"nodes_optimalsteps.py",
"nodes_hidream.py"
"nodes_hidream.py",
"nodes_fresca.py",
]
import_failed = []

View File

@ -1,6 +1,6 @@
[project]
name = "ComfyUI"
version = "0.3.28"
version = "0.3.29"
readme = "README.md"
license = { file = "LICENSE" }
requires-python = ">=3.9"

View File

@ -1,4 +1,5 @@
comfyui-frontend-package==1.15.13
comfyui-frontend-package==1.16.9
comfyui-workflow-templates==0.1.1
torch
torchsde
torchvision

View File

@ -736,6 +736,12 @@ class PromptServer():
for name, dir in nodes.EXTENSION_WEB_DIRS.items():
self.app.add_routes([web.static('/extensions/' + name, dir)])
workflow_templates_path = FrontendManager.templates_path()
if workflow_templates_path:
self.app.add_routes([
web.static('/templates', workflow_templates_path)
])
self.app.add_routes([
web.static('/', self.web_root),
])