mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-09-11 03:58:22 +00:00
Stable Cascade Stage C.
This commit is contained in:
@@ -306,5 +306,38 @@ class SD_X4Upscaler(SD20):
|
||||
out = model_base.SD_X4Upscaler(self, device=device)
|
||||
return out
|
||||
|
||||
models = [Stable_Zero123, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXLRefiner, SDXL, SSD1B, Segmind_Vega, SD_X4Upscaler]
|
||||
class Stable_Cascade_C(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"stable_cascade_stage": 'c',
|
||||
}
|
||||
|
||||
unet_extra_config = {}
|
||||
|
||||
latent_format = latent_formats.SC_Prior
|
||||
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
||||
|
||||
def process_unet_state_dict(self, state_dict):
|
||||
key_list = list(state_dict.keys())
|
||||
for y in ["weight", "bias"]:
|
||||
suffix = "in_proj_{}".format(y)
|
||||
keys = filter(lambda a: a.endswith(suffix), key_list)
|
||||
for k_from in keys:
|
||||
weights = state_dict.pop(k_from)
|
||||
prefix = k_from[:-(len(suffix) + 1)]
|
||||
shape_from = weights.shape[0] // 3
|
||||
for x in range(3):
|
||||
p = ["to_q", "to_k", "to_v"]
|
||||
k_to = "{}.{}.{}".format(prefix, p[x], y)
|
||||
state_dict[k_to] = weights[shape_from*x:shape_from*(x + 1)]
|
||||
return state_dict
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.StableCascade_C(self, device=device)
|
||||
return out
|
||||
|
||||
def clip_target(self):
|
||||
return None
|
||||
|
||||
|
||||
models = [Stable_Zero123, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXLRefiner, SDXL, SSD1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C]
|
||||
models += [SVD_img2vid]
|
||||
|
Reference in New Issue
Block a user