mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-09-14 05:25:23 +00:00
Refactor to make it easier to add custom conds to models.
This commit is contained in:
@@ -4,6 +4,7 @@ from comfy.ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugme
|
||||
from comfy.ldm.modules.diffusionmodules.util import make_beta_schedule
|
||||
from comfy.ldm.modules.diffusionmodules.openaimodel import Timestep
|
||||
import comfy.model_management
|
||||
import comfy.conds
|
||||
import numpy as np
|
||||
from enum import Enum
|
||||
from . import utils
|
||||
@@ -49,7 +50,7 @@ class BaseModel(torch.nn.Module):
|
||||
self.register_buffer('alphas_cumprod', torch.tensor(alphas_cumprod, dtype=torch.float32))
|
||||
self.register_buffer('alphas_cumprod_prev', torch.tensor(alphas_cumprod_prev, dtype=torch.float32))
|
||||
|
||||
def apply_model(self, x, t, c_concat=None, c_crossattn=None, c_adm=None, control=None, transformer_options={}):
|
||||
def apply_model(self, x, t, c_concat=None, c_crossattn=None, c_adm=None, control=None, transformer_options={}, **kwargs):
|
||||
if c_concat is not None:
|
||||
xc = torch.cat([x] + [c_concat], dim=1)
|
||||
else:
|
||||
@@ -72,7 +73,8 @@ class BaseModel(torch.nn.Module):
|
||||
def encode_adm(self, **kwargs):
|
||||
return None
|
||||
|
||||
def cond_concat(self, **kwargs):
|
||||
def extra_conds(self, **kwargs):
|
||||
out = {}
|
||||
if self.inpaint_model:
|
||||
concat_keys = ("mask", "masked_image")
|
||||
cond_concat = []
|
||||
@@ -101,8 +103,12 @@ class BaseModel(torch.nn.Module):
|
||||
cond_concat.append(torch.ones_like(noise)[:,:1])
|
||||
elif ck == "masked_image":
|
||||
cond_concat.append(blank_inpaint_image_like(noise))
|
||||
return cond_concat
|
||||
return None
|
||||
data = torch.cat(cond_concat, dim=1)
|
||||
out['c_concat'] = comfy.conds.CONDNoiseShape(data)
|
||||
adm = self.encode_adm(**kwargs)
|
||||
if adm is not None:
|
||||
out['c_adm'] = comfy.conds.CONDRegular(adm)
|
||||
return out
|
||||
|
||||
def load_model_weights(self, sd, unet_prefix=""):
|
||||
to_load = {}
|
||||
|
Reference in New Issue
Block a user