From 206595f854c67538d5921d36326acbfeb69c5ac2 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Tue, 9 Sep 2025 18:33:36 -0700 Subject: [PATCH 01/13] Change validate_inputs' output typehint to 'bool | str' and update docstrings (#9786) --- comfy_api/latest/_io.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/comfy_api/latest/_io.py b/comfy_api/latest/_io.py index e0ee943a7..f770109d5 100644 --- a/comfy_api/latest/_io.py +++ b/comfy_api/latest/_io.py @@ -1190,13 +1190,18 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal): raise NotImplementedError @classmethod - def validate_inputs(cls, **kwargs) -> bool: - """Optionally, define this function to validate inputs; equivalent to V1's VALIDATE_INPUTS.""" + def validate_inputs(cls, **kwargs) -> bool | str: + """Optionally, define this function to validate inputs; equivalent to V1's VALIDATE_INPUTS. + + If the function returns a string, it will be used as the validation error message for the node. + """ raise NotImplementedError @classmethod def fingerprint_inputs(cls, **kwargs) -> Any: - """Optionally, define this function to fingerprint inputs; equivalent to V1's IS_CHANGED.""" + """Optionally, define this function to fingerprint inputs; equivalent to V1's IS_CHANGED. + + If this function returns the same value as last run, the node will not be executed.""" raise NotImplementedError @classmethod From 5c33872e2f355e51adf212d5b5c83815b7fe77b0 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 9 Sep 2025 21:23:47 -0700 Subject: [PATCH 02/13] Fix issue on old torch. (#9791) --- comfy/ldm/hunyuan3dv2_1/hunyuandit.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/ldm/hunyuan3dv2_1/hunyuandit.py b/comfy/ldm/hunyuan3dv2_1/hunyuandit.py index ca1a83001..d48d9d642 100644 --- a/comfy/ldm/hunyuan3dv2_1/hunyuandit.py +++ b/comfy/ldm/hunyuan3dv2_1/hunyuandit.py @@ -426,7 +426,7 @@ class HunYuanDiTBlock(nn.Module): text_states_dim=1024, qk_norm=False, norm_layer=nn.LayerNorm, - qk_norm_layer=nn.RMSNorm, + qk_norm_layer=True, qkv_bias=True, skip_connection=True, timested_modulate=False, From 85e34643f874aec2ab9eed6a8499f2aefa81486e Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 9 Sep 2025 23:05:07 -0700 Subject: [PATCH 03/13] Support hunyuan image 2.1 regular model. (#9792) --- comfy/latent_formats.py | 5 + comfy/ldm/hunyuan_video/model.py | 102 +- comfy/ldm/hunyuan_video/vae.py | 136 ++ comfy/model_base.py | 24 + comfy/model_detection.py | 28 +- comfy/sd.py | 31 +- comfy/supported_models.py | 27 +- .../byt5_config_small_glyph.json | 22 + .../byt5_tokenizer/added_tokens.json | 127 ++ .../byt5_tokenizer/special_tokens_map.json | 150 +++ .../byt5_tokenizer/tokenizer_config.json | 1163 +++++++++++++++++ comfy/text_encoders/hunyuan_image.py | 100 ++ comfy_extras/nodes_hunyuan.py | 15 + nodes.py | 6 +- 14 files changed, 1906 insertions(+), 30 deletions(-) create mode 100644 comfy/ldm/hunyuan_video/vae.py create mode 100644 comfy/text_encoders/byt5_config_small_glyph.json create mode 100644 comfy/text_encoders/byt5_tokenizer/added_tokens.json create mode 100644 comfy/text_encoders/byt5_tokenizer/special_tokens_map.json create mode 100644 comfy/text_encoders/byt5_tokenizer/tokenizer_config.json create mode 100644 comfy/text_encoders/hunyuan_image.py diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py index 0d84994b0..859ae8421 100644 --- a/comfy/latent_formats.py +++ b/comfy/latent_formats.py @@ -533,6 +533,11 @@ class Wan22(Wan21): 0.3971, 1.0600, 0.3943, 0.5537, 0.5444, 0.4089, 0.7468, 0.7744 ]).view(1, self.latent_channels, 1, 1, 1) +class HunyuanImage21(LatentFormat): + latent_channels = 64 + latent_dimensions = 2 + scale_factor = 0.75289 + class Hunyuan3Dv2(LatentFormat): latent_channels = 64 latent_dimensions = 1 diff --git a/comfy/ldm/hunyuan_video/model.py b/comfy/ldm/hunyuan_video/model.py index da1011596..ca289c5bd 100644 --- a/comfy/ldm/hunyuan_video/model.py +++ b/comfy/ldm/hunyuan_video/model.py @@ -40,6 +40,7 @@ class HunyuanVideoParams: patch_size: list qkv_bias: bool guidance_embed: bool + byt5: bool class SelfAttentionRef(nn.Module): @@ -161,6 +162,30 @@ class TokenRefiner(nn.Module): x = self.individual_token_refiner(x, c, mask) return x + +class ByT5Mapper(nn.Module): + def __init__(self, in_dim, out_dim, hidden_dim, out_dim1, use_res=False, dtype=None, device=None, operations=None): + super().__init__() + self.layernorm = operations.LayerNorm(in_dim, dtype=dtype, device=device) + self.fc1 = operations.Linear(in_dim, hidden_dim, dtype=dtype, device=device) + self.fc2 = operations.Linear(hidden_dim, out_dim, dtype=dtype, device=device) + self.fc3 = operations.Linear(out_dim, out_dim1, dtype=dtype, device=device) + self.use_res = use_res + self.act_fn = nn.GELU() + + def forward(self, x): + if self.use_res: + res = x + x = self.layernorm(x) + x = self.fc1(x) + x = self.act_fn(x) + x = self.fc2(x) + x2 = self.act_fn(x) + x2 = self.fc3(x2) + if self.use_res: + x2 = x2 + res + return x2 + class HunyuanVideo(nn.Module): """ Transformer model for flow matching on sequences. @@ -185,9 +210,13 @@ class HunyuanVideo(nn.Module): self.num_heads = params.num_heads self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim) - self.img_in = comfy.ldm.modules.diffusionmodules.mmdit.PatchEmbed(None, self.patch_size, self.in_channels, self.hidden_size, conv3d=True, dtype=dtype, device=device, operations=operations) + self.img_in = comfy.ldm.modules.diffusionmodules.mmdit.PatchEmbed(None, self.patch_size, self.in_channels, self.hidden_size, conv3d=len(self.patch_size) == 3, dtype=dtype, device=device, operations=operations) self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations) - self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size, dtype=dtype, device=device, operations=operations) + if params.vec_in_dim is not None: + self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size, dtype=dtype, device=device, operations=operations) + else: + self.vector_in = None + self.guidance_in = ( MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations) if params.guidance_embed else nn.Identity() ) @@ -215,6 +244,18 @@ class HunyuanVideo(nn.Module): ] ) + if params.byt5: + self.byt5_in = ByT5Mapper( + in_dim=1472, + out_dim=2048, + hidden_dim=2048, + out_dim1=self.hidden_size, + use_res=False, + dtype=dtype, device=device, operations=operations + ) + else: + self.byt5_in = None + if final_layer: self.final_layer = LastLayer(self.hidden_size, self.patch_size[-1], self.out_channels, dtype=dtype, device=device, operations=operations) @@ -226,7 +267,8 @@ class HunyuanVideo(nn.Module): txt_ids: Tensor, txt_mask: Tensor, timesteps: Tensor, - y: Tensor, + y: Tensor = None, + txt_byt5=None, guidance: Tensor = None, guiding_frame_index=None, ref_latent=None, @@ -250,13 +292,17 @@ class HunyuanVideo(nn.Module): if guiding_frame_index is not None: token_replace_vec = self.time_in(timestep_embedding(guiding_frame_index, 256, time_factor=1.0)) - vec_ = self.vector_in(y[:, :self.params.vec_in_dim]) - vec = torch.cat([(vec_ + token_replace_vec).unsqueeze(1), (vec_ + vec).unsqueeze(1)], dim=1) + if self.vector_in is not None: + vec_ = self.vector_in(y[:, :self.params.vec_in_dim]) + vec = torch.cat([(vec_ + token_replace_vec).unsqueeze(1), (vec_ + vec).unsqueeze(1)], dim=1) + else: + vec = torch.cat([(token_replace_vec).unsqueeze(1), (vec).unsqueeze(1)], dim=1) frame_tokens = (initial_shape[-1] // self.patch_size[-1]) * (initial_shape[-2] // self.patch_size[-2]) modulation_dims = [(0, frame_tokens, 0), (frame_tokens, None, 1)] modulation_dims_txt = [(0, None, 1)] else: - vec = vec + self.vector_in(y[:, :self.params.vec_in_dim]) + if self.vector_in is not None: + vec = vec + self.vector_in(y[:, :self.params.vec_in_dim]) modulation_dims = None modulation_dims_txt = None @@ -269,6 +315,12 @@ class HunyuanVideo(nn.Module): txt = self.txt_in(txt, timesteps, txt_mask) + if self.byt5_in is not None and txt_byt5 is not None: + txt_byt5 = self.byt5_in(txt_byt5) + txt_byt5_ids = torch.zeros((txt_ids.shape[0], txt_byt5.shape[1], txt_ids.shape[-1]), device=txt_ids.device, dtype=txt_ids.dtype) + txt = torch.cat((txt, txt_byt5), dim=1) + txt_ids = torch.cat((txt_ids, txt_byt5_ids), dim=1) + ids = torch.cat((img_ids, txt_ids), dim=1) pe = self.pe_embedder(ids) @@ -328,12 +380,16 @@ class HunyuanVideo(nn.Module): img = self.final_layer(img, vec, modulation_dims=modulation_dims) # (N, T, patch_size ** 2 * out_channels) - shape = initial_shape[-3:] + shape = initial_shape[-len(self.patch_size):] for i in range(len(shape)): shape[i] = shape[i] // self.patch_size[i] img = img.reshape([img.shape[0]] + shape + [self.out_channels] + self.patch_size) - img = img.permute(0, 4, 1, 5, 2, 6, 3, 7) - img = img.reshape(initial_shape[0], self.out_channels, initial_shape[2], initial_shape[3], initial_shape[4]) + if img.ndim == 8: + img = img.permute(0, 4, 1, 5, 2, 6, 3, 7) + img = img.reshape(initial_shape[0], self.out_channels, initial_shape[2], initial_shape[3], initial_shape[4]) + else: + img = img.permute(0, 3, 1, 4, 2, 5) + img = img.reshape(initial_shape[0], self.out_channels, initial_shape[2], initial_shape[3]) return img def img_ids(self, x): @@ -348,16 +404,30 @@ class HunyuanVideo(nn.Module): img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).reshape(1, 1, -1) return repeat(img_ids, "t h w c -> b (t h w) c", b=bs) - def forward(self, x, timestep, context, y, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, control=None, transformer_options={}, **kwargs): + def img_ids_2d(self, x): + bs, c, h, w = x.shape + patch_size = self.patch_size + h_len = ((h + (patch_size[0] // 2)) // patch_size[0]) + w_len = ((w + (patch_size[1] // 2)) // patch_size[1]) + img_ids = torch.zeros((h_len, w_len, 2), device=x.device, dtype=x.dtype) + img_ids[:, :, 0] = img_ids[:, :, 0] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1) + img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0) + return repeat(img_ids, "h w c -> b (h w) c", b=bs) + + def forward(self, x, timestep, context, y=None, txt_byt5=None, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, control=None, transformer_options={}, **kwargs): return comfy.patcher_extension.WrapperExecutor.new_class_executor( self._forward, self, comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options) - ).execute(x, timestep, context, y, guidance, attention_mask, guiding_frame_index, ref_latent, control, transformer_options, **kwargs) + ).execute(x, timestep, context, y, txt_byt5, guidance, attention_mask, guiding_frame_index, ref_latent, control, transformer_options, **kwargs) - def _forward(self, x, timestep, context, y, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, control=None, transformer_options={}, **kwargs): - bs, c, t, h, w = x.shape - img_ids = self.img_ids(x) - txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype) - out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, guidance, guiding_frame_index, ref_latent, control=control, transformer_options=transformer_options) + def _forward(self, x, timestep, context, y=None, txt_byt5=None, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, control=None, transformer_options={}, **kwargs): + bs = x.shape[0] + if len(self.patch_size) == 3: + img_ids = self.img_ids(x) + txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype) + else: + img_ids = self.img_ids_2d(x) + txt_ids = torch.zeros((bs, context.shape[1], 2), device=x.device, dtype=x.dtype) + out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, txt_byt5, guidance, guiding_frame_index, ref_latent, control=control, transformer_options=transformer_options) return out diff --git a/comfy/ldm/hunyuan_video/vae.py b/comfy/ldm/hunyuan_video/vae.py new file mode 100644 index 000000000..8d406089b --- /dev/null +++ b/comfy/ldm/hunyuan_video/vae.py @@ -0,0 +1,136 @@ +import torch.nn as nn +import torch.nn.functional as F +from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, AttnBlock +import comfy.ops +ops = comfy.ops.disable_weight_init + + +class PixelShuffle2D(nn.Module): + def __init__(self, in_dim, out_dim, op=ops.Conv2d): + super().__init__() + self.conv = op(in_dim, out_dim >> 2, 3, 1, 1) + self.ratio = (in_dim << 2) // out_dim + + def forward(self, x): + b, c, h, w = x.shape + h2, w2 = h >> 1, w >> 1 + y = self.conv(x).view(b, -1, h2, 2, w2, 2).permute(0, 3, 5, 1, 2, 4).reshape(b, -1, h2, w2) + r = x.view(b, c, h2, 2, w2, 2).permute(0, 3, 5, 1, 2, 4).reshape(b, c << 2, h2, w2) + return y + r.view(b, y.shape[1], self.ratio, h2, w2).mean(2) + + +class PixelUnshuffle2D(nn.Module): + def __init__(self, in_dim, out_dim, op=ops.Conv2d): + super().__init__() + self.conv = op(in_dim, out_dim << 2, 3, 1, 1) + self.scale = (out_dim << 2) // in_dim + + def forward(self, x): + b, c, h, w = x.shape + h2, w2 = h << 1, w << 1 + y = self.conv(x).view(b, 2, 2, -1, h, w).permute(0, 3, 4, 1, 5, 2).reshape(b, -1, h2, w2) + r = x.repeat_interleave(self.scale, 1).view(b, 2, 2, -1, h, w).permute(0, 3, 4, 1, 5, 2).reshape(b, -1, h2, w2) + return y + r + + +class Encoder(nn.Module): + def __init__(self, in_channels, z_channels, block_out_channels, num_res_blocks, + ffactor_spatial, downsample_match_channel=True, **_): + super().__init__() + self.z_channels = z_channels + self.block_out_channels = block_out_channels + self.num_res_blocks = num_res_blocks + self.conv_in = ops.Conv2d(in_channels, block_out_channels[0], 3, 1, 1) + + self.down = nn.ModuleList() + ch = block_out_channels[0] + depth = (ffactor_spatial >> 1).bit_length() + + for i, tgt in enumerate(block_out_channels): + stage = nn.Module() + stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt, + out_channels=tgt, + temb_channels=0, + conv_op=ops.Conv2d) + for j in range(num_res_blocks)]) + ch = tgt + if i < depth: + nxt = block_out_channels[i + 1] if i + 1 < len(block_out_channels) and downsample_match_channel else ch + stage.downsample = PixelShuffle2D(ch, nxt, ops.Conv2d) + ch = nxt + self.down.append(stage) + + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=ops.Conv2d) + self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv2d) + self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=ops.Conv2d) + + self.norm_out = nn.GroupNorm(32, ch, 1e-6, True) + self.conv_out = ops.Conv2d(ch, z_channels << 1, 3, 1, 1) + + def forward(self, x): + x = self.conv_in(x) + + for stage in self.down: + for blk in stage.block: + x = blk(x) + if hasattr(stage, 'downsample'): + x = stage.downsample(x) + + x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(x))) + + b, c, h, w = x.shape + grp = c // (self.z_channels << 1) + skip = x.view(b, c // grp, grp, h, w).mean(2) + + return self.conv_out(F.silu(self.norm_out(x))) + skip + + +class Decoder(nn.Module): + def __init__(self, z_channels, out_channels, block_out_channels, num_res_blocks, + ffactor_spatial, upsample_match_channel=True, **_): + super().__init__() + block_out_channels = block_out_channels[::-1] + self.z_channels = z_channels + self.block_out_channels = block_out_channels + self.num_res_blocks = num_res_blocks + + ch = block_out_channels[0] + self.conv_in = ops.Conv2d(z_channels, ch, 3, 1, 1) + + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=ops.Conv2d) + self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv2d) + self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=ops.Conv2d) + + self.up = nn.ModuleList() + depth = (ffactor_spatial >> 1).bit_length() + + for i, tgt in enumerate(block_out_channels): + stage = nn.Module() + stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt, + out_channels=tgt, + temb_channels=0, + conv_op=ops.Conv2d) + for j in range(num_res_blocks + 1)]) + ch = tgt + if i < depth: + nxt = block_out_channels[i + 1] if i + 1 < len(block_out_channels) and upsample_match_channel else ch + stage.upsample = PixelUnshuffle2D(ch, nxt, ops.Conv2d) + ch = nxt + self.up.append(stage) + + self.norm_out = nn.GroupNorm(32, ch, 1e-6, True) + self.conv_out = ops.Conv2d(ch, out_channels, 3, 1, 1) + + def forward(self, z): + x = self.conv_in(z) + z.repeat_interleave(self.block_out_channels[0] // self.z_channels, 1) + x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(x))) + + for stage in self.up: + for blk in stage.block: + x = blk(x) + if hasattr(stage, 'upsample'): + x = stage.upsample(x) + + return self.conv_out(F.silu(self.norm_out(x))) diff --git a/comfy/model_base.py b/comfy/model_base.py index 39a3344bc..993ff65e6 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1408,3 +1408,27 @@ class QwenImage(BaseModel): if ref_latents is not None: out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16]) return out + +class HunyuanImage21(BaseModel): + def __init__(self, model_config, model_type=ModelType.FLOW, device=None): + super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan_video.model.HunyuanVideo) + + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + attention_mask = kwargs.get("attention_mask", None) + if attention_mask is not None: + if torch.numel(attention_mask) != attention_mask.sum(): + out['attention_mask'] = comfy.conds.CONDRegular(attention_mask) + cross_attn = kwargs.get("cross_attn", None) + if cross_attn is not None: + out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) + + conditioning_byt5small = kwargs.get("conditioning_byt5small", None) + if conditioning_byt5small is not None: + out['txt_byt5'] = comfy.conds.CONDRegular(conditioning_byt5small) + + guidance = kwargs.get("guidance", 6.0) + if guidance is not None: + out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance])) + + return out diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 75552ede9..dbcbe5f5a 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -136,20 +136,32 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): if '{}txt_in.individual_token_refiner.blocks.0.norm1.weight'.format(key_prefix) in state_dict_keys: #Hunyuan Video dit_config = {} + in_w = state_dict['{}img_in.proj.weight'.format(key_prefix)] + out_w = state_dict['{}final_layer.linear.weight'.format(key_prefix)] dit_config["image_model"] = "hunyuan_video" - dit_config["in_channels"] = state_dict['{}img_in.proj.weight'.format(key_prefix)].shape[1] #SkyReels img2video has 32 input channels - dit_config["patch_size"] = [1, 2, 2] - dit_config["out_channels"] = 16 - dit_config["vec_in_dim"] = 768 - dit_config["context_in_dim"] = 4096 - dit_config["hidden_size"] = 3072 + dit_config["in_channels"] = in_w.shape[1] #SkyReels img2video has 32 input channels + dit_config["patch_size"] = list(in_w.shape[2:]) + dit_config["out_channels"] = out_w.shape[0] // math.prod(dit_config["patch_size"]) + if '{}vector_in.in_layer.weight'.format(key_prefix) in state_dict: + dit_config["vec_in_dim"] = 768 + dit_config["axes_dim"] = [16, 56, 56] + else: + dit_config["vec_in_dim"] = None + dit_config["axes_dim"] = [64, 64] + + dit_config["context_in_dim"] = state_dict['{}txt_in.input_embedder.weight'.format(key_prefix)].shape[1] + dit_config["hidden_size"] = in_w.shape[0] dit_config["mlp_ratio"] = 4.0 - dit_config["num_heads"] = 24 + dit_config["num_heads"] = in_w.shape[0] // 128 dit_config["depth"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.') dit_config["depth_single_blocks"] = count_blocks(state_dict_keys, '{}single_blocks.'.format(key_prefix) + '{}.') - dit_config["axes_dim"] = [16, 56, 56] dit_config["theta"] = 256 dit_config["qkv_bias"] = True + if '{}byt5_in.fc1.weight'.format(key_prefix) in state_dict: + dit_config["byt5"] = True + else: + dit_config["byt5"] = False + guidance_keys = list(filter(lambda a: a.startswith("{}guidance_in.".format(key_prefix)), state_dict_keys)) dit_config["guidance_embed"] = len(guidance_keys) > 0 return dit_config diff --git a/comfy/sd.py b/comfy/sd.py index be5aa8dc8..9dd9a74d4 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -17,6 +17,7 @@ import comfy.ldm.wan.vae import comfy.ldm.wan.vae2_2 import comfy.ldm.hunyuan3d.vae import comfy.ldm.ace.vae.music_dcae_pipeline +import comfy.ldm.hunyuan_video.vae import yaml import math import os @@ -48,6 +49,7 @@ import comfy.text_encoders.hidream import comfy.text_encoders.ace import comfy.text_encoders.omnigen2 import comfy.text_encoders.qwen_image +import comfy.text_encoders.hunyuan_image import comfy.model_patcher import comfy.lora @@ -328,6 +330,19 @@ class VAE: self.first_stage_model = StageC_coder() self.downscale_ratio = 32 self.latent_channels = 16 + elif "decoder.conv_in.weight" in sd and sd['decoder.conv_in.weight'].shape[1] == 64: + ddconfig = {"block_out_channels": [128, 256, 512, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 32, "downsample_match_channel": True, "upsample_match_channel": True} + self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1] + self.downscale_ratio = 32 + self.upscale_ratio = 32 + self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32] + self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"}, + encoder_config={'target': "comfy.ldm.hunyuan_video.vae.Encoder", 'params': ddconfig}, + decoder_config={'target': "comfy.ldm.hunyuan_video.vae.Decoder", 'params': ddconfig}) + + self.memory_used_encode = lambda shape, dtype: (700 * shape[2] * shape[3]) * model_management.dtype_size(dtype) + self.memory_used_decode = lambda shape, dtype: (700 * shape[2] * shape[3] * 32 * 32) * model_management.dtype_size(dtype) + elif "decoder.conv_in.weight" in sd: #default SD1.x/SD2.x VAE parameters ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0} @@ -785,6 +800,7 @@ class CLIPType(Enum): ACE = 16 OMNIGEN2 = 17 QWEN_IMAGE = 18 + HUNYUAN_IMAGE = 19 def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}): @@ -806,6 +822,7 @@ class TEModel(Enum): GEMMA_2_2B = 9 QWEN25_3B = 10 QWEN25_7B = 11 + BYT5_SMALL_GLYPH = 12 def detect_te_model(sd): if "text_model.encoder.layers.30.mlp.fc1.weight" in sd: @@ -823,6 +840,9 @@ def detect_te_model(sd): if 'encoder.block.23.layer.1.DenseReluDense.wi.weight' in sd: return TEModel.T5_XXL_OLD if "encoder.block.0.layer.0.SelfAttention.k.weight" in sd: + weight = sd['encoder.block.0.layer.0.SelfAttention.k.weight'] + if weight.shape[0] == 384: + return TEModel.BYT5_SMALL_GLYPH return TEModel.T5_BASE if 'model.layers.0.post_feedforward_layernorm.weight' in sd: return TEModel.GEMMA_2_2B @@ -937,8 +957,12 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip clip_target.clip = comfy.text_encoders.omnigen2.te(**llama_detect(clip_data)) clip_target.tokenizer = comfy.text_encoders.omnigen2.Omnigen2Tokenizer elif te_model == TEModel.QWEN25_7B: - clip_target.clip = comfy.text_encoders.qwen_image.te(**llama_detect(clip_data)) - clip_target.tokenizer = comfy.text_encoders.qwen_image.QwenImageTokenizer + if clip_type == CLIPType.HUNYUAN_IMAGE: + clip_target.clip = comfy.text_encoders.hunyuan_image.te(byt5=False, **llama_detect(clip_data)) + clip_target.tokenizer = comfy.text_encoders.hunyuan_image.HunyuanImageTokenizer + else: + clip_target.clip = comfy.text_encoders.qwen_image.te(**llama_detect(clip_data)) + clip_target.tokenizer = comfy.text_encoders.qwen_image.QwenImageTokenizer else: # clip_l if clip_type == CLIPType.SD3: @@ -982,6 +1006,9 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip clip_target.clip = comfy.text_encoders.hidream.hidream_clip(clip_l=clip_l, clip_g=clip_g, t5=t5, llama=llama, **t5_kwargs, **llama_kwargs) clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer + elif clip_type == CLIPType.HUNYUAN_IMAGE: + clip_target.clip = comfy.text_encoders.hunyuan_image.te(**llama_detect(clip_data)) + clip_target.tokenizer = comfy.text_encoders.hunyuan_image.HunyuanImageTokenizer else: clip_target.clip = sdxl_clip.SDXLClipModel clip_target.tokenizer = sdxl_clip.SDXLTokenizer diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 75dad277d..aa953b462 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -20,6 +20,7 @@ import comfy.text_encoders.wan import comfy.text_encoders.ace import comfy.text_encoders.omnigen2 import comfy.text_encoders.qwen_image +import comfy.text_encoders.hunyuan_image from . import supported_models_base from . import latent_formats @@ -1295,7 +1296,31 @@ class QwenImage(supported_models_base.BASE): hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref)) return supported_models_base.ClipTarget(comfy.text_encoders.qwen_image.QwenImageTokenizer, comfy.text_encoders.qwen_image.te(**hunyuan_detect)) +class HunyuanImage21(HunyuanVideo): + unet_config = { + "image_model": "hunyuan_video", + "vec_in_dim": None, + } -models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ACEStep, Omnigen2, QwenImage] + sampling_settings = { + "shift": 5.0, + } + + latent_format = latent_formats.HunyuanImage21 + + memory_usage_factor = 7.7 + + supported_inference_dtypes = [torch.bfloat16, torch.float32] + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.HunyuanImage21(self, device=device) + return out + + def clip_target(self, state_dict={}): + pref = self.text_encoder_key_prefix[0] + hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref)) + return supported_models_base.ClipTarget(comfy.text_encoders.hunyuan_image.HunyuanImageTokenizer, comfy.text_encoders.hunyuan_image.te(**hunyuan_detect)) + +models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ACEStep, Omnigen2, QwenImage] models += [SVD_img2vid] diff --git a/comfy/text_encoders/byt5_config_small_glyph.json b/comfy/text_encoders/byt5_config_small_glyph.json new file mode 100644 index 000000000..0239c7164 --- /dev/null +++ b/comfy/text_encoders/byt5_config_small_glyph.json @@ -0,0 +1,22 @@ +{ + "d_ff": 3584, + "d_kv": 64, + "d_model": 1472, + "decoder_start_token_id": 0, + "dropout_rate": 0.1, + "eos_token_id": 1, + "dense_act_fn": "gelu_pytorch_tanh", + "initializer_factor": 1.0, + "is_encoder_decoder": true, + "is_gated_act": true, + "layer_norm_epsilon": 1e-06, + "model_type": "t5", + "num_decoder_layers": 4, + "num_heads": 6, + "num_layers": 12, + "output_past": true, + "pad_token_id": 0, + "relative_attention_num_buckets": 32, + "tie_word_embeddings": false, + "vocab_size": 1510 +} diff --git a/comfy/text_encoders/byt5_tokenizer/added_tokens.json b/comfy/text_encoders/byt5_tokenizer/added_tokens.json new file mode 100644 index 000000000..93c190b56 --- /dev/null +++ b/comfy/text_encoders/byt5_tokenizer/added_tokens.json @@ -0,0 +1,127 @@ +{ + "": 259, + "": 359, + "": 360, + "": 361, + "": 362, + "": 363, + "": 364, + "": 365, + "": 366, + "": 367, + "": 368, + "": 269, + "": 369, + "": 370, + "": 371, + "": 372, + "": 373, + "": 374, + "": 375, + "": 376, + "": 377, + "": 378, + "": 270, + "": 379, + "": 380, + "": 381, + "": 382, + "": 383, + "": 271, + "": 272, + "": 273, + "": 274, + "": 275, + "": 276, + "": 277, + "": 278, + "": 260, + "": 279, + "": 280, + "": 281, + "": 282, + "": 283, + "": 284, + "": 285, + "": 286, + "": 287, + "": 288, + "": 261, + "": 289, + "": 290, + "": 291, + "": 292, + "": 293, + "": 294, + "": 295, + "": 296, + "": 297, + "": 298, + "": 262, + "": 299, + "": 300, + "": 301, + "": 302, + "": 303, + "": 304, + "": 305, + "": 306, + "": 307, + "": 308, + "": 263, + "": 309, + "": 310, + "": 311, + "": 312, + "": 313, + "": 314, + "": 315, + "": 316, + "": 317, + "": 318, + "": 264, + "": 319, + "": 320, + "": 321, + "": 322, + "": 323, + "": 324, + "": 325, + "": 326, + "": 327, + "": 328, + "": 265, + "": 329, + "": 330, + "": 331, + "": 332, + "": 333, + "": 334, + "": 335, + "": 336, + "": 337, + "": 338, + "": 266, + "": 339, + "": 340, + "": 341, + "": 342, + "": 343, + "": 344, + "": 345, + "": 346, + "": 347, + "": 348, + "": 267, + "": 349, + "": 350, + "": 351, + "": 352, + "": 353, + "": 354, + "": 355, + "": 356, + "": 357, + "": 358, + "": 268 +} diff --git a/comfy/text_encoders/byt5_tokenizer/special_tokens_map.json b/comfy/text_encoders/byt5_tokenizer/special_tokens_map.json new file mode 100644 index 000000000..04fd58b5f --- /dev/null +++ b/comfy/text_encoders/byt5_tokenizer/special_tokens_map.json @@ -0,0 +1,150 @@ +{ + "additional_special_tokens": [ + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "" + ], + "eos_token": { + "content": "", + "lstrip": false, + "normalized": true, + "rstrip": false, + "single_word": false + }, + "pad_token": { + "content": "", + "lstrip": false, + "normalized": true, + "rstrip": false, + "single_word": false + }, + "unk_token": { + "content": "", + "lstrip": false, + "normalized": true, + "rstrip": false, + "single_word": false + } +} diff --git a/comfy/text_encoders/byt5_tokenizer/tokenizer_config.json b/comfy/text_encoders/byt5_tokenizer/tokenizer_config.json new file mode 100644 index 000000000..5b1fe24c1 --- /dev/null +++ b/comfy/text_encoders/byt5_tokenizer/tokenizer_config.json @@ -0,0 +1,1163 @@ +{ + "added_tokens_decoder": { + "0": { + "content": "", + "lstrip": false, + "normalized": true, + "rstrip": false, + "single_word": false, + "special": true + }, + "1": { + "content": "", + "lstrip": false, + "normalized": true, + "rstrip": false, + "single_word": false, + "special": true + }, + "2": { + "content": "", + "lstrip": false, + "normalized": true, + "rstrip": false, + "single_word": false, + "special": true + }, + "259": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "260": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "261": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "262": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "263": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "264": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "265": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "266": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "267": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "268": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "269": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "270": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "271": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "272": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "273": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "274": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "275": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "276": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "277": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "278": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "279": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "280": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "281": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "282": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "283": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "284": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "285": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "286": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "287": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "288": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "289": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "290": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "291": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "292": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "293": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "294": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "295": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "296": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "297": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "298": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "299": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "300": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "301": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "302": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "303": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "304": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "305": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "306": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "307": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "308": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "309": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "310": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "311": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "312": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "313": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "314": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "315": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "316": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "317": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "318": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "319": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "320": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "321": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "322": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "323": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "324": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "325": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "326": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "327": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "328": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "329": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "330": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "331": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "332": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "333": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "334": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "335": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "336": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "337": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "338": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "339": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "340": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "341": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "342": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "343": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "344": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "345": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "346": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "347": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "348": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "349": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "350": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "351": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "352": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "353": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "354": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "355": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "356": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "357": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "358": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "359": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "360": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "361": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "362": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "363": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "364": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "365": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "366": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "367": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "368": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "369": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "370": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "371": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "372": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "373": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "374": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "375": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "376": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "377": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "378": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "379": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "380": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "381": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "382": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "383": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + } + }, + "additional_special_tokens": [ + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "" + ], + "clean_up_tokenization_spaces": false, + "eos_token": "", + "extra_ids": 0, + "extra_special_tokens": {}, + "model_max_length": 1000000000000000019884624838656, + "pad_token": "", + "tokenizer_class": "ByT5Tokenizer", + "unk_token": "" +} diff --git a/comfy/text_encoders/hunyuan_image.py b/comfy/text_encoders/hunyuan_image.py new file mode 100644 index 000000000..be396cae7 --- /dev/null +++ b/comfy/text_encoders/hunyuan_image.py @@ -0,0 +1,100 @@ +from comfy import sd1_clip +import comfy.text_encoders.llama +from .qwen_image import QwenImageTokenizer, QwenImageTEModel +from transformers import ByT5Tokenizer +import os +import re + +class ByT5SmallTokenizer(sd1_clip.SDTokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "byt5_tokenizer") + super().__init__(tokenizer_path, pad_with_end=False, embedding_size=1472, embedding_key='byt5_small', tokenizer_class=ByT5Tokenizer, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_data=tokenizer_data) + +class HunyuanImageTokenizer(QwenImageTokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) + self.llama_template = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>" + # self.llama_template_images = "{}" + self.byt5 = ByT5SmallTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) + + def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs): + out = super().tokenize_with_weights(text, return_word_ids, **kwargs) + + # ByT5 processing for HunyuanImage + text_prompt_texts = [] + pattern_quote_single = r'\'(.*?)\'' + pattern_quote_double = r'\"(.*?)\"' + pattern_quote_chinese_single = r'‘(.*?)’' + pattern_quote_chinese_double = r'“(.*?)”' + + matches_quote_single = re.findall(pattern_quote_single, text) + matches_quote_double = re.findall(pattern_quote_double, text) + matches_quote_chinese_single = re.findall(pattern_quote_chinese_single, text) + matches_quote_chinese_double = re.findall(pattern_quote_chinese_double, text) + + text_prompt_texts.extend(matches_quote_single) + text_prompt_texts.extend(matches_quote_double) + text_prompt_texts.extend(matches_quote_chinese_single) + text_prompt_texts.extend(matches_quote_chinese_double) + + if len(text_prompt_texts) > 0: + out['byt5'] = self.byt5.tokenize_with_weights(''.join(map(lambda a: 'Text "{}". '.format(a), text_prompt_texts)), return_word_ids, **kwargs) + return out + +class Qwen25_7BVLIModel(sd1_clip.SDClipModel): + def __init__(self, device="cpu", layer="hidden", layer_idx=-3, dtype=None, attention_mask=True, model_options={}): + llama_scaled_fp8 = model_options.get("qwen_scaled_fp8", None) + if llama_scaled_fp8 is not None: + model_options = model_options.copy() + model_options["scaled_fp8"] = llama_scaled_fp8 + super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Qwen25_7BVLI, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) + + +class ByT5SmallModel(sd1_clip.SDClipModel): + def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options={}): + textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "byt5_config_small_glyph.json") + super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, model_options=model_options, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=True, zero_out_masked=True) + + +class HunyuanImageTEModel(QwenImageTEModel): + def __init__(self, byt5=True, device="cpu", dtype=None, model_options={}): + super(QwenImageTEModel, self).__init__(device=device, dtype=dtype, name="qwen25_7b", clip_model=Qwen25_7BVLIModel, model_options=model_options) + + if byt5: + self.byt5_small = ByT5SmallModel(device=device, dtype=dtype, model_options=model_options) + else: + self.byt5_small = None + + def encode_token_weights(self, token_weight_pairs): + cond, p, extra = super().encode_token_weights(token_weight_pairs) + if self.byt5_small is not None and "byt5" in token_weight_pairs: + out = self.byt5_small.encode_token_weights(token_weight_pairs["byt5"]) + extra["conditioning_byt5small"] = out[0] + return cond, p, extra + + def set_clip_options(self, options): + super().set_clip_options(options) + if self.byt5_small is not None: + self.byt5_small.set_clip_options(options) + + def reset_clip_options(self): + super().reset_clip_options() + if self.byt5_small is not None: + self.byt5_small.reset_clip_options() + + def load_sd(self, sd): + if "encoder.block.0.layer.0.SelfAttention.o.weight" in sd: + return self.byt5_small.load_sd(sd) + else: + return super().load_sd(sd) + +def te(byt5=True, dtype_llama=None, llama_scaled_fp8=None): + class QwenImageTEModel_(HunyuanImageTEModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options: + model_options = model_options.copy() + model_options["qwen_scaled_fp8"] = llama_scaled_fp8 + if dtype_llama is not None: + dtype = dtype_llama + super().__init__(byt5=byt5, device=device, dtype=dtype, model_options=model_options) + return QwenImageTEModel_ diff --git a/comfy_extras/nodes_hunyuan.py b/comfy_extras/nodes_hunyuan.py index d7278e7a7..ce031ceb2 100644 --- a/comfy_extras/nodes_hunyuan.py +++ b/comfy_extras/nodes_hunyuan.py @@ -113,6 +113,20 @@ class HunyuanImageToVideo: out_latent["samples"] = latent return (positive, out_latent) +class EmptyHunyuanImageLatent: + @classmethod + def INPUT_TYPES(s): + return {"required": { "width": ("INT", {"default": 2048, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}), + "height": ("INT", {"default": 2048, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}), + "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}} + RETURN_TYPES = ("LATENT",) + FUNCTION = "generate" + + CATEGORY = "latent" + + def generate(self, width, height, batch_size=1): + latent = torch.zeros([batch_size, 64, height // 32, width // 32], device=comfy.model_management.intermediate_device()) + return ({"samples":latent}, ) NODE_CLASS_MAPPINGS = { @@ -120,4 +134,5 @@ NODE_CLASS_MAPPINGS = { "TextEncodeHunyuanVideo_ImageToVideo": TextEncodeHunyuanVideo_ImageToVideo, "EmptyHunyuanLatentVideo": EmptyHunyuanLatentVideo, "HunyuanImageToVideo": HunyuanImageToVideo, + "EmptyHunyuanImageLatent": EmptyHunyuanImageLatent, } diff --git a/nodes.py b/nodes.py index 6c2f9dd14..2befb4b75 100644 --- a/nodes.py +++ b/nodes.py @@ -925,7 +925,7 @@ class CLIPLoader: @classmethod def INPUT_TYPES(s): return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ), - "type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image"], ), + "type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image"], ), }, "optional": { "device": (["default", "cpu"], {"advanced": True}), @@ -953,7 +953,7 @@ class DualCLIPLoader: def INPUT_TYPES(s): return {"required": { "clip_name1": (folder_paths.get_filename_list("text_encoders"), ), "clip_name2": (folder_paths.get_filename_list("text_encoders"), ), - "type": (["sdxl", "sd3", "flux", "hunyuan_video", "hidream"], ), + "type": (["sdxl", "sd3", "flux", "hunyuan_video", "hidream", "hunyuan_image"], ), }, "optional": { "device": (["default", "cpu"], {"advanced": True}), @@ -963,7 +963,7 @@ class DualCLIPLoader: CATEGORY = "advanced/loaders" - DESCRIPTION = "[Recipes]\n\nsdxl: clip-l, clip-g\nsd3: clip-l, clip-g / clip-l, t5 / clip-g, t5\nflux: clip-l, t5\nhidream: at least one of t5 or llama, recommended t5 and llama" + DESCRIPTION = "[Recipes]\n\nsdxl: clip-l, clip-g\nsd3: clip-l, clip-g / clip-l, t5 / clip-g, t5\nflux: clip-l, t5\nhidream: at least one of t5 or llama, recommended t5 and llama\nhunyuan_image: qwen2.5vl 7b and byt5 small" def load_clip(self, clip_name1, clip_name2, type, device="default"): clip_type = getattr(comfy.sd.CLIPType, type.upper(), comfy.sd.CLIPType.STABLE_DIFFUSION) From 70fc0425b36515926c6414aee9f2269b27880cc2 Mon Sep 17 00:00:00 2001 From: ComfyUI Wiki Date: Wed, 10 Sep 2025 14:09:16 +0800 Subject: [PATCH 04/13] Update template to 0.1.76 (#9793) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 3008a5dc3..ea1931d78 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.25.11 -comfyui-workflow-templates==0.1.75 +comfyui-workflow-templates==0.1.76 comfyui-embedded-docs==0.2.6 torch torchsde From 543888d3d84a6ec4c4273838d5179845840e3226 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 9 Sep 2025 23:15:34 -0700 Subject: [PATCH 05/13] Fix lowvram issue with hunyuan image vae. (#9794) --- comfy/ldm/hunyuan_video/vae.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/comfy/ldm/hunyuan_video/vae.py b/comfy/ldm/hunyuan_video/vae.py index 8d406089b..40c12b183 100644 --- a/comfy/ldm/hunyuan_video/vae.py +++ b/comfy/ldm/hunyuan_video/vae.py @@ -65,7 +65,7 @@ class Encoder(nn.Module): self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv2d) self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=ops.Conv2d) - self.norm_out = nn.GroupNorm(32, ch, 1e-6, True) + self.norm_out = ops.GroupNorm(32, ch, 1e-6, True) self.conv_out = ops.Conv2d(ch, z_channels << 1, 3, 1, 1) def forward(self, x): @@ -120,7 +120,7 @@ class Decoder(nn.Module): ch = nxt self.up.append(stage) - self.norm_out = nn.GroupNorm(32, ch, 1e-6, True) + self.norm_out = ops.GroupNorm(32, ch, 1e-6, True) self.conv_out = ops.Conv2d(ch, out_channels, 3, 1, 1) def forward(self, z): From de44b95db6c7ef107f26e7edf30748b608afaa48 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Wed, 10 Sep 2025 12:06:47 +0300 Subject: [PATCH 06/13] add StabilityAudio API nodes (#9749) --- comfy_api_nodes/apinode_utils.py | 65 +++++ comfy_api_nodes/apis/stability_api.py | 22 ++ comfy_api_nodes/nodes_stability.py | 312 ++++++++++++++++++++++- comfy_api_nodes/util/validation_utils.py | 20 +- 4 files changed, 415 insertions(+), 4 deletions(-) diff --git a/comfy_api_nodes/apinode_utils.py b/comfy_api_nodes/apinode_utils.py index f953f86df..37438f835 100644 --- a/comfy_api_nodes/apinode_utils.py +++ b/comfy_api_nodes/apinode_utils.py @@ -518,6 +518,71 @@ async def upload_audio_to_comfyapi( return await upload_file_to_comfyapi(audio_bytes_io, filename, mime_type, auth_kwargs) +def f32_pcm(wav: torch.Tensor) -> torch.Tensor: + """Convert audio to float 32 bits PCM format. Copy-paste from nodes_audio.py file.""" + if wav.dtype.is_floating_point: + return wav + elif wav.dtype == torch.int16: + return wav.float() / (2 ** 15) + elif wav.dtype == torch.int32: + return wav.float() / (2 ** 31) + raise ValueError(f"Unsupported wav dtype: {wav.dtype}") + + +def audio_bytes_to_audio_input(audio_bytes: bytes,) -> dict: + """ + Decode any common audio container from bytes using PyAV and return + a Comfy AUDIO dict: {"waveform": [1, C, T] float32, "sample_rate": int}. + """ + with av.open(io.BytesIO(audio_bytes)) as af: + if not af.streams.audio: + raise ValueError("No audio stream found in response.") + stream = af.streams.audio[0] + + in_sr = int(stream.codec_context.sample_rate) + out_sr = in_sr + + frames: list[torch.Tensor] = [] + n_channels = stream.channels or 1 + + for frame in af.decode(streams=stream.index): + arr = frame.to_ndarray() # shape can be [C, T] or [T, C] or [T] + buf = torch.from_numpy(arr) + if buf.ndim == 1: + buf = buf.unsqueeze(0) # [T] -> [1, T] + elif buf.shape[0] != n_channels and buf.shape[-1] == n_channels: + buf = buf.transpose(0, 1).contiguous() # [T, C] -> [C, T] + elif buf.shape[0] != n_channels: + buf = buf.reshape(-1, n_channels).t().contiguous() # fallback to [C, T] + frames.append(buf) + + if not frames: + raise ValueError("Decoded zero audio frames.") + + wav = torch.cat(frames, dim=1) # [C, T] + wav = f32_pcm(wav) + return {"waveform": wav.unsqueeze(0).contiguous(), "sample_rate": out_sr} + + +def audio_input_to_mp3(audio: AudioInput) -> io.BytesIO: + waveform = audio["waveform"].cpu() + + output_buffer = io.BytesIO() + output_container = av.open(output_buffer, mode='w', format="mp3") + + out_stream = output_container.add_stream("libmp3lame", rate=audio["sample_rate"]) + out_stream.bit_rate = 320000 + + frame = av.AudioFrame.from_ndarray(waveform.movedim(0, 1).reshape(1, -1).float().numpy(), format='flt', layout='mono' if waveform.shape[0] == 1 else 'stereo') + frame.sample_rate = audio["sample_rate"] + frame.pts = 0 + output_container.mux(out_stream.encode(frame)) + output_container.mux(out_stream.encode(None)) + output_container.close() + output_buffer.seek(0) + return output_buffer + + def audio_to_base64_string( audio: AudioInput, container_format: str = "mp4", codec_name: str = "aac" ) -> str: diff --git a/comfy_api_nodes/apis/stability_api.py b/comfy_api_nodes/apis/stability_api.py index 47c87daec..718360187 100644 --- a/comfy_api_nodes/apis/stability_api.py +++ b/comfy_api_nodes/apis/stability_api.py @@ -125,3 +125,25 @@ class StabilityResultsGetResponse(BaseModel): class StabilityAsyncResponse(BaseModel): id: Optional[str] = Field(None) + + +class StabilityTextToAudioRequest(BaseModel): + model: str = Field(...) + prompt: str = Field(...) + duration: int = Field(190, ge=1, le=190) + seed: int = Field(0, ge=0, le=4294967294) + steps: int = Field(8, ge=4, le=8) + output_format: str = Field("wav") + + +class StabilityAudioToAudioRequest(StabilityTextToAudioRequest): + strength: float = Field(0.01, ge=0.01, le=1.0) + + +class StabilityAudioInpaintRequest(StabilityTextToAudioRequest): + mask_start: int = Field(30, ge=0, le=190) + mask_end: int = Field(190, ge=0, le=190) + + +class StabilityAudioResponse(BaseModel): + audio: Optional[str] = Field(None) diff --git a/comfy_api_nodes/nodes_stability.py b/comfy_api_nodes/nodes_stability.py index e05cb6bb2..5ba5ed986 100644 --- a/comfy_api_nodes/nodes_stability.py +++ b/comfy_api_nodes/nodes_stability.py @@ -2,7 +2,7 @@ from inspect import cleandoc from typing import Optional from typing_extensions import override -from comfy_api.latest import ComfyExtension, io as comfy_io +from comfy_api.latest import ComfyExtension, Input, io as comfy_io from comfy_api_nodes.apis.stability_api import ( StabilityUpscaleConservativeRequest, StabilityUpscaleCreativeRequest, @@ -15,6 +15,10 @@ from comfy_api_nodes.apis.stability_api import ( Stability_SD3_5_Model, Stability_SD3_5_GenerationMode, get_stability_style_presets, + StabilityTextToAudioRequest, + StabilityAudioToAudioRequest, + StabilityAudioInpaintRequest, + StabilityAudioResponse, ) from comfy_api_nodes.apis.client import ( ApiEndpoint, @@ -27,7 +31,10 @@ from comfy_api_nodes.apinode_utils import ( bytesio_to_image_tensor, tensor_to_bytesio, validate_string, + audio_bytes_to_audio_input, + audio_input_to_mp3, ) +from comfy_api_nodes.util.validation_utils import validate_audio_duration import torch import base64 @@ -649,6 +656,306 @@ class StabilityUpscaleFastNode(comfy_io.ComfyNode): return comfy_io.NodeOutput(returned_image) +class StabilityTextToAudio(comfy_io.ComfyNode): + """Generates high-quality music and sound effects from text descriptions.""" + + @classmethod + def define_schema(cls): + return comfy_io.Schema( + node_id="StabilityTextToAudio", + display_name="Stability AI Text To Audio", + category="api node/audio/Stability AI", + description=cleandoc(cls.__doc__ or ""), + inputs=[ + comfy_io.Combo.Input( + "model", + options=["stable-audio-2.5"], + ), + comfy_io.String.Input("prompt", multiline=True, default=""), + comfy_io.Int.Input( + "duration", + default=190, + min=1, + max=190, + step=1, + tooltip="Controls the duration in seconds of the generated audio.", + optional=True, + ), + comfy_io.Int.Input( + "seed", + default=0, + min=0, + max=4294967294, + step=1, + display_mode=comfy_io.NumberDisplay.number, + control_after_generate=True, + tooltip="The random seed used for generation.", + optional=True, + ), + comfy_io.Int.Input( + "steps", + default=8, + min=4, + max=8, + step=1, + tooltip="Controls the number of sampling steps.", + optional=True, + ), + ], + outputs=[ + comfy_io.Audio.Output(), + ], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + comfy_io.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute(cls, model: str, prompt: str, duration: int, seed: int, steps: int) -> comfy_io.NodeOutput: + validate_string(prompt, max_length=10000) + payload = StabilityTextToAudioRequest(prompt=prompt, model=model, duration=duration, seed=seed, steps=steps) + operation = SynchronousOperation( + endpoint=ApiEndpoint( + path="/proxy/stability/v2beta/audio/stable-audio-2/text-to-audio", + method=HttpMethod.POST, + request_model=StabilityTextToAudioRequest, + response_model=StabilityAudioResponse, + ), + request=payload, + content_type="multipart/form-data", + auth_kwargs= { + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, + }, + ) + response_api = await operation.execute() + if not response_api.audio: + raise ValueError("No audio file was received in response.") + return comfy_io.NodeOutput(audio_bytes_to_audio_input(base64.b64decode(response_api.audio))) + + +class StabilityAudioToAudio(comfy_io.ComfyNode): + """Transforms existing audio samples into new high-quality compositions using text instructions.""" + + @classmethod + def define_schema(cls): + return comfy_io.Schema( + node_id="StabilityAudioToAudio", + display_name="Stability AI Audio To Audio", + category="api node/audio/Stability AI", + description=cleandoc(cls.__doc__ or ""), + inputs=[ + comfy_io.Combo.Input( + "model", + options=["stable-audio-2.5"], + ), + comfy_io.String.Input("prompt", multiline=True, default=""), + comfy_io.Audio.Input("audio", tooltip="Audio must be between 6 and 190 seconds long."), + comfy_io.Int.Input( + "duration", + default=190, + min=1, + max=190, + step=1, + tooltip="Controls the duration in seconds of the generated audio.", + optional=True, + ), + comfy_io.Int.Input( + "seed", + default=0, + min=0, + max=4294967294, + step=1, + display_mode=comfy_io.NumberDisplay.number, + control_after_generate=True, + tooltip="The random seed used for generation.", + optional=True, + ), + comfy_io.Int.Input( + "steps", + default=8, + min=4, + max=8, + step=1, + tooltip="Controls the number of sampling steps.", + optional=True, + ), + comfy_io.Float.Input( + "strength", + default=1, + min=0.01, + max=1.0, + step=0.01, + display_mode=comfy_io.NumberDisplay.slider, + tooltip="Parameter controls how much influence the audio parameter has on the generated audio.", + optional=True, + ), + ], + outputs=[ + comfy_io.Audio.Output(), + ], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + comfy_io.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, model: str, prompt: str, audio: Input.Audio, duration: int, seed: int, steps: int, strength: float + ) -> comfy_io.NodeOutput: + validate_string(prompt, max_length=10000) + validate_audio_duration(audio, 6, 190) + payload = StabilityAudioToAudioRequest( + prompt=prompt, model=model, duration=duration, seed=seed, steps=steps, strength=strength + ) + operation = SynchronousOperation( + endpoint=ApiEndpoint( + path="/proxy/stability/v2beta/audio/stable-audio-2/audio-to-audio", + method=HttpMethod.POST, + request_model=StabilityAudioToAudioRequest, + response_model=StabilityAudioResponse, + ), + request=payload, + content_type="multipart/form-data", + files={"audio": audio_input_to_mp3(audio)}, + auth_kwargs= { + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, + }, + ) + response_api = await operation.execute() + if not response_api.audio: + raise ValueError("No audio file was received in response.") + return comfy_io.NodeOutput(audio_bytes_to_audio_input(base64.b64decode(response_api.audio))) + + +class StabilityAudioInpaint(comfy_io.ComfyNode): + """Transforms part of existing audio sample using text instructions.""" + + @classmethod + def define_schema(cls): + return comfy_io.Schema( + node_id="StabilityAudioInpaint", + display_name="Stability AI Audio Inpaint", + category="api node/audio/Stability AI", + description=cleandoc(cls.__doc__ or ""), + inputs=[ + comfy_io.Combo.Input( + "model", + options=["stable-audio-2.5"], + ), + comfy_io.String.Input("prompt", multiline=True, default=""), + comfy_io.Audio.Input("audio", tooltip="Audio must be between 6 and 190 seconds long."), + comfy_io.Int.Input( + "duration", + default=190, + min=1, + max=190, + step=1, + tooltip="Controls the duration in seconds of the generated audio.", + optional=True, + ), + comfy_io.Int.Input( + "seed", + default=0, + min=0, + max=4294967294, + step=1, + display_mode=comfy_io.NumberDisplay.number, + control_after_generate=True, + tooltip="The random seed used for generation.", + optional=True, + ), + comfy_io.Int.Input( + "steps", + default=8, + min=4, + max=8, + step=1, + tooltip="Controls the number of sampling steps.", + optional=True, + ), + comfy_io.Int.Input( + "mask_start", + default=30, + min=0, + max=190, + step=1, + optional=True, + ), + comfy_io.Int.Input( + "mask_end", + default=190, + min=0, + max=190, + step=1, + optional=True, + ), + ], + outputs=[ + comfy_io.Audio.Output(), + ], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + comfy_io.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, + model: str, + prompt: str, + audio: Input.Audio, + duration: int, + seed: int, + steps: int, + mask_start: int, + mask_end: int, + ) -> comfy_io.NodeOutput: + validate_string(prompt, max_length=10000) + if mask_end <= mask_start: + raise ValueError(f"Value of mask_end({mask_end}) should be greater then mask_start({mask_start})") + validate_audio_duration(audio, 6, 190) + + payload = StabilityAudioInpaintRequest( + prompt=prompt, + model=model, + duration=duration, + seed=seed, + steps=steps, + mask_start=mask_start, + mask_end=mask_end, + ) + operation = SynchronousOperation( + endpoint=ApiEndpoint( + path="/proxy/stability/v2beta/audio/stable-audio-2/inpaint", + method=HttpMethod.POST, + request_model=StabilityAudioInpaintRequest, + response_model=StabilityAudioResponse, + ), + request=payload, + content_type="multipart/form-data", + files={"audio": audio_input_to_mp3(audio)}, + auth_kwargs={ + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, + }, + ) + response_api = await operation.execute() + if not response_api.audio: + raise ValueError("No audio file was received in response.") + return comfy_io.NodeOutput(audio_bytes_to_audio_input(base64.b64decode(response_api.audio))) + + class StabilityExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]: @@ -658,6 +965,9 @@ class StabilityExtension(ComfyExtension): StabilityUpscaleConservativeNode, StabilityUpscaleCreativeNode, StabilityUpscaleFastNode, + StabilityTextToAudio, + StabilityAudioToAudio, + StabilityAudioInpaint, ] diff --git a/comfy_api_nodes/util/validation_utils.py b/comfy_api_nodes/util/validation_utils.py index 606b794bf..ca913e9b3 100644 --- a/comfy_api_nodes/util/validation_utils.py +++ b/comfy_api_nodes/util/validation_utils.py @@ -2,7 +2,7 @@ import logging from typing import Optional import torch -from comfy_api.input.video_types import VideoInput +from comfy_api.latest import Input def get_image_dimensions(image: torch.Tensor) -> tuple[int, int]: @@ -101,7 +101,7 @@ def validate_aspect_ratio_closeness( def validate_video_dimensions( - video: VideoInput, + video: Input.Video, min_width: Optional[int] = None, max_width: Optional[int] = None, min_height: Optional[int] = None, @@ -126,7 +126,7 @@ def validate_video_dimensions( def validate_video_duration( - video: VideoInput, + video: Input.Video, min_duration: Optional[float] = None, max_duration: Optional[float] = None, ): @@ -151,3 +151,17 @@ def get_number_of_images(images): if isinstance(images, torch.Tensor): return images.shape[0] if images.ndim >= 4 else 1 return len(images) + + +def validate_audio_duration( + audio: Input.Audio, + min_duration: Optional[float] = None, + max_duration: Optional[float] = None, +) -> None: + sr = int(audio["sample_rate"]) + dur = int(audio["waveform"].shape[-1]) / sr + eps = 1.0 / sr + if min_duration is not None and dur + eps < min_duration: + raise ValueError(f"Audio duration must be at least {min_duration}s, got {dur + eps:.2f}s") + if max_duration is not None and dur - eps > max_duration: + raise ValueError(f"Audio duration must be at most {max_duration}s, got {dur - eps:.2f}s") From 8d7c930246bd33c32eb957b01ab0d364af6e81c0 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 10 Sep 2025 10:51:02 -0400 Subject: [PATCH 07/13] ComfyUI version v0.3.58 --- comfyui_version.py | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/comfyui_version.py b/comfyui_version.py index 4cc3c8647..37361bd75 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.3.57" +__version__ = "0.3.58" diff --git a/pyproject.toml b/pyproject.toml index d75cd04a2..f02ab9126 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.3.57" +version = "0.3.58" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.9" From 9b0553809cbac084aac0576892aca3e448eb07c7 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Thu, 11 Sep 2025 00:13:18 +0300 Subject: [PATCH 08/13] add new ByteDanceSeedream (4.0) node (#9802) --- comfy_api_nodes/nodes_bytedance.py | 208 ++++++++++++++++++++++++++++- 1 file changed, 207 insertions(+), 1 deletion(-) diff --git a/comfy_api_nodes/nodes_bytedance.py b/comfy_api_nodes/nodes_bytedance.py index 064df2d10..369a3a4fe 100644 --- a/comfy_api_nodes/nodes_bytedance.py +++ b/comfy_api_nodes/nodes_bytedance.py @@ -77,6 +77,22 @@ class Image2ImageTaskCreationRequest(BaseModel): watermark: Optional[bool] = Field(True) +class Seedream4Options(BaseModel): + max_images: int = Field(15) + + +class Seedream4TaskCreationRequest(BaseModel): + model: str = Field("seedream-4-0-250828") + prompt: str = Field(...) + response_format: str = Field("url") + image: Optional[list[str]] = Field(None, description="Image URLs") + size: str = Field(...) + seed: int = Field(..., ge=0, le=2147483647) + sequential_image_generation: str = Field("disabled") + sequential_image_generation_options: Seedream4Options = Field(Seedream4Options(max_images=15)) + watermark: bool = Field(True) + + class ImageTaskCreationResponse(BaseModel): model: str = Field(...) created: int = Field(..., description="Unix timestamp (in seconds) indicating time when the request was created.") @@ -143,6 +159,19 @@ RECOMMENDED_PRESETS = [ ("Custom", None, None), ] +RECOMMENDED_PRESETS_SEEDREAM_4 = [ + ("2048x2048 (1:1)", 2048, 2048), + ("2304x1728 (4:3)", 2304, 1728), + ("1728x2304 (3:4)", 1728, 2304), + ("2560x1440 (16:9)", 2560, 1440), + ("1440x2560 (9:16)", 1440, 2560), + ("2496x1664 (3:2)", 2496, 1664), + ("1664x2496 (2:3)", 1664, 2496), + ("3024x1296 (21:9)", 3024, 1296), + ("4096x4096 (1:1)", 4096, 4096), + ("Custom", None, None), +] + # The time in this dictionary are given for 10 seconds duration. VIDEO_TASKS_EXECUTION_TIME = { "seedance-1-0-lite-t2v-250428": { @@ -348,7 +377,7 @@ class ByteDanceImageEditNode(comfy_io.ComfyNode): return comfy_io.Schema( node_id="ByteDanceImageEditNode", display_name="ByteDance Image Edit", - category="api node/video/ByteDance", + category="api node/image/ByteDance", description="Edit images using ByteDance models via api based on prompt", inputs=[ comfy_io.Combo.Input( @@ -451,6 +480,182 @@ class ByteDanceImageEditNode(comfy_io.ComfyNode): return comfy_io.NodeOutput(await download_url_to_image_tensor(get_image_url_from_response(response))) +class ByteDanceSeedreamNode(comfy_io.ComfyNode): + + @classmethod + def define_schema(cls): + return comfy_io.Schema( + node_id="ByteDanceSeedreamNode", + display_name="ByteDance Seedream 4", + category="api node/image/ByteDance", + description="Unified text-to-image generation and precise single-sentence editing at up to 4K resolution.", + inputs=[ + comfy_io.Combo.Input( + "model", + options=["seedream-4-0-250828"], + tooltip="Model name", + ), + comfy_io.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Text prompt for creating or editing an image.", + ), + comfy_io.Image.Input( + "image", + tooltip="Input image(s) for image-to-image generation. " + "List of 1-10 images for single or multi-reference generation.", + optional=True, + ), + comfy_io.Combo.Input( + "size_preset", + options=[label for label, _, _ in RECOMMENDED_PRESETS_SEEDREAM_4], + tooltip="Pick a recommended size. Select Custom to use the width and height below.", + ), + comfy_io.Int.Input( + "width", + default=2048, + min=1024, + max=4096, + step=64, + tooltip="Custom width for image. Value is working only if `size_preset` is set to `Custom`", + optional=True, + ), + comfy_io.Int.Input( + "height", + default=2048, + min=1024, + max=4096, + step=64, + tooltip="Custom height for image. Value is working only if `size_preset` is set to `Custom`", + optional=True, + ), + comfy_io.Combo.Input( + "sequential_image_generation", + options=["disabled", "auto"], + tooltip="Group image generation mode. " + "'disabled' generates a single image. " + "'auto' lets the model decide whether to generate multiple related images " + "(e.g., story scenes, character variations).", + optional=True, + ), + comfy_io.Int.Input( + "max_images", + default=1, + min=1, + max=15, + step=1, + display_mode=comfy_io.NumberDisplay.number, + tooltip="Maximum number of images to generate when sequential_image_generation='auto'. " + "Total images (input + generated) cannot exceed 15.", + optional=True, + ), + comfy_io.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + step=1, + display_mode=comfy_io.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed to use for generation.", + optional=True, + ), + comfy_io.Boolean.Input( + "watermark", + default=True, + tooltip="Whether to add an \"AI generated\" watermark to the image.", + optional=True, + ), + ], + outputs=[ + comfy_io.Image.Output(), + ], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + comfy_io.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, + model: str, + prompt: str, + image: torch.Tensor = None, + size_preset: str = RECOMMENDED_PRESETS_SEEDREAM_4[0][0], + width: int = 2048, + height: int = 2048, + sequential_image_generation: str = "disabled", + max_images: int = 1, + seed: int = 0, + watermark: bool = True, + ) -> comfy_io.NodeOutput: + validate_string(prompt, strip_whitespace=True, min_length=1) + w = h = None + for label, tw, th in RECOMMENDED_PRESETS_SEEDREAM_4: + if label == size_preset: + w, h = tw, th + break + + if w is None or h is None: + w, h = width, height + if not (1024 <= w <= 4096) or not (1024 <= h <= 4096): + raise ValueError( + f"Custom size out of range: {w}x{h}. " + "Both width and height must be between 1024 and 4096 pixels." + ) + n_input_images = get_number_of_images(image) if image is not None else 0 + if n_input_images > 10: + raise ValueError(f"Maximum of 10 reference images are supported, but {n_input_images} received.") + if sequential_image_generation == "auto" and n_input_images + max_images > 15: + raise ValueError( + "The maximum number of generated images plus the number of reference images cannot exceed 15." + ) + auth_kwargs = { + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, + } + reference_images_urls = [] + if n_input_images: + for i in image: + validate_image_aspect_ratio_range(i, (1, 3), (3, 1)) + reference_images_urls = (await upload_images_to_comfyapi( + image, + max_images=n_input_images, + mime_type="image/png", + auth_kwargs=auth_kwargs, + )) + payload = Seedream4TaskCreationRequest( + model=model, + prompt=prompt, + image=reference_images_urls, + size=f"{w}x{h}", + seed=seed, + sequential_image_generation=sequential_image_generation, + sequential_image_generation_options=Seedream4Options(max_images=max_images), + watermark=watermark, + ) + response = await SynchronousOperation( + endpoint=ApiEndpoint( + path=BYTEPLUS_IMAGE_ENDPOINT, + method=HttpMethod.POST, + request_model=Seedream4TaskCreationRequest, + response_model=ImageTaskCreationResponse, + ), + request=payload, + auth_kwargs=auth_kwargs, + ).execute() + + if len(response.data) == 1: + return comfy_io.NodeOutput(await download_url_to_image_tensor(get_image_url_from_response(response))) + return comfy_io.NodeOutput( + torch.cat([await download_url_to_image_tensor(str(i["url"])) for i in response.data]) + ) + + class ByteDanceTextToVideoNode(comfy_io.ComfyNode): @classmethod @@ -1001,6 +1206,7 @@ class ByteDanceExtension(ComfyExtension): return [ ByteDanceImageNode, ByteDanceImageEditNode, + ByteDanceSeedreamNode, ByteDanceTextToVideoNode, ByteDanceImageToVideoNode, ByteDanceFirstLastFrameNode, From df34f1549a431c85a6326e87075a206228697cde Mon Sep 17 00:00:00 2001 From: ComfyUI Wiki Date: Thu, 11 Sep 2025 05:16:41 +0800 Subject: [PATCH 09/13] Update template to 0.1.78 (#9806) * Update template to 0.1.77 * Update template to 0.1.78 --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index ea1931d78..d31df0fec 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.25.11 -comfyui-workflow-templates==0.1.76 +comfyui-workflow-templates==0.1.78 comfyui-embedded-docs==0.2.6 torch torchsde From 72212fef660bcd7d9702fa52011d089c027a64d8 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 10 Sep 2025 17:25:41 -0400 Subject: [PATCH 10/13] ComfyUI version 0.3.59 --- comfyui_version.py | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/comfyui_version.py b/comfyui_version.py index 37361bd75..ee58205f5 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.3.58" +__version__ = "0.3.59" diff --git a/pyproject.toml b/pyproject.toml index f02ab9126..a7fc1a5a6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.3.58" +version = "0.3.59" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.9" From e01e99d075852b94e93f27ea64ab862a49a7d2cc Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 10 Sep 2025 20:17:34 -0700 Subject: [PATCH 11/13] Support hunyuan image distilled model. (#9807) --- comfy/ldm/hunyuan_video/model.py | 14 ++++++++++++++ comfy/model_detection.py | 12 ++++++++++-- 2 files changed, 24 insertions(+), 2 deletions(-) diff --git a/comfy/ldm/hunyuan_video/model.py b/comfy/ldm/hunyuan_video/model.py index ca289c5bd..7732182a4 100644 --- a/comfy/ldm/hunyuan_video/model.py +++ b/comfy/ldm/hunyuan_video/model.py @@ -41,6 +41,7 @@ class HunyuanVideoParams: qkv_bias: bool guidance_embed: bool byt5: bool + meanflow: bool class SelfAttentionRef(nn.Module): @@ -256,6 +257,11 @@ class HunyuanVideo(nn.Module): else: self.byt5_in = None + if params.meanflow: + self.time_r_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations) + else: + self.time_r_in = None + if final_layer: self.final_layer = LastLayer(self.hidden_size, self.patch_size[-1], self.out_channels, dtype=dtype, device=device, operations=operations) @@ -282,6 +288,14 @@ class HunyuanVideo(nn.Module): img = self.img_in(img) vec = self.time_in(timestep_embedding(timesteps, 256, time_factor=1.0).to(img.dtype)) + if self.time_r_in is not None: + w = torch.where(transformer_options['sigmas'][0] == transformer_options['sample_sigmas'])[0] # This most likely could be improved + if len(w) > 0: + timesteps_r = transformer_options['sample_sigmas'][w[0] + 1] + timesteps_r = timesteps_r.unsqueeze(0).to(device=timesteps.device, dtype=timesteps.dtype) + vec_r = self.time_r_in(timestep_embedding(timesteps_r, 256, time_factor=1000.0).to(img.dtype)) + vec = (vec + vec_r) / 2 + if ref_latent is not None: ref_latent_ids = self.img_ids(ref_latent) ref_latent = self.img_in(ref_latent) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index dbcbe5f5a..fe983cede 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -142,12 +142,20 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["in_channels"] = in_w.shape[1] #SkyReels img2video has 32 input channels dit_config["patch_size"] = list(in_w.shape[2:]) dit_config["out_channels"] = out_w.shape[0] // math.prod(dit_config["patch_size"]) - if '{}vector_in.in_layer.weight'.format(key_prefix) in state_dict: + if any(s.startswith('{}vector_in.'.format(key_prefix)) for s in state_dict_keys): dit_config["vec_in_dim"] = 768 - dit_config["axes_dim"] = [16, 56, 56] else: dit_config["vec_in_dim"] = None + + if len(dit_config["patch_size"]) == 2: dit_config["axes_dim"] = [64, 64] + else: + dit_config["axes_dim"] = [16, 56, 56] + + if any(s.startswith('{}time_r_in.'.format(key_prefix)) for s in state_dict_keys): + dit_config["meanflow"] = True + else: + dit_config["meanflow"] = False dit_config["context_in_dim"] = state_dict['{}txt_in.input_embedder.weight'.format(key_prefix)].shape[1] dit_config["hidden_size"] = in_w.shape[0] From df6850fae8a75126cb7a645e38d58cebcfd51096 Mon Sep 17 00:00:00 2001 From: ComfyUI Wiki Date: Fri, 12 Sep 2025 02:59:26 +0800 Subject: [PATCH 12/13] Update template to 0.1.81 (#9811) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index d31df0fec..0e21967ef 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.25.11 -comfyui-workflow-templates==0.1.78 +comfyui-workflow-templates==0.1.81 comfyui-embedded-docs==0.2.6 torch torchsde From 18de0b28305fd8bf002d74e91c0630bd76b01d6b Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Thu, 11 Sep 2025 16:33:02 -0700 Subject: [PATCH 13/13] Fast preview for hunyuan image. (#9814) --- comfy/latent_formats.py | 68 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 68 insertions(+) diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py index 859ae8421..f975b5e11 100644 --- a/comfy/latent_formats.py +++ b/comfy/latent_formats.py @@ -538,6 +538,74 @@ class HunyuanImage21(LatentFormat): latent_dimensions = 2 scale_factor = 0.75289 + latent_rgb_factors = [ + [-0.0154, -0.0397, -0.0521], + [ 0.0005, 0.0093, 0.0006], + [-0.0805, -0.0773, -0.0586], + [-0.0494, -0.0487, -0.0498], + [-0.0212, -0.0076, -0.0261], + [-0.0179, -0.0417, -0.0505], + [ 0.0158, 0.0310, 0.0239], + [ 0.0409, 0.0516, 0.0201], + [ 0.0350, 0.0553, 0.0036], + [-0.0447, -0.0327, -0.0479], + [-0.0038, -0.0221, -0.0365], + [-0.0423, -0.0718, -0.0654], + [ 0.0039, 0.0368, 0.0104], + [ 0.0655, 0.0217, 0.0122], + [ 0.0490, 0.1638, 0.2053], + [ 0.0932, 0.0829, 0.0650], + [-0.0186, -0.0209, -0.0135], + [-0.0080, -0.0076, -0.0148], + [-0.0284, -0.0201, 0.0011], + [-0.0642, -0.0294, -0.0777], + [-0.0035, 0.0076, -0.0140], + [ 0.0519, 0.0731, 0.0887], + [-0.0102, 0.0095, 0.0704], + [ 0.0068, 0.0218, -0.0023], + [-0.0726, -0.0486, -0.0519], + [ 0.0260, 0.0295, 0.0263], + [ 0.0250, 0.0333, 0.0341], + [ 0.0168, -0.0120, -0.0174], + [ 0.0226, 0.1037, 0.0114], + [ 0.2577, 0.1906, 0.1604], + [-0.0646, -0.0137, -0.0018], + [-0.0112, 0.0309, 0.0358], + [-0.0347, 0.0146, -0.0481], + [ 0.0234, 0.0179, 0.0201], + [ 0.0157, 0.0313, 0.0225], + [ 0.0423, 0.0675, 0.0524], + [-0.0031, 0.0027, -0.0255], + [ 0.0447, 0.0555, 0.0330], + [-0.0152, 0.0103, 0.0299], + [-0.0755, -0.0489, -0.0635], + [ 0.0853, 0.0788, 0.1017], + [-0.0272, -0.0294, -0.0471], + [ 0.0440, 0.0400, -0.0137], + [ 0.0335, 0.0317, -0.0036], + [-0.0344, -0.0621, -0.0984], + [-0.0127, -0.0630, -0.0620], + [-0.0648, 0.0360, 0.0924], + [-0.0781, -0.0801, -0.0409], + [ 0.0363, 0.0613, 0.0499], + [ 0.0238, 0.0034, 0.0041], + [-0.0135, 0.0258, 0.0310], + [ 0.0614, 0.1086, 0.0589], + [ 0.0428, 0.0350, 0.0205], + [ 0.0153, 0.0173, -0.0018], + [-0.0288, -0.0455, -0.0091], + [ 0.0344, 0.0109, -0.0157], + [-0.0205, -0.0247, -0.0187], + [ 0.0487, 0.0126, 0.0064], + [-0.0220, -0.0013, 0.0074], + [-0.0203, -0.0094, -0.0048], + [-0.0719, 0.0429, -0.0442], + [ 0.1042, 0.0497, 0.0356], + [-0.0659, -0.0578, -0.0280], + [-0.0060, -0.0322, -0.0234]] + + latent_rgb_factors_bias = [0.0007, -0.0256, -0.0206] + class Hunyuan3Dv2(LatentFormat): latent_channels = 64 latent_dimensions = 1