diff --git a/comfy/ldm/wan/model.py b/comfy/ldm/wan/model.py index 86d0795e9..4e2d99566 100644 --- a/comfy/ldm/wan/model.py +++ b/comfy/ldm/wan/model.py @@ -391,6 +391,7 @@ class WanModel(torch.nn.Module): cross_attn_norm=True, eps=1e-6, flf_pos_embed_token_number=None, + in_dim_ref_conv=None, image_model=None, device=None, dtype=None, @@ -484,6 +485,11 @@ class WanModel(torch.nn.Module): else: self.img_emb = None + if in_dim_ref_conv is not None: + self.ref_conv = operations.Conv2d(in_dim_ref_conv, dim, kernel_size=patch_size[1:], stride=patch_size[1:], device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + else: + self.ref_conv = None + def forward_orig( self, x, @@ -526,6 +532,13 @@ class WanModel(torch.nn.Module): e = e.reshape(t.shape[0], -1, e.shape[-1]) e0 = self.time_projection(e).unflatten(2, (6, self.dim)) + full_ref = None + if self.ref_conv is not None: + full_ref = kwargs.get("reference_latent", None) + if full_ref is not None: + full_ref = self.ref_conv(full_ref).flatten(2).transpose(1, 2) + x = torch.concat((full_ref, x), dim=1) + # context context = self.text_embedding(context) @@ -552,6 +565,9 @@ class WanModel(torch.nn.Module): # head x = self.head(x, e) + if full_ref is not None: + x = x[:, full_ref.shape[1]:] + # unpatchify x = self.unpatchify(x, grid_sizes) return x @@ -570,6 +586,9 @@ class WanModel(torch.nn.Module): x = torch.cat([x, time_dim_concat], dim=2) t_len = ((x.shape[2] + (patch_size[0] // 2)) // patch_size[0]) + if self.ref_conv is not None and "reference_latent" in kwargs: + t_len += 1 + img_ids = torch.zeros((t_len, h_len, w_len, 3), device=x.device, dtype=x.dtype) img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(0, t_len - 1, steps=t_len, device=x.device, dtype=x.dtype).reshape(-1, 1, 1) img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).reshape(1, -1, 1) diff --git a/comfy/model_base.py b/comfy/model_base.py index 8a2d9cbe6..cde61df7c 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1124,7 +1124,11 @@ class WAN21(BaseModel): mask = mask.repeat(1, 4, 1, 1, 1) mask = utils.resize_to_batch_size(mask, noise.shape[0]) - return torch.cat((mask, image), dim=1) + concat_mask_index = kwargs.get("concat_mask_index", 0) + if concat_mask_index != 0: + return torch.cat((image[:, :concat_mask_index], mask, image[:, concat_mask_index:]), dim=1) + else: + return torch.cat((mask, image), dim=1) def extra_conds(self, **kwargs): out = super().extra_conds(**kwargs) @@ -1140,6 +1144,10 @@ class WAN21(BaseModel): if time_dim_concat is not None: out['time_dim_concat'] = comfy.conds.CONDRegular(self.process_latent_in(time_dim_concat)) + reference_latents = kwargs.get("reference_latents", None) + if reference_latents is not None: + out['reference_latent'] = comfy.conds.CONDRegular(self.process_latent_in(reference_latents[-1])[:, :, 0]) + return out diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 8b57ebd2f..8acc51e20 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -373,6 +373,11 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): flf_weight = state_dict.get('{}img_emb.emb_pos'.format(key_prefix)) if flf_weight is not None: dit_config["flf_pos_embed_token_number"] = flf_weight.shape[1] + + ref_conv_weight = state_dict.get('{}ref_conv.weight'.format(key_prefix)) + if ref_conv_weight is not None: + dit_config["in_dim_ref_conv"] = ref_conv_weight.shape[1] + return dit_config if '{}latent_in.weight'.format(key_prefix) in state_dict_keys: # Hunyuan 3D diff --git a/comfy_extras/nodes_wan.py b/comfy_extras/nodes_wan.py index 0067d054d..f80c83ba6 100644 --- a/comfy_extras/nodes_wan.py +++ b/comfy_extras/nodes_wan.py @@ -103,6 +103,63 @@ class WanFunControlToVideo: out_latent["samples"] = latent return (positive, negative, out_latent) +class Wan22FunControlToVideo: + @classmethod + def INPUT_TYPES(s): + return {"required": {"positive": ("CONDITIONING", ), + "negative": ("CONDITIONING", ), + "vae": ("VAE", ), + "width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), + "height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), + "length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}), + "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), + }, + "optional": {"ref_image": ("IMAGE", ), + "control_video": ("IMAGE", ), + # "start_image": ("IMAGE", ), + }} + + RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") + RETURN_NAMES = ("positive", "negative", "latent") + FUNCTION = "encode" + + CATEGORY = "conditioning/video_models" + + def encode(self, positive, negative, vae, width, height, length, batch_size, ref_image=None, start_image=None, control_video=None): + latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) + concat_latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) + concat_latent = comfy.latent_formats.Wan21().process_out(concat_latent) + concat_latent = concat_latent.repeat(1, 2, 1, 1, 1) + mask = torch.ones((1, 1, latent.shape[2] * 4, latent.shape[-2], latent.shape[-1])) + + if start_image is not None: + start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) + concat_latent_image = vae.encode(start_image[:, :, :, :3]) + concat_latent[:,16:,:concat_latent_image.shape[2]] = concat_latent_image[:,:,:concat_latent.shape[2]] + mask[:, :, :start_image.shape[0] + 3] = 0.0 + + ref_latent = None + if ref_image is not None: + ref_image = comfy.utils.common_upscale(ref_image[:1].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) + ref_latent = vae.encode(ref_image[:, :, :, :3]) + + if control_video is not None: + control_video = comfy.utils.common_upscale(control_video[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) + concat_latent_image = vae.encode(control_video[:, :, :, :3]) + concat_latent[:,:16,:concat_latent_image.shape[2]] = concat_latent_image[:,:,:concat_latent.shape[2]] + + mask = mask.view(1, mask.shape[2] // 4, 4, mask.shape[3], mask.shape[4]).transpose(1, 2) + positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent, "concat_mask": mask, "concat_mask_index": 16}) + negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent, "concat_mask": mask, "concat_mask_index": 16}) + + if ref_latent is not None: + positive = node_helpers.conditioning_set_values(positive, {"reference_latents": [ref_latent]}, append=True) + negative = node_helpers.conditioning_set_values(negative, {"reference_latents": [ref_latent]}, append=True) + + out_latent = {} + out_latent["samples"] = latent + return (positive, negative, out_latent) + class WanFirstLastFrameToVideo: @classmethod def INPUT_TYPES(s): @@ -733,6 +790,7 @@ NODE_CLASS_MAPPINGS = { "WanTrackToVideo": WanTrackToVideo, "WanImageToVideo": WanImageToVideo, "WanFunControlToVideo": WanFunControlToVideo, + "Wan22FunControlToVideo": Wan22FunControlToVideo, "WanFunInpaintToVideo": WanFunInpaintToVideo, "WanFirstLastFrameToVideo": WanFirstLastFrameToVideo, "WanVaceToVideo": WanVaceToVideo,