mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-09-13 21:16:09 +00:00
Add some more transformer hooks and move tomesd to comfy_extras.
Tomesd now uses q instead of x to decide which tokens to merge because it seems to give better results.
This commit is contained in:
@@ -12,8 +12,6 @@ from .sub_quadratic_attention import efficient_dot_product_attention
|
||||
from comfy import model_management
|
||||
import comfy.ops
|
||||
|
||||
from . import tomesd
|
||||
|
||||
if model_management.xformers_enabled():
|
||||
import xformers
|
||||
import xformers.ops
|
||||
@@ -519,23 +517,39 @@ class BasicTransformerBlock(nn.Module):
|
||||
self.norm2 = nn.LayerNorm(dim, dtype=dtype)
|
||||
self.norm3 = nn.LayerNorm(dim, dtype=dtype)
|
||||
self.checkpoint = checkpoint
|
||||
self.n_heads = n_heads
|
||||
self.d_head = d_head
|
||||
|
||||
def forward(self, x, context=None, transformer_options={}):
|
||||
return checkpoint(self._forward, (x, context, transformer_options), self.parameters(), self.checkpoint)
|
||||
|
||||
def _forward(self, x, context=None, transformer_options={}):
|
||||
extra_options = {}
|
||||
block = None
|
||||
block_index = 0
|
||||
if "current_index" in transformer_options:
|
||||
extra_options["transformer_index"] = transformer_options["current_index"]
|
||||
if "block_index" in transformer_options:
|
||||
extra_options["block_index"] = transformer_options["block_index"]
|
||||
block_index = transformer_options["block_index"]
|
||||
extra_options["block_index"] = block_index
|
||||
if "original_shape" in transformer_options:
|
||||
extra_options["original_shape"] = transformer_options["original_shape"]
|
||||
if "block" in transformer_options:
|
||||
block = transformer_options["block"]
|
||||
extra_options["block"] = block
|
||||
if "patches" in transformer_options:
|
||||
transformer_patches = transformer_options["patches"]
|
||||
else:
|
||||
transformer_patches = {}
|
||||
|
||||
extra_options["n_heads"] = self.n_heads
|
||||
extra_options["dim_head"] = self.d_head
|
||||
|
||||
if "patches_replace" in transformer_options:
|
||||
transformer_patches_replace = transformer_options["patches_replace"]
|
||||
else:
|
||||
transformer_patches_replace = {}
|
||||
|
||||
n = self.norm1(x)
|
||||
if self.disable_self_attn:
|
||||
context_attn1 = context
|
||||
@@ -551,12 +565,29 @@ class BasicTransformerBlock(nn.Module):
|
||||
for p in patch:
|
||||
n, context_attn1, value_attn1 = p(n, context_attn1, value_attn1, extra_options)
|
||||
|
||||
if "tomesd" in transformer_options:
|
||||
m, u = tomesd.get_functions(x, transformer_options["tomesd"]["ratio"], transformer_options["original_shape"])
|
||||
n = u(self.attn1(m(n), context=context_attn1, value=value_attn1))
|
||||
transformer_block = (block[0], block[1], block_index)
|
||||
attn1_replace_patch = transformer_patches_replace.get("attn1", {})
|
||||
block_attn1 = transformer_block
|
||||
if block_attn1 not in attn1_replace_patch:
|
||||
block_attn1 = block
|
||||
|
||||
if block_attn1 in attn1_replace_patch:
|
||||
if context_attn1 is None:
|
||||
context_attn1 = n
|
||||
value_attn1 = n
|
||||
n = self.attn1.to_q(n)
|
||||
context_attn1 = self.attn1.to_k(context_attn1)
|
||||
value_attn1 = self.attn1.to_v(value_attn1)
|
||||
n = attn1_replace_patch[block_attn1](n, context_attn1, value_attn1, extra_options)
|
||||
n = self.attn1.to_out(n)
|
||||
else:
|
||||
n = self.attn1(n, context=context_attn1, value=value_attn1)
|
||||
|
||||
if "attn1_output_patch" in transformer_patches:
|
||||
patch = transformer_patches["attn1_output_patch"]
|
||||
for p in patch:
|
||||
n = p(n, extra_options)
|
||||
|
||||
x += n
|
||||
if "middle_patch" in transformer_patches:
|
||||
patch = transformer_patches["middle_patch"]
|
||||
@@ -573,7 +604,21 @@ class BasicTransformerBlock(nn.Module):
|
||||
for p in patch:
|
||||
n, context_attn2, value_attn2 = p(n, context_attn2, value_attn2, extra_options)
|
||||
|
||||
n = self.attn2(n, context=context_attn2, value=value_attn2)
|
||||
attn2_replace_patch = transformer_patches_replace.get("attn2", {})
|
||||
block_attn2 = transformer_block
|
||||
if block_attn2 not in attn2_replace_patch:
|
||||
block_attn2 = block
|
||||
|
||||
if block_attn2 in attn2_replace_patch:
|
||||
if value_attn2 is None:
|
||||
value_attn2 = context_attn2
|
||||
n = self.attn2.to_q(n)
|
||||
context_attn2 = self.attn2.to_k(context_attn2)
|
||||
value_attn2 = self.attn2.to_v(value_attn2)
|
||||
n = attn2_replace_patch[block_attn2](n, context_attn2, value_attn2, extra_options)
|
||||
n = self.attn2.to_out(n)
|
||||
else:
|
||||
n = self.attn2(n, context=context_attn2, value=value_attn2)
|
||||
|
||||
if "attn2_output_patch" in transformer_patches:
|
||||
patch = transformer_patches["attn2_output_patch"]
|
||||
|
@@ -830,17 +830,20 @@ class UNetModel(nn.Module):
|
||||
|
||||
h = x.type(self.dtype)
|
||||
for id, module in enumerate(self.input_blocks):
|
||||
transformer_options["block"] = ("input", id)
|
||||
h = forward_timestep_embed(module, h, emb, context, transformer_options)
|
||||
if control is not None and 'input' in control and len(control['input']) > 0:
|
||||
ctrl = control['input'].pop()
|
||||
if ctrl is not None:
|
||||
h += ctrl
|
||||
hs.append(h)
|
||||
transformer_options["block"] = ("middle", 0)
|
||||
h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options)
|
||||
if control is not None and 'middle' in control and len(control['middle']) > 0:
|
||||
h += control['middle'].pop()
|
||||
|
||||
for module in self.output_blocks:
|
||||
for id, module in enumerate(self.output_blocks):
|
||||
transformer_options["block"] = ("output", id)
|
||||
hsp = hs.pop()
|
||||
if control is not None and 'output' in control and len(control['output']) > 0:
|
||||
ctrl = control['output'].pop()
|
||||
|
@@ -1,144 +0,0 @@
|
||||
#Taken from: https://github.com/dbolya/tomesd
|
||||
|
||||
import torch
|
||||
from typing import Tuple, Callable
|
||||
import math
|
||||
|
||||
def do_nothing(x: torch.Tensor, mode:str=None):
|
||||
return x
|
||||
|
||||
|
||||
def mps_gather_workaround(input, dim, index):
|
||||
if input.shape[-1] == 1:
|
||||
return torch.gather(
|
||||
input.unsqueeze(-1),
|
||||
dim - 1 if dim < 0 else dim,
|
||||
index.unsqueeze(-1)
|
||||
).squeeze(-1)
|
||||
else:
|
||||
return torch.gather(input, dim, index)
|
||||
|
||||
|
||||
def bipartite_soft_matching_random2d(metric: torch.Tensor,
|
||||
w: int, h: int, sx: int, sy: int, r: int,
|
||||
no_rand: bool = False) -> Tuple[Callable, Callable]:
|
||||
"""
|
||||
Partitions the tokens into src and dst and merges r tokens from src to dst.
|
||||
Dst tokens are partitioned by choosing one randomy in each (sx, sy) region.
|
||||
Args:
|
||||
- metric [B, N, C]: metric to use for similarity
|
||||
- w: image width in tokens
|
||||
- h: image height in tokens
|
||||
- sx: stride in the x dimension for dst, must divide w
|
||||
- sy: stride in the y dimension for dst, must divide h
|
||||
- r: number of tokens to remove (by merging)
|
||||
- no_rand: if true, disable randomness (use top left corner only)
|
||||
"""
|
||||
B, N, _ = metric.shape
|
||||
|
||||
if r <= 0 or w == 1 or h == 1:
|
||||
return do_nothing, do_nothing
|
||||
|
||||
gather = mps_gather_workaround if metric.device.type == "mps" else torch.gather
|
||||
|
||||
with torch.no_grad():
|
||||
|
||||
hsy, wsx = h // sy, w // sx
|
||||
|
||||
# For each sy by sx kernel, randomly assign one token to be dst and the rest src
|
||||
if no_rand:
|
||||
rand_idx = torch.zeros(hsy, wsx, 1, device=metric.device, dtype=torch.int64)
|
||||
else:
|
||||
rand_idx = torch.randint(sy*sx, size=(hsy, wsx, 1), device=metric.device)
|
||||
|
||||
# The image might not divide sx and sy, so we need to work on a view of the top left if the idx buffer instead
|
||||
idx_buffer_view = torch.zeros(hsy, wsx, sy*sx, device=metric.device, dtype=torch.int64)
|
||||
idx_buffer_view.scatter_(dim=2, index=rand_idx, src=-torch.ones_like(rand_idx, dtype=rand_idx.dtype))
|
||||
idx_buffer_view = idx_buffer_view.view(hsy, wsx, sy, sx).transpose(1, 2).reshape(hsy * sy, wsx * sx)
|
||||
|
||||
# Image is not divisible by sx or sy so we need to move it into a new buffer
|
||||
if (hsy * sy) < h or (wsx * sx) < w:
|
||||
idx_buffer = torch.zeros(h, w, device=metric.device, dtype=torch.int64)
|
||||
idx_buffer[:(hsy * sy), :(wsx * sx)] = idx_buffer_view
|
||||
else:
|
||||
idx_buffer = idx_buffer_view
|
||||
|
||||
# We set dst tokens to be -1 and src to be 0, so an argsort gives us dst|src indices
|
||||
rand_idx = idx_buffer.reshape(1, -1, 1).argsort(dim=1)
|
||||
|
||||
# We're finished with these
|
||||
del idx_buffer, idx_buffer_view
|
||||
|
||||
# rand_idx is currently dst|src, so split them
|
||||
num_dst = hsy * wsx
|
||||
a_idx = rand_idx[:, num_dst:, :] # src
|
||||
b_idx = rand_idx[:, :num_dst, :] # dst
|
||||
|
||||
def split(x):
|
||||
C = x.shape[-1]
|
||||
src = gather(x, dim=1, index=a_idx.expand(B, N - num_dst, C))
|
||||
dst = gather(x, dim=1, index=b_idx.expand(B, num_dst, C))
|
||||
return src, dst
|
||||
|
||||
# Cosine similarity between A and B
|
||||
metric = metric / metric.norm(dim=-1, keepdim=True)
|
||||
a, b = split(metric)
|
||||
scores = a @ b.transpose(-1, -2)
|
||||
|
||||
# Can't reduce more than the # tokens in src
|
||||
r = min(a.shape[1], r)
|
||||
|
||||
# Find the most similar greedily
|
||||
node_max, node_idx = scores.max(dim=-1)
|
||||
edge_idx = node_max.argsort(dim=-1, descending=True)[..., None]
|
||||
|
||||
unm_idx = edge_idx[..., r:, :] # Unmerged Tokens
|
||||
src_idx = edge_idx[..., :r, :] # Merged Tokens
|
||||
dst_idx = gather(node_idx[..., None], dim=-2, index=src_idx)
|
||||
|
||||
def merge(x: torch.Tensor, mode="mean") -> torch.Tensor:
|
||||
src, dst = split(x)
|
||||
n, t1, c = src.shape
|
||||
|
||||
unm = gather(src, dim=-2, index=unm_idx.expand(n, t1 - r, c))
|
||||
src = gather(src, dim=-2, index=src_idx.expand(n, r, c))
|
||||
dst = dst.scatter_reduce(-2, dst_idx.expand(n, r, c), src, reduce=mode)
|
||||
|
||||
return torch.cat([unm, dst], dim=1)
|
||||
|
||||
def unmerge(x: torch.Tensor) -> torch.Tensor:
|
||||
unm_len = unm_idx.shape[1]
|
||||
unm, dst = x[..., :unm_len, :], x[..., unm_len:, :]
|
||||
_, _, c = unm.shape
|
||||
|
||||
src = gather(dst, dim=-2, index=dst_idx.expand(B, r, c))
|
||||
|
||||
# Combine back to the original shape
|
||||
out = torch.zeros(B, N, c, device=x.device, dtype=x.dtype)
|
||||
out.scatter_(dim=-2, index=b_idx.expand(B, num_dst, c), src=dst)
|
||||
out.scatter_(dim=-2, index=gather(a_idx.expand(B, a_idx.shape[1], 1), dim=1, index=unm_idx).expand(B, unm_len, c), src=unm)
|
||||
out.scatter_(dim=-2, index=gather(a_idx.expand(B, a_idx.shape[1], 1), dim=1, index=src_idx).expand(B, r, c), src=src)
|
||||
|
||||
return out
|
||||
|
||||
return merge, unmerge
|
||||
|
||||
|
||||
def get_functions(x, ratio, original_shape):
|
||||
b, c, original_h, original_w = original_shape
|
||||
original_tokens = original_h * original_w
|
||||
downsample = int(math.ceil(math.sqrt(original_tokens // x.shape[1])))
|
||||
stride_x = 2
|
||||
stride_y = 2
|
||||
max_downsample = 1
|
||||
|
||||
if downsample <= max_downsample:
|
||||
w = int(math.ceil(original_w / downsample))
|
||||
h = int(math.ceil(original_h / downsample))
|
||||
r = int(x.shape[1] * ratio)
|
||||
no_rand = False
|
||||
m, u = bipartite_soft_matching_random2d(x, w, h, stride_x, stride_y, r, no_rand)
|
||||
return m, u
|
||||
|
||||
nothing = lambda y: y
|
||||
return nothing, nothing
|
Reference in New Issue
Block a user