Better s2v memory estimation. (#9584)

This commit is contained in:
comfyanonymous
2025-08-27 16:02:42 -07:00
committed by GitHub
parent 496888fd68
commit 491755325c
2 changed files with 25 additions and 2 deletions

View File

@@ -1278,6 +1278,7 @@ class WanModel_S2V(WanModel):
x = torch.cat([x, ref], dim=1)
freqs = torch.cat([freqs, freqs_ref], dim=1)
t = torch.cat([t, torch.zeros((t.shape[0], reference_latent.shape[-3]), device=t.device, dtype=t.dtype)], dim=1)
del ref, freqs_ref
if reference_motion is not None:
motion_encoded, freqs_motion = self.frame_packer(reference_motion, self)
@@ -1287,6 +1288,7 @@ class WanModel_S2V(WanModel):
t = torch.repeat_interleave(t, 2, dim=1)
t = torch.cat([t, torch.zeros((t.shape[0], 3), device=t.device, dtype=t.dtype)], dim=1)
del motion_encoded, freqs_motion
# time embeddings
e = self.time_embedding(