mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-09-12 04:27:21 +00:00
Make flux work with optimized_attention_override
This commit is contained in:
@@ -159,7 +159,7 @@ class DoubleStreamBlock(nn.Module):
|
|||||||
)
|
)
|
||||||
self.flipped_img_txt = flipped_img_txt
|
self.flipped_img_txt = flipped_img_txt
|
||||||
|
|
||||||
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims_img=None, modulation_dims_txt=None):
|
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims_img=None, modulation_dims_txt=None, transformer_options={}):
|
||||||
img_mod1, img_mod2 = self.img_mod(vec)
|
img_mod1, img_mod2 = self.img_mod(vec)
|
||||||
txt_mod1, txt_mod2 = self.txt_mod(vec)
|
txt_mod1, txt_mod2 = self.txt_mod(vec)
|
||||||
|
|
||||||
@@ -182,7 +182,7 @@ class DoubleStreamBlock(nn.Module):
|
|||||||
attn = attention(torch.cat((img_q, txt_q), dim=2),
|
attn = attention(torch.cat((img_q, txt_q), dim=2),
|
||||||
torch.cat((img_k, txt_k), dim=2),
|
torch.cat((img_k, txt_k), dim=2),
|
||||||
torch.cat((img_v, txt_v), dim=2),
|
torch.cat((img_v, txt_v), dim=2),
|
||||||
pe=pe, mask=attn_mask)
|
pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
||||||
|
|
||||||
img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1]:]
|
img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1]:]
|
||||||
else:
|
else:
|
||||||
@@ -190,7 +190,7 @@ class DoubleStreamBlock(nn.Module):
|
|||||||
attn = attention(torch.cat((txt_q, img_q), dim=2),
|
attn = attention(torch.cat((txt_q, img_q), dim=2),
|
||||||
torch.cat((txt_k, img_k), dim=2),
|
torch.cat((txt_k, img_k), dim=2),
|
||||||
torch.cat((txt_v, img_v), dim=2),
|
torch.cat((txt_v, img_v), dim=2),
|
||||||
pe=pe, mask=attn_mask)
|
pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
||||||
|
|
||||||
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
|
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
|
||||||
|
|
||||||
@@ -244,7 +244,7 @@ class SingleStreamBlock(nn.Module):
|
|||||||
self.mlp_act = nn.GELU(approximate="tanh")
|
self.mlp_act = nn.GELU(approximate="tanh")
|
||||||
self.modulation = Modulation(hidden_size, double=False, dtype=dtype, device=device, operations=operations)
|
self.modulation = Modulation(hidden_size, double=False, dtype=dtype, device=device, operations=operations)
|
||||||
|
|
||||||
def forward(self, x: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims=None) -> Tensor:
|
def forward(self, x: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims=None, transformer_options={}) -> Tensor:
|
||||||
mod, _ = self.modulation(vec)
|
mod, _ = self.modulation(vec)
|
||||||
qkv, mlp = torch.split(self.linear1(apply_mod(self.pre_norm(x), (1 + mod.scale), mod.shift, modulation_dims)), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
qkv, mlp = torch.split(self.linear1(apply_mod(self.pre_norm(x), (1 + mod.scale), mod.shift, modulation_dims)), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
||||||
|
|
||||||
@@ -252,7 +252,7 @@ class SingleStreamBlock(nn.Module):
|
|||||||
q, k = self.norm(q, k, v)
|
q, k = self.norm(q, k, v)
|
||||||
|
|
||||||
# compute attention
|
# compute attention
|
||||||
attn = attention(q, k, v, pe=pe, mask=attn_mask)
|
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
||||||
# compute activation in mlp stream, cat again and run second linear layer
|
# compute activation in mlp stream, cat again and run second linear layer
|
||||||
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
|
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
|
||||||
x += apply_mod(output, mod.gate, None, modulation_dims)
|
x += apply_mod(output, mod.gate, None, modulation_dims)
|
||||||
|
@@ -6,7 +6,7 @@ from comfy.ldm.modules.attention import optimized_attention
|
|||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
|
|
||||||
|
|
||||||
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None) -> Tensor:
|
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None, transformer_options={}) -> Tensor:
|
||||||
q_shape = q.shape
|
q_shape = q.shape
|
||||||
k_shape = k.shape
|
k_shape = k.shape
|
||||||
|
|
||||||
@@ -17,7 +17,7 @@ def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None) -> Tensor:
|
|||||||
k = (pe[..., 0] * k[..., 0] + pe[..., 1] * k[..., 1]).reshape(*k_shape).type_as(v)
|
k = (pe[..., 0] * k[..., 0] + pe[..., 1] * k[..., 1]).reshape(*k_shape).type_as(v)
|
||||||
|
|
||||||
heads = q.shape[1]
|
heads = q.shape[1]
|
||||||
x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask)
|
x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask, transformer_options=transformer_options)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
@@ -135,14 +135,16 @@ class Flux(nn.Module):
|
|||||||
txt=args["txt"],
|
txt=args["txt"],
|
||||||
vec=args["vec"],
|
vec=args["vec"],
|
||||||
pe=args["pe"],
|
pe=args["pe"],
|
||||||
attn_mask=args.get("attn_mask"))
|
attn_mask=args.get("attn_mask"),
|
||||||
|
transformer_options=args.get("transformer_options"))
|
||||||
return out
|
return out
|
||||||
|
|
||||||
out = blocks_replace[("double_block", i)]({"img": img,
|
out = blocks_replace[("double_block", i)]({"img": img,
|
||||||
"txt": txt,
|
"txt": txt,
|
||||||
"vec": vec,
|
"vec": vec,
|
||||||
"pe": pe,
|
"pe": pe,
|
||||||
"attn_mask": attn_mask},
|
"attn_mask": attn_mask,
|
||||||
|
"transformer_options": transformer_options},
|
||||||
{"original_block": block_wrap})
|
{"original_block": block_wrap})
|
||||||
txt = out["txt"]
|
txt = out["txt"]
|
||||||
img = out["img"]
|
img = out["img"]
|
||||||
@@ -151,7 +153,8 @@ class Flux(nn.Module):
|
|||||||
txt=txt,
|
txt=txt,
|
||||||
vec=vec,
|
vec=vec,
|
||||||
pe=pe,
|
pe=pe,
|
||||||
attn_mask=attn_mask)
|
attn_mask=attn_mask,
|
||||||
|
transformer_options=transformer_options)
|
||||||
|
|
||||||
if control is not None: # Controlnet
|
if control is not None: # Controlnet
|
||||||
control_i = control.get("input")
|
control_i = control.get("input")
|
||||||
@@ -172,17 +175,19 @@ class Flux(nn.Module):
|
|||||||
out["img"] = block(args["img"],
|
out["img"] = block(args["img"],
|
||||||
vec=args["vec"],
|
vec=args["vec"],
|
||||||
pe=args["pe"],
|
pe=args["pe"],
|
||||||
attn_mask=args.get("attn_mask"))
|
attn_mask=args.get("attn_mask"),
|
||||||
|
transformer_options=args.get("transformer_options"))
|
||||||
return out
|
return out
|
||||||
|
|
||||||
out = blocks_replace[("single_block", i)]({"img": img,
|
out = blocks_replace[("single_block", i)]({"img": img,
|
||||||
"vec": vec,
|
"vec": vec,
|
||||||
"pe": pe,
|
"pe": pe,
|
||||||
"attn_mask": attn_mask},
|
"attn_mask": attn_mask,
|
||||||
|
"transformer_options": transformer_options},
|
||||||
{"original_block": block_wrap})
|
{"original_block": block_wrap})
|
||||||
img = out["img"]
|
img = out["img"]
|
||||||
else:
|
else:
|
||||||
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask)
|
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, transformer_options=transformer_options)
|
||||||
|
|
||||||
if control is not None: # Controlnet
|
if control is not None: # Controlnet
|
||||||
control_o = control.get("output")
|
control_o = control.get("output")
|
||||||
|
Reference in New Issue
Block a user