Turn off attention logging for now, make AttentionOverrideTestNode have a dropdown with available attention (this is a test node only)

This commit is contained in:
Jedrzej Kosinski
2025-08-28 18:54:22 -07:00
parent 51a30c2ad7
commit 1f499f0794
2 changed files with 18 additions and 8 deletions

View File

@@ -1034,7 +1034,7 @@ class CFGGuider:
self,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, self.model_options, is_model_options=True)
)
comfy.ldm.modules.attention.LOG_ATTN_CALLS = True #TODO: Remove this $$$$$
# comfy.ldm.modules.attention.LOG_ATTN_CALLS = True #TODO: Remove this $$$$$
comfy.ldm.modules.attention.LOG_CONTENTS = {}
output = executor.execute(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
finally:
@@ -1042,8 +1042,9 @@ class CFGGuider:
self.model_options = orig_model_options
self.model_patcher.hook_mode = orig_hook_mode
self.model_patcher.restore_hook_patches()
if comfy.ldm.modules.attention.LOG_ATTN_CALLS:
comfy.ldm.modules.attention.save_log_contents()
comfy.ldm.modules.attention.LOG_ATTN_CALLS = False #TODO: Remove this $$$$$
comfy.ldm.modules.attention.save_log_contents()
comfy.ldm.modules.attention.LOG_CONTENTS = {}
del self.conds

View File

@@ -1005,14 +1005,18 @@ class Wan22ImageToVideoLatent(io.ComfyNode):
import comfy.patcher_extension
import comfy.ldm.modules.attention
import logging
class AttentionOverrideTest(io.ComfyNode):
@classmethod
def define_schema(cls):
attention_function_names = list(comfy.ldm.modules.attention.REGISTERED_ATTENTION_FUNCTIONS.keys())
return io.Schema(
node_id="AttentionOverrideTest",
category="devtools",
inputs=[
io.Model.Input("model"),
io.Combo.Input("attention", options=attention_function_names),
],
outputs=[
io.Model.Output(),
@@ -1020,9 +1024,10 @@ class AttentionOverrideTest(io.ComfyNode):
)
@staticmethod
def attention_override(func, transformer_options, *args, **kwargs):
new_attention = comfy.ldm.modules.attention.attention_basic
return new_attention.__wrapped__(*args, **kwargs)
def attention_override_factory(attention_func):
def attention_override(func, *args, **kwargs):
return attention_func(*args, **kwargs)
return attention_override
@staticmethod
def sampler_sampler_wrapper(executor, *args, **kwargs):
@@ -1033,10 +1038,14 @@ class AttentionOverrideTest(io.ComfyNode):
pass
@classmethod
def execute(cls, model: io.Model.Type) -> io.NodeOutput:
model = model.clone()
def execute(cls, model: io.Model.Type, attention: str) -> io.NodeOutput:
attention_func = comfy.ldm.modules.attention.get_attention_function(attention, None)
if attention_func is None:
logging.info(f"Attention type '{attention}' not found, using default optimized attention for your hardware.")
return model
model.model_options["transformer_options"]["optimized_attention_override"] = cls.attention_override
model = model.clone()
model.model_options["transformer_options"]["optimized_attention_override"] = cls.attention_override_factory(attention_func)
model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.SAMPLER_SAMPLE, "attention_override_test", cls.sampler_sampler_wrapper)
return io.NodeOutput(model)