From 8270ff312f7aefc4d29aeeed667296b2a56628ce Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Sun, 5 Jan 2025 21:07:02 -0600 Subject: [PATCH] Refactored 'registered' to be HookGroup instead of a list of Hooks, made AddModelsHook operational and compliant with should_register result, moved TransformerOptionsHook handling out of ModelPatcher.register_all_hook_patches, support patches in TransformerOptionsHook properly by casting any patches/wrappers/hooks to proper device at sample time --- comfy/hooks.py | 34 +++++++++++++++--------- comfy/model_patcher.py | 15 +++++------ comfy/sampler_helpers.py | 48 +++++++++++++++++++++++++-------- comfy/samplers.py | 57 +++++++++++++++++++++++++++++++++++++--- 4 files changed, 119 insertions(+), 35 deletions(-) diff --git a/comfy/hooks.py b/comfy/hooks.py index 3ead8c963..25d67b86c 100644 --- a/comfy/hooks.py +++ b/comfy/hooks.py @@ -65,7 +65,7 @@ class _HookRef: pass -def default_should_register(hook: Hook, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: list[Hook]): +def default_should_register(hook: Hook, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup): '''Example for how should_register function should look like.''' return True @@ -114,10 +114,10 @@ class Hook: c.auto_apply_to_nonpositive = self.auto_apply_to_nonpositive return c - def should_register(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: list[Hook]): + def should_register(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup): return self.custom_should_register(self, model, model_options, target_dict, registered) - def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: list[Hook]): + def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup): raise NotImplementedError("add_hook_patches should be defined for Hook subclasses") def on_apply(self, model: ModelPatcher, transformer_options: dict[str]): @@ -154,7 +154,7 @@ class WeightHook(Hook): def strength_clip(self): return self._strength_clip * self.strength - def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: list[Hook]): + def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup): if not self.should_register(model, model_options, target_dict, registered): return False weights = None @@ -178,7 +178,7 @@ class WeightHook(Hook): else: weights = self.weights model.add_hook_patches(hook=self, patches=weights, strength_patch=strength) - registered.append(self) + registered.add(self) return True # TODO: add logs about any keys that were not applied @@ -212,11 +212,12 @@ class AddModelsHook(Hook): Note, value of hook_scope is ignored and is treated as AllConditioning. ''' - def __init__(self, key: str=None, models: list[ModelPatcher]=None): + def __init__(self, models: list[ModelPatcher]=None, key: str=None): super().__init__(hook_type=EnumHookType.AddModels) - self.key = key self.models = models + self.key = key self.append_when_same = True + '''Curently does nothing.''' def clone(self, subtype: Callable=None): if subtype is None: @@ -227,9 +228,10 @@ class AddModelsHook(Hook): c.append_when_same = self.append_when_same return c - def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: list[Hook]): + def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup): if not self.should_register(model, model_options, target_dict, registered): return False + registered.add(self) return True class TransformerOptionsHook(Hook): @@ -247,14 +249,17 @@ class TransformerOptionsHook(Hook): c.transformers_dict = self.transformers_dict return c - def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: list[Hook]): + def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup): if not self.should_register(model, model_options, target_dict, registered): return False - add_model_options = {"transformer_options": self.transformers_dict} - # TODO: call .to on patches/anything else in transformer_options that is expected to do something + # NOTE: to_load_options will be used to manually load patches/wrappers/callbacks from hooks if self.hook_scope == EnumHookScope.AllConditioning: - comfy.patcher_extension.merge_nested_dicts(model_options, add_model_options, copy_dict1=False) - registered.append(self) + add_model_options = {"transformer_options": self.transformers_dict, + "to_load_options": self.transformers_dict} + else: + add_model_options = {"to_load_options": self.transformers_dict} + comfy.patcher_extension.merge_nested_dicts(model_options, add_model_options, copy_dict1=False) + registered.add(self) return True def on_apply_hooks(self, model: ModelPatcher, transformer_options: dict[str]): @@ -295,6 +300,9 @@ class HookGroup: self.hooks: list[Hook] = [] self._hook_dict: dict[EnumHookType, list[Hook]] = {} + def __len__(self): + return len(self.hooks) + def add(self, hook: Hook): if hook not in self.hooks: self.hooks.append(hook) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 0430430e5..2a5510873 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -940,13 +940,11 @@ class ModelPatcher: if reset_current_hooks: self.patch_hooks(None) - def register_all_hook_patches(self, hooks: comfy.hooks.HookGroup, target_dict: dict[str], model_options: dict=None): + def register_all_hook_patches(self, hooks: comfy.hooks.HookGroup, target_dict: dict[str], model_options: dict=None, + registered: comfy.hooks.HookGroup = None): self.restore_hook_patches() - registered_hooks: list[comfy.hooks.Hook] = [] - # handle TransformerOptionsHooks, if model_options provided - if model_options is not None: - for hook in hooks.get_type(comfy.hooks.EnumHookType.TransformerOptions): - hook.add_hook_patches(self, model_options, target_dict, registered_hooks) + if registered is None: + registered = comfy.hooks.HookGroup() # handle WeightHooks weight_hooks_to_register: list[comfy.hooks.WeightHook] = [] for hook in hooks.get_type(comfy.hooks.EnumHookType.Weight): @@ -956,9 +954,10 @@ class ModelPatcher: # clone hook_patches to become backup so that any non-dynamic hooks will return to their original state self.hook_patches_backup = create_hook_patches_clone(self.hook_patches) for hook in weight_hooks_to_register: - hook.add_hook_patches(self, model_options, target_dict, registered_hooks) + hook.add_hook_patches(self, model_options, target_dict, registered) for callback in self.get_all_callbacks(CallbacksMP.ON_REGISTER_ALL_HOOK_PATCHES): - callback(self, hooks, target_dict) + callback(self, hooks, target_dict, model_options, registered) + return registered def add_hook_patches(self, hook: comfy.hooks.WeightHook, patches, strength_patch=1.0, strength_model=1.0): with self.use_ejected(): diff --git a/comfy/sampler_helpers.py b/comfy/sampler_helpers.py index abd44cf6e..cb9388519 100644 --- a/comfy/sampler_helpers.py +++ b/comfy/sampler_helpers.py @@ -70,13 +70,11 @@ def get_additional_models(conds, dtype): cnets: list[ControlBase] = [] gligen = [] add_models = [] - hooks = comfy.hooks.HookGroup() for k in conds: cnets += get_models_from_cond(conds[k], "control") gligen += get_models_from_cond(conds[k], "gligen") add_models += get_models_from_cond(conds[k], "additional_models") - get_hooks_from_cond(conds[k], hooks) control_nets = set(cnets) @@ -87,14 +85,20 @@ def get_additional_models(conds, dtype): inference_memory += m.inference_memory_requirements(dtype) gligen = [x[1] for x in gligen] - hook_models = [] - for x in hooks.get_type(comfy.hooks.EnumHookType.AddModels): - x: comfy.hooks.AddModelsHook - hook_models.extend(x.models) - models = control_models + gligen + add_models + hook_models + models = control_models + gligen + add_models return models, inference_memory +def get_additional_models_from_model_options(model_options: dict[str]=None): + """loads additional models from registered AddModels hooks""" + models = [] + if model_options is not None and "registered_hooks" in model_options: + registered: comfy.hooks.HookGroup = model_options["registered_hooks"] + for hook in registered.get_type(comfy.hooks.EnumHookType.AddModels): + hook: comfy.hooks.AddModelsHook + models.extend(hook.models) + return models + def cleanup_additional_models(models): """cleanup additional models that were loaded""" for m in models: @@ -102,9 +106,10 @@ def cleanup_additional_models(models): m.cleanup() -def prepare_sampling(model: 'ModelPatcher', noise_shape, conds): - real_model: 'BaseModel' = None +def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None): + real_model: BaseModel = None models, inference_memory = get_additional_models(conds, model.model_dtype()) + models += get_additional_models_from_model_options(model_options) models += model.get_nested_additional_models() # TODO: does this require inference_memory update? memory_required = model.memory_required([noise_shape[0] * 2] + list(noise_shape[1:])) + inference_memory minimum_memory_required = model.memory_required([noise_shape[0]] + list(noise_shape[1:])) + inference_memory @@ -130,5 +135,26 @@ def prepare_model_patcher(model: 'ModelPatcher', conds, model_options: dict): # add wrappers and callbacks from ModelPatcher to transformer_options model_options["transformer_options"]["wrappers"] = comfy.patcher_extension.copy_nested_dicts(model.wrappers) model_options["transformer_options"]["callbacks"] = comfy.patcher_extension.copy_nested_dicts(model.callbacks) - # register hooks on model/model_options - model.register_all_hook_patches(hooks, comfy.hooks.create_target_dict(comfy.hooks.EnumWeightTarget.Model), model_options) + # begin registering hooks + registered = comfy.hooks.HookGroup() + target_dict = comfy.hooks.create_target_dict(comfy.hooks.EnumWeightTarget.Model) + # handle all TransformerOptionsHooks + for hook in hooks.get_type(comfy.hooks.EnumHookType.TransformerOptions): + hook: comfy.hooks.TransformerOptionsHook + hook.add_hook_patches(model, model_options, target_dict, registered) + # handle all AddModelsHooks + for hook in hooks.get_type(comfy.hooks.EnumHookType.AddModels): + hook: comfy.hooks.AddModelsHook + hook.add_hook_patches(model, model_options, target_dict, registered) + # handle all WeightHooks by registering on ModelPatcher + model.register_all_hook_patches(hooks, target_dict, model_options, registered) + # add registered_hooks onto model_options for further reference + if len(registered) > 0: + model_options["registered_hooks"] = registered + # merge original wrappers and callbacks with hooked wrappers and callbacks + to_load_options: dict[str] = model_options.setdefault("to_load_options", {}) + for wc_name in ["wrappers", "callbacks"]: + comfy.patcher_extension.merge_nested_dicts(to_load_options.setdefault(wc_name, {}), model_options["transformer_options"][wc_name], + copy_dict1=False) + return to_load_options + diff --git a/comfy/samplers.py b/comfy/samplers.py index af2b8e110..8f8345abc 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -819,9 +819,58 @@ def get_total_hook_groups_in_conds(conds: dict[str, list[dict[str]]]): return len(hooks_set) +def cast_to_load_options(model_options: dict[str], device=None, dtype=None): + ''' + If any patches from hooks, wrappers, or callbacks have .to to be called, call it. + ''' + if model_options is None: + return + to_load_options = model_options.get("to_load_options", None) + if to_load_options is None: + return + + casts = [] + if device is not None: + casts.append(device) + if dtype is not None: + casts.append(dtype) + # if nothing to apply, do nothing + if len(casts) == 0: + return + + # Try to call .to on patches + if "patches" in to_load_options: + patches = to_load_options["patches"] + for name in patches: + patch_list = patches[name] + for i in range(len(patch_list)): + if hasattr(patch_list[i], "to"): + for cast in casts: + patch_list[i] = patch_list[i].to(cast) + if "patches_replace" in to_load_options: + patches = to_load_options["patches_replace"] + for name in patches: + patch_list = patches[name] + for k in patch_list: + if hasattr(patch_list[k], "to"): + for cast in casts: + patch_list[k] = patch_list[k].to(cast) + # Try to call .to on any wrappers/callbacks + wrappers_and_callbacks = ["wrappers", "callbacks"] + for wc_name in wrappers_and_callbacks: + if wc_name in to_load_options: + wc: dict[str, list] = to_load_options[wc_name] + for wc_dict in wc.values(): + for wc_list in wc_dict.values(): + for i in range(len(wc_list)): + if hasattr(wc_list[i], "to"): + for cast in casts: + wc_list[i] = wc_list[i].to(cast) + + class CFGGuider: - def __init__(self, model_patcher): - self.model_patcher: 'ModelPatcher' = model_patcher + def __init__(self, model_patcher: ModelPatcher): + self.model_patcher = model_patcher self.model_options = model_patcher.model_options self.original_conds = {} self.cfg = 1.0 @@ -861,7 +910,7 @@ class CFGGuider: return self.inner_model.process_latent_out(samples.to(torch.float32)) def outer_sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callback=None, disable_pbar=False, seed=None): - self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds) + self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds, self.model_options) device = self.model_patcher.load_device if denoise_mask is not None: @@ -870,6 +919,7 @@ class CFGGuider: noise = noise.to(device) latent_image = latent_image.to(device) sigmas = sigmas.to(device) + cast_to_load_options(self.model_options, device=device, dtype=self.model_patcher.model_dtype()) try: self.model_patcher.pre_run() @@ -906,6 +956,7 @@ class CFGGuider: ) output = executor.execute(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed) finally: + cast_to_load_options(self.model_options, device=self.model_patcher.offload_device) self.model_options = orig_model_options self.model_patcher.hook_mode = orig_hook_mode self.model_patcher.restore_hook_patches()