Add a way to set patches that modify the attn2 output.

Change the transformer patches function format to be more future proof.
This commit is contained in:
comfyanonymous
2023-06-18 22:58:22 -04:00
parent cd930d4e7f
commit 8883cb0f67
4 changed files with 19 additions and 7 deletions

View File

@@ -68,7 +68,7 @@ def load_hypernetwork_patch(path, strength):
def __init__(self, hypernet, strength):
self.hypernet = hypernet
self.strength = strength
def __call__(self, current_index, q, k, v):
def __call__(self, q, k, v, extra_options):
dim = k.shape[-1]
if dim in self.hypernet:
hn = self.hypernet[dim]