From 66c4eb006bcc068b202891151174ecb8d6d0bf57 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Sat, 30 Aug 2025 15:19:36 -0700 Subject: [PATCH] Remove AttentionOverrideTest node, that's something to cook up for later --- comfy_extras/nodes_wan.py | 46 --------------------------------------- 1 file changed, 46 deletions(-) diff --git a/comfy_extras/nodes_wan.py b/comfy_extras/nodes_wan.py index 66806e34d..4f73369f5 100644 --- a/comfy_extras/nodes_wan.py +++ b/comfy_extras/nodes_wan.py @@ -1058,51 +1058,6 @@ class Wan22ImageToVideoLatent(io.ComfyNode): out_latent["noise_mask"] = mask.repeat((batch_size, ) + (1,) * (mask.ndim - 1)) return io.NodeOutput(out_latent) -import comfy.patcher_extension -import comfy.ldm.modules.attention -import logging - -class AttentionOverrideTest(io.ComfyNode): - @classmethod - def define_schema(cls): - attention_function_names = list(comfy.ldm.modules.attention.REGISTERED_ATTENTION_FUNCTIONS.keys()) - return io.Schema( - node_id="AttentionOverrideTest", - category="devtools", - inputs=[ - io.Model.Input("model"), - io.Combo.Input("attention", options=attention_function_names), - ], - outputs=[ - io.Model.Output(), - ], - ) - - @staticmethod - def attention_override_factory(attention_func): - def attention_override(func, *args, **kwargs): - return attention_func(*args, **kwargs) - return attention_override - - @staticmethod - def sampler_sampler_wrapper(executor, *args, **kwargs): - try: - # extra_args = args[2] - return executor(*args, **kwargs) - finally: - pass - - @classmethod - def execute(cls, model: io.Model.Type, attention: str) -> io.NodeOutput: - attention_func = comfy.ldm.modules.attention.get_attention_function(attention, None) - if attention_func is None: - logging.info(f"Attention type '{attention}' not found, using default optimized attention for your hardware.") - return model - - model = model.clone() - model.model_options["transformer_options"]["optimized_attention_override"] = cls.attention_override_factory(attention_func) - model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.SAMPLER_SAMPLE, "attention_override_test", cls.sampler_sampler_wrapper) - return io.NodeOutput(model) class WanExtension(ComfyExtension): @override @@ -1121,7 +1076,6 @@ class WanExtension(ComfyExtension): WanSoundImageToVideo, WanSoundImageToVideoExtend, Wan22ImageToVideoLatent, - AttentionOverrideTest, ] async def comfy_entrypoint() -> WanExtension: