mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-09-11 12:06:23 +00:00
Stable Cascade Stage C.
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
import torch
|
||||
from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep
|
||||
from comfy.ldm.cascade.stage_c import StageC
|
||||
from comfy.ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation
|
||||
from comfy.ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation
|
||||
import comfy.model_management
|
||||
@@ -12,9 +13,10 @@ class ModelType(Enum):
|
||||
EPS = 1
|
||||
V_PREDICTION = 2
|
||||
V_PREDICTION_EDM = 3
|
||||
STABLE_CASCADE = 4
|
||||
|
||||
|
||||
from comfy.model_sampling import EPS, V_PREDICTION, ModelSamplingDiscrete, ModelSamplingContinuousEDM
|
||||
from comfy.model_sampling import EPS, V_PREDICTION, ModelSamplingDiscrete, ModelSamplingContinuousEDM, StableCascadeSampling
|
||||
|
||||
|
||||
def model_sampling(model_config, model_type):
|
||||
@@ -27,6 +29,9 @@ def model_sampling(model_config, model_type):
|
||||
elif model_type == ModelType.V_PREDICTION_EDM:
|
||||
c = V_PREDICTION
|
||||
s = ModelSamplingContinuousEDM
|
||||
elif model_type == ModelType.STABLE_CASCADE:
|
||||
c = EPS
|
||||
s = StableCascadeSampling
|
||||
|
||||
class ModelSampling(s, c):
|
||||
pass
|
||||
@@ -35,7 +40,7 @@ def model_sampling(model_config, model_type):
|
||||
|
||||
|
||||
class BaseModel(torch.nn.Module):
|
||||
def __init__(self, model_config, model_type=ModelType.EPS, device=None):
|
||||
def __init__(self, model_config, model_type=ModelType.EPS, device=None, unet_model=UNetModel):
|
||||
super().__init__()
|
||||
|
||||
unet_config = model_config.unet_config
|
||||
@@ -48,7 +53,7 @@ class BaseModel(torch.nn.Module):
|
||||
operations = comfy.ops.manual_cast
|
||||
else:
|
||||
operations = comfy.ops.disable_weight_init
|
||||
self.diffusion_model = UNetModel(**unet_config, device=device, operations=operations)
|
||||
self.diffusion_model = unet_model(**unet_config, device=device, operations=operations)
|
||||
self.model_type = model_type
|
||||
self.model_sampling = model_sampling(model_config, model_type)
|
||||
|
||||
@@ -427,3 +432,32 @@ class SD_X4Upscaler(BaseModel):
|
||||
out['c_concat'] = comfy.conds.CONDNoiseShape(image)
|
||||
out['y'] = comfy.conds.CONDRegular(noise_level)
|
||||
return out
|
||||
|
||||
class StableCascade_C(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.STABLE_CASCADE, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=StageC)
|
||||
self.diffusion_model.eval().requires_grad_(False)
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = {}
|
||||
clip_text_pooled = kwargs["pooled_output"]
|
||||
if clip_text_pooled is not None:
|
||||
out['clip_text_pooled'] = comfy.conds.CONDRegular(clip_text_pooled)
|
||||
|
||||
if "unclip_conditioning" in kwargs:
|
||||
embeds = []
|
||||
for unclip_cond in kwargs["unclip_conditioning"]:
|
||||
weight = unclip_cond["strength"]
|
||||
embeds.append(unclip_cond["clip_vision_output"].image_embeds.unsqueeze(0) * weight)
|
||||
clip_img = torch.cat(embeds, dim=1)
|
||||
else:
|
||||
clip_img = torch.zeros((1, 1, 768))
|
||||
out["clip_img"] = comfy.conds.CONDRegular(clip_img)
|
||||
out["sca"] = comfy.conds.CONDRegular(torch.zeros((1,)))
|
||||
out["crp"] = comfy.conds.CONDRegular(torch.zeros((1,)))
|
||||
|
||||
cross_attn = kwargs.get("cross_attn", None)
|
||||
if cross_attn is not None:
|
||||
out['clip_text'] = comfy.conds.CONDCrossAttn(cross_attn)
|
||||
return out
|
||||
|
||||
|
Reference in New Issue
Block a user