mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-09-13 04:55:53 +00:00
Implement Linear hypernetworks.
Add a HypernetworkLoader node to use hypernetworks.
This commit is contained in:
@@ -163,13 +163,17 @@ class CrossAttentionBirchSan(nn.Module):
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
def forward(self, x, context=None, mask=None):
|
||||
def forward(self, x, context=None, value=None, mask=None):
|
||||
h = self.heads
|
||||
|
||||
query = self.to_q(x)
|
||||
context = default(context, x)
|
||||
key = self.to_k(context)
|
||||
value = self.to_v(context)
|
||||
if value is not None:
|
||||
value = self.to_v(value)
|
||||
else:
|
||||
value = self.to_v(context)
|
||||
|
||||
del context, x
|
||||
|
||||
query = query.unflatten(-1, (self.heads, -1)).transpose(1,2).flatten(end_dim=1)
|
||||
@@ -256,13 +260,17 @@ class CrossAttentionDoggettx(nn.Module):
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
def forward(self, x, context=None, mask=None):
|
||||
def forward(self, x, context=None, value=None, mask=None):
|
||||
h = self.heads
|
||||
|
||||
q_in = self.to_q(x)
|
||||
context = default(context, x)
|
||||
k_in = self.to_k(context)
|
||||
v_in = self.to_v(context)
|
||||
if value is not None:
|
||||
v_in = self.to_v(value)
|
||||
del value
|
||||
else:
|
||||
v_in = self.to_v(context)
|
||||
del context, x
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
|
||||
@@ -350,13 +358,17 @@ class CrossAttention(nn.Module):
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
def forward(self, x, context=None, mask=None):
|
||||
def forward(self, x, context=None, value=None, mask=None):
|
||||
h = self.heads
|
||||
|
||||
q = self.to_q(x)
|
||||
context = default(context, x)
|
||||
k = self.to_k(context)
|
||||
v = self.to_v(context)
|
||||
if value is not None:
|
||||
v = self.to_v(value)
|
||||
del value
|
||||
else:
|
||||
v = self.to_v(context)
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
||||
|
||||
@@ -402,11 +414,15 @@ class MemoryEfficientCrossAttention(nn.Module):
|
||||
self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
|
||||
self.attention_op: Optional[Any] = None
|
||||
|
||||
def forward(self, x, context=None, mask=None):
|
||||
def forward(self, x, context=None, value=None, mask=None):
|
||||
q = self.to_q(x)
|
||||
context = default(context, x)
|
||||
k = self.to_k(context)
|
||||
v = self.to_v(context)
|
||||
if value is not None:
|
||||
v = self.to_v(value)
|
||||
del value
|
||||
else:
|
||||
v = self.to_v(context)
|
||||
|
||||
b, _, _ = q.shape
|
||||
q, k, v = map(
|
||||
@@ -447,11 +463,15 @@ class CrossAttentionPytorch(nn.Module):
|
||||
self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
|
||||
self.attention_op: Optional[Any] = None
|
||||
|
||||
def forward(self, x, context=None, mask=None):
|
||||
def forward(self, x, context=None, value=None, mask=None):
|
||||
q = self.to_q(x)
|
||||
context = default(context, x)
|
||||
k = self.to_k(context)
|
||||
v = self.to_v(context)
|
||||
if value is not None:
|
||||
v = self.to_v(value)
|
||||
del value
|
||||
else:
|
||||
v = self.to_v(context)
|
||||
|
||||
b, _, _ = q.shape
|
||||
q, k, v = map(
|
||||
@@ -512,11 +532,25 @@ class BasicTransformerBlock(nn.Module):
|
||||
transformer_patches = {}
|
||||
|
||||
n = self.norm1(x)
|
||||
if self.disable_self_attn:
|
||||
context_attn1 = context
|
||||
else:
|
||||
context_attn1 = None
|
||||
value_attn1 = None
|
||||
|
||||
if "attn1_patch" in transformer_patches:
|
||||
patch = transformer_patches["attn1_patch"]
|
||||
if context_attn1 is None:
|
||||
context_attn1 = n
|
||||
value_attn1 = context_attn1
|
||||
for p in patch:
|
||||
n, context_attn1, value_attn1 = p(current_index, n, context_attn1, value_attn1)
|
||||
|
||||
if "tomesd" in transformer_options:
|
||||
m, u = tomesd.get_functions(x, transformer_options["tomesd"]["ratio"], transformer_options["original_shape"])
|
||||
n = u(self.attn1(m(n), context=context if self.disable_self_attn else None))
|
||||
n = u(self.attn1(m(n), context=context_attn1, value=value_attn1))
|
||||
else:
|
||||
n = self.attn1(n, context=context if self.disable_self_attn else None)
|
||||
n = self.attn1(n, context=context_attn1, value=value_attn1)
|
||||
|
||||
x += n
|
||||
if "middle_patch" in transformer_patches:
|
||||
@@ -525,7 +559,16 @@ class BasicTransformerBlock(nn.Module):
|
||||
x = p(current_index, x)
|
||||
|
||||
n = self.norm2(x)
|
||||
n = self.attn2(n, context=context)
|
||||
|
||||
context_attn2 = context
|
||||
value_attn2 = None
|
||||
if "attn2_patch" in transformer_patches:
|
||||
patch = transformer_patches["attn2_patch"]
|
||||
value_attn2 = context_attn2
|
||||
for p in patch:
|
||||
n, context_attn2, value_attn2 = p(current_index, n, context_attn2, value_attn2)
|
||||
|
||||
n = self.attn2(n, context=context_attn2, value=value_attn2)
|
||||
|
||||
x += n
|
||||
x = self.ff(self.norm3(x)) + x
|
||||
|
Reference in New Issue
Block a user