mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-09-11 20:17:30 +00:00
Improve s2v performance when generating videos longer than 120 frames. (#9582)
This commit is contained in:
@@ -1255,6 +1255,7 @@ class WanModel_S2V(WanModel):
|
|||||||
audio_emb = None
|
audio_emb = None
|
||||||
|
|
||||||
# embeddings
|
# embeddings
|
||||||
|
bs, _, time, height, width = x.shape
|
||||||
x = self.patch_embedding(x.float()).to(x.dtype)
|
x = self.patch_embedding(x.float()).to(x.dtype)
|
||||||
if control_video is not None:
|
if control_video is not None:
|
||||||
x = x + self.cond_encoder(control_video)
|
x = x + self.cond_encoder(control_video)
|
||||||
@@ -1272,7 +1273,7 @@ class WanModel_S2V(WanModel):
|
|||||||
if reference_latent is not None:
|
if reference_latent is not None:
|
||||||
ref = self.patch_embedding(reference_latent.float()).to(x.dtype)
|
ref = self.patch_embedding(reference_latent.float()).to(x.dtype)
|
||||||
ref = ref.flatten(2).transpose(1, 2)
|
ref = ref.flatten(2).transpose(1, 2)
|
||||||
freqs_ref = self.rope_encode(reference_latent.shape[-3], reference_latent.shape[-2], reference_latent.shape[-1], t_start=30, device=x.device, dtype=x.dtype)
|
freqs_ref = self.rope_encode(reference_latent.shape[-3], reference_latent.shape[-2], reference_latent.shape[-1], t_start=max(30, time + 9), device=x.device, dtype=x.dtype)
|
||||||
ref = ref + cond_mask_weight[1]
|
ref = ref + cond_mask_weight[1]
|
||||||
x = torch.cat([x, ref], dim=1)
|
x = torch.cat([x, ref], dim=1)
|
||||||
freqs = torch.cat([freqs, freqs_ref], dim=1)
|
freqs = torch.cat([freqs, freqs_ref], dim=1)
|
||||||
@@ -1296,7 +1297,6 @@ class WanModel_S2V(WanModel):
|
|||||||
# context
|
# context
|
||||||
context = self.text_embedding(context)
|
context = self.text_embedding(context)
|
||||||
|
|
||||||
|
|
||||||
patches_replace = transformer_options.get("patches_replace", {})
|
patches_replace = transformer_options.get("patches_replace", {})
|
||||||
blocks_replace = patches_replace.get("dit", {})
|
blocks_replace = patches_replace.get("dit", {})
|
||||||
for i, block in enumerate(self.blocks):
|
for i, block in enumerate(self.blocks):
|
||||||
|
Reference in New Issue
Block a user