diff --git a/comfy/ldm/wan/model.py b/comfy/ldm/wan/model.py index 2a30497c..5e7848bd 100644 --- a/comfy/ldm/wan/model.py +++ b/comfy/ldm/wan/model.py @@ -220,6 +220,34 @@ class WanAttentionBlock(nn.Module): return x +class VaceWanAttentionBlock(WanAttentionBlock): + def __init__( + self, + cross_attn_type, + dim, + ffn_dim, + num_heads, + window_size=(-1, -1), + qk_norm=True, + cross_attn_norm=False, + eps=1e-6, + block_id=0, + operation_settings={} + ): + super().__init__(cross_attn_type, dim, ffn_dim, num_heads, window_size, qk_norm, cross_attn_norm, eps, operation_settings=operation_settings) + self.block_id = block_id + if block_id == 0: + self.before_proj = operation_settings.get("operations").Linear(self.dim, self.dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + self.after_proj = operation_settings.get("operations").Linear(self.dim, self.dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + + def forward(self, c, x, **kwargs): + if self.block_id == 0: + c = self.before_proj(c) + x + c = super().forward(c, **kwargs) + c_skip = self.after_proj(c) + return c_skip, c + + class Head(nn.Module): def __init__(self, dim, out_dim, patch_size, eps=1e-6, operation_settings={}): @@ -395,6 +423,7 @@ class WanModel(torch.nn.Module): clip_fea=None, freqs=None, transformer_options={}, + **kwargs, ): r""" Forward pass through the diffusion model @@ -457,7 +486,7 @@ class WanModel(torch.nn.Module): x = self.unpatchify(x, grid_sizes) return x - def forward(self, x, timestep, context, clip_fea=None, transformer_options={},**kwargs): + def forward(self, x, timestep, context, clip_fea=None, transformer_options={}, **kwargs): bs, c, t, h, w = x.shape x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size) patch_size = self.patch_size @@ -471,7 +500,7 @@ class WanModel(torch.nn.Module): img_ids = repeat(img_ids, "t h w c -> b (t h w) c", b=bs) freqs = self.rope_embedder(img_ids).movedim(1, 2) - return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options)[:, :, :t, :h, :w] + return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options, **kwargs)[:, :, :t, :h, :w] def unpatchify(self, x, grid_sizes): r""" @@ -496,3 +525,114 @@ class WanModel(torch.nn.Module): u = torch.einsum('bfhwpqrc->bcfphqwr', u) u = u.reshape(b, c, *[i * j for i, j in zip(grid_sizes, self.patch_size)]) return u + + +class VaceWanModel(WanModel): + r""" + Wan diffusion backbone supporting both text-to-video and image-to-video. + """ + + def __init__(self, + model_type='vace', + patch_size=(1, 2, 2), + text_len=512, + in_dim=16, + dim=2048, + ffn_dim=8192, + freq_dim=256, + text_dim=4096, + out_dim=16, + num_heads=16, + num_layers=32, + window_size=(-1, -1), + qk_norm=True, + cross_attn_norm=True, + eps=1e-6, + flf_pos_embed_token_number=None, + image_model=None, + vace_layers=None, + vace_in_dim=None, + device=None, + dtype=None, + operations=None, + ): + + super().__init__(model_type='t2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, image_model=image_model, device=device, dtype=dtype, operations=operations) + operation_settings = {"operations": operations, "device": device, "dtype": dtype} + + # Vace + if vace_layers is not None: + self.vace_layers = vace_layers + self.vace_in_dim = vace_in_dim + # vace blocks + self.vace_blocks = nn.ModuleList([ + VaceWanAttentionBlock('t2v_cross_attn', self.dim, self.ffn_dim, self.num_heads, self.window_size, self.qk_norm, self.cross_attn_norm, self.eps, block_id=i, operation_settings=operation_settings) + for i in range(self.vace_layers) + ]) + + self.vace_layers_mapping = {i: n for n, i in enumerate(range(0, self.num_layers, self.num_layers // self.vace_layers))} + # vace patch embeddings + self.vace_patch_embedding = operations.Conv3d( + self.vace_in_dim, self.dim, kernel_size=self.patch_size, stride=self.patch_size, device=device, dtype=torch.float32 + ) + + def forward_orig( + self, + x, + t, + context, + vace_context, + clip_fea=None, + freqs=None, + transformer_options={}, + **kwargs, + ): + # embeddings + x = self.patch_embedding(x.float()).to(x.dtype) + grid_sizes = x.shape[2:] + x = x.flatten(2).transpose(1, 2) + + # time embeddings + e = self.time_embedding( + sinusoidal_embedding_1d(self.freq_dim, t).to(dtype=x[0].dtype)) + e0 = self.time_projection(e).unflatten(1, (6, self.dim)) + + # context + context = self.text_embedding(context) + + context_img_len = None + if clip_fea is not None: + if self.img_emb is not None: + context_clip = self.img_emb(clip_fea) # bs x 257 x dim + context = torch.concat([context_clip, context], dim=1) + context_img_len = clip_fea.shape[-2] + + c = self.vace_patch_embedding(vace_context.float()).to(vace_context.dtype) + c = c.flatten(2).transpose(1, 2) + + # arguments + x_orig = x + + patches_replace = transformer_options.get("patches_replace", {}) + blocks_replace = patches_replace.get("dit", {}) + for i, block in enumerate(self.blocks): + if ("double_block", i) in blocks_replace: + def block_wrap(args): + out = {} + out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len) + return out + out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap}) + x = out["img"] + else: + x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len) + + ii = self.vace_layers_mapping.get(i, None) + if ii is not None: + c_skip, c = self.vace_blocks[ii](c, x=x_orig, e=e0, freqs=freqs, context=context, context_img_len=context_img_len) + x += c_skip + # head + x = self.head(x, e) + + # unpatchify + x = self.unpatchify(x, grid_sizes) + return x diff --git a/comfy/model_base.py b/comfy/model_base.py index 8dab1740..04a10152 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1043,6 +1043,34 @@ class WAN21(BaseModel): out['clip_fea'] = comfy.conds.CONDRegular(clip_vision_output.penultimate_hidden_states) return out + +class WAN21_Vace(WAN21): + def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None): + super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.VaceWanModel) + self.image_to_video = image_to_video + + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + noise = kwargs.get("noise", None) + noise_shape = list(noise.shape) + vace_frames = kwargs.get("vace_frames", None) + if vace_frames is None: + noise_shape[1] = 32 + vace_frames = torch.zeros(noise_shape, device=noise.device, dtype=noise.dtype) + + for i in range(0, vace_frames.shape[1], 16): + vace_frames = vace_frames.clone() + vace_frames[:, i:i + 16] = self.process_latent_in(vace_frames[:, i:i + 16]) + + mask = kwargs.get("vace_mask", None) + if mask is None: + noise_shape[1] = 64 + mask = torch.ones(noise_shape, device=noise.device, dtype=noise.dtype) + + out['vace_context'] = comfy.conds.CONDRegular(torch.cat([vace_frames.to(noise), mask.to(noise)], dim=1)) + return out + + class Hunyuan3Dv2(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan3d.model.Hunyuan3Dv2) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 6499bf23..76de78a8 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -317,10 +317,15 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["cross_attn_norm"] = True dit_config["eps"] = 1e-6 dit_config["in_dim"] = state_dict['{}patch_embedding.weight'.format(key_prefix)].shape[1] - if '{}img_emb.proj.0.bias'.format(key_prefix) in state_dict_keys: - dit_config["model_type"] = "i2v" + if '{}vace_patch_embedding.weight'.format(key_prefix) in state_dict_keys: + dit_config["model_type"] = "vace" + dit_config["vace_in_dim"] = state_dict['{}vace_patch_embedding.weight'.format(key_prefix)].shape[1] + dit_config["vace_layers"] = count_blocks(state_dict_keys, '{}vace_blocks.'.format(key_prefix) + '{}.') else: - dit_config["model_type"] = "t2v" + if '{}img_emb.proj.0.bias'.format(key_prefix) in state_dict_keys: + dit_config["model_type"] = "i2v" + else: + dit_config["model_type"] = "t2v" flf_weight = state_dict.get('{}img_emb.emb_pos'.format(key_prefix)) if flf_weight is not None: dit_config["flf_pos_embed_token_number"] = flf_weight.shape[1] diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 81c47ac6..5e55035c 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -987,6 +987,16 @@ class WAN21_FunControl2V(WAN21_T2V): out = model_base.WAN21(self, image_to_video=False, device=device) return out +class WAN21_Vace(WAN21_T2V): + unet_config = { + "image_model": "wan2.1", + "model_type": "vace", + } + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.WAN21_Vace(self, image_to_video=False, device=device) + return out + class Hunyuan3Dv2(supported_models_base.BASE): unet_config = { "image_model": "hunyuan3d2", @@ -1055,6 +1065,6 @@ class HiDream(supported_models_base.BASE): return None # TODO -models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream] +models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream] models += [SVD_img2vid] diff --git a/comfy_extras/nodes_wan.py b/comfy_extras/nodes_wan.py index 8ad358ce..19a6cdfa 100644 --- a/comfy_extras/nodes_wan.py +++ b/comfy_extras/nodes_wan.py @@ -193,9 +193,115 @@ class WanFunInpaintToVideo: return flfv.encode(positive, negative, vae, width, height, length, batch_size, start_image=start_image, end_image=end_image, clip_vision_start_image=clip_vision_output) +class WanVaceToVideo: + @classmethod + def INPUT_TYPES(s): + return {"required": {"positive": ("CONDITIONING", ), + "negative": ("CONDITIONING", ), + "vae": ("VAE", ), + "width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), + "height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), + "length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}), + "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), + }, + "optional": {"control_video": ("IMAGE", ), + "control_masks": ("MASK", ), + "reference_image": ("IMAGE", ), + }} + + RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT", "INT") + RETURN_NAMES = ("positive", "negative", "latent", "trim_latent") + FUNCTION = "encode" + + CATEGORY = "conditioning/video_models" + + EXPERIMENTAL = True + + def encode(self, positive, negative, vae, width, height, length, batch_size, control_video=None, control_masks=None, reference_image=None): + latent_length = ((length - 1) // 4) + 1 + if control_video is not None: + control_video = comfy.utils.common_upscale(control_video[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) + if control_video.shape[0] < length: + control_video = torch.nn.functional.pad(control_video, (0, 0, 0, 0, 0, 0, 0, length - control_video.shape[0]), value=0.5) + else: + control_video = torch.ones((length, height, width, 3)) * 0.5 + + if reference_image is not None: + reference_image = comfy.utils.common_upscale(reference_image[:1].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) + reference_image = vae.encode(reference_image[:, :, :, :3]) + reference_image = torch.cat([reference_image, comfy.latent_formats.Wan21().process_out(torch.zeros_like(reference_image))], dim=1) + + if control_masks is None: + mask = torch.ones((length, height, width, 1)) + else: + mask = control_masks + if mask.ndim == 3: + mask = mask.unsqueeze(1) + mask = comfy.utils.common_upscale(mask[:length], width, height, "bilinear", "center").movedim(1, -1) + if mask.shape[0] < length: + mask = torch.nn.functional.pad(mask, (0, 0, 0, 0, 0, 0, 0, length - mask.shape[0]), value=1.0) + + control_video = control_video - 0.5 + inactive = (control_video * (1 - mask)) + 0.5 + reactive = (control_video * mask) + 0.5 + + inactive = vae.encode(inactive[:, :, :, :3]) + reactive = vae.encode(reactive[:, :, :, :3]) + control_video_latent = torch.cat((inactive, reactive), dim=1) + if reference_image is not None: + control_video_latent = torch.cat((reference_image, control_video_latent), dim=2) + + vae_stride = 8 + height_mask = height // vae_stride + width_mask = width // vae_stride + mask = mask.view(length, height_mask, vae_stride, width_mask, vae_stride) + mask = mask.permute(2, 4, 0, 1, 3) + mask = mask.reshape(vae_stride * vae_stride, length, height_mask, width_mask) + mask = torch.nn.functional.interpolate(mask.unsqueeze(0), size=(latent_length, height_mask, width_mask), mode='nearest-exact').squeeze(0) + + trim_latent = 0 + if reference_image is not None: + mask_pad = torch.zeros_like(mask[:, :reference_image.shape[2], :, :]) + mask = torch.cat((mask_pad, mask), dim=1) + latent_length += reference_image.shape[2] + trim_latent = reference_image.shape[2] + + mask = mask.unsqueeze(0) + positive = node_helpers.conditioning_set_values(positive, {"vace_frames": control_video_latent, "vace_mask": mask}) + negative = node_helpers.conditioning_set_values(negative, {"vace_frames": control_video_latent, "vace_mask": mask}) + + latent = torch.zeros([batch_size, 16, latent_length, height // 8, width // 8], device=comfy.model_management.intermediate_device()) + out_latent = {} + out_latent["samples"] = latent + return (positive, negative, out_latent, trim_latent) + +class TrimVideoLatent: + @classmethod + def INPUT_TYPES(s): + return {"required": { "samples": ("LATENT",), + "trim_amount": ("INT", {"default": 0, "min": 0, "max": 99999}), + }} + + RETURN_TYPES = ("LATENT",) + FUNCTION = "op" + + CATEGORY = "latent/video" + + EXPERIMENTAL = True + + def op(self, samples, trim_amount): + samples_out = samples.copy() + + s1 = samples["samples"] + samples_out["samples"] = s1[:, :, trim_amount:] + return (samples_out,) + + NODE_CLASS_MAPPINGS = { "WanImageToVideo": WanImageToVideo, "WanFunControlToVideo": WanFunControlToVideo, "WanFunInpaintToVideo": WanFunInpaintToVideo, "WanFirstLastFrameToVideo": WanFirstLastFrameToVideo, + "WanVaceToVideo": WanVaceToVideo, + "TrimVideoLatent": TrimVideoLatent, }