mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-09-10 11:35:40 +00:00
Better s2v memory estimation. (#9584)
This commit is contained in:
@@ -1278,6 +1278,7 @@ class WanModel_S2V(WanModel):
|
|||||||
x = torch.cat([x, ref], dim=1)
|
x = torch.cat([x, ref], dim=1)
|
||||||
freqs = torch.cat([freqs, freqs_ref], dim=1)
|
freqs = torch.cat([freqs, freqs_ref], dim=1)
|
||||||
t = torch.cat([t, torch.zeros((t.shape[0], reference_latent.shape[-3]), device=t.device, dtype=t.dtype)], dim=1)
|
t = torch.cat([t, torch.zeros((t.shape[0], reference_latent.shape[-3]), device=t.device, dtype=t.dtype)], dim=1)
|
||||||
|
del ref, freqs_ref
|
||||||
|
|
||||||
if reference_motion is not None:
|
if reference_motion is not None:
|
||||||
motion_encoded, freqs_motion = self.frame_packer(reference_motion, self)
|
motion_encoded, freqs_motion = self.frame_packer(reference_motion, self)
|
||||||
@@ -1287,6 +1288,7 @@ class WanModel_S2V(WanModel):
|
|||||||
|
|
||||||
t = torch.repeat_interleave(t, 2, dim=1)
|
t = torch.repeat_interleave(t, 2, dim=1)
|
||||||
t = torch.cat([t, torch.zeros((t.shape[0], 3), device=t.device, dtype=t.dtype)], dim=1)
|
t = torch.cat([t, torch.zeros((t.shape[0], 3), device=t.device, dtype=t.dtype)], dim=1)
|
||||||
|
del motion_encoded, freqs_motion
|
||||||
|
|
||||||
# time embeddings
|
# time embeddings
|
||||||
e = self.time_embedding(
|
e = self.time_embedding(
|
||||||
|
@@ -150,6 +150,7 @@ class BaseModel(torch.nn.Module):
|
|||||||
logging.debug("adm {}".format(self.adm_channels))
|
logging.debug("adm {}".format(self.adm_channels))
|
||||||
self.memory_usage_factor = model_config.memory_usage_factor
|
self.memory_usage_factor = model_config.memory_usage_factor
|
||||||
self.memory_usage_factor_conds = ()
|
self.memory_usage_factor_conds = ()
|
||||||
|
self.memory_usage_shape_process = {}
|
||||||
|
|
||||||
def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
|
def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
|
||||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||||
@@ -350,8 +351,15 @@ class BaseModel(torch.nn.Module):
|
|||||||
input_shapes = [input_shape]
|
input_shapes = [input_shape]
|
||||||
for c in self.memory_usage_factor_conds:
|
for c in self.memory_usage_factor_conds:
|
||||||
shape = cond_shapes.get(c, None)
|
shape = cond_shapes.get(c, None)
|
||||||
if shape is not None and len(shape) > 0:
|
if shape is not None:
|
||||||
input_shapes += shape
|
if c in self.memory_usage_shape_process:
|
||||||
|
out = []
|
||||||
|
for s in shape:
|
||||||
|
out.append(self.memory_usage_shape_process[c](s))
|
||||||
|
shape = out
|
||||||
|
|
||||||
|
if len(shape) > 0:
|
||||||
|
input_shapes += shape
|
||||||
|
|
||||||
if comfy.model_management.xformers_enabled() or comfy.model_management.pytorch_attention_flash_attention():
|
if comfy.model_management.xformers_enabled() or comfy.model_management.pytorch_attention_flash_attention():
|
||||||
dtype = self.get_dtype()
|
dtype = self.get_dtype()
|
||||||
@@ -1204,6 +1212,8 @@ class WAN21_Camera(WAN21):
|
|||||||
class WAN22_S2V(WAN21):
|
class WAN22_S2V(WAN21):
|
||||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||||
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel_S2V)
|
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel_S2V)
|
||||||
|
self.memory_usage_factor_conds = ("reference_latent", "reference_motion")
|
||||||
|
self.memory_usage_shape_process = {"reference_motion": lambda shape: [shape[0], shape[1], 1.5, shape[-2], shape[-1]]}
|
||||||
|
|
||||||
def extra_conds(self, **kwargs):
|
def extra_conds(self, **kwargs):
|
||||||
out = super().extra_conds(**kwargs)
|
out = super().extra_conds(**kwargs)
|
||||||
@@ -1224,6 +1234,17 @@ class WAN22_S2V(WAN21):
|
|||||||
out['control_video'] = comfy.conds.CONDRegular(self.process_latent_in(control_video))
|
out['control_video'] = comfy.conds.CONDRegular(self.process_latent_in(control_video))
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
def extra_conds_shapes(self, **kwargs):
|
||||||
|
out = {}
|
||||||
|
ref_latents = kwargs.get("reference_latents", None)
|
||||||
|
if ref_latents is not None:
|
||||||
|
out['reference_latent'] = list([1, 16, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16])
|
||||||
|
|
||||||
|
reference_motion = kwargs.get("reference_motion", None)
|
||||||
|
if reference_motion is not None:
|
||||||
|
out['reference_motion'] = reference_motion.shape
|
||||||
|
return out
|
||||||
|
|
||||||
class WAN22(BaseModel):
|
class WAN22(BaseModel):
|
||||||
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
|
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
|
||||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel)
|
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel)
|
||||||
|
Reference in New Issue
Block a user