mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-09-11 12:06:23 +00:00
Supports TAESD models in safetensors format
This commit is contained in:
@@ -50,9 +50,17 @@ class TAESD(nn.Module):
|
||||
self.encoder = Encoder()
|
||||
self.decoder = Decoder()
|
||||
if encoder_path is not None:
|
||||
self.encoder.load_state_dict(torch.load(encoder_path, map_location="cpu", weights_only=True))
|
||||
if encoder_path.lower().endswith(".safetensors"):
|
||||
import safetensors.torch
|
||||
self.encoder.load_state_dict(safetensors.torch.load_file(encoder_path, device="cpu"))
|
||||
else:
|
||||
self.encoder.load_state_dict(torch.load(encoder_path, map_location="cpu", weights_only=True))
|
||||
if decoder_path is not None:
|
||||
self.decoder.load_state_dict(torch.load(decoder_path, map_location="cpu", weights_only=True))
|
||||
if decoder_path.lower().endswith(".safetensors"):
|
||||
import safetensors.torch
|
||||
self.decoder.load_state_dict(safetensors.torch.load_file(decoder_path, device="cpu"))
|
||||
else:
|
||||
self.decoder.load_state_dict(torch.load(decoder_path, map_location="cpu", weights_only=True))
|
||||
|
||||
@staticmethod
|
||||
def scale_latents(x):
|
||||
|
Reference in New Issue
Block a user