mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-09-13 04:55:53 +00:00
Add taesd and taesdxl to VAELoader node.
They will show up if both the taesd_encoder and taesd_decoder or taesdxl model files are present in the models/vae_approx directory.
This commit is contained in:
@@ -46,15 +46,16 @@ class TAESD(nn.Module):
|
||||
latent_magnitude = 3
|
||||
latent_shift = 0.5
|
||||
|
||||
def __init__(self, encoder_path="taesd_encoder.pth", decoder_path="taesd_decoder.pth"):
|
||||
def __init__(self, encoder_path=None, decoder_path=None):
|
||||
"""Initialize pretrained TAESD on the given device from the given checkpoints."""
|
||||
super().__init__()
|
||||
self.encoder = Encoder()
|
||||
self.decoder = Decoder()
|
||||
self.taesd_encoder = Encoder()
|
||||
self.taesd_decoder = Decoder()
|
||||
self.vae_scale = torch.nn.Parameter(torch.tensor(1.0))
|
||||
if encoder_path is not None:
|
||||
self.encoder.load_state_dict(comfy.utils.load_torch_file(encoder_path, safe_load=True))
|
||||
self.taesd_encoder.load_state_dict(comfy.utils.load_torch_file(encoder_path, safe_load=True))
|
||||
if decoder_path is not None:
|
||||
self.decoder.load_state_dict(comfy.utils.load_torch_file(decoder_path, safe_load=True))
|
||||
self.taesd_decoder.load_state_dict(comfy.utils.load_torch_file(decoder_path, safe_load=True))
|
||||
|
||||
@staticmethod
|
||||
def scale_latents(x):
|
||||
@@ -65,3 +66,11 @@ class TAESD(nn.Module):
|
||||
def unscale_latents(x):
|
||||
"""[0, 1] -> raw latents"""
|
||||
return x.sub(TAESD.latent_shift).mul(2 * TAESD.latent_magnitude)
|
||||
|
||||
def decode(self, x):
|
||||
x_sample = self.taesd_decoder(x * self.vae_scale)
|
||||
x_sample = x_sample.sub(0.5).mul(2)
|
||||
return x_sample
|
||||
|
||||
def encode(self, x):
|
||||
return self.taesd_encoder(x * 0.5 + 0.5) / self.vae_scale
|
||||
|
Reference in New Issue
Block a user