mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-09-11 12:06:23 +00:00
Try to free enough vram for control lora inference.
This commit is contained in:
19
comfy/sd.py
19
comfy/sd.py
@@ -779,6 +779,11 @@ class ControlBase:
|
||||
c.strength = self.strength
|
||||
c.timestep_percent_range = self.timestep_percent_range
|
||||
|
||||
def inference_memory_requirements(self, dtype):
|
||||
if self.previous_controlnet is not None:
|
||||
return self.previous_controlnet.inference_memory_requirements(dtype)
|
||||
return 0
|
||||
|
||||
def control_merge(self, control_input, control_output, control_prev, output_dtype):
|
||||
out = {'input':[], 'middle':[], 'output': []}
|
||||
|
||||
@@ -985,6 +990,9 @@ class ControlLora(ControlNet):
|
||||
out = ControlBase.get_models(self)
|
||||
return out
|
||||
|
||||
def inference_memory_requirements(self, dtype):
|
||||
return utils.calculate_parameters(self.control_weights) * model_management.dtype_size(dtype) + ControlBase.inference_memory_requirements(self, dtype)
|
||||
|
||||
def load_controlnet(ckpt_path, model=None):
|
||||
controlnet_data = utils.load_torch_file(ckpt_path, safe_load=True)
|
||||
if "lora_controlnet" in controlnet_data:
|
||||
@@ -1323,13 +1331,6 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
|
||||
|
||||
return (ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device), clip, vae)
|
||||
|
||||
def calculate_parameters(sd, prefix):
|
||||
params = 0
|
||||
for k in sd.keys():
|
||||
if k.startswith(prefix):
|
||||
params += sd[k].nelement()
|
||||
return params
|
||||
|
||||
def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None):
|
||||
sd = utils.load_torch_file(ckpt_path)
|
||||
sd_keys = sd.keys()
|
||||
@@ -1339,7 +1340,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
||||
model = None
|
||||
clip_target = None
|
||||
|
||||
parameters = calculate_parameters(sd, "model.diffusion_model.")
|
||||
parameters = utils.calculate_parameters(sd, "model.diffusion_model.")
|
||||
fp16 = model_management.should_use_fp16(model_params=parameters)
|
||||
|
||||
class WeightsLoader(torch.nn.Module):
|
||||
@@ -1390,7 +1391,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
||||
|
||||
def load_unet(unet_path): #load unet in diffusers format
|
||||
sd = utils.load_torch_file(unet_path)
|
||||
parameters = calculate_parameters(sd, "")
|
||||
parameters = utils.calculate_parameters(sd)
|
||||
fp16 = model_management.should_use_fp16(model_params=parameters)
|
||||
|
||||
model_config = model_detection.model_config_from_diffusers_unet(sd, fp16)
|
||||
|
Reference in New Issue
Block a user