Stable Cascade Stage B.

This commit is contained in:
comfyanonymous
2024-02-16 12:56:11 -05:00
parent f83109f09b
commit 667c92814e
10 changed files with 430 additions and 8 deletions

View File

@@ -338,6 +338,20 @@ class Stable_Cascade_C(supported_models_base.BASE):
def clip_target(self):
return None
class Stable_Cascade_B(Stable_Cascade_C):
unet_config = {
"stable_cascade_stage": 'b',
}
models = [Stable_Zero123, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXLRefiner, SDXL, SSD1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C]
unet_extra_config = {}
latent_format = latent_formats.SC_B
supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32]
def get_model(self, state_dict, prefix="", device=None):
out = model_base.StableCascade_B(self, device=device)
return out
models = [Stable_Zero123, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXLRefiner, SDXL, SSD1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B]
models += [SVD_img2vid]