From bff60b5cfc10d1b037a95746226ac6698dc3e373 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 1 Aug 2025 20:03:22 -0400 Subject: [PATCH 01/12] ComfyUI version 0.3.48 --- comfyui_version.py | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/comfyui_version.py b/comfyui_version.py index 20a2e892..7b29e338 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.3.47" +__version__ = "0.3.48" diff --git a/pyproject.toml b/pyproject.toml index 244fdd23..256677fa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.3.47" +version = "0.3.48" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.9" From 3dfefc88d00bde744b729b073058a57e149cddc1 Mon Sep 17 00:00:00 2001 From: Johnpaul Chiwetelu <49923152+Myestery@users.noreply.github.com> Date: Sat, 2 Aug 2025 03:02:06 +0100 Subject: [PATCH 02/12] API for Recently Used Items (#8792) * feat: add file creation time to model file metadata and user file info * fix linting --- app/model_manager.py | 21 ++++++++++++++++----- app/user_manager.py | 4 +++- 2 files changed, 19 insertions(+), 6 deletions(-) diff --git a/app/model_manager.py b/app/model_manager.py index 74d942fb..ab36bca7 100644 --- a/app/model_manager.py +++ b/app/model_manager.py @@ -130,10 +130,21 @@ class ModelFileManager: for file_name in filenames: try: - relative_path = os.path.relpath(os.path.join(dirpath, file_name), directory) - result.append(relative_path) - except: - logging.warning(f"Warning: Unable to access {file_name}. Skipping this file.") + full_path = os.path.join(dirpath, file_name) + relative_path = os.path.relpath(full_path, directory) + + # Get file metadata + file_info = { + "name": relative_path, + "pathIndex": pathIndex, + "modified": os.path.getmtime(full_path), # Add modification time + "created": os.path.getctime(full_path), # Add creation time + "size": os.path.getsize(full_path) # Add file size + } + result.append(file_info) + + except Exception as e: + logging.warning(f"Warning: Unable to access {file_name}. Error: {e}. Skipping this file.") continue for d in subdirs: @@ -144,7 +155,7 @@ class ModelFileManager: logging.warning(f"Warning: Unable to access {path}. Skipping this path.") continue - return [{"name": f, "pathIndex": pathIndex} for f in result], dirs, time.perf_counter() + return result, dirs, time.perf_counter() def get_model_previews(self, filepath: str) -> list[str | BytesIO]: dirname = os.path.dirname(filepath) diff --git a/app/user_manager.py b/app/user_manager.py index d31da5b9..0ec3e46e 100644 --- a/app/user_manager.py +++ b/app/user_manager.py @@ -20,13 +20,15 @@ class FileInfo(TypedDict): path: str size: int modified: int + created: int def get_file_info(path: str, relative_to: str) -> FileInfo: return { "path": os.path.relpath(path, relative_to).replace(os.sep, '/'), "size": os.path.getsize(path), - "modified": os.path.getmtime(path) + "modified": os.path.getmtime(path), + "created": os.path.getctime(path) } From fbcc23945dc377c8623bbee6132f15a93ac0c84a Mon Sep 17 00:00:00 2001 From: ComfyUI Wiki Date: Sun, 3 Aug 2025 02:15:29 +0800 Subject: [PATCH 03/12] Update template to 0.1.47 (#9153) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 3828c5b9..ffa7dce6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.23.4 -comfyui-workflow-templates==0.1.45 +comfyui-workflow-templates==0.1.47 comfyui-embedded-docs==0.2.4 torch torchsde From 5f582a97572e87ebfa655d379e8c8f7611c0249f Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sat, 2 Aug 2025 12:00:13 -0700 Subject: [PATCH 04/12] Make sure all the conds are on the right device. (#9151) --- comfy/model_base.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/comfy/model_base.py b/comfy/model_base.py index 6b797894..3ff8106d 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -106,10 +106,12 @@ def model_sampling(model_config, model_type): return ModelSampling(model_config) -def convert_tensor(extra, dtype): +def convert_tensor(extra, dtype, device): if hasattr(extra, "dtype"): if extra.dtype != torch.int and extra.dtype != torch.long: - extra = extra.to(dtype) + extra = extra.to(dtype=dtype, device=device) + else: + extra = extra.to(device=device) return extra @@ -174,15 +176,16 @@ class BaseModel(torch.nn.Module): context = context.to(dtype) extra_conds = {} + device = xc.device for o in kwargs: extra = kwargs[o] if hasattr(extra, "dtype"): - extra = convert_tensor(extra, dtype) + extra = convert_tensor(extra, dtype, device) elif isinstance(extra, list): ex = [] for ext in extra: - ex.append(convert_tensor(ext, dtype)) + ex.append(convert_tensor(ext, dtype, device)) extra = ex extra_conds[o] = extra From 13aaa66ec21c397240a9b972d818430b39112588 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sat, 2 Aug 2025 12:09:23 -0700 Subject: [PATCH 05/12] Make sure context is on the right device. (#9154) --- comfy/model_base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/comfy/model_base.py b/comfy/model_base.py index 3ff8106d..4556ee13 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -171,12 +171,12 @@ class BaseModel(torch.nn.Module): dtype = self.manual_cast_dtype xc = xc.to(dtype) + device = xc.device t = self.model_sampling.timestep(t).float() if context is not None: - context = context.to(dtype) + context = context.to(dtype=dtype, device=device) extra_conds = {} - device = xc.device for o in kwargs: extra = kwargs[o] From aebac221937b511d46fe601656acdc753435b849 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sun, 3 Aug 2025 04:08:11 -0700 Subject: [PATCH 06/12] Cleanup. (#9160) --- comfy/controlnet.py | 1 - 1 file changed, 1 deletion(-) diff --git a/comfy/controlnet.py b/comfy/controlnet.py index 9a47b86f..6ed8bd75 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -43,7 +43,6 @@ if TYPE_CHECKING: def broadcast_image_to(tensor, target_batch_size, batched_number): current_batch_size = tensor.shape[0] - #print(current_batch_size, target_batch_size) if current_batch_size == 1: return tensor From 182f90b5eca2baa25474223759039925b286d562 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 4 Aug 2025 00:11:53 -0700 Subject: [PATCH 07/12] Lower cond vram use by casting at the same time as device transfer. (#9159) --- comfy/conds.py | 14 +++++++------- comfy/model_base.py | 6 +++--- comfy/samplers.py | 2 +- 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/comfy/conds.py b/comfy/conds.py index 2af2a43a..f2564e7e 100644 --- a/comfy/conds.py +++ b/comfy/conds.py @@ -10,8 +10,8 @@ class CONDRegular: def _copy_with(self, cond): return self.__class__(cond) - def process_cond(self, batch_size, device, **kwargs): - return self._copy_with(comfy.utils.repeat_to_batch_size(self.cond, batch_size).to(device)) + def process_cond(self, batch_size, **kwargs): + return self._copy_with(comfy.utils.repeat_to_batch_size(self.cond, batch_size)) def can_concat(self, other): if self.cond.shape != other.cond.shape: @@ -29,14 +29,14 @@ class CONDRegular: class CONDNoiseShape(CONDRegular): - def process_cond(self, batch_size, device, area, **kwargs): + def process_cond(self, batch_size, area, **kwargs): data = self.cond if area is not None: dims = len(area) // 2 for i in range(dims): data = data.narrow(i + 2, area[i + dims], area[i]) - return self._copy_with(comfy.utils.repeat_to_batch_size(data, batch_size).to(device)) + return self._copy_with(comfy.utils.repeat_to_batch_size(data, batch_size)) class CONDCrossAttn(CONDRegular): @@ -73,7 +73,7 @@ class CONDConstant(CONDRegular): def __init__(self, cond): self.cond = cond - def process_cond(self, batch_size, device, **kwargs): + def process_cond(self, batch_size, **kwargs): return self._copy_with(self.cond) def can_concat(self, other): @@ -92,10 +92,10 @@ class CONDList(CONDRegular): def __init__(self, cond): self.cond = cond - def process_cond(self, batch_size, device, **kwargs): + def process_cond(self, batch_size, **kwargs): out = [] for c in self.cond: - out.append(comfy.utils.repeat_to_batch_size(c, batch_size).to(device)) + out.append(comfy.utils.repeat_to_batch_size(c, batch_size)) return self._copy_with(out) diff --git a/comfy/model_base.py b/comfy/model_base.py index 4556ee13..3a9c031e 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -109,9 +109,9 @@ def model_sampling(model_config, model_type): def convert_tensor(extra, dtype, device): if hasattr(extra, "dtype"): if extra.dtype != torch.int and extra.dtype != torch.long: - extra = extra.to(dtype=dtype, device=device) + extra = comfy.model_management.cast_to_device(extra, device, dtype) else: - extra = extra.to(device=device) + extra = comfy.model_management.cast_to_device(extra, device, None) return extra @@ -174,7 +174,7 @@ class BaseModel(torch.nn.Module): device = xc.device t = self.model_sampling.timestep(t).float() if context is not None: - context = context.to(dtype=dtype, device=device) + context = comfy.model_management.cast_to_device(context, device, dtype) extra_conds = {} for o in kwargs: diff --git a/comfy/samplers.py b/comfy/samplers.py index e93d2a31..ad2f40cd 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -89,7 +89,7 @@ def get_area_and_mult(conds, x_in, timestep_in): conditioning = {} model_conds = conds["model_conds"] for c in model_conds: - conditioning[c] = model_conds[c].process_cond(batch_size=x_in.shape[0], device=x_in.device, area=area) + conditioning[c] = model_conds[c].process_cond(batch_size=x_in.shape[0], area=area) hooks = conds.get('hooks', None) control = conds.get('control', None) From 140ffc7fdc53e810030f060e421c1f528c2d2ab9 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 4 Aug 2025 00:28:12 -0700 Subject: [PATCH 08/12] Fix broken controlnet from last PR. (#9167) --- comfy/controlnet.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/comfy/controlnet.py b/comfy/controlnet.py index 6ed8bd75..988acdb5 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -28,6 +28,7 @@ import comfy.model_detection import comfy.model_patcher import comfy.ops import comfy.latent_formats +import comfy.model_base import comfy.cldm.cldm import comfy.t2i_adapter.adapter @@ -264,12 +265,12 @@ class ControlNet(ControlBase): for c in self.extra_conds: temp = cond.get(c, None) if temp is not None: - extra[c] = temp.to(dtype) + extra[c] = comfy.model_base.convert_tensor(temp, dtype, x_noisy.device) timestep = self.model_sampling_current.timestep(t) x_noisy = self.model_sampling_current.calculate_input(t, x_noisy) - control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.to(dtype), context=context.to(dtype), **extra) + control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.to(dtype), context=comfy.model_management.cast_to_device(context, x_noisy.device, dtype), **extra) return self.control_merge(control, control_prev, output_dtype=None) def copy(self): From 7991341e89cab521441641505ac4b0eea292a829 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 4 Aug 2025 01:02:40 -0700 Subject: [PATCH 09/12] Various fixes for broken things from earlier PR. (#9168) --- comfy/model_base.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/comfy/model_base.py b/comfy/model_base.py index 3a9c031e..f9591f29 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -401,7 +401,7 @@ class SD21UNCLIP(BaseModel): unclip_conditioning = kwargs.get("unclip_conditioning", None) device = kwargs["device"] if unclip_conditioning is None: - return torch.zeros((1, self.adm_channels)) + return torch.zeros((1, self.adm_channels), device=device) else: return unclip_adm(unclip_conditioning, device, self.noise_augmentor, kwargs.get("unclip_noise_augment_merge", 0.05), kwargs.get("seed", 0) - 10) @@ -409,7 +409,7 @@ def sdxl_pooled(args, noise_augmentor): if "unclip_conditioning" in args: return unclip_adm(args.get("unclip_conditioning", None), args["device"], noise_augmentor, seed=args.get("seed", 0) - 10)[:,:1280] else: - return args["pooled_output"] + return args["pooled_output"].to(device=args["device"]) class SDXLRefiner(BaseModel): def __init__(self, model_config, model_type=ModelType.EPS, device=None): @@ -615,9 +615,11 @@ class IP2P: if image is None: image = torch.zeros_like(noise) + else: + image = image.to(device=device) if image.shape[1:] != noise.shape[1:]: - image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center") + image = utils.common_upscale(image, noise.shape[-1], noise.shape[-2], "bilinear", "center") image = utils.resize_to_batch_size(image, noise.shape[0]) return self.process_ip2p_image_in(image) @@ -696,7 +698,7 @@ class StableCascade_B(BaseModel): #size of prior doesn't really matter if zeros because it gets resized but I still want it to get batched prior = kwargs.get("stable_cascade_prior", torch.zeros((1, 16, (noise.shape[2] * 4) // 42, (noise.shape[3] * 4) // 42), dtype=noise.dtype, layout=noise.layout, device=noise.device)) - out["effnet"] = comfy.conds.CONDRegular(prior) + out["effnet"] = comfy.conds.CONDRegular(prior.to(device=noise.device)) out["sca"] = comfy.conds.CONDRegular(torch.zeros((1,))) return out @@ -1161,10 +1163,10 @@ class WAN21_Vace(WAN21): vace_frames_out = [] for j in range(len(vace_frames)): - vf = vace_frames[j].clone() + vf = vace_frames[j].to(device=noise.device, dtype=noise.dtype, copy=True) for i in range(0, vf.shape[1], 16): vf[:, i:i + 16] = self.process_latent_in(vf[:, i:i + 16]) - vf = torch.cat([vf, mask[j]], dim=1) + vf = torch.cat([vf, mask[j].to(device=noise.device, dtype=noise.dtype)], dim=1) vace_frames_out.append(vf) vace_frames = torch.stack(vace_frames_out, dim=1) From 84f9759424ccbd8de710960c79f0f1d28eef2776 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 4 Aug 2025 01:20:12 -0700 Subject: [PATCH 10/12] Add some warnings and prevent crash when cond devices don't match. (#9169) --- comfy/conds.py | 7 +++++++ comfy/model_base.py | 2 +- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/comfy/conds.py b/comfy/conds.py index f2564e7e..5af3e93e 100644 --- a/comfy/conds.py +++ b/comfy/conds.py @@ -1,6 +1,7 @@ import torch import math import comfy.utils +import logging class CONDRegular: @@ -16,6 +17,9 @@ class CONDRegular: def can_concat(self, other): if self.cond.shape != other.cond.shape: return False + if self.cond.device != other.cond.device: + logging.warning("WARNING: conds not on same device, skipping concat.") + return False return True def concat(self, others): @@ -51,6 +55,9 @@ class CONDCrossAttn(CONDRegular): diff = mult_min // min(s1[1], s2[1]) if diff > 4: #arbitrary limit on the padding because it's probably going to impact performance negatively if it's too much return False + if self.cond.device != other.cond.device: + logging.warning("WARNING: conds not on same device: skipping concat.") + return False return True def concat(self, others): diff --git a/comfy/model_base.py b/comfy/model_base.py index f9591f29..2db81e24 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -409,7 +409,7 @@ def sdxl_pooled(args, noise_augmentor): if "unclip_conditioning" in args: return unclip_adm(args.get("unclip_conditioning", None), args["device"], noise_augmentor, seed=args.get("seed", 0) - 10)[:,:1280] else: - return args["pooled_output"].to(device=args["device"]) + return args["pooled_output"] class SDXLRefiner(BaseModel): def __init__(self, model_config, model_type=ModelType.EPS, device=None): From 03895dea7c4a6cc93fa362cd11ca450217d74b13 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 4 Aug 2025 01:33:04 -0700 Subject: [PATCH 11/12] Fix another issue with the PR. (#9170) --- comfy/model_base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/model_base.py b/comfy/model_base.py index 2db81e24..a0668643 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -162,7 +162,7 @@ class BaseModel(torch.nn.Module): xc = self.model_sampling.calculate_input(sigma, x) if c_concat is not None: - xc = torch.cat([xc] + [c_concat], dim=1) + xc = torch.cat([xc] + [comfy.model_management.cast_to_device(c_concat, xc.device, xc.dtype)], dim=1) context = c_crossattn dtype = self.get_dtype() From c012400240d4867cd63a45220eb791b91ad47617 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 4 Aug 2025 19:53:25 -0700 Subject: [PATCH 12/12] Initial support for qwen image model. (#9179) --- comfy/ldm/qwen_image/model.py | 399 ++++++++++++++++++++++++++++++ comfy/model_base.py | 12 + comfy/model_detection.py | 7 +- comfy/sd.py | 12 +- comfy/supported_models.py | 32 ++- comfy/text_encoders/llama.py | 26 ++ comfy/text_encoders/qwen_image.py | 71 ++++++ nodes.py | 2 +- 8 files changed, 557 insertions(+), 4 deletions(-) create mode 100644 comfy/ldm/qwen_image/model.py create mode 100644 comfy/text_encoders/qwen_image.py diff --git a/comfy/ldm/qwen_image/model.py b/comfy/ldm/qwen_image/model.py new file mode 100644 index 00000000..ff631a60 --- /dev/null +++ b/comfy/ldm/qwen_image/model.py @@ -0,0 +1,399 @@ +# https://github.com/QwenLM/Qwen-Image (Apache 2.0) +import torch +import torch.nn as nn +import torch.nn.functional as F +from typing import Optional, Tuple +from einops import repeat + +from comfy.ldm.lightricks.model import TimestepEmbedding, Timesteps +from comfy.ldm.modules.attention import optimized_attention_masked +from comfy.ldm.flux.layers import EmbedND + + +class GELU(nn.Module): + def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool = True, dtype=None, device=None, operations=None): + super().__init__() + self.proj = operations.Linear(dim_in, dim_out, bias=bias, dtype=dtype, device=device) + self.approximate = approximate + + def forward(self, hidden_states): + hidden_states = self.proj(hidden_states) + hidden_states = F.gelu(hidden_states, approximate=self.approximate) + return hidden_states + + +class FeedForward(nn.Module): + def __init__( + self, + dim: int, + dim_out: Optional[int] = None, + mult: int = 4, + dropout: float = 0.0, + inner_dim=None, + bias: bool = True, + dtype=None, device=None, operations=None + ): + super().__init__() + if inner_dim is None: + inner_dim = int(dim * mult) + dim_out = dim_out if dim_out is not None else dim + + self.net = nn.ModuleList([]) + self.net.append(GELU(dim, inner_dim, approximate="tanh", bias=bias, dtype=dtype, device=device, operations=operations)) + self.net.append(nn.Dropout(dropout)) + self.net.append(operations.Linear(inner_dim, dim_out, bias=bias, dtype=dtype, device=device)) + + def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor: + for module in self.net: + hidden_states = module(hidden_states) + return hidden_states + + +def apply_rotary_emb(x, freqs_cis): + if x.shape[1] == 0: + return x + + t_ = x.reshape(*x.shape[:-1], -1, 1, 2) + t_out = freqs_cis[..., 0] * t_[..., 0] + freqs_cis[..., 1] * t_[..., 1] + return t_out.reshape(*x.shape) + + +class QwenTimestepProjEmbeddings(nn.Module): + def __init__(self, embedding_dim, pooled_projection_dim, dtype=None, device=None, operations=None): + super().__init__() + self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0, scale=1000) + self.timestep_embedder = TimestepEmbedding( + in_channels=256, + time_embed_dim=embedding_dim, + dtype=dtype, + device=device, + operations=operations + ) + + def forward(self, timestep, hidden_states): + timesteps_proj = self.time_proj(timestep) + timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_states.dtype)) + return timesteps_emb + + +class Attention(nn.Module): + def __init__( + self, + query_dim: int, + dim_head: int = 64, + heads: int = 8, + dropout: float = 0.0, + bias: bool = False, + eps: float = 1e-5, + out_bias: bool = True, + out_dim: int = None, + out_context_dim: int = None, + dtype=None, + device=None, + operations=None + ): + super().__init__() + self.inner_dim = out_dim if out_dim is not None else dim_head * heads + self.inner_kv_dim = self.inner_dim + self.heads = heads + self.dim_head = dim_head + self.out_dim = out_dim if out_dim is not None else query_dim + self.out_context_dim = out_context_dim if out_context_dim is not None else query_dim + self.dropout = dropout + + # Q/K normalization + self.norm_q = operations.RMSNorm(dim_head, eps=eps, elementwise_affine=True, dtype=dtype, device=device) + self.norm_k = operations.RMSNorm(dim_head, eps=eps, elementwise_affine=True, dtype=dtype, device=device) + self.norm_added_q = operations.RMSNorm(dim_head, eps=eps, dtype=dtype, device=device) + self.norm_added_k = operations.RMSNorm(dim_head, eps=eps, dtype=dtype, device=device) + + # Image stream projections + self.to_q = operations.Linear(query_dim, self.inner_dim, bias=bias, dtype=dtype, device=device) + self.to_k = operations.Linear(query_dim, self.inner_kv_dim, bias=bias, dtype=dtype, device=device) + self.to_v = operations.Linear(query_dim, self.inner_kv_dim, bias=bias, dtype=dtype, device=device) + + # Text stream projections + self.add_q_proj = operations.Linear(query_dim, self.inner_dim, bias=bias, dtype=dtype, device=device) + self.add_k_proj = operations.Linear(query_dim, self.inner_kv_dim, bias=bias, dtype=dtype, device=device) + self.add_v_proj = operations.Linear(query_dim, self.inner_kv_dim, bias=bias, dtype=dtype, device=device) + + # Output projections + self.to_out = nn.ModuleList([ + operations.Linear(self.inner_dim, self.out_dim, bias=out_bias, dtype=dtype, device=device), + nn.Dropout(dropout) + ]) + self.to_add_out = operations.Linear(self.inner_dim, self.out_context_dim, bias=out_bias, dtype=dtype, device=device) + + def forward( + self, + hidden_states: torch.FloatTensor, # Image stream + encoder_hidden_states: torch.FloatTensor = None, # Text stream + encoder_hidden_states_mask: torch.FloatTensor = None, + attention_mask: Optional[torch.FloatTensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + seq_txt = encoder_hidden_states.shape[1] + + img_query = self.to_q(hidden_states).unflatten(-1, (self.heads, -1)) + img_key = self.to_k(hidden_states).unflatten(-1, (self.heads, -1)) + img_value = self.to_v(hidden_states).unflatten(-1, (self.heads, -1)) + + txt_query = self.add_q_proj(encoder_hidden_states).unflatten(-1, (self.heads, -1)) + txt_key = self.add_k_proj(encoder_hidden_states).unflatten(-1, (self.heads, -1)) + txt_value = self.add_v_proj(encoder_hidden_states).unflatten(-1, (self.heads, -1)) + + img_query = self.norm_q(img_query) + img_key = self.norm_k(img_key) + txt_query = self.norm_added_q(txt_query) + txt_key = self.norm_added_k(txt_key) + + joint_query = torch.cat([txt_query, img_query], dim=1) + joint_key = torch.cat([txt_key, img_key], dim=1) + joint_value = torch.cat([txt_value, img_value], dim=1) + + joint_query = apply_rotary_emb(joint_query, image_rotary_emb) + joint_key = apply_rotary_emb(joint_key, image_rotary_emb) + + joint_query = joint_query.flatten(start_dim=2) + joint_key = joint_key.flatten(start_dim=2) + joint_value = joint_value.flatten(start_dim=2) + + joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads, attention_mask) + + txt_attn_output = joint_hidden_states[:, :seq_txt, :] + img_attn_output = joint_hidden_states[:, seq_txt:, :] + + img_attn_output = self.to_out[0](img_attn_output) + img_attn_output = self.to_out[1](img_attn_output) + txt_attn_output = self.to_add_out(txt_attn_output) + + return img_attn_output, txt_attn_output + + +class QwenImageTransformerBlock(nn.Module): + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + eps: float = 1e-6, + dtype=None, + device=None, + operations=None + ): + super().__init__() + self.dim = dim + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + + self.img_mod = nn.Sequential( + nn.SiLU(), + operations.Linear(dim, 6 * dim, bias=True, dtype=dtype, device=device), + ) + self.img_norm1 = operations.LayerNorm(dim, elementwise_affine=False, eps=eps, dtype=dtype, device=device) + self.img_norm2 = operations.LayerNorm(dim, elementwise_affine=False, eps=eps, dtype=dtype, device=device) + self.img_mlp = FeedForward(dim=dim, dim_out=dim, dtype=dtype, device=device, operations=operations) + + self.txt_mod = nn.Sequential( + nn.SiLU(), + operations.Linear(dim, 6 * dim, bias=True, dtype=dtype, device=device), + ) + self.txt_norm1 = operations.LayerNorm(dim, elementwise_affine=False, eps=eps, dtype=dtype, device=device) + self.txt_norm2 = operations.LayerNorm(dim, elementwise_affine=False, eps=eps, dtype=dtype, device=device) + self.txt_mlp = FeedForward(dim=dim, dim_out=dim, dtype=dtype, device=device, operations=operations) + + self.attn = Attention( + query_dim=dim, + dim_head=attention_head_dim, + heads=num_attention_heads, + out_dim=dim, + bias=True, + eps=eps, + dtype=dtype, + device=device, + operations=operations, + ) + + def _modulate(self, x, mod_params): + shift, scale, gate = mod_params.chunk(3, dim=-1) + return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1), gate.unsqueeze(1) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + encoder_hidden_states_mask: torch.Tensor, + temb: torch.Tensor, + image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + img_mod_params = self.img_mod(temb) + txt_mod_params = self.txt_mod(temb) + img_mod1, img_mod2 = img_mod_params.chunk(2, dim=-1) + txt_mod1, txt_mod2 = txt_mod_params.chunk(2, dim=-1) + + img_normed = self.img_norm1(hidden_states) + img_modulated, img_gate1 = self._modulate(img_normed, img_mod1) + txt_normed = self.txt_norm1(encoder_hidden_states) + txt_modulated, txt_gate1 = self._modulate(txt_normed, txt_mod1) + + img_attn_output, txt_attn_output = self.attn( + hidden_states=img_modulated, + encoder_hidden_states=txt_modulated, + encoder_hidden_states_mask=encoder_hidden_states_mask, + image_rotary_emb=image_rotary_emb, + ) + + hidden_states = hidden_states + img_gate1 * img_attn_output + encoder_hidden_states = encoder_hidden_states + txt_gate1 * txt_attn_output + + img_normed2 = self.img_norm2(hidden_states) + img_modulated2, img_gate2 = self._modulate(img_normed2, img_mod2) + hidden_states = hidden_states + img_gate2 * self.img_mlp(img_modulated2) + + txt_normed2 = self.txt_norm2(encoder_hidden_states) + txt_modulated2, txt_gate2 = self._modulate(txt_normed2, txt_mod2) + encoder_hidden_states = encoder_hidden_states + txt_gate2 * self.txt_mlp(txt_modulated2) + + return encoder_hidden_states, hidden_states + + +class LastLayer(nn.Module): + def __init__( + self, + embedding_dim: int, + conditioning_embedding_dim: int, + elementwise_affine=False, + eps=1e-6, + bias=True, + dtype=None, device=None, operations=None + ): + super().__init__() + self.silu = nn.SiLU() + self.linear = operations.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=bias, dtype=dtype, device=device) + self.norm = operations.LayerNorm(embedding_dim, eps, elementwise_affine=False, bias=bias, dtype=dtype, device=device) + + def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor: + emb = self.linear(self.silu(conditioning_embedding)) + scale, shift = torch.chunk(emb, 2, dim=1) + x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :] + return x + + +class QwenImageTransformer2DModel(nn.Module): + def __init__( + self, + patch_size: int = 2, + in_channels: int = 64, + out_channels: Optional[int] = 16, + num_layers: int = 60, + attention_head_dim: int = 128, + num_attention_heads: int = 24, + joint_attention_dim: int = 3584, + pooled_projection_dim: int = 768, + guidance_embeds: bool = False, + axes_dims_rope: Tuple[int, int, int] = (16, 56, 56), + image_model=None, + dtype=None, + device=None, + operations=None, + ): + super().__init__() + self.dtype = dtype + self.patch_size = patch_size + self.out_channels = out_channels or in_channels + self.inner_dim = num_attention_heads * attention_head_dim + + self.pe_embedder = EmbedND(dim=attention_head_dim, theta=10000, axes_dim=list(axes_dims_rope)) + + self.time_text_embed = QwenTimestepProjEmbeddings( + embedding_dim=self.inner_dim, + pooled_projection_dim=pooled_projection_dim, + dtype=dtype, + device=device, + operations=operations + ) + + self.txt_norm = operations.RMSNorm(joint_attention_dim, eps=1e-6, dtype=dtype, device=device) + self.img_in = operations.Linear(in_channels, self.inner_dim, dtype=dtype, device=device) + self.txt_in = operations.Linear(joint_attention_dim, self.inner_dim, dtype=dtype, device=device) + + self.transformer_blocks = nn.ModuleList([ + QwenImageTransformerBlock( + dim=self.inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + dtype=dtype, + device=device, + operations=operations + ) + for _ in range(num_layers) + ]) + + self.norm_out = LastLayer(self.inner_dim, self.inner_dim, dtype=dtype, device=device, operations=operations) + self.proj_out = operations.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True, dtype=dtype, device=device) + self.gradient_checkpointing = False + + def pos_embeds(self, x, context): + bs, c, t, h, w = x.shape + patch_size = self.patch_size + h_len = ((h + (patch_size // 2)) // patch_size) + w_len = ((w + (patch_size // 2)) // patch_size) + + img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype) + img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1) + img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0) + img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs) + + txt_start = round(max(h_len, w_len)) + txt_ids = torch.linspace(txt_start, txt_start + context.shape[1], steps=context.shape[1], device=x.device, dtype=x.dtype).reshape(1, -1, 1).repeat(bs, 1, 3) + ids = torch.cat((txt_ids, img_ids), dim=1) + return self.pe_embedder(ids).squeeze(1).unsqueeze(2).to(x.dtype) + + def forward( + self, + x, + timesteps, + context, + attention_mask=None, + guidance: torch.Tensor = None, + **kwargs + ): + timestep = timesteps + encoder_hidden_states = context + encoder_hidden_states_mask = attention_mask + + image_rotary_emb = self.pos_embeds(x, context) + + orig_shape = x.shape + hidden_states = x.view(orig_shape[0], orig_shape[1], orig_shape[-2] // 2, 2, orig_shape[-1] // 2, 2) + hidden_states = hidden_states.permute(0, 2, 4, 1, 3, 5) + hidden_states = hidden_states.reshape(orig_shape[0], (orig_shape[-2] // 2) * (orig_shape[-1] // 2), orig_shape[1] * 4) + + hidden_states = self.img_in(hidden_states) + encoder_hidden_states = self.txt_norm(encoder_hidden_states) + encoder_hidden_states = self.txt_in(encoder_hidden_states) + + if guidance is not None: + guidance = guidance * 1000 + + temb = ( + self.time_text_embed(timestep, hidden_states) + if guidance is None + else self.time_text_embed(timestep, guidance, hidden_states) + ) + + for block in self.transformer_blocks: + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + encoder_hidden_states_mask=encoder_hidden_states_mask, + temb=temb, + image_rotary_emb=image_rotary_emb, + ) + + hidden_states = self.norm_out(hidden_states, temb) + hidden_states = self.proj_out(hidden_states) + + hidden_states = hidden_states.view(orig_shape[0], orig_shape[-2] // 2, orig_shape[-1] // 2, orig_shape[1], 2, 2) + hidden_states = hidden_states.permute(0, 3, 1, 4, 2, 5) + return hidden_states.reshape(orig_shape) diff --git a/comfy/model_base.py b/comfy/model_base.py index a0668643..8a2d9cbe 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -42,6 +42,7 @@ import comfy.ldm.hidream.model import comfy.ldm.chroma.model import comfy.ldm.ace.model import comfy.ldm.omnigen.omnigen2 +import comfy.ldm.qwen_image.model import comfy.model_management import comfy.patcher_extension @@ -1308,3 +1309,14 @@ class Omnigen2(BaseModel): if ref_latents is not None: out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16]) return out + +class QwenImage(BaseModel): + def __init__(self, model_config, model_type=ModelType.FLUX, device=None): + super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.qwen_image.model.QwenImageTransformer2DModel) + + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + cross_attn = kwargs.get("cross_attn", None) + if cross_attn is not None: + out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) + return out diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 9fc1f42d..8b57ebd2 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -481,6 +481,11 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["timestep_scale"] = 1000.0 return dit_config + if '{}txt_norm.weight'.format(key_prefix) in state_dict_keys: # Qwen Image + dit_config = {} + dit_config["image_model"] = "qwen_image" + return dit_config + if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys: return None @@ -867,7 +872,7 @@ def convert_diffusers_mmdit(state_dict, output_prefix=""): depth_single_blocks = count_blocks(state_dict, 'single_transformer_blocks.{}.') hidden_size = state_dict["x_embedder.bias"].shape[0] sd_map = comfy.utils.flux_to_diffusers({"depth": depth, "depth_single_blocks": depth_single_blocks, "hidden_size": hidden_size}, output_prefix=output_prefix) - elif 'transformer_blocks.0.attn.add_q_proj.weight' in state_dict: #SD3 + elif 'transformer_blocks.0.attn.add_q_proj.weight' in state_dict and 'pos_embed.proj.weight' in state_dict: #SD3 num_blocks = count_blocks(state_dict, 'transformer_blocks.{}.') depth = state_dict["pos_embed.proj.weight"].shape[0] // 64 sd_map = comfy.utils.mmdit_to_diffusers({"depth": depth, "num_blocks": num_blocks}, output_prefix=output_prefix) diff --git a/comfy/sd.py b/comfy/sd.py index e0498e58..bb5d61fb 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -47,6 +47,7 @@ import comfy.text_encoders.wan import comfy.text_encoders.hidream import comfy.text_encoders.ace import comfy.text_encoders.omnigen2 +import comfy.text_encoders.qwen_image import comfy.model_patcher import comfy.lora @@ -771,6 +772,7 @@ class CLIPType(Enum): CHROMA = 15 ACE = 16 OMNIGEN2 = 17 + QWEN_IMAGE = 18 def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}): @@ -791,6 +793,7 @@ class TEModel(Enum): T5_XXL_OLD = 8 GEMMA_2_2B = 9 QWEN25_3B = 10 + QWEN25_7B = 11 def detect_te_model(sd): if "text_model.encoder.layers.30.mlp.fc1.weight" in sd: @@ -812,7 +815,11 @@ def detect_te_model(sd): if 'model.layers.0.post_feedforward_layernorm.weight' in sd: return TEModel.GEMMA_2_2B if 'model.layers.0.self_attn.k_proj.bias' in sd: - return TEModel.QWEN25_3B + weight = sd['model.layers.0.self_attn.k_proj.bias'] + if weight.shape[0] == 256: + return TEModel.QWEN25_3B + if weight.shape[0] == 512: + return TEModel.QWEN25_7B if "model.layers.0.post_attention_layernorm.weight" in sd: return TEModel.LLAMA3_8 return None @@ -917,6 +924,9 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip elif te_model == TEModel.QWEN25_3B: clip_target.clip = comfy.text_encoders.omnigen2.te(**llama_detect(clip_data)) clip_target.tokenizer = comfy.text_encoders.omnigen2.Omnigen2Tokenizer + elif te_model == TEModel.QWEN25_7B: + clip_target.clip = comfy.text_encoders.qwen_image.te(**llama_detect(clip_data)) + clip_target.tokenizer = comfy.text_encoders.qwen_image.QwenImageTokenizer else: # clip_l if clip_type == CLIPType.SD3: diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 8f3f4652..880055bd 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -19,6 +19,7 @@ import comfy.text_encoders.lumina2 import comfy.text_encoders.wan import comfy.text_encoders.ace import comfy.text_encoders.omnigen2 +import comfy.text_encoders.qwen_image from . import supported_models_base from . import latent_formats @@ -1229,7 +1230,36 @@ class Omnigen2(supported_models_base.BASE): hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_3b.transformer.".format(pref)) return supported_models_base.ClipTarget(comfy.text_encoders.omnigen2.Omnigen2Tokenizer, comfy.text_encoders.omnigen2.te(**hunyuan_detect)) +class QwenImage(supported_models_base.BASE): + unet_config = { + "image_model": "qwen_image", + } -models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep, Omnigen2] + sampling_settings = { + "multiplier": 1.0, + "shift": 2.6, + } + + memory_usage_factor = 1.8 #TODO + + unet_extra_config = {} + latent_format = latent_formats.Wan21 + + supported_inference_dtypes = [torch.bfloat16, torch.float32] + + vae_key_prefix = ["vae."] + text_encoder_key_prefix = ["text_encoders."] + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.QwenImage(self, device=device) + return out + + def clip_target(self, state_dict={}): + pref = self.text_encoder_key_prefix[0] + hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref)) + return supported_models_base.ClipTarget(comfy.text_encoders.qwen_image.QwenImageTokenizer, comfy.text_encoders.qwen_image.te(**hunyuan_detect)) + + +models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep, Omnigen2, QwenImage] models += [SVD_img2vid] diff --git a/comfy/text_encoders/llama.py b/comfy/text_encoders/llama.py index 7fbd0f60..1da6a0c9 100644 --- a/comfy/text_encoders/llama.py +++ b/comfy/text_encoders/llama.py @@ -43,6 +43,23 @@ class Qwen25_3BConfig: mlp_activation = "silu" qkv_bias = True +@dataclass +class Qwen25_7BVLI_Config: + vocab_size: int = 152064 + hidden_size: int = 3584 + intermediate_size: int = 18944 + num_hidden_layers: int = 28 + num_attention_heads: int = 28 + num_key_value_heads: int = 4 + max_position_embeddings: int = 128000 + rms_norm_eps: float = 1e-6 + rope_theta: float = 1000000.0 + transformer_type: str = "llama" + head_dim = 128 + rms_norm_add = False + mlp_activation = "silu" + qkv_bias = True + @dataclass class Gemma2_2B_Config: vocab_size: int = 256000 @@ -348,6 +365,15 @@ class Qwen25_3B(BaseLlama, torch.nn.Module): self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) self.dtype = dtype +class Qwen25_7BVLI(BaseLlama, torch.nn.Module): + def __init__(self, config_dict, dtype, device, operations): + super().__init__() + config = Qwen25_7BVLI_Config(**config_dict) + self.num_layers = config.num_hidden_layers + + self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) + self.dtype = dtype + class Gemma2_2B(BaseLlama, torch.nn.Module): def __init__(self, config_dict, dtype, device, operations): super().__init__() diff --git a/comfy/text_encoders/qwen_image.py b/comfy/text_encoders/qwen_image.py new file mode 100644 index 00000000..ce5c9809 --- /dev/null +++ b/comfy/text_encoders/qwen_image.py @@ -0,0 +1,71 @@ +from transformers import Qwen2Tokenizer +from comfy import sd1_clip +import comfy.text_encoders.llama +import os +import torch +import numbers + +class Qwen25_7BVLITokenizer(sd1_clip.SDTokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer") + super().__init__(tokenizer_path, pad_with_end=False, embedding_size=3584, embedding_key='qwen25_7b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=151643, tokenizer_data=tokenizer_data) + + +class QwenImageTokenizer(sd1_clip.SD1Tokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="qwen25_7b", tokenizer=Qwen25_7BVLITokenizer) + self.llama_template = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n" + + def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None,**kwargs): + if llama_template is None: + llama_text = self.llama_template.format(text) + else: + llama_text = llama_template.format(text) + return super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, **kwargs) + + +class Qwen25_7BVLIModel(sd1_clip.SDClipModel): + def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}): + super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Qwen25_7BVLI, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) + + +class QwenImageTEModel(sd1_clip.SD1ClipModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + super().__init__(device=device, dtype=dtype, name="qwen25_7b", clip_model=Qwen25_7BVLIModel, model_options=model_options) + + def encode_token_weights(self, token_weight_pairs): + out, pooled, extra = super().encode_token_weights(token_weight_pairs) + tok_pairs = token_weight_pairs["qwen25_7b"][0] + count_im_start = 0 + for i, v in enumerate(tok_pairs): + elem = v[0] + if not torch.is_tensor(elem): + if isinstance(elem, numbers.Integral): + if elem == 151644 and count_im_start < 2: + template_end = i + count_im_start += 1 + + if out.shape[1] > (template_end + 3): + if tok_pairs[template_end + 1][0] == 872: + if tok_pairs[template_end + 2][0] == 198: + template_end += 3 + + out = out[:, template_end:] + + extra["attention_mask"] = extra["attention_mask"][:, template_end:] + if extra["attention_mask"].sum() == torch.numel(extra["attention_mask"]): + extra.pop("attention_mask") # attention mask is useless if no masked elements + + return out, pooled, extra + + +def te(dtype_llama=None, llama_scaled_fp8=None): + class QwenImageTEModel_(QwenImageTEModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options: + model_options = model_options.copy() + model_options["scaled_fp8"] = llama_scaled_fp8 + if dtype_llama is not None: + dtype = dtype_llama + super().__init__(device=device, dtype=dtype, model_options=model_options) + return QwenImageTEModel_ diff --git a/nodes.py b/nodes.py index da4a4636..9bedbcac 100644 --- a/nodes.py +++ b/nodes.py @@ -925,7 +925,7 @@ class CLIPLoader: @classmethod def INPUT_TYPES(s): return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ), - "type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2"], ), + "type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image"], ), }, "optional": { "device": (["default", "cpu"], {"advanced": True}),