Smarter memory management.

Try to keep models on the vram when possible.

Better lowvram mode for controlnets.
This commit is contained in:
comfyanonymous
2023-08-17 01:06:34 -04:00
parent 2c97c30256
commit 89a0767abf
6 changed files with 230 additions and 168 deletions

View File

@@ -244,7 +244,7 @@ def set_attr(obj, attr, value):
del prev
class ModelPatcher:
def __init__(self, model, load_device, offload_device, size=0):
def __init__(self, model, load_device, offload_device, size=0, current_device=None):
self.size = size
self.model = model
self.patches = {}
@@ -253,6 +253,10 @@ class ModelPatcher:
self.model_size()
self.load_device = load_device
self.offload_device = offload_device
if current_device is None:
self.current_device = self.offload_device
else:
self.current_device = current_device
def model_size(self):
if self.size > 0:
@@ -267,7 +271,7 @@ class ModelPatcher:
return size
def clone(self):
n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size)
n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device)
n.patches = {}
for k in self.patches:
n.patches[k] = self.patches[k][:]
@@ -276,6 +280,11 @@ class ModelPatcher:
n.model_keys = self.model_keys
return n
def is_clone(self, other):
if hasattr(other, 'model') and self.model is other.model:
return True
return False
def set_model_sampler_cfg_function(self, sampler_cfg_function):
if len(inspect.signature(sampler_cfg_function).parameters) == 3:
self.model_options["sampler_cfg_function"] = lambda args: sampler_cfg_function(args["cond"], args["uncond"], args["cond_scale"]) #Old way
@@ -390,6 +399,11 @@ class ModelPatcher:
out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype)
set_attr(self.model, key, out_weight)
del temp_weight
if device_to is not None:
self.model.to(device_to)
self.current_device = device_to
return self.model
def calculate_weight(self, patches, weight, key):
@@ -482,7 +496,7 @@ class ModelPatcher:
return weight
def unpatch_model(self):
def unpatch_model(self, device_to=None):
keys = list(self.backup.keys())
for k in keys:
@@ -490,6 +504,11 @@ class ModelPatcher:
self.backup = {}
if device_to is not None:
self.model.to(device_to)
self.current_device = device_to
def load_lora_for_models(model, clip, lora, strength_model, strength_clip):
key_map = model_lora_keys_unet(model.model)
key_map = model_lora_keys_clip(clip.cond_stage_model, key_map)
@@ -630,11 +649,12 @@ class VAE:
return samples
def decode(self, samples_in):
model_management.unload_model()
self.first_stage_model = self.first_stage_model.to(self.device)
try:
memory_used = (2562 * samples_in.shape[2] * samples_in.shape[3] * 64) * 1.4
model_management.free_memory(memory_used, self.device)
free_memory = model_management.get_free_memory(self.device)
batch_number = int((free_memory * 0.7) / (2562 * samples_in.shape[2] * samples_in.shape[3] * 64))
batch_number = int(free_memory / memory_used)
batch_number = max(1, batch_number)
pixel_samples = torch.empty((samples_in.shape[0], 3, round(samples_in.shape[2] * 8), round(samples_in.shape[3] * 8)), device="cpu")
@@ -650,19 +670,19 @@ class VAE:
return pixel_samples
def decode_tiled(self, samples, tile_x=64, tile_y=64, overlap = 16):
model_management.unload_model()
self.first_stage_model = self.first_stage_model.to(self.device)
output = self.decode_tiled_(samples, tile_x, tile_y, overlap)
self.first_stage_model = self.first_stage_model.to(self.offload_device)
return output.movedim(1,-1)
def encode(self, pixel_samples):
model_management.unload_model()
self.first_stage_model = self.first_stage_model.to(self.device)
pixel_samples = pixel_samples.movedim(-1,1)
try:
memory_used = (2078 * pixel_samples.shape[2] * pixel_samples.shape[3]) * 1.4 #NOTE: this constant along with the one in the decode above are estimated from the mem usage for the VAE and could change.
model_management.free_memory(memory_used, self.device)
free_memory = model_management.get_free_memory(self.device)
batch_number = int((free_memory * 0.7) / (2078 * pixel_samples.shape[2] * pixel_samples.shape[3])) #NOTE: this constant along with the one in the decode above are estimated from the mem usage for the VAE and could change.
batch_number = int(free_memory / memory_used)
batch_number = max(1, batch_number)
samples = torch.empty((pixel_samples.shape[0], 4, round(pixel_samples.shape[2] // 8), round(pixel_samples.shape[3] // 8)), device="cpu")
for x in range(0, pixel_samples.shape[0], batch_number):
@@ -677,7 +697,6 @@ class VAE:
return samples
def encode_tiled(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64):
model_management.unload_model()
self.first_stage_model = self.first_stage_model.to(self.device)
pixel_samples = pixel_samples.movedim(-1,1)
samples = self.encode_tiled_(pixel_samples, tile_x=tile_x, tile_y=tile_y, overlap=overlap)
@@ -757,6 +776,7 @@ class ControlNet(ControlBase):
def __init__(self, control_model, global_average_pooling=False, device=None):
super().__init__(device)
self.control_model = control_model
self.control_model_wrapped = ModelPatcher(self.control_model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device())
self.global_average_pooling = global_average_pooling
def get_control(self, x_noisy, t, cond, batched_number):
@@ -786,11 +806,9 @@ class ControlNet(ControlBase):
precision_scope = contextlib.nullcontext
with precision_scope(model_management.get_autocast_device(self.device)):
self.control_model = model_management.load_if_low_vram(self.control_model)
context = torch.cat(cond['c_crossattn'], 1)
y = cond.get('c_adm', None)
control = self.control_model(x=x_noisy, hint=self.cond_hint, timesteps=t, context=context, y=y)
self.control_model = model_management.unload_if_low_vram(self.control_model)
out = {'middle':[], 'output': []}
autocast_enabled = torch.is_autocast_enabled()
@@ -825,7 +843,7 @@ class ControlNet(ControlBase):
def get_models(self):
out = super().get_models()
out.append(self.control_model)
out.append(self.control_model_wrapped)
return out
@@ -1004,7 +1022,6 @@ class T2IAdapter(ControlBase):
self.copy_to(c)
return c
def load_t2i_adapter(t2i_data):
keys = t2i_data.keys()
if 'adapter' in keys:
@@ -1090,7 +1107,7 @@ def load_gligen(ckpt_path):
model = gligen.load_gligen(data)
if model_management.should_use_fp16():
model = model.half()
return model
return ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device())
def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_clip=True, embedding_directory=None, state_dict=None, config=None):
#TODO: this function is a mess and should be removed eventually
@@ -1202,8 +1219,13 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
if output_clipvision:
clipvision = clip_vision.load_clipvision_from_sd(sd, model_config.clip_vision_prefix, True)
dtype = torch.float32
if fp16:
dtype = torch.float16
inital_load_device = model_management.unet_inital_load_device(parameters, dtype)
offload_device = model_management.unet_offload_device()
model = model_config.get_model(sd, "model.diffusion_model.", device=offload_device)
model = model_config.get_model(sd, "model.diffusion_model.", device=inital_load_device)
model.load_model_weights(sd, "model.diffusion_model.")
if output_vae:
@@ -1224,7 +1246,12 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
if len(left_over) > 0:
print("left over keys:", left_over)
return (ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device), clip, vae, clipvision)
model_patcher = ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device(), current_device=inital_load_device)
if inital_load_device != torch.device("cpu"):
print("loaded straight to GPU")
model_management.load_model_gpu(model_patcher)
return (model_patcher, clip, vae, clipvision)
def load_unet(unet_path): #load unet in diffusers format