diff --git a/comfy/hooks.py b/comfy/hooks.py index 181c4996a..7ca3a8a11 100644 --- a/comfy/hooks.py +++ b/comfy/hooks.py @@ -44,7 +44,7 @@ class EnumHookType(enum.Enum): Patch = "patch" ObjectPatch = "object_patch" AddModels = "add_models" - Wrappers = "wrappers" + TransformerOptions = "transformer_options" Injections = "add_injections" class EnumWeightTarget(enum.Enum): @@ -245,29 +245,39 @@ class AddModelsHook(Hook): if not self.should_register(model, model_options, target_dict, registered): return False -class WrapperHook(Hook): +class TransformerOptionsHook(Hook): ''' - Hook responsible for adding wrappers, callbacks, or anything else onto transformer_options. + Hook responsible for adding wrappers, callbacks, patches, or anything else related to transformer_options. ''' def __init__(self, wrappers_dict: dict[str, dict[str, dict[str, list[Callable]]]]=None): - super().__init__(hook_type=EnumHookType.Wrappers) - self.wrappers_dict = wrappers_dict + super().__init__(hook_type=EnumHookType.TransformerOptions) + self.transformers_dict = wrappers_dict def clone(self, subtype: Callable=None): if subtype is None: subtype = type(self) c: WrapperHook = super().clone(subtype) - c.wrappers_dict = self.wrappers_dict + c.transformers_dict = self.transformers_dict return c def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: list[Hook]): if not self.should_register(model, model_options, target_dict, registered): return False - add_model_options = {"transformer_options": self.wrappers_dict} + add_model_options = {"transformer_options": self.transformers_dict} + # TODO: call .to on patches/anything else in transformer_options that is expected to do something if self.hook_scope == EnumHookScope.AllConditioning: comfy.patcher_extension.merge_nested_dicts(model_options, add_model_options, copy_dict1=False) registered.append(self) return True + + def on_apply_hooks(self, model: ModelPatcher, transformer_options: dict[str]): + comfy.patcher_extension.merge_nested_dicts(transformer_options, self.transformers_dict, copy_dict1=False) + +class WrapperHook(TransformerOptionsHook): + ''' + For backwards compatibility, this hook is identical to TransformerOptionsHook. + ''' + pass class SetInjectionsHook(Hook): def __init__(self, key: str=None, injections: list[PatcherInjection]=None): diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 071535526..2db21bdc4 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -945,7 +945,7 @@ class ModelPatcher: registered_hooks: list[comfy.hooks.Hook] = [] # handle WrapperHooks, if model_options provided if model_options is not None: - for hook in hooks_dict.get(comfy.hooks.EnumHookType.Wrappers, {}): + for hook in hooks_dict.get(comfy.hooks.EnumHookType.TransformerOptions, {}): hook.add_hook_patches(self, model_options, target_dict, registered_hooks) # handle WeightHooks weight_hooks_to_register: list[comfy.hooks.WeightHook] = []