Use faster manual cast for fp8 in unet.

This commit is contained in:
comfyanonymous
2023-12-11 18:24:44 -05:00
parent ab93abd4b2
commit ba07cb748e
5 changed files with 48 additions and 12 deletions

View File

@@ -22,6 +22,8 @@ class BASE:
sampling_settings = {}
latent_format = latent_formats.LatentFormat
manual_cast_dtype = None
@classmethod
def matches(s, unet_config):
for k in s.unet_config:
@@ -71,3 +73,5 @@ class BASE:
replace_prefix = {"": "first_stage_model."}
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
def set_manual_cast(self, manual_cast_dtype):
self.manual_cast_dtype = manual_cast_dtype