Automatically use fp8 for diffusion model weights if:

Checkpoint contains weights in fp8.

There isn't enough memory to load the diffusion model in GPU vram.
This commit is contained in:
comfyanonymous
2024-08-03 13:45:19 -04:00
parent f123328b82
commit ba9095e5bd
4 changed files with 34 additions and 4 deletions

View File

@@ -94,6 +94,7 @@ class BaseModel(torch.nn.Module):
self.concat_keys = ()
logging.info("model_type {}".format(model_type.name))
logging.debug("adm {}".format(self.adm_channels))
logging.info("model weight dtype {}, manual cast: {}".format(self.get_dtype(), self.manual_cast_dtype))
self.memory_usage_factor = model_config.memory_usage_factor
def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):