support TAESD3 (#3738)

This commit is contained in:
Dr.Lt.Data
2024-06-16 15:03:53 +09:00
committed by GitHub
parent bb1969cab7
commit df7db0e027
5 changed files with 28 additions and 13 deletions

View File

@@ -25,18 +25,19 @@ class Block(nn.Module):
def forward(self, x):
return self.fuse(self.conv(x) + self.skip(x))
def Encoder():
def Encoder(latent_channels=4):
return nn.Sequential(
conv(3, 64), Block(64, 64),
conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
conv(64, 4),
conv(64, latent_channels),
)
def Decoder():
def Decoder(latent_channels=4):
return nn.Sequential(
Clamp(), conv(4, 64), nn.ReLU(),
Clamp(), conv(latent_channels, 64), nn.ReLU(),
Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
@@ -47,11 +48,11 @@ class TAESD(nn.Module):
latent_magnitude = 3
latent_shift = 0.5
def __init__(self, encoder_path=None, decoder_path=None):
def __init__(self, encoder_path=None, decoder_path=None, latent_channels=4):
"""Initialize pretrained TAESD on the given device from the given checkpoints."""
super().__init__()
self.taesd_encoder = Encoder()
self.taesd_decoder = Decoder()
self.taesd_encoder = Encoder(latent_channels=latent_channels)
self.taesd_decoder = Decoder(latent_channels=latent_channels)
self.vae_scale = torch.nn.Parameter(torch.tensor(1.0))
if encoder_path is not None:
self.taesd_encoder.load_state_dict(comfy.utils.load_torch_file(encoder_path, safe_load=True))