mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-09-12 20:48:22 +00:00
A few improvements to #5937.
This commit is contained in:
@@ -416,9 +416,8 @@ class LTXVModel(torch.nn.Module):
|
||||
|
||||
self.patchifier = SymmetricPatchifier(1)
|
||||
|
||||
def forward(self, x, timestep, context, attention_mask, frame_rate=25, guiding_latent=None, transformer_options={}, **kwargs):
|
||||
def forward(self, x, timestep, context, attention_mask, frame_rate=25, guiding_latent=None, guiding_latent_noise_scale=0, transformer_options={}, **kwargs):
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
image_noise_scale = transformer_options.get("image_noise_scale", 0.15)
|
||||
|
||||
indices_grid = self.patchifier.get_grid(
|
||||
orig_num_frames=x.shape[2],
|
||||
@@ -433,20 +432,21 @@ class LTXVModel(torch.nn.Module):
|
||||
ts = torch.ones([x.shape[0], 1, x.shape[2], x.shape[3], x.shape[4]], device=x.device, dtype=x.dtype)
|
||||
input_ts = timestep.view([timestep.shape[0]] + [1] * (x.ndim - 1))
|
||||
ts *= input_ts
|
||||
ts[:, :, 0] = 0.0
|
||||
ts[:, :, 0] = guiding_latent_noise_scale * (input_ts[:, :, 0] ** 2)
|
||||
timestep = self.patchifier.patchify(ts)
|
||||
input_x = x.clone()
|
||||
x[:, :, 0] = guiding_latent[:, :, 0]
|
||||
if image_noise_scale > 0:
|
||||
if guiding_latent_noise_scale > 0:
|
||||
if self.generator is None:
|
||||
self.generator = torch.Generator(device=x.device).manual_seed(42)
|
||||
elif self.generator.device != x.device:
|
||||
self.generator = torch.Generator(device=x.device).set_state(self.generator.get_state())
|
||||
|
||||
noise_shape = [guiding_latent.shape[0], guiding_latent.shape[1], 1, guiding_latent.shape[3], guiding_latent.shape[4]]
|
||||
guiding_noise = image_noise_scale * (input_ts ** 2) * torch.randn(size=noise_shape, device=x.device, generator=self.generator)
|
||||
scale = guiding_latent_noise_scale * (input_ts ** 2)
|
||||
guiding_noise = scale * torch.randn(size=noise_shape, device=x.device, generator=self.generator)
|
||||
|
||||
x[:, :, 0] += guiding_noise[:, :, 0]
|
||||
x[:, :, 0] = guiding_noise[:, :, 0] + x[:, :, 0] * (1.0 - scale[:, :, 0])
|
||||
|
||||
|
||||
orig_shape = list(x.shape)
|
||||
|
Reference in New Issue
Block a user