diff --git a/comfy/samplers.py b/comfy/samplers.py index acf86c03d..81847dfa6 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -1034,7 +1034,7 @@ class CFGGuider: self, comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, self.model_options, is_model_options=True) ) - comfy.ldm.modules.attention.LOG_ATTN_CALLS = True #TODO: Remove this $$$$$ + # comfy.ldm.modules.attention.LOG_ATTN_CALLS = True #TODO: Remove this $$$$$ comfy.ldm.modules.attention.LOG_CONTENTS = {} output = executor.execute(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed) finally: @@ -1042,8 +1042,9 @@ class CFGGuider: self.model_options = orig_model_options self.model_patcher.hook_mode = orig_hook_mode self.model_patcher.restore_hook_patches() + if comfy.ldm.modules.attention.LOG_ATTN_CALLS: + comfy.ldm.modules.attention.save_log_contents() comfy.ldm.modules.attention.LOG_ATTN_CALLS = False #TODO: Remove this $$$$$ - comfy.ldm.modules.attention.save_log_contents() comfy.ldm.modules.attention.LOG_CONTENTS = {} del self.conds diff --git a/comfy_extras/nodes_wan.py b/comfy_extras/nodes_wan.py index 4fedbba7d..6f4f7f676 100644 --- a/comfy_extras/nodes_wan.py +++ b/comfy_extras/nodes_wan.py @@ -1005,14 +1005,18 @@ class Wan22ImageToVideoLatent(io.ComfyNode): import comfy.patcher_extension import comfy.ldm.modules.attention +import logging + class AttentionOverrideTest(io.ComfyNode): @classmethod def define_schema(cls): + attention_function_names = list(comfy.ldm.modules.attention.REGISTERED_ATTENTION_FUNCTIONS.keys()) return io.Schema( node_id="AttentionOverrideTest", category="devtools", inputs=[ io.Model.Input("model"), + io.Combo.Input("attention", options=attention_function_names), ], outputs=[ io.Model.Output(), @@ -1020,9 +1024,10 @@ class AttentionOverrideTest(io.ComfyNode): ) @staticmethod - def attention_override(func, transformer_options, *args, **kwargs): - new_attention = comfy.ldm.modules.attention.attention_basic - return new_attention.__wrapped__(*args, **kwargs) + def attention_override_factory(attention_func): + def attention_override(func, *args, **kwargs): + return attention_func(*args, **kwargs) + return attention_override @staticmethod def sampler_sampler_wrapper(executor, *args, **kwargs): @@ -1033,10 +1038,14 @@ class AttentionOverrideTest(io.ComfyNode): pass @classmethod - def execute(cls, model: io.Model.Type) -> io.NodeOutput: - model = model.clone() + def execute(cls, model: io.Model.Type, attention: str) -> io.NodeOutput: + attention_func = comfy.ldm.modules.attention.get_attention_function(attention, None) + if attention_func is None: + logging.info(f"Attention type '{attention}' not found, using default optimized attention for your hardware.") + return model - model.model_options["transformer_options"]["optimized_attention_override"] = cls.attention_override + model = model.clone() + model.model_options["transformer_options"]["optimized_attention_override"] = cls.attention_override_factory(attention_func) model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.SAMPLER_SAMPLE, "attention_override_test", cls.sampler_sampler_wrapper) return io.NodeOutput(model)