Refactor of sampler code to deal more easily with different model types.

This commit is contained in:
comfyanonymous
2023-07-17 01:22:12 -04:00
parent ac9c038ac2
commit 3ded1a3a04
8 changed files with 68 additions and 53 deletions

View File

@@ -4,10 +4,15 @@ from comfy.ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugme
from comfy.ldm.modules.diffusionmodules.util import make_beta_schedule
from comfy.ldm.modules.diffusionmodules.openaimodel import Timestep
import numpy as np
from enum import Enum
from . import utils
class ModelType(Enum):
EPS = 1
V_PREDICTION = 2
class BaseModel(torch.nn.Module):
def __init__(self, model_config, v_prediction=False):
def __init__(self, model_config, model_type=ModelType.EPS):
super().__init__()
unet_config = model_config.unet_config
@@ -15,16 +20,11 @@ class BaseModel(torch.nn.Module):
self.model_config = model_config
self.register_schedule(given_betas=None, beta_schedule="linear", timesteps=1000, linear_start=0.00085, linear_end=0.012, cosine_s=8e-3)
self.diffusion_model = UNetModel(**unet_config)
self.v_prediction = v_prediction
if self.v_prediction:
self.parameterization = "v"
else:
self.parameterization = "eps"
self.model_type = model_type
self.adm_channels = unet_config.get("adm_in_channels", None)
if self.adm_channels is None:
self.adm_channels = 0
print("v_prediction", v_prediction)
print("model_type", model_type.name)
print("adm", self.adm_channels)
def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
@@ -103,8 +103,8 @@ class BaseModel(torch.nn.Module):
class SD21UNCLIP(BaseModel):
def __init__(self, model_config, noise_aug_config, v_prediction=True):
super().__init__(model_config, v_prediction)
def __init__(self, model_config, noise_aug_config, model_type=ModelType.V_PREDICTION):
super().__init__(model_config, model_type)
self.noise_augmentor = CLIPEmbeddingNoiseAugmentation(**noise_aug_config)
def encode_adm(self, **kwargs):
@@ -139,13 +139,13 @@ class SD21UNCLIP(BaseModel):
return adm_out
class SDInpaint(BaseModel):
def __init__(self, model_config, v_prediction=False):
super().__init__(model_config, v_prediction)
def __init__(self, model_config, model_type=ModelType.EPS):
super().__init__(model_config, model_type)
self.concat_keys = ("mask", "masked_image")
class SDXLRefiner(BaseModel):
def __init__(self, model_config, v_prediction=False):
super().__init__(model_config, v_prediction)
def __init__(self, model_config, model_type=ModelType.EPS):
super().__init__(model_config, model_type)
self.embedder = Timestep(256)
def encode_adm(self, **kwargs):
@@ -171,8 +171,8 @@ class SDXLRefiner(BaseModel):
return torch.cat((clip_pooled.to(flat.device), flat), dim=1)
class SDXL(BaseModel):
def __init__(self, model_config, v_prediction=False):
super().__init__(model_config, v_prediction)
def __init__(self, model_config, model_type=ModelType.EPS):
super().__init__(model_config, model_type)
self.embedder = Timestep(256)
def encode_adm(self, **kwargs):