mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-09-12 04:27:21 +00:00
Stable Cascade Stage C.
This commit is contained in:
@@ -22,13 +22,14 @@ class BASE:
|
||||
sampling_settings = {}
|
||||
latent_format = latent_formats.LatentFormat
|
||||
vae_key_prefix = ["first_stage_model."]
|
||||
supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||
|
||||
manual_cast_dtype = None
|
||||
|
||||
@classmethod
|
||||
def matches(s, unet_config):
|
||||
for k in s.unet_config:
|
||||
if s.unet_config[k] != unet_config[k]:
|
||||
if k not in unet_config or s.unet_config[k] != unet_config[k]:
|
||||
return False
|
||||
return True
|
||||
|
||||
@@ -80,5 +81,6 @@ class BASE:
|
||||
replace_prefix = {"": "first_stage_model."}
|
||||
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
|
||||
|
||||
def set_manual_cast(self, manual_cast_dtype):
|
||||
def set_inference_dtype(self, dtype, manual_cast_dtype):
|
||||
self.unet_config['dtype'] = dtype
|
||||
self.manual_cast_dtype = manual_cast_dtype
|
||||
|
Reference in New Issue
Block a user