mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-09-13 21:16:09 +00:00
Some fixes/cleanups to pixart code.
Commented out the masking related code because it is never used in this implementation.
This commit is contained in:
@@ -46,32 +46,33 @@ class MultiHeadCrossAttention(nn.Module):
|
||||
kv = self.kv_linear(cond).view(1, -1, 2, self.num_heads, self.head_dim)
|
||||
k, v = kv.unbind(2)
|
||||
|
||||
# TODO: xformers needs separate mask logic here
|
||||
if model_management.xformers_enabled():
|
||||
attn_bias = None
|
||||
if mask is not None:
|
||||
attn_bias = block_diagonal_mask_from_seqlens([N] * B, mask)
|
||||
x = xformers.ops.memory_efficient_attention(q, k, v, p=0, attn_bias=attn_bias)
|
||||
else:
|
||||
q, k, v = map(lambda t: t.transpose(1, 2), (q, k, v),)
|
||||
attn_mask = None
|
||||
if mask is not None and len(mask) > 1:
|
||||
# Create equivalent of xformer diagonal block mask, still only correct for square masks
|
||||
# But depth doesn't matter as tensors can expand in that dimension
|
||||
attn_mask_template = torch.ones(
|
||||
[q.shape[2] // B, mask[0]],
|
||||
dtype=torch.bool,
|
||||
device=q.device
|
||||
)
|
||||
attn_mask = torch.block_diag(attn_mask_template)
|
||||
assert mask is None # TODO?
|
||||
# # TODO: xformers needs separate mask logic here
|
||||
# if model_management.xformers_enabled():
|
||||
# attn_bias = None
|
||||
# if mask is not None:
|
||||
# attn_bias = block_diagonal_mask_from_seqlens([N] * B, mask)
|
||||
# x = xformers.ops.memory_efficient_attention(q, k, v, p=0, attn_bias=attn_bias)
|
||||
# else:
|
||||
# q, k, v = map(lambda t: t.transpose(1, 2), (q, k, v),)
|
||||
# attn_mask = None
|
||||
# mask = torch.ones(())
|
||||
# if mask is not None and len(mask) > 1:
|
||||
# # Create equivalent of xformer diagonal block mask, still only correct for square masks
|
||||
# # But depth doesn't matter as tensors can expand in that dimension
|
||||
# attn_mask_template = torch.ones(
|
||||
# [q.shape[2] // B, mask[0]],
|
||||
# dtype=torch.bool,
|
||||
# device=q.device
|
||||
# )
|
||||
# attn_mask = torch.block_diag(attn_mask_template)
|
||||
#
|
||||
# # create a mask on the diagonal for each mask in the batch
|
||||
# for _ in range(B - 1):
|
||||
# attn_mask = torch.block_diag(attn_mask, attn_mask_template)
|
||||
# x = optimized_attention(q, k, v, self.num_heads, mask=attn_mask, skip_reshape=True)
|
||||
|
||||
# create a mask on the diagonal for each mask in the batch
|
||||
for _ in range(B - 1):
|
||||
attn_mask = torch.block_diag(attn_mask, attn_mask_template)
|
||||
|
||||
x = optimized_attention(q, k, v, self.num_heads, mask=attn_mask, skip_reshape=True)
|
||||
|
||||
x = x.view(B, -1, C)
|
||||
x = optimized_attention(q.view(B, -1, C), k.view(B, -1, C), v.view(B, -1, C), self.num_heads, mask=None)
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
@@ -155,9 +156,9 @@ class AttentionKVCompress(nn.Module):
|
||||
k, new_N = self.downsample_2d(k, H, W, self.sr_ratio, sampling=self.sampling)
|
||||
v, new_N = self.downsample_2d(v, H, W, self.sr_ratio, sampling=self.sampling)
|
||||
|
||||
q = q.reshape(B, N, self.num_heads, C // self.num_heads).to(dtype)
|
||||
k = k.reshape(B, new_N, self.num_heads, C // self.num_heads).to(dtype)
|
||||
v = v.reshape(B, new_N, self.num_heads, C // self.num_heads).to(dtype)
|
||||
q = q.reshape(B, N, self.num_heads, C // self.num_heads)
|
||||
k = k.reshape(B, new_N, self.num_heads, C // self.num_heads)
|
||||
v = v.reshape(B, new_N, self.num_heads, C // self.num_heads)
|
||||
|
||||
if mask is not None:
|
||||
raise NotImplementedError("Attn mask logic not added for self attention")
|
||||
@@ -209,9 +210,9 @@ class T2IFinalLayer(nn.Module):
|
||||
|
||||
def forward(self, x, t):
|
||||
dtype = x.dtype
|
||||
shift, scale = (self.scale_shift_table[None] + t[:, None]).chunk(2, dim=1)
|
||||
shift, scale = (self.scale_shift_table[None].to(dtype=x.dtype, device=x.device) + t[:, None]).chunk(2, dim=1)
|
||||
x = t2i_modulate(self.norm_final(x), shift, scale)
|
||||
x = self.linear(x.to(dtype))
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user