mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-09-10 11:35:40 +00:00
Remove AttentionOverrideTest node, that's something to cook up for later
This commit is contained in:
@@ -1058,51 +1058,6 @@ class Wan22ImageToVideoLatent(io.ComfyNode):
|
|||||||
out_latent["noise_mask"] = mask.repeat((batch_size, ) + (1,) * (mask.ndim - 1))
|
out_latent["noise_mask"] = mask.repeat((batch_size, ) + (1,) * (mask.ndim - 1))
|
||||||
return io.NodeOutput(out_latent)
|
return io.NodeOutput(out_latent)
|
||||||
|
|
||||||
import comfy.patcher_extension
|
|
||||||
import comfy.ldm.modules.attention
|
|
||||||
import logging
|
|
||||||
|
|
||||||
class AttentionOverrideTest(io.ComfyNode):
|
|
||||||
@classmethod
|
|
||||||
def define_schema(cls):
|
|
||||||
attention_function_names = list(comfy.ldm.modules.attention.REGISTERED_ATTENTION_FUNCTIONS.keys())
|
|
||||||
return io.Schema(
|
|
||||||
node_id="AttentionOverrideTest",
|
|
||||||
category="devtools",
|
|
||||||
inputs=[
|
|
||||||
io.Model.Input("model"),
|
|
||||||
io.Combo.Input("attention", options=attention_function_names),
|
|
||||||
],
|
|
||||||
outputs=[
|
|
||||||
io.Model.Output(),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def attention_override_factory(attention_func):
|
|
||||||
def attention_override(func, *args, **kwargs):
|
|
||||||
return attention_func(*args, **kwargs)
|
|
||||||
return attention_override
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def sampler_sampler_wrapper(executor, *args, **kwargs):
|
|
||||||
try:
|
|
||||||
# extra_args = args[2]
|
|
||||||
return executor(*args, **kwargs)
|
|
||||||
finally:
|
|
||||||
pass
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def execute(cls, model: io.Model.Type, attention: str) -> io.NodeOutput:
|
|
||||||
attention_func = comfy.ldm.modules.attention.get_attention_function(attention, None)
|
|
||||||
if attention_func is None:
|
|
||||||
logging.info(f"Attention type '{attention}' not found, using default optimized attention for your hardware.")
|
|
||||||
return model
|
|
||||||
|
|
||||||
model = model.clone()
|
|
||||||
model.model_options["transformer_options"]["optimized_attention_override"] = cls.attention_override_factory(attention_func)
|
|
||||||
model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.SAMPLER_SAMPLE, "attention_override_test", cls.sampler_sampler_wrapper)
|
|
||||||
return io.NodeOutput(model)
|
|
||||||
|
|
||||||
class WanExtension(ComfyExtension):
|
class WanExtension(ComfyExtension):
|
||||||
@override
|
@override
|
||||||
@@ -1121,7 +1076,6 @@ class WanExtension(ComfyExtension):
|
|||||||
WanSoundImageToVideo,
|
WanSoundImageToVideo,
|
||||||
WanSoundImageToVideoExtend,
|
WanSoundImageToVideoExtend,
|
||||||
Wan22ImageToVideoLatent,
|
Wan22ImageToVideoLatent,
|
||||||
AttentionOverrideTest,
|
|
||||||
]
|
]
|
||||||
|
|
||||||
async def comfy_entrypoint() -> WanExtension:
|
async def comfy_entrypoint() -> WanExtension:
|
||||||
|
Reference in New Issue
Block a user