Support diffusion models with scaled fp8 weights.

This commit is contained in:
comfyanonymous
2024-10-19 23:47:42 -04:00
parent 73e3a9e676
commit a68bbafddb
7 changed files with 95 additions and 23 deletions

View File

@@ -96,7 +96,8 @@ class BaseModel(torch.nn.Module):
if not unet_config.get("disable_unet_model_creation", False):
if model_config.custom_operations is None:
operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=model_config.optimizations.get("fp8", False))
fp8 = model_config.optimizations.get("fp8", model_config.scaled_fp8)
operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, scaled_fp8=model_config.scaled_fp8)
else:
operations = model_config.custom_operations
self.diffusion_model = unet_model(**unet_config, device=device, operations=operations)
@@ -244,6 +245,10 @@ class BaseModel(torch.nn.Module):
extra_sds.append(self.model_config.process_clip_vision_state_dict_for_saving(clip_vision_state_dict))
unet_state_dict = self.diffusion_model.state_dict()
if self.model_config.scaled_fp8:
unet_state_dict["scaled_fp8"] = torch.tensor([])
unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict)
if self.model_type == ModelType.V_PREDICTION: