Remove windows line endings. (#8866)

This commit is contained in:
comfyanonymous 2025-07-10 23:37:51 -07:00 committed by GitHub
parent 8f05fb48ea
commit 938d3e8216
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 298 additions and 298 deletions

View File

@ -1,256 +1,256 @@
# Based on: # Based on:
# https://github.com/PixArt-alpha/PixArt-alpha [Apache 2.0 license] # https://github.com/PixArt-alpha/PixArt-alpha [Apache 2.0 license]
# https://github.com/PixArt-alpha/PixArt-sigma [Apache 2.0 license] # https://github.com/PixArt-alpha/PixArt-sigma [Apache 2.0 license]
import torch import torch
import torch.nn as nn import torch.nn as nn
from .blocks import ( from .blocks import (
t2i_modulate, t2i_modulate,
CaptionEmbedder, CaptionEmbedder,
AttentionKVCompress, AttentionKVCompress,
MultiHeadCrossAttention, MultiHeadCrossAttention,
T2IFinalLayer, T2IFinalLayer,
SizeEmbedder, SizeEmbedder,
) )
from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder, PatchEmbed, Mlp, get_1d_sincos_pos_embed_from_grid_torch from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder, PatchEmbed, Mlp, get_1d_sincos_pos_embed_from_grid_torch
def get_2d_sincos_pos_embed_torch(embed_dim, w, h, pe_interpolation=1.0, base_size=16, device=None, dtype=torch.float32): def get_2d_sincos_pos_embed_torch(embed_dim, w, h, pe_interpolation=1.0, base_size=16, device=None, dtype=torch.float32):
grid_h, grid_w = torch.meshgrid( grid_h, grid_w = torch.meshgrid(
torch.arange(h, device=device, dtype=dtype) / (h/base_size) / pe_interpolation, torch.arange(h, device=device, dtype=dtype) / (h/base_size) / pe_interpolation,
torch.arange(w, device=device, dtype=dtype) / (w/base_size) / pe_interpolation, torch.arange(w, device=device, dtype=dtype) / (w/base_size) / pe_interpolation,
indexing='ij' indexing='ij'
) )
emb_h = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_h, device=device, dtype=dtype) emb_h = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_h, device=device, dtype=dtype)
emb_w = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_w, device=device, dtype=dtype) emb_w = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_w, device=device, dtype=dtype)
emb = torch.cat([emb_w, emb_h], dim=1) # (H*W, D) emb = torch.cat([emb_w, emb_h], dim=1) # (H*W, D)
return emb return emb
class PixArtMSBlock(nn.Module): class PixArtMSBlock(nn.Module):
""" """
A PixArt block with adaptive layer norm zero (adaLN-Zero) conditioning. A PixArt block with adaptive layer norm zero (adaLN-Zero) conditioning.
""" """
def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, drop_path=0., input_size=None, def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, drop_path=0., input_size=None,
sampling=None, sr_ratio=1, qk_norm=False, dtype=None, device=None, operations=None, **block_kwargs): sampling=None, sr_ratio=1, qk_norm=False, dtype=None, device=None, operations=None, **block_kwargs):
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) self.norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.attn = AttentionKVCompress( self.attn = AttentionKVCompress(
hidden_size, num_heads=num_heads, qkv_bias=True, sampling=sampling, sr_ratio=sr_ratio, hidden_size, num_heads=num_heads, qkv_bias=True, sampling=sampling, sr_ratio=sr_ratio,
qk_norm=qk_norm, dtype=dtype, device=device, operations=operations, **block_kwargs qk_norm=qk_norm, dtype=dtype, device=device, operations=operations, **block_kwargs
) )
self.cross_attn = MultiHeadCrossAttention( self.cross_attn = MultiHeadCrossAttention(
hidden_size, num_heads, dtype=dtype, device=device, operations=operations, **block_kwargs hidden_size, num_heads, dtype=dtype, device=device, operations=operations, **block_kwargs
) )
self.norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) self.norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
# to be compatible with lower version pytorch # to be compatible with lower version pytorch
approx_gelu = lambda: nn.GELU(approximate="tanh") approx_gelu = lambda: nn.GELU(approximate="tanh")
self.mlp = Mlp( self.mlp = Mlp(
in_features=hidden_size, hidden_features=int(hidden_size * mlp_ratio), act_layer=approx_gelu, in_features=hidden_size, hidden_features=int(hidden_size * mlp_ratio), act_layer=approx_gelu,
dtype=dtype, device=device, operations=operations dtype=dtype, device=device, operations=operations
) )
self.scale_shift_table = nn.Parameter(torch.randn(6, hidden_size) / hidden_size ** 0.5) self.scale_shift_table = nn.Parameter(torch.randn(6, hidden_size) / hidden_size ** 0.5)
def forward(self, x, y, t, mask=None, HW=None, **kwargs): def forward(self, x, y, t, mask=None, HW=None, **kwargs):
B, N, C = x.shape B, N, C = x.shape
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None].to(dtype=x.dtype, device=x.device) + t.reshape(B, 6, -1)).chunk(6, dim=1) shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None].to(dtype=x.dtype, device=x.device) + t.reshape(B, 6, -1)).chunk(6, dim=1)
x = x + (gate_msa * self.attn(t2i_modulate(self.norm1(x), shift_msa, scale_msa), HW=HW)) x = x + (gate_msa * self.attn(t2i_modulate(self.norm1(x), shift_msa, scale_msa), HW=HW))
x = x + self.cross_attn(x, y, mask) x = x + self.cross_attn(x, y, mask)
x = x + (gate_mlp * self.mlp(t2i_modulate(self.norm2(x), shift_mlp, scale_mlp))) x = x + (gate_mlp * self.mlp(t2i_modulate(self.norm2(x), shift_mlp, scale_mlp)))
return x return x
### Core PixArt Model ### ### Core PixArt Model ###
class PixArtMS(nn.Module): class PixArtMS(nn.Module):
""" """
Diffusion model with a Transformer backbone. Diffusion model with a Transformer backbone.
""" """
def __init__( def __init__(
self, self,
input_size=32, input_size=32,
patch_size=2, patch_size=2,
in_channels=4, in_channels=4,
hidden_size=1152, hidden_size=1152,
depth=28, depth=28,
num_heads=16, num_heads=16,
mlp_ratio=4.0, mlp_ratio=4.0,
class_dropout_prob=0.1, class_dropout_prob=0.1,
learn_sigma=True, learn_sigma=True,
pred_sigma=True, pred_sigma=True,
drop_path: float = 0., drop_path: float = 0.,
caption_channels=4096, caption_channels=4096,
pe_interpolation=None, pe_interpolation=None,
pe_precision=None, pe_precision=None,
config=None, config=None,
model_max_length=120, model_max_length=120,
micro_condition=True, micro_condition=True,
qk_norm=False, qk_norm=False,
kv_compress_config=None, kv_compress_config=None,
dtype=None, dtype=None,
device=None, device=None,
operations=None, operations=None,
**kwargs, **kwargs,
): ):
nn.Module.__init__(self) nn.Module.__init__(self)
self.dtype = dtype self.dtype = dtype
self.pred_sigma = pred_sigma self.pred_sigma = pred_sigma
self.in_channels = in_channels self.in_channels = in_channels
self.out_channels = in_channels * 2 if pred_sigma else in_channels self.out_channels = in_channels * 2 if pred_sigma else in_channels
self.patch_size = patch_size self.patch_size = patch_size
self.num_heads = num_heads self.num_heads = num_heads
self.pe_interpolation = pe_interpolation self.pe_interpolation = pe_interpolation
self.pe_precision = pe_precision self.pe_precision = pe_precision
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.depth = depth self.depth = depth
approx_gelu = lambda: nn.GELU(approximate="tanh") approx_gelu = lambda: nn.GELU(approximate="tanh")
self.t_block = nn.Sequential( self.t_block = nn.Sequential(
nn.SiLU(), nn.SiLU(),
operations.Linear(hidden_size, 6 * hidden_size, bias=True, dtype=dtype, device=device) operations.Linear(hidden_size, 6 * hidden_size, bias=True, dtype=dtype, device=device)
) )
self.x_embedder = PatchEmbed( self.x_embedder = PatchEmbed(
patch_size=patch_size, patch_size=patch_size,
in_chans=in_channels, in_chans=in_channels,
embed_dim=hidden_size, embed_dim=hidden_size,
bias=True, bias=True,
dtype=dtype, dtype=dtype,
device=device, device=device,
operations=operations operations=operations
) )
self.t_embedder = TimestepEmbedder( self.t_embedder = TimestepEmbedder(
hidden_size, dtype=dtype, device=device, operations=operations, hidden_size, dtype=dtype, device=device, operations=operations,
) )
self.y_embedder = CaptionEmbedder( self.y_embedder = CaptionEmbedder(
in_channels=caption_channels, hidden_size=hidden_size, uncond_prob=class_dropout_prob, in_channels=caption_channels, hidden_size=hidden_size, uncond_prob=class_dropout_prob,
act_layer=approx_gelu, token_num=model_max_length, act_layer=approx_gelu, token_num=model_max_length,
dtype=dtype, device=device, operations=operations, dtype=dtype, device=device, operations=operations,
) )
self.micro_conditioning = micro_condition self.micro_conditioning = micro_condition
if self.micro_conditioning: if self.micro_conditioning:
self.csize_embedder = SizeEmbedder(hidden_size//3, dtype=dtype, device=device, operations=operations) self.csize_embedder = SizeEmbedder(hidden_size//3, dtype=dtype, device=device, operations=operations)
self.ar_embedder = SizeEmbedder(hidden_size//3, dtype=dtype, device=device, operations=operations) self.ar_embedder = SizeEmbedder(hidden_size//3, dtype=dtype, device=device, operations=operations)
# For fixed sin-cos embedding: # For fixed sin-cos embedding:
# num_patches = (input_size // patch_size) * (input_size // patch_size) # num_patches = (input_size // patch_size) * (input_size // patch_size)
# self.base_size = input_size // self.patch_size # self.base_size = input_size // self.patch_size
# self.register_buffer("pos_embed", torch.zeros(1, num_patches, hidden_size)) # self.register_buffer("pos_embed", torch.zeros(1, num_patches, hidden_size))
drop_path = [x.item() for x in torch.linspace(0, drop_path, depth)] # stochastic depth decay rule drop_path = [x.item() for x in torch.linspace(0, drop_path, depth)] # stochastic depth decay rule
if kv_compress_config is None: if kv_compress_config is None:
kv_compress_config = { kv_compress_config = {
'sampling': None, 'sampling': None,
'scale_factor': 1, 'scale_factor': 1,
'kv_compress_layer': [], 'kv_compress_layer': [],
} }
self.blocks = nn.ModuleList([ self.blocks = nn.ModuleList([
PixArtMSBlock( PixArtMSBlock(
hidden_size, num_heads, mlp_ratio=mlp_ratio, drop_path=drop_path[i], hidden_size, num_heads, mlp_ratio=mlp_ratio, drop_path=drop_path[i],
sampling=kv_compress_config['sampling'], sampling=kv_compress_config['sampling'],
sr_ratio=int(kv_compress_config['scale_factor']) if i in kv_compress_config['kv_compress_layer'] else 1, sr_ratio=int(kv_compress_config['scale_factor']) if i in kv_compress_config['kv_compress_layer'] else 1,
qk_norm=qk_norm, qk_norm=qk_norm,
dtype=dtype, dtype=dtype,
device=device, device=device,
operations=operations, operations=operations,
) )
for i in range(depth) for i in range(depth)
]) ])
self.final_layer = T2IFinalLayer( self.final_layer = T2IFinalLayer(
hidden_size, patch_size, self.out_channels, dtype=dtype, device=device, operations=operations hidden_size, patch_size, self.out_channels, dtype=dtype, device=device, operations=operations
) )
def forward_orig(self, x, timestep, y, mask=None, c_size=None, c_ar=None, **kwargs): def forward_orig(self, x, timestep, y, mask=None, c_size=None, c_ar=None, **kwargs):
""" """
Original forward pass of PixArt. Original forward pass of PixArt.
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
t: (N,) tensor of diffusion timesteps t: (N,) tensor of diffusion timesteps
y: (N, 1, 120, C) conditioning y: (N, 1, 120, C) conditioning
ar: (N, 1): aspect ratio ar: (N, 1): aspect ratio
cs: (N ,2) size conditioning for height/width cs: (N ,2) size conditioning for height/width
""" """
B, C, H, W = x.shape B, C, H, W = x.shape
c_res = (H + W) // 2 c_res = (H + W) // 2
pe_interpolation = self.pe_interpolation pe_interpolation = self.pe_interpolation
if pe_interpolation is None or self.pe_precision is not None: if pe_interpolation is None or self.pe_precision is not None:
# calculate pe_interpolation on-the-fly # calculate pe_interpolation on-the-fly
pe_interpolation = round(c_res / (512/8.0), self.pe_precision or 0) pe_interpolation = round(c_res / (512/8.0), self.pe_precision or 0)
pos_embed = get_2d_sincos_pos_embed_torch( pos_embed = get_2d_sincos_pos_embed_torch(
self.hidden_size, self.hidden_size,
h=(H // self.patch_size), h=(H // self.patch_size),
w=(W // self.patch_size), w=(W // self.patch_size),
pe_interpolation=pe_interpolation, pe_interpolation=pe_interpolation,
base_size=((round(c_res / 64) * 64) // self.patch_size), base_size=((round(c_res / 64) * 64) // self.patch_size),
device=x.device, device=x.device,
dtype=x.dtype, dtype=x.dtype,
).unsqueeze(0) ).unsqueeze(0)
x = self.x_embedder(x) + pos_embed # (N, T, D), where T = H * W / patch_size ** 2 x = self.x_embedder(x) + pos_embed # (N, T, D), where T = H * W / patch_size ** 2
t = self.t_embedder(timestep, x.dtype) # (N, D) t = self.t_embedder(timestep, x.dtype) # (N, D)
if self.micro_conditioning and (c_size is not None and c_ar is not None): if self.micro_conditioning and (c_size is not None and c_ar is not None):
bs = x.shape[0] bs = x.shape[0]
c_size = self.csize_embedder(c_size, bs) # (N, D) c_size = self.csize_embedder(c_size, bs) # (N, D)
c_ar = self.ar_embedder(c_ar, bs) # (N, D) c_ar = self.ar_embedder(c_ar, bs) # (N, D)
t = t + torch.cat([c_size, c_ar], dim=1) t = t + torch.cat([c_size, c_ar], dim=1)
t0 = self.t_block(t) t0 = self.t_block(t)
y = self.y_embedder(y, self.training) # (N, D) y = self.y_embedder(y, self.training) # (N, D)
if mask is not None: if mask is not None:
if mask.shape[0] != y.shape[0]: if mask.shape[0] != y.shape[0]:
mask = mask.repeat(y.shape[0] // mask.shape[0], 1) mask = mask.repeat(y.shape[0] // mask.shape[0], 1)
mask = mask.squeeze(1).squeeze(1) mask = mask.squeeze(1).squeeze(1)
y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1]) y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1])
y_lens = mask.sum(dim=1).tolist() y_lens = mask.sum(dim=1).tolist()
else: else:
y_lens = None y_lens = None
y = y.squeeze(1).view(1, -1, x.shape[-1]) y = y.squeeze(1).view(1, -1, x.shape[-1])
for block in self.blocks: for block in self.blocks:
x = block(x, y, t0, y_lens, (H, W), **kwargs) # (N, T, D) x = block(x, y, t0, y_lens, (H, W), **kwargs) # (N, T, D)
x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels) x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels)
x = self.unpatchify(x, H, W) # (N, out_channels, H, W) x = self.unpatchify(x, H, W) # (N, out_channels, H, W)
return x return x
def forward(self, x, timesteps, context, c_size=None, c_ar=None, **kwargs): def forward(self, x, timesteps, context, c_size=None, c_ar=None, **kwargs):
B, C, H, W = x.shape B, C, H, W = x.shape
# Fallback for missing microconds # Fallback for missing microconds
if self.micro_conditioning: if self.micro_conditioning:
if c_size is None: if c_size is None:
c_size = torch.tensor([H*8, W*8], dtype=x.dtype, device=x.device).repeat(B, 1) c_size = torch.tensor([H*8, W*8], dtype=x.dtype, device=x.device).repeat(B, 1)
if c_ar is None: if c_ar is None:
c_ar = torch.tensor([H/W], dtype=x.dtype, device=x.device).repeat(B, 1) c_ar = torch.tensor([H/W], dtype=x.dtype, device=x.device).repeat(B, 1)
## Still accepts the input w/o that dim but returns garbage ## Still accepts the input w/o that dim but returns garbage
if len(context.shape) == 3: if len(context.shape) == 3:
context = context.unsqueeze(1) context = context.unsqueeze(1)
## run original forward pass ## run original forward pass
out = self.forward_orig(x, timesteps, context, c_size=c_size, c_ar=c_ar) out = self.forward_orig(x, timesteps, context, c_size=c_size, c_ar=c_ar)
## only return EPS ## only return EPS
if self.pred_sigma: if self.pred_sigma:
return out[:, :self.in_channels] return out[:, :self.in_channels]
return out return out
def unpatchify(self, x, h, w): def unpatchify(self, x, h, w):
""" """
x: (N, T, patch_size**2 * C) x: (N, T, patch_size**2 * C)
imgs: (N, H, W, C) imgs: (N, H, W, C)
""" """
c = self.out_channels c = self.out_channels
p = self.x_embedder.patch_size[0] p = self.x_embedder.patch_size[0]
h = h // self.patch_size h = h // self.patch_size
w = w // self.patch_size w = w // self.patch_size
assert h * w == x.shape[1] assert h * w == x.shape[1]
x = x.reshape(shape=(x.shape[0], h, w, p, p, c)) x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
x = torch.einsum('nhwpqc->nchpwq', x) x = torch.einsum('nhwpqc->nchpwq', x)
imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p)) imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p))
return imgs return imgs

