Improve s2v performance when generating videos longer than 120 frames. (#9582)

This commit is contained in:
comfyanonymous
2025-08-27 13:06:40 -07:00
committed by GitHub
parent b5ac6ed7ce
commit 496888fd68

View File

@@ -1255,6 +1255,7 @@ class WanModel_S2V(WanModel):
audio_emb = None
# embeddings
bs, _, time, height, width = x.shape
x = self.patch_embedding(x.float()).to(x.dtype)
if control_video is not None:
x = x + self.cond_encoder(control_video)
@@ -1272,7 +1273,7 @@ class WanModel_S2V(WanModel):
if reference_latent is not None:
ref = self.patch_embedding(reference_latent.float()).to(x.dtype)
ref = ref.flatten(2).transpose(1, 2)
freqs_ref = self.rope_encode(reference_latent.shape[-3], reference_latent.shape[-2], reference_latent.shape[-1], t_start=30, device=x.device, dtype=x.dtype)
freqs_ref = self.rope_encode(reference_latent.shape[-3], reference_latent.shape[-2], reference_latent.shape[-1], t_start=max(30, time + 9), device=x.device, dtype=x.dtype)
ref = ref + cond_mask_weight[1]
x = torch.cat([x, ref], dim=1)
freqs = torch.cat([freqs, freqs_ref], dim=1)
@@ -1296,7 +1297,6 @@ class WanModel_S2V(WanModel):
# context
context = self.text_embedding(context)
patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {})
for i, block in enumerate(self.blocks):