mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-09-12 12:37:01 +00:00
Support hunyuan image 2.1 regular model. (#9792)
This commit is contained in:
@@ -533,6 +533,11 @@ class Wan22(Wan21):
|
|||||||
0.3971, 1.0600, 0.3943, 0.5537, 0.5444, 0.4089, 0.7468, 0.7744
|
0.3971, 1.0600, 0.3943, 0.5537, 0.5444, 0.4089, 0.7468, 0.7744
|
||||||
]).view(1, self.latent_channels, 1, 1, 1)
|
]).view(1, self.latent_channels, 1, 1, 1)
|
||||||
|
|
||||||
|
class HunyuanImage21(LatentFormat):
|
||||||
|
latent_channels = 64
|
||||||
|
latent_dimensions = 2
|
||||||
|
scale_factor = 0.75289
|
||||||
|
|
||||||
class Hunyuan3Dv2(LatentFormat):
|
class Hunyuan3Dv2(LatentFormat):
|
||||||
latent_channels = 64
|
latent_channels = 64
|
||||||
latent_dimensions = 1
|
latent_dimensions = 1
|
||||||
|
@@ -40,6 +40,7 @@ class HunyuanVideoParams:
|
|||||||
patch_size: list
|
patch_size: list
|
||||||
qkv_bias: bool
|
qkv_bias: bool
|
||||||
guidance_embed: bool
|
guidance_embed: bool
|
||||||
|
byt5: bool
|
||||||
|
|
||||||
|
|
||||||
class SelfAttentionRef(nn.Module):
|
class SelfAttentionRef(nn.Module):
|
||||||
@@ -161,6 +162,30 @@ class TokenRefiner(nn.Module):
|
|||||||
x = self.individual_token_refiner(x, c, mask)
|
x = self.individual_token_refiner(x, c, mask)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class ByT5Mapper(nn.Module):
|
||||||
|
def __init__(self, in_dim, out_dim, hidden_dim, out_dim1, use_res=False, dtype=None, device=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
self.layernorm = operations.LayerNorm(in_dim, dtype=dtype, device=device)
|
||||||
|
self.fc1 = operations.Linear(in_dim, hidden_dim, dtype=dtype, device=device)
|
||||||
|
self.fc2 = operations.Linear(hidden_dim, out_dim, dtype=dtype, device=device)
|
||||||
|
self.fc3 = operations.Linear(out_dim, out_dim1, dtype=dtype, device=device)
|
||||||
|
self.use_res = use_res
|
||||||
|
self.act_fn = nn.GELU()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
if self.use_res:
|
||||||
|
res = x
|
||||||
|
x = self.layernorm(x)
|
||||||
|
x = self.fc1(x)
|
||||||
|
x = self.act_fn(x)
|
||||||
|
x = self.fc2(x)
|
||||||
|
x2 = self.act_fn(x)
|
||||||
|
x2 = self.fc3(x2)
|
||||||
|
if self.use_res:
|
||||||
|
x2 = x2 + res
|
||||||
|
return x2
|
||||||
|
|
||||||
class HunyuanVideo(nn.Module):
|
class HunyuanVideo(nn.Module):
|
||||||
"""
|
"""
|
||||||
Transformer model for flow matching on sequences.
|
Transformer model for flow matching on sequences.
|
||||||
@@ -185,9 +210,13 @@ class HunyuanVideo(nn.Module):
|
|||||||
self.num_heads = params.num_heads
|
self.num_heads = params.num_heads
|
||||||
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
|
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
|
||||||
|
|
||||||
self.img_in = comfy.ldm.modules.diffusionmodules.mmdit.PatchEmbed(None, self.patch_size, self.in_channels, self.hidden_size, conv3d=True, dtype=dtype, device=device, operations=operations)
|
self.img_in = comfy.ldm.modules.diffusionmodules.mmdit.PatchEmbed(None, self.patch_size, self.in_channels, self.hidden_size, conv3d=len(self.patch_size) == 3, dtype=dtype, device=device, operations=operations)
|
||||||
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations)
|
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations)
|
||||||
|
if params.vec_in_dim is not None:
|
||||||
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size, dtype=dtype, device=device, operations=operations)
|
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size, dtype=dtype, device=device, operations=operations)
|
||||||
|
else:
|
||||||
|
self.vector_in = None
|
||||||
|
|
||||||
self.guidance_in = (
|
self.guidance_in = (
|
||||||
MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations) if params.guidance_embed else nn.Identity()
|
MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations) if params.guidance_embed else nn.Identity()
|
||||||
)
|
)
|
||||||
@@ -215,6 +244,18 @@ class HunyuanVideo(nn.Module):
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if params.byt5:
|
||||||
|
self.byt5_in = ByT5Mapper(
|
||||||
|
in_dim=1472,
|
||||||
|
out_dim=2048,
|
||||||
|
hidden_dim=2048,
|
||||||
|
out_dim1=self.hidden_size,
|
||||||
|
use_res=False,
|
||||||
|
dtype=dtype, device=device, operations=operations
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.byt5_in = None
|
||||||
|
|
||||||
if final_layer:
|
if final_layer:
|
||||||
self.final_layer = LastLayer(self.hidden_size, self.patch_size[-1], self.out_channels, dtype=dtype, device=device, operations=operations)
|
self.final_layer = LastLayer(self.hidden_size, self.patch_size[-1], self.out_channels, dtype=dtype, device=device, operations=operations)
|
||||||
|
|
||||||
@@ -226,7 +267,8 @@ class HunyuanVideo(nn.Module):
|
|||||||
txt_ids: Tensor,
|
txt_ids: Tensor,
|
||||||
txt_mask: Tensor,
|
txt_mask: Tensor,
|
||||||
timesteps: Tensor,
|
timesteps: Tensor,
|
||||||
y: Tensor,
|
y: Tensor = None,
|
||||||
|
txt_byt5=None,
|
||||||
guidance: Tensor = None,
|
guidance: Tensor = None,
|
||||||
guiding_frame_index=None,
|
guiding_frame_index=None,
|
||||||
ref_latent=None,
|
ref_latent=None,
|
||||||
@@ -250,12 +292,16 @@ class HunyuanVideo(nn.Module):
|
|||||||
|
|
||||||
if guiding_frame_index is not None:
|
if guiding_frame_index is not None:
|
||||||
token_replace_vec = self.time_in(timestep_embedding(guiding_frame_index, 256, time_factor=1.0))
|
token_replace_vec = self.time_in(timestep_embedding(guiding_frame_index, 256, time_factor=1.0))
|
||||||
|
if self.vector_in is not None:
|
||||||
vec_ = self.vector_in(y[:, :self.params.vec_in_dim])
|
vec_ = self.vector_in(y[:, :self.params.vec_in_dim])
|
||||||
vec = torch.cat([(vec_ + token_replace_vec).unsqueeze(1), (vec_ + vec).unsqueeze(1)], dim=1)
|
vec = torch.cat([(vec_ + token_replace_vec).unsqueeze(1), (vec_ + vec).unsqueeze(1)], dim=1)
|
||||||
|
else:
|
||||||
|
vec = torch.cat([(token_replace_vec).unsqueeze(1), (vec).unsqueeze(1)], dim=1)
|
||||||
frame_tokens = (initial_shape[-1] // self.patch_size[-1]) * (initial_shape[-2] // self.patch_size[-2])
|
frame_tokens = (initial_shape[-1] // self.patch_size[-1]) * (initial_shape[-2] // self.patch_size[-2])
|
||||||
modulation_dims = [(0, frame_tokens, 0), (frame_tokens, None, 1)]
|
modulation_dims = [(0, frame_tokens, 0), (frame_tokens, None, 1)]
|
||||||
modulation_dims_txt = [(0, None, 1)]
|
modulation_dims_txt = [(0, None, 1)]
|
||||||
else:
|
else:
|
||||||
|
if self.vector_in is not None:
|
||||||
vec = vec + self.vector_in(y[:, :self.params.vec_in_dim])
|
vec = vec + self.vector_in(y[:, :self.params.vec_in_dim])
|
||||||
modulation_dims = None
|
modulation_dims = None
|
||||||
modulation_dims_txt = None
|
modulation_dims_txt = None
|
||||||
@@ -269,6 +315,12 @@ class HunyuanVideo(nn.Module):
|
|||||||
|
|
||||||
txt = self.txt_in(txt, timesteps, txt_mask)
|
txt = self.txt_in(txt, timesteps, txt_mask)
|
||||||
|
|
||||||
|
if self.byt5_in is not None and txt_byt5 is not None:
|
||||||
|
txt_byt5 = self.byt5_in(txt_byt5)
|
||||||
|
txt_byt5_ids = torch.zeros((txt_ids.shape[0], txt_byt5.shape[1], txt_ids.shape[-1]), device=txt_ids.device, dtype=txt_ids.dtype)
|
||||||
|
txt = torch.cat((txt, txt_byt5), dim=1)
|
||||||
|
txt_ids = torch.cat((txt_ids, txt_byt5_ids), dim=1)
|
||||||
|
|
||||||
ids = torch.cat((img_ids, txt_ids), dim=1)
|
ids = torch.cat((img_ids, txt_ids), dim=1)
|
||||||
pe = self.pe_embedder(ids)
|
pe = self.pe_embedder(ids)
|
||||||
|
|
||||||
@@ -328,12 +380,16 @@ class HunyuanVideo(nn.Module):
|
|||||||
|
|
||||||
img = self.final_layer(img, vec, modulation_dims=modulation_dims) # (N, T, patch_size ** 2 * out_channels)
|
img = self.final_layer(img, vec, modulation_dims=modulation_dims) # (N, T, patch_size ** 2 * out_channels)
|
||||||
|
|
||||||
shape = initial_shape[-3:]
|
shape = initial_shape[-len(self.patch_size):]
|
||||||
for i in range(len(shape)):
|
for i in range(len(shape)):
|
||||||
shape[i] = shape[i] // self.patch_size[i]
|
shape[i] = shape[i] // self.patch_size[i]
|
||||||
img = img.reshape([img.shape[0]] + shape + [self.out_channels] + self.patch_size)
|
img = img.reshape([img.shape[0]] + shape + [self.out_channels] + self.patch_size)
|
||||||
|
if img.ndim == 8:
|
||||||
img = img.permute(0, 4, 1, 5, 2, 6, 3, 7)
|
img = img.permute(0, 4, 1, 5, 2, 6, 3, 7)
|
||||||
img = img.reshape(initial_shape[0], self.out_channels, initial_shape[2], initial_shape[3], initial_shape[4])
|
img = img.reshape(initial_shape[0], self.out_channels, initial_shape[2], initial_shape[3], initial_shape[4])
|
||||||
|
else:
|
||||||
|
img = img.permute(0, 3, 1, 4, 2, 5)
|
||||||
|
img = img.reshape(initial_shape[0], self.out_channels, initial_shape[2], initial_shape[3])
|
||||||
return img
|
return img
|
||||||
|
|
||||||
def img_ids(self, x):
|
def img_ids(self, x):
|
||||||
@@ -348,16 +404,30 @@ class HunyuanVideo(nn.Module):
|
|||||||
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).reshape(1, 1, -1)
|
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).reshape(1, 1, -1)
|
||||||
return repeat(img_ids, "t h w c -> b (t h w) c", b=bs)
|
return repeat(img_ids, "t h w c -> b (t h w) c", b=bs)
|
||||||
|
|
||||||
def forward(self, x, timestep, context, y, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, control=None, transformer_options={}, **kwargs):
|
def img_ids_2d(self, x):
|
||||||
|
bs, c, h, w = x.shape
|
||||||
|
patch_size = self.patch_size
|
||||||
|
h_len = ((h + (patch_size[0] // 2)) // patch_size[0])
|
||||||
|
w_len = ((w + (patch_size[1] // 2)) // patch_size[1])
|
||||||
|
img_ids = torch.zeros((h_len, w_len, 2), device=x.device, dtype=x.dtype)
|
||||||
|
img_ids[:, :, 0] = img_ids[:, :, 0] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1)
|
||||||
|
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
|
||||||
|
return repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
||||||
|
|
||||||
|
def forward(self, x, timestep, context, y=None, txt_byt5=None, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, control=None, transformer_options={}, **kwargs):
|
||||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||||
self._forward,
|
self._forward,
|
||||||
self,
|
self,
|
||||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
|
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
|
||||||
).execute(x, timestep, context, y, guidance, attention_mask, guiding_frame_index, ref_latent, control, transformer_options, **kwargs)
|
).execute(x, timestep, context, y, txt_byt5, guidance, attention_mask, guiding_frame_index, ref_latent, control, transformer_options, **kwargs)
|
||||||
|
|
||||||
def _forward(self, x, timestep, context, y, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, control=None, transformer_options={}, **kwargs):
|
def _forward(self, x, timestep, context, y=None, txt_byt5=None, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, control=None, transformer_options={}, **kwargs):
|
||||||
bs, c, t, h, w = x.shape
|
bs = x.shape[0]
|
||||||
|
if len(self.patch_size) == 3:
|
||||||
img_ids = self.img_ids(x)
|
img_ids = self.img_ids(x)
|
||||||
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
||||||
out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, guidance, guiding_frame_index, ref_latent, control=control, transformer_options=transformer_options)
|
else:
|
||||||
|
img_ids = self.img_ids_2d(x)
|
||||||
|
txt_ids = torch.zeros((bs, context.shape[1], 2), device=x.device, dtype=x.dtype)
|
||||||
|
out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, txt_byt5, guidance, guiding_frame_index, ref_latent, control=control, transformer_options=transformer_options)
|
||||||
return out
|
return out
|
||||||
|
136
comfy/ldm/hunyuan_video/vae.py
Normal file
136
comfy/ldm/hunyuan_video/vae.py
Normal file
@@ -0,0 +1,136 @@
|
|||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, AttnBlock
|
||||||
|
import comfy.ops
|
||||||
|
ops = comfy.ops.disable_weight_init
|
||||||
|
|
||||||
|
|
||||||
|
class PixelShuffle2D(nn.Module):
|
||||||
|
def __init__(self, in_dim, out_dim, op=ops.Conv2d):
|
||||||
|
super().__init__()
|
||||||
|
self.conv = op(in_dim, out_dim >> 2, 3, 1, 1)
|
||||||
|
self.ratio = (in_dim << 2) // out_dim
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
b, c, h, w = x.shape
|
||||||
|
h2, w2 = h >> 1, w >> 1
|
||||||
|
y = self.conv(x).view(b, -1, h2, 2, w2, 2).permute(0, 3, 5, 1, 2, 4).reshape(b, -1, h2, w2)
|
||||||
|
r = x.view(b, c, h2, 2, w2, 2).permute(0, 3, 5, 1, 2, 4).reshape(b, c << 2, h2, w2)
|
||||||
|
return y + r.view(b, y.shape[1], self.ratio, h2, w2).mean(2)
|
||||||
|
|
||||||
|
|
||||||
|
class PixelUnshuffle2D(nn.Module):
|
||||||
|
def __init__(self, in_dim, out_dim, op=ops.Conv2d):
|
||||||
|
super().__init__()
|
||||||
|
self.conv = op(in_dim, out_dim << 2, 3, 1, 1)
|
||||||
|
self.scale = (out_dim << 2) // in_dim
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
b, c, h, w = x.shape
|
||||||
|
h2, w2 = h << 1, w << 1
|
||||||
|
y = self.conv(x).view(b, 2, 2, -1, h, w).permute(0, 3, 4, 1, 5, 2).reshape(b, -1, h2, w2)
|
||||||
|
r = x.repeat_interleave(self.scale, 1).view(b, 2, 2, -1, h, w).permute(0, 3, 4, 1, 5, 2).reshape(b, -1, h2, w2)
|
||||||
|
return y + r
|
||||||
|
|
||||||
|
|
||||||
|
class Encoder(nn.Module):
|
||||||
|
def __init__(self, in_channels, z_channels, block_out_channels, num_res_blocks,
|
||||||
|
ffactor_spatial, downsample_match_channel=True, **_):
|
||||||
|
super().__init__()
|
||||||
|
self.z_channels = z_channels
|
||||||
|
self.block_out_channels = block_out_channels
|
||||||
|
self.num_res_blocks = num_res_blocks
|
||||||
|
self.conv_in = ops.Conv2d(in_channels, block_out_channels[0], 3, 1, 1)
|
||||||
|
|
||||||
|
self.down = nn.ModuleList()
|
||||||
|
ch = block_out_channels[0]
|
||||||
|
depth = (ffactor_spatial >> 1).bit_length()
|
||||||
|
|
||||||
|
for i, tgt in enumerate(block_out_channels):
|
||||||
|
stage = nn.Module()
|
||||||
|
stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt,
|
||||||
|
out_channels=tgt,
|
||||||
|
temb_channels=0,
|
||||||
|
conv_op=ops.Conv2d)
|
||||||
|
for j in range(num_res_blocks)])
|
||||||
|
ch = tgt
|
||||||
|
if i < depth:
|
||||||
|
nxt = block_out_channels[i + 1] if i + 1 < len(block_out_channels) and downsample_match_channel else ch
|
||||||
|
stage.downsample = PixelShuffle2D(ch, nxt, ops.Conv2d)
|
||||||
|
ch = nxt
|
||||||
|
self.down.append(stage)
|
||||||
|
|
||||||
|
self.mid = nn.Module()
|
||||||
|
self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=ops.Conv2d)
|
||||||
|
self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv2d)
|
||||||
|
self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=ops.Conv2d)
|
||||||
|
|
||||||
|
self.norm_out = nn.GroupNorm(32, ch, 1e-6, True)
|
||||||
|
self.conv_out = ops.Conv2d(ch, z_channels << 1, 3, 1, 1)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.conv_in(x)
|
||||||
|
|
||||||
|
for stage in self.down:
|
||||||
|
for blk in stage.block:
|
||||||
|
x = blk(x)
|
||||||
|
if hasattr(stage, 'downsample'):
|
||||||
|
x = stage.downsample(x)
|
||||||
|
|
||||||
|
x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(x)))
|
||||||
|
|
||||||
|
b, c, h, w = x.shape
|
||||||
|
grp = c // (self.z_channels << 1)
|
||||||
|
skip = x.view(b, c // grp, grp, h, w).mean(2)
|
||||||
|
|
||||||
|
return self.conv_out(F.silu(self.norm_out(x))) + skip
|
||||||
|
|
||||||
|
|
||||||
|
class Decoder(nn.Module):
|
||||||
|
def __init__(self, z_channels, out_channels, block_out_channels, num_res_blocks,
|
||||||
|
ffactor_spatial, upsample_match_channel=True, **_):
|
||||||
|
super().__init__()
|
||||||
|
block_out_channels = block_out_channels[::-1]
|
||||||
|
self.z_channels = z_channels
|
||||||
|
self.block_out_channels = block_out_channels
|
||||||
|
self.num_res_blocks = num_res_blocks
|
||||||
|
|
||||||
|
ch = block_out_channels[0]
|
||||||
|
self.conv_in = ops.Conv2d(z_channels, ch, 3, 1, 1)
|
||||||
|
|
||||||
|
self.mid = nn.Module()
|
||||||
|
self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=ops.Conv2d)
|
||||||
|
self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv2d)
|
||||||
|
self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=ops.Conv2d)
|
||||||
|
|
||||||
|
self.up = nn.ModuleList()
|
||||||
|
depth = (ffactor_spatial >> 1).bit_length()
|
||||||
|
|
||||||
|
for i, tgt in enumerate(block_out_channels):
|
||||||
|
stage = nn.Module()
|
||||||
|
stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt,
|
||||||
|
out_channels=tgt,
|
||||||
|
temb_channels=0,
|
||||||
|
conv_op=ops.Conv2d)
|
||||||
|
for j in range(num_res_blocks + 1)])
|
||||||
|
ch = tgt
|
||||||
|
if i < depth:
|
||||||
|
nxt = block_out_channels[i + 1] if i + 1 < len(block_out_channels) and upsample_match_channel else ch
|
||||||
|
stage.upsample = PixelUnshuffle2D(ch, nxt, ops.Conv2d)
|
||||||
|
ch = nxt
|
||||||
|
self.up.append(stage)
|
||||||
|
|
||||||
|
self.norm_out = nn.GroupNorm(32, ch, 1e-6, True)
|
||||||
|
self.conv_out = ops.Conv2d(ch, out_channels, 3, 1, 1)
|
||||||
|
|
||||||
|
def forward(self, z):
|
||||||
|
x = self.conv_in(z) + z.repeat_interleave(self.block_out_channels[0] // self.z_channels, 1)
|
||||||
|
x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(x)))
|
||||||
|
|
||||||
|
for stage in self.up:
|
||||||
|
for blk in stage.block:
|
||||||
|
x = blk(x)
|
||||||
|
if hasattr(stage, 'upsample'):
|
||||||
|
x = stage.upsample(x)
|
||||||
|
|
||||||
|
return self.conv_out(F.silu(self.norm_out(x)))
|
@@ -1408,3 +1408,27 @@ class QwenImage(BaseModel):
|
|||||||
if ref_latents is not None:
|
if ref_latents is not None:
|
||||||
out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16])
|
out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16])
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
class HunyuanImage21(BaseModel):
|
||||||
|
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||||
|
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan_video.model.HunyuanVideo)
|
||||||
|
|
||||||
|
def extra_conds(self, **kwargs):
|
||||||
|
out = super().extra_conds(**kwargs)
|
||||||
|
attention_mask = kwargs.get("attention_mask", None)
|
||||||
|
if attention_mask is not None:
|
||||||
|
if torch.numel(attention_mask) != attention_mask.sum():
|
||||||
|
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
|
||||||
|
cross_attn = kwargs.get("cross_attn", None)
|
||||||
|
if cross_attn is not None:
|
||||||
|
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||||
|
|
||||||
|
conditioning_byt5small = kwargs.get("conditioning_byt5small", None)
|
||||||
|
if conditioning_byt5small is not None:
|
||||||
|
out['txt_byt5'] = comfy.conds.CONDRegular(conditioning_byt5small)
|
||||||
|
|
||||||
|
guidance = kwargs.get("guidance", 6.0)
|
||||||
|
if guidance is not None:
|
||||||
|
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance]))
|
||||||
|
|
||||||
|
return out
|
||||||
|
@@ -136,20 +136,32 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
|||||||
|
|
||||||
if '{}txt_in.individual_token_refiner.blocks.0.norm1.weight'.format(key_prefix) in state_dict_keys: #Hunyuan Video
|
if '{}txt_in.individual_token_refiner.blocks.0.norm1.weight'.format(key_prefix) in state_dict_keys: #Hunyuan Video
|
||||||
dit_config = {}
|
dit_config = {}
|
||||||
|
in_w = state_dict['{}img_in.proj.weight'.format(key_prefix)]
|
||||||
|
out_w = state_dict['{}final_layer.linear.weight'.format(key_prefix)]
|
||||||
dit_config["image_model"] = "hunyuan_video"
|
dit_config["image_model"] = "hunyuan_video"
|
||||||
dit_config["in_channels"] = state_dict['{}img_in.proj.weight'.format(key_prefix)].shape[1] #SkyReels img2video has 32 input channels
|
dit_config["in_channels"] = in_w.shape[1] #SkyReels img2video has 32 input channels
|
||||||
dit_config["patch_size"] = [1, 2, 2]
|
dit_config["patch_size"] = list(in_w.shape[2:])
|
||||||
dit_config["out_channels"] = 16
|
dit_config["out_channels"] = out_w.shape[0] // math.prod(dit_config["patch_size"])
|
||||||
|
if '{}vector_in.in_layer.weight'.format(key_prefix) in state_dict:
|
||||||
dit_config["vec_in_dim"] = 768
|
dit_config["vec_in_dim"] = 768
|
||||||
dit_config["context_in_dim"] = 4096
|
dit_config["axes_dim"] = [16, 56, 56]
|
||||||
dit_config["hidden_size"] = 3072
|
else:
|
||||||
|
dit_config["vec_in_dim"] = None
|
||||||
|
dit_config["axes_dim"] = [64, 64]
|
||||||
|
|
||||||
|
dit_config["context_in_dim"] = state_dict['{}txt_in.input_embedder.weight'.format(key_prefix)].shape[1]
|
||||||
|
dit_config["hidden_size"] = in_w.shape[0]
|
||||||
dit_config["mlp_ratio"] = 4.0
|
dit_config["mlp_ratio"] = 4.0
|
||||||
dit_config["num_heads"] = 24
|
dit_config["num_heads"] = in_w.shape[0] // 128
|
||||||
dit_config["depth"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.')
|
dit_config["depth"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.')
|
||||||
dit_config["depth_single_blocks"] = count_blocks(state_dict_keys, '{}single_blocks.'.format(key_prefix) + '{}.')
|
dit_config["depth_single_blocks"] = count_blocks(state_dict_keys, '{}single_blocks.'.format(key_prefix) + '{}.')
|
||||||
dit_config["axes_dim"] = [16, 56, 56]
|
|
||||||
dit_config["theta"] = 256
|
dit_config["theta"] = 256
|
||||||
dit_config["qkv_bias"] = True
|
dit_config["qkv_bias"] = True
|
||||||
|
if '{}byt5_in.fc1.weight'.format(key_prefix) in state_dict:
|
||||||
|
dit_config["byt5"] = True
|
||||||
|
else:
|
||||||
|
dit_config["byt5"] = False
|
||||||
|
|
||||||
guidance_keys = list(filter(lambda a: a.startswith("{}guidance_in.".format(key_prefix)), state_dict_keys))
|
guidance_keys = list(filter(lambda a: a.startswith("{}guidance_in.".format(key_prefix)), state_dict_keys))
|
||||||
dit_config["guidance_embed"] = len(guidance_keys) > 0
|
dit_config["guidance_embed"] = len(guidance_keys) > 0
|
||||||
return dit_config
|
return dit_config
|
||||||
|
27
comfy/sd.py
27
comfy/sd.py
@@ -17,6 +17,7 @@ import comfy.ldm.wan.vae
|
|||||||
import comfy.ldm.wan.vae2_2
|
import comfy.ldm.wan.vae2_2
|
||||||
import comfy.ldm.hunyuan3d.vae
|
import comfy.ldm.hunyuan3d.vae
|
||||||
import comfy.ldm.ace.vae.music_dcae_pipeline
|
import comfy.ldm.ace.vae.music_dcae_pipeline
|
||||||
|
import comfy.ldm.hunyuan_video.vae
|
||||||
import yaml
|
import yaml
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
@@ -48,6 +49,7 @@ import comfy.text_encoders.hidream
|
|||||||
import comfy.text_encoders.ace
|
import comfy.text_encoders.ace
|
||||||
import comfy.text_encoders.omnigen2
|
import comfy.text_encoders.omnigen2
|
||||||
import comfy.text_encoders.qwen_image
|
import comfy.text_encoders.qwen_image
|
||||||
|
import comfy.text_encoders.hunyuan_image
|
||||||
|
|
||||||
import comfy.model_patcher
|
import comfy.model_patcher
|
||||||
import comfy.lora
|
import comfy.lora
|
||||||
@@ -328,6 +330,19 @@ class VAE:
|
|||||||
self.first_stage_model = StageC_coder()
|
self.first_stage_model = StageC_coder()
|
||||||
self.downscale_ratio = 32
|
self.downscale_ratio = 32
|
||||||
self.latent_channels = 16
|
self.latent_channels = 16
|
||||||
|
elif "decoder.conv_in.weight" in sd and sd['decoder.conv_in.weight'].shape[1] == 64:
|
||||||
|
ddconfig = {"block_out_channels": [128, 256, 512, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 32, "downsample_match_channel": True, "upsample_match_channel": True}
|
||||||
|
self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1]
|
||||||
|
self.downscale_ratio = 32
|
||||||
|
self.upscale_ratio = 32
|
||||||
|
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||||
|
self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"},
|
||||||
|
encoder_config={'target': "comfy.ldm.hunyuan_video.vae.Encoder", 'params': ddconfig},
|
||||||
|
decoder_config={'target': "comfy.ldm.hunyuan_video.vae.Decoder", 'params': ddconfig})
|
||||||
|
|
||||||
|
self.memory_used_encode = lambda shape, dtype: (700 * shape[2] * shape[3]) * model_management.dtype_size(dtype)
|
||||||
|
self.memory_used_decode = lambda shape, dtype: (700 * shape[2] * shape[3] * 32 * 32) * model_management.dtype_size(dtype)
|
||||||
|
|
||||||
elif "decoder.conv_in.weight" in sd:
|
elif "decoder.conv_in.weight" in sd:
|
||||||
#default SD1.x/SD2.x VAE parameters
|
#default SD1.x/SD2.x VAE parameters
|
||||||
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
|
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
|
||||||
@@ -785,6 +800,7 @@ class CLIPType(Enum):
|
|||||||
ACE = 16
|
ACE = 16
|
||||||
OMNIGEN2 = 17
|
OMNIGEN2 = 17
|
||||||
QWEN_IMAGE = 18
|
QWEN_IMAGE = 18
|
||||||
|
HUNYUAN_IMAGE = 19
|
||||||
|
|
||||||
|
|
||||||
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
|
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
|
||||||
@@ -806,6 +822,7 @@ class TEModel(Enum):
|
|||||||
GEMMA_2_2B = 9
|
GEMMA_2_2B = 9
|
||||||
QWEN25_3B = 10
|
QWEN25_3B = 10
|
||||||
QWEN25_7B = 11
|
QWEN25_7B = 11
|
||||||
|
BYT5_SMALL_GLYPH = 12
|
||||||
|
|
||||||
def detect_te_model(sd):
|
def detect_te_model(sd):
|
||||||
if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:
|
if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:
|
||||||
@@ -823,6 +840,9 @@ def detect_te_model(sd):
|
|||||||
if 'encoder.block.23.layer.1.DenseReluDense.wi.weight' in sd:
|
if 'encoder.block.23.layer.1.DenseReluDense.wi.weight' in sd:
|
||||||
return TEModel.T5_XXL_OLD
|
return TEModel.T5_XXL_OLD
|
||||||
if "encoder.block.0.layer.0.SelfAttention.k.weight" in sd:
|
if "encoder.block.0.layer.0.SelfAttention.k.weight" in sd:
|
||||||
|
weight = sd['encoder.block.0.layer.0.SelfAttention.k.weight']
|
||||||
|
if weight.shape[0] == 384:
|
||||||
|
return TEModel.BYT5_SMALL_GLYPH
|
||||||
return TEModel.T5_BASE
|
return TEModel.T5_BASE
|
||||||
if 'model.layers.0.post_feedforward_layernorm.weight' in sd:
|
if 'model.layers.0.post_feedforward_layernorm.weight' in sd:
|
||||||
return TEModel.GEMMA_2_2B
|
return TEModel.GEMMA_2_2B
|
||||||
@@ -937,6 +957,10 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
|||||||
clip_target.clip = comfy.text_encoders.omnigen2.te(**llama_detect(clip_data))
|
clip_target.clip = comfy.text_encoders.omnigen2.te(**llama_detect(clip_data))
|
||||||
clip_target.tokenizer = comfy.text_encoders.omnigen2.Omnigen2Tokenizer
|
clip_target.tokenizer = comfy.text_encoders.omnigen2.Omnigen2Tokenizer
|
||||||
elif te_model == TEModel.QWEN25_7B:
|
elif te_model == TEModel.QWEN25_7B:
|
||||||
|
if clip_type == CLIPType.HUNYUAN_IMAGE:
|
||||||
|
clip_target.clip = comfy.text_encoders.hunyuan_image.te(byt5=False, **llama_detect(clip_data))
|
||||||
|
clip_target.tokenizer = comfy.text_encoders.hunyuan_image.HunyuanImageTokenizer
|
||||||
|
else:
|
||||||
clip_target.clip = comfy.text_encoders.qwen_image.te(**llama_detect(clip_data))
|
clip_target.clip = comfy.text_encoders.qwen_image.te(**llama_detect(clip_data))
|
||||||
clip_target.tokenizer = comfy.text_encoders.qwen_image.QwenImageTokenizer
|
clip_target.tokenizer = comfy.text_encoders.qwen_image.QwenImageTokenizer
|
||||||
else:
|
else:
|
||||||
@@ -982,6 +1006,9 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
|||||||
|
|
||||||
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(clip_l=clip_l, clip_g=clip_g, t5=t5, llama=llama, **t5_kwargs, **llama_kwargs)
|
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(clip_l=clip_l, clip_g=clip_g, t5=t5, llama=llama, **t5_kwargs, **llama_kwargs)
|
||||||
clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer
|
clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer
|
||||||
|
elif clip_type == CLIPType.HUNYUAN_IMAGE:
|
||||||
|
clip_target.clip = comfy.text_encoders.hunyuan_image.te(**llama_detect(clip_data))
|
||||||
|
clip_target.tokenizer = comfy.text_encoders.hunyuan_image.HunyuanImageTokenizer
|
||||||
else:
|
else:
|
||||||
clip_target.clip = sdxl_clip.SDXLClipModel
|
clip_target.clip = sdxl_clip.SDXLClipModel
|
||||||
clip_target.tokenizer = sdxl_clip.SDXLTokenizer
|
clip_target.tokenizer = sdxl_clip.SDXLTokenizer
|
||||||
|
@@ -20,6 +20,7 @@ import comfy.text_encoders.wan
|
|||||||
import comfy.text_encoders.ace
|
import comfy.text_encoders.ace
|
||||||
import comfy.text_encoders.omnigen2
|
import comfy.text_encoders.omnigen2
|
||||||
import comfy.text_encoders.qwen_image
|
import comfy.text_encoders.qwen_image
|
||||||
|
import comfy.text_encoders.hunyuan_image
|
||||||
|
|
||||||
from . import supported_models_base
|
from . import supported_models_base
|
||||||
from . import latent_formats
|
from . import latent_formats
|
||||||
@@ -1295,7 +1296,31 @@ class QwenImage(supported_models_base.BASE):
|
|||||||
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
|
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
|
||||||
return supported_models_base.ClipTarget(comfy.text_encoders.qwen_image.QwenImageTokenizer, comfy.text_encoders.qwen_image.te(**hunyuan_detect))
|
return supported_models_base.ClipTarget(comfy.text_encoders.qwen_image.QwenImageTokenizer, comfy.text_encoders.qwen_image.te(**hunyuan_detect))
|
||||||
|
|
||||||
|
class HunyuanImage21(HunyuanVideo):
|
||||||
|
unet_config = {
|
||||||
|
"image_model": "hunyuan_video",
|
||||||
|
"vec_in_dim": None,
|
||||||
|
}
|
||||||
|
|
||||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ACEStep, Omnigen2, QwenImage]
|
sampling_settings = {
|
||||||
|
"shift": 5.0,
|
||||||
|
}
|
||||||
|
|
||||||
|
latent_format = latent_formats.HunyuanImage21
|
||||||
|
|
||||||
|
memory_usage_factor = 7.7
|
||||||
|
|
||||||
|
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
||||||
|
|
||||||
|
def get_model(self, state_dict, prefix="", device=None):
|
||||||
|
out = model_base.HunyuanImage21(self, device=device)
|
||||||
|
return out
|
||||||
|
|
||||||
|
def clip_target(self, state_dict={}):
|
||||||
|
pref = self.text_encoder_key_prefix[0]
|
||||||
|
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
|
||||||
|
return supported_models_base.ClipTarget(comfy.text_encoders.hunyuan_image.HunyuanImageTokenizer, comfy.text_encoders.hunyuan_image.te(**hunyuan_detect))
|
||||||
|
|
||||||
|
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ACEStep, Omnigen2, QwenImage]
|
||||||
|
|
||||||
models += [SVD_img2vid]
|
models += [SVD_img2vid]
|
||||||
|
22
comfy/text_encoders/byt5_config_small_glyph.json
Normal file
22
comfy/text_encoders/byt5_config_small_glyph.json
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
{
|
||||||
|
"d_ff": 3584,
|
||||||
|
"d_kv": 64,
|
||||||
|
"d_model": 1472,
|
||||||
|
"decoder_start_token_id": 0,
|
||||||
|
"dropout_rate": 0.1,
|
||||||
|
"eos_token_id": 1,
|
||||||
|
"dense_act_fn": "gelu_pytorch_tanh",
|
||||||
|
"initializer_factor": 1.0,
|
||||||
|
"is_encoder_decoder": true,
|
||||||
|
"is_gated_act": true,
|
||||||
|
"layer_norm_epsilon": 1e-06,
|
||||||
|
"model_type": "t5",
|
||||||
|
"num_decoder_layers": 4,
|
||||||
|
"num_heads": 6,
|
||||||
|
"num_layers": 12,
|
||||||
|
"output_past": true,
|
||||||
|
"pad_token_id": 0,
|
||||||
|
"relative_attention_num_buckets": 32,
|
||||||
|
"tie_word_embeddings": false,
|
||||||
|
"vocab_size": 1510
|
||||||
|
}
|
127
comfy/text_encoders/byt5_tokenizer/added_tokens.json
Normal file
127
comfy/text_encoders/byt5_tokenizer/added_tokens.json
Normal file
@@ -0,0 +1,127 @@
|
|||||||
|
{
|
||||||
|
"<extra_id_0>": 259,
|
||||||
|
"<extra_id_100>": 359,
|
||||||
|
"<extra_id_101>": 360,
|
||||||
|
"<extra_id_102>": 361,
|
||||||
|
"<extra_id_103>": 362,
|
||||||
|
"<extra_id_104>": 363,
|
||||||
|
"<extra_id_105>": 364,
|
||||||
|
"<extra_id_106>": 365,
|
||||||
|
"<extra_id_107>": 366,
|
||||||
|
"<extra_id_108>": 367,
|
||||||
|
"<extra_id_109>": 368,
|
||||||
|
"<extra_id_10>": 269,
|
||||||
|
"<extra_id_110>": 369,
|
||||||
|
"<extra_id_111>": 370,
|
||||||
|
"<extra_id_112>": 371,
|
||||||
|
"<extra_id_113>": 372,
|
||||||
|
"<extra_id_114>": 373,
|
||||||
|
"<extra_id_115>": 374,
|
||||||
|
"<extra_id_116>": 375,
|
||||||
|
"<extra_id_117>": 376,
|
||||||
|
"<extra_id_118>": 377,
|
||||||
|
"<extra_id_119>": 378,
|
||||||
|
"<extra_id_11>": 270,
|
||||||
|
"<extra_id_120>": 379,
|
||||||
|
"<extra_id_121>": 380,
|
||||||
|
"<extra_id_122>": 381,
|
||||||
|
"<extra_id_123>": 382,
|
||||||
|
"<extra_id_124>": 383,
|
||||||
|
"<extra_id_12>": 271,
|
||||||
|
"<extra_id_13>": 272,
|
||||||
|
"<extra_id_14>": 273,
|
||||||
|
"<extra_id_15>": 274,
|
||||||
|
"<extra_id_16>": 275,
|
||||||
|
"<extra_id_17>": 276,
|
||||||
|
"<extra_id_18>": 277,
|
||||||
|
"<extra_id_19>": 278,
|
||||||
|
"<extra_id_1>": 260,
|
||||||
|
"<extra_id_20>": 279,
|
||||||
|
"<extra_id_21>": 280,
|
||||||
|
"<extra_id_22>": 281,
|
||||||
|
"<extra_id_23>": 282,
|
||||||
|
"<extra_id_24>": 283,
|
||||||
|
"<extra_id_25>": 284,
|
||||||
|
"<extra_id_26>": 285,
|
||||||
|
"<extra_id_27>": 286,
|
||||||
|
"<extra_id_28>": 287,
|
||||||
|
"<extra_id_29>": 288,
|
||||||
|
"<extra_id_2>": 261,
|
||||||
|
"<extra_id_30>": 289,
|
||||||
|
"<extra_id_31>": 290,
|
||||||
|
"<extra_id_32>": 291,
|
||||||
|
"<extra_id_33>": 292,
|
||||||
|
"<extra_id_34>": 293,
|
||||||
|
"<extra_id_35>": 294,
|
||||||
|
"<extra_id_36>": 295,
|
||||||
|
"<extra_id_37>": 296,
|
||||||
|
"<extra_id_38>": 297,
|
||||||
|
"<extra_id_39>": 298,
|
||||||
|
"<extra_id_3>": 262,
|
||||||
|
"<extra_id_40>": 299,
|
||||||
|
"<extra_id_41>": 300,
|
||||||
|
"<extra_id_42>": 301,
|
||||||
|
"<extra_id_43>": 302,
|
||||||
|
"<extra_id_44>": 303,
|
||||||
|
"<extra_id_45>": 304,
|
||||||
|
"<extra_id_46>": 305,
|
||||||
|
"<extra_id_47>": 306,
|
||||||
|
"<extra_id_48>": 307,
|
||||||
|
"<extra_id_49>": 308,
|
||||||
|
"<extra_id_4>": 263,
|
||||||
|
"<extra_id_50>": 309,
|
||||||
|
"<extra_id_51>": 310,
|
||||||
|
"<extra_id_52>": 311,
|
||||||
|
"<extra_id_53>": 312,
|
||||||
|
"<extra_id_54>": 313,
|
||||||
|
"<extra_id_55>": 314,
|
||||||
|
"<extra_id_56>": 315,
|
||||||
|
"<extra_id_57>": 316,
|
||||||
|
"<extra_id_58>": 317,
|
||||||
|
"<extra_id_59>": 318,
|
||||||
|
"<extra_id_5>": 264,
|
||||||
|
"<extra_id_60>": 319,
|
||||||
|
"<extra_id_61>": 320,
|
||||||
|
"<extra_id_62>": 321,
|
||||||
|
"<extra_id_63>": 322,
|
||||||
|
"<extra_id_64>": 323,
|
||||||
|
"<extra_id_65>": 324,
|
||||||
|
"<extra_id_66>": 325,
|
||||||
|
"<extra_id_67>": 326,
|
||||||
|
"<extra_id_68>": 327,
|
||||||
|
"<extra_id_69>": 328,
|
||||||
|
"<extra_id_6>": 265,
|
||||||
|
"<extra_id_70>": 329,
|
||||||
|
"<extra_id_71>": 330,
|
||||||
|
"<extra_id_72>": 331,
|
||||||
|
"<extra_id_73>": 332,
|
||||||
|
"<extra_id_74>": 333,
|
||||||
|
"<extra_id_75>": 334,
|
||||||
|
"<extra_id_76>": 335,
|
||||||
|
"<extra_id_77>": 336,
|
||||||
|
"<extra_id_78>": 337,
|
||||||
|
"<extra_id_79>": 338,
|
||||||
|
"<extra_id_7>": 266,
|
||||||
|
"<extra_id_80>": 339,
|
||||||
|
"<extra_id_81>": 340,
|
||||||
|
"<extra_id_82>": 341,
|
||||||
|
"<extra_id_83>": 342,
|
||||||
|
"<extra_id_84>": 343,
|
||||||
|
"<extra_id_85>": 344,
|
||||||
|
"<extra_id_86>": 345,
|
||||||
|
"<extra_id_87>": 346,
|
||||||
|
"<extra_id_88>": 347,
|
||||||
|
"<extra_id_89>": 348,
|
||||||
|
"<extra_id_8>": 267,
|
||||||
|
"<extra_id_90>": 349,
|
||||||
|
"<extra_id_91>": 350,
|
||||||
|
"<extra_id_92>": 351,
|
||||||
|
"<extra_id_93>": 352,
|
||||||
|
"<extra_id_94>": 353,
|
||||||
|
"<extra_id_95>": 354,
|
||||||
|
"<extra_id_96>": 355,
|
||||||
|
"<extra_id_97>": 356,
|
||||||
|
"<extra_id_98>": 357,
|
||||||
|
"<extra_id_99>": 358,
|
||||||
|
"<extra_id_9>": 268
|
||||||
|
}
|
150
comfy/text_encoders/byt5_tokenizer/special_tokens_map.json
Normal file
150
comfy/text_encoders/byt5_tokenizer/special_tokens_map.json
Normal file
@@ -0,0 +1,150 @@
|
|||||||
|
{
|
||||||
|
"additional_special_tokens": [
|
||||||
|
"<extra_id_0>",
|
||||||
|
"<extra_id_1>",
|
||||||
|
"<extra_id_2>",
|
||||||
|
"<extra_id_3>",
|
||||||
|
"<extra_id_4>",
|
||||||
|
"<extra_id_5>",
|
||||||
|
"<extra_id_6>",
|
||||||
|
"<extra_id_7>",
|
||||||
|
"<extra_id_8>",
|
||||||
|
"<extra_id_9>",
|
||||||
|
"<extra_id_10>",
|
||||||
|
"<extra_id_11>",
|
||||||
|
"<extra_id_12>",
|
||||||
|
"<extra_id_13>",
|
||||||
|
"<extra_id_14>",
|
||||||
|
"<extra_id_15>",
|
||||||
|
"<extra_id_16>",
|
||||||
|
"<extra_id_17>",
|
||||||
|
"<extra_id_18>",
|
||||||
|
"<extra_id_19>",
|
||||||
|
"<extra_id_20>",
|
||||||
|
"<extra_id_21>",
|
||||||
|
"<extra_id_22>",
|
||||||
|
"<extra_id_23>",
|
||||||
|
"<extra_id_24>",
|
||||||
|
"<extra_id_25>",
|
||||||
|
"<extra_id_26>",
|
||||||
|
"<extra_id_27>",
|
||||||
|
"<extra_id_28>",
|
||||||
|
"<extra_id_29>",
|
||||||
|
"<extra_id_30>",
|
||||||
|
"<extra_id_31>",
|
||||||
|
"<extra_id_32>",
|
||||||
|
"<extra_id_33>",
|
||||||
|
"<extra_id_34>",
|
||||||
|
"<extra_id_35>",
|
||||||
|
"<extra_id_36>",
|
||||||
|
"<extra_id_37>",
|
||||||
|
"<extra_id_38>",
|
||||||
|
"<extra_id_39>",
|
||||||
|
"<extra_id_40>",
|
||||||
|
"<extra_id_41>",
|
||||||
|
"<extra_id_42>",
|
||||||
|
"<extra_id_43>",
|
||||||
|
"<extra_id_44>",
|
||||||
|
"<extra_id_45>",
|
||||||
|
"<extra_id_46>",
|
||||||
|
"<extra_id_47>",
|
||||||
|
"<extra_id_48>",
|
||||||
|
"<extra_id_49>",
|
||||||
|
"<extra_id_50>",
|
||||||
|
"<extra_id_51>",
|
||||||
|
"<extra_id_52>",
|
||||||
|
"<extra_id_53>",
|
||||||
|
"<extra_id_54>",
|
||||||
|
"<extra_id_55>",
|
||||||
|
"<extra_id_56>",
|
||||||
|
"<extra_id_57>",
|
||||||
|
"<extra_id_58>",
|
||||||
|
"<extra_id_59>",
|
||||||
|
"<extra_id_60>",
|
||||||
|
"<extra_id_61>",
|
||||||
|
"<extra_id_62>",
|
||||||
|
"<extra_id_63>",
|
||||||
|
"<extra_id_64>",
|
||||||
|
"<extra_id_65>",
|
||||||
|
"<extra_id_66>",
|
||||||
|
"<extra_id_67>",
|
||||||
|
"<extra_id_68>",
|
||||||
|
"<extra_id_69>",
|
||||||
|
"<extra_id_70>",
|
||||||
|
"<extra_id_71>",
|
||||||
|
"<extra_id_72>",
|
||||||
|
"<extra_id_73>",
|
||||||
|
"<extra_id_74>",
|
||||||
|
"<extra_id_75>",
|
||||||
|
"<extra_id_76>",
|
||||||
|
"<extra_id_77>",
|
||||||
|
"<extra_id_78>",
|
||||||
|
"<extra_id_79>",
|
||||||
|
"<extra_id_80>",
|
||||||
|
"<extra_id_81>",
|
||||||
|
"<extra_id_82>",
|
||||||
|
"<extra_id_83>",
|
||||||
|
"<extra_id_84>",
|
||||||
|
"<extra_id_85>",
|
||||||
|
"<extra_id_86>",
|
||||||
|
"<extra_id_87>",
|
||||||
|
"<extra_id_88>",
|
||||||
|
"<extra_id_89>",
|
||||||
|
"<extra_id_90>",
|
||||||
|
"<extra_id_91>",
|
||||||
|
"<extra_id_92>",
|
||||||
|
"<extra_id_93>",
|
||||||
|
"<extra_id_94>",
|
||||||
|
"<extra_id_95>",
|
||||||
|
"<extra_id_96>",
|
||||||
|
"<extra_id_97>",
|
||||||
|
"<extra_id_98>",
|
||||||
|
"<extra_id_99>",
|
||||||
|
"<extra_id_100>",
|
||||||
|
"<extra_id_101>",
|
||||||
|
"<extra_id_102>",
|
||||||
|
"<extra_id_103>",
|
||||||
|
"<extra_id_104>",
|
||||||
|
"<extra_id_105>",
|
||||||
|
"<extra_id_106>",
|
||||||
|
"<extra_id_107>",
|
||||||
|
"<extra_id_108>",
|
||||||
|
"<extra_id_109>",
|
||||||
|
"<extra_id_110>",
|
||||||
|
"<extra_id_111>",
|
||||||
|
"<extra_id_112>",
|
||||||
|
"<extra_id_113>",
|
||||||
|
"<extra_id_114>",
|
||||||
|
"<extra_id_115>",
|
||||||
|
"<extra_id_116>",
|
||||||
|
"<extra_id_117>",
|
||||||
|
"<extra_id_118>",
|
||||||
|
"<extra_id_119>",
|
||||||
|
"<extra_id_120>",
|
||||||
|
"<extra_id_121>",
|
||||||
|
"<extra_id_122>",
|
||||||
|
"<extra_id_123>",
|
||||||
|
"<extra_id_124>"
|
||||||
|
],
|
||||||
|
"eos_token": {
|
||||||
|
"content": "</s>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": true,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false
|
||||||
|
},
|
||||||
|
"pad_token": {
|
||||||
|
"content": "<pad>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": true,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false
|
||||||
|
},
|
||||||
|
"unk_token": {
|
||||||
|
"content": "<unk>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": true,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false
|
||||||
|
}
|
||||||
|
}
|
1163
comfy/text_encoders/byt5_tokenizer/tokenizer_config.json
Normal file
1163
comfy/text_encoders/byt5_tokenizer/tokenizer_config.json
Normal file
File diff suppressed because it is too large
Load Diff
100
comfy/text_encoders/hunyuan_image.py
Normal file
100
comfy/text_encoders/hunyuan_image.py
Normal file
@@ -0,0 +1,100 @@
|
|||||||
|
from comfy import sd1_clip
|
||||||
|
import comfy.text_encoders.llama
|
||||||
|
from .qwen_image import QwenImageTokenizer, QwenImageTEModel
|
||||||
|
from transformers import ByT5Tokenizer
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
|
||||||
|
class ByT5SmallTokenizer(sd1_clip.SDTokenizer):
|
||||||
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
|
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "byt5_tokenizer")
|
||||||
|
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=1472, embedding_key='byt5_small', tokenizer_class=ByT5Tokenizer, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_data=tokenizer_data)
|
||||||
|
|
||||||
|
class HunyuanImageTokenizer(QwenImageTokenizer):
|
||||||
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
|
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
||||||
|
self.llama_template = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>"
|
||||||
|
# self.llama_template_images = "{}"
|
||||||
|
self.byt5 = ByT5SmallTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
||||||
|
|
||||||
|
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
|
||||||
|
out = super().tokenize_with_weights(text, return_word_ids, **kwargs)
|
||||||
|
|
||||||
|
# ByT5 processing for HunyuanImage
|
||||||
|
text_prompt_texts = []
|
||||||
|
pattern_quote_single = r'\'(.*?)\''
|
||||||
|
pattern_quote_double = r'\"(.*?)\"'
|
||||||
|
pattern_quote_chinese_single = r'‘(.*?)’'
|
||||||
|
pattern_quote_chinese_double = r'“(.*?)”'
|
||||||
|
|
||||||
|
matches_quote_single = re.findall(pattern_quote_single, text)
|
||||||
|
matches_quote_double = re.findall(pattern_quote_double, text)
|
||||||
|
matches_quote_chinese_single = re.findall(pattern_quote_chinese_single, text)
|
||||||
|
matches_quote_chinese_double = re.findall(pattern_quote_chinese_double, text)
|
||||||
|
|
||||||
|
text_prompt_texts.extend(matches_quote_single)
|
||||||
|
text_prompt_texts.extend(matches_quote_double)
|
||||||
|
text_prompt_texts.extend(matches_quote_chinese_single)
|
||||||
|
text_prompt_texts.extend(matches_quote_chinese_double)
|
||||||
|
|
||||||
|
if len(text_prompt_texts) > 0:
|
||||||
|
out['byt5'] = self.byt5.tokenize_with_weights(''.join(map(lambda a: 'Text "{}". '.format(a), text_prompt_texts)), return_word_ids, **kwargs)
|
||||||
|
return out
|
||||||
|
|
||||||
|
class Qwen25_7BVLIModel(sd1_clip.SDClipModel):
|
||||||
|
def __init__(self, device="cpu", layer="hidden", layer_idx=-3, dtype=None, attention_mask=True, model_options={}):
|
||||||
|
llama_scaled_fp8 = model_options.get("qwen_scaled_fp8", None)
|
||||||
|
if llama_scaled_fp8 is not None:
|
||||||
|
model_options = model_options.copy()
|
||||||
|
model_options["scaled_fp8"] = llama_scaled_fp8
|
||||||
|
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Qwen25_7BVLI, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
|
||||||
|
|
||||||
|
|
||||||
|
class ByT5SmallModel(sd1_clip.SDClipModel):
|
||||||
|
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options={}):
|
||||||
|
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "byt5_config_small_glyph.json")
|
||||||
|
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, model_options=model_options, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=True, zero_out_masked=True)
|
||||||
|
|
||||||
|
|
||||||
|
class HunyuanImageTEModel(QwenImageTEModel):
|
||||||
|
def __init__(self, byt5=True, device="cpu", dtype=None, model_options={}):
|
||||||
|
super(QwenImageTEModel, self).__init__(device=device, dtype=dtype, name="qwen25_7b", clip_model=Qwen25_7BVLIModel, model_options=model_options)
|
||||||
|
|
||||||
|
if byt5:
|
||||||
|
self.byt5_small = ByT5SmallModel(device=device, dtype=dtype, model_options=model_options)
|
||||||
|
else:
|
||||||
|
self.byt5_small = None
|
||||||
|
|
||||||
|
def encode_token_weights(self, token_weight_pairs):
|
||||||
|
cond, p, extra = super().encode_token_weights(token_weight_pairs)
|
||||||
|
if self.byt5_small is not None and "byt5" in token_weight_pairs:
|
||||||
|
out = self.byt5_small.encode_token_weights(token_weight_pairs["byt5"])
|
||||||
|
extra["conditioning_byt5small"] = out[0]
|
||||||
|
return cond, p, extra
|
||||||
|
|
||||||
|
def set_clip_options(self, options):
|
||||||
|
super().set_clip_options(options)
|
||||||
|
if self.byt5_small is not None:
|
||||||
|
self.byt5_small.set_clip_options(options)
|
||||||
|
|
||||||
|
def reset_clip_options(self):
|
||||||
|
super().reset_clip_options()
|
||||||
|
if self.byt5_small is not None:
|
||||||
|
self.byt5_small.reset_clip_options()
|
||||||
|
|
||||||
|
def load_sd(self, sd):
|
||||||
|
if "encoder.block.0.layer.0.SelfAttention.o.weight" in sd:
|
||||||
|
return self.byt5_small.load_sd(sd)
|
||||||
|
else:
|
||||||
|
return super().load_sd(sd)
|
||||||
|
|
||||||
|
def te(byt5=True, dtype_llama=None, llama_scaled_fp8=None):
|
||||||
|
class QwenImageTEModel_(HunyuanImageTEModel):
|
||||||
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||||
|
if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options:
|
||||||
|
model_options = model_options.copy()
|
||||||
|
model_options["qwen_scaled_fp8"] = llama_scaled_fp8
|
||||||
|
if dtype_llama is not None:
|
||||||
|
dtype = dtype_llama
|
||||||
|
super().__init__(byt5=byt5, device=device, dtype=dtype, model_options=model_options)
|
||||||
|
return QwenImageTEModel_
|
@@ -113,6 +113,20 @@ class HunyuanImageToVideo:
|
|||||||
out_latent["samples"] = latent
|
out_latent["samples"] = latent
|
||||||
return (positive, out_latent)
|
return (positive, out_latent)
|
||||||
|
|
||||||
|
class EmptyHunyuanImageLatent:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": { "width": ("INT", {"default": 2048, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}),
|
||||||
|
"height": ("INT", {"default": 2048, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}),
|
||||||
|
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}}
|
||||||
|
RETURN_TYPES = ("LATENT",)
|
||||||
|
FUNCTION = "generate"
|
||||||
|
|
||||||
|
CATEGORY = "latent"
|
||||||
|
|
||||||
|
def generate(self, width, height, batch_size=1):
|
||||||
|
latent = torch.zeros([batch_size, 64, height // 32, width // 32], device=comfy.model_management.intermediate_device())
|
||||||
|
return ({"samples":latent}, )
|
||||||
|
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
@@ -120,4 +134,5 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
"TextEncodeHunyuanVideo_ImageToVideo": TextEncodeHunyuanVideo_ImageToVideo,
|
"TextEncodeHunyuanVideo_ImageToVideo": TextEncodeHunyuanVideo_ImageToVideo,
|
||||||
"EmptyHunyuanLatentVideo": EmptyHunyuanLatentVideo,
|
"EmptyHunyuanLatentVideo": EmptyHunyuanLatentVideo,
|
||||||
"HunyuanImageToVideo": HunyuanImageToVideo,
|
"HunyuanImageToVideo": HunyuanImageToVideo,
|
||||||
|
"EmptyHunyuanImageLatent": EmptyHunyuanImageLatent,
|
||||||
}
|
}
|
||||||
|
6
nodes.py
6
nodes.py
@@ -925,7 +925,7 @@ class CLIPLoader:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ),
|
return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ),
|
||||||
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image"], ),
|
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image"], ),
|
||||||
},
|
},
|
||||||
"optional": {
|
"optional": {
|
||||||
"device": (["default", "cpu"], {"advanced": True}),
|
"device": (["default", "cpu"], {"advanced": True}),
|
||||||
@@ -953,7 +953,7 @@ class DualCLIPLoader:
|
|||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required": { "clip_name1": (folder_paths.get_filename_list("text_encoders"), ),
|
return {"required": { "clip_name1": (folder_paths.get_filename_list("text_encoders"), ),
|
||||||
"clip_name2": (folder_paths.get_filename_list("text_encoders"), ),
|
"clip_name2": (folder_paths.get_filename_list("text_encoders"), ),
|
||||||
"type": (["sdxl", "sd3", "flux", "hunyuan_video", "hidream"], ),
|
"type": (["sdxl", "sd3", "flux", "hunyuan_video", "hidream", "hunyuan_image"], ),
|
||||||
},
|
},
|
||||||
"optional": {
|
"optional": {
|
||||||
"device": (["default", "cpu"], {"advanced": True}),
|
"device": (["default", "cpu"], {"advanced": True}),
|
||||||
@@ -963,7 +963,7 @@ class DualCLIPLoader:
|
|||||||
|
|
||||||
CATEGORY = "advanced/loaders"
|
CATEGORY = "advanced/loaders"
|
||||||
|
|
||||||
DESCRIPTION = "[Recipes]\n\nsdxl: clip-l, clip-g\nsd3: clip-l, clip-g / clip-l, t5 / clip-g, t5\nflux: clip-l, t5\nhidream: at least one of t5 or llama, recommended t5 and llama"
|
DESCRIPTION = "[Recipes]\n\nsdxl: clip-l, clip-g\nsd3: clip-l, clip-g / clip-l, t5 / clip-g, t5\nflux: clip-l, t5\nhidream: at least one of t5 or llama, recommended t5 and llama\nhunyuan_image: qwen2.5vl 7b and byt5 small"
|
||||||
|
|
||||||
def load_clip(self, clip_name1, clip_name2, type, device="default"):
|
def load_clip(self, clip_name1, clip_name2, type, device="default"):
|
||||||
clip_type = getattr(comfy.sd.CLIPType, type.upper(), comfy.sd.CLIPType.STABLE_DIFFUSION)
|
clip_type = getattr(comfy.sd.CLIPType, type.upper(), comfy.sd.CLIPType.STABLE_DIFFUSION)
|
||||||
|
Reference in New Issue
Block a user