mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-07-28 08:46:35 +00:00
Merge branch 'master' into worksplit-multigpu
This commit is contained in:
commit
d53479a197
108
.github/workflows/release-webhook.yml
vendored
Normal file
108
.github/workflows/release-webhook.yml
vendored
Normal file
@ -0,0 +1,108 @@
|
||||
name: Release Webhook
|
||||
|
||||
on:
|
||||
release:
|
||||
types: [published]
|
||||
|
||||
jobs:
|
||||
send-webhook:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Send release webhook
|
||||
env:
|
||||
WEBHOOK_URL: ${{ secrets.RELEASE_GITHUB_WEBHOOK_URL }}
|
||||
WEBHOOK_SECRET: ${{ secrets.RELEASE_GITHUB_WEBHOOK_SECRET }}
|
||||
run: |
|
||||
# Generate UUID for delivery ID
|
||||
DELIVERY_ID=$(uuidgen)
|
||||
HOOK_ID="release-webhook-$(date +%s)"
|
||||
|
||||
# Create webhook payload matching GitHub release webhook format
|
||||
PAYLOAD=$(cat <<EOF
|
||||
{
|
||||
"action": "published",
|
||||
"release": {
|
||||
"id": ${{ github.event.release.id }},
|
||||
"node_id": "${{ github.event.release.node_id }}",
|
||||
"url": "${{ github.event.release.url }}",
|
||||
"html_url": "${{ github.event.release.html_url }}",
|
||||
"assets_url": "${{ github.event.release.assets_url }}",
|
||||
"upload_url": "${{ github.event.release.upload_url }}",
|
||||
"tag_name": "${{ github.event.release.tag_name }}",
|
||||
"target_commitish": "${{ github.event.release.target_commitish }}",
|
||||
"name": ${{ toJSON(github.event.release.name) }},
|
||||
"body": ${{ toJSON(github.event.release.body) }},
|
||||
"draft": ${{ github.event.release.draft }},
|
||||
"prerelease": ${{ github.event.release.prerelease }},
|
||||
"created_at": "${{ github.event.release.created_at }}",
|
||||
"published_at": "${{ github.event.release.published_at }}",
|
||||
"author": {
|
||||
"login": "${{ github.event.release.author.login }}",
|
||||
"id": ${{ github.event.release.author.id }},
|
||||
"node_id": "${{ github.event.release.author.node_id }}",
|
||||
"avatar_url": "${{ github.event.release.author.avatar_url }}",
|
||||
"url": "${{ github.event.release.author.url }}",
|
||||
"html_url": "${{ github.event.release.author.html_url }}",
|
||||
"type": "${{ github.event.release.author.type }}",
|
||||
"site_admin": ${{ github.event.release.author.site_admin }}
|
||||
},
|
||||
"tarball_url": "${{ github.event.release.tarball_url }}",
|
||||
"zipball_url": "${{ github.event.release.zipball_url }}",
|
||||
"assets": ${{ toJSON(github.event.release.assets) }}
|
||||
},
|
||||
"repository": {
|
||||
"id": ${{ github.event.repository.id }},
|
||||
"node_id": "${{ github.event.repository.node_id }}",
|
||||
"name": "${{ github.event.repository.name }}",
|
||||
"full_name": "${{ github.event.repository.full_name }}",
|
||||
"private": ${{ github.event.repository.private }},
|
||||
"owner": {
|
||||
"login": "${{ github.event.repository.owner.login }}",
|
||||
"id": ${{ github.event.repository.owner.id }},
|
||||
"node_id": "${{ github.event.repository.owner.node_id }}",
|
||||
"avatar_url": "${{ github.event.repository.owner.avatar_url }}",
|
||||
"url": "${{ github.event.repository.owner.url }}",
|
||||
"html_url": "${{ github.event.repository.owner.html_url }}",
|
||||
"type": "${{ github.event.repository.owner.type }}",
|
||||
"site_admin": ${{ github.event.repository.owner.site_admin }}
|
||||
},
|
||||
"html_url": "${{ github.event.repository.html_url }}",
|
||||
"clone_url": "${{ github.event.repository.clone_url }}",
|
||||
"git_url": "${{ github.event.repository.git_url }}",
|
||||
"ssh_url": "${{ github.event.repository.ssh_url }}",
|
||||
"url": "${{ github.event.repository.url }}",
|
||||
"created_at": "${{ github.event.repository.created_at }}",
|
||||
"updated_at": "${{ github.event.repository.updated_at }}",
|
||||
"pushed_at": "${{ github.event.repository.pushed_at }}",
|
||||
"default_branch": "${{ github.event.repository.default_branch }}",
|
||||
"fork": ${{ github.event.repository.fork }}
|
||||
},
|
||||
"sender": {
|
||||
"login": "${{ github.event.sender.login }}",
|
||||
"id": ${{ github.event.sender.id }},
|
||||
"node_id": "${{ github.event.sender.node_id }}",
|
||||
"avatar_url": "${{ github.event.sender.avatar_url }}",
|
||||
"url": "${{ github.event.sender.url }}",
|
||||
"html_url": "${{ github.event.sender.html_url }}",
|
||||
"type": "${{ github.event.sender.type }}",
|
||||
"site_admin": ${{ github.event.sender.site_admin }}
|
||||
}
|
||||
}
|
||||
EOF
|
||||
)
|
||||
|
||||
# Generate HMAC-SHA256 signature
|
||||
SIGNATURE=$(echo -n "$PAYLOAD" | openssl dgst -sha256 -hmac "$WEBHOOK_SECRET" -hex | cut -d' ' -f2)
|
||||
|
||||
# Send webhook with required headers
|
||||
curl -X POST "$WEBHOOK_URL" \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "X-GitHub-Event: release" \
|
||||
-H "X-GitHub-Delivery: $DELIVERY_ID" \
|
||||
-H "X-GitHub-Hook-ID: $HOOK_ID" \
|
||||
-H "X-Hub-Signature-256: sha256=$SIGNATURE" \
|
||||
-H "User-Agent: GitHub-Actions-Webhook/1.0" \
|
||||
-d "$PAYLOAD" \
|
||||
--fail --silent --show-error
|
||||
|
||||
echo "✅ Release webhook sent successfully"
|
3
.github/workflows/stable-release.yml
vendored
3
.github/workflows/stable-release.yml
vendored
@ -102,5 +102,4 @@ jobs:
|
||||
file: ComfyUI_windows_portable_nvidia.7z
|
||||
tag: ${{ inputs.git_tag }}
|
||||
overwrite: true
|
||||
prerelease: true
|
||||
make_latest: false
|
||||
draft: true
|
||||
|
@ -66,6 +66,9 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith
|
||||
- [Lumina Image 2.0](https://comfyanonymous.github.io/ComfyUI_examples/lumina2/)
|
||||
- [HiDream](https://comfyanonymous.github.io/ComfyUI_examples/hidream/)
|
||||
- [Cosmos Predict2](https://comfyanonymous.github.io/ComfyUI_examples/cosmos_predict2/)
|
||||
- Image Editing Models
|
||||
- [Omnigen 2](https://comfyanonymous.github.io/ComfyUI_examples/omnigen/)
|
||||
- [Flux Kontext](https://comfyanonymous.github.io/ComfyUI_examples/flux/#flux-kontext-image-editing-model)
|
||||
- Video Models
|
||||
- [Stable Video Diffusion](https://comfyanonymous.github.io/ComfyUI_examples/video/)
|
||||
- [Mochi](https://comfyanonymous.github.io/ComfyUI_examples/mochi/)
|
||||
|
@ -151,6 +151,7 @@ parser.add_argument("--windows-standalone-build", action="store_true", help="Win
|
||||
|
||||
parser.add_argument("--disable-metadata", action="store_true", help="Disable saving prompt metadata in files.")
|
||||
parser.add_argument("--disable-all-custom-nodes", action="store_true", help="Disable loading all custom nodes.")
|
||||
parser.add_argument("--whitelist-custom-nodes", type=str, nargs='+', default=[], help="Specify custom node folders to load even when --disable-all-custom-nodes is enabled.")
|
||||
parser.add_argument("--disable-api-nodes", action="store_true", help="Disable loading all api nodes.")
|
||||
|
||||
parser.add_argument("--multi-user", action="store_true", help="Enables per-user storage.")
|
||||
|
@ -710,6 +710,7 @@ def sample_dpmpp_2s_ancestral_RF(model, x, sigmas, extra_args=None, callback=Non
|
||||
# logged_x = torch.cat((logged_x, x.unsqueeze(0)), dim=0)
|
||||
return x
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2):
|
||||
"""DPM-Solver++ (stochastic)."""
|
||||
@ -721,38 +722,49 @@ def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=N
|
||||
seed = extra_args.get("seed", None)
|
||||
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
sigma_fn = lambda t: t.neg().exp()
|
||||
t_fn = lambda sigma: sigma.log().neg()
|
||||
|
||||
model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling')
|
||||
sigma_fn = partial(half_log_snr_to_sigma, model_sampling=model_sampling)
|
||||
lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling)
|
||||
sigmas = offset_first_sigma_for_snr(sigmas, model_sampling)
|
||||
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||
if sigmas[i + 1] == 0:
|
||||
# Euler method
|
||||
d = to_d(x, sigmas[i], denoised)
|
||||
dt = sigmas[i + 1] - sigmas[i]
|
||||
x = x + d * dt
|
||||
# Denoising step
|
||||
x = denoised
|
||||
else:
|
||||
# DPM-Solver++
|
||||
t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1])
|
||||
h = t_next - t
|
||||
s = t + h * r
|
||||
lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1])
|
||||
h = lambda_t - lambda_s
|
||||
lambda_s_1 = lambda_s + r * h
|
||||
fac = 1 / (2 * r)
|
||||
|
||||
sigma_s_1 = sigma_fn(lambda_s_1)
|
||||
|
||||
alpha_s = sigmas[i] * lambda_s.exp()
|
||||
alpha_s_1 = sigma_s_1 * lambda_s_1.exp()
|
||||
alpha_t = sigmas[i + 1] * lambda_t.exp()
|
||||
|
||||
# Step 1
|
||||
sd, su = get_ancestral_step(sigma_fn(t), sigma_fn(s), eta)
|
||||
s_ = t_fn(sd)
|
||||
x_2 = (sigma_fn(s_) / sigma_fn(t)) * x - (t - s_).expm1() * denoised
|
||||
x_2 = x_2 + noise_sampler(sigma_fn(t), sigma_fn(s)) * s_noise * su
|
||||
denoised_2 = model(x_2, sigma_fn(s) * s_in, **extra_args)
|
||||
sd, su = get_ancestral_step(lambda_s.neg().exp(), lambda_s_1.neg().exp(), eta)
|
||||
lambda_s_1_ = sd.log().neg()
|
||||
h_ = lambda_s_1_ - lambda_s
|
||||
x_2 = (alpha_s_1 / alpha_s) * (-h_).exp() * x - alpha_s_1 * (-h_).expm1() * denoised
|
||||
if eta > 0 and s_noise > 0:
|
||||
x_2 = x_2 + alpha_s_1 * noise_sampler(sigmas[i], sigma_s_1) * s_noise * su
|
||||
denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args)
|
||||
|
||||
# Step 2
|
||||
sd, su = get_ancestral_step(sigma_fn(t), sigma_fn(t_next), eta)
|
||||
t_next_ = t_fn(sd)
|
||||
sd, su = get_ancestral_step(lambda_s.neg().exp(), lambda_t.neg().exp(), eta)
|
||||
lambda_t_ = sd.log().neg()
|
||||
h_ = lambda_t_ - lambda_s
|
||||
denoised_d = (1 - fac) * denoised + fac * denoised_2
|
||||
x = (sigma_fn(t_next_) / sigma_fn(t)) * x - (t - t_next_).expm1() * denoised_d
|
||||
x = x + noise_sampler(sigma_fn(t), sigma_fn(t_next)) * s_noise * su
|
||||
x = (alpha_t / alpha_s) * (-h_).exp() * x - alpha_t * (-h_).expm1() * denoised_d
|
||||
if eta > 0 and s_noise > 0:
|
||||
x = x + alpha_t * noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * su
|
||||
return x
|
||||
|
||||
|
||||
@ -1435,14 +1447,15 @@ def sample_gradient_estimation(model, x, sigmas, extra_args=None, callback=None,
|
||||
old_d = d
|
||||
return x
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_gradient_estimation_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, ge_gamma=2.):
|
||||
return sample_gradient_estimation(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, ge_gamma=ge_gamma, cfg_pp=True)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, s_noise=1., noise_sampler=None, noise_scaler=None, max_stage=3):
|
||||
"""
|
||||
Extended Reverse-Time SDE solver (VE ER-SDE-Solver-3). Arxiv: https://arxiv.org/abs/2309.06169.
|
||||
def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, s_noise=1.0, noise_sampler=None, noise_scaler=None, max_stage=3):
|
||||
"""Extended Reverse-Time SDE solver (VP ER-SDE-Solver-3). arXiv: https://arxiv.org/abs/2309.06169.
|
||||
Code reference: https://github.com/QinpengCui/ER-SDE-Solver/blob/main/er_sde_solver.py.
|
||||
"""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
@ -1450,12 +1463,18 @@ def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None
|
||||
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
|
||||
def default_noise_scaler(sigma):
|
||||
return sigma * ((sigma ** 0.3).exp() + 10.0)
|
||||
noise_scaler = default_noise_scaler if noise_scaler is None else noise_scaler
|
||||
def default_er_sde_noise_scaler(x):
|
||||
return x * ((x ** 0.3).exp() + 10.0)
|
||||
|
||||
noise_scaler = default_er_sde_noise_scaler if noise_scaler is None else noise_scaler
|
||||
num_integration_points = 200.0
|
||||
point_indice = torch.arange(0, num_integration_points, dtype=torch.float32, device=x.device)
|
||||
|
||||
model_sampling = model.inner_model.model_patcher.get_model_object("model_sampling")
|
||||
sigmas = offset_first_sigma_for_snr(sigmas, model_sampling)
|
||||
half_log_snrs = sigma_to_half_log_snr(sigmas, model_sampling)
|
||||
er_lambdas = half_log_snrs.neg().exp() # er_lambda_t = sigma_t / alpha_t
|
||||
|
||||
old_denoised = None
|
||||
old_denoised_d = None
|
||||
|
||||
@ -1466,32 +1485,36 @@ def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None
|
||||
stage_used = min(max_stage, i + 1)
|
||||
if sigmas[i + 1] == 0:
|
||||
x = denoised
|
||||
elif stage_used == 1:
|
||||
r = noise_scaler(sigmas[i + 1]) / noise_scaler(sigmas[i])
|
||||
x = r * x + (1 - r) * denoised
|
||||
else:
|
||||
r = noise_scaler(sigmas[i + 1]) / noise_scaler(sigmas[i])
|
||||
x = r * x + (1 - r) * denoised
|
||||
er_lambda_s, er_lambda_t = er_lambdas[i], er_lambdas[i + 1]
|
||||
alpha_s = sigmas[i] / er_lambda_s
|
||||
alpha_t = sigmas[i + 1] / er_lambda_t
|
||||
r_alpha = alpha_t / alpha_s
|
||||
r = noise_scaler(er_lambda_t) / noise_scaler(er_lambda_s)
|
||||
|
||||
dt = sigmas[i + 1] - sigmas[i]
|
||||
sigma_step_size = -dt / num_integration_points
|
||||
sigma_pos = sigmas[i + 1] + point_indice * sigma_step_size
|
||||
scaled_pos = noise_scaler(sigma_pos)
|
||||
# Stage 1 Euler
|
||||
x = r_alpha * r * x + alpha_t * (1 - r) * denoised
|
||||
|
||||
# Stage 2
|
||||
s = torch.sum(1 / scaled_pos) * sigma_step_size
|
||||
denoised_d = (denoised - old_denoised) / (sigmas[i] - sigmas[i - 1])
|
||||
x = x + (dt + s * noise_scaler(sigmas[i + 1])) * denoised_d
|
||||
if stage_used >= 2:
|
||||
dt = er_lambda_t - er_lambda_s
|
||||
lambda_step_size = -dt / num_integration_points
|
||||
lambda_pos = er_lambda_t + point_indice * lambda_step_size
|
||||
scaled_pos = noise_scaler(lambda_pos)
|
||||
|
||||
if stage_used >= 3:
|
||||
# Stage 3
|
||||
s_u = torch.sum((sigma_pos - sigmas[i]) / scaled_pos) * sigma_step_size
|
||||
denoised_u = (denoised_d - old_denoised_d) / ((sigmas[i] - sigmas[i - 2]) / 2)
|
||||
x = x + ((dt ** 2) / 2 + s_u * noise_scaler(sigmas[i + 1])) * denoised_u
|
||||
old_denoised_d = denoised_d
|
||||
# Stage 2
|
||||
s = torch.sum(1 / scaled_pos) * lambda_step_size
|
||||
denoised_d = (denoised - old_denoised) / (er_lambda_s - er_lambdas[i - 1])
|
||||
x = x + alpha_t * (dt + s * noise_scaler(er_lambda_t)) * denoised_d
|
||||
|
||||
if s_noise != 0 and sigmas[i + 1] > 0:
|
||||
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * (sigmas[i + 1] ** 2 - sigmas[i] ** 2 * r ** 2).sqrt().nan_to_num(nan=0.0)
|
||||
if stage_used >= 3:
|
||||
# Stage 3
|
||||
s_u = torch.sum((lambda_pos - er_lambda_s) / scaled_pos) * lambda_step_size
|
||||
denoised_u = (denoised_d - old_denoised_d) / ((er_lambda_s - er_lambdas[i - 2]) / 2)
|
||||
x = x + alpha_t * ((dt ** 2) / 2 + s_u * noise_scaler(er_lambda_t)) * denoised_u
|
||||
old_denoised_d = denoised_d
|
||||
|
||||
if s_noise > 0:
|
||||
x = x + alpha_t * noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * (er_lambda_t ** 2 - er_lambda_s ** 2 * r ** 2).sqrt().nan_to_num(nan=0.0)
|
||||
old_denoised = denoised
|
||||
return x
|
||||
|
||||
|
@ -195,20 +195,50 @@ class Flux(nn.Module):
|
||||
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
|
||||
return img
|
||||
|
||||
def forward(self, x, timestep, context, y=None, guidance=None, control=None, transformer_options={}, **kwargs):
|
||||
def process_img(self, x, index=0, h_offset=0, w_offset=0):
|
||||
bs, c, h, w = x.shape
|
||||
patch_size = self.patch_size
|
||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
|
||||
|
||||
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
|
||||
|
||||
h_len = ((h + (patch_size // 2)) // patch_size)
|
||||
w_len = ((w + (patch_size // 2)) // patch_size)
|
||||
|
||||
h_offset = ((h_offset + (patch_size // 2)) // patch_size)
|
||||
w_offset = ((w_offset + (patch_size // 2)) // patch_size)
|
||||
|
||||
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
|
||||
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1)
|
||||
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
|
||||
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
||||
img_ids[:, :, 0] = img_ids[:, :, 1] + index
|
||||
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1)
|
||||
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
|
||||
return img, repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
||||
|
||||
def forward(self, x, timestep, context, y=None, guidance=None, ref_latents=None, control=None, transformer_options={}, **kwargs):
|
||||
bs, c, h_orig, w_orig = x.shape
|
||||
patch_size = self.patch_size
|
||||
|
||||
h_len = ((h_orig + (patch_size // 2)) // patch_size)
|
||||
w_len = ((w_orig + (patch_size // 2)) // patch_size)
|
||||
img, img_ids = self.process_img(x)
|
||||
img_tokens = img.shape[1]
|
||||
if ref_latents is not None:
|
||||
h = 0
|
||||
w = 0
|
||||
for ref in ref_latents:
|
||||
h_offset = 0
|
||||
w_offset = 0
|
||||
if ref.shape[-2] + h > ref.shape[-1] + w:
|
||||
w_offset = w
|
||||
else:
|
||||
h_offset = h
|
||||
|
||||
kontext, kontext_ids = self.process_img(ref, index=1, h_offset=h_offset, w_offset=w_offset)
|
||||
img = torch.cat([img, kontext], dim=1)
|
||||
img_ids = torch.cat([img_ids, kontext_ids], dim=1)
|
||||
h = max(h, ref.shape[-2] + h_offset)
|
||||
w = max(w, ref.shape[-1] + w_offset)
|
||||
|
||||
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
||||
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None))
|
||||
return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:,:,:h,:w]
|
||||
out = out[:, :img_tokens]
|
||||
return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:,:,:h_orig,:w_orig]
|
||||
|
@ -11,7 +11,7 @@ from comfy.ldm.modules.ema import LitEma
|
||||
import comfy.ops
|
||||
|
||||
class DiagonalGaussianRegularizer(torch.nn.Module):
|
||||
def __init__(self, sample: bool = True):
|
||||
def __init__(self, sample: bool = False):
|
||||
super().__init__()
|
||||
self.sample = sample
|
||||
|
||||
@ -19,16 +19,12 @@ class DiagonalGaussianRegularizer(torch.nn.Module):
|
||||
yield from ()
|
||||
|
||||
def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
|
||||
log = dict()
|
||||
posterior = DiagonalGaussianDistribution(z)
|
||||
if self.sample:
|
||||
z = posterior.sample()
|
||||
else:
|
||||
z = posterior.mode()
|
||||
kl_loss = posterior.kl()
|
||||
kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
|
||||
log["kl_loss"] = kl_loss
|
||||
return z, log
|
||||
return z, None
|
||||
|
||||
|
||||
class AbstractAutoencoder(torch.nn.Module):
|
||||
|
469
comfy/ldm/omnigen/omnigen2.py
Normal file
469
comfy/ldm/omnigen/omnigen2.py
Normal file
@ -0,0 +1,469 @@
|
||||
# Original code: https://github.com/VectorSpaceLab/OmniGen2
|
||||
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange, repeat
|
||||
from comfy.ldm.lightricks.model import Timesteps
|
||||
from comfy.ldm.flux.layers import EmbedND
|
||||
from comfy.ldm.modules.attention import optimized_attention_masked
|
||||
import comfy.model_management
|
||||
import comfy.ldm.common_dit
|
||||
|
||||
|
||||
def apply_rotary_emb(x, freqs_cis):
|
||||
if x.shape[1] == 0:
|
||||
return x
|
||||
|
||||
t_ = x.reshape(*x.shape[:-1], -1, 1, 2)
|
||||
t_out = freqs_cis[..., 0] * t_[..., 0] + freqs_cis[..., 1] * t_[..., 1]
|
||||
return t_out.reshape(*x.shape).to(dtype=x.dtype)
|
||||
|
||||
|
||||
def swiglu(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||
return F.silu(x) * y
|
||||
|
||||
|
||||
class TimestepEmbedding(nn.Module):
|
||||
def __init__(self, in_channels: int, time_embed_dim: int, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.linear_1 = operations.Linear(in_channels, time_embed_dim, dtype=dtype, device=device)
|
||||
self.act = nn.SiLU()
|
||||
self.linear_2 = operations.Linear(time_embed_dim, time_embed_dim, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, sample: torch.Tensor) -> torch.Tensor:
|
||||
sample = self.linear_1(sample)
|
||||
sample = self.act(sample)
|
||||
sample = self.linear_2(sample)
|
||||
return sample
|
||||
|
||||
|
||||
class LuminaRMSNormZero(nn.Module):
|
||||
def __init__(self, embedding_dim: int, norm_eps: float = 1e-5, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.silu = nn.SiLU()
|
||||
self.linear = operations.Linear(min(embedding_dim, 1024), 4 * embedding_dim, dtype=dtype, device=device)
|
||||
self.norm = operations.RMSNorm(embedding_dim, eps=norm_eps, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x: torch.Tensor, emb: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
emb = self.linear(self.silu(emb))
|
||||
scale_msa, gate_msa, scale_mlp, gate_mlp = emb.chunk(4, dim=1)
|
||||
x = self.norm(x) * (1 + scale_msa[:, None])
|
||||
return x, gate_msa, scale_mlp, gate_mlp
|
||||
|
||||
|
||||
class LuminaLayerNormContinuous(nn.Module):
|
||||
def __init__(self, embedding_dim: int, conditioning_embedding_dim: int, elementwise_affine: bool = False, eps: float = 1e-6, out_dim: Optional[int] = None, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.silu = nn.SiLU()
|
||||
self.linear_1 = operations.Linear(conditioning_embedding_dim, embedding_dim, dtype=dtype, device=device)
|
||||
self.norm = operations.LayerNorm(embedding_dim, eps, elementwise_affine, dtype=dtype, device=device)
|
||||
self.linear_2 = operations.Linear(embedding_dim, out_dim, bias=True, dtype=dtype, device=device) if out_dim is not None else None
|
||||
|
||||
def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor:
|
||||
emb = self.linear_1(self.silu(conditioning_embedding).to(x.dtype))
|
||||
x = self.norm(x) * (1 + emb)[:, None, :]
|
||||
if self.linear_2 is not None:
|
||||
x = self.linear_2(x)
|
||||
return x
|
||||
|
||||
|
||||
class LuminaFeedForward(nn.Module):
|
||||
def __init__(self, dim: int, inner_dim: int, multiple_of: int = 256, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
inner_dim = multiple_of * ((inner_dim + multiple_of - 1) // multiple_of)
|
||||
self.linear_1 = operations.Linear(dim, inner_dim, bias=False, dtype=dtype, device=device)
|
||||
self.linear_2 = operations.Linear(inner_dim, dim, bias=False, dtype=dtype, device=device)
|
||||
self.linear_3 = operations.Linear(dim, inner_dim, bias=False, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
h1, h2 = self.linear_1(x), self.linear_3(x)
|
||||
return self.linear_2(swiglu(h1, h2))
|
||||
|
||||
|
||||
class Lumina2CombinedTimestepCaptionEmbedding(nn.Module):
|
||||
def __init__(self, hidden_size: int = 4096, text_feat_dim: int = 2048, frequency_embedding_size: int = 256, norm_eps: float = 1e-5, timestep_scale: float = 1.0, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.time_proj = Timesteps(num_channels=frequency_embedding_size, flip_sin_to_cos=True, downscale_freq_shift=0.0, scale=timestep_scale)
|
||||
self.timestep_embedder = TimestepEmbedding(in_channels=frequency_embedding_size, time_embed_dim=min(hidden_size, 1024), dtype=dtype, device=device, operations=operations)
|
||||
self.caption_embedder = nn.Sequential(
|
||||
operations.RMSNorm(text_feat_dim, eps=norm_eps, dtype=dtype, device=device),
|
||||
operations.Linear(text_feat_dim, hidden_size, bias=True, dtype=dtype, device=device),
|
||||
)
|
||||
|
||||
def forward(self, timestep: torch.Tensor, text_hidden_states: torch.Tensor, dtype: torch.dtype) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
timestep_proj = self.time_proj(timestep).to(dtype=dtype)
|
||||
time_embed = self.timestep_embedder(timestep_proj)
|
||||
caption_embed = self.caption_embedder(text_hidden_states)
|
||||
return time_embed, caption_embed
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, query_dim: int, dim_head: int, heads: int, kv_heads: int, eps: float = 1e-5, bias: bool = False, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.heads = heads
|
||||
self.kv_heads = kv_heads
|
||||
self.dim_head = dim_head
|
||||
self.scale = dim_head ** -0.5
|
||||
|
||||
self.to_q = operations.Linear(query_dim, heads * dim_head, bias=bias, dtype=dtype, device=device)
|
||||
self.to_k = operations.Linear(query_dim, kv_heads * dim_head, bias=bias, dtype=dtype, device=device)
|
||||
self.to_v = operations.Linear(query_dim, kv_heads * dim_head, bias=bias, dtype=dtype, device=device)
|
||||
|
||||
self.norm_q = operations.RMSNorm(dim_head, eps=eps, dtype=dtype, device=device)
|
||||
self.norm_k = operations.RMSNorm(dim_head, eps=eps, dtype=dtype, device=device)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
operations.Linear(heads * dim_head, query_dim, bias=bias, dtype=dtype, device=device),
|
||||
nn.Dropout(0.0)
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, image_rotary_emb: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
batch_size, sequence_length, _ = hidden_states.shape
|
||||
|
||||
query = self.to_q(hidden_states)
|
||||
key = self.to_k(encoder_hidden_states)
|
||||
value = self.to_v(encoder_hidden_states)
|
||||
|
||||
query = query.view(batch_size, -1, self.heads, self.dim_head)
|
||||
key = key.view(batch_size, -1, self.kv_heads, self.dim_head)
|
||||
value = value.view(batch_size, -1, self.kv_heads, self.dim_head)
|
||||
|
||||
query = self.norm_q(query)
|
||||
key = self.norm_k(key)
|
||||
|
||||
if image_rotary_emb is not None:
|
||||
query = apply_rotary_emb(query, image_rotary_emb)
|
||||
key = apply_rotary_emb(key, image_rotary_emb)
|
||||
|
||||
query = query.transpose(1, 2)
|
||||
key = key.transpose(1, 2)
|
||||
value = value.transpose(1, 2)
|
||||
|
||||
if self.kv_heads < self.heads:
|
||||
key = key.repeat_interleave(self.heads // self.kv_heads, dim=1)
|
||||
value = value.repeat_interleave(self.heads // self.kv_heads, dim=1)
|
||||
|
||||
hidden_states = optimized_attention_masked(query, key, value, self.heads, attention_mask, skip_reshape=True)
|
||||
hidden_states = self.to_out[0](hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class OmniGen2TransformerBlock(nn.Module):
|
||||
def __init__(self, dim: int, num_attention_heads: int, num_kv_heads: int, multiple_of: int, ffn_dim_multiplier: float, norm_eps: float, modulation: bool = True, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.modulation = modulation
|
||||
|
||||
self.attn = Attention(
|
||||
query_dim=dim,
|
||||
dim_head=dim // num_attention_heads,
|
||||
heads=num_attention_heads,
|
||||
kv_heads=num_kv_heads,
|
||||
eps=1e-5,
|
||||
bias=False,
|
||||
dtype=dtype, device=device, operations=operations,
|
||||
)
|
||||
|
||||
self.feed_forward = LuminaFeedForward(
|
||||
dim=dim,
|
||||
inner_dim=4 * dim,
|
||||
multiple_of=multiple_of,
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
|
||||
if modulation:
|
||||
self.norm1 = LuminaRMSNormZero(embedding_dim=dim, norm_eps=norm_eps, dtype=dtype, device=device, operations=operations)
|
||||
else:
|
||||
self.norm1 = operations.RMSNorm(dim, eps=norm_eps, dtype=dtype, device=device)
|
||||
|
||||
self.ffn_norm1 = operations.RMSNorm(dim, eps=norm_eps, dtype=dtype, device=device)
|
||||
self.norm2 = operations.RMSNorm(dim, eps=norm_eps, dtype=dtype, device=device)
|
||||
self.ffn_norm2 = operations.RMSNorm(dim, eps=norm_eps, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, image_rotary_emb: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
if self.modulation:
|
||||
norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb)
|
||||
attn_output = self.attn(norm_hidden_states, norm_hidden_states, attention_mask, image_rotary_emb)
|
||||
hidden_states = hidden_states + gate_msa.unsqueeze(1).tanh() * self.norm2(attn_output)
|
||||
mlp_output = self.feed_forward(self.ffn_norm1(hidden_states) * (1 + scale_mlp.unsqueeze(1)))
|
||||
hidden_states = hidden_states + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(mlp_output)
|
||||
else:
|
||||
norm_hidden_states = self.norm1(hidden_states)
|
||||
attn_output = self.attn(norm_hidden_states, norm_hidden_states, attention_mask, image_rotary_emb)
|
||||
hidden_states = hidden_states + self.norm2(attn_output)
|
||||
mlp_output = self.feed_forward(self.ffn_norm1(hidden_states))
|
||||
hidden_states = hidden_states + self.ffn_norm2(mlp_output)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class OmniGen2RotaryPosEmbed(nn.Module):
|
||||
def __init__(self, theta: int, axes_dim: Tuple[int, int, int], axes_lens: Tuple[int, int, int] = (300, 512, 512), patch_size: int = 2):
|
||||
super().__init__()
|
||||
self.theta = theta
|
||||
self.axes_dim = axes_dim
|
||||
self.axes_lens = axes_lens
|
||||
self.patch_size = patch_size
|
||||
self.rope_embedder = EmbedND(dim=sum(axes_dim), theta=self.theta, axes_dim=axes_dim)
|
||||
|
||||
def forward(self, batch_size, encoder_seq_len, l_effective_cap_len, l_effective_ref_img_len, l_effective_img_len, ref_img_sizes, img_sizes, device):
|
||||
p = self.patch_size
|
||||
|
||||
seq_lengths = [cap_len + sum(ref_img_len) + img_len for cap_len, ref_img_len, img_len in zip(l_effective_cap_len, l_effective_ref_img_len, l_effective_img_len)]
|
||||
|
||||
max_seq_len = max(seq_lengths)
|
||||
max_ref_img_len = max([sum(ref_img_len) for ref_img_len in l_effective_ref_img_len])
|
||||
max_img_len = max(l_effective_img_len)
|
||||
|
||||
position_ids = torch.zeros(batch_size, max_seq_len, 3, dtype=torch.int32, device=device)
|
||||
|
||||
for i, (cap_seq_len, seq_len) in enumerate(zip(l_effective_cap_len, seq_lengths)):
|
||||
position_ids[i, :cap_seq_len] = repeat(torch.arange(cap_seq_len, dtype=torch.int32, device=device), "l -> l 3")
|
||||
|
||||
pe_shift = cap_seq_len
|
||||
pe_shift_len = cap_seq_len
|
||||
|
||||
if ref_img_sizes[i] is not None:
|
||||
for ref_img_size, ref_img_len in zip(ref_img_sizes[i], l_effective_ref_img_len[i]):
|
||||
H, W = ref_img_size
|
||||
ref_H_tokens, ref_W_tokens = H // p, W // p
|
||||
|
||||
row_ids = repeat(torch.arange(ref_H_tokens, dtype=torch.int32, device=device), "h -> h w", w=ref_W_tokens).flatten()
|
||||
col_ids = repeat(torch.arange(ref_W_tokens, dtype=torch.int32, device=device), "w -> h w", h=ref_H_tokens).flatten()
|
||||
position_ids[i, pe_shift_len:pe_shift_len + ref_img_len, 0] = pe_shift
|
||||
position_ids[i, pe_shift_len:pe_shift_len + ref_img_len, 1] = row_ids
|
||||
position_ids[i, pe_shift_len:pe_shift_len + ref_img_len, 2] = col_ids
|
||||
|
||||
pe_shift += max(ref_H_tokens, ref_W_tokens)
|
||||
pe_shift_len += ref_img_len
|
||||
|
||||
H, W = img_sizes[i]
|
||||
H_tokens, W_tokens = H // p, W // p
|
||||
|
||||
row_ids = repeat(torch.arange(H_tokens, dtype=torch.int32, device=device), "h -> h w", w=W_tokens).flatten()
|
||||
col_ids = repeat(torch.arange(W_tokens, dtype=torch.int32, device=device), "w -> h w", h=H_tokens).flatten()
|
||||
|
||||
position_ids[i, pe_shift_len: seq_len, 0] = pe_shift
|
||||
position_ids[i, pe_shift_len: seq_len, 1] = row_ids
|
||||
position_ids[i, pe_shift_len: seq_len, 2] = col_ids
|
||||
|
||||
freqs_cis = self.rope_embedder(position_ids).movedim(1, 2)
|
||||
|
||||
cap_freqs_cis_shape = list(freqs_cis.shape)
|
||||
cap_freqs_cis_shape[1] = encoder_seq_len
|
||||
cap_freqs_cis = torch.zeros(*cap_freqs_cis_shape, device=device, dtype=freqs_cis.dtype)
|
||||
|
||||
ref_img_freqs_cis_shape = list(freqs_cis.shape)
|
||||
ref_img_freqs_cis_shape[1] = max_ref_img_len
|
||||
ref_img_freqs_cis = torch.zeros(*ref_img_freqs_cis_shape, device=device, dtype=freqs_cis.dtype)
|
||||
|
||||
img_freqs_cis_shape = list(freqs_cis.shape)
|
||||
img_freqs_cis_shape[1] = max_img_len
|
||||
img_freqs_cis = torch.zeros(*img_freqs_cis_shape, device=device, dtype=freqs_cis.dtype)
|
||||
|
||||
for i, (cap_seq_len, ref_img_len, img_len, seq_len) in enumerate(zip(l_effective_cap_len, l_effective_ref_img_len, l_effective_img_len, seq_lengths)):
|
||||
cap_freqs_cis[i, :cap_seq_len] = freqs_cis[i, :cap_seq_len]
|
||||
ref_img_freqs_cis[i, :sum(ref_img_len)] = freqs_cis[i, cap_seq_len:cap_seq_len + sum(ref_img_len)]
|
||||
img_freqs_cis[i, :img_len] = freqs_cis[i, cap_seq_len + sum(ref_img_len):cap_seq_len + sum(ref_img_len) + img_len]
|
||||
|
||||
return cap_freqs_cis, ref_img_freqs_cis, img_freqs_cis, freqs_cis, l_effective_cap_len, seq_lengths
|
||||
|
||||
|
||||
class OmniGen2Transformer2DModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
patch_size: int = 2,
|
||||
in_channels: int = 16,
|
||||
out_channels: Optional[int] = None,
|
||||
hidden_size: int = 2304,
|
||||
num_layers: int = 26,
|
||||
num_refiner_layers: int = 2,
|
||||
num_attention_heads: int = 24,
|
||||
num_kv_heads: int = 8,
|
||||
multiple_of: int = 256,
|
||||
ffn_dim_multiplier: Optional[float] = None,
|
||||
norm_eps: float = 1e-5,
|
||||
axes_dim_rope: Tuple[int, int, int] = (32, 32, 32),
|
||||
axes_lens: Tuple[int, int, int] = (300, 512, 512),
|
||||
text_feat_dim: int = 1024,
|
||||
timestep_scale: float = 1.0,
|
||||
image_model=None,
|
||||
device=None,
|
||||
dtype=None,
|
||||
operations=None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.patch_size = patch_size
|
||||
self.out_channels = out_channels or in_channels
|
||||
self.hidden_size = hidden_size
|
||||
self.dtype = dtype
|
||||
|
||||
self.rope_embedder = OmniGen2RotaryPosEmbed(
|
||||
theta=10000,
|
||||
axes_dim=axes_dim_rope,
|
||||
axes_lens=axes_lens,
|
||||
patch_size=patch_size,
|
||||
)
|
||||
|
||||
self.x_embedder = operations.Linear(patch_size * patch_size * in_channels, hidden_size, dtype=dtype, device=device)
|
||||
self.ref_image_patch_embedder = operations.Linear(patch_size * patch_size * in_channels, hidden_size, dtype=dtype, device=device)
|
||||
|
||||
self.time_caption_embed = Lumina2CombinedTimestepCaptionEmbedding(
|
||||
hidden_size=hidden_size,
|
||||
text_feat_dim=text_feat_dim,
|
||||
norm_eps=norm_eps,
|
||||
timestep_scale=timestep_scale, dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
|
||||
self.noise_refiner = nn.ModuleList([
|
||||
OmniGen2TransformerBlock(
|
||||
hidden_size, num_attention_heads, num_kv_heads,
|
||||
multiple_of, ffn_dim_multiplier, norm_eps, modulation=True, dtype=dtype, device=device, operations=operations
|
||||
) for _ in range(num_refiner_layers)
|
||||
])
|
||||
|
||||
self.ref_image_refiner = nn.ModuleList([
|
||||
OmniGen2TransformerBlock(
|
||||
hidden_size, num_attention_heads, num_kv_heads,
|
||||
multiple_of, ffn_dim_multiplier, norm_eps, modulation=True, dtype=dtype, device=device, operations=operations
|
||||
) for _ in range(num_refiner_layers)
|
||||
])
|
||||
|
||||
self.context_refiner = nn.ModuleList([
|
||||
OmniGen2TransformerBlock(
|
||||
hidden_size, num_attention_heads, num_kv_heads,
|
||||
multiple_of, ffn_dim_multiplier, norm_eps, modulation=False, dtype=dtype, device=device, operations=operations
|
||||
) for _ in range(num_refiner_layers)
|
||||
])
|
||||
|
||||
self.layers = nn.ModuleList([
|
||||
OmniGen2TransformerBlock(
|
||||
hidden_size, num_attention_heads, num_kv_heads,
|
||||
multiple_of, ffn_dim_multiplier, norm_eps, modulation=True, dtype=dtype, device=device, operations=operations
|
||||
) for _ in range(num_layers)
|
||||
])
|
||||
|
||||
self.norm_out = LuminaLayerNormContinuous(
|
||||
embedding_dim=hidden_size,
|
||||
conditioning_embedding_dim=min(hidden_size, 1024),
|
||||
elementwise_affine=False,
|
||||
eps=1e-6,
|
||||
out_dim=patch_size * patch_size * self.out_channels, dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
|
||||
self.image_index_embedding = nn.Parameter(torch.empty(5, hidden_size, device=device, dtype=dtype))
|
||||
|
||||
def flat_and_pad_to_seq(self, hidden_states, ref_image_hidden_states):
|
||||
batch_size = len(hidden_states)
|
||||
p = self.patch_size
|
||||
|
||||
img_sizes = [(img.size(1), img.size(2)) for img in hidden_states]
|
||||
l_effective_img_len = [(H // p) * (W // p) for (H, W) in img_sizes]
|
||||
|
||||
if ref_image_hidden_states is not None:
|
||||
ref_image_hidden_states = list(map(lambda ref: comfy.ldm.common_dit.pad_to_patch_size(ref, (p, p)), ref_image_hidden_states))
|
||||
ref_img_sizes = [[(imgs.size(2), imgs.size(3)) if imgs is not None else None for imgs in ref_image_hidden_states]] * batch_size
|
||||
l_effective_ref_img_len = [[(ref_img_size[0] // p) * (ref_img_size[1] // p) for ref_img_size in _ref_img_sizes] if _ref_img_sizes is not None else [0] for _ref_img_sizes in ref_img_sizes]
|
||||
else:
|
||||
ref_img_sizes = [None for _ in range(batch_size)]
|
||||
l_effective_ref_img_len = [[0] for _ in range(batch_size)]
|
||||
|
||||
flat_ref_img_hidden_states = None
|
||||
if ref_image_hidden_states is not None:
|
||||
imgs = []
|
||||
for ref_img in ref_image_hidden_states:
|
||||
B, C, H, W = ref_img.size()
|
||||
ref_img = rearrange(ref_img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=p, p2=p)
|
||||
imgs.append(ref_img)
|
||||
flat_ref_img_hidden_states = torch.cat(imgs, dim=1)
|
||||
|
||||
img = hidden_states
|
||||
B, C, H, W = img.size()
|
||||
flat_hidden_states = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=p, p2=p)
|
||||
|
||||
return (
|
||||
flat_hidden_states, flat_ref_img_hidden_states,
|
||||
None, None,
|
||||
l_effective_ref_img_len, l_effective_img_len,
|
||||
ref_img_sizes, img_sizes,
|
||||
)
|
||||
|
||||
def img_patch_embed_and_refine(self, hidden_states, ref_image_hidden_states, padded_img_mask, padded_ref_img_mask, noise_rotary_emb, ref_img_rotary_emb, l_effective_ref_img_len, l_effective_img_len, temb):
|
||||
batch_size = len(hidden_states)
|
||||
|
||||
hidden_states = self.x_embedder(hidden_states)
|
||||
if ref_image_hidden_states is not None:
|
||||
ref_image_hidden_states = self.ref_image_patch_embedder(ref_image_hidden_states)
|
||||
image_index_embedding = comfy.model_management.cast_to(self.image_index_embedding, dtype=hidden_states.dtype, device=hidden_states.device)
|
||||
|
||||
for i in range(batch_size):
|
||||
shift = 0
|
||||
for j, ref_img_len in enumerate(l_effective_ref_img_len[i]):
|
||||
ref_image_hidden_states[i, shift:shift + ref_img_len, :] = ref_image_hidden_states[i, shift:shift + ref_img_len, :] + image_index_embedding[j]
|
||||
shift += ref_img_len
|
||||
|
||||
for layer in self.noise_refiner:
|
||||
hidden_states = layer(hidden_states, padded_img_mask, noise_rotary_emb, temb)
|
||||
|
||||
if ref_image_hidden_states is not None:
|
||||
for layer in self.ref_image_refiner:
|
||||
ref_image_hidden_states = layer(ref_image_hidden_states, padded_ref_img_mask, ref_img_rotary_emb, temb)
|
||||
|
||||
hidden_states = torch.cat([ref_image_hidden_states, hidden_states], dim=1)
|
||||
|
||||
return hidden_states
|
||||
|
||||
def forward(self, x, timesteps, context, num_tokens, ref_latents=None, attention_mask=None, **kwargs):
|
||||
B, C, H, W = x.shape
|
||||
hidden_states = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size))
|
||||
_, _, H_padded, W_padded = hidden_states.shape
|
||||
timestep = 1.0 - timesteps
|
||||
text_hidden_states = context
|
||||
text_attention_mask = attention_mask
|
||||
ref_image_hidden_states = ref_latents
|
||||
device = hidden_states.device
|
||||
|
||||
temb, text_hidden_states = self.time_caption_embed(timestep, text_hidden_states, hidden_states[0].dtype)
|
||||
|
||||
(
|
||||
hidden_states, ref_image_hidden_states,
|
||||
img_mask, ref_img_mask,
|
||||
l_effective_ref_img_len, l_effective_img_len,
|
||||
ref_img_sizes, img_sizes,
|
||||
) = self.flat_and_pad_to_seq(hidden_states, ref_image_hidden_states)
|
||||
|
||||
(
|
||||
context_rotary_emb, ref_img_rotary_emb, noise_rotary_emb,
|
||||
rotary_emb, encoder_seq_lengths, seq_lengths,
|
||||
) = self.rope_embedder(
|
||||
hidden_states.shape[0], text_hidden_states.shape[1], [num_tokens] * text_hidden_states.shape[0],
|
||||
l_effective_ref_img_len, l_effective_img_len,
|
||||
ref_img_sizes, img_sizes, device,
|
||||
)
|
||||
|
||||
for layer in self.context_refiner:
|
||||
text_hidden_states = layer(text_hidden_states, text_attention_mask, context_rotary_emb)
|
||||
|
||||
img_len = hidden_states.shape[1]
|
||||
combined_img_hidden_states = self.img_patch_embed_and_refine(
|
||||
hidden_states, ref_image_hidden_states,
|
||||
img_mask, ref_img_mask,
|
||||
noise_rotary_emb, ref_img_rotary_emb,
|
||||
l_effective_ref_img_len, l_effective_img_len,
|
||||
temb,
|
||||
)
|
||||
|
||||
hidden_states = torch.cat([text_hidden_states, combined_img_hidden_states], dim=1)
|
||||
attention_mask = None
|
||||
|
||||
for layer in self.layers:
|
||||
hidden_states = layer(hidden_states, attention_mask, rotary_emb, temb)
|
||||
|
||||
hidden_states = self.norm_out(hidden_states, temb)
|
||||
|
||||
p = self.patch_size
|
||||
output = rearrange(hidden_states[:, -img_len:], 'b (h w) (p1 p2 c) -> b c (h p1) (w p2)', h=H_padded // p, w=W_padded// p, p1=p, p2=p)[:, :, :H, :W]
|
||||
|
||||
return -output
|
@ -41,6 +41,7 @@ import comfy.ldm.hunyuan3d.model
|
||||
import comfy.ldm.hidream.model
|
||||
import comfy.ldm.chroma.model
|
||||
import comfy.ldm.ace.model
|
||||
import comfy.ldm.omnigen.omnigen2
|
||||
|
||||
import comfy.model_management
|
||||
import comfy.patcher_extension
|
||||
@ -815,6 +816,7 @@ class PixArt(BaseModel):
|
||||
class Flux(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLUX, device=None, unet_model=comfy.ldm.flux.model.Flux):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=unet_model)
|
||||
self.memory_usage_factor_conds = ("ref_latents",)
|
||||
|
||||
def concat_cond(self, **kwargs):
|
||||
try:
|
||||
@ -875,8 +877,23 @@ class Flux(BaseModel):
|
||||
guidance = kwargs.get("guidance", 3.5)
|
||||
if guidance is not None:
|
||||
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance]))
|
||||
|
||||
ref_latents = kwargs.get("reference_latents", None)
|
||||
if ref_latents is not None:
|
||||
latents = []
|
||||
for lat in ref_latents:
|
||||
latents.append(self.process_latent_in(lat))
|
||||
out['ref_latents'] = comfy.conds.CONDList(latents)
|
||||
return out
|
||||
|
||||
def extra_conds_shapes(self, **kwargs):
|
||||
out = {}
|
||||
ref_latents = kwargs.get("reference_latents", None)
|
||||
if ref_latents is not None:
|
||||
out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16])
|
||||
return out
|
||||
|
||||
|
||||
class GenmoMochi(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.genmo.joint_model.asymm_models_joint.AsymmDiTJoint)
|
||||
@ -1230,3 +1247,33 @@ class ACEStep(BaseModel):
|
||||
out['speaker_embeds'] = comfy.conds.CONDRegular(torch.zeros(noise.shape[0], 512, device=noise.device, dtype=noise.dtype))
|
||||
out['lyrics_strength'] = comfy.conds.CONDConstant(kwargs.get("lyrics_strength", 1.0))
|
||||
return out
|
||||
|
||||
class Omnigen2(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.omnigen.omnigen2.OmniGen2Transformer2DModel)
|
||||
self.memory_usage_factor_conds = ("ref_latents",)
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
attention_mask = kwargs.get("attention_mask", None)
|
||||
if attention_mask is not None:
|
||||
if torch.numel(attention_mask) != attention_mask.sum():
|
||||
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
|
||||
out['num_tokens'] = comfy.conds.CONDConstant(max(1, torch.sum(attention_mask).item()))
|
||||
cross_attn = kwargs.get("cross_attn", None)
|
||||
if cross_attn is not None:
|
||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||
ref_latents = kwargs.get("reference_latents", None)
|
||||
if ref_latents is not None:
|
||||
latents = []
|
||||
for lat in ref_latents:
|
||||
latents.append(self.process_latent_in(lat))
|
||||
out['ref_latents'] = comfy.conds.CONDList(latents)
|
||||
return out
|
||||
|
||||
def extra_conds_shapes(self, **kwargs):
|
||||
out = {}
|
||||
ref_latents = kwargs.get("reference_latents", None)
|
||||
if ref_latents is not None:
|
||||
out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16])
|
||||
return out
|
||||
|
@ -459,6 +459,26 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
|
||||
return dit_config
|
||||
|
||||
if '{}time_caption_embed.timestep_embedder.linear_1.bias'.format(key_prefix) in state_dict_keys: # Omnigen2
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "omnigen2"
|
||||
dit_config["axes_dim_rope"] = [40, 40, 40]
|
||||
dit_config["axes_lens"] = [1024, 1664, 1664]
|
||||
dit_config["ffn_dim_multiplier"] = None
|
||||
dit_config["hidden_size"] = 2520
|
||||
dit_config["in_channels"] = 16
|
||||
dit_config["multiple_of"] = 256
|
||||
dit_config["norm_eps"] = 1e-05
|
||||
dit_config["num_attention_heads"] = 21
|
||||
dit_config["num_kv_heads"] = 7
|
||||
dit_config["num_layers"] = 32
|
||||
dit_config["num_refiner_layers"] = 2
|
||||
dit_config["out_channels"] = None
|
||||
dit_config["patch_size"] = 2
|
||||
dit_config["text_feat_dim"] = 2048
|
||||
dit_config["timestep_scale"] = 1000.0
|
||||
return dit_config
|
||||
|
||||
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
|
||||
return None
|
||||
|
||||
|
@ -1318,6 +1318,13 @@ def supports_fp8_compute(device=None):
|
||||
|
||||
return True
|
||||
|
||||
def extended_fp16_support():
|
||||
# TODO: check why some models work with fp16 on newer torch versions but not on older
|
||||
if torch_version_numeric < (2, 7):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def soft_empty_cache(force=False):
|
||||
global cpu_state
|
||||
if cpu_state == CPUState.MPS:
|
||||
|
@ -1251,13 +1251,13 @@ class SchedulerHandler(NamedTuple):
|
||||
use_ms: bool = True
|
||||
|
||||
SCHEDULER_HANDLERS = {
|
||||
"normal": SchedulerHandler(normal_scheduler),
|
||||
"simple": SchedulerHandler(simple_scheduler),
|
||||
"sgm_uniform": SchedulerHandler(partial(normal_scheduler, sgm=True)),
|
||||
"karras": SchedulerHandler(k_diffusion_sampling.get_sigmas_karras, use_ms=False),
|
||||
"exponential": SchedulerHandler(k_diffusion_sampling.get_sigmas_exponential, use_ms=False),
|
||||
"sgm_uniform": SchedulerHandler(partial(normal_scheduler, sgm=True)),
|
||||
"simple": SchedulerHandler(simple_scheduler),
|
||||
"ddim_uniform": SchedulerHandler(ddim_scheduler),
|
||||
"beta": SchedulerHandler(beta_scheduler),
|
||||
"normal": SchedulerHandler(normal_scheduler),
|
||||
"linear_quadratic": SchedulerHandler(linear_quadratic_schedule),
|
||||
"kl_optimal": SchedulerHandler(kl_optimal_scheduler, use_ms=False),
|
||||
}
|
||||
|
12
comfy/sd.py
12
comfy/sd.py
@ -44,6 +44,7 @@ import comfy.text_encoders.lumina2
|
||||
import comfy.text_encoders.wan
|
||||
import comfy.text_encoders.hidream
|
||||
import comfy.text_encoders.ace
|
||||
import comfy.text_encoders.omnigen2
|
||||
|
||||
import comfy.model_patcher
|
||||
import comfy.lora
|
||||
@ -754,6 +755,7 @@ class CLIPType(Enum):
|
||||
HIDREAM = 14
|
||||
CHROMA = 15
|
||||
ACE = 16
|
||||
OMNIGEN2 = 17
|
||||
|
||||
|
||||
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
|
||||
@ -773,6 +775,7 @@ class TEModel(Enum):
|
||||
LLAMA3_8 = 7
|
||||
T5_XXL_OLD = 8
|
||||
GEMMA_2_2B = 9
|
||||
QWEN25_3B = 10
|
||||
|
||||
def detect_te_model(sd):
|
||||
if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:
|
||||
@ -793,6 +796,8 @@ def detect_te_model(sd):
|
||||
return TEModel.T5_BASE
|
||||
if 'model.layers.0.post_feedforward_layernorm.weight' in sd:
|
||||
return TEModel.GEMMA_2_2B
|
||||
if 'model.layers.0.self_attn.k_proj.bias' in sd:
|
||||
return TEModel.QWEN25_3B
|
||||
if "model.layers.0.post_attention_layernorm.weight" in sd:
|
||||
return TEModel.LLAMA3_8
|
||||
return None
|
||||
@ -894,6 +899,9 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
||||
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(**llama_detect(clip_data),
|
||||
clip_l=False, clip_g=False, t5=False, llama=True, dtype_t5=None, t5xxl_scaled_fp8=None)
|
||||
clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer
|
||||
elif te_model == TEModel.QWEN25_3B:
|
||||
clip_target.clip = comfy.text_encoders.omnigen2.te(**llama_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.omnigen2.Omnigen2Tokenizer
|
||||
else:
|
||||
# clip_l
|
||||
if clip_type == CLIPType.SD3:
|
||||
@ -1160,7 +1168,7 @@ def load_diffusion_model_state_dict(sd, model_options={}):
|
||||
model.load_model_weights(new_sd, "")
|
||||
left_over = sd.keys()
|
||||
if len(left_over) > 0:
|
||||
logging.info("left over keys in unet: {}".format(left_over))
|
||||
logging.info("left over keys in diffusion model: {}".format(left_over))
|
||||
return comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=offload_device)
|
||||
|
||||
|
||||
@ -1168,7 +1176,7 @@ def load_diffusion_model(unet_path, model_options={}):
|
||||
sd = comfy.utils.load_torch_file(unet_path)
|
||||
model = load_diffusion_model_state_dict(sd, model_options=model_options)
|
||||
if model is None:
|
||||
logging.error("ERROR UNSUPPORTED UNET {}".format(unet_path))
|
||||
logging.error("ERROR UNSUPPORTED DIFFUSION MODEL {}".format(unet_path))
|
||||
raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path))
|
||||
return model
|
||||
|
||||
|
@ -482,7 +482,8 @@ class SDTokenizer:
|
||||
if end_token is not None:
|
||||
self.end_token = end_token
|
||||
else:
|
||||
self.end_token = empty[0]
|
||||
if has_end_token:
|
||||
self.end_token = empty[0]
|
||||
|
||||
if pad_token is not None:
|
||||
self.pad_token = pad_token
|
||||
|
@ -18,6 +18,7 @@ import comfy.text_encoders.cosmos
|
||||
import comfy.text_encoders.lumina2
|
||||
import comfy.text_encoders.wan
|
||||
import comfy.text_encoders.ace
|
||||
import comfy.text_encoders.omnigen2
|
||||
|
||||
from . import supported_models_base
|
||||
from . import latent_formats
|
||||
@ -1181,6 +1182,41 @@ class ACEStep(supported_models_base.BASE):
|
||||
def clip_target(self, state_dict={}):
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.ace.AceT5Tokenizer, comfy.text_encoders.ace.AceT5Model)
|
||||
|
||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep]
|
||||
class Omnigen2(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"image_model": "omnigen2",
|
||||
}
|
||||
|
||||
sampling_settings = {
|
||||
"multiplier": 1.0,
|
||||
"shift": 2.6,
|
||||
}
|
||||
|
||||
memory_usage_factor = 1.65 #TODO
|
||||
|
||||
unet_extra_config = {}
|
||||
latent_format = latent_formats.Flux
|
||||
|
||||
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
||||
|
||||
vae_key_prefix = ["vae."]
|
||||
text_encoder_key_prefix = ["text_encoders."]
|
||||
|
||||
def __init__(self, unet_config):
|
||||
super().__init__(unet_config)
|
||||
if comfy.model_management.extended_fp16_support():
|
||||
self.supported_inference_dtypes = [torch.float16] + self.supported_inference_dtypes
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.Omnigen2(self, device=device)
|
||||
return out
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
pref = self.text_encoder_key_prefix[0]
|
||||
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_3b.transformer.".format(pref))
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.omnigen2.LuminaTokenizer, comfy.text_encoders.omnigen2.te(**hunyuan_detect))
|
||||
|
||||
|
||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep, Omnigen2]
|
||||
|
||||
models += [SVD_img2vid]
|
||||
|
@ -24,6 +24,24 @@ class Llama2Config:
|
||||
head_dim = 128
|
||||
rms_norm_add = False
|
||||
mlp_activation = "silu"
|
||||
qkv_bias = False
|
||||
|
||||
@dataclass
|
||||
class Qwen25_3BConfig:
|
||||
vocab_size: int = 151936
|
||||
hidden_size: int = 2048
|
||||
intermediate_size: int = 11008
|
||||
num_hidden_layers: int = 36
|
||||
num_attention_heads: int = 16
|
||||
num_key_value_heads: int = 2
|
||||
max_position_embeddings: int = 128000
|
||||
rms_norm_eps: float = 1e-6
|
||||
rope_theta: float = 1000000.0
|
||||
transformer_type: str = "llama"
|
||||
head_dim = 128
|
||||
rms_norm_add = False
|
||||
mlp_activation = "silu"
|
||||
qkv_bias = True
|
||||
|
||||
@dataclass
|
||||
class Gemma2_2B_Config:
|
||||
@ -40,6 +58,7 @@ class Gemma2_2B_Config:
|
||||
head_dim = 256
|
||||
rms_norm_add = True
|
||||
mlp_activation = "gelu_pytorch_tanh"
|
||||
qkv_bias = False
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, dim: int, eps: float = 1e-5, add=False, device=None, dtype=None):
|
||||
@ -98,9 +117,9 @@ class Attention(nn.Module):
|
||||
self.inner_size = self.num_heads * self.head_dim
|
||||
|
||||
ops = ops or nn
|
||||
self.q_proj = ops.Linear(config.hidden_size, self.inner_size, bias=False, device=device, dtype=dtype)
|
||||
self.k_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False, device=device, dtype=dtype)
|
||||
self.v_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False, device=device, dtype=dtype)
|
||||
self.q_proj = ops.Linear(config.hidden_size, self.inner_size, bias=config.qkv_bias, device=device, dtype=dtype)
|
||||
self.k_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=config.qkv_bias, device=device, dtype=dtype)
|
||||
self.v_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=config.qkv_bias, device=device, dtype=dtype)
|
||||
self.o_proj = ops.Linear(self.inner_size, config.hidden_size, bias=False, device=device, dtype=dtype)
|
||||
|
||||
def forward(
|
||||
@ -320,6 +339,14 @@ class Llama2(BaseLlama, torch.nn.Module):
|
||||
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
||||
self.dtype = dtype
|
||||
|
||||
class Qwen25_3B(BaseLlama, torch.nn.Module):
|
||||
def __init__(self, config_dict, dtype, device, operations):
|
||||
super().__init__()
|
||||
config = Qwen25_3BConfig(**config_dict)
|
||||
self.num_layers = config.num_hidden_layers
|
||||
|
||||
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
||||
self.dtype = dtype
|
||||
|
||||
class Gemma2_2B(BaseLlama, torch.nn.Module):
|
||||
def __init__(self, config_dict, dtype, device, operations):
|
||||
|
44
comfy/text_encoders/omnigen2.py
Normal file
44
comfy/text_encoders/omnigen2.py
Normal file
@ -0,0 +1,44 @@
|
||||
from transformers import Qwen2Tokenizer
|
||||
from comfy import sd1_clip
|
||||
import comfy.text_encoders.llama
|
||||
import os
|
||||
|
||||
|
||||
class Qwen25_3BTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer")
|
||||
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=2048, embedding_key='qwen25_3b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=151643, tokenizer_data=tokenizer_data)
|
||||
|
||||
|
||||
class Omnigen2Tokenizer(sd1_clip.SD1Tokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="qwen25_3b", tokenizer=Qwen25_3BTokenizer)
|
||||
self.llama_template = '<|im_start|>system\nYou are a helpful assistant that generates high-quality images based on user instructions.<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n'
|
||||
|
||||
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None,**kwargs):
|
||||
if llama_template is None:
|
||||
llama_text = self.llama_template.format(text)
|
||||
else:
|
||||
llama_text = llama_template.format(text)
|
||||
return super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, **kwargs)
|
||||
|
||||
class Qwen25_3BModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}):
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Qwen25_3B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
|
||||
|
||||
|
||||
class Omnigen2Model(sd1_clip.SD1ClipModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
super().__init__(device=device, dtype=dtype, name="qwen25_3b", clip_model=Qwen25_3BModel, model_options=model_options)
|
||||
|
||||
|
||||
def te(dtype_llama=None, llama_scaled_fp8=None):
|
||||
class Omnigen2TEModel_(Omnigen2Model):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options:
|
||||
model_options = model_options.copy()
|
||||
model_options["scaled_fp8"] = llama_scaled_fp8
|
||||
if dtype_llama is not None:
|
||||
dtype = dtype_llama
|
||||
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
||||
return Omnigen2TEModel_
|
151388
comfy/text_encoders/qwen25_tokenizer/merges.txt
Normal file
151388
comfy/text_encoders/qwen25_tokenizer/merges.txt
Normal file
File diff suppressed because it is too large
Load Diff
241
comfy/text_encoders/qwen25_tokenizer/tokenizer_config.json
Normal file
241
comfy/text_encoders/qwen25_tokenizer/tokenizer_config.json
Normal file
@ -0,0 +1,241 @@
|
||||
{
|
||||
"add_bos_token": false,
|
||||
"add_prefix_space": false,
|
||||
"added_tokens_decoder": {
|
||||
"151643": {
|
||||
"content": "<|endoftext|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"151644": {
|
||||
"content": "<|im_start|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"151645": {
|
||||
"content": "<|im_end|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"151646": {
|
||||
"content": "<|object_ref_start|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"151647": {
|
||||
"content": "<|object_ref_end|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"151648": {
|
||||
"content": "<|box_start|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"151649": {
|
||||
"content": "<|box_end|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"151650": {
|
||||
"content": "<|quad_start|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"151651": {
|
||||
"content": "<|quad_end|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"151652": {
|
||||
"content": "<|vision_start|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"151653": {
|
||||
"content": "<|vision_end|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"151654": {
|
||||
"content": "<|vision_pad|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"151655": {
|
||||
"content": "<|image_pad|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"151656": {
|
||||
"content": "<|video_pad|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"151657": {
|
||||
"content": "<tool_call>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": false
|
||||
},
|
||||
"151658": {
|
||||
"content": "</tool_call>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": false
|
||||
},
|
||||
"151659": {
|
||||
"content": "<|fim_prefix|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": false
|
||||
},
|
||||
"151660": {
|
||||
"content": "<|fim_middle|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": false
|
||||
},
|
||||
"151661": {
|
||||
"content": "<|fim_suffix|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": false
|
||||
},
|
||||
"151662": {
|
||||
"content": "<|fim_pad|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": false
|
||||
},
|
||||
"151663": {
|
||||
"content": "<|repo_name|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": false
|
||||
},
|
||||
"151664": {
|
||||
"content": "<|file_sep|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": false
|
||||
},
|
||||
"151665": {
|
||||
"content": "<|img|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"151666": {
|
||||
"content": "<|endofimg|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"151667": {
|
||||
"content": "<|meta|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"151668": {
|
||||
"content": "<|endofmeta|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
}
|
||||
},
|
||||
"additional_special_tokens": [
|
||||
"<|im_start|>",
|
||||
"<|im_end|>",
|
||||
"<|object_ref_start|>",
|
||||
"<|object_ref_end|>",
|
||||
"<|box_start|>",
|
||||
"<|box_end|>",
|
||||
"<|quad_start|>",
|
||||
"<|quad_end|>",
|
||||
"<|vision_start|>",
|
||||
"<|vision_end|>",
|
||||
"<|vision_pad|>",
|
||||
"<|image_pad|>",
|
||||
"<|video_pad|>"
|
||||
],
|
||||
"bos_token": null,
|
||||
"chat_template": "{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0]['role'] == 'system' %}\n {{- messages[0]['content'] }}\n {%- else %}\n {{- 'You are a helpful assistant.' }}\n {%- endif %}\n {{- \"\\n\\n# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within <tools></tools> XML tags:\\n<tools>\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n</tools>\\n\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\n<tool_call>\\n{\\\"name\\\": <function-name>, \\\"arguments\\\": <args-json-object>}\\n</tool_call><|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0]['role'] == 'system' %}\n {{- '<|im_start|>system\\n' + messages[0]['content'] + '<|im_end|>\\n' }}\n {%- else %}\n {{- '<|im_start|>system\\nYou are a helpful assistant.<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- for message in messages %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) or (message.role == \"assistant\" and not message.tool_calls) %}\n {{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {{- '<|im_start|>' + message.role }}\n {%- if message.content %}\n {{- '\\n' + message.content }}\n {%- endif %}\n {%- for tool_call in message.tool_calls %}\n {%- if tool_call.function is defined %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '\\n<tool_call>\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {{- tool_call.arguments | tojson }}\n {{- '}\\n</tool_call>' }}\n {%- endfor %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n<tool_response>\\n' }}\n {{- message.content }}\n {{- '\\n</tool_response>' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n{%- endif %}\n",
|
||||
"clean_up_tokenization_spaces": false,
|
||||
"eos_token": "<|im_end|>",
|
||||
"errors": "replace",
|
||||
"extra_special_tokens": {},
|
||||
"model_max_length": 131072,
|
||||
"pad_token": "<|endoftext|>",
|
||||
"processor_class": "Qwen2_5_VLProcessor",
|
||||
"split_special_tokens": false,
|
||||
"tokenizer_class": "Qwen2Tokenizer",
|
||||
"unk_token": null
|
||||
}
|
1
comfy/text_encoders/qwen25_tokenizer/vocab.json
Normal file
1
comfy/text_encoders/qwen25_tokenizer/vocab.json
Normal file
File diff suppressed because one or more lines are too long
@ -146,7 +146,7 @@ class T5Attention(torch.nn.Module):
|
||||
)
|
||||
values = self.relative_attention_bias(relative_position_bucket, out_dtype=dtype) # shape (query_length, key_length, num_heads)
|
||||
values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length)
|
||||
return values
|
||||
return values.contiguous()
|
||||
|
||||
def forward(self, x, mask=None, past_bias=None, optimized_attention=None):
|
||||
q = self.q(x)
|
||||
|
@ -2,6 +2,7 @@ import math
|
||||
import comfy.samplers
|
||||
import comfy.sample
|
||||
from comfy.k_diffusion import sampling as k_diffusion_sampling
|
||||
from comfy.comfy_types import IO, ComfyNodeABC, InputTypeDict
|
||||
import latent_preview
|
||||
import torch
|
||||
import comfy.utils
|
||||
@ -480,6 +481,46 @@ class SamplerDPMAdaptative:
|
||||
"s_noise":s_noise })
|
||||
return (sampler, )
|
||||
|
||||
|
||||
class SamplerER_SDE(ComfyNodeABC):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypeDict:
|
||||
return {
|
||||
"required": {
|
||||
"solver_type": (IO.COMBO, {"options": ["ER-SDE", "Reverse-time SDE", "ODE"]}),
|
||||
"max_stage": (IO.INT, {"default": 3, "min": 1, "max": 3}),
|
||||
"eta": (
|
||||
IO.FLOAT,
|
||||
{"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01, "round": False, "tooltip": "Stochastic strength of reverse-time SDE.\nWhen eta=0, it reduces to deterministic ODE. This setting doesn't apply to ER-SDE solver type."},
|
||||
),
|
||||
"s_noise": (IO.FLOAT, {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01, "round": False}),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.SAMPLER,)
|
||||
CATEGORY = "sampling/custom_sampling/samplers"
|
||||
|
||||
FUNCTION = "get_sampler"
|
||||
|
||||
def get_sampler(self, solver_type, max_stage, eta, s_noise):
|
||||
if solver_type == "ODE" or (solver_type == "Reverse-time SDE" and eta == 0):
|
||||
eta = 0
|
||||
s_noise = 0
|
||||
|
||||
def reverse_time_sde_noise_scaler(x):
|
||||
return x ** (eta + 1)
|
||||
|
||||
if solver_type == "ER-SDE":
|
||||
# Use the default one in sample_er_sde()
|
||||
noise_scaler = None
|
||||
else:
|
||||
noise_scaler = reverse_time_sde_noise_scaler
|
||||
|
||||
sampler_name = "er_sde"
|
||||
sampler = comfy.samplers.ksampler(sampler_name, {"s_noise": s_noise, "noise_scaler": noise_scaler, "max_stage": max_stage})
|
||||
return (sampler,)
|
||||
|
||||
|
||||
class Noise_EmptyNoise:
|
||||
def __init__(self):
|
||||
self.seed = 0
|
||||
@ -609,8 +650,14 @@ class Guider_DualCFG(comfy.samplers.CFGGuider):
|
||||
def predict_noise(self, x, timestep, model_options={}, seed=None):
|
||||
negative_cond = self.conds.get("negative", None)
|
||||
middle_cond = self.conds.get("middle", None)
|
||||
positive_cond = self.conds.get("positive", None)
|
||||
if model_options.get("disable_cfg1_optimization", False) == False:
|
||||
if math.isclose(self.cfg2, 1.0):
|
||||
negative_cond = None
|
||||
if math.isclose(self.cfg1, 1.0):
|
||||
middle_cond = None
|
||||
|
||||
out = comfy.samplers.calc_cond_batch(self.inner_model, [negative_cond, middle_cond, self.conds.get("positive", None)], x, timestep, model_options)
|
||||
out = comfy.samplers.calc_cond_batch(self.inner_model, [negative_cond, middle_cond, positive_cond], x, timestep, model_options)
|
||||
return comfy.samplers.cfg_function(self.inner_model, out[1], out[0], self.cfg2, x, timestep, model_options=model_options, cond=middle_cond, uncond=negative_cond) + (out[2] - out[1]) * self.cfg1
|
||||
|
||||
class DualCFGGuider:
|
||||
@ -781,6 +828,7 @@ NODE_CLASS_MAPPINGS = {
|
||||
"SamplerDPMPP_SDE": SamplerDPMPP_SDE,
|
||||
"SamplerDPMPP_2S_Ancestral": SamplerDPMPP_2S_Ancestral,
|
||||
"SamplerDPMAdaptative": SamplerDPMAdaptative,
|
||||
"SamplerER_SDE": SamplerER_SDE,
|
||||
"SplitSigmas": SplitSigmas,
|
||||
"SplitSigmasDenoise": SplitSigmasDenoise,
|
||||
"FlipSigmas": FlipSigmas,
|
||||
|
26
comfy_extras/nodes_edit_model.py
Normal file
26
comfy_extras/nodes_edit_model.py
Normal file
@ -0,0 +1,26 @@
|
||||
import node_helpers
|
||||
|
||||
|
||||
class ReferenceLatent:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"conditioning": ("CONDITIONING", ),
|
||||
},
|
||||
"optional": {"latent": ("LATENT", ),}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("CONDITIONING",)
|
||||
FUNCTION = "append"
|
||||
|
||||
CATEGORY = "advanced/conditioning/edit_models"
|
||||
DESCRIPTION = "This node sets the guiding latent for an edit model. If the model supports it you can chain multiple to set multiple reference images."
|
||||
|
||||
def append(self, conditioning, latent=None):
|
||||
if latent is not None:
|
||||
conditioning = node_helpers.conditioning_set_values(conditioning, {"reference_latents": [latent["samples"]]}, append=True)
|
||||
return (conditioning, )
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"ReferenceLatent": ReferenceLatent,
|
||||
}
|
@ -1,4 +1,5 @@
|
||||
import node_helpers
|
||||
import comfy.utils
|
||||
|
||||
class CLIPTextEncodeFlux:
|
||||
@classmethod
|
||||
@ -56,8 +57,52 @@ class FluxDisableGuidance:
|
||||
return (c, )
|
||||
|
||||
|
||||
PREFERED_KONTEXT_RESOLUTIONS = [
|
||||
(672, 1568),
|
||||
(688, 1504),
|
||||
(720, 1456),
|
||||
(752, 1392),
|
||||
(800, 1328),
|
||||
(832, 1248),
|
||||
(880, 1184),
|
||||
(944, 1104),
|
||||
(1024, 1024),
|
||||
(1104, 944),
|
||||
(1184, 880),
|
||||
(1248, 832),
|
||||
(1328, 800),
|
||||
(1392, 752),
|
||||
(1456, 720),
|
||||
(1504, 688),
|
||||
(1568, 672),
|
||||
]
|
||||
|
||||
|
||||
class FluxKontextImageScale:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"image": ("IMAGE", ),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "scale"
|
||||
|
||||
CATEGORY = "advanced/conditioning/flux"
|
||||
DESCRIPTION = "This node resizes the image to one that is more optimal for flux kontext."
|
||||
|
||||
def scale(self, image):
|
||||
width = image.shape[2]
|
||||
height = image.shape[1]
|
||||
aspect_ratio = width / height
|
||||
_, width, height = min((abs(aspect_ratio - w / h), w, h) for w, h in PREFERED_KONTEXT_RESOLUTIONS)
|
||||
image = comfy.utils.common_upscale(image.movedim(-1, 1), width, height, "lanczos", "center").movedim(1, -1)
|
||||
return (image, )
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"CLIPTextEncodeFlux": CLIPTextEncodeFlux,
|
||||
"FluxGuidance": FluxGuidance,
|
||||
"FluxDisableGuidance": FluxDisableGuidance,
|
||||
"FluxKontextImageScale": FluxKontextImageScale,
|
||||
}
|
||||
|
@ -268,6 +268,52 @@ class ModelMergeWAN2_1(comfy_extras.nodes_model_merging.ModelMergeBlocks):
|
||||
|
||||
return {"required": arg_dict}
|
||||
|
||||
class ModelMergeCosmosPredict2_2B(comfy_extras.nodes_model_merging.ModelMergeBlocks):
|
||||
CATEGORY = "advanced/model_merging/model_specific"
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
arg_dict = { "model1": ("MODEL",),
|
||||
"model2": ("MODEL",)}
|
||||
|
||||
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
|
||||
|
||||
arg_dict["pos_embedder."] = argument
|
||||
arg_dict["x_embedder."] = argument
|
||||
arg_dict["t_embedder."] = argument
|
||||
arg_dict["t_embedding_norm."] = argument
|
||||
|
||||
|
||||
for i in range(28):
|
||||
arg_dict["blocks.{}.".format(i)] = argument
|
||||
|
||||
arg_dict["final_layer."] = argument
|
||||
|
||||
return {"required": arg_dict}
|
||||
|
||||
class ModelMergeCosmosPredict2_14B(comfy_extras.nodes_model_merging.ModelMergeBlocks):
|
||||
CATEGORY = "advanced/model_merging/model_specific"
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
arg_dict = { "model1": ("MODEL",),
|
||||
"model2": ("MODEL",)}
|
||||
|
||||
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
|
||||
|
||||
arg_dict["pos_embedder."] = argument
|
||||
arg_dict["x_embedder."] = argument
|
||||
arg_dict["t_embedder."] = argument
|
||||
arg_dict["t_embedding_norm."] = argument
|
||||
|
||||
|
||||
for i in range(36):
|
||||
arg_dict["blocks.{}.".format(i)] = argument
|
||||
|
||||
arg_dict["final_layer."] = argument
|
||||
|
||||
return {"required": arg_dict}
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"ModelMergeSD1": ModelMergeSD1,
|
||||
"ModelMergeSD2": ModelMergeSD1, #SD1 and SD2 have the same blocks
|
||||
@ -281,4 +327,6 @@ NODE_CLASS_MAPPINGS = {
|
||||
"ModelMergeCosmos7B": ModelMergeCosmos7B,
|
||||
"ModelMergeCosmos14B": ModelMergeCosmos14B,
|
||||
"ModelMergeWAN2_1": ModelMergeWAN2_1,
|
||||
"ModelMergeCosmosPredict2_2B": ModelMergeCosmosPredict2_2B,
|
||||
"ModelMergeCosmosPredict2_14B": ModelMergeCosmosPredict2_14B,
|
||||
}
|
||||
|
@ -4,6 +4,7 @@ import comfy.sampler_helpers
|
||||
import comfy.samplers
|
||||
import comfy.utils
|
||||
import node_helpers
|
||||
import math
|
||||
|
||||
def perp_neg(x, noise_pred_pos, noise_pred_neg, noise_pred_nocond, neg_scale, cond_scale):
|
||||
pos = noise_pred_pos - noise_pred_nocond
|
||||
@ -69,8 +70,23 @@ class Guider_PerpNeg(comfy.samplers.CFGGuider):
|
||||
negative_cond = self.conds.get("negative", None)
|
||||
empty_cond = self.conds.get("empty_negative_prompt", None)
|
||||
|
||||
(noise_pred_pos, noise_pred_neg, noise_pred_empty) = \
|
||||
comfy.samplers.calc_cond_batch(self.inner_model, [positive_cond, negative_cond, empty_cond], x, timestep, model_options)
|
||||
if model_options.get("disable_cfg1_optimization", False) == False:
|
||||
if math.isclose(self.neg_scale, 0.0):
|
||||
negative_cond = None
|
||||
if math.isclose(self.cfg, 1.0):
|
||||
empty_cond = None
|
||||
|
||||
conds = [positive_cond, negative_cond, empty_cond]
|
||||
|
||||
out = comfy.samplers.calc_cond_batch(self.inner_model, conds, x, timestep, model_options)
|
||||
|
||||
# Apply pre_cfg_functions since sampling_function() is skipped
|
||||
for fn in model_options.get("sampler_pre_cfg_function", []):
|
||||
args = {"conds":conds, "conds_out": out, "cond_scale": self.cfg, "timestep": timestep,
|
||||
"input": x, "sigma": timestep, "model": self.inner_model, "model_options": model_options}
|
||||
out = fn(args)
|
||||
|
||||
noise_pred_pos, noise_pred_neg, noise_pred_empty = out
|
||||
cfg_result = perp_neg(x, noise_pred_pos, noise_pred_neg, noise_pred_empty, self.neg_scale, self.cfg)
|
||||
|
||||
# normally this would be done in cfg_function, but we skipped
|
||||
@ -82,6 +98,7 @@ class Guider_PerpNeg(comfy.samplers.CFGGuider):
|
||||
"denoised": cfg_result,
|
||||
"cond": positive_cond,
|
||||
"uncond": negative_cond,
|
||||
"cond_scale": self.cfg,
|
||||
"model": self.inner_model,
|
||||
"uncond_denoised": noise_pred_neg,
|
||||
"cond_denoised": noise_pred_pos,
|
||||
|
71
comfy_extras/nodes_tcfg.py
Normal file
71
comfy_extras/nodes_tcfg.py
Normal file
@ -0,0 +1,71 @@
|
||||
# TCFG: Tangential Damping Classifier-free Guidance - (arXiv: https://arxiv.org/abs/2503.18137)
|
||||
|
||||
import torch
|
||||
|
||||
from comfy.comfy_types import IO, ComfyNodeABC, InputTypeDict
|
||||
|
||||
|
||||
def score_tangential_damping(cond_score: torch.Tensor, uncond_score: torch.Tensor) -> torch.Tensor:
|
||||
"""Drop tangential components from uncond score to align with cond score."""
|
||||
# (B, 1, ...)
|
||||
batch_num = cond_score.shape[0]
|
||||
cond_score_flat = cond_score.reshape(batch_num, 1, -1).float()
|
||||
uncond_score_flat = uncond_score.reshape(batch_num, 1, -1).float()
|
||||
|
||||
# Score matrix A (B, 2, ...)
|
||||
score_matrix = torch.cat((uncond_score_flat, cond_score_flat), dim=1)
|
||||
try:
|
||||
_, _, Vh = torch.linalg.svd(score_matrix, full_matrices=False)
|
||||
except RuntimeError:
|
||||
# Fallback to CPU
|
||||
_, _, Vh = torch.linalg.svd(score_matrix.cpu(), full_matrices=False)
|
||||
|
||||
# Drop the tangential components
|
||||
v1 = Vh[:, 0:1, :].to(uncond_score_flat.device) # (B, 1, ...)
|
||||
uncond_score_td = (uncond_score_flat @ v1.transpose(-2, -1)) * v1
|
||||
return uncond_score_td.reshape_as(uncond_score).to(uncond_score.dtype)
|
||||
|
||||
|
||||
class TCFG(ComfyNodeABC):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypeDict:
|
||||
return {
|
||||
"required": {
|
||||
"model": (IO.MODEL, {}),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.MODEL,)
|
||||
RETURN_NAMES = ("patched_model",)
|
||||
FUNCTION = "patch"
|
||||
|
||||
CATEGORY = "advanced/guidance"
|
||||
DESCRIPTION = "TCFG – Tangential Damping CFG (2503.18137)\n\nRefine the uncond (negative) to align with the cond (positive) for improving quality."
|
||||
|
||||
def patch(self, model):
|
||||
m = model.clone()
|
||||
|
||||
def tangential_damping_cfg(args):
|
||||
# Assume [cond, uncond, ...]
|
||||
x = args["input"]
|
||||
conds_out = args["conds_out"]
|
||||
if len(conds_out) <= 1 or None in args["conds"][:2]:
|
||||
# Skip when either cond or uncond is None
|
||||
return conds_out
|
||||
cond_pred = conds_out[0]
|
||||
uncond_pred = conds_out[1]
|
||||
uncond_td = score_tangential_damping(x - cond_pred, x - uncond_pred)
|
||||
uncond_pred_td = x - uncond_td
|
||||
return [cond_pred, uncond_pred_td] + conds_out[2:]
|
||||
|
||||
m.set_model_sampler_pre_cfg_function(tangential_damping_cfg)
|
||||
return (m,)
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"TCFG": TCFG,
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"TCFG": "Tangential Damping CFG",
|
||||
}
|
@ -1,3 +1,3 @@
|
||||
# This file is automatically generated by the build process when version is
|
||||
# updated in pyproject.toml.
|
||||
__version__ = "0.3.41"
|
||||
__version__ = "0.3.43"
|
||||
|
14
main.py
14
main.py
@ -55,6 +55,9 @@ def apply_custom_paths():
|
||||
|
||||
|
||||
def execute_prestartup_script():
|
||||
if args.disable_all_custom_nodes and len(args.whitelist_custom_nodes) == 0:
|
||||
return
|
||||
|
||||
def execute_script(script_path):
|
||||
module_name = os.path.splitext(script_path)[0]
|
||||
try:
|
||||
@ -66,9 +69,6 @@ def execute_prestartup_script():
|
||||
logging.error(f"Failed to execute startup-script: {script_path} / {e}")
|
||||
return False
|
||||
|
||||
if args.disable_all_custom_nodes:
|
||||
return
|
||||
|
||||
node_paths = folder_paths.get_folder_paths("custom_nodes")
|
||||
for custom_node_path in node_paths:
|
||||
possible_modules = os.listdir(custom_node_path)
|
||||
@ -81,6 +81,9 @@ def execute_prestartup_script():
|
||||
|
||||
script_path = os.path.join(module_path, "prestartup_script.py")
|
||||
if os.path.exists(script_path):
|
||||
if args.disable_all_custom_nodes and possible_module not in args.whitelist_custom_nodes:
|
||||
logging.info(f"Prestartup Skipping {possible_module} due to disable_all_custom_nodes and whitelist_custom_nodes")
|
||||
continue
|
||||
time_before = time.perf_counter()
|
||||
success = execute_script(script_path)
|
||||
node_prestartup_times.append((time.perf_counter() - time_before, module_path, success))
|
||||
@ -276,7 +279,10 @@ def start_comfyui(asyncio_loop=None):
|
||||
prompt_server = server.PromptServer(asyncio_loop)
|
||||
|
||||
hook_breaker_ac10a0.save_functions()
|
||||
nodes.init_extra_nodes(init_custom_nodes=not args.disable_all_custom_nodes, init_api_nodes=not args.disable_api_nodes)
|
||||
nodes.init_extra_nodes(
|
||||
init_custom_nodes=(not args.disable_all_custom_nodes) or len(args.whitelist_custom_nodes) > 0,
|
||||
init_api_nodes=not args.disable_api_nodes
|
||||
)
|
||||
hook_breaker_ac10a0.restore_functions()
|
||||
|
||||
cuda_malloc_warning()
|
||||
|
9
nodes.py
9
nodes.py
@ -920,7 +920,7 @@ class CLIPLoader:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ),
|
||||
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace"], ),
|
||||
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2"], ),
|
||||
},
|
||||
"optional": {
|
||||
"device": (["default", "cpu"], {"advanced": True}),
|
||||
@ -930,7 +930,7 @@ class CLIPLoader:
|
||||
|
||||
CATEGORY = "advanced/loaders"
|
||||
|
||||
DESCRIPTION = "[Recipes]\n\nstable_diffusion: clip-l\nstable_cascade: clip-g\nsd3: t5 xxl/ clip-g / clip-l\nstable_audio: t5 base\nmochi: t5 xxl\ncosmos: old t5 xxl\nlumina2: gemma 2 2B\nwan: umt5 xxl\n hidream: llama-3.1 (Recommend) or t5"
|
||||
DESCRIPTION = "[Recipes]\n\nstable_diffusion: clip-l\nstable_cascade: clip-g\nsd3: t5 xxl/ clip-g / clip-l\nstable_audio: t5 base\nmochi: t5 xxl\ncosmos: old t5 xxl\nlumina2: gemma 2 2B\nwan: umt5 xxl\n hidream: llama-3.1 (Recommend) or t5\nomnigen2: qwen vl 2.5 3B"
|
||||
|
||||
def load_clip(self, clip_name, type="stable_diffusion", device="default"):
|
||||
clip_type = getattr(comfy.sd.CLIPType, type.upper(), comfy.sd.CLIPType.STABLE_DIFFUSION)
|
||||
@ -2187,6 +2187,9 @@ def init_external_custom_nodes():
|
||||
module_path = os.path.join(custom_node_path, possible_module)
|
||||
if os.path.isfile(module_path) and os.path.splitext(module_path)[1] != ".py": continue
|
||||
if module_path.endswith(".disabled"): continue
|
||||
if args.disable_all_custom_nodes and possible_module not in args.whitelist_custom_nodes:
|
||||
logging.info(f"Skipping {possible_module} due to disable_all_custom_nodes and whitelist_custom_nodes")
|
||||
continue
|
||||
time_before = time.perf_counter()
|
||||
success = load_custom_node(module_path, base_node_names, module_parent="custom_nodes")
|
||||
node_import_times.append((time.perf_counter() - time_before, module_path, success))
|
||||
@ -2280,6 +2283,8 @@ def init_builtin_extra_nodes():
|
||||
"nodes_ace.py",
|
||||
"nodes_string.py",
|
||||
"nodes_camera_trajectory.py",
|
||||
"nodes_edit_model.py",
|
||||
"nodes_tcfg.py"
|
||||
]
|
||||
|
||||
import_failed = []
|
||||
|
@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "ComfyUI"
|
||||
version = "0.3.41"
|
||||
version = "0.3.43"
|
||||
readme = "README.md"
|
||||
license = { file = "LICENSE" }
|
||||
requires-python = ">=3.9"
|
||||
|
@ -1,13 +1,13 @@
|
||||
comfyui-frontend-package==1.22.2
|
||||
comfyui-workflow-templates==0.1.29
|
||||
comfyui-embedded-docs==0.2.2
|
||||
comfyui-frontend-package==1.23.4
|
||||
comfyui-workflow-templates==0.1.31
|
||||
comfyui-embedded-docs==0.2.3
|
||||
torch
|
||||
torchsde
|
||||
torchvision
|
||||
torchaudio
|
||||
numpy>=1.25.0
|
||||
einops
|
||||
transformers>=4.28.1
|
||||
transformers>=4.37.2
|
||||
tokenizers>=0.13.3
|
||||
sentencepiece
|
||||
safetensors>=0.4.2
|
||||
|
Loading…
x
Reference in New Issue
Block a user