""" This file is part of ComfyUI. Copyright (C) 2024 Comfy This program is free software: you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details. You should have received a copy of the GNU General Public License along with this program. If not, see . """ from __future__ import annotations import collections import copy import inspect import logging import math import uuid from typing import Callable, Optional, Union import torch import comfy.float import comfy.hooks import comfy.lora import comfy.model_management import comfy.patcher_extension import comfy.utils from comfy.comfy_types import UnetWrapperFunction from comfy.patcher_extension import CallbacksMP, PatcherInjection, WrappersMP def string_to_seed(data): crc = 0xFFFFFFFF for byte in data: if isinstance(byte, str): byte = ord(byte) crc ^= byte for _ in range(8): if crc & 1: crc = (crc >> 1) ^ 0xEDB88320 else: crc >>= 1 return crc ^ 0xFFFFFFFF def set_model_options_patch_replace(model_options, patch, name, block_name, number, transformer_index=None): to = model_options["transformer_options"].copy() if "patches_replace" not in to: to["patches_replace"] = {} else: to["patches_replace"] = to["patches_replace"].copy() if name not in to["patches_replace"]: to["patches_replace"][name] = {} else: to["patches_replace"][name] = to["patches_replace"][name].copy() if transformer_index is not None: block = (block_name, number, transformer_index) else: block = (block_name, number) to["patches_replace"][name][block] = patch model_options["transformer_options"] = to return model_options def set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=False): model_options["sampler_post_cfg_function"] = model_options.get("sampler_post_cfg_function", []) + [post_cfg_function] if disable_cfg1_optimization: model_options["disable_cfg1_optimization"] = True return model_options def set_model_options_pre_cfg_function(model_options, pre_cfg_function, disable_cfg1_optimization=False): model_options["sampler_pre_cfg_function"] = model_options.get("sampler_pre_cfg_function", []) + [pre_cfg_function] if disable_cfg1_optimization: model_options["disable_cfg1_optimization"] = True return model_options def add_model_options_memory_reserve(model_options, memory_reserve_gb: float): if "model_memory_reserve" not in model_options: model_options["model_memory_reserve"] = [] model_options["model_memory_reserve"].append(memory_reserve_gb) return model_options def create_model_options_clone(orig_model_options: dict): return comfy.patcher_extension.copy_nested_dicts(orig_model_options) def create_hook_patches_clone(orig_hook_patches): new_hook_patches = {} for hook_ref in orig_hook_patches: new_hook_patches[hook_ref] = {} for k in orig_hook_patches[hook_ref]: new_hook_patches[hook_ref][k] = orig_hook_patches[hook_ref][k][:] return new_hook_patches def wipe_lowvram_weight(m): if hasattr(m, "prev_comfy_cast_weights"): m.comfy_cast_weights = m.prev_comfy_cast_weights del m.prev_comfy_cast_weights if hasattr(m, "weight_function"): m.weight_function = [] if hasattr(m, "bias_function"): m.bias_function = [] def move_weight_functions(m, device): if device is None: return 0 memory = 0 if hasattr(m, "weight_function"): for f in m.weight_function: if hasattr(f, "move_to"): memory += f.move_to(device=device) if hasattr(m, "bias_function"): for f in m.bias_function: if hasattr(f, "move_to"): memory += f.move_to(device=device) return memory class LowVramPatch: def __init__(self, key, patches): self.key = key self.patches = patches def __call__(self, weight): intermediate_dtype = weight.dtype if intermediate_dtype not in [torch.float32, torch.float16, torch.bfloat16]: #intermediate_dtype has to be one that is supported in math ops intermediate_dtype = torch.float32 return comfy.float.stochastic_rounding(comfy.lora.calculate_weight(self.patches[self.key], weight.to(intermediate_dtype), self.key, intermediate_dtype=intermediate_dtype), weight.dtype, seed=string_to_seed(self.key)) return comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=intermediate_dtype) def get_key_weight(model, key): set_func = None convert_func = None op_keys = key.rsplit('.', 1) if len(op_keys) < 2: weight = comfy.utils.get_attr(model, key) else: op = comfy.utils.get_attr(model, op_keys[0]) try: set_func = getattr(op, "set_{}".format(op_keys[1])) except AttributeError: pass try: convert_func = getattr(op, "convert_{}".format(op_keys[1])) except AttributeError: pass weight = getattr(op, op_keys[1]) if convert_func is not None: weight = comfy.utils.get_attr(model, key) return weight, set_func, convert_func class AutoPatcherEjector: def __init__(self, model: 'ModelPatcher', skip_and_inject_on_exit_only=False): self.model = model self.was_injected = False self.prev_skip_injection = False self.skip_and_inject_on_exit_only = skip_and_inject_on_exit_only def __enter__(self): self.was_injected = False self.prev_skip_injection = self.model.skip_injection if self.skip_and_inject_on_exit_only: self.model.skip_injection = True if self.model.is_injected: self.model.eject_model() self.was_injected = True def __exit__(self, *args): if self.skip_and_inject_on_exit_only: self.model.skip_injection = self.prev_skip_injection self.model.inject_model() if self.was_injected and not self.model.skip_injection: self.model.inject_model() self.model.skip_injection = self.prev_skip_injection class MemoryCounter: def __init__(self, initial: int, minimum=0): self.value = initial self.minimum = minimum # TODO: add a safe limit besides 0 def use(self, weight: torch.Tensor): weight_size = weight.nelement() * weight.element_size() if self.is_useable(weight_size): self.decrement(weight_size) return True return False def is_useable(self, used: int): return self.value - used > self.minimum def decrement(self, used: int): self.value -= used class ModelPatcher: def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False): self.size = size self.model = model if not hasattr(self.model, 'device'): logging.debug("Model doesn't have a device attribute.") self.model.device = offload_device elif self.model.device is None: self.model.device = offload_device self.patches = {} self.backup = {} self.object_patches = {} self.object_patches_backup = {} self.weight_wrapper_patches = {} self.model_options = {"transformer_options":{}} self.model_size() self.load_device = load_device self.offload_device = offload_device self.weight_inplace_update = weight_inplace_update self.force_cast_weights = False self.patches_uuid = uuid.uuid4() self.parent = None self.attachments: dict[str] = {} self.additional_models: dict[str, list[ModelPatcher]] = {} self.callbacks: dict[str, dict[str, list[Callable]]] = CallbacksMP.init_callbacks() self.wrappers: dict[str, dict[str, list[Callable]]] = WrappersMP.init_wrappers() self.is_injected = False self.skip_injection = False self.injections: dict[str, list[PatcherInjection]] = {} self.hook_patches: dict[comfy.hooks._HookRef] = {} self.hook_patches_backup: dict[comfy.hooks._HookRef] = None self.hook_backup: dict[str, tuple[torch.Tensor, torch.device]] = {} self.cached_hook_patches: dict[comfy.hooks.HookGroup, dict[str, torch.Tensor]] = {} self.current_hooks: Optional[comfy.hooks.HookGroup] = None self.forced_hooks: Optional[comfy.hooks.HookGroup] = None # NOTE: only used for CLIP at this time self.is_clip = False self.hook_mode = comfy.hooks.EnumHookMode.MaxSpeed if not hasattr(self.model, 'model_loaded_weight_memory'): self.model.model_loaded_weight_memory = 0 if not hasattr(self.model, 'lowvram_patch_counter'): self.model.lowvram_patch_counter = 0 if not hasattr(self.model, 'model_lowvram'): self.model.model_lowvram = False if not hasattr(self.model, 'current_weight_patches_uuid'): self.model.current_weight_patches_uuid = None def model_size(self): if self.size > 0: return self.size self.size = comfy.model_management.module_size(self.model) return self.size def loaded_size(self): return self.model.model_loaded_weight_memory def lowvram_patch_counter(self): return self.model.lowvram_patch_counter def clone(self): n = self.__class__(self.model, self.load_device, self.offload_device, self.size, weight_inplace_update=self.weight_inplace_update) n.patches = {} for k in self.patches: n.patches[k] = self.patches[k][:] n.patches_uuid = self.patches_uuid n.object_patches = self.object_patches.copy() n.weight_wrapper_patches = self.weight_wrapper_patches.copy() n.model_options = copy.deepcopy(self.model_options) n.backup = self.backup n.object_patches_backup = self.object_patches_backup n.parent = self n.force_cast_weights = self.force_cast_weights # attachments n.attachments = {} for k in self.attachments: if hasattr(self.attachments[k], "on_model_patcher_clone"): n.attachments[k] = self.attachments[k].on_model_patcher_clone() else: n.attachments[k] = self.attachments[k] # additional models for k, c in self.additional_models.items(): n.additional_models[k] = [x.clone() for x in c] # callbacks for k, c in self.callbacks.items(): n.callbacks[k] = {} for k1, c1 in c.items(): n.callbacks[k][k1] = c1.copy() # sample wrappers for k, w in self.wrappers.items(): n.wrappers[k] = {} for k1, w1 in w.items(): n.wrappers[k][k1] = w1.copy() # injection n.is_injected = self.is_injected n.skip_injection = self.skip_injection for k, i in self.injections.items(): n.injections[k] = i.copy() # hooks n.hook_patches = create_hook_patches_clone(self.hook_patches) n.hook_patches_backup = create_hook_patches_clone(self.hook_patches_backup) if self.hook_patches_backup else self.hook_patches_backup for group in self.cached_hook_patches: n.cached_hook_patches[group] = {} for k in self.cached_hook_patches[group]: n.cached_hook_patches[group][k] = self.cached_hook_patches[group][k] n.hook_backup = self.hook_backup n.current_hooks = self.current_hooks.clone() if self.current_hooks else self.current_hooks n.forced_hooks = self.forced_hooks.clone() if self.forced_hooks else self.forced_hooks n.is_clip = self.is_clip n.hook_mode = self.hook_mode for callback in self.get_all_callbacks(CallbacksMP.ON_CLONE): callback(self, n) return n def is_clone(self, other): if hasattr(other, 'model') and self.model is other.model: return True return False def clone_has_same_weights(self, clone: 'ModelPatcher'): if not self.is_clone(clone): return False if self.current_hooks != clone.current_hooks: return False if self.forced_hooks != clone.forced_hooks: return False if self.hook_patches.keys() != clone.hook_patches.keys(): return False if self.attachments.keys() != clone.attachments.keys(): return False if self.additional_models.keys() != clone.additional_models.keys(): return False for key in self.callbacks: if len(self.callbacks[key]) != len(clone.callbacks[key]): return False for key in self.wrappers: if len(self.wrappers[key]) != len(clone.wrappers[key]): return False if self.injections.keys() != clone.injections.keys(): return False if len(self.patches) == 0 and len(clone.patches) == 0: return True if self.patches_uuid == clone.patches_uuid: if len(self.patches) != len(clone.patches): logging.warning("WARNING: something went wrong, same patch uuid but different length of patches.") else: return True def memory_required(self, input_shape): return self.model.memory_required(input_shape=input_shape) def set_model_sampler_cfg_function(self, sampler_cfg_function, disable_cfg1_optimization=False): if len(inspect.signature(sampler_cfg_function).parameters) == 3: self.model_options["sampler_cfg_function"] = lambda args: sampler_cfg_function(args["cond"], args["uncond"], args["cond_scale"]) #Old way else: self.model_options["sampler_cfg_function"] = sampler_cfg_function if disable_cfg1_optimization: self.model_options["disable_cfg1_optimization"] = True def set_model_sampler_post_cfg_function(self, post_cfg_function, disable_cfg1_optimization=False): self.model_options = set_model_options_post_cfg_function(self.model_options, post_cfg_function, disable_cfg1_optimization) def set_model_sampler_pre_cfg_function(self, pre_cfg_function, disable_cfg1_optimization=False): self.model_options = set_model_options_pre_cfg_function(self.model_options, pre_cfg_function, disable_cfg1_optimization) def set_model_sampler_calc_cond_batch_function(self, sampler_calc_cond_batch_function): self.model_options["sampler_calc_cond_batch_function"] = sampler_calc_cond_batch_function def set_model_unet_function_wrapper(self, unet_wrapper_function: UnetWrapperFunction): self.model_options["model_function_wrapper"] = unet_wrapper_function def set_model_denoise_mask_function(self, denoise_mask_function): self.model_options["denoise_mask_function"] = denoise_mask_function def set_model_patch(self, patch, name): to = self.model_options["transformer_options"] if "patches" not in to: to["patches"] = {} to["patches"][name] = to["patches"].get(name, []) + [patch] def set_model_patch_replace(self, patch, name, block_name, number, transformer_index=None): self.model_options = set_model_options_patch_replace(self.model_options, patch, name, block_name, number, transformer_index=transformer_index) def set_model_attn1_patch(self, patch): self.set_model_patch(patch, "attn1_patch") def set_model_attn2_patch(self, patch): self.set_model_patch(patch, "attn2_patch") def set_model_attn1_replace(self, patch, block_name, number, transformer_index=None): self.set_model_patch_replace(patch, "attn1", block_name, number, transformer_index) def set_model_attn2_replace(self, patch, block_name, number, transformer_index=None): self.set_model_patch_replace(patch, "attn2", block_name, number, transformer_index) def set_model_attn1_output_patch(self, patch): self.set_model_patch(patch, "attn1_output_patch") def set_model_attn2_output_patch(self, patch): self.set_model_patch(patch, "attn2_output_patch") def set_model_input_block_patch(self, patch): self.set_model_patch(patch, "input_block_patch") def set_model_input_block_patch_after_skip(self, patch): self.set_model_patch(patch, "input_block_patch_after_skip") def set_model_output_block_patch(self, patch): self.set_model_patch(patch, "output_block_patch") def set_model_emb_patch(self, patch): self.set_model_patch(patch, "emb_patch") def set_model_forward_timestep_embed_patch(self, patch): self.set_model_patch(patch, "forward_timestep_embed_patch") def add_object_patch(self, name, obj): self.object_patches[name] = obj def set_model_compute_dtype(self, dtype): self.add_object_patch("manual_cast_dtype", dtype) if dtype is not None: self.force_cast_weights = True self.patches_uuid = uuid.uuid4() #TODO: optimize by preventing a full model reload for this def add_model_memory_reserve(self, memory_reserve_gb: float): """Adds additional expected memory usage for the model, in gigabytes.""" self.model_options = add_model_options_memory_reserve(self.model_options, memory_reserve_gb) def get_model_memory_reserve(self, convert_to_bytes: bool = False) -> Union[float, int]: """Returns the total expected memory usage for the model in gigabytes, or bytes if convert_to_bytes is True.""" total_reserve = sum(self.model_options.get("model_memory_reserve", [])) if convert_to_bytes: return total_reserve * 1024 * 1024 * 1024 return total_reserve def add_weight_wrapper(self, name, function): self.weight_wrapper_patches[name] = self.weight_wrapper_patches.get(name, []) + [function] self.patches_uuid = uuid.uuid4() def get_model_object(self, name: str) -> torch.nn.Module: """Retrieves a nested attribute from an object using dot notation considering object patches. Args: name (str): The attribute path using dot notation (e.g. "model.layer.weight") Returns: The value of the requested attribute Example: patcher = ModelPatcher() weight = patcher.get_model_object("layer1.conv.weight") """ if name in self.object_patches: return self.object_patches[name] else: if name in self.object_patches_backup: return self.object_patches_backup[name] else: return comfy.utils.get_attr(self.model, name) def model_patches_to(self, device): to = self.model_options["transformer_options"] if "patches" in to: patches = to["patches"] for name in patches: patch_list = patches[name] for i in range(len(patch_list)): if hasattr(patch_list[i], "to"): patch_list[i] = patch_list[i].to(device) if "patches_replace" in to: patches = to["patches_replace"] for name in patches: patch_list = patches[name] for k in patch_list: if hasattr(patch_list[k], "to"): patch_list[k] = patch_list[k].to(device) if "model_function_wrapper" in self.model_options: wrap_func = self.model_options["model_function_wrapper"] if hasattr(wrap_func, "to"): self.model_options["model_function_wrapper"] = wrap_func.to(device) def model_dtype(self): if hasattr(self.model, "get_dtype"): return self.model.get_dtype() def add_patches(self, patches, strength_patch=1.0, strength_model=1.0): with self.use_ejected(): p = set() model_sd = self.model.state_dict() for k in patches: offset = None function = None if isinstance(k, str): key = k else: offset = k[1] key = k[0] if len(k) > 2: function = k[2] if key in model_sd: p.add(k) current_patches = self.patches.get(key, []) current_patches.append((strength_patch, patches[k], strength_model, offset, function)) self.patches[key] = current_patches self.patches_uuid = uuid.uuid4() return list(p) def get_key_patches(self, filter_prefix=None): model_sd = self.model_state_dict() p = {} for k in model_sd: if filter_prefix is not None: if not k.startswith(filter_prefix): continue bk = self.backup.get(k, None) hbk = self.hook_backup.get(k, None) weight, set_func, convert_func = get_key_weight(self.model, k) if bk is not None: weight = bk.weight if hbk is not None: weight = hbk[0] if convert_func is None: convert_func = lambda a, **kwargs: a if k in self.patches: p[k] = [(weight, convert_func)] + self.patches[k] else: p[k] = [(weight, convert_func)] return p def model_state_dict(self, filter_prefix=None): with self.use_ejected(): sd = self.model.state_dict() keys = list(sd.keys()) if filter_prefix is not None: for k in keys: if not k.startswith(filter_prefix): sd.pop(k) return sd def patch_weight_to_device(self, key, device_to=None, inplace_update=False): if key not in self.patches: return weight, set_func, convert_func = get_key_weight(self.model, key) inplace_update = self.weight_inplace_update or inplace_update if key not in self.backup: self.backup[key] = collections.namedtuple('Dimension', ['weight', 'inplace_update'])(weight.to(device=self.offload_device, copy=inplace_update), inplace_update) if device_to is not None: temp_weight = comfy.model_management.cast_to_device(weight, device_to, torch.float32, copy=True) else: temp_weight = weight.to(torch.float32, copy=True) if convert_func is not None: temp_weight = convert_func(temp_weight, inplace=True) out_weight = comfy.lora.calculate_weight(self.patches[key], temp_weight, key) if set_func is None: out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=string_to_seed(key)) if inplace_update: comfy.utils.copy_to_param(self.model, key, out_weight) else: comfy.utils.set_attr_param(self.model, key, out_weight) else: set_func(out_weight, inplace_update=inplace_update, seed=string_to_seed(key)) def _load_list(self): loading = [] for n, m in self.model.named_modules(): params = [] skip = False for name, param in m.named_parameters(recurse=False): params.append(name) for name, param in m.named_parameters(recurse=True): if name not in params: skip = True # skip random weights in non leaf modules break if not skip and (hasattr(m, "comfy_cast_weights") or len(params) > 0): loading.append((comfy.model_management.module_size(m), n, m, params)) return loading def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False): with self.use_ejected(): self.unpatch_hooks() mem_counter = 0 patch_counter = 0 lowvram_counter = 0 loading = self._load_list() load_completely = [] loading.sort(reverse=True) for x in loading: n = x[1] m = x[2] params = x[3] module_mem = x[0] lowvram_weight = False weight_key = "{}.weight".format(n) bias_key = "{}.bias".format(n) if not full_load and hasattr(m, "comfy_cast_weights"): if mem_counter + module_mem >= lowvram_model_memory: lowvram_weight = True lowvram_counter += 1 if hasattr(m, "prev_comfy_cast_weights"): #Already lowvramed continue cast_weight = self.force_cast_weights if lowvram_weight: if hasattr(m, "comfy_cast_weights"): m.weight_function = [] m.bias_function = [] if weight_key in self.patches: if force_patch_weights: self.patch_weight_to_device(weight_key) else: m.weight_function = [LowVramPatch(weight_key, self.patches)] patch_counter += 1 if bias_key in self.patches: if force_patch_weights: self.patch_weight_to_device(bias_key) else: m.bias_function = [LowVramPatch(bias_key, self.patches)] patch_counter += 1 cast_weight = True else: if hasattr(m, "comfy_cast_weights"): wipe_lowvram_weight(m) if full_load or mem_counter + module_mem < lowvram_model_memory: mem_counter += module_mem load_completely.append((module_mem, n, m, params)) if cast_weight and hasattr(m, "comfy_cast_weights"): m.prev_comfy_cast_weights = m.comfy_cast_weights m.comfy_cast_weights = True if weight_key in self.weight_wrapper_patches: m.weight_function.extend(self.weight_wrapper_patches[weight_key]) if bias_key in self.weight_wrapper_patches: m.bias_function.extend(self.weight_wrapper_patches[bias_key]) mem_counter += move_weight_functions(m, device_to) load_completely.sort(reverse=True) for x in load_completely: n = x[1] m = x[2] params = x[3] if hasattr(m, "comfy_patched_weights"): if m.comfy_patched_weights == True: continue for param in params: self.patch_weight_to_device("{}.{}".format(n, param), device_to=device_to) logging.debug("lowvram: loaded module regularly {} {}".format(n, m)) m.comfy_patched_weights = True for x in load_completely: x[2].to(device_to) if lowvram_counter > 0: logging.info("loaded partially {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), patch_counter)) self.model.model_lowvram = True else: logging.info("loaded completely {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load)) self.model.model_lowvram = False if full_load: self.model.to(device_to) mem_counter = self.model_size() self.model.lowvram_patch_counter += patch_counter self.model.device = device_to self.model.model_loaded_weight_memory = mem_counter self.model.current_weight_patches_uuid = self.patches_uuid for callback in self.get_all_callbacks(CallbacksMP.ON_LOAD): callback(self, device_to, lowvram_model_memory, force_patch_weights, full_load) self.apply_hooks(self.forced_hooks, force_apply=True) def patch_model(self, device_to=None, lowvram_model_memory=0, load_weights=True, force_patch_weights=False): with self.use_ejected(): for k in self.object_patches: old = comfy.utils.set_attr(self.model, k, self.object_patches[k]) if k not in self.object_patches_backup: self.object_patches_backup[k] = old if lowvram_model_memory == 0: full_load = True else: full_load = False if load_weights: self.load(device_to, lowvram_model_memory=lowvram_model_memory, force_patch_weights=force_patch_weights, full_load=full_load) self.inject_model() return self.model def unpatch_model(self, device_to=None, unpatch_weights=True): self.eject_model() if unpatch_weights: self.unpatch_hooks() if self.model.model_lowvram: for m in self.model.modules(): move_weight_functions(m, device_to) wipe_lowvram_weight(m) self.model.model_lowvram = False self.model.lowvram_patch_counter = 0 keys = list(self.backup.keys()) for k in keys: bk = self.backup[k] if bk.inplace_update: comfy.utils.copy_to_param(self.model, k, bk.weight) else: comfy.utils.set_attr_param(self.model, k, bk.weight) self.model.current_weight_patches_uuid = None self.backup.clear() if device_to is not None: self.model.to(device_to) self.model.device = device_to self.model.model_loaded_weight_memory = 0 for m in self.model.modules(): if hasattr(m, "comfy_patched_weights"): del m.comfy_patched_weights keys = list(self.object_patches_backup.keys()) for k in keys: comfy.utils.set_attr(self.model, k, self.object_patches_backup[k]) self.object_patches_backup.clear() def partially_unload(self, device_to, memory_to_free=0): with self.use_ejected(): hooks_unpatched = False memory_freed = 0 patch_counter = 0 unload_list = self._load_list() unload_list.sort() for unload in unload_list: if memory_to_free < memory_freed: break module_mem = unload[0] n = unload[1] m = unload[2] params = unload[3] lowvram_possible = hasattr(m, "comfy_cast_weights") if hasattr(m, "comfy_patched_weights") and m.comfy_patched_weights == True: move_weight = True for param in params: key = "{}.{}".format(n, param) bk = self.backup.get(key, None) if bk is not None: if not lowvram_possible: move_weight = False break if not hooks_unpatched: self.unpatch_hooks() hooks_unpatched = True if bk.inplace_update: comfy.utils.copy_to_param(self.model, key, bk.weight) else: comfy.utils.set_attr_param(self.model, key, bk.weight) self.backup.pop(key) weight_key = "{}.weight".format(n) bias_key = "{}.bias".format(n) if move_weight: cast_weight = self.force_cast_weights m.to(device_to) module_mem += move_weight_functions(m, device_to) if lowvram_possible: if weight_key in self.patches: m.weight_function.append(LowVramPatch(weight_key, self.patches)) patch_counter += 1 if bias_key in self.patches: m.bias_function.append(LowVramPatch(bias_key, self.patches)) patch_counter += 1 cast_weight = True if cast_weight: m.prev_comfy_cast_weights = m.comfy_cast_weights m.comfy_cast_weights = True m.comfy_patched_weights = False memory_freed += module_mem logging.debug("freed {}".format(n)) self.model.model_lowvram = True self.model.lowvram_patch_counter += patch_counter self.model.model_loaded_weight_memory -= memory_freed return memory_freed def partially_load(self, device_to, extra_memory=0, force_patch_weights=False): with self.use_ejected(skip_and_inject_on_exit_only=True): unpatch_weights = self.model.current_weight_patches_uuid is not None and (self.model.current_weight_patches_uuid != self.patches_uuid or force_patch_weights) # TODO: force_patch_weights should not unload + reload full model used = self.model.model_loaded_weight_memory self.unpatch_model(self.offload_device, unpatch_weights=unpatch_weights) if unpatch_weights: extra_memory += (used - self.model.model_loaded_weight_memory) self.patch_model(load_weights=False) full_load = False if self.model.model_lowvram == False and self.model.model_loaded_weight_memory > 0: self.apply_hooks(self.forced_hooks, force_apply=True) return 0 if self.model.model_loaded_weight_memory + extra_memory > self.model_size(): full_load = True current_used = self.model.model_loaded_weight_memory try: self.load(device_to, lowvram_model_memory=current_used + extra_memory, force_patch_weights=force_patch_weights, full_load=full_load) except Exception as e: self.detach() raise e return self.model.model_loaded_weight_memory - current_used def detach(self, unpatch_all=True): self.eject_model() self.model_patches_to(self.offload_device) if unpatch_all: self.unpatch_model(self.offload_device, unpatch_weights=unpatch_all) for callback in self.get_all_callbacks(CallbacksMP.ON_DETACH): callback(self, unpatch_all) return self.model def current_loaded_device(self): return self.model.device def calculate_weight(self, patches, weight, key, intermediate_dtype=torch.float32): logging.warning("The ModelPatcher.calculate_weight function is deprecated, please use: comfy.lora.calculate_weight instead") return comfy.lora.calculate_weight(patches, weight, key, intermediate_dtype=intermediate_dtype) def cleanup(self): self.clean_hooks() if hasattr(self.model, "current_patcher"): self.model.current_patcher = None for callback in self.get_all_callbacks(CallbacksMP.ON_CLEANUP): callback(self) def add_callback(self, call_type: str, callback: Callable): self.add_callback_with_key(call_type, None, callback) def add_callback_with_key(self, call_type: str, key: str, callback: Callable): c = self.callbacks.setdefault(call_type, {}).setdefault(key, []) c.append(callback) def remove_callbacks_with_key(self, call_type: str, key: str): c = self.callbacks.get(call_type, {}) if key in c: c.pop(key) def get_callbacks(self, call_type: str, key: str): return self.callbacks.get(call_type, {}).get(key, []) def get_all_callbacks(self, call_type: str): c_list = [] for c in self.callbacks.get(call_type, {}).values(): c_list.extend(c) return c_list def add_wrapper(self, wrapper_type: str, wrapper: Callable): self.add_wrapper_with_key(wrapper_type, None, wrapper) def add_wrapper_with_key(self, wrapper_type: str, key: str, wrapper: Callable): w = self.wrappers.setdefault(wrapper_type, {}).setdefault(key, []) w.append(wrapper) def remove_wrappers_with_key(self, wrapper_type: str, key: str): w = self.wrappers.get(wrapper_type, {}) if key in w: w.pop(key) def get_wrappers(self, wrapper_type: str, key: str): return self.wrappers.get(wrapper_type, {}).get(key, []) def get_all_wrappers(self, wrapper_type: str): w_list = [] for w in self.wrappers.get(wrapper_type, {}).values(): w_list.extend(w) return w_list def set_attachments(self, key: str, attachment): self.attachments[key] = attachment def remove_attachments(self, key: str): if key in self.attachments: self.attachments.pop(key) def get_attachment(self, key: str): return self.attachments.get(key, None) def set_injections(self, key: str, injections: list[PatcherInjection]): self.injections[key] = injections def remove_injections(self, key: str): if key in self.injections: self.injections.pop(key) def get_injections(self, key: str): return self.injections.get(key, None) def set_additional_models(self, key: str, models: list['ModelPatcher']): self.additional_models[key] = models def remove_additional_models(self, key: str): if key in self.additional_models: self.additional_models.pop(key) def get_additional_models_with_key(self, key: str): return self.additional_models.get(key, []) def get_additional_models(self): all_models = [] for models in self.additional_models.values(): all_models.extend(models) return all_models def get_nested_additional_models(self): def _evaluate_sub_additional_models(prev_models: list[ModelPatcher], cache_set: set[ModelPatcher]): '''Make sure circular references do not cause infinite recursion.''' next_models = [] for model in prev_models: candidates = model.get_additional_models() for c in candidates: if c not in cache_set: next_models.append(c) cache_set.add(c) if len(next_models) == 0: return prev_models return prev_models + _evaluate_sub_additional_models(next_models, cache_set) all_models = self.get_additional_models() models_set = set(all_models) real_all_models = _evaluate_sub_additional_models(prev_models=all_models, cache_set=models_set) return real_all_models def use_ejected(self, skip_and_inject_on_exit_only=False): return AutoPatcherEjector(self, skip_and_inject_on_exit_only=skip_and_inject_on_exit_only) def inject_model(self): if self.is_injected or self.skip_injection: return for injections in self.injections.values(): for inj in injections: inj.inject(self) self.is_injected = True if self.is_injected: for callback in self.get_all_callbacks(CallbacksMP.ON_INJECT_MODEL): callback(self) def eject_model(self): if not self.is_injected: return for injections in self.injections.values(): for inj in injections: inj.eject(self) self.is_injected = False for callback in self.get_all_callbacks(CallbacksMP.ON_EJECT_MODEL): callback(self) def pre_run(self): if hasattr(self.model, "current_patcher"): self.model.current_patcher = self for callback in self.get_all_callbacks(CallbacksMP.ON_PRE_RUN): callback(self) def prepare_state(self, timestep): for callback in self.get_all_callbacks(CallbacksMP.ON_PREPARE_STATE): callback(self, timestep) def restore_hook_patches(self): if self.hook_patches_backup is not None: self.hook_patches = self.hook_patches_backup self.hook_patches_backup = None def set_hook_mode(self, hook_mode: comfy.hooks.EnumHookMode): self.hook_mode = hook_mode def prepare_hook_patches_current_keyframe(self, t: torch.Tensor, hook_group: comfy.hooks.HookGroup, model_options: dict[str]): curr_t = t[0] reset_current_hooks = False transformer_options = model_options.get("transformer_options", {}) for hook in hook_group.hooks: changed = hook.hook_keyframe.prepare_current_keyframe(curr_t=curr_t, transformer_options=transformer_options) # if keyframe changed, remove any cached HookGroups that contain hook with the same hook_ref; # this will cause the weights to be recalculated when sampling if changed: # reset current_hooks if contains hook that changed if self.current_hooks is not None: for current_hook in self.current_hooks.hooks: if current_hook == hook: reset_current_hooks = True break for cached_group in list(self.cached_hook_patches.keys()): if cached_group.contains(hook): self.cached_hook_patches.pop(cached_group) if reset_current_hooks: self.patch_hooks(None) def register_all_hook_patches(self, hooks: comfy.hooks.HookGroup, target_dict: dict[str], model_options: dict=None, registered: comfy.hooks.HookGroup = None): self.restore_hook_patches() if registered is None: registered = comfy.hooks.HookGroup() # handle WeightHooks weight_hooks_to_register: list[comfy.hooks.WeightHook] = [] for hook in hooks.get_type(comfy.hooks.EnumHookType.Weight): if hook.hook_ref not in self.hook_patches: weight_hooks_to_register.append(hook) else: registered.add(hook) if len(weight_hooks_to_register) > 0: # clone hook_patches to become backup so that any non-dynamic hooks will return to their original state self.hook_patches_backup = create_hook_patches_clone(self.hook_patches) for hook in weight_hooks_to_register: hook.add_hook_patches(self, model_options, target_dict, registered) for callback in self.get_all_callbacks(CallbacksMP.ON_REGISTER_ALL_HOOK_PATCHES): callback(self, hooks, target_dict, model_options, registered) return registered def add_hook_patches(self, hook: comfy.hooks.WeightHook, patches, strength_patch=1.0, strength_model=1.0): with self.use_ejected(): # NOTE: this mirrors behavior of add_patches func current_hook_patches: dict[str,list] = self.hook_patches.get(hook.hook_ref, {}) p = set() model_sd = self.model.state_dict() for k in patches: offset = None function = None if isinstance(k, str): key = k else: offset = k[1] key = k[0] if len(k) > 2: function = k[2] if key in model_sd: p.add(k) current_patches: list[tuple] = current_hook_patches.get(key, []) current_patches.append((strength_patch, patches[k], strength_model, offset, function)) current_hook_patches[key] = current_patches self.hook_patches[hook.hook_ref] = current_hook_patches # since should care about these patches too to determine if same model, reroll patches_uuid self.patches_uuid = uuid.uuid4() return list(p) def get_combined_hook_patches(self, hooks: comfy.hooks.HookGroup): # combined_patches will contain weights of all relevant hooks, per key combined_patches = {} if hooks is not None: for hook in hooks.hooks: hook_patches: dict = self.hook_patches.get(hook.hook_ref, {}) for key in hook_patches.keys(): current_patches: list[tuple] = combined_patches.get(key, []) if math.isclose(hook.strength, 1.0): current_patches.extend(hook_patches[key]) else: # patches are stored as tuples: (strength_patch, (tuple_with_weights,), strength_model) for patch in hook_patches[key]: new_patch = list(patch) new_patch[0] *= hook.strength current_patches.append(tuple(new_patch)) combined_patches[key] = current_patches return combined_patches def apply_hooks(self, hooks: comfy.hooks.HookGroup, transformer_options: dict=None, force_apply=False): # TODO: return transformer_options dict with any additions from hooks if self.current_hooks == hooks and (not force_apply or (not self.is_clip and hooks is None)): return comfy.hooks.create_transformer_options_from_hooks(self, hooks, transformer_options) self.patch_hooks(hooks=hooks) for callback in self.get_all_callbacks(CallbacksMP.ON_APPLY_HOOKS): callback(self, hooks) return comfy.hooks.create_transformer_options_from_hooks(self, hooks, transformer_options) def patch_hooks(self, hooks: comfy.hooks.HookGroup): with self.use_ejected(): if hooks is not None: model_sd_keys = list(self.model_state_dict().keys()) memory_counter = None if self.hook_mode == comfy.hooks.EnumHookMode.MaxSpeed: # TODO: minimum_counter should have a minimum that conforms to loaded model requirements memory_counter = MemoryCounter(initial=comfy.model_management.get_free_memory(self.load_device), minimum=comfy.model_management.minimum_inference_memory()*2) # if have cached weights for hooks, use it cached_weights = self.cached_hook_patches.get(hooks, None) if cached_weights is not None: model_sd_keys_set = set(model_sd_keys) for key in cached_weights: if key not in model_sd_keys: logging.warning(f"Cached hook could not patch. Key does not exist in model: {key}") continue self.patch_cached_hook_weights(cached_weights=cached_weights, key=key, memory_counter=memory_counter) model_sd_keys_set.remove(key) self.unpatch_hooks(model_sd_keys_set) else: self.unpatch_hooks() relevant_patches = self.get_combined_hook_patches(hooks=hooks) original_weights = None if len(relevant_patches) > 0: original_weights = self.get_key_patches() for key in relevant_patches: if key not in model_sd_keys: logging.warning(f"Cached hook would not patch. Key does not exist in model: {key}") continue self.patch_hook_weight_to_device(hooks=hooks, combined_patches=relevant_patches, key=key, original_weights=original_weights, memory_counter=memory_counter) else: self.unpatch_hooks() self.current_hooks = hooks def patch_cached_hook_weights(self, cached_weights: dict, key: str, memory_counter: MemoryCounter): if key not in self.hook_backup: weight: torch.Tensor = comfy.utils.get_attr(self.model, key) target_device = self.offload_device if self.hook_mode == comfy.hooks.EnumHookMode.MaxSpeed: used = memory_counter.use(weight) if used: target_device = weight.device self.hook_backup[key] = (weight.to(device=target_device, copy=True), weight.device) comfy.utils.copy_to_param(self.model, key, cached_weights[key][0].to(device=cached_weights[key][1])) def clear_cached_hook_weights(self): self.cached_hook_patches.clear() self.patch_hooks(None) def patch_hook_weight_to_device(self, hooks: comfy.hooks.HookGroup, combined_patches: dict, key: str, original_weights: dict, memory_counter: MemoryCounter): if key not in combined_patches: return weight, set_func, convert_func = get_key_weight(self.model, key) weight: torch.Tensor if key not in self.hook_backup: target_device = self.offload_device if self.hook_mode == comfy.hooks.EnumHookMode.MaxSpeed: used = memory_counter.use(weight) if used: target_device = weight.device self.hook_backup[key] = (weight.to(device=target_device, copy=True), weight.device) # TODO: properly handle LowVramPatch, if it ends up an issue temp_weight = comfy.model_management.cast_to_device(weight, weight.device, torch.float32, copy=True) if convert_func is not None: temp_weight = convert_func(temp_weight, inplace=True) out_weight = comfy.lora.calculate_weight(combined_patches[key], temp_weight, key, original_weights=original_weights) del original_weights[key] if set_func is None: out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=string_to_seed(key)) comfy.utils.copy_to_param(self.model, key, out_weight) else: set_func(out_weight, inplace_update=True, seed=string_to_seed(key)) if self.hook_mode == comfy.hooks.EnumHookMode.MaxSpeed: # TODO: disable caching if not enough system RAM to do so target_device = self.offload_device used = memory_counter.use(weight) if used: target_device = weight.device self.cached_hook_patches.setdefault(hooks, {}) self.cached_hook_patches[hooks][key] = (out_weight.to(device=target_device, copy=False), weight.device) del temp_weight del out_weight del weight def unpatch_hooks(self, whitelist_keys_set: set[str]=None) -> None: with self.use_ejected(): if len(self.hook_backup) == 0: self.current_hooks = None return keys = list(self.hook_backup.keys()) if whitelist_keys_set: for k in keys: if k in whitelist_keys_set: comfy.utils.copy_to_param(self.model, k, self.hook_backup[k][0].to(device=self.hook_backup[k][1])) self.hook_backup.pop(k) else: for k in keys: comfy.utils.copy_to_param(self.model, k, self.hook_backup[k][0].to(device=self.hook_backup[k][1])) self.hook_backup.clear() self.current_hooks = None def clean_hooks(self): self.unpatch_hooks() self.clear_cached_hook_weights() def __del__(self): self.detach(unpatch_all=False)