Add taesd and taesdxl to VAELoader node.

They will show up if both the taesd_encoder and taesd_decoder or taesdxl
model files are present in the models/vae_approx directory.
This commit is contained in:
comfyanonymous
2023-11-21 12:54:19 -05:00
parent 6ff06fa796
commit cd4fc77d5f
4 changed files with 79 additions and 17 deletions

View File

@@ -573,9 +573,55 @@ class LoraLoader:
return (model_lora, clip_lora)
class VAELoader:
@staticmethod
def vae_list():
vaes = folder_paths.get_filename_list("vae")
approx_vaes = folder_paths.get_filename_list("vae_approx")
sdxl_taesd_enc = False
sdxl_taesd_dec = False
sd1_taesd_enc = False
sd1_taesd_dec = False
for v in approx_vaes:
if v.startswith("taesd_decoder."):
sd1_taesd_dec = True
elif v.startswith("taesd_encoder."):
sd1_taesd_enc = True
elif v.startswith("taesdxl_decoder."):
sdxl_taesd_dec = True
elif v.startswith("taesdxl_encoder."):
sdxl_taesd_enc = True
if sd1_taesd_dec and sd1_taesd_enc:
vaes.append("taesd")
if sdxl_taesd_dec and sdxl_taesd_enc:
vaes.append("taesdxl")
return vaes
@staticmethod
def load_taesd(name):
sd = {}
approx_vaes = folder_paths.get_filename_list("vae_approx")
encoder = next(filter(lambda a: a.startswith("{}_encoder.".format(name)), approx_vaes))
decoder = next(filter(lambda a: a.startswith("{}_decoder.".format(name)), approx_vaes))
enc = comfy.utils.load_torch_file(folder_paths.get_full_path("vae_approx", encoder))
for k in enc:
sd["taesd_encoder.{}".format(k)] = enc[k]
dec = comfy.utils.load_torch_file(folder_paths.get_full_path("vae_approx", decoder))
for k in dec:
sd["taesd_decoder.{}".format(k)] = dec[k]
if name == "taesd":
sd["vae_scale"] = torch.tensor(0.18215)
elif name == "taesdxl":
sd["vae_scale"] = torch.tensor(0.13025)
return sd
@classmethod
def INPUT_TYPES(s):
return {"required": { "vae_name": (folder_paths.get_filename_list("vae"), )}}
return {"required": { "vae_name": (s.vae_list(), )}}
RETURN_TYPES = ("VAE",)
FUNCTION = "load_vae"
@@ -583,8 +629,11 @@ class VAELoader:
#TODO: scale factor?
def load_vae(self, vae_name):
vae_path = folder_paths.get_full_path("vae", vae_name)
sd = comfy.utils.load_torch_file(vae_path)
if vae_name in ["taesd", "taesdxl"]:
sd = self.load_taesd(vae_name)
else:
vae_path = folder_paths.get_full_path("vae", vae_name)
sd = comfy.utils.load_torch_file(vae_path)
vae = comfy.sd.VAE(sd=sd)
return (vae,)