Merge branch 'master' into worksplit-multigpu

This commit is contained in:
Jedrzej Kosinski 2025-08-04 20:23:19 -07:00
commit b4f559b34d
16 changed files with 613 additions and 35 deletions

View File

@ -130,10 +130,21 @@ class ModelFileManager:
for file_name in filenames:
try:
relative_path = os.path.relpath(os.path.join(dirpath, file_name), directory)
result.append(relative_path)
except:
logging.warning(f"Warning: Unable to access {file_name}. Skipping this file.")
full_path = os.path.join(dirpath, file_name)
relative_path = os.path.relpath(full_path, directory)
# Get file metadata
file_info = {
"name": relative_path,
"pathIndex": pathIndex,
"modified": os.path.getmtime(full_path), # Add modification time
"created": os.path.getctime(full_path), # Add creation time
"size": os.path.getsize(full_path) # Add file size
}
result.append(file_info)
except Exception as e:
logging.warning(f"Warning: Unable to access {file_name}. Error: {e}. Skipping this file.")
continue
for d in subdirs:
@ -144,7 +155,7 @@ class ModelFileManager:
logging.warning(f"Warning: Unable to access {path}. Skipping this path.")
continue
return [{"name": f, "pathIndex": pathIndex} for f in result], dirs, time.perf_counter()
return result, dirs, time.perf_counter()
def get_model_previews(self, filepath: str) -> list[str | BytesIO]:
dirname = os.path.dirname(filepath)

View File

@ -20,13 +20,15 @@ class FileInfo(TypedDict):
path: str
size: int
modified: int
created: int
def get_file_info(path: str, relative_to: str) -> FileInfo:
return {
"path": os.path.relpath(path, relative_to).replace(os.sep, '/'),
"size": os.path.getsize(path),
"modified": os.path.getmtime(path)
"modified": os.path.getmtime(path),
"created": os.path.getctime(path)
}

View File

@ -1,6 +1,7 @@
import torch
import math
import comfy.utils
import logging
class CONDRegular:
@ -10,12 +11,15 @@ class CONDRegular:
def _copy_with(self, cond):
return self.__class__(cond)
def process_cond(self, batch_size, device, **kwargs):
return self._copy_with(comfy.utils.repeat_to_batch_size(self.cond, batch_size).to(device))
def process_cond(self, batch_size, **kwargs):
return self._copy_with(comfy.utils.repeat_to_batch_size(self.cond, batch_size))
def can_concat(self, other):
if self.cond.shape != other.cond.shape:
return False
if self.cond.device != other.cond.device:
logging.warning("WARNING: conds not on same device, skipping concat.")
return False
return True
def concat(self, others):
@ -29,14 +33,14 @@ class CONDRegular:
class CONDNoiseShape(CONDRegular):
def process_cond(self, batch_size, device, area, **kwargs):
def process_cond(self, batch_size, area, **kwargs):
data = self.cond
if area is not None:
dims = len(area) // 2
for i in range(dims):
data = data.narrow(i + 2, area[i + dims], area[i])
return self._copy_with(comfy.utils.repeat_to_batch_size(data, batch_size).to(device))
return self._copy_with(comfy.utils.repeat_to_batch_size(data, batch_size))
class CONDCrossAttn(CONDRegular):
@ -51,6 +55,9 @@ class CONDCrossAttn(CONDRegular):
diff = mult_min // min(s1[1], s2[1])
if diff > 4: #arbitrary limit on the padding because it's probably going to impact performance negatively if it's too much
return False
if self.cond.device != other.cond.device:
logging.warning("WARNING: conds not on same device: skipping concat.")
return False
return True
def concat(self, others):
@ -73,7 +80,7 @@ class CONDConstant(CONDRegular):
def __init__(self, cond):
self.cond = cond
def process_cond(self, batch_size, device, **kwargs):
def process_cond(self, batch_size, **kwargs):
return self._copy_with(self.cond)
def can_concat(self, other):
@ -92,10 +99,10 @@ class CONDList(CONDRegular):
def __init__(self, cond):
self.cond = cond
def process_cond(self, batch_size, device, **kwargs):
def process_cond(self, batch_size, **kwargs):
out = []
for c in self.cond:
out.append(comfy.utils.repeat_to_batch_size(c, batch_size).to(device))
out.append(comfy.utils.repeat_to_batch_size(c, batch_size))
return self._copy_with(out)

