mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-06-09 07:37:14 +00:00
Refactor the attention functions.
There's no reason for the whole CrossAttention object to be repeated when only the operation in the middle changes.
This commit is contained in:
parent
8cc75c64ff
commit
1a4bd9e9a6
@ -94,95 +94,41 @@ def zero_module(module):
|
|||||||
def Normalize(in_channels, dtype=None, device=None):
|
def Normalize(in_channels, dtype=None, device=None):
|
||||||
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device)
|
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
def attention_basic(q, k, v, heads, mask=None):
|
||||||
|
h = heads
|
||||||
|
scale = (q.shape[-1] // heads) ** -0.5
|
||||||
|
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
||||||
|
|
||||||
class SpatialSelfAttention(nn.Module):
|
# force cast to fp32 to avoid overflowing
|
||||||
def __init__(self, in_channels):
|
if _ATTN_PRECISION =="fp32":
|
||||||
super().__init__()
|
with torch.autocast(enabled=False, device_type = 'cuda'):
|
||||||
self.in_channels = in_channels
|
q, k = q.float(), k.float()
|
||||||
|
sim = einsum('b i d, b j d -> b i j', q, k) * scale
|
||||||
self.norm = Normalize(in_channels)
|
|
||||||
self.q = torch.nn.Conv2d(in_channels,
|
|
||||||
in_channels,
|
|
||||||
kernel_size=1,
|
|
||||||
stride=1,
|
|
||||||
padding=0)
|
|
||||||
self.k = torch.nn.Conv2d(in_channels,
|
|
||||||
in_channels,
|
|
||||||
kernel_size=1,
|
|
||||||
stride=1,
|
|
||||||
padding=0)
|
|
||||||
self.v = torch.nn.Conv2d(in_channels,
|
|
||||||
in_channels,
|
|
||||||
kernel_size=1,
|
|
||||||
stride=1,
|
|
||||||
padding=0)
|
|
||||||
self.proj_out = torch.nn.Conv2d(in_channels,
|
|
||||||
in_channels,
|
|
||||||
kernel_size=1,
|
|
||||||
stride=1,
|
|
||||||
padding=0)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
h_ = x
|
|
||||||
h_ = self.norm(h_)
|
|
||||||
q = self.q(h_)
|
|
||||||
k = self.k(h_)
|
|
||||||
v = self.v(h_)
|
|
||||||
|
|
||||||
# compute attention
|
|
||||||
b,c,h,w = q.shape
|
|
||||||
q = rearrange(q, 'b c h w -> b (h w) c')
|
|
||||||
k = rearrange(k, 'b c h w -> b c (h w)')
|
|
||||||
w_ = torch.einsum('bij,bjk->bik', q, k)
|
|
||||||
|
|
||||||
w_ = w_ * (int(c)**(-0.5))
|
|
||||||
w_ = torch.nn.functional.softmax(w_, dim=2)
|
|
||||||
|
|
||||||
# attend to values
|
|
||||||
v = rearrange(v, 'b c h w -> b c (h w)')
|
|
||||||
w_ = rearrange(w_, 'b i j -> b j i')
|
|
||||||
h_ = torch.einsum('bij,bjk->bik', v, w_)
|
|
||||||
h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h)
|
|
||||||
h_ = self.proj_out(h_)
|
|
||||||
|
|
||||||
return x+h_
|
|
||||||
|
|
||||||
|
|
||||||
class CrossAttentionBirchSan(nn.Module):
|
|
||||||
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=comfy.ops):
|
|
||||||
super().__init__()
|
|
||||||
inner_dim = dim_head * heads
|
|
||||||
context_dim = default(context_dim, query_dim)
|
|
||||||
|
|
||||||
self.scale = dim_head ** -0.5
|
|
||||||
self.heads = heads
|
|
||||||
|
|
||||||
self.to_q = operations.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
|
||||||
self.to_k = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
|
||||||
self.to_v = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
|
||||||
|
|
||||||
self.to_out = nn.Sequential(
|
|
||||||
operations.Linear(inner_dim, query_dim, dtype=dtype, device=device),
|
|
||||||
nn.Dropout(dropout)
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x, context=None, value=None, mask=None):
|
|
||||||
h = self.heads
|
|
||||||
|
|
||||||
query = self.to_q(x)
|
|
||||||
context = default(context, x)
|
|
||||||
key = self.to_k(context)
|
|
||||||
if value is not None:
|
|
||||||
value = self.to_v(value)
|
|
||||||
else:
|
else:
|
||||||
value = self.to_v(context)
|
sim = einsum('b i d, b j d -> b i j', q, k) * scale
|
||||||
|
|
||||||
del context, x
|
del q, k
|
||||||
|
|
||||||
query = query.unflatten(-1, (self.heads, -1)).transpose(1,2).flatten(end_dim=1)
|
if exists(mask):
|
||||||
key_t = key.transpose(1,2).unflatten(1, (self.heads, -1)).flatten(end_dim=1)
|
mask = rearrange(mask, 'b ... -> b (...)')
|
||||||
|
max_neg_value = -torch.finfo(sim.dtype).max
|
||||||
|
mask = repeat(mask, 'b j -> (b h) () j', h=h)
|
||||||
|
sim.masked_fill_(~mask, max_neg_value)
|
||||||
|
|
||||||
|
# attention, what we cannot get enough of
|
||||||
|
sim = sim.softmax(dim=-1)
|
||||||
|
|
||||||
|
out = einsum('b i j, b j d -> b i d', sim.to(v.dtype), v)
|
||||||
|
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def attention_sub_quad(query, key, value, heads, mask=None):
|
||||||
|
scale = (query.shape[-1] // heads) ** -0.5
|
||||||
|
query = query.unflatten(-1, (heads, -1)).transpose(1,2).flatten(end_dim=1)
|
||||||
|
key_t = key.transpose(1,2).unflatten(1, (heads, -1)).flatten(end_dim=1)
|
||||||
del key
|
del key
|
||||||
value = value.unflatten(-1, (self.heads, -1)).transpose(1,2).flatten(end_dim=1)
|
value = value.unflatten(-1, (heads, -1)).transpose(1,2).flatten(end_dim=1)
|
||||||
|
|
||||||
dtype = query.dtype
|
dtype = query.dtype
|
||||||
upcast_attention = _ATTN_PRECISION =="fp32" and query.dtype != torch.float32
|
upcast_attention = _ATTN_PRECISION =="fp32" and query.dtype != torch.float32
|
||||||
@ -230,54 +176,19 @@ class CrossAttentionBirchSan(nn.Module):
|
|||||||
query_chunk_size=query_chunk_size,
|
query_chunk_size=query_chunk_size,
|
||||||
kv_chunk_size=kv_chunk_size,
|
kv_chunk_size=kv_chunk_size,
|
||||||
kv_chunk_size_min=kv_chunk_size_min,
|
kv_chunk_size_min=kv_chunk_size_min,
|
||||||
use_checkpoint=self.training,
|
use_checkpoint=False,
|
||||||
upcast_attention=upcast_attention,
|
upcast_attention=upcast_attention,
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = hidden_states.to(dtype)
|
hidden_states = hidden_states.to(dtype)
|
||||||
|
|
||||||
hidden_states = hidden_states.unflatten(0, (-1, self.heads)).transpose(1,2).flatten(start_dim=2)
|
hidden_states = hidden_states.unflatten(0, (-1, heads)).transpose(1,2).flatten(start_dim=2)
|
||||||
|
|
||||||
out_proj, dropout = self.to_out
|
|
||||||
hidden_states = out_proj(hidden_states)
|
|
||||||
hidden_states = dropout(hidden_states)
|
|
||||||
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
def attention_split(q, k, v, heads, mask=None):
|
||||||
class CrossAttentionDoggettx(nn.Module):
|
scale = (q.shape[-1] // heads) ** -0.5
|
||||||
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=comfy.ops):
|
h = heads
|
||||||
super().__init__()
|
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
||||||
inner_dim = dim_head * heads
|
|
||||||
context_dim = default(context_dim, query_dim)
|
|
||||||
|
|
||||||
self.scale = dim_head ** -0.5
|
|
||||||
self.heads = heads
|
|
||||||
|
|
||||||
self.to_q = operations.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
|
||||||
self.to_k = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
|
||||||
self.to_v = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
|
||||||
|
|
||||||
self.to_out = nn.Sequential(
|
|
||||||
operations.Linear(inner_dim, query_dim, dtype=dtype, device=device),
|
|
||||||
nn.Dropout(dropout)
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x, context=None, value=None, mask=None):
|
|
||||||
h = self.heads
|
|
||||||
|
|
||||||
q_in = self.to_q(x)
|
|
||||||
context = default(context, x)
|
|
||||||
k_in = self.to_k(context)
|
|
||||||
if value is not None:
|
|
||||||
v_in = self.to_v(value)
|
|
||||||
del value
|
|
||||||
else:
|
|
||||||
v_in = self.to_v(context)
|
|
||||||
del context, x
|
|
||||||
|
|
||||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
|
|
||||||
del q_in, k_in, v_in
|
|
||||||
|
|
||||||
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
||||||
|
|
||||||
@ -310,9 +221,9 @@ class CrossAttentionDoggettx(nn.Module):
|
|||||||
end = i + slice_size
|
end = i + slice_size
|
||||||
if _ATTN_PRECISION =="fp32":
|
if _ATTN_PRECISION =="fp32":
|
||||||
with torch.autocast(enabled=False, device_type = 'cuda'):
|
with torch.autocast(enabled=False, device_type = 'cuda'):
|
||||||
s1 = einsum('b i d, b j d -> b i j', q[:, i:end].float(), k.float()) * self.scale
|
s1 = einsum('b i d, b j d -> b i j', q[:, i:end].float(), k.float()) * scale
|
||||||
else:
|
else:
|
||||||
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * self.scale
|
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * scale
|
||||||
first_op_done = True
|
first_op_done = True
|
||||||
|
|
||||||
s2 = s1.softmax(dim=-1).to(v.dtype)
|
s2 = s1.softmax(dim=-1).to(v.dtype)
|
||||||
@ -339,143 +250,37 @@ class CrossAttentionDoggettx(nn.Module):
|
|||||||
|
|
||||||
r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
|
r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
|
||||||
del r1
|
del r1
|
||||||
|
return r2
|
||||||
|
|
||||||
return self.to_out(r2)
|
def attention_xformers(q, k, v, heads, mask=None):
|
||||||
|
|
||||||
class CrossAttention(nn.Module):
|
|
||||||
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=comfy.ops):
|
|
||||||
super().__init__()
|
|
||||||
inner_dim = dim_head * heads
|
|
||||||
context_dim = default(context_dim, query_dim)
|
|
||||||
|
|
||||||
self.scale = dim_head ** -0.5
|
|
||||||
self.heads = heads
|
|
||||||
|
|
||||||
self.to_q = operations.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
|
||||||
self.to_k = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
|
||||||
self.to_v = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
|
||||||
|
|
||||||
self.to_out = nn.Sequential(
|
|
||||||
operations.Linear(inner_dim, query_dim, dtype=dtype, device=device),
|
|
||||||
nn.Dropout(dropout)
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x, context=None, value=None, mask=None):
|
|
||||||
h = self.heads
|
|
||||||
|
|
||||||
q = self.to_q(x)
|
|
||||||
context = default(context, x)
|
|
||||||
k = self.to_k(context)
|
|
||||||
if value is not None:
|
|
||||||
v = self.to_v(value)
|
|
||||||
del value
|
|
||||||
else:
|
|
||||||
v = self.to_v(context)
|
|
||||||
|
|
||||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
|
||||||
|
|
||||||
# force cast to fp32 to avoid overflowing
|
|
||||||
if _ATTN_PRECISION =="fp32":
|
|
||||||
with torch.autocast(enabled=False, device_type = 'cuda'):
|
|
||||||
q, k = q.float(), k.float()
|
|
||||||
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
|
|
||||||
else:
|
|
||||||
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
|
|
||||||
|
|
||||||
del q, k
|
|
||||||
|
|
||||||
if exists(mask):
|
|
||||||
mask = rearrange(mask, 'b ... -> b (...)')
|
|
||||||
max_neg_value = -torch.finfo(sim.dtype).max
|
|
||||||
mask = repeat(mask, 'b j -> (b h) () j', h=h)
|
|
||||||
sim.masked_fill_(~mask, max_neg_value)
|
|
||||||
|
|
||||||
# attention, what we cannot get enough of
|
|
||||||
sim = sim.softmax(dim=-1)
|
|
||||||
|
|
||||||
out = einsum('b i j, b j d -> b i d', sim, v)
|
|
||||||
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
|
|
||||||
return self.to_out(out)
|
|
||||||
|
|
||||||
class MemoryEfficientCrossAttention(nn.Module):
|
|
||||||
# https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
|
|
||||||
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, dtype=None, device=None, operations=comfy.ops):
|
|
||||||
super().__init__()
|
|
||||||
inner_dim = dim_head * heads
|
|
||||||
context_dim = default(context_dim, query_dim)
|
|
||||||
|
|
||||||
self.heads = heads
|
|
||||||
self.dim_head = dim_head
|
|
||||||
|
|
||||||
self.to_q = operations.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
|
||||||
self.to_k = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
|
||||||
self.to_v = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
|
||||||
|
|
||||||
self.to_out = nn.Sequential(operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout))
|
|
||||||
self.attention_op: Optional[Any] = None
|
|
||||||
|
|
||||||
def forward(self, x, context=None, value=None, mask=None):
|
|
||||||
q = self.to_q(x)
|
|
||||||
context = default(context, x)
|
|
||||||
k = self.to_k(context)
|
|
||||||
if value is not None:
|
|
||||||
v = self.to_v(value)
|
|
||||||
del value
|
|
||||||
else:
|
|
||||||
v = self.to_v(context)
|
|
||||||
|
|
||||||
b, _, _ = q.shape
|
b, _, _ = q.shape
|
||||||
q, k, v = map(
|
q, k, v = map(
|
||||||
lambda t: t.unsqueeze(3)
|
lambda t: t.unsqueeze(3)
|
||||||
.reshape(b, t.shape[1], self.heads, self.dim_head)
|
.reshape(b, t.shape[1], heads, -1)
|
||||||
.permute(0, 2, 1, 3)
|
.permute(0, 2, 1, 3)
|
||||||
.reshape(b * self.heads, t.shape[1], self.dim_head)
|
.reshape(b * heads, t.shape[1], -1)
|
||||||
.contiguous(),
|
.contiguous(),
|
||||||
(q, k, v),
|
(q, k, v),
|
||||||
)
|
)
|
||||||
|
|
||||||
# actually compute the attention, what we cannot get enough of
|
# actually compute the attention, what we cannot get enough of
|
||||||
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
|
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None)
|
||||||
|
|
||||||
if exists(mask):
|
if exists(mask):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
out = (
|
out = (
|
||||||
out.unsqueeze(0)
|
out.unsqueeze(0)
|
||||||
.reshape(b, self.heads, out.shape[1], self.dim_head)
|
.reshape(b, heads, out.shape[1], -1)
|
||||||
.permute(0, 2, 1, 3)
|
.permute(0, 2, 1, 3)
|
||||||
.reshape(b, out.shape[1], self.heads * self.dim_head)
|
.reshape(b, out.shape[1], -1)
|
||||||
)
|
)
|
||||||
return self.to_out(out)
|
return out
|
||||||
|
|
||||||
class CrossAttentionPytorch(nn.Module):
|
def attention_pytorch(q, k, v, heads, mask=None):
|
||||||
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=comfy.ops):
|
b, _, dim_head = q.shape
|
||||||
super().__init__()
|
dim_head //= heads
|
||||||
inner_dim = dim_head * heads
|
|
||||||
context_dim = default(context_dim, query_dim)
|
|
||||||
|
|
||||||
self.heads = heads
|
|
||||||
self.dim_head = dim_head
|
|
||||||
|
|
||||||
self.to_q = operations.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
|
||||||
self.to_k = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
|
||||||
self.to_v = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
|
||||||
|
|
||||||
self.to_out = nn.Sequential(operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout))
|
|
||||||
self.attention_op: Optional[Any] = None
|
|
||||||
|
|
||||||
def forward(self, x, context=None, value=None, mask=None):
|
|
||||||
q = self.to_q(x)
|
|
||||||
context = default(context, x)
|
|
||||||
k = self.to_k(context)
|
|
||||||
if value is not None:
|
|
||||||
v = self.to_v(value)
|
|
||||||
del value
|
|
||||||
else:
|
|
||||||
v = self.to_v(context)
|
|
||||||
|
|
||||||
b, _, _ = q.shape
|
|
||||||
q, k, v = map(
|
q, k, v = map(
|
||||||
lambda t: t.view(b, -1, self.heads, self.dim_head).transpose(1, 2),
|
lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2),
|
||||||
(q, k, v),
|
(q, k, v),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -484,24 +289,53 @@ class CrossAttentionPytorch(nn.Module):
|
|||||||
if exists(mask):
|
if exists(mask):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
out = (
|
out = (
|
||||||
out.transpose(1, 2).reshape(b, -1, self.heads * self.dim_head)
|
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
|
||||||
)
|
)
|
||||||
|
return out
|
||||||
|
|
||||||
return self.to_out(out)
|
optimized_attention = attention_basic
|
||||||
|
|
||||||
if model_management.xformers_enabled():
|
if model_management.xformers_enabled():
|
||||||
print("Using xformers cross attention")
|
print("Using xformers cross attention")
|
||||||
CrossAttention = MemoryEfficientCrossAttention
|
optimized_attention = attention_xformers
|
||||||
elif model_management.pytorch_attention_enabled():
|
elif model_management.pytorch_attention_enabled():
|
||||||
print("Using pytorch cross attention")
|
print("Using pytorch cross attention")
|
||||||
CrossAttention = CrossAttentionPytorch
|
optimized_attention = attention_pytorch
|
||||||
else:
|
else:
|
||||||
if args.use_split_cross_attention:
|
if args.use_split_cross_attention:
|
||||||
print("Using split optimization for cross attention")
|
print("Using split optimization for cross attention")
|
||||||
CrossAttention = CrossAttentionDoggettx
|
optimized_attention = attention_split
|
||||||
else:
|
else:
|
||||||
print("Using sub quadratic optimization for cross attention, if you have memory or speed issues try using: --use-split-cross-attention")
|
print("Using sub quadratic optimization for cross attention, if you have memory or speed issues try using: --use-split-cross-attention")
|
||||||
CrossAttention = CrossAttentionBirchSan
|
optimized_attention = attention_sub_quad
|
||||||
|
|
||||||
|
class CrossAttention(nn.Module):
|
||||||
|
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=comfy.ops):
|
||||||
|
super().__init__()
|
||||||
|
inner_dim = dim_head * heads
|
||||||
|
context_dim = default(context_dim, query_dim)
|
||||||
|
|
||||||
|
self.heads = heads
|
||||||
|
self.dim_head = dim_head
|
||||||
|
|
||||||
|
self.to_q = operations.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
||||||
|
self.to_k = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
||||||
|
self.to_v = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
self.to_out = nn.Sequential(operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout))
|
||||||
|
|
||||||
|
def forward(self, x, context=None, value=None, mask=None):
|
||||||
|
q = self.to_q(x)
|
||||||
|
context = default(context, x)
|
||||||
|
k = self.to_k(context)
|
||||||
|
if value is not None:
|
||||||
|
v = self.to_v(value)
|
||||||
|
del value
|
||||||
|
else:
|
||||||
|
v = self.to_v(context)
|
||||||
|
|
||||||
|
out = optimized_attention(q, k, v, self.heads, mask)
|
||||||
|
return self.to_out(out)
|
||||||
|
|
||||||
|
|
||||||
class BasicTransformerBlock(nn.Module):
|
class BasicTransformerBlock(nn.Module):
|
||||||
|
@ -6,7 +6,6 @@ import numpy as np
|
|||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
from typing import Optional, Any
|
from typing import Optional, Any
|
||||||
|
|
||||||
from ..attention import MemoryEfficientCrossAttention
|
|
||||||
from comfy import model_management
|
from comfy import model_management
|
||||||
import comfy.ops
|
import comfy.ops
|
||||||
|
|
||||||
@ -352,15 +351,6 @@ class MemoryEfficientAttnBlockPytorch(nn.Module):
|
|||||||
out = self.proj_out(out)
|
out = self.proj_out(out)
|
||||||
return x+out
|
return x+out
|
||||||
|
|
||||||
class MemoryEfficientCrossAttentionWrapper(MemoryEfficientCrossAttention):
|
|
||||||
def forward(self, x, context=None, mask=None):
|
|
||||||
b, c, h, w = x.shape
|
|
||||||
x = rearrange(x, 'b c h w -> b (h w) c')
|
|
||||||
out = super().forward(x, context=context, mask=mask)
|
|
||||||
out = rearrange(out, 'b (h w) c -> b c h w', h=h, w=w, c=c)
|
|
||||||
return x + out
|
|
||||||
|
|
||||||
|
|
||||||
def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
|
def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
|
||||||
assert attn_type in ["vanilla", "vanilla-xformers", "memory-efficient-cross-attn", "linear", "none"], f'attn_type {attn_type} unknown'
|
assert attn_type in ["vanilla", "vanilla-xformers", "memory-efficient-cross-attn", "linear", "none"], f'attn_type {attn_type} unknown'
|
||||||
if model_management.xformers_enabled_vae() and attn_type == "vanilla":
|
if model_management.xformers_enabled_vae() and attn_type == "vanilla":
|
||||||
@ -376,9 +366,6 @@ def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
|
|||||||
return MemoryEfficientAttnBlock(in_channels)
|
return MemoryEfficientAttnBlock(in_channels)
|
||||||
elif attn_type == "vanilla-pytorch":
|
elif attn_type == "vanilla-pytorch":
|
||||||
return MemoryEfficientAttnBlockPytorch(in_channels)
|
return MemoryEfficientAttnBlockPytorch(in_channels)
|
||||||
elif type == "memory-efficient-cross-attn":
|
|
||||||
attn_kwargs["query_dim"] = in_channels
|
|
||||||
return MemoryEfficientCrossAttentionWrapper(**attn_kwargs)
|
|
||||||
elif attn_type == "none":
|
elif attn_type == "none":
|
||||||
return nn.Identity(in_channels)
|
return nn.Identity(in_channels)
|
||||||
else:
|
else:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user