Implement shift schedule for cascade stage C.

This commit is contained in:
comfyanonymous
2024-02-17 11:38:47 -05:00
parent 929e266f3e
commit 5b40e7a5ed
2 changed files with 30 additions and 3 deletions

View File

@@ -316,6 +316,10 @@ class Stable_Cascade_C(supported_models_base.BASE):
latent_format = latent_formats.SC_Prior
supported_inference_dtypes = [torch.bfloat16, torch.float32]
sampling_settings = {
"shift": 2.0,
}
def process_unet_state_dict(self, state_dict):
key_list = list(state_dict.keys())
for y in ["weight", "bias"]:
@@ -348,6 +352,10 @@ class Stable_Cascade_B(Stable_Cascade_C):
latent_format = latent_formats.SC_B
supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32]
sampling_settings = {
"shift": 1.0,
}
def get_model(self, state_dict, prefix="", device=None):
out = model_base.StableCascade_B(self, device=device)
return out