View File

@ -1,42 +1,42 @@
import os import os
from comfy import sd1_clip from comfy import sd1_clip
import comfy.text_encoders.t5 import comfy.text_encoders.t5
import comfy.text_encoders.sd3_clip import comfy.text_encoders.sd3_clip
from comfy.sd1_clip import gen_empty_tokens from comfy.sd1_clip import gen_empty_tokens
from transformers import T5TokenizerFast from transformers import T5TokenizerFast
class T5XXLModel(comfy.text_encoders.sd3_clip.T5XXLModel): class T5XXLModel(comfy.text_encoders.sd3_clip.T5XXLModel):
def __init__(self, **kwargs): def __init__(self, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
def gen_empty_tokens(self, special_tokens, *args, **kwargs): def gen_empty_tokens(self, special_tokens, *args, **kwargs):
# PixArt expects the negative to be all pad tokens # PixArt expects the negative to be all pad tokens
special_tokens = special_tokens.copy() special_tokens = special_tokens.copy()
special_tokens.pop("end") special_tokens.pop("end")
return gen_empty_tokens(special_tokens, *args, **kwargs) return gen_empty_tokens(special_tokens, *args, **kwargs)
class PixArtT5XXL(sd1_clip.SD1ClipModel): class PixArtT5XXL(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}): def __init__(self, device="cpu", dtype=None, model_options={}):
super().__init__(device=device, dtype=dtype, name="t5xxl", clip_model=T5XXLModel, model_options=model_options) super().__init__(device=device, dtype=dtype, name="t5xxl", clip_model=T5XXLModel, model_options=model_options)
class T5XXLTokenizer(sd1_clip.SDTokenizer): class T5XXLTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}): def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer") tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_data=tokenizer_data) # no padding super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_data=tokenizer_data) # no padding
class PixArtTokenizer(sd1_clip.SD1Tokenizer): class PixArtTokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}): def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5xxl", tokenizer=T5XXLTokenizer) super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5xxl", tokenizer=T5XXLTokenizer)
def pixart_te(dtype_t5=None, t5xxl_scaled_fp8=None): def pixart_te(dtype_t5=None, t5xxl_scaled_fp8=None):
class PixArtTEModel_(PixArtT5XXL): class PixArtTEModel_(PixArtT5XXL):
def __init__(self, device="cpu", dtype=None, model_options={}): def __init__(self, device="cpu", dtype=None, model_options={}):
if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options: if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
model_options = model_options.copy() model_options = model_options.copy()
model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8 model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
if dtype is None: if dtype is None:
dtype = dtype_t5 dtype = dtype_t5
super().__init__(device=device, dtype=dtype, model_options=model_options) super().__init__(device=device, dtype=dtype, model_options=model_options)
return PixArtTEModel_ return PixArtTEModel_