View File

@ -29,6 +29,7 @@ import comfy.model_detection
import comfy.model_patcher
import comfy.ops
import comfy.latent_formats
import comfy.model_base
import comfy.cldm.cldm
import comfy.t2i_adapter.adapter
@ -44,7 +45,6 @@ if TYPE_CHECKING:
def broadcast_image_to(tensor, target_batch_size, batched_number):
current_batch_size = tensor.shape[0]
#print(current_batch_size, target_batch_size)
if current_batch_size == 1:
return tensor
@ -300,12 +300,12 @@ class ControlNet(ControlBase):
for c in self.extra_conds:
temp = cond.get(c, None)
if temp is not None:
extra[c] = temp.to(dtype)
extra[c] = comfy.model_base.convert_tensor(temp, dtype, x_noisy.device)
timestep = self.model_sampling_current.timestep(t)
x_noisy = self.model_sampling_current.calculate_input(t, x_noisy)
control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.to(dtype), context=context.to(dtype), **extra)
control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.to(dtype), context=comfy.model_management.cast_to_device(context, x_noisy.device, dtype), **extra)
return self.control_merge(control, control_prev, output_dtype=None)
def copy(self):

View File

@ -0,0 +1,399 @@
# https://github.com/QwenLM/Qwen-Image (Apache 2.0)
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple
from einops import repeat
from comfy.ldm.lightricks.model import TimestepEmbedding, Timesteps
from comfy.ldm.modules.attention import optimized_attention_masked
from comfy.ldm.flux.layers import EmbedND
class GELU(nn.Module):
def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool = True, dtype=None, device=None, operations=None):
super().__init__()
self.proj = operations.Linear(dim_in, dim_out, bias=bias, dtype=dtype, device=device)
self.approximate = approximate
def forward(self, hidden_states):
hidden_states = self.proj(hidden_states)
hidden_states = F.gelu(hidden_states, approximate=self.approximate)
return hidden_states
class FeedForward(nn.Module):
def __init__(
self,
dim: int,
dim_out: Optional[int] = None,
mult: int = 4,
dropout: float = 0.0,
inner_dim=None,
bias: bool = True,
dtype=None, device=None, operations=None
):
super().__init__()
if inner_dim is None:
inner_dim = int(dim * mult)
dim_out = dim_out if dim_out is not None else dim
self.net = nn.ModuleList([])
self.net.append(GELU(dim, inner_dim, approximate="tanh", bias=bias, dtype=dtype, device=device, operations=operations))
self.net.append(nn.Dropout(dropout))
self.net.append(operations.Linear(inner_dim, dim_out, bias=bias, dtype=dtype, device=device))
def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor:
for module in self.net:
hidden_states = module(hidden_states)
return hidden_states
def apply_rotary_emb(x, freqs_cis):
if x.shape[1] == 0:
return x
t_ = x.reshape(*x.shape[:-1], -1, 1, 2)
t_out = freqs_cis[..., 0] * t_[..., 0] + freqs_cis[..., 1] * t_[..., 1]
return t_out.reshape(*x.shape)
class QwenTimestepProjEmbeddings(nn.Module):
def __init__(self, embedding_dim, pooled_projection_dim, dtype=None, device=None, operations=None):
super().__init__()
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0, scale=1000)
self.timestep_embedder = TimestepEmbedding(
in_channels=256,
time_embed_dim=embedding_dim,
dtype=dtype,
device=device,
operations=operations
)
def forward(self, timestep, hidden_states):
timesteps_proj = self.time_proj(timestep)
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_states.dtype))
return timesteps_emb
class Attention(nn.Module):
def __init__(
self,
query_dim: int,
dim_head: int = 64,
heads: int = 8,
dropout: float = 0.0,
bias: bool = False,
eps: float = 1e-5,
out_bias: bool = True,
out_dim: int = None,
out_context_dim: int = None,
dtype=None,
device=None,
operations=None
):
super().__init__()
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
self.inner_kv_dim = self.inner_dim
self.heads = heads
self.dim_head = dim_head
self.out_dim = out_dim if out_dim is not None else query_dim
self.out_context_dim = out_context_dim if out_context_dim is not None else query_dim
self.dropout = dropout
# Q/K normalization
self.norm_q = operations.RMSNorm(dim_head, eps=eps, elementwise_affine=True, dtype=dtype, device=device)
self.norm_k = operations.RMSNorm(dim_head, eps=eps, elementwise_affine=True, dtype=dtype, device=device)
self.norm_added_q = operations.RMSNorm(dim_head, eps=eps, dtype=dtype, device=device)
self.norm_added_k = operations.RMSNorm(dim_head, eps=eps, dtype=dtype, device=device)
# Image stream projections
self.to_q = operations.Linear(query_dim, self.inner_dim, bias=bias, dtype=dtype, device=device)
self.to_k = operations.Linear(query_dim, self.inner_kv_dim, bias=bias, dtype=dtype, device=device)
self.to_v = operations.Linear(query_dim, self.inner_kv_dim, bias=bias, dtype=dtype, device=device)
# Text stream projections
self.add_q_proj = operations.Linear(query_dim, self.inner_dim, bias=bias, dtype=dtype, device=device)
self.add_k_proj = operations.Linear(query_dim, self.inner_kv_dim, bias=bias, dtype=dtype, device=device)
self.add_v_proj = operations.Linear(query_dim, self.inner_kv_dim, bias=bias, dtype=dtype, device=device)
# Output projections
self.to_out = nn.ModuleList([
operations.Linear(self.inner_dim, self.out_dim, bias=out_bias, dtype=dtype, device=device),
nn.Dropout(dropout)
])
self.to_add_out = operations.Linear(self.inner_dim, self.out_context_dim, bias=out_bias, dtype=dtype, device=device)
def forward(
self,
hidden_states: torch.FloatTensor, # Image stream
encoder_hidden_states: torch.FloatTensor = None, # Text stream
encoder_hidden_states_mask: torch.FloatTensor = None,
attention_mask: Optional[torch.FloatTensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
seq_txt = encoder_hidden_states.shape[1]
img_query = self.to_q(hidden_states).unflatten(-1, (self.heads, -1))
img_key = self.to_k(hidden_states).unflatten(-1, (self.heads, -1))
img_value = self.to_v(hidden_states).unflatten(-1, (self.heads, -1))
txt_query = self.add_q_proj(encoder_hidden_states).unflatten(-1, (self.heads, -1))
txt_key = self.add_k_proj(encoder_hidden_states).unflatten(-1, (self.heads, -1))
txt_value = self.add_v_proj(encoder_hidden_states).unflatten(-1, (self.heads, -1))
img_query = self.norm_q(img_query)
img_key = self.norm_k(img_key)
txt_query = self.norm_added_q(txt_query)
txt_key = self.norm_added_k(txt_key)
joint_query = torch.cat([txt_query, img_query], dim=1)
joint_key = torch.cat([txt_key, img_key], dim=1)
joint_value = torch.cat([txt_value, img_value], dim=1)
joint_query = apply_rotary_emb(joint_query, image_rotary_emb)
joint_key = apply_rotary_emb(joint_key, image_rotary_emb)
joint_query = joint_query.flatten(start_dim=2)
joint_key = joint_key.flatten(start_dim=2)
joint_value = joint_value.flatten(start_dim=2)
joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads, attention_mask)
txt_attn_output = joint_hidden_states[:, :seq_txt, :]
img_attn_output = joint_hidden_states[:, seq_txt:, :]
img_attn_output = self.to_out[0](img_attn_output)
img_attn_output = self.to_out[1](img_attn_output)
txt_attn_output = self.to_add_out(txt_attn_output)
return img_attn_output, txt_attn_output
class QwenImageTransformerBlock(nn.Module):
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
eps: float = 1e-6,
dtype=None,
device=None,
operations=None
):
super().__init__()
self.dim = dim
self.num_attention_heads = num_attention_heads
self.attention_head_dim = attention_head_dim
self.img_mod = nn.Sequential(
nn.SiLU(),
operations.Linear(dim, 6 * dim, bias=True, dtype=dtype, device=device),
)
self.img_norm1 = operations.LayerNorm(dim, elementwise_affine=False, eps=eps, dtype=dtype, device=device)
self.img_norm2 = operations.LayerNorm(dim, elementwise_affine=False, eps=eps, dtype=dtype, device=device)
self.img_mlp = FeedForward(dim=dim, dim_out=dim, dtype=dtype, device=device, operations=operations)
self.txt_mod = nn.Sequential(
nn.SiLU(),
operations.Linear(dim, 6 * dim, bias=True, dtype=dtype, device=device),
)
self.txt_norm1 = operations.LayerNorm(dim, elementwise_affine=False, eps=eps, dtype=dtype, device=device)
self.txt_norm2 = operations.LayerNorm(dim, elementwise_affine=False, eps=eps, dtype=dtype, device=device)
self.txt_mlp = FeedForward(dim=dim, dim_out=dim, dtype=dtype, device=device, operations=operations)
self.attn = Attention(
query_dim=dim,
dim_head=attention_head_dim,
heads=num_attention_heads,
out_dim=dim,
bias=True,
eps=eps,
dtype=dtype,
device=device,
operations=operations,
)
def _modulate(self, x, mod_params):
shift, scale, gate = mod_params.chunk(3, dim=-1)
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1), gate.unsqueeze(1)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
encoder_hidden_states_mask: torch.Tensor,
temb: torch.Tensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
img_mod_params = self.img_mod(temb)
txt_mod_params = self.txt_mod(temb)
img_mod1, img_mod2 = img_mod_params.chunk(2, dim=-1)
txt_mod1, txt_mod2 = txt_mod_params.chunk(2, dim=-1)
img_normed = self.img_norm1(hidden_states)
img_modulated, img_gate1 = self._modulate(img_normed, img_mod1)
txt_normed = self.txt_norm1(encoder_hidden_states)
txt_modulated, txt_gate1 = self._modulate(txt_normed, txt_mod1)
img_attn_output, txt_attn_output = self.attn(
hidden_states=img_modulated,
encoder_hidden_states=txt_modulated,
encoder_hidden_states_mask=encoder_hidden_states_mask,
image_rotary_emb=image_rotary_emb,
)
hidden_states = hidden_states + img_gate1 * img_attn_output
encoder_hidden_states = encoder_hidden_states + txt_gate1 * txt_attn_output
img_normed2 = self.img_norm2(hidden_states)
img_modulated2, img_gate2 = self._modulate(img_normed2, img_mod2)
hidden_states = hidden_states + img_gate2 * self.img_mlp(img_modulated2)
txt_normed2 = self.txt_norm2(encoder_hidden_states)
txt_modulated2, txt_gate2 = self._modulate(txt_normed2, txt_mod2)
encoder_hidden_states = encoder_hidden_states + txt_gate2 * self.txt_mlp(txt_modulated2)
return encoder_hidden_states, hidden_states
class LastLayer(nn.Module):
def __init__(
self,
embedding_dim: int,
conditioning_embedding_dim: int,
elementwise_affine=False,
eps=1e-6,
bias=True,
dtype=None, device=None, operations=None
):
super().__init__()
self.silu = nn.SiLU()
self.linear = operations.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=bias, dtype=dtype, device=device)
self.norm = operations.LayerNorm(embedding_dim, eps, elementwise_affine=False, bias=bias, dtype=dtype, device=device)
def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor:
emb = self.linear(self.silu(conditioning_embedding))
scale, shift = torch.chunk(emb, 2, dim=1)
x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
return x
class QwenImageTransformer2DModel(nn.Module):
def __init__(
self,
patch_size: int = 2,
in_channels: int = 64,
out_channels: Optional[int] = 16,
num_layers: int = 60,
attention_head_dim: int = 128,
num_attention_heads: int = 24,
joint_attention_dim: int = 3584,
pooled_projection_dim: int = 768,
guidance_embeds: bool = False,
axes_dims_rope: Tuple[int, int, int] = (16, 56, 56),
image_model=None,
dtype=None,
device=None,
operations=None,
):
super().__init__()
self.dtype = dtype
self.patch_size = patch_size
self.out_channels = out_channels or in_channels
self.inner_dim = num_attention_heads * attention_head_dim
self.pe_embedder = EmbedND(dim=attention_head_dim, theta=10000, axes_dim=list(axes_dims_rope))
self.time_text_embed = QwenTimestepProjEmbeddings(
embedding_dim=self.inner_dim,
pooled_projection_dim=pooled_projection_dim,
dtype=dtype,
device=device,
operations=operations
)
self.txt_norm = operations.RMSNorm(joint_attention_dim, eps=1e-6, dtype=dtype, device=device)
self.img_in = operations.Linear(in_channels, self.inner_dim, dtype=dtype, device=device)
self.txt_in = operations.Linear(joint_attention_dim, self.inner_dim, dtype=dtype, device=device)
self.transformer_blocks = nn.ModuleList([
QwenImageTransformerBlock(
dim=self.inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
dtype=dtype,
device=device,
operations=operations
)
for _ in range(num_layers)
])
self.norm_out = LastLayer(self.inner_dim, self.inner_dim, dtype=dtype, device=device, operations=operations)
self.proj_out = operations.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True, dtype=dtype, device=device)
self.gradient_checkpointing = False
def pos_embeds(self, x, context):
bs, c, t, h, w = x.shape
patch_size = self.patch_size
h_len = ((h + (patch_size // 2)) // patch_size)
w_len = ((w + (patch_size // 2)) // patch_size)
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1)
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
txt_start = round(max(h_len, w_len))
txt_ids = torch.linspace(txt_start, txt_start + context.shape[1], steps=context.shape[1], device=x.device, dtype=x.dtype).reshape(1, -1, 1).repeat(bs, 1, 3)
ids = torch.cat((txt_ids, img_ids), dim=1)
return self.pe_embedder(ids).squeeze(1).unsqueeze(2).to(x.dtype)
def forward(
self,
x,
timesteps,
context,
attention_mask=None,
guidance: torch.Tensor = None,
**kwargs
):
timestep = timesteps
encoder_hidden_states = context
encoder_hidden_states_mask = attention_mask
image_rotary_emb = self.pos_embeds(x, context)
orig_shape = x.shape
hidden_states = x.view(orig_shape[0], orig_shape[1], orig_shape[-2] // 2, 2, orig_shape[-1] // 2, 2)
hidden_states = hidden_states.permute(0, 2, 4, 1, 3, 5)
hidden_states = hidden_states.reshape(orig_shape[0], (orig_shape[-2] // 2) * (orig_shape[-1] // 2), orig_shape[1] * 4)
hidden_states = self.img_in(hidden_states)
encoder_hidden_states = self.txt_norm(encoder_hidden_states)
encoder_hidden_states = self.txt_in(encoder_hidden_states)
if guidance is not None:
guidance = guidance * 1000
temb = (
self.time_text_embed(timestep, hidden_states)
if guidance is None
else self.time_text_embed(timestep, guidance, hidden_states)
)
for block in self.transformer_blocks:
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
encoder_hidden_states_mask=encoder_hidden_states_mask,
temb=temb,
image_rotary_emb=image_rotary_emb,
)
hidden_states = self.norm_out(hidden_states, temb)
hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states.view(orig_shape[0], orig_shape[-2] // 2, orig_shape[-1] // 2, orig_shape[1], 2, 2)
hidden_states = hidden_states.permute(0, 3, 1, 4, 2, 5)
return hidden_states.reshape(orig_shape)

View File

@ -42,6 +42,7 @@ import comfy.ldm.hidream.model
import comfy.ldm.chroma.model
import comfy.ldm.ace.model
import comfy.ldm.omnigen.omnigen2
import comfy.ldm.qwen_image.model
import comfy.model_management
import comfy.patcher_extension
@ -106,10 +107,12 @@ def model_sampling(model_config, model_type):
return ModelSampling(model_config)
def convert_tensor(extra, dtype):
def convert_tensor(extra, dtype, device):
if hasattr(extra, "dtype"):
if extra.dtype != torch.int and extra.dtype != torch.long:
extra = extra.to(dtype)
extra = comfy.model_management.cast_to_device(extra, device, dtype)
else:
extra = comfy.model_management.cast_to_device(extra, device, None)
return extra
@ -160,7 +163,7 @@ class BaseModel(torch.nn.Module):
xc = self.model_sampling.calculate_input(sigma, x)
if c_concat is not None:
xc = torch.cat([xc] + [c_concat], dim=1)
xc = torch.cat([xc] + [comfy.model_management.cast_to_device(c_concat, xc.device, xc.dtype)], dim=1)
context = c_crossattn
dtype = self.get_dtype()
@ -169,20 +172,21 @@ class BaseModel(torch.nn.Module):
dtype = self.manual_cast_dtype
xc = xc.to(dtype)
device = xc.device
t = self.model_sampling.timestep(t).float()
if context is not None:
context = context.to(dtype)
context = comfy.model_management.cast_to_device(context, device, dtype)
extra_conds = {}
for o in kwargs:
extra = kwargs[o]
if hasattr(extra, "dtype"):
extra = convert_tensor(extra, dtype)
extra = convert_tensor(extra, dtype, device)
elif isinstance(extra, list):
ex = []
for ext in extra:
ex.append(convert_tensor(ext, dtype))
ex.append(convert_tensor(ext, dtype, device))
extra = ex
extra_conds[o] = extra
@ -398,7 +402,7 @@ class SD21UNCLIP(BaseModel):
unclip_conditioning = kwargs.get("unclip_conditioning", None)
device = kwargs["device"]
if unclip_conditioning is None:
return torch.zeros((1, self.adm_channels))
return torch.zeros((1, self.adm_channels), device=device)
else:
return unclip_adm(unclip_conditioning, device, self.noise_augmentor, kwargs.get("unclip_noise_augment_merge", 0.05), kwargs.get("seed", 0) - 10)
@ -612,9 +616,11 @@ class IP2P:
if image is None:
image = torch.zeros_like(noise)
else:
image = image.to(device=device)
if image.shape[1:] != noise.shape[1:]:
image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
image = utils.common_upscale(image, noise.shape[-1], noise.shape[-2], "bilinear", "center")
image = utils.resize_to_batch_size(image, noise.shape[0])
return self.process_ip2p_image_in(image)
@ -693,7 +699,7 @@ class StableCascade_B(BaseModel):
#size of prior doesn't really matter if zeros because it gets resized but I still want it to get batched
prior = kwargs.get("stable_cascade_prior", torch.zeros((1, 16, (noise.shape[2] * 4) // 42, (noise.shape[3] * 4) // 42), dtype=noise.dtype, layout=noise.layout, device=noise.device))
out["effnet"] = comfy.conds.CONDRegular(prior)
out["effnet"] = comfy.conds.CONDRegular(prior.to(device=noise.device))
out["sca"] = comfy.conds.CONDRegular(torch.zeros((1,)))
return out
@ -1158,10 +1164,10 @@ class WAN21_Vace(WAN21):
vace_frames_out = []
for j in range(len(vace_frames)):
vf = vace_frames[j].clone()
vf = vace_frames[j].to(device=noise.device, dtype=noise.dtype, copy=True)
for i in range(0, vf.shape[1], 16):
vf[:, i:i + 16] = self.process_latent_in(vf[:, i:i + 16])
vf = torch.cat([vf, mask[j]], dim=1)
vf = torch.cat([vf, mask[j].to(device=noise.device, dtype=noise.dtype)], dim=1)
vace_frames_out.append(vf)
vace_frames = torch.stack(vace_frames_out, dim=1)
@ -1303,3 +1309,14 @@ class Omnigen2(BaseModel):
if ref_latents is not None:
out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16])
return out
class QwenImage(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.qwen_image.model.QwenImageTransformer2DModel)
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
return out

View File

@ -481,6 +481,11 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["timestep_scale"] = 1000.0
return dit_config
if '{}txt_norm.weight'.format(key_prefix) in state_dict_keys: # Qwen Image
dit_config = {}
dit_config["image_model"] = "qwen_image"
return dit_config
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
return None
@ -867,7 +872,7 @@ def convert_diffusers_mmdit(state_dict, output_prefix=""):
depth_single_blocks = count_blocks(state_dict, 'single_transformer_blocks.{}.')
hidden_size = state_dict["x_embedder.bias"].shape[0]
sd_map = comfy.utils.flux_to_diffusers({"depth": depth, "depth_single_blocks": depth_single_blocks, "hidden_size": hidden_size}, output_prefix=output_prefix)
elif 'transformer_blocks.0.attn.add_q_proj.weight' in state_dict: #SD3
elif 'transformer_blocks.0.attn.add_q_proj.weight' in state_dict and 'pos_embed.proj.weight' in state_dict: #SD3
num_blocks = count_blocks(state_dict, 'transformer_blocks.{}.')
depth = state_dict["pos_embed.proj.weight"].shape[0] // 64
sd_map = comfy.utils.mmdit_to_diffusers({"depth": depth, "num_blocks": num_blocks}, output_prefix=output_prefix)

View File

@ -92,7 +92,7 @@ def get_area_and_mult(conds, x_in, timestep_in):
conditioning = {}
model_conds = conds["model_conds"]
for c in model_conds:
conditioning[c] = model_conds[c].process_cond(batch_size=x_in.shape[0], device=x_in.device, area=area)
conditioning[c] = model_conds[c].process_cond(batch_size=x_in.shape[0], area=area)
hooks = conds.get('hooks', None)
control = conds.get('control', None)

View File

@ -47,6 +47,7 @@ import comfy.text_encoders.wan
import comfy.text_encoders.hidream
import comfy.text_encoders.ace
import comfy.text_encoders.omnigen2
import comfy.text_encoders.qwen_image
import comfy.model_patcher
import comfy.lora
@ -771,6 +772,7 @@ class CLIPType(Enum):
CHROMA = 15
ACE = 16
OMNIGEN2 = 17
QWEN_IMAGE = 18
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
@ -791,6 +793,7 @@ class TEModel(Enum):
T5_XXL_OLD = 8
GEMMA_2_2B = 9
QWEN25_3B = 10
QWEN25_7B = 11
def detect_te_model(sd):
if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:
@ -812,7 +815,11 @@ def detect_te_model(sd):
if 'model.layers.0.post_feedforward_layernorm.weight' in sd:
return TEModel.GEMMA_2_2B
if 'model.layers.0.self_attn.k_proj.bias' in sd:
return TEModel.QWEN25_3B
weight = sd['model.layers.0.self_attn.k_proj.bias']
if weight.shape[0] == 256:
return TEModel.QWEN25_3B
if weight.shape[0] == 512:
return TEModel.QWEN25_7B
if "model.layers.0.post_attention_layernorm.weight" in sd:
return TEModel.LLAMA3_8
return None
@ -917,6 +924,9 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
elif te_model == TEModel.QWEN25_3B:
clip_target.clip = comfy.text_encoders.omnigen2.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.omnigen2.Omnigen2Tokenizer
elif te_model == TEModel.QWEN25_7B:
clip_target.clip = comfy.text_encoders.qwen_image.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.qwen_image.QwenImageTokenizer
else:
# clip_l
if clip_type == CLIPType.SD3:

View File

@ -19,6 +19,7 @@ import comfy.text_encoders.lumina2
import comfy.text_encoders.wan
import comfy.text_encoders.ace
import comfy.text_encoders.omnigen2
import comfy.text_encoders.qwen_image
from . import supported_models_base
from . import latent_formats
@ -1229,7 +1230,36 @@ class Omnigen2(supported_models_base.BASE):
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_3b.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.omnigen2.Omnigen2Tokenizer, comfy.text_encoders.omnigen2.te(**hunyuan_detect))
class QwenImage(supported_models_base.BASE):
unet_config = {
"image_model": "qwen_image",
}
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep, Omnigen2]
sampling_settings = {
"multiplier": 1.0,
"shift": 2.6,
}
memory_usage_factor = 1.8 #TODO
unet_extra_config = {}
latent_format = latent_formats.Wan21
supported_inference_dtypes = [torch.bfloat16, torch.float32]
vae_key_prefix = ["vae."]
text_encoder_key_prefix = ["text_encoders."]
def get_model(self, state_dict, prefix="", device=None):
out = model_base.QwenImage(self, device=device)
return out
def clip_target(self, state_dict={}):
pref = self.text_encoder_key_prefix[0]
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.qwen_image.QwenImageTokenizer, comfy.text_encoders.qwen_image.te(**hunyuan_detect))
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep, Omnigen2, QwenImage]
models += [SVD_img2vid]

View File

@ -43,6 +43,23 @@ class Qwen25_3BConfig:
mlp_activation = "silu"
qkv_bias = True
@dataclass
class Qwen25_7BVLI_Config:
vocab_size: int = 152064
hidden_size: int = 3584
intermediate_size: int = 18944
num_hidden_layers: int = 28
num_attention_heads: int = 28
num_key_value_heads: int = 4
max_position_embeddings: int = 128000
rms_norm_eps: float = 1e-6
rope_theta: float = 1000000.0
transformer_type: str = "llama"
head_dim = 128
rms_norm_add = False
mlp_activation = "silu"
qkv_bias = True
@dataclass
class Gemma2_2B_Config:
vocab_size: int = 256000
@ -348,6 +365,15 @@ class Qwen25_3B(BaseLlama, torch.nn.Module):
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype
class Qwen25_7BVLI(BaseLlama, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
config = Qwen25_7BVLI_Config(**config_dict)
self.num_layers = config.num_hidden_layers
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype
class Gemma2_2B(BaseLlama, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()

View File

@ -0,0 +1,71 @@
from transformers import Qwen2Tokenizer
from comfy import sd1_clip
import comfy.text_encoders.llama
import os
import torch
import numbers
class Qwen25_7BVLITokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer")
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=3584, embedding_key='qwen25_7b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=151643, tokenizer_data=tokenizer_data)
class QwenImageTokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="qwen25_7b", tokenizer=Qwen25_7BVLITokenizer)
self.llama_template = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None,**kwargs):
if llama_template is None:
llama_text = self.llama_template.format(text)
else:
llama_text = llama_template.format(text)
return super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, **kwargs)
class Qwen25_7BVLIModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}):
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Qwen25_7BVLI, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
class QwenImageTEModel(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
super().__init__(device=device, dtype=dtype, name="qwen25_7b", clip_model=Qwen25_7BVLIModel, model_options=model_options)
def encode_token_weights(self, token_weight_pairs):
out, pooled, extra = super().encode_token_weights(token_weight_pairs)
tok_pairs = token_weight_pairs["qwen25_7b"][0]
count_im_start = 0
for i, v in enumerate(tok_pairs):
elem = v[0]
if not torch.is_tensor(elem):
if isinstance(elem, numbers.Integral):
if elem == 151644 and count_im_start < 2:
template_end = i
count_im_start += 1
if out.shape[1] > (template_end + 3):
if tok_pairs[template_end + 1][0] == 872:
if tok_pairs[template_end + 2][0] == 198:
template_end += 3
out = out[:, template_end:]
extra["attention_mask"] = extra["attention_mask"][:, template_end:]
if extra["attention_mask"].sum() == torch.numel(extra["attention_mask"]):
extra.pop("attention_mask") # attention mask is useless if no masked elements
return out, pooled, extra
def te(dtype_llama=None, llama_scaled_fp8=None):
class QwenImageTEModel_(QwenImageTEModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options:
model_options = model_options.copy()
model_options["scaled_fp8"] = llama_scaled_fp8
if dtype_llama is not None:
dtype = dtype_llama
super().__init__(device=device, dtype=dtype, model_options=model_options)
return QwenImageTEModel_

View File

@ -1,3 +1,3 @@
# This file is automatically generated by the build process when version is
# updated in pyproject.toml.
__version__ = "0.3.47"
__version__ = "0.3.48"

View File

@ -925,7 +925,7 @@ class CLIPLoader:
@classmethod
def INPUT_TYPES(s):
return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ),
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2"], ),
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image"], ),
},
"optional": {
"device": (["default", "cpu"], {"advanced": True}),

View File

@ -1,6 +1,6 @@
[project]
name = "ComfyUI"
version = "0.3.47"
version = "0.3.48"
readme = "README.md"
license = { file = "LICENSE" }
requires-python = ">=3.9"

View File

@ -1,5 +1,5 @@
comfyui-frontend-package==1.23.4
comfyui-workflow-templates==0.1.45
comfyui-workflow-templates==0.1.47
comfyui-embedded-docs==0.2.4
torch
torchsde