Improvements to the TAESD3 implementation.

This commit is contained in:
comfyanonymous
2024-06-16 02:04:24 -04:00
parent df7db0e027
commit 04e8798c37
3 changed files with 10 additions and 9 deletions

View File

@@ -54,6 +54,7 @@ class TAESD(nn.Module):
self.taesd_encoder = Encoder(latent_channels=latent_channels)
self.taesd_decoder = Decoder(latent_channels=latent_channels)
self.vae_scale = torch.nn.Parameter(torch.tensor(1.0))
self.vae_shift = torch.nn.Parameter(torch.tensor(0.0))
if encoder_path is not None:
self.taesd_encoder.load_state_dict(comfy.utils.load_torch_file(encoder_path, safe_load=True))
if decoder_path is not None:
@@ -70,9 +71,9 @@ class TAESD(nn.Module):
return x.sub(TAESD.latent_shift).mul(2 * TAESD.latent_magnitude)
def decode(self, x):
x_sample = self.taesd_decoder(x * self.vae_scale)
x_sample = self.taesd_decoder((x - self.vae_shift) * self.vae_scale)
x_sample = x_sample.sub(0.5).mul(2)
return x_sample
def encode(self, x):
return self.taesd_encoder(x * 0.5 + 0.5) / self.vae_scale
return (self.taesd_encoder(x * 0.5 + 0.5) / self.vae_scale) + self.vae_shift