diff --git a/comfy/model_base.py b/comfy/model_base.py index ce29fdc49..56a6798be 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1110,9 +1110,10 @@ class WAN21(BaseModel): shape_image[1] = extra_channels image = torch.zeros(shape_image, dtype=noise.dtype, layout=noise.layout, device=noise.device) else: + latent_dim = self.latent_format.latent_channels image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center") - for i in range(0, image.shape[1], 16): - image[:, i: i + 16] = self.process_latent_in(image[:, i: i + 16]) + for i in range(0, image.shape[1], latent_dim): + image[:, i: i + latent_dim] = self.process_latent_in(image[:, i: i + latent_dim]) image = utils.resize_to_batch_size(image, noise.shape[0]) if extra_channels != image.shape[1] + 4: @@ -1245,18 +1246,14 @@ class WAN22_S2V(WAN21): out['reference_motion'] = reference_motion.shape return out -class WAN22(BaseModel): +class WAN22(WAN21): def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None): - super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel) + super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel) self.image_to_video = image_to_video def extra_conds(self, **kwargs): out = super().extra_conds(**kwargs) - cross_attn = kwargs.get("cross_attn", None) - if cross_attn is not None: - out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) - - denoise_mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None)) + denoise_mask = kwargs.get("denoise_mask", None) if denoise_mask is not None: out["denoise_mask"] = comfy.conds.CONDRegular(denoise_mask) return out diff --git a/comfy_extras/nodes_wan.py b/comfy_extras/nodes_wan.py index 2cbc93ceb..8c1d36613 100644 --- a/comfy_extras/nodes_wan.py +++ b/comfy_extras/nodes_wan.py @@ -139,16 +139,21 @@ class Wan22FunControlToVideo(io.ComfyNode): @classmethod def execute(cls, positive, negative, vae, width, height, length, batch_size, ref_image=None, start_image=None, control_video=None) -> io.NodeOutput: - latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) - concat_latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) - concat_latent = comfy.latent_formats.Wan21().process_out(concat_latent) + spacial_scale = vae.spacial_compression_encode() + latent_channels = vae.latent_channels + latent = torch.zeros([batch_size, latent_channels, ((length - 1) // 4) + 1, height // spacial_scale, width // spacial_scale], device=comfy.model_management.intermediate_device()) + concat_latent = torch.zeros([batch_size, latent_channels, ((length - 1) // 4) + 1, height // spacial_scale, width // spacial_scale], device=comfy.model_management.intermediate_device()) + if latent_channels == 48: + concat_latent = comfy.latent_formats.Wan22().process_out(concat_latent) + else: + concat_latent = comfy.latent_formats.Wan21().process_out(concat_latent) concat_latent = concat_latent.repeat(1, 2, 1, 1, 1) mask = torch.ones((1, 1, latent.shape[2] * 4, latent.shape[-2], latent.shape[-1])) if start_image is not None: start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) concat_latent_image = vae.encode(start_image[:, :, :, :3]) - concat_latent[:,16:,:concat_latent_image.shape[2]] = concat_latent_image[:,:,:concat_latent.shape[2]] + concat_latent[:,latent_channels:,:concat_latent_image.shape[2]] = concat_latent_image[:,:,:concat_latent.shape[2]] mask[:, :, :start_image.shape[0] + 3] = 0.0 ref_latent = None @@ -159,11 +164,11 @@ class Wan22FunControlToVideo(io.ComfyNode): if control_video is not None: control_video = comfy.utils.common_upscale(control_video[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) concat_latent_image = vae.encode(control_video[:, :, :, :3]) - concat_latent[:,:16,:concat_latent_image.shape[2]] = concat_latent_image[:,:,:concat_latent.shape[2]] + concat_latent[:,:latent_channels,:concat_latent_image.shape[2]] = concat_latent_image[:,:,:concat_latent.shape[2]] mask = mask.view(1, mask.shape[2] // 4, 4, mask.shape[3], mask.shape[4]).transpose(1, 2) - positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent, "concat_mask": mask, "concat_mask_index": 16}) - negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent, "concat_mask": mask, "concat_mask_index": 16}) + positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent, "concat_mask": mask, "concat_mask_index": latent_channels}) + negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent, "concat_mask": mask, "concat_mask_index": latent_channels}) if ref_latent is not None: positive = node_helpers.conditioning_set_values(positive, {"reference_latents": [ref_latent]}, append=True)