mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-09-10 19:46:38 +00:00
This node is only useful if someone trains the kontext model to properly use multiple reference images via the index method. The default is the offset method which feeds the multiple images like if they were stitched together as one. This method works with the current flux kontext model.
253 lines
10 KiB
Python
253 lines
10 KiB
Python
#Original code can be found on: https://github.com/black-forest-labs/flux
|
|
|
|
from dataclasses import dataclass
|
|
|
|
import torch
|
|
from torch import Tensor, nn
|
|
from einops import rearrange, repeat
|
|
import comfy.ldm.common_dit
|
|
|
|
from .layers import (
|
|
DoubleStreamBlock,
|
|
EmbedND,
|
|
LastLayer,
|
|
MLPEmbedder,
|
|
SingleStreamBlock,
|
|
timestep_embedding,
|
|
)
|
|
|
|
@dataclass
|
|
class FluxParams:
|
|
in_channels: int
|
|
out_channels: int
|
|
vec_in_dim: int
|
|
context_in_dim: int
|
|
hidden_size: int
|
|
mlp_ratio: float
|
|
num_heads: int
|
|
depth: int
|
|
depth_single_blocks: int
|
|
axes_dim: list
|
|
theta: int
|
|
patch_size: int
|
|
qkv_bias: bool
|
|
guidance_embed: bool
|
|
|
|
|
|
class Flux(nn.Module):
|
|
"""
|
|
Transformer model for flow matching on sequences.
|
|
"""
|
|
|
|
def __init__(self, image_model=None, final_layer=True, dtype=None, device=None, operations=None, **kwargs):
|
|
super().__init__()
|
|
self.dtype = dtype
|
|
params = FluxParams(**kwargs)
|
|
self.params = params
|
|
self.patch_size = params.patch_size
|
|
self.in_channels = params.in_channels * params.patch_size * params.patch_size
|
|
self.out_channels = params.out_channels * params.patch_size * params.patch_size
|
|
if params.hidden_size % params.num_heads != 0:
|
|
raise ValueError(
|
|
f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
|
|
)
|
|
pe_dim = params.hidden_size // params.num_heads
|
|
if sum(params.axes_dim) != pe_dim:
|
|
raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
|
|
self.hidden_size = params.hidden_size
|
|
self.num_heads = params.num_heads
|
|
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
|
|
self.img_in = operations.Linear(self.in_channels, self.hidden_size, bias=True, dtype=dtype, device=device)
|
|
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations)
|
|
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size, dtype=dtype, device=device, operations=operations)
|
|
self.guidance_in = (
|
|
MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations) if params.guidance_embed else nn.Identity()
|
|
)
|
|
self.txt_in = operations.Linear(params.context_in_dim, self.hidden_size, dtype=dtype, device=device)
|
|
|
|
self.double_blocks = nn.ModuleList(
|
|
[
|
|
DoubleStreamBlock(
|
|
self.hidden_size,
|
|
self.num_heads,
|
|
mlp_ratio=params.mlp_ratio,
|
|
qkv_bias=params.qkv_bias,
|
|
dtype=dtype, device=device, operations=operations
|
|
)
|
|
for _ in range(params.depth)
|
|
]
|
|
)
|
|
|
|
self.single_blocks = nn.ModuleList(
|
|
[
|
|
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, dtype=dtype, device=device, operations=operations)
|
|
for _ in range(params.depth_single_blocks)
|
|
]
|
|
)
|
|
|
|
if final_layer:
|
|
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels, dtype=dtype, device=device, operations=operations)
|
|
|
|
def forward_orig(
|
|
self,
|
|
img: Tensor,
|
|
img_ids: Tensor,
|
|
txt: Tensor,
|
|
txt_ids: Tensor,
|
|
timesteps: Tensor,
|
|
y: Tensor,
|
|
guidance: Tensor = None,
|
|
control = None,
|
|
transformer_options={},
|
|
attn_mask: Tensor = None,
|
|
) -> Tensor:
|
|
|
|
if y is None:
|
|
y = torch.zeros((img.shape[0], self.params.vec_in_dim), device=img.device, dtype=img.dtype)
|
|
|
|
patches_replace = transformer_options.get("patches_replace", {})
|
|
if img.ndim != 3 or txt.ndim != 3:
|
|
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
|
|
|
# running on sequences img
|
|
img = self.img_in(img)
|
|
vec = self.time_in(timestep_embedding(timesteps, 256).to(img.dtype))
|
|
if self.params.guidance_embed:
|
|
if guidance is not None:
|
|
vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype))
|
|
|
|
vec = vec + self.vector_in(y[:,:self.params.vec_in_dim])
|
|
txt = self.txt_in(txt)
|
|
|
|
if img_ids is not None:
|
|
ids = torch.cat((txt_ids, img_ids), dim=1)
|
|
pe = self.pe_embedder(ids)
|
|
else:
|
|
pe = None
|
|
|
|
blocks_replace = patches_replace.get("dit", {})
|
|
for i, block in enumerate(self.double_blocks):
|
|
if ("double_block", i) in blocks_replace:
|
|
def block_wrap(args):
|
|
out = {}
|
|
out["img"], out["txt"] = block(img=args["img"],
|
|
txt=args["txt"],
|
|
vec=args["vec"],
|
|
pe=args["pe"],
|
|
attn_mask=args.get("attn_mask"))
|
|
return out
|
|
|
|
out = blocks_replace[("double_block", i)]({"img": img,
|
|
"txt": txt,
|
|
"vec": vec,
|
|
"pe": pe,
|
|
"attn_mask": attn_mask},
|
|
{"original_block": block_wrap})
|
|
txt = out["txt"]
|
|
img = out["img"]
|
|
else:
|
|
img, txt = block(img=img,
|
|
txt=txt,
|
|
vec=vec,
|
|
pe=pe,
|
|
attn_mask=attn_mask)
|
|
|
|
if control is not None: # Controlnet
|
|
control_i = control.get("input")
|
|
if i < len(control_i):
|
|
add = control_i[i]
|
|
if add is not None:
|
|
img += add
|
|
|
|
if img.dtype == torch.float16:
|
|
img = torch.nan_to_num(img, nan=0.0, posinf=65504, neginf=-65504)
|
|
|
|
img = torch.cat((txt, img), 1)
|
|
|
|
for i, block in enumerate(self.single_blocks):
|
|
if ("single_block", i) in blocks_replace:
|
|
def block_wrap(args):
|
|
out = {}
|
|
out["img"] = block(args["img"],
|
|
vec=args["vec"],
|
|
pe=args["pe"],
|
|
attn_mask=args.get("attn_mask"))
|
|
return out
|
|
|
|
out = blocks_replace[("single_block", i)]({"img": img,
|
|
"vec": vec,
|
|
"pe": pe,
|
|
"attn_mask": attn_mask},
|
|
{"original_block": block_wrap})
|
|
img = out["img"]
|
|
else:
|
|
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask)
|
|
|
|
if control is not None: # Controlnet
|
|
control_o = control.get("output")
|
|
if i < len(control_o):
|
|
add = control_o[i]
|
|
if add is not None:
|
|
img[:, txt.shape[1] :, ...] += add
|
|
|
|
img = img[:, txt.shape[1] :, ...]
|
|
|
|
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
|
|
return img
|
|
|
|
def process_img(self, x, index=0, h_offset=0, w_offset=0):
|
|
bs, c, h, w = x.shape
|
|
patch_size = self.patch_size
|
|
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
|
|
|
|
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
|
|
h_len = ((h + (patch_size // 2)) // patch_size)
|
|
w_len = ((w + (patch_size // 2)) // patch_size)
|
|
|
|
h_offset = ((h_offset + (patch_size // 2)) // patch_size)
|
|
w_offset = ((w_offset + (patch_size // 2)) // patch_size)
|
|
|
|
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
|
|
img_ids[:, :, 0] = img_ids[:, :, 1] + index
|
|
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1)
|
|
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
|
|
return img, repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
|
|
|
def forward(self, x, timestep, context, y=None, guidance=None, ref_latents=None, control=None, transformer_options={}, **kwargs):
|
|
bs, c, h_orig, w_orig = x.shape
|
|
patch_size = self.patch_size
|
|
|
|
h_len = ((h_orig + (patch_size // 2)) // patch_size)
|
|
w_len = ((w_orig + (patch_size // 2)) // patch_size)
|
|
img, img_ids = self.process_img(x)
|
|
img_tokens = img.shape[1]
|
|
if ref_latents is not None:
|
|
h = 0
|
|
w = 0
|
|
index = 0
|
|
index_ref_method = kwargs.get("ref_latents_method", "offset") == "index"
|
|
for ref in ref_latents:
|
|
if index_ref_method:
|
|
index += 1
|
|
h_offset = 0
|
|
w_offset = 0
|
|
else:
|
|
index = 1
|
|
h_offset = 0
|
|
w_offset = 0
|
|
if ref.shape[-2] + h > ref.shape[-1] + w:
|
|
w_offset = w
|
|
else:
|
|
h_offset = h
|
|
h = max(h, ref.shape[-2] + h_offset)
|
|
w = max(w, ref.shape[-1] + w_offset)
|
|
|
|
kontext, kontext_ids = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset)
|
|
img = torch.cat([img, kontext], dim=1)
|
|
img_ids = torch.cat([img_ids, kontext_ids], dim=1)
|
|
|
|
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
|
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None))
|
|
out = out[:, :img_tokens]
|
|
return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:,:,:h_orig,:w_orig]
|