Try to free enough vram for control lora inference.

This commit is contained in:
comfyanonymous
2023-08-24 17:20:54 -04:00
parent e3d0a9a490
commit 51dde87e97
4 changed files with 30 additions and 18 deletions

View File

@@ -779,6 +779,11 @@ class ControlBase:
c.strength = self.strength
c.timestep_percent_range = self.timestep_percent_range
def inference_memory_requirements(self, dtype):
if self.previous_controlnet is not None:
return self.previous_controlnet.inference_memory_requirements(dtype)
return 0
def control_merge(self, control_input, control_output, control_prev, output_dtype):
out = {'input':[], 'middle':[], 'output': []}
@@ -985,6 +990,9 @@ class ControlLora(ControlNet):
out = ControlBase.get_models(self)
return out
def inference_memory_requirements(self, dtype):
return utils.calculate_parameters(self.control_weights) * model_management.dtype_size(dtype) + ControlBase.inference_memory_requirements(self, dtype)
def load_controlnet(ckpt_path, model=None):
controlnet_data = utils.load_torch_file(ckpt_path, safe_load=True)
if "lora_controlnet" in controlnet_data:
@@ -1323,13 +1331,6 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
return (ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device), clip, vae)
def calculate_parameters(sd, prefix):
params = 0
for k in sd.keys():
if k.startswith(prefix):
params += sd[k].nelement()
return params
def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None):
sd = utils.load_torch_file(ckpt_path)
sd_keys = sd.keys()
@@ -1339,7 +1340,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
model = None
clip_target = None
parameters = calculate_parameters(sd, "model.diffusion_model.")
parameters = utils.calculate_parameters(sd, "model.diffusion_model.")
fp16 = model_management.should_use_fp16(model_params=parameters)
class WeightsLoader(torch.nn.Module):
@@ -1390,7 +1391,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
def load_unet(unet_path): #load unet in diffusers format
sd = utils.load_torch_file(unet_path)
parameters = calculate_parameters(sd, "")
parameters = utils.calculate_parameters(sd)
fp16 = model_management.should_use_fp16(model_params=parameters)
model_config = model_detection.model_config_from_diffusers_unet(sd, fp16)