Add strength control for vace. (#7717)

This commit is contained in:
comfyanonymous 2025-04-21 16:36:20 -07:00 committed by GitHub
parent 9d57b8afd8
commit 5d0d4ee98a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 9 additions and 4 deletions

View File

@ -582,6 +582,7 @@ class VaceWanModel(WanModel):
t, t,
context, context,
vace_context, vace_context,
vace_strength=1.0,
clip_fea=None, clip_fea=None,
freqs=None, freqs=None,
transformer_options={}, transformer_options={},
@ -629,7 +630,7 @@ class VaceWanModel(WanModel):
ii = self.vace_layers_mapping.get(i, None) ii = self.vace_layers_mapping.get(i, None)
if ii is not None: if ii is not None:
c_skip, c = self.vace_blocks[ii](c, x=x_orig, e=e0, freqs=freqs, context=context, context_img_len=context_img_len) c_skip, c = self.vace_blocks[ii](c, x=x_orig, e=e0, freqs=freqs, context=context, context_img_len=context_img_len)
x += c_skip x += c_skip * vace_strength
# head # head
x = self.head(x, e) x = self.head(x, e)

View File

@ -1068,6 +1068,9 @@ class WAN21_Vace(WAN21):
mask = torch.ones(noise_shape, device=noise.device, dtype=noise.dtype) mask = torch.ones(noise_shape, device=noise.device, dtype=noise.dtype)
out['vace_context'] = comfy.conds.CONDRegular(torch.cat([vace_frames.to(noise), mask.to(noise)], dim=1)) out['vace_context'] = comfy.conds.CONDRegular(torch.cat([vace_frames.to(noise), mask.to(noise)], dim=1))
vace_strength = kwargs.get("vace_strength", 1.0)
out['vace_strength'] = comfy.conds.CONDConstant(vace_strength)
return out return out

View File

@ -203,6 +203,7 @@ class WanVaceToVideo:
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), "height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}), "length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1000.0, "step": 0.01}),
}, },
"optional": {"control_video": ("IMAGE", ), "optional": {"control_video": ("IMAGE", ),
"control_masks": ("MASK", ), "control_masks": ("MASK", ),
@ -217,7 +218,7 @@ class WanVaceToVideo:
EXPERIMENTAL = True EXPERIMENTAL = True
def encode(self, positive, negative, vae, width, height, length, batch_size, control_video=None, control_masks=None, reference_image=None): def encode(self, positive, negative, vae, width, height, length, batch_size, strength, control_video=None, control_masks=None, reference_image=None):
latent_length = ((length - 1) // 4) + 1 latent_length = ((length - 1) // 4) + 1
if control_video is not None: if control_video is not None:
control_video = comfy.utils.common_upscale(control_video[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) control_video = comfy.utils.common_upscale(control_video[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
@ -267,8 +268,8 @@ class WanVaceToVideo:
trim_latent = reference_image.shape[2] trim_latent = reference_image.shape[2]
mask = mask.unsqueeze(0) mask = mask.unsqueeze(0)
positive = node_helpers.conditioning_set_values(positive, {"vace_frames": control_video_latent, "vace_mask": mask}) positive = node_helpers.conditioning_set_values(positive, {"vace_frames": control_video_latent, "vace_mask": mask, "vace_strength": strength})
negative = node_helpers.conditioning_set_values(negative, {"vace_frames": control_video_latent, "vace_mask": mask}) negative = node_helpers.conditioning_set_values(negative, {"vace_frames": control_video_latent, "vace_mask": mask, "vace_strength": strength})
latent = torch.zeros([batch_size, 16, latent_length, height // 8, width // 8], device=comfy.model_management.intermediate_device()) latent = torch.zeros([batch_size, 16, latent_length, height // 8, width // 8], device=comfy.model_management.intermediate_device())
out_latent = {} out_latent = {}