Use faster manual cast for fp8 in unet.

This commit is contained in:
comfyanonymous
2023-12-11 18:24:44 -05:00
parent ab93abd4b2
commit ba07cb748e
5 changed files with 48 additions and 12 deletions

View File

@@ -433,11 +433,15 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
parameters = comfy.utils.calculate_parameters(sd, "model.diffusion_model.")
unet_dtype = model_management.unet_dtype(model_params=parameters)
load_device = model_management.get_torch_device()
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device)
class WeightsLoader(torch.nn.Module):
pass
model_config = model_detection.model_config_from_unet(sd, "model.diffusion_model.", unet_dtype)
model_config.set_manual_cast(manual_cast_dtype)
if model_config is None:
raise RuntimeError("ERROR: Could not detect model type of: {}".format(ckpt_path))
@@ -470,7 +474,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
print("left over keys:", left_over)
if output_model:
model_patcher = comfy.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device(), current_device=inital_load_device)
model_patcher = comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device(), current_device=inital_load_device)
if inital_load_device != torch.device("cpu"):
print("loaded straight to GPU")
model_management.load_model_gpu(model_patcher)
@@ -481,6 +485,9 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
def load_unet_state_dict(sd): #load unet in diffusers format
parameters = comfy.utils.calculate_parameters(sd)
unet_dtype = model_management.unet_dtype(model_params=parameters)
load_device = model_management.get_torch_device()
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device)
if "input_blocks.0.0.weight" in sd: #ldm
model_config = model_detection.model_config_from_unet(sd, "", unet_dtype)
if model_config is None:
@@ -501,13 +508,14 @@ def load_unet_state_dict(sd): #load unet in diffusers format
else:
print(diffusers_keys[k], k)
offload_device = model_management.unet_offload_device()
model_config.set_manual_cast(manual_cast_dtype)
model = model_config.get_model(new_sd, "")
model = model.to(offload_device)
model.load_model_weights(new_sd, "")
left_over = sd.keys()
if len(left_over) > 0:
print("left over keys in unet:", left_over)
return comfy.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device)
return comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=offload_device)
def load_unet(unet_path):
sd = comfy.utils.load_torch_file(unet_path)