Stable Cascade Stage C.

This commit is contained in:
comfyanonymous
2024-02-16 10:55:08 -05:00
parent 5e06baf112
commit f83109f09b
11 changed files with 619 additions and 31 deletions

View File

@@ -22,13 +22,14 @@ class BASE:
sampling_settings = {}
latent_format = latent_formats.LatentFormat
vae_key_prefix = ["first_stage_model."]
supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32]
manual_cast_dtype = None
@classmethod
def matches(s, unet_config):
for k in s.unet_config:
if s.unet_config[k] != unet_config[k]:
if k not in unet_config or s.unet_config[k] != unet_config[k]:
return False
return True
@@ -80,5 +81,6 @@ class BASE:
replace_prefix = {"": "first_stage_model."}
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
def set_manual_cast(self, manual_cast_dtype):
def set_inference_dtype(self, dtype, manual_cast_dtype):
self.unet_config['dtype'] = dtype
self.manual_cast_dtype = manual_cast_dtype