Move latent scale factor from VAE to model.

This commit is contained in:
comfyanonymous
2023-06-23 02:14:12 -04:00
parent 30a3861946
commit 8607c2d42d
7 changed files with 73 additions and 33 deletions

View File

@@ -6,9 +6,11 @@ from comfy.ldm.modules.diffusionmodules.openaimodel import Timestep
import numpy as np
class BaseModel(torch.nn.Module):
def __init__(self, unet_config, v_prediction=False):
def __init__(self, model_config, v_prediction=False):
super().__init__()
unet_config = model_config.unet_config
self.latent_format = model_config.latent_format
self.register_schedule(given_betas=None, beta_schedule="linear", timesteps=1000, linear_start=0.00085, linear_end=0.012, cosine_s=8e-3)
self.diffusion_model = UNetModel(**unet_config)
self.v_prediction = v_prediction
@@ -75,9 +77,16 @@ class BaseModel(torch.nn.Module):
del to_load
return self
def process_latent_in(self, latent):
return self.latent_format.process_in(latent)
def process_latent_out(self, latent):
return self.latent_format.process_out(latent)
class SD21UNCLIP(BaseModel):
def __init__(self, unet_config, noise_aug_config, v_prediction=True):
super().__init__(unet_config, v_prediction)
def __init__(self, model_config, noise_aug_config, v_prediction=True):
super().__init__(model_config, v_prediction)
self.noise_augmentor = CLIPEmbeddingNoiseAugmentation(**noise_aug_config)
def encode_adm(self, **kwargs):
@@ -112,13 +121,13 @@ class SD21UNCLIP(BaseModel):
return adm_out
class SDInpaint(BaseModel):
def __init__(self, unet_config, v_prediction=False):
super().__init__(unet_config, v_prediction)
def __init__(self, model_config, v_prediction=False):
super().__init__(model_config, v_prediction)
self.concat_keys = ("mask", "masked_image")
class SDXLRefiner(BaseModel):
def __init__(self, unet_config, v_prediction=False):
super().__init__(unet_config, v_prediction)
def __init__(self, model_config, v_prediction=False):
super().__init__(model_config, v_prediction)
self.embedder = Timestep(256)
def encode_adm(self, **kwargs):
@@ -144,8 +153,8 @@ class SDXLRefiner(BaseModel):
return torch.cat((clip_pooled.to(flat.device), flat), dim=1)
class SDXL(BaseModel):
def __init__(self, unet_config, v_prediction=False):
super().__init__(unet_config, v_prediction)
def __init__(self, model_config, v_prediction=False):
super().__init__(model_config, v_prediction)
self.embedder = Timestep(256)
def encode_adm(self, **kwargs):