mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-09-12 12:37:01 +00:00
Mixed precision diffusion models with scaled fp8.
This change allows supports for diffusion models where all the linears are scaled fp8 while the other weights are the original precision.
This commit is contained in:
@@ -96,7 +96,7 @@ class BaseModel(torch.nn.Module):
|
||||
|
||||
if not unet_config.get("disable_unet_model_creation", False):
|
||||
if model_config.custom_operations is None:
|
||||
fp8 = model_config.optimizations.get("fp8", model_config.scaled_fp8)
|
||||
fp8 = model_config.optimizations.get("fp8", model_config.scaled_fp8 is not None)
|
||||
operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, scaled_fp8=model_config.scaled_fp8)
|
||||
else:
|
||||
operations = model_config.custom_operations
|
||||
@@ -246,8 +246,8 @@ class BaseModel(torch.nn.Module):
|
||||
|
||||
unet_state_dict = self.diffusion_model.state_dict()
|
||||
|
||||
if self.model_config.scaled_fp8:
|
||||
unet_state_dict["scaled_fp8"] = torch.tensor([])
|
||||
if self.model_config.scaled_fp8 is not None:
|
||||
unet_state_dict["scaled_fp8"] = torch.tensor([], dtype=self.model_config.scaled_fp8)
|
||||
|
||||
unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict)
|
||||
|
||||
|
Reference in New Issue
Block a user