Add some more transformer hooks and move tomesd to comfy_extras.

Tomesd now uses q instead of x to decide which tokens to merge because
it seems to give better results.
This commit is contained in:
comfyanonymous
2023-06-23 20:17:45 -04:00
parent fa28d7334b
commit 05676942b7
5 changed files with 114 additions and 28 deletions

View File

@@ -12,8 +12,6 @@ from .sub_quadratic_attention import efficient_dot_product_attention
from comfy import model_management
import comfy.ops
from . import tomesd
if model_management.xformers_enabled():
import xformers
import xformers.ops
@@ -519,23 +517,39 @@ class BasicTransformerBlock(nn.Module):
self.norm2 = nn.LayerNorm(dim, dtype=dtype)
self.norm3 = nn.LayerNorm(dim, dtype=dtype)
self.checkpoint = checkpoint
self.n_heads = n_heads
self.d_head = d_head
def forward(self, x, context=None, transformer_options={}):
return checkpoint(self._forward, (x, context, transformer_options), self.parameters(), self.checkpoint)
def _forward(self, x, context=None, transformer_options={}):
extra_options = {}
block = None
block_index = 0
if "current_index" in transformer_options:
extra_options["transformer_index"] = transformer_options["current_index"]
if "block_index" in transformer_options:
extra_options["block_index"] = transformer_options["block_index"]
block_index = transformer_options["block_index"]
extra_options["block_index"] = block_index
if "original_shape" in transformer_options:
extra_options["original_shape"] = transformer_options["original_shape"]
if "block" in transformer_options:
block = transformer_options["block"]
extra_options["block"] = block
if "patches" in transformer_options:
transformer_patches = transformer_options["patches"]
else:
transformer_patches = {}
extra_options["n_heads"] = self.n_heads
extra_options["dim_head"] = self.d_head
if "patches_replace" in transformer_options:
transformer_patches_replace = transformer_options["patches_replace"]
else:
transformer_patches_replace = {}
n = self.norm1(x)
if self.disable_self_attn:
context_attn1 = context
@@ -551,12 +565,29 @@ class BasicTransformerBlock(nn.Module):
for p in patch:
n, context_attn1, value_attn1 = p(n, context_attn1, value_attn1, extra_options)
if "tomesd" in transformer_options:
m, u = tomesd.get_functions(x, transformer_options["tomesd"]["ratio"], transformer_options["original_shape"])
n = u(self.attn1(m(n), context=context_attn1, value=value_attn1))
transformer_block = (block[0], block[1], block_index)
attn1_replace_patch = transformer_patches_replace.get("attn1", {})
block_attn1 = transformer_block
if block_attn1 not in attn1_replace_patch:
block_attn1 = block
if block_attn1 in attn1_replace_patch:
if context_attn1 is None:
context_attn1 = n
value_attn1 = n
n = self.attn1.to_q(n)
context_attn1 = self.attn1.to_k(context_attn1)
value_attn1 = self.attn1.to_v(value_attn1)
n = attn1_replace_patch[block_attn1](n, context_attn1, value_attn1, extra_options)
n = self.attn1.to_out(n)
else:
n = self.attn1(n, context=context_attn1, value=value_attn1)
if "attn1_output_patch" in transformer_patches:
patch = transformer_patches["attn1_output_patch"]
for p in patch:
n = p(n, extra_options)
x += n
if "middle_patch" in transformer_patches:
patch = transformer_patches["middle_patch"]
@@ -573,7 +604,21 @@ class BasicTransformerBlock(nn.Module):
for p in patch:
n, context_attn2, value_attn2 = p(n, context_attn2, value_attn2, extra_options)
n = self.attn2(n, context=context_attn2, value=value_attn2)
attn2_replace_patch = transformer_patches_replace.get("attn2", {})
block_attn2 = transformer_block
if block_attn2 not in attn2_replace_patch:
block_attn2 = block
if block_attn2 in attn2_replace_patch:
if value_attn2 is None:
value_attn2 = context_attn2
n = self.attn2.to_q(n)
context_attn2 = self.attn2.to_k(context_attn2)
value_attn2 = self.attn2.to_v(value_attn2)
n = attn2_replace_patch[block_attn2](n, context_attn2, value_attn2, extra_options)
n = self.attn2.to_out(n)
else:
n = self.attn2(n, context=context_attn2, value=value_attn2)
if "attn2_output_patch" in transformer_patches:
patch = transformer_patches["attn2_output_patch"]