mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-09-11 03:58:22 +00:00
Support base SDXL and SDXL refiner models.
Large refactor of the model detection and loading code.
This commit is contained in:
@@ -2,6 +2,7 @@ import torch
|
||||
from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel
|
||||
from comfy.ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation
|
||||
from comfy.ldm.modules.diffusionmodules.util import make_beta_schedule
|
||||
from comfy.ldm.modules.diffusionmodules.openaimodel import Timestep
|
||||
import numpy as np
|
||||
|
||||
class BaseModel(torch.nn.Module):
|
||||
@@ -15,9 +16,9 @@ class BaseModel(torch.nn.Module):
|
||||
self.parameterization = "v"
|
||||
else:
|
||||
self.parameterization = "eps"
|
||||
if "adm_in_channels" in unet_config:
|
||||
self.adm_channels = unet_config["adm_in_channels"]
|
||||
else:
|
||||
|
||||
self.adm_channels = unet_config.get("adm_in_channels", None)
|
||||
if self.adm_channels is None:
|
||||
self.adm_channels = 0
|
||||
print("v_prediction", v_prediction)
|
||||
print("adm", self.adm_channels)
|
||||
@@ -55,6 +56,25 @@ class BaseModel(torch.nn.Module):
|
||||
def is_adm(self):
|
||||
return self.adm_channels > 0
|
||||
|
||||
def encode_adm(self, **kwargs):
|
||||
return None
|
||||
|
||||
def load_model_weights(self, sd, unet_prefix=""):
|
||||
to_load = {}
|
||||
keys = list(sd.keys())
|
||||
for k in keys:
|
||||
if k.startswith(unet_prefix):
|
||||
to_load[k[len(unet_prefix):]] = sd.pop(k)
|
||||
|
||||
m, u = self.diffusion_model.load_state_dict(to_load, strict=False)
|
||||
if len(m) > 0:
|
||||
print("unet missing:", m)
|
||||
|
||||
if len(u) > 0:
|
||||
print("unet unexpected:", u)
|
||||
del to_load
|
||||
return self
|
||||
|
||||
class SD21UNCLIP(BaseModel):
|
||||
def __init__(self, unet_config, noise_aug_config, v_prediction=True):
|
||||
super().__init__(unet_config, v_prediction)
|
||||
@@ -95,3 +115,55 @@ class SDInpaint(BaseModel):
|
||||
def __init__(self, unet_config, v_prediction=False):
|
||||
super().__init__(unet_config, v_prediction)
|
||||
self.concat_keys = ("mask", "masked_image")
|
||||
|
||||
class SDXLRefiner(BaseModel):
|
||||
def __init__(self, unet_config, v_prediction=False):
|
||||
super().__init__(unet_config, v_prediction)
|
||||
self.embedder = Timestep(256)
|
||||
|
||||
def encode_adm(self, **kwargs):
|
||||
clip_pooled = kwargs["pooled_output"]
|
||||
width = kwargs.get("width", 768)
|
||||
height = kwargs.get("height", 768)
|
||||
crop_w = kwargs.get("crop_w", 0)
|
||||
crop_h = kwargs.get("crop_h", 0)
|
||||
|
||||
if kwargs.get("prompt_type", "") == "negative":
|
||||
aesthetic_score = kwargs.get("aesthetic_score", 2.5)
|
||||
else:
|
||||
aesthetic_score = kwargs.get("aesthetic_score", 6)
|
||||
|
||||
print(clip_pooled.shape, width, height, crop_w, crop_h, aesthetic_score)
|
||||
out = []
|
||||
out.append(self.embedder(torch.Tensor([width])))
|
||||
out.append(self.embedder(torch.Tensor([height])))
|
||||
out.append(self.embedder(torch.Tensor([crop_w])))
|
||||
out.append(self.embedder(torch.Tensor([crop_h])))
|
||||
out.append(self.embedder(torch.Tensor([aesthetic_score])))
|
||||
flat = torch.flatten(torch.cat(out))[None, ]
|
||||
return torch.cat((clip_pooled.to(flat.device), flat), dim=1)
|
||||
|
||||
class SDXL(BaseModel):
|
||||
def __init__(self, unet_config, v_prediction=False):
|
||||
super().__init__(unet_config, v_prediction)
|
||||
self.embedder = Timestep(256)
|
||||
|
||||
def encode_adm(self, **kwargs):
|
||||
clip_pooled = kwargs["pooled_output"]
|
||||
width = kwargs.get("width", 768)
|
||||
height = kwargs.get("height", 768)
|
||||
crop_w = kwargs.get("crop_w", 0)
|
||||
crop_h = kwargs.get("crop_h", 0)
|
||||
target_width = kwargs.get("target_width", width)
|
||||
target_height = kwargs.get("target_height", height)
|
||||
|
||||
print(clip_pooled.shape, width, height, crop_w, crop_h, target_width, target_height)
|
||||
out = []
|
||||
out.append(self.embedder(torch.Tensor([width])))
|
||||
out.append(self.embedder(torch.Tensor([height])))
|
||||
out.append(self.embedder(torch.Tensor([crop_w])))
|
||||
out.append(self.embedder(torch.Tensor([crop_h])))
|
||||
out.append(self.embedder(torch.Tensor([target_width])))
|
||||
out.append(self.embedder(torch.Tensor([target_height])))
|
||||
flat = torch.flatten(torch.cat(out))[None, ]
|
||||
return torch.cat((clip_pooled.to(flat.device), flat), dim=1)
|
||||
|
Reference in New Issue
Block a user