Supports TAESD models in safetensors format

This commit is contained in:
Yukimasa Funaoka
2023-10-10 13:21:44 +09:00
parent ae3e4e9ad8
commit 9eb621c95a
3 changed files with 18 additions and 5 deletions

View File

@@ -50,9 +50,17 @@ class TAESD(nn.Module):
self.encoder = Encoder()
self.decoder = Decoder()
if encoder_path is not None:
self.encoder.load_state_dict(torch.load(encoder_path, map_location="cpu", weights_only=True))
if encoder_path.lower().endswith(".safetensors"):
import safetensors.torch
self.encoder.load_state_dict(safetensors.torch.load_file(encoder_path, device="cpu"))
else:
self.encoder.load_state_dict(torch.load(encoder_path, map_location="cpu", weights_only=True))
if decoder_path is not None:
self.decoder.load_state_dict(torch.load(decoder_path, map_location="cpu", weights_only=True))
if decoder_path.lower().endswith(".safetensors"):
import safetensors.torch
self.decoder.load_state_dict(safetensors.torch.load_file(decoder_path, device="cpu"))
else:
self.decoder.load_state_dict(torch.load(decoder_path, map_location="cpu", weights_only=True))
@staticmethod
def scale_latents(x):