mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-09-13 13:05:07 +00:00
Made HunyuanVideo work with optimized_attention_override
This commit is contained in:
@@ -78,13 +78,13 @@ class TokenRefinerBlock(nn.Module):
|
|||||||
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
|
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x, c, mask):
|
def forward(self, x, c, mask, transformer_options={}):
|
||||||
mod1, mod2 = self.adaLN_modulation(c).chunk(2, dim=1)
|
mod1, mod2 = self.adaLN_modulation(c).chunk(2, dim=1)
|
||||||
|
|
||||||
norm_x = self.norm1(x)
|
norm_x = self.norm1(x)
|
||||||
qkv = self.self_attn.qkv(norm_x)
|
qkv = self.self_attn.qkv(norm_x)
|
||||||
q, k, v = qkv.reshape(qkv.shape[0], qkv.shape[1], 3, self.heads, -1).permute(2, 0, 3, 1, 4)
|
q, k, v = qkv.reshape(qkv.shape[0], qkv.shape[1], 3, self.heads, -1).permute(2, 0, 3, 1, 4)
|
||||||
attn = optimized_attention(q, k, v, self.heads, mask=mask, skip_reshape=True)
|
attn = optimized_attention(q, k, v, self.heads, mask=mask, skip_reshape=True, transformer_options=transformer_options)
|
||||||
|
|
||||||
x = x + self.self_attn.proj(attn) * mod1.unsqueeze(1)
|
x = x + self.self_attn.proj(attn) * mod1.unsqueeze(1)
|
||||||
x = x + self.mlp(self.norm2(x)) * mod2.unsqueeze(1)
|
x = x + self.mlp(self.norm2(x)) * mod2.unsqueeze(1)
|
||||||
@@ -115,14 +115,14 @@ class IndividualTokenRefiner(nn.Module):
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x, c, mask):
|
def forward(self, x, c, mask, transformer_options={}):
|
||||||
m = None
|
m = None
|
||||||
if mask is not None:
|
if mask is not None:
|
||||||
m = mask.view(mask.shape[0], 1, 1, mask.shape[1]).repeat(1, 1, mask.shape[1], 1)
|
m = mask.view(mask.shape[0], 1, 1, mask.shape[1]).repeat(1, 1, mask.shape[1], 1)
|
||||||
m = m + m.transpose(2, 3)
|
m = m + m.transpose(2, 3)
|
||||||
|
|
||||||
for block in self.blocks:
|
for block in self.blocks:
|
||||||
x = block(x, c, m)
|
x = block(x, c, m, transformer_options=transformer_options)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
@@ -150,6 +150,7 @@ class TokenRefiner(nn.Module):
|
|||||||
x,
|
x,
|
||||||
timesteps,
|
timesteps,
|
||||||
mask,
|
mask,
|
||||||
|
transformer_options={},
|
||||||
):
|
):
|
||||||
t = self.t_embedder(timestep_embedding(timesteps, 256, time_factor=1.0).to(x.dtype))
|
t = self.t_embedder(timestep_embedding(timesteps, 256, time_factor=1.0).to(x.dtype))
|
||||||
# m = mask.float().unsqueeze(-1)
|
# m = mask.float().unsqueeze(-1)
|
||||||
@@ -158,7 +159,7 @@ class TokenRefiner(nn.Module):
|
|||||||
|
|
||||||
c = t + self.c_embedder(c.to(x.dtype))
|
c = t + self.c_embedder(c.to(x.dtype))
|
||||||
x = self.input_embedder(x)
|
x = self.input_embedder(x)
|
||||||
x = self.individual_token_refiner(x, c, mask)
|
x = self.individual_token_refiner(x, c, mask, transformer_options=transformer_options)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
class HunyuanVideo(nn.Module):
|
class HunyuanVideo(nn.Module):
|
||||||
@@ -267,7 +268,7 @@ class HunyuanVideo(nn.Module):
|
|||||||
if txt_mask is not None and not torch.is_floating_point(txt_mask):
|
if txt_mask is not None and not torch.is_floating_point(txt_mask):
|
||||||
txt_mask = (txt_mask - 1).to(img.dtype) * torch.finfo(img.dtype).max
|
txt_mask = (txt_mask - 1).to(img.dtype) * torch.finfo(img.dtype).max
|
||||||
|
|
||||||
txt = self.txt_in(txt, timesteps, txt_mask)
|
txt = self.txt_in(txt, timesteps, txt_mask, transformer_options=transformer_options)
|
||||||
|
|
||||||
ids = torch.cat((img_ids, txt_ids), dim=1)
|
ids = torch.cat((img_ids, txt_ids), dim=1)
|
||||||
pe = self.pe_embedder(ids)
|
pe = self.pe_embedder(ids)
|
||||||
@@ -285,14 +286,14 @@ class HunyuanVideo(nn.Module):
|
|||||||
if ("double_block", i) in blocks_replace:
|
if ("double_block", i) in blocks_replace:
|
||||||
def block_wrap(args):
|
def block_wrap(args):
|
||||||
out = {}
|
out = {}
|
||||||
out["img"], out["txt"] = block(img=args["img"], txt=args["txt"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"], modulation_dims_img=args["modulation_dims_img"], modulation_dims_txt=args["modulation_dims_txt"])
|
out["img"], out["txt"] = block(img=args["img"], txt=args["txt"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"], modulation_dims_img=args["modulation_dims_img"], modulation_dims_txt=args["modulation_dims_txt"], transformer_options=args["transformer_options"])
|
||||||
return out
|
return out
|
||||||
|
|
||||||
out = blocks_replace[("double_block", i)]({"img": img, "txt": txt, "vec": vec, "pe": pe, "attention_mask": attn_mask, 'modulation_dims_img': modulation_dims, 'modulation_dims_txt': modulation_dims_txt}, {"original_block": block_wrap})
|
out = blocks_replace[("double_block", i)]({"img": img, "txt": txt, "vec": vec, "pe": pe, "attention_mask": attn_mask, 'modulation_dims_img': modulation_dims, 'modulation_dims_txt': modulation_dims_txt, 'transformer_options': transformer_options}, {"original_block": block_wrap})
|
||||||
txt = out["txt"]
|
txt = out["txt"]
|
||||||
img = out["img"]
|
img = out["img"]
|
||||||
else:
|
else:
|
||||||
img, txt = block(img=img, txt=txt, vec=vec, pe=pe, attn_mask=attn_mask, modulation_dims_img=modulation_dims, modulation_dims_txt=modulation_dims_txt)
|
img, txt = block(img=img, txt=txt, vec=vec, pe=pe, attn_mask=attn_mask, modulation_dims_img=modulation_dims, modulation_dims_txt=modulation_dims_txt, transformer_options=transformer_options)
|
||||||
|
|
||||||
if control is not None: # Controlnet
|
if control is not None: # Controlnet
|
||||||
control_i = control.get("input")
|
control_i = control.get("input")
|
||||||
@@ -307,13 +308,13 @@ class HunyuanVideo(nn.Module):
|
|||||||
if ("single_block", i) in blocks_replace:
|
if ("single_block", i) in blocks_replace:
|
||||||
def block_wrap(args):
|
def block_wrap(args):
|
||||||
out = {}
|
out = {}
|
||||||
out["img"] = block(args["img"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"], modulation_dims=args["modulation_dims"])
|
out["img"] = block(args["img"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"], modulation_dims=args["modulation_dims"], transformer_options=args["transformer_options"])
|
||||||
return out
|
return out
|
||||||
|
|
||||||
out = blocks_replace[("single_block", i)]({"img": img, "vec": vec, "pe": pe, "attention_mask": attn_mask, 'modulation_dims': modulation_dims}, {"original_block": block_wrap})
|
out = blocks_replace[("single_block", i)]({"img": img, "vec": vec, "pe": pe, "attention_mask": attn_mask, 'modulation_dims': modulation_dims, 'transformer_options': transformer_options}, {"original_block": block_wrap})
|
||||||
img = out["img"]
|
img = out["img"]
|
||||||
else:
|
else:
|
||||||
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, modulation_dims=modulation_dims)
|
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, modulation_dims=modulation_dims, transformer_options=transformer_options)
|
||||||
|
|
||||||
if control is not None: # Controlnet
|
if control is not None: # Controlnet
|
||||||
control_o = control.get("output")
|
control_o = control.get("output")
|
||||||
|
Reference in New Issue
Block a user