From 491755325cc189d0aa1513b12fac738c87e38de6 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 27 Aug 2025 16:02:42 -0700 Subject: [PATCH] Better s2v memory estimation. (#9584) --- comfy/ldm/wan/model.py | 2 ++ comfy/model_base.py | 25 +++++++++++++++++++++++-- 2 files changed, 25 insertions(+), 2 deletions(-) diff --git a/comfy/ldm/wan/model.py b/comfy/ldm/wan/model.py index e70446c86..47857dc2b 100644 --- a/comfy/ldm/wan/model.py +++ b/comfy/ldm/wan/model.py @@ -1278,6 +1278,7 @@ class WanModel_S2V(WanModel): x = torch.cat([x, ref], dim=1) freqs = torch.cat([freqs, freqs_ref], dim=1) t = torch.cat([t, torch.zeros((t.shape[0], reference_latent.shape[-3]), device=t.device, dtype=t.dtype)], dim=1) + del ref, freqs_ref if reference_motion is not None: motion_encoded, freqs_motion = self.frame_packer(reference_motion, self) @@ -1287,6 +1288,7 @@ class WanModel_S2V(WanModel): t = torch.repeat_interleave(t, 2, dim=1) t = torch.cat([t, torch.zeros((t.shape[0], 3), device=t.device, dtype=t.dtype)], dim=1) + del motion_encoded, freqs_motion # time embeddings e = self.time_embedding( diff --git a/comfy/model_base.py b/comfy/model_base.py index 18d55c1c4..ce29fdc49 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -150,6 +150,7 @@ class BaseModel(torch.nn.Module): logging.debug("adm {}".format(self.adm_channels)) self.memory_usage_factor = model_config.memory_usage_factor self.memory_usage_factor_conds = () + self.memory_usage_shape_process = {} def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs): return comfy.patcher_extension.WrapperExecutor.new_class_executor( @@ -350,8 +351,15 @@ class BaseModel(torch.nn.Module): input_shapes = [input_shape] for c in self.memory_usage_factor_conds: shape = cond_shapes.get(c, None) - if shape is not None and len(shape) > 0: - input_shapes += shape + if shape is not None: + if c in self.memory_usage_shape_process: + out = [] + for s in shape: + out.append(self.memory_usage_shape_process[c](s)) + shape = out + + if len(shape) > 0: + input_shapes += shape if comfy.model_management.xformers_enabled() or comfy.model_management.pytorch_attention_flash_attention(): dtype = self.get_dtype() @@ -1204,6 +1212,8 @@ class WAN21_Camera(WAN21): class WAN22_S2V(WAN21): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel_S2V) + self.memory_usage_factor_conds = ("reference_latent", "reference_motion") + self.memory_usage_shape_process = {"reference_motion": lambda shape: [shape[0], shape[1], 1.5, shape[-2], shape[-1]]} def extra_conds(self, **kwargs): out = super().extra_conds(**kwargs) @@ -1224,6 +1234,17 @@ class WAN22_S2V(WAN21): out['control_video'] = comfy.conds.CONDRegular(self.process_latent_in(control_video)) return out + def extra_conds_shapes(self, **kwargs): + out = {} + ref_latents = kwargs.get("reference_latents", None) + if ref_latents is not None: + out['reference_latent'] = list([1, 16, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16]) + + reference_motion = kwargs.get("reference_motion", None) + if reference_motion is not None: + out['reference_motion'] = reference_motion.shape + return out + class WAN22(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel)