mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-09-14 13:35:05 +00:00
Add PixArt model support (#6055)
* PixArt initial version * PixArt Diffusers convert logic * pos_emb and interpolation logic * Reduce duplicate code * Formatting * Use optimized attention * Edit empty token logic * Basic PixArt LoRA support * Fix aspect ratio logic * PixArtAlpha text encode with conds * Use same detection key logic for PixArt diffusers
This commit is contained in:
382
comfy/ldm/pixart/blocks.py
Normal file
382
comfy/ldm/pixart/blocks.py
Normal file
@@ -0,0 +1,382 @@
|
||||
# Based on:
|
||||
# https://github.com/PixArt-alpha/PixArt-alpha [Apache 2.0 license]
|
||||
# https://github.com/PixArt-alpha/PixArt-sigma [Apache 2.0 license]
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
|
||||
from comfy import model_management
|
||||
from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder, Mlp, timestep_embedding
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
|
||||
if model_management.xformers_enabled():
|
||||
import xformers.ops
|
||||
if int((xformers.__version__).split(".")[2]) >= 28:
|
||||
block_diagonal_mask_from_seqlens = xformers.ops.fmha.attn_bias.BlockDiagonalMask.from_seqlens
|
||||
else:
|
||||
block_diagonal_mask_from_seqlens = xformers.ops.fmha.BlockDiagonalMask.from_seqlens
|
||||
|
||||
def modulate(x, shift, scale):
|
||||
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
||||
|
||||
def t2i_modulate(x, shift, scale):
|
||||
return x * (1 + scale) + shift
|
||||
|
||||
class MultiHeadCrossAttention(nn.Module):
|
||||
def __init__(self, d_model, num_heads, attn_drop=0., proj_drop=0., dtype=None, device=None, operations=None, **kwargs):
|
||||
super(MultiHeadCrossAttention, self).__init__()
|
||||
assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
|
||||
|
||||
self.d_model = d_model
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = d_model // num_heads
|
||||
|
||||
self.q_linear = operations.Linear(d_model, d_model, dtype=dtype, device=device)
|
||||
self.kv_linear = operations.Linear(d_model, d_model*2, dtype=dtype, device=device)
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.proj = operations.Linear(d_model, d_model, dtype=dtype, device=device)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
|
||||
def forward(self, x, cond, mask=None):
|
||||
# query/value: img tokens; key: condition; mask: if padding tokens
|
||||
B, N, C = x.shape
|
||||
|
||||
q = self.q_linear(x).view(1, -1, self.num_heads, self.head_dim)
|
||||
kv = self.kv_linear(cond).view(1, -1, 2, self.num_heads, self.head_dim)
|
||||
k, v = kv.unbind(2)
|
||||
|
||||
# TODO: xformers needs separate mask logic here
|
||||
if model_management.xformers_enabled():
|
||||
attn_bias = None
|
||||
if mask is not None:
|
||||
attn_bias = block_diagonal_mask_from_seqlens([N] * B, mask)
|
||||
x = xformers.ops.memory_efficient_attention(q, k, v, p=0, attn_bias=attn_bias)
|
||||
else:
|
||||
q, k, v = map(lambda t: t.transpose(1, 2), (q, k, v),)
|
||||
attn_mask = None
|
||||
if mask is not None and len(mask) > 1:
|
||||
# Create equivalent of xformer diagonal block mask, still only correct for square masks
|
||||
# But depth doesn't matter as tensors can expand in that dimension
|
||||
attn_mask_template = torch.ones(
|
||||
[q.shape[2] // B, mask[0]],
|
||||
dtype=torch.bool,
|
||||
device=q.device
|
||||
)
|
||||
attn_mask = torch.block_diag(attn_mask_template)
|
||||
|
||||
# create a mask on the diagonal for each mask in the batch
|
||||
for _ in range(B - 1):
|
||||
attn_mask = torch.block_diag(attn_mask, attn_mask_template)
|
||||
|
||||
x = optimized_attention(q, k, v, self.num_heads, mask=attn_mask, skip_reshape=True)
|
||||
|
||||
x = x.view(B, -1, C)
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
|
||||
|
||||
class AttentionKVCompress(nn.Module):
|
||||
"""Multi-head Attention block with KV token compression and qk norm."""
|
||||
def __init__(self, dim, num_heads=8, qkv_bias=True, sampling='conv', sr_ratio=1, qk_norm=False, dtype=None, device=None, operations=None, **kwargs):
|
||||
"""
|
||||
Args:
|
||||
dim (int): Number of input channels.
|
||||
num_heads (int): Number of attention heads.
|
||||
qkv_bias (bool: If True, add a learnable bias to query, key, value.
|
||||
"""
|
||||
super().__init__()
|
||||
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
self.scale = self.head_dim ** -0.5
|
||||
|
||||
self.qkv = operations.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device)
|
||||
self.proj = operations.Linear(dim, dim, dtype=dtype, device=device)
|
||||
|
||||
self.sampling=sampling # ['conv', 'ave', 'uniform', 'uniform_every']
|
||||
self.sr_ratio = sr_ratio
|
||||
if sr_ratio > 1 and sampling == 'conv':
|
||||
# Avg Conv Init.
|
||||
self.sr = operations.Conv2d(dim, dim, groups=dim, kernel_size=sr_ratio, stride=sr_ratio, dtype=dtype, device=device)
|
||||
# self.sr.weight.data.fill_(1/sr_ratio**2)
|
||||
# self.sr.bias.data.zero_()
|
||||
self.norm = operations.LayerNorm(dim, dtype=dtype, device=device)
|
||||
if qk_norm:
|
||||
self.q_norm = operations.LayerNorm(dim, dtype=dtype, device=device)
|
||||
self.k_norm = operations.LayerNorm(dim, dtype=dtype, device=device)
|
||||
else:
|
||||
self.q_norm = nn.Identity()
|
||||
self.k_norm = nn.Identity()
|
||||
|
||||
def downsample_2d(self, tensor, H, W, scale_factor, sampling=None):
|
||||
if sampling is None or scale_factor == 1:
|
||||
return tensor
|
||||
B, N, C = tensor.shape
|
||||
|
||||
if sampling == 'uniform_every':
|
||||
return tensor[:, ::scale_factor], int(N // scale_factor)
|
||||
|
||||
tensor = tensor.reshape(B, H, W, C).permute(0, 3, 1, 2)
|
||||
new_H, new_W = int(H / scale_factor), int(W / scale_factor)
|
||||
new_N = new_H * new_W
|
||||
|
||||
if sampling == 'ave':
|
||||
tensor = F.interpolate(
|
||||
tensor, scale_factor=1 / scale_factor, mode='nearest'
|
||||
).permute(0, 2, 3, 1)
|
||||
elif sampling == 'uniform':
|
||||
tensor = tensor[:, :, ::scale_factor, ::scale_factor].permute(0, 2, 3, 1)
|
||||
elif sampling == 'conv':
|
||||
tensor = self.sr(tensor).reshape(B, C, -1).permute(0, 2, 1)
|
||||
tensor = self.norm(tensor)
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
return tensor.reshape(B, new_N, C).contiguous(), new_N
|
||||
|
||||
def forward(self, x, mask=None, HW=None, block_id=None):
|
||||
B, N, C = x.shape # 2 4096 1152
|
||||
new_N = N
|
||||
if HW is None:
|
||||
H = W = int(N ** 0.5)
|
||||
else:
|
||||
H, W = HW
|
||||
qkv = self.qkv(x).reshape(B, N, 3, C)
|
||||
|
||||
q, k, v = qkv.unbind(2)
|
||||
dtype = q.dtype
|
||||
q = self.q_norm(q)
|
||||
k = self.k_norm(k)
|
||||
|
||||
# KV compression
|
||||
if self.sr_ratio > 1:
|
||||
k, new_N = self.downsample_2d(k, H, W, self.sr_ratio, sampling=self.sampling)
|
||||
v, new_N = self.downsample_2d(v, H, W, self.sr_ratio, sampling=self.sampling)
|
||||
|
||||
q = q.reshape(B, N, self.num_heads, C // self.num_heads).to(dtype)
|
||||
k = k.reshape(B, new_N, self.num_heads, C // self.num_heads).to(dtype)
|
||||
v = v.reshape(B, new_N, self.num_heads, C // self.num_heads).to(dtype)
|
||||
|
||||
if mask is not None:
|
||||
raise NotImplementedError("Attn mask logic not added for self attention")
|
||||
|
||||
# This is never called at the moment
|
||||
# attn_bias = None
|
||||
# if mask is not None:
|
||||
# attn_bias = torch.zeros([B * self.num_heads, q.shape[1], k.shape[1]], dtype=q.dtype, device=q.device)
|
||||
# attn_bias.masked_fill_(mask.squeeze(1).repeat(self.num_heads, 1, 1) == 0, float('-inf'))
|
||||
|
||||
# attention 2
|
||||
q, k, v = map(lambda t: t.transpose(1, 2), (q, k, v),)
|
||||
x = optimized_attention(q, k, v, self.num_heads, mask=None, skip_reshape=True)
|
||||
|
||||
x = x.view(B, N, C)
|
||||
x = self.proj(x)
|
||||
return x
|
||||
|
||||
|
||||
class FinalLayer(nn.Module):
|
||||
"""
|
||||
The final layer of PixArt.
|
||||
"""
|
||||
def __init__(self, hidden_size, patch_size, out_channels, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.norm_final = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.linear = operations.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True, dtype=dtype, device=device)
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
operations.Linear(hidden_size, 2 * hidden_size, bias=True, dtype=dtype, device=device)
|
||||
)
|
||||
|
||||
def forward(self, x, c):
|
||||
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
|
||||
x = modulate(self.norm_final(x), shift, scale)
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
class T2IFinalLayer(nn.Module):
|
||||
"""
|
||||
The final layer of PixArt.
|
||||
"""
|
||||
def __init__(self, hidden_size, patch_size, out_channels, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.norm_final = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.linear = operations.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True, dtype=dtype, device=device)
|
||||
self.scale_shift_table = nn.Parameter(torch.randn(2, hidden_size) / hidden_size ** 0.5)
|
||||
self.out_channels = out_channels
|
||||
|
||||
def forward(self, x, t):
|
||||
dtype = x.dtype
|
||||
shift, scale = (self.scale_shift_table[None] + t[:, None]).chunk(2, dim=1)
|
||||
x = t2i_modulate(self.norm_final(x), shift, scale)
|
||||
x = self.linear(x.to(dtype))
|
||||
return x
|
||||
|
||||
|
||||
class MaskFinalLayer(nn.Module):
|
||||
"""
|
||||
The final layer of PixArt.
|
||||
"""
|
||||
def __init__(self, final_hidden_size, c_emb_size, patch_size, out_channels, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.norm_final = operations.LayerNorm(final_hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.linear = operations.Linear(final_hidden_size, patch_size * patch_size * out_channels, bias=True, dtype=dtype, device=device)
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
operations.Linear(c_emb_size, 2 * final_hidden_size, bias=True, dtype=dtype, device=device)
|
||||
)
|
||||
def forward(self, x, t):
|
||||
shift, scale = self.adaLN_modulation(t).chunk(2, dim=1)
|
||||
x = modulate(self.norm_final(x), shift, scale)
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
|
||||
class DecoderLayer(nn.Module):
|
||||
"""
|
||||
The final layer of PixArt.
|
||||
"""
|
||||
def __init__(self, hidden_size, decoder_hidden_size, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.norm_decoder = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.linear = operations.Linear(hidden_size, decoder_hidden_size, bias=True, dtype=dtype, device=device)
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
operations.Linear(hidden_size, 2 * hidden_size, bias=True, dtype=dtype, device=device)
|
||||
)
|
||||
def forward(self, x, t):
|
||||
shift, scale = self.adaLN_modulation(t).chunk(2, dim=1)
|
||||
x = modulate(self.norm_decoder(x), shift, scale)
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
|
||||
class SizeEmbedder(TimestepEmbedder):
|
||||
"""
|
||||
Embeds scalar timesteps into vector representations.
|
||||
"""
|
||||
def __init__(self, hidden_size, frequency_embedding_size=256, dtype=None, device=None, operations=None):
|
||||
super().__init__(hidden_size=hidden_size, frequency_embedding_size=frequency_embedding_size, operations=operations)
|
||||
self.mlp = nn.Sequential(
|
||||
operations.Linear(frequency_embedding_size, hidden_size, bias=True, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Linear(hidden_size, hidden_size, bias=True, dtype=dtype, device=device),
|
||||
)
|
||||
self.frequency_embedding_size = frequency_embedding_size
|
||||
self.outdim = hidden_size
|
||||
|
||||
def forward(self, s, bs):
|
||||
if s.ndim == 1:
|
||||
s = s[:, None]
|
||||
assert s.ndim == 2
|
||||
if s.shape[0] != bs:
|
||||
s = s.repeat(bs//s.shape[0], 1)
|
||||
assert s.shape[0] == bs
|
||||
b, dims = s.shape[0], s.shape[1]
|
||||
s = rearrange(s, "b d -> (b d)")
|
||||
s_freq = timestep_embedding(s, self.frequency_embedding_size)
|
||||
s_emb = self.mlp(s_freq.to(s.dtype))
|
||||
s_emb = rearrange(s_emb, "(b d) d2 -> b (d d2)", b=b, d=dims, d2=self.outdim)
|
||||
return s_emb
|
||||
|
||||
|
||||
class LabelEmbedder(nn.Module):
|
||||
"""
|
||||
Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
|
||||
"""
|
||||
def __init__(self, num_classes, hidden_size, dropout_prob, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
use_cfg_embedding = dropout_prob > 0
|
||||
self.embedding_table = operations.Embedding(num_classes + use_cfg_embedding, hidden_size, dtype=dtype, device=device),
|
||||
self.num_classes = num_classes
|
||||
self.dropout_prob = dropout_prob
|
||||
|
||||
def token_drop(self, labels, force_drop_ids=None):
|
||||
"""
|
||||
Drops labels to enable classifier-free guidance.
|
||||
"""
|
||||
if force_drop_ids is None:
|
||||
drop_ids = torch.rand(labels.shape[0]).cuda() < self.dropout_prob
|
||||
else:
|
||||
drop_ids = force_drop_ids == 1
|
||||
labels = torch.where(drop_ids, self.num_classes, labels)
|
||||
return labels
|
||||
|
||||
def forward(self, labels, train, force_drop_ids=None):
|
||||
use_dropout = self.dropout_prob > 0
|
||||
if (train and use_dropout) or (force_drop_ids is not None):
|
||||
labels = self.token_drop(labels, force_drop_ids)
|
||||
embeddings = self.embedding_table(labels)
|
||||
return embeddings
|
||||
|
||||
|
||||
class CaptionEmbedder(nn.Module):
|
||||
"""
|
||||
Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
|
||||
"""
|
||||
def __init__(self, in_channels, hidden_size, uncond_prob, act_layer=nn.GELU(approximate='tanh'), token_num=120, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.y_proj = Mlp(
|
||||
in_features=in_channels, hidden_features=hidden_size, out_features=hidden_size, act_layer=act_layer,
|
||||
dtype=dtype, device=device, operations=operations,
|
||||
)
|
||||
self.register_buffer("y_embedding", nn.Parameter(torch.randn(token_num, in_channels) / in_channels ** 0.5))
|
||||
self.uncond_prob = uncond_prob
|
||||
|
||||
def token_drop(self, caption, force_drop_ids=None):
|
||||
"""
|
||||
Drops labels to enable classifier-free guidance.
|
||||
"""
|
||||
if force_drop_ids is None:
|
||||
drop_ids = torch.rand(caption.shape[0]).cuda() < self.uncond_prob
|
||||
else:
|
||||
drop_ids = force_drop_ids == 1
|
||||
caption = torch.where(drop_ids[:, None, None, None], self.y_embedding, caption)
|
||||
return caption
|
||||
|
||||
def forward(self, caption, train, force_drop_ids=None):
|
||||
if train:
|
||||
assert caption.shape[2:] == self.y_embedding.shape
|
||||
use_dropout = self.uncond_prob > 0
|
||||
if (train and use_dropout) or (force_drop_ids is not None):
|
||||
caption = self.token_drop(caption, force_drop_ids)
|
||||
caption = self.y_proj(caption)
|
||||
return caption
|
||||
|
||||
|
||||
class CaptionEmbedderDoubleBr(nn.Module):
|
||||
"""
|
||||
Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
|
||||
"""
|
||||
def __init__(self, in_channels, hidden_size, uncond_prob, act_layer=nn.GELU(approximate='tanh'), token_num=120, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.proj = Mlp(
|
||||
in_features=in_channels, hidden_features=hidden_size, out_features=hidden_size, act_layer=act_layer,
|
||||
dtype=dtype, device=device, operations=operations,
|
||||
)
|
||||
self.embedding = nn.Parameter(torch.randn(1, in_channels) / 10 ** 0.5)
|
||||
self.y_embedding = nn.Parameter(torch.randn(token_num, in_channels) / 10 ** 0.5)
|
||||
self.uncond_prob = uncond_prob
|
||||
|
||||
def token_drop(self, global_caption, caption, force_drop_ids=None):
|
||||
"""
|
||||
Drops labels to enable classifier-free guidance.
|
||||
"""
|
||||
if force_drop_ids is None:
|
||||
drop_ids = torch.rand(global_caption.shape[0]).cuda() < self.uncond_prob
|
||||
else:
|
||||
drop_ids = force_drop_ids == 1
|
||||
global_caption = torch.where(drop_ids[:, None], self.embedding, global_caption)
|
||||
caption = torch.where(drop_ids[:, None, None, None], self.y_embedding, caption)
|
||||
return global_caption, caption
|
||||
|
||||
def forward(self, caption, train, force_drop_ids=None):
|
||||
assert caption.shape[2: ] == self.y_embedding.shape
|
||||
global_caption = caption.mean(dim=2).squeeze()
|
||||
use_dropout = self.uncond_prob > 0
|
||||
if (train and use_dropout) or (force_drop_ids is not None):
|
||||
global_caption, caption = self.token_drop(global_caption, caption, force_drop_ids)
|
||||
y_embed = self.proj(global_caption)
|
||||
return y_embed, caption
|
201
comfy/ldm/pixart/pixart.py
Normal file
201
comfy/ldm/pixart/pixart.py
Normal file
@@ -0,0 +1,201 @@
|
||||
# Based on:
|
||||
# https://github.com/PixArt-alpha/PixArt-alpha [Apache 2.0 license]
|
||||
# https://github.com/PixArt-alpha/PixArt-sigma [Apache 2.0 license]
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .blocks import (
|
||||
t2i_modulate,
|
||||
CaptionEmbedder,
|
||||
AttentionKVCompress,
|
||||
MultiHeadCrossAttention,
|
||||
T2IFinalLayer,
|
||||
)
|
||||
from comfy.ldm.modules.diffusionmodules.mmdit import PatchEmbed, TimestepEmbedder, Mlp, get_1d_sincos_pos_embed_from_grid_torch
|
||||
|
||||
|
||||
class PixArtBlock(nn.Module):
|
||||
"""
|
||||
A PixArt block with adaptive layer norm (adaLN-single) conditioning.
|
||||
"""
|
||||
def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, drop_path=0, input_size=None, sampling=None, sr_ratio=1, qk_norm=False, **block_kwargs):
|
||||
super().__init__()
|
||||
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.attn = AttentionKVCompress(
|
||||
hidden_size, num_heads=num_heads, qkv_bias=True, sampling=sampling, sr_ratio=sr_ratio,
|
||||
qk_norm=qk_norm, **block_kwargs
|
||||
)
|
||||
self.cross_attn = MultiHeadCrossAttention(hidden_size, num_heads, **block_kwargs)
|
||||
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
# to be compatible with lower version pytorch
|
||||
approx_gelu = lambda: nn.GELU(approximate="tanh")
|
||||
self.mlp = Mlp(in_features=hidden_size, hidden_features=int(hidden_size * mlp_ratio), act_layer=approx_gelu, drop=0)
|
||||
self.drop_path = nn.Identity() #DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||
self.scale_shift_table = nn.Parameter(torch.randn(6, hidden_size) / hidden_size ** 0.5)
|
||||
self.sampling = sampling
|
||||
self.sr_ratio = sr_ratio
|
||||
|
||||
def forward(self, x, y, t, mask=None, **kwargs):
|
||||
B, N, C = x.shape
|
||||
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None] + t.reshape(B, 6, -1)).chunk(6, dim=1)
|
||||
x = x + self.drop_path(gate_msa * self.attn(t2i_modulate(self.norm1(x), shift_msa, scale_msa)).reshape(B, N, C))
|
||||
x = x + self.cross_attn(x, y, mask)
|
||||
x = x + self.drop_path(gate_mlp * self.mlp(t2i_modulate(self.norm2(x), shift_mlp, scale_mlp)))
|
||||
|
||||
return x
|
||||
|
||||
|
||||
### Core PixArt Model ###
|
||||
class PixArt(nn.Module):
|
||||
"""
|
||||
Diffusion model with a Transformer backbone.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
input_size=32,
|
||||
patch_size=2,
|
||||
in_channels=4,
|
||||
hidden_size=1152,
|
||||
depth=28,
|
||||
num_heads=16,
|
||||
mlp_ratio=4.0,
|
||||
class_dropout_prob=0.1,
|
||||
pred_sigma=True,
|
||||
drop_path: float = 0.,
|
||||
caption_channels=4096,
|
||||
pe_interpolation=1.0,
|
||||
pe_precision=None,
|
||||
config=None,
|
||||
model_max_length=120,
|
||||
qk_norm=False,
|
||||
kv_compress_config=None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.pred_sigma = pred_sigma
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = in_channels * 2 if pred_sigma else in_channels
|
||||
self.patch_size = patch_size
|
||||
self.num_heads = num_heads
|
||||
self.pe_interpolation = pe_interpolation
|
||||
self.pe_precision = pe_precision
|
||||
self.depth = depth
|
||||
|
||||
self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True)
|
||||
self.t_embedder = TimestepEmbedder(hidden_size)
|
||||
num_patches = self.x_embedder.num_patches
|
||||
self.base_size = input_size // self.patch_size
|
||||
# Will use fixed sin-cos embedding:
|
||||
self.register_buffer("pos_embed", torch.zeros(1, num_patches, hidden_size))
|
||||
|
||||
approx_gelu = lambda: nn.GELU(approximate="tanh")
|
||||
self.t_block = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
|
||||
)
|
||||
self.y_embedder = CaptionEmbedder(
|
||||
in_channels=caption_channels, hidden_size=hidden_size, uncond_prob=class_dropout_prob,
|
||||
act_layer=approx_gelu, token_num=model_max_length
|
||||
)
|
||||
drop_path = [x.item() for x in torch.linspace(0, drop_path, depth)] # stochastic depth decay rule
|
||||
self.kv_compress_config = kv_compress_config
|
||||
if kv_compress_config is None:
|
||||
self.kv_compress_config = {
|
||||
'sampling': None,
|
||||
'scale_factor': 1,
|
||||
'kv_compress_layer': [],
|
||||
}
|
||||
self.blocks = nn.ModuleList([
|
||||
PixArtBlock(
|
||||
hidden_size, num_heads, mlp_ratio=mlp_ratio, drop_path=drop_path[i],
|
||||
input_size=(input_size // patch_size, input_size // patch_size),
|
||||
sampling=self.kv_compress_config['sampling'],
|
||||
sr_ratio=int(
|
||||
self.kv_compress_config['scale_factor']
|
||||
) if i in self.kv_compress_config['kv_compress_layer'] else 1,
|
||||
qk_norm=qk_norm,
|
||||
)
|
||||
for i in range(depth)
|
||||
])
|
||||
self.final_layer = T2IFinalLayer(hidden_size, patch_size, self.out_channels)
|
||||
|
||||
def forward_raw(self, x, t, y, mask=None, data_info=None):
|
||||
"""
|
||||
Original forward pass of PixArt.
|
||||
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
|
||||
t: (N,) tensor of diffusion timesteps
|
||||
y: (N, 1, 120, C) tensor of class labels
|
||||
"""
|
||||
x = x.to(self.dtype)
|
||||
timestep = t.to(self.dtype)
|
||||
y = y.to(self.dtype)
|
||||
pos_embed = self.pos_embed.to(self.dtype)
|
||||
x = self.x_embedder(x) + pos_embed # (N, T, D), where T = H * W / patch_size ** 2
|
||||
t = self.t_embedder(timestep.to(x.dtype)) # (N, D)
|
||||
t0 = self.t_block(t)
|
||||
y = self.y_embedder(y, self.training) # (N, 1, L, D)
|
||||
if mask is not None:
|
||||
if mask.shape[0] != y.shape[0]:
|
||||
mask = mask.repeat(y.shape[0] // mask.shape[0], 1)
|
||||
mask = mask.squeeze(1).squeeze(1)
|
||||
y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1])
|
||||
y_lens = mask.sum(dim=1).tolist()
|
||||
else:
|
||||
y_lens = [y.shape[2]] * y.shape[0]
|
||||
y = y.squeeze(1).view(1, -1, x.shape[-1])
|
||||
for block in self.blocks:
|
||||
x = block(x, y, t0, y_lens) # (N, T, D)
|
||||
x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels)
|
||||
x = self.unpatchify(x) # (N, out_channels, H, W)
|
||||
return x
|
||||
|
||||
def forward(self, x, timesteps, context, y=None, **kwargs):
|
||||
"""
|
||||
Forward pass that adapts comfy input to original forward function
|
||||
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
|
||||
timesteps: (N,) tensor of diffusion timesteps
|
||||
context: (N, 1, 120, C) conditioning
|
||||
y: extra conditioning.
|
||||
"""
|
||||
## Still accepts the input w/o that dim but returns garbage
|
||||
if len(context.shape) == 3:
|
||||
context = context.unsqueeze(1)
|
||||
|
||||
## run original forward pass
|
||||
out = self.forward_raw(
|
||||
x = x.to(self.dtype),
|
||||
t = timesteps.to(self.dtype),
|
||||
y = context.to(self.dtype),
|
||||
)
|
||||
|
||||
## only return EPS
|
||||
out = out.to(torch.float)
|
||||
eps, _ = out[:, :self.in_channels], out[:, self.in_channels:]
|
||||
return eps
|
||||
|
||||
def unpatchify(self, x):
|
||||
"""
|
||||
x: (N, T, patch_size**2 * C)
|
||||
imgs: (N, H, W, C)
|
||||
"""
|
||||
c = self.out_channels
|
||||
p = self.x_embedder.patch_size[0]
|
||||
h = w = int(x.shape[1] ** 0.5)
|
||||
assert h * w == x.shape[1]
|
||||
|
||||
x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
|
||||
x = torch.einsum('nhwpqc->nchpwq', x)
|
||||
imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p))
|
||||
return imgs
|
||||
|
||||
def get_2d_sincos_pos_embed_torch(embed_dim, w, h, pe_interpolation=1.0, base_size=16, device=None, dtype=torch.float32):
|
||||
grid_h, grid_w = torch.meshgrid(
|
||||
torch.arange(h, device=device, dtype=dtype) / (h/base_size) / pe_interpolation,
|
||||
torch.arange(w, device=device, dtype=dtype) / (w/base_size) / pe_interpolation,
|
||||
indexing='ij'
|
||||
)
|
||||
emb_h = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_h, device=device, dtype=dtype)
|
||||
emb_w = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_w, device=device, dtype=dtype)
|
||||
emb = torch.cat([emb_w, emb_h], dim=1) # (H*W, D)
|
||||
return emb
|
246
comfy/ldm/pixart/pixartms.py
Normal file
246
comfy/ldm/pixart/pixartms.py
Normal file
@@ -0,0 +1,246 @@
|
||||
# Based on:
|
||||
# https://github.com/PixArt-alpha/PixArt-alpha [Apache 2.0 license]
|
||||
# https://github.com/PixArt-alpha/PixArt-sigma [Apache 2.0 license]
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .blocks import (
|
||||
t2i_modulate,
|
||||
CaptionEmbedder,
|
||||
AttentionKVCompress,
|
||||
MultiHeadCrossAttention,
|
||||
T2IFinalLayer,
|
||||
SizeEmbedder,
|
||||
)
|
||||
from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder, PatchEmbed, Mlp
|
||||
from .pixart import PixArt, get_2d_sincos_pos_embed_torch
|
||||
|
||||
|
||||
class PixArtMSBlock(nn.Module):
|
||||
"""
|
||||
A PixArt block with adaptive layer norm zero (adaLN-Zero) conditioning.
|
||||
"""
|
||||
def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, drop_path=0., input_size=None,
|
||||
sampling=None, sr_ratio=1, qk_norm=False, dtype=None, device=None, operations=None, **block_kwargs):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.attn = AttentionKVCompress(
|
||||
hidden_size, num_heads=num_heads, qkv_bias=True, sampling=sampling, sr_ratio=sr_ratio,
|
||||
qk_norm=qk_norm, dtype=dtype, device=device, operations=operations, **block_kwargs
|
||||
)
|
||||
self.cross_attn = MultiHeadCrossAttention(
|
||||
hidden_size, num_heads, dtype=dtype, device=device, operations=operations, **block_kwargs
|
||||
)
|
||||
self.norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
# to be compatible with lower version pytorch
|
||||
approx_gelu = lambda: nn.GELU(approximate="tanh")
|
||||
self.mlp = Mlp(
|
||||
in_features=hidden_size, hidden_features=int(hidden_size * mlp_ratio), act_layer=approx_gelu,
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
self.scale_shift_table = nn.Parameter(torch.randn(6, hidden_size) / hidden_size ** 0.5)
|
||||
|
||||
def forward(self, x, y, t, mask=None, HW=None, **kwargs):
|
||||
B, N, C = x.shape
|
||||
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None].to(x.dtype) + t.reshape(B, 6, -1)).chunk(6, dim=1)
|
||||
x = x + (gate_msa * self.attn(t2i_modulate(self.norm1(x), shift_msa, scale_msa), HW=HW))
|
||||
x = x + self.cross_attn(x, y, mask)
|
||||
x = x + (gate_mlp * self.mlp(t2i_modulate(self.norm2(x), shift_mlp, scale_mlp)))
|
||||
|
||||
return x
|
||||
|
||||
|
||||
### Core PixArt Model ###
|
||||
class PixArtMS(PixArt):
|
||||
"""
|
||||
Diffusion model with a Transformer backbone.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
input_size=32,
|
||||
patch_size=2,
|
||||
in_channels=4,
|
||||
hidden_size=1152,
|
||||
depth=28,
|
||||
num_heads=16,
|
||||
mlp_ratio=4.0,
|
||||
class_dropout_prob=0.1,
|
||||
learn_sigma=True,
|
||||
pred_sigma=True,
|
||||
drop_path: float = 0.,
|
||||
caption_channels=4096,
|
||||
pe_interpolation=None,
|
||||
pe_precision=None,
|
||||
config=None,
|
||||
model_max_length=120,
|
||||
micro_condition=True,
|
||||
qk_norm=False,
|
||||
kv_compress_config=None,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
**kwargs,
|
||||
):
|
||||
nn.Module.__init__(self)
|
||||
self.dtype = dtype
|
||||
self.pred_sigma = pred_sigma
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = in_channels * 2 if pred_sigma else in_channels
|
||||
self.patch_size = patch_size
|
||||
self.num_heads = num_heads
|
||||
self.pe_interpolation = pe_interpolation
|
||||
self.pe_precision = pe_precision
|
||||
self.hidden_size = hidden_size
|
||||
self.depth = depth
|
||||
|
||||
approx_gelu = lambda: nn.GELU(approximate="tanh")
|
||||
self.t_block = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
operations.Linear(hidden_size, 6 * hidden_size, bias=True, dtype=dtype, device=device)
|
||||
)
|
||||
self.x_embedder = PatchEmbed(
|
||||
patch_size=patch_size,
|
||||
in_chans=in_channels,
|
||||
embed_dim=hidden_size,
|
||||
bias=True,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations
|
||||
)
|
||||
self.t_embedder = TimestepEmbedder(
|
||||
hidden_size, dtype=dtype, device=device, operations=operations,
|
||||
)
|
||||
self.y_embedder = CaptionEmbedder(
|
||||
in_channels=caption_channels, hidden_size=hidden_size, uncond_prob=class_dropout_prob,
|
||||
act_layer=approx_gelu, token_num=model_max_length,
|
||||
dtype=dtype, device=device, operations=operations,
|
||||
)
|
||||
|
||||
self.micro_conditioning = micro_condition
|
||||
if self.micro_conditioning:
|
||||
self.csize_embedder = SizeEmbedder(hidden_size//3, dtype=dtype, device=device, operations=operations)
|
||||
self.ar_embedder = SizeEmbedder(hidden_size//3, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
# For fixed sin-cos embedding:
|
||||
# num_patches = (input_size // patch_size) * (input_size // patch_size)
|
||||
# self.base_size = input_size // self.patch_size
|
||||
# self.register_buffer("pos_embed", torch.zeros(1, num_patches, hidden_size))
|
||||
|
||||
drop_path = [x.item() for x in torch.linspace(0, drop_path, depth)] # stochastic depth decay rule
|
||||
if kv_compress_config is None:
|
||||
kv_compress_config = {
|
||||
'sampling': None,
|
||||
'scale_factor': 1,
|
||||
'kv_compress_layer': [],
|
||||
}
|
||||
self.blocks = nn.ModuleList([
|
||||
PixArtMSBlock(
|
||||
hidden_size, num_heads, mlp_ratio=mlp_ratio, drop_path=drop_path[i],
|
||||
sampling=kv_compress_config['sampling'],
|
||||
sr_ratio=int(kv_compress_config['scale_factor']) if i in kv_compress_config['kv_compress_layer'] else 1,
|
||||
qk_norm=qk_norm,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
for i in range(depth)
|
||||
])
|
||||
self.final_layer = T2IFinalLayer(
|
||||
hidden_size, patch_size, self.out_channels, dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
|
||||
def forward_orig(self, x, timestep, y, mask=None, c_size=None, c_ar=None, **kwargs):
|
||||
"""
|
||||
Original forward pass of PixArt.
|
||||
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
|
||||
t: (N,) tensor of diffusion timesteps
|
||||
y: (N, 1, 120, C) conditioning
|
||||
ar: (N, 1): aspect ratio
|
||||
cs: (N ,2) size conditioning for height/width
|
||||
"""
|
||||
B, C, H, W = x.shape
|
||||
c_res = (H + W) // 2
|
||||
pe_interpolation = self.pe_interpolation
|
||||
if pe_interpolation is None or self.pe_precision is not None:
|
||||
# calculate pe_interpolation on-the-fly
|
||||
pe_interpolation = round(c_res / (512/8.0), self.pe_precision or 0)
|
||||
|
||||
pos_embed = get_2d_sincos_pos_embed_torch(
|
||||
self.hidden_size,
|
||||
h=(H // self.patch_size),
|
||||
w=(W // self.patch_size),
|
||||
pe_interpolation=pe_interpolation,
|
||||
base_size=((round(c_res / 64) * 64) // self.patch_size),
|
||||
device=x.device,
|
||||
dtype=x.dtype,
|
||||
).unsqueeze(0)
|
||||
|
||||
x = self.x_embedder(x) + pos_embed # (N, T, D), where T = H * W / patch_size ** 2
|
||||
t = self.t_embedder(timestep, x.dtype) # (N, D)
|
||||
|
||||
if self.micro_conditioning and (c_size is not None and c_ar is not None):
|
||||
bs = x.shape[0]
|
||||
c_size = self.csize_embedder(c_size, bs) # (N, D)
|
||||
c_ar = self.ar_embedder(c_ar, bs) # (N, D)
|
||||
t = t + torch.cat([c_size, c_ar], dim=1)
|
||||
|
||||
t0 = self.t_block(t)
|
||||
y = self.y_embedder(y, self.training) # (N, D)
|
||||
|
||||
if mask is not None:
|
||||
if mask.shape[0] != y.shape[0]:
|
||||
mask = mask.repeat(y.shape[0] // mask.shape[0], 1)
|
||||
mask = mask.squeeze(1).squeeze(1)
|
||||
y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1])
|
||||
y_lens = mask.sum(dim=1).tolist()
|
||||
else:
|
||||
y_lens = [y.shape[2]] * y.shape[0]
|
||||
y = y.squeeze(1).view(1, -1, x.shape[-1])
|
||||
for block in self.blocks:
|
||||
x = block(x, y, t0, y_lens, (H, W), **kwargs) # (N, T, D)
|
||||
|
||||
x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels)
|
||||
x = self.unpatchify(x, H, W) # (N, out_channels, H, W)
|
||||
|
||||
return x
|
||||
|
||||
def forward(self, x, timesteps, context, c_size=None, c_ar=None, **kwargs):
|
||||
B, C, H, W = x.shape
|
||||
|
||||
# Fallback for missing microconds
|
||||
if self.micro_conditioning:
|
||||
if c_size is None:
|
||||
c_size = torch.tensor([H*8, W*8], dtype=x.dtype, device=x.device).repeat(B, 1)
|
||||
|
||||
if c_ar is None:
|
||||
c_ar = torch.tensor([H/W], dtype=x.dtype, device=x.device).repeat(B, 1)
|
||||
|
||||
## Still accepts the input w/o that dim but returns garbage
|
||||
if len(context.shape) == 3:
|
||||
context = context.unsqueeze(1)
|
||||
|
||||
## run original forward pass
|
||||
out = self.forward_orig(x, timesteps, context, c_size=c_size, c_ar=c_ar)
|
||||
|
||||
## only return EPS
|
||||
if self.pred_sigma:
|
||||
return out[:, :self.in_channels]
|
||||
return out
|
||||
|
||||
def unpatchify(self, x, h, w):
|
||||
"""
|
||||
x: (N, T, patch_size**2 * C)
|
||||
imgs: (N, H, W, C)
|
||||
"""
|
||||
c = self.out_channels
|
||||
p = self.x_embedder.patch_size[0]
|
||||
h = h // self.patch_size
|
||||
w = w // self.patch_size
|
||||
assert h * w == x.shape[1]
|
||||
|
||||
x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
|
||||
x = torch.einsum('nhwpqc->nchpwq', x)
|
||||
imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p))
|
||||
return imgs
|
Reference in New Issue
Block a user