mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-09-10 11:35:40 +00:00
Improve s2v performance when generating videos longer than 120 frames. (#9582)
This commit is contained in:
@@ -1255,6 +1255,7 @@ class WanModel_S2V(WanModel):
|
||||
audio_emb = None
|
||||
|
||||
# embeddings
|
||||
bs, _, time, height, width = x.shape
|
||||
x = self.patch_embedding(x.float()).to(x.dtype)
|
||||
if control_video is not None:
|
||||
x = x + self.cond_encoder(control_video)
|
||||
@@ -1272,7 +1273,7 @@ class WanModel_S2V(WanModel):
|
||||
if reference_latent is not None:
|
||||
ref = self.patch_embedding(reference_latent.float()).to(x.dtype)
|
||||
ref = ref.flatten(2).transpose(1, 2)
|
||||
freqs_ref = self.rope_encode(reference_latent.shape[-3], reference_latent.shape[-2], reference_latent.shape[-1], t_start=30, device=x.device, dtype=x.dtype)
|
||||
freqs_ref = self.rope_encode(reference_latent.shape[-3], reference_latent.shape[-2], reference_latent.shape[-1], t_start=max(30, time + 9), device=x.device, dtype=x.dtype)
|
||||
ref = ref + cond_mask_weight[1]
|
||||
x = torch.cat([x, ref], dim=1)
|
||||
freqs = torch.cat([freqs, freqs_ref], dim=1)
|
||||
@@ -1296,7 +1297,6 @@ class WanModel_S2V(WanModel):
|
||||
# context
|
||||
context = self.text_embedding(context)
|
||||
|
||||
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
for i, block in enumerate(self.blocks):
|
||||
|
Reference in New Issue
Block a user