mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-07-29 17:26:34 +00:00
Merge branch 'master' into worksplit-multigpu
This commit is contained in:
commit
d53479a197
108
.github/workflows/release-webhook.yml
vendored
Normal file
108
.github/workflows/release-webhook.yml
vendored
Normal file
@ -0,0 +1,108 @@
|
|||||||
|
name: Release Webhook
|
||||||
|
|
||||||
|
on:
|
||||||
|
release:
|
||||||
|
types: [published]
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
send-webhook:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- name: Send release webhook
|
||||||
|
env:
|
||||||
|
WEBHOOK_URL: ${{ secrets.RELEASE_GITHUB_WEBHOOK_URL }}
|
||||||
|
WEBHOOK_SECRET: ${{ secrets.RELEASE_GITHUB_WEBHOOK_SECRET }}
|
||||||
|
run: |
|
||||||
|
# Generate UUID for delivery ID
|
||||||
|
DELIVERY_ID=$(uuidgen)
|
||||||
|
HOOK_ID="release-webhook-$(date +%s)"
|
||||||
|
|
||||||
|
# Create webhook payload matching GitHub release webhook format
|
||||||
|
PAYLOAD=$(cat <<EOF
|
||||||
|
{
|
||||||
|
"action": "published",
|
||||||
|
"release": {
|
||||||
|
"id": ${{ github.event.release.id }},
|
||||||
|
"node_id": "${{ github.event.release.node_id }}",
|
||||||
|
"url": "${{ github.event.release.url }}",
|
||||||
|
"html_url": "${{ github.event.release.html_url }}",
|
||||||
|
"assets_url": "${{ github.event.release.assets_url }}",
|
||||||
|
"upload_url": "${{ github.event.release.upload_url }}",
|
||||||
|
"tag_name": "${{ github.event.release.tag_name }}",
|
||||||
|
"target_commitish": "${{ github.event.release.target_commitish }}",
|
||||||
|
"name": ${{ toJSON(github.event.release.name) }},
|
||||||
|
"body": ${{ toJSON(github.event.release.body) }},
|
||||||
|
"draft": ${{ github.event.release.draft }},
|
||||||
|
"prerelease": ${{ github.event.release.prerelease }},
|
||||||
|
"created_at": "${{ github.event.release.created_at }}",
|
||||||
|
"published_at": "${{ github.event.release.published_at }}",
|
||||||
|
"author": {
|
||||||
|
"login": "${{ github.event.release.author.login }}",
|
||||||
|
"id": ${{ github.event.release.author.id }},
|
||||||
|
"node_id": "${{ github.event.release.author.node_id }}",
|
||||||
|
"avatar_url": "${{ github.event.release.author.avatar_url }}",
|
||||||
|
"url": "${{ github.event.release.author.url }}",
|
||||||
|
"html_url": "${{ github.event.release.author.html_url }}",
|
||||||
|
"type": "${{ github.event.release.author.type }}",
|
||||||
|
"site_admin": ${{ github.event.release.author.site_admin }}
|
||||||
|
},
|
||||||
|
"tarball_url": "${{ github.event.release.tarball_url }}",
|
||||||
|
"zipball_url": "${{ github.event.release.zipball_url }}",
|
||||||
|
"assets": ${{ toJSON(github.event.release.assets) }}
|
||||||
|
},
|
||||||
|
"repository": {
|
||||||
|
"id": ${{ github.event.repository.id }},
|
||||||
|
"node_id": "${{ github.event.repository.node_id }}",
|
||||||
|
"name": "${{ github.event.repository.name }}",
|
||||||
|
"full_name": "${{ github.event.repository.full_name }}",
|
||||||
|
"private": ${{ github.event.repository.private }},
|
||||||
|
"owner": {
|
||||||
|
"login": "${{ github.event.repository.owner.login }}",
|
||||||
|
"id": ${{ github.event.repository.owner.id }},
|
||||||
|
"node_id": "${{ github.event.repository.owner.node_id }}",
|
||||||
|
"avatar_url": "${{ github.event.repository.owner.avatar_url }}",
|
||||||
|
"url": "${{ github.event.repository.owner.url }}",
|
||||||
|
"html_url": "${{ github.event.repository.owner.html_url }}",
|
||||||
|
"type": "${{ github.event.repository.owner.type }}",
|
||||||
|
"site_admin": ${{ github.event.repository.owner.site_admin }}
|
||||||
|
},
|
||||||
|
"html_url": "${{ github.event.repository.html_url }}",
|
||||||
|
"clone_url": "${{ github.event.repository.clone_url }}",
|
||||||
|
"git_url": "${{ github.event.repository.git_url }}",
|
||||||
|
"ssh_url": "${{ github.event.repository.ssh_url }}",
|
||||||
|
"url": "${{ github.event.repository.url }}",
|
||||||
|
"created_at": "${{ github.event.repository.created_at }}",
|
||||||
|
"updated_at": "${{ github.event.repository.updated_at }}",
|
||||||
|
"pushed_at": "${{ github.event.repository.pushed_at }}",
|
||||||
|
"default_branch": "${{ github.event.repository.default_branch }}",
|
||||||
|
"fork": ${{ github.event.repository.fork }}
|
||||||
|
},
|
||||||
|
"sender": {
|
||||||
|
"login": "${{ github.event.sender.login }}",
|
||||||
|
"id": ${{ github.event.sender.id }},
|
||||||
|
"node_id": "${{ github.event.sender.node_id }}",
|
||||||
|
"avatar_url": "${{ github.event.sender.avatar_url }}",
|
||||||
|
"url": "${{ github.event.sender.url }}",
|
||||||
|
"html_url": "${{ github.event.sender.html_url }}",
|
||||||
|
"type": "${{ github.event.sender.type }}",
|
||||||
|
"site_admin": ${{ github.event.sender.site_admin }}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
EOF
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generate HMAC-SHA256 signature
|
||||||
|
SIGNATURE=$(echo -n "$PAYLOAD" | openssl dgst -sha256 -hmac "$WEBHOOK_SECRET" -hex | cut -d' ' -f2)
|
||||||
|
|
||||||
|
# Send webhook with required headers
|
||||||
|
curl -X POST "$WEBHOOK_URL" \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-H "X-GitHub-Event: release" \
|
||||||
|
-H "X-GitHub-Delivery: $DELIVERY_ID" \
|
||||||
|
-H "X-GitHub-Hook-ID: $HOOK_ID" \
|
||||||
|
-H "X-Hub-Signature-256: sha256=$SIGNATURE" \
|
||||||
|
-H "User-Agent: GitHub-Actions-Webhook/1.0" \
|
||||||
|
-d "$PAYLOAD" \
|
||||||
|
--fail --silent --show-error
|
||||||
|
|
||||||
|
echo "✅ Release webhook sent successfully"
|
3
.github/workflows/stable-release.yml
vendored
3
.github/workflows/stable-release.yml
vendored
@ -102,5 +102,4 @@ jobs:
|
|||||||
file: ComfyUI_windows_portable_nvidia.7z
|
file: ComfyUI_windows_portable_nvidia.7z
|
||||||
tag: ${{ inputs.git_tag }}
|
tag: ${{ inputs.git_tag }}
|
||||||
overwrite: true
|
overwrite: true
|
||||||
prerelease: true
|
draft: true
|
||||||
make_latest: false
|
|
||||||
|
@ -66,6 +66,9 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith
|
|||||||
- [Lumina Image 2.0](https://comfyanonymous.github.io/ComfyUI_examples/lumina2/)
|
- [Lumina Image 2.0](https://comfyanonymous.github.io/ComfyUI_examples/lumina2/)
|
||||||
- [HiDream](https://comfyanonymous.github.io/ComfyUI_examples/hidream/)
|
- [HiDream](https://comfyanonymous.github.io/ComfyUI_examples/hidream/)
|
||||||
- [Cosmos Predict2](https://comfyanonymous.github.io/ComfyUI_examples/cosmos_predict2/)
|
- [Cosmos Predict2](https://comfyanonymous.github.io/ComfyUI_examples/cosmos_predict2/)
|
||||||
|
- Image Editing Models
|
||||||
|
- [Omnigen 2](https://comfyanonymous.github.io/ComfyUI_examples/omnigen/)
|
||||||
|
- [Flux Kontext](https://comfyanonymous.github.io/ComfyUI_examples/flux/#flux-kontext-image-editing-model)
|
||||||
- Video Models
|
- Video Models
|
||||||
- [Stable Video Diffusion](https://comfyanonymous.github.io/ComfyUI_examples/video/)
|
- [Stable Video Diffusion](https://comfyanonymous.github.io/ComfyUI_examples/video/)
|
||||||
- [Mochi](https://comfyanonymous.github.io/ComfyUI_examples/mochi/)
|
- [Mochi](https://comfyanonymous.github.io/ComfyUI_examples/mochi/)
|
||||||
|
@ -151,6 +151,7 @@ parser.add_argument("--windows-standalone-build", action="store_true", help="Win
|
|||||||
|
|
||||||
parser.add_argument("--disable-metadata", action="store_true", help="Disable saving prompt metadata in files.")
|
parser.add_argument("--disable-metadata", action="store_true", help="Disable saving prompt metadata in files.")
|
||||||
parser.add_argument("--disable-all-custom-nodes", action="store_true", help="Disable loading all custom nodes.")
|
parser.add_argument("--disable-all-custom-nodes", action="store_true", help="Disable loading all custom nodes.")
|
||||||
|
parser.add_argument("--whitelist-custom-nodes", type=str, nargs='+', default=[], help="Specify custom node folders to load even when --disable-all-custom-nodes is enabled.")
|
||||||
parser.add_argument("--disable-api-nodes", action="store_true", help="Disable loading all api nodes.")
|
parser.add_argument("--disable-api-nodes", action="store_true", help="Disable loading all api nodes.")
|
||||||
|
|
||||||
parser.add_argument("--multi-user", action="store_true", help="Enables per-user storage.")
|
parser.add_argument("--multi-user", action="store_true", help="Enables per-user storage.")
|
||||||
|
@ -710,6 +710,7 @@ def sample_dpmpp_2s_ancestral_RF(model, x, sigmas, extra_args=None, callback=Non
|
|||||||
# logged_x = torch.cat((logged_x, x.unsqueeze(0)), dim=0)
|
# logged_x = torch.cat((logged_x, x.unsqueeze(0)), dim=0)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2):
|
def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2):
|
||||||
"""DPM-Solver++ (stochastic)."""
|
"""DPM-Solver++ (stochastic)."""
|
||||||
@ -721,38 +722,49 @@ def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=N
|
|||||||
seed = extra_args.get("seed", None)
|
seed = extra_args.get("seed", None)
|
||||||
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler
|
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler
|
||||||
s_in = x.new_ones([x.shape[0]])
|
s_in = x.new_ones([x.shape[0]])
|
||||||
sigma_fn = lambda t: t.neg().exp()
|
|
||||||
t_fn = lambda sigma: sigma.log().neg()
|
model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling')
|
||||||
|
sigma_fn = partial(half_log_snr_to_sigma, model_sampling=model_sampling)
|
||||||
|
lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling)
|
||||||
|
sigmas = offset_first_sigma_for_snr(sigmas, model_sampling)
|
||||||
|
|
||||||
for i in trange(len(sigmas) - 1, disable=disable):
|
for i in trange(len(sigmas) - 1, disable=disable):
|
||||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||||
if callback is not None:
|
if callback is not None:
|
||||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||||
if sigmas[i + 1] == 0:
|
if sigmas[i + 1] == 0:
|
||||||
# Euler method
|
# Denoising step
|
||||||
d = to_d(x, sigmas[i], denoised)
|
x = denoised
|
||||||
dt = sigmas[i + 1] - sigmas[i]
|
|
||||||
x = x + d * dt
|
|
||||||
else:
|
else:
|
||||||
# DPM-Solver++
|
# DPM-Solver++
|
||||||
t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1])
|
lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1])
|
||||||
h = t_next - t
|
h = lambda_t - lambda_s
|
||||||
s = t + h * r
|
lambda_s_1 = lambda_s + r * h
|
||||||
fac = 1 / (2 * r)
|
fac = 1 / (2 * r)
|
||||||
|
|
||||||
|
sigma_s_1 = sigma_fn(lambda_s_1)
|
||||||
|
|
||||||
|
alpha_s = sigmas[i] * lambda_s.exp()
|
||||||
|
alpha_s_1 = sigma_s_1 * lambda_s_1.exp()
|
||||||
|
alpha_t = sigmas[i + 1] * lambda_t.exp()
|
||||||
|
|
||||||
# Step 1
|
# Step 1
|
||||||
sd, su = get_ancestral_step(sigma_fn(t), sigma_fn(s), eta)
|
sd, su = get_ancestral_step(lambda_s.neg().exp(), lambda_s_1.neg().exp(), eta)
|
||||||
s_ = t_fn(sd)
|
lambda_s_1_ = sd.log().neg()
|
||||||
x_2 = (sigma_fn(s_) / sigma_fn(t)) * x - (t - s_).expm1() * denoised
|
h_ = lambda_s_1_ - lambda_s
|
||||||
x_2 = x_2 + noise_sampler(sigma_fn(t), sigma_fn(s)) * s_noise * su
|
x_2 = (alpha_s_1 / alpha_s) * (-h_).exp() * x - alpha_s_1 * (-h_).expm1() * denoised
|
||||||
denoised_2 = model(x_2, sigma_fn(s) * s_in, **extra_args)
|
if eta > 0 and s_noise > 0:
|
||||||
|
x_2 = x_2 + alpha_s_1 * noise_sampler(sigmas[i], sigma_s_1) * s_noise * su
|
||||||
|
denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args)
|
||||||
|
|
||||||
# Step 2
|
# Step 2
|
||||||
sd, su = get_ancestral_step(sigma_fn(t), sigma_fn(t_next), eta)
|
sd, su = get_ancestral_step(lambda_s.neg().exp(), lambda_t.neg().exp(), eta)
|
||||||
t_next_ = t_fn(sd)
|
lambda_t_ = sd.log().neg()
|
||||||
|
h_ = lambda_t_ - lambda_s
|
||||||
denoised_d = (1 - fac) * denoised + fac * denoised_2
|
denoised_d = (1 - fac) * denoised + fac * denoised_2
|
||||||
x = (sigma_fn(t_next_) / sigma_fn(t)) * x - (t - t_next_).expm1() * denoised_d
|
x = (alpha_t / alpha_s) * (-h_).exp() * x - alpha_t * (-h_).expm1() * denoised_d
|
||||||
x = x + noise_sampler(sigma_fn(t), sigma_fn(t_next)) * s_noise * su
|
if eta > 0 and s_noise > 0:
|
||||||
|
x = x + alpha_t * noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * su
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
@ -1435,14 +1447,15 @@ def sample_gradient_estimation(model, x, sigmas, extra_args=None, callback=None,
|
|||||||
old_d = d
|
old_d = d
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def sample_gradient_estimation_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, ge_gamma=2.):
|
def sample_gradient_estimation_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, ge_gamma=2.):
|
||||||
return sample_gradient_estimation(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, ge_gamma=ge_gamma, cfg_pp=True)
|
return sample_gradient_estimation(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, ge_gamma=ge_gamma, cfg_pp=True)
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, s_noise=1., noise_sampler=None, noise_scaler=None, max_stage=3):
|
def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, s_noise=1.0, noise_sampler=None, noise_scaler=None, max_stage=3):
|
||||||
"""
|
"""Extended Reverse-Time SDE solver (VP ER-SDE-Solver-3). arXiv: https://arxiv.org/abs/2309.06169.
|
||||||
Extended Reverse-Time SDE solver (VE ER-SDE-Solver-3). Arxiv: https://arxiv.org/abs/2309.06169.
|
|
||||||
Code reference: https://github.com/QinpengCui/ER-SDE-Solver/blob/main/er_sde_solver.py.
|
Code reference: https://github.com/QinpengCui/ER-SDE-Solver/blob/main/er_sde_solver.py.
|
||||||
"""
|
"""
|
||||||
extra_args = {} if extra_args is None else extra_args
|
extra_args = {} if extra_args is None else extra_args
|
||||||
@ -1450,12 +1463,18 @@ def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None
|
|||||||
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
||||||
s_in = x.new_ones([x.shape[0]])
|
s_in = x.new_ones([x.shape[0]])
|
||||||
|
|
||||||
def default_noise_scaler(sigma):
|
def default_er_sde_noise_scaler(x):
|
||||||
return sigma * ((sigma ** 0.3).exp() + 10.0)
|
return x * ((x ** 0.3).exp() + 10.0)
|
||||||
noise_scaler = default_noise_scaler if noise_scaler is None else noise_scaler
|
|
||||||
|
noise_scaler = default_er_sde_noise_scaler if noise_scaler is None else noise_scaler
|
||||||
num_integration_points = 200.0
|
num_integration_points = 200.0
|
||||||
point_indice = torch.arange(0, num_integration_points, dtype=torch.float32, device=x.device)
|
point_indice = torch.arange(0, num_integration_points, dtype=torch.float32, device=x.device)
|
||||||
|
|
||||||
|
model_sampling = model.inner_model.model_patcher.get_model_object("model_sampling")
|
||||||
|
sigmas = offset_first_sigma_for_snr(sigmas, model_sampling)
|
||||||
|
half_log_snrs = sigma_to_half_log_snr(sigmas, model_sampling)
|
||||||
|
er_lambdas = half_log_snrs.neg().exp() # er_lambda_t = sigma_t / alpha_t
|
||||||
|
|
||||||
old_denoised = None
|
old_denoised = None
|
||||||
old_denoised_d = None
|
old_denoised_d = None
|
||||||
|
|
||||||
@ -1466,32 +1485,36 @@ def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None
|
|||||||
stage_used = min(max_stage, i + 1)
|
stage_used = min(max_stage, i + 1)
|
||||||
if sigmas[i + 1] == 0:
|
if sigmas[i + 1] == 0:
|
||||||
x = denoised
|
x = denoised
|
||||||
elif stage_used == 1:
|
|
||||||
r = noise_scaler(sigmas[i + 1]) / noise_scaler(sigmas[i])
|
|
||||||
x = r * x + (1 - r) * denoised
|
|
||||||
else:
|
else:
|
||||||
r = noise_scaler(sigmas[i + 1]) / noise_scaler(sigmas[i])
|
er_lambda_s, er_lambda_t = er_lambdas[i], er_lambdas[i + 1]
|
||||||
x = r * x + (1 - r) * denoised
|
alpha_s = sigmas[i] / er_lambda_s
|
||||||
|
alpha_t = sigmas[i + 1] / er_lambda_t
|
||||||
|
r_alpha = alpha_t / alpha_s
|
||||||
|
r = noise_scaler(er_lambda_t) / noise_scaler(er_lambda_s)
|
||||||
|
|
||||||
dt = sigmas[i + 1] - sigmas[i]
|
# Stage 1 Euler
|
||||||
sigma_step_size = -dt / num_integration_points
|
x = r_alpha * r * x + alpha_t * (1 - r) * denoised
|
||||||
sigma_pos = sigmas[i + 1] + point_indice * sigma_step_size
|
|
||||||
scaled_pos = noise_scaler(sigma_pos)
|
|
||||||
|
|
||||||
# Stage 2
|
if stage_used >= 2:
|
||||||
s = torch.sum(1 / scaled_pos) * sigma_step_size
|
dt = er_lambda_t - er_lambda_s
|
||||||
denoised_d = (denoised - old_denoised) / (sigmas[i] - sigmas[i - 1])
|
lambda_step_size = -dt / num_integration_points
|
||||||
x = x + (dt + s * noise_scaler(sigmas[i + 1])) * denoised_d
|
lambda_pos = er_lambda_t + point_indice * lambda_step_size
|
||||||
|
scaled_pos = noise_scaler(lambda_pos)
|
||||||
|
|
||||||
if stage_used >= 3:
|
# Stage 2
|
||||||
# Stage 3
|
s = torch.sum(1 / scaled_pos) * lambda_step_size
|
||||||
s_u = torch.sum((sigma_pos - sigmas[i]) / scaled_pos) * sigma_step_size
|
denoised_d = (denoised - old_denoised) / (er_lambda_s - er_lambdas[i - 1])
|
||||||
denoised_u = (denoised_d - old_denoised_d) / ((sigmas[i] - sigmas[i - 2]) / 2)
|
x = x + alpha_t * (dt + s * noise_scaler(er_lambda_t)) * denoised_d
|
||||||
x = x + ((dt ** 2) / 2 + s_u * noise_scaler(sigmas[i + 1])) * denoised_u
|
|
||||||
old_denoised_d = denoised_d
|
|
||||||
|
|
||||||
if s_noise != 0 and sigmas[i + 1] > 0:
|
if stage_used >= 3:
|
||||||
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * (sigmas[i + 1] ** 2 - sigmas[i] ** 2 * r ** 2).sqrt().nan_to_num(nan=0.0)
|
# Stage 3
|
||||||
|
s_u = torch.sum((lambda_pos - er_lambda_s) / scaled_pos) * lambda_step_size
|
||||||
|
denoised_u = (denoised_d - old_denoised_d) / ((er_lambda_s - er_lambdas[i - 2]) / 2)
|
||||||
|
x = x + alpha_t * ((dt ** 2) / 2 + s_u * noise_scaler(er_lambda_t)) * denoised_u
|
||||||
|
old_denoised_d = denoised_d
|
||||||
|
|
||||||
|
if s_noise > 0:
|
||||||
|
x = x + alpha_t * noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * (er_lambda_t ** 2 - er_lambda_s ** 2 * r ** 2).sqrt().nan_to_num(nan=0.0)
|
||||||
old_denoised = denoised
|
old_denoised = denoised
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
@ -195,20 +195,50 @@ class Flux(nn.Module):
|
|||||||
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
|
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
|
||||||
return img
|
return img
|
||||||
|
|
||||||
def forward(self, x, timestep, context, y=None, guidance=None, control=None, transformer_options={}, **kwargs):
|
def process_img(self, x, index=0, h_offset=0, w_offset=0):
|
||||||
bs, c, h, w = x.shape
|
bs, c, h, w = x.shape
|
||||||
patch_size = self.patch_size
|
patch_size = self.patch_size
|
||||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
|
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
|
||||||
|
|
||||||
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
|
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
|
||||||
|
|
||||||
h_len = ((h + (patch_size // 2)) // patch_size)
|
h_len = ((h + (patch_size // 2)) // patch_size)
|
||||||
w_len = ((w + (patch_size // 2)) // patch_size)
|
w_len = ((w + (patch_size // 2)) // patch_size)
|
||||||
|
|
||||||
|
h_offset = ((h_offset + (patch_size // 2)) // patch_size)
|
||||||
|
w_offset = ((w_offset + (patch_size // 2)) // patch_size)
|
||||||
|
|
||||||
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
|
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
|
||||||
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1)
|
img_ids[:, :, 0] = img_ids[:, :, 1] + index
|
||||||
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
|
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1)
|
||||||
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
|
||||||
|
return img, repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
||||||
|
|
||||||
|
def forward(self, x, timestep, context, y=None, guidance=None, ref_latents=None, control=None, transformer_options={}, **kwargs):
|
||||||
|
bs, c, h_orig, w_orig = x.shape
|
||||||
|
patch_size = self.patch_size
|
||||||
|
|
||||||
|
h_len = ((h_orig + (patch_size // 2)) // patch_size)
|
||||||
|
w_len = ((w_orig + (patch_size // 2)) // patch_size)
|
||||||
|
img, img_ids = self.process_img(x)
|
||||||
|
img_tokens = img.shape[1]
|
||||||
|
if ref_latents is not None:
|
||||||
|
h = 0
|
||||||
|
w = 0
|
||||||
|
for ref in ref_latents:
|
||||||
|
h_offset = 0
|
||||||
|
w_offset = 0
|
||||||
|
if ref.shape[-2] + h > ref.shape[-1] + w:
|
||||||
|
w_offset = w
|
||||||
|
else:
|
||||||
|
h_offset = h
|
||||||
|
|
||||||
|
kontext, kontext_ids = self.process_img(ref, index=1, h_offset=h_offset, w_offset=w_offset)
|
||||||
|
img = torch.cat([img, kontext], dim=1)
|
||||||
|
img_ids = torch.cat([img_ids, kontext_ids], dim=1)
|
||||||
|
h = max(h, ref.shape[-2] + h_offset)
|
||||||
|
w = max(w, ref.shape[-1] + w_offset)
|
||||||
|
|
||||||
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
||||||
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None))
|
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None))
|
||||||
return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:,:,:h,:w]
|
out = out[:, :img_tokens]
|
||||||
|
return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:,:,:h_orig,:w_orig]
|
||||||
|
@ -11,7 +11,7 @@ from comfy.ldm.modules.ema import LitEma
|
|||||||
import comfy.ops
|
import comfy.ops
|
||||||
|
|
||||||
class DiagonalGaussianRegularizer(torch.nn.Module):
|
class DiagonalGaussianRegularizer(torch.nn.Module):
|
||||||
def __init__(self, sample: bool = True):
|
def __init__(self, sample: bool = False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.sample = sample
|
self.sample = sample
|
||||||
|
|
||||||
@ -19,16 +19,12 @@ class DiagonalGaussianRegularizer(torch.nn.Module):
|
|||||||
yield from ()
|
yield from ()
|
||||||
|
|
||||||
def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
|
def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
|
||||||
log = dict()
|
|
||||||
posterior = DiagonalGaussianDistribution(z)
|
posterior = DiagonalGaussianDistribution(z)
|
||||||
if self.sample:
|
if self.sample:
|
||||||
z = posterior.sample()
|
z = posterior.sample()
|
||||||
else:
|
else:
|
||||||
z = posterior.mode()
|
z = posterior.mode()
|
||||||
kl_loss = posterior.kl()
|
return z, None
|
||||||
kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
|
|
||||||
log["kl_loss"] = kl_loss
|
|
||||||
return z, log
|
|
||||||
|
|
||||||
|
|
||||||
class AbstractAutoencoder(torch.nn.Module):
|
class AbstractAutoencoder(torch.nn.Module):
|
||||||
|
469
comfy/ldm/omnigen/omnigen2.py
Normal file
469
comfy/ldm/omnigen/omnigen2.py
Normal file
@ -0,0 +1,469 @@
|
|||||||
|
# Original code: https://github.com/VectorSpaceLab/OmniGen2
|
||||||
|
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from einops import rearrange, repeat
|
||||||
|
from comfy.ldm.lightricks.model import Timesteps
|
||||||
|
from comfy.ldm.flux.layers import EmbedND
|
||||||
|
from comfy.ldm.modules.attention import optimized_attention_masked
|
||||||
|
import comfy.model_management
|
||||||
|
import comfy.ldm.common_dit
|
||||||
|
|
||||||
|
|
||||||
|
def apply_rotary_emb(x, freqs_cis):
|
||||||
|
if x.shape[1] == 0:
|
||||||
|
return x
|
||||||
|
|
||||||
|
t_ = x.reshape(*x.shape[:-1], -1, 1, 2)
|
||||||
|
t_out = freqs_cis[..., 0] * t_[..., 0] + freqs_cis[..., 1] * t_[..., 1]
|
||||||
|
return t_out.reshape(*x.shape).to(dtype=x.dtype)
|
||||||
|
|
||||||
|
|
||||||
|
def swiglu(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||||
|
return F.silu(x) * y
|
||||||
|
|
||||||
|
|
||||||
|
class TimestepEmbedding(nn.Module):
|
||||||
|
def __init__(self, in_channels: int, time_embed_dim: int, dtype=None, device=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
self.linear_1 = operations.Linear(in_channels, time_embed_dim, dtype=dtype, device=device)
|
||||||
|
self.act = nn.SiLU()
|
||||||
|
self.linear_2 = operations.Linear(time_embed_dim, time_embed_dim, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
def forward(self, sample: torch.Tensor) -> torch.Tensor:
|
||||||
|
sample = self.linear_1(sample)
|
||||||
|
sample = self.act(sample)
|
||||||
|
sample = self.linear_2(sample)
|
||||||
|
return sample
|
||||||
|
|
||||||
|
|
||||||
|
class LuminaRMSNormZero(nn.Module):
|
||||||
|
def __init__(self, embedding_dim: int, norm_eps: float = 1e-5, dtype=None, device=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
self.silu = nn.SiLU()
|
||||||
|
self.linear = operations.Linear(min(embedding_dim, 1024), 4 * embedding_dim, dtype=dtype, device=device)
|
||||||
|
self.norm = operations.RMSNorm(embedding_dim, eps=norm_eps, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, emb: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
|
emb = self.linear(self.silu(emb))
|
||||||
|
scale_msa, gate_msa, scale_mlp, gate_mlp = emb.chunk(4, dim=1)
|
||||||
|
x = self.norm(x) * (1 + scale_msa[:, None])
|
||||||
|
return x, gate_msa, scale_mlp, gate_mlp
|
||||||
|
|
||||||
|
|
||||||
|
class LuminaLayerNormContinuous(nn.Module):
|
||||||
|
def __init__(self, embedding_dim: int, conditioning_embedding_dim: int, elementwise_affine: bool = False, eps: float = 1e-6, out_dim: Optional[int] = None, dtype=None, device=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
self.silu = nn.SiLU()
|
||||||
|
self.linear_1 = operations.Linear(conditioning_embedding_dim, embedding_dim, dtype=dtype, device=device)
|
||||||
|
self.norm = operations.LayerNorm(embedding_dim, eps, elementwise_affine, dtype=dtype, device=device)
|
||||||
|
self.linear_2 = operations.Linear(embedding_dim, out_dim, bias=True, dtype=dtype, device=device) if out_dim is not None else None
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor:
|
||||||
|
emb = self.linear_1(self.silu(conditioning_embedding).to(x.dtype))
|
||||||
|
x = self.norm(x) * (1 + emb)[:, None, :]
|
||||||
|
if self.linear_2 is not None:
|
||||||
|
x = self.linear_2(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class LuminaFeedForward(nn.Module):
|
||||||
|
def __init__(self, dim: int, inner_dim: int, multiple_of: int = 256, dtype=None, device=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
inner_dim = multiple_of * ((inner_dim + multiple_of - 1) // multiple_of)
|
||||||
|
self.linear_1 = operations.Linear(dim, inner_dim, bias=False, dtype=dtype, device=device)
|
||||||
|
self.linear_2 = operations.Linear(inner_dim, dim, bias=False, dtype=dtype, device=device)
|
||||||
|
self.linear_3 = operations.Linear(dim, inner_dim, bias=False, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
h1, h2 = self.linear_1(x), self.linear_3(x)
|
||||||
|
return self.linear_2(swiglu(h1, h2))
|
||||||
|
|
||||||
|
|
||||||
|
class Lumina2CombinedTimestepCaptionEmbedding(nn.Module):
|
||||||
|
def __init__(self, hidden_size: int = 4096, text_feat_dim: int = 2048, frequency_embedding_size: int = 256, norm_eps: float = 1e-5, timestep_scale: float = 1.0, dtype=None, device=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
self.time_proj = Timesteps(num_channels=frequency_embedding_size, flip_sin_to_cos=True, downscale_freq_shift=0.0, scale=timestep_scale)
|
||||||
|
self.timestep_embedder = TimestepEmbedding(in_channels=frequency_embedding_size, time_embed_dim=min(hidden_size, 1024), dtype=dtype, device=device, operations=operations)
|
||||||
|
self.caption_embedder = nn.Sequential(
|
||||||
|
operations.RMSNorm(text_feat_dim, eps=norm_eps, dtype=dtype, device=device),
|
||||||
|
operations.Linear(text_feat_dim, hidden_size, bias=True, dtype=dtype, device=device),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, timestep: torch.Tensor, text_hidden_states: torch.Tensor, dtype: torch.dtype) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
timestep_proj = self.time_proj(timestep).to(dtype=dtype)
|
||||||
|
time_embed = self.timestep_embedder(timestep_proj)
|
||||||
|
caption_embed = self.caption_embedder(text_hidden_states)
|
||||||
|
return time_embed, caption_embed
|
||||||
|
|
||||||
|
|
||||||
|
class Attention(nn.Module):
|
||||||
|
def __init__(self, query_dim: int, dim_head: int, heads: int, kv_heads: int, eps: float = 1e-5, bias: bool = False, dtype=None, device=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
self.heads = heads
|
||||||
|
self.kv_heads = kv_heads
|
||||||
|
self.dim_head = dim_head
|
||||||
|
self.scale = dim_head ** -0.5
|
||||||
|
|
||||||
|
self.to_q = operations.Linear(query_dim, heads * dim_head, bias=bias, dtype=dtype, device=device)
|
||||||
|
self.to_k = operations.Linear(query_dim, kv_heads * dim_head, bias=bias, dtype=dtype, device=device)
|
||||||
|
self.to_v = operations.Linear(query_dim, kv_heads * dim_head, bias=bias, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
self.norm_q = operations.RMSNorm(dim_head, eps=eps, dtype=dtype, device=device)
|
||||||
|
self.norm_k = operations.RMSNorm(dim_head, eps=eps, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
self.to_out = nn.Sequential(
|
||||||
|
operations.Linear(heads * dim_head, query_dim, bias=bias, dtype=dtype, device=device),
|
||||||
|
nn.Dropout(0.0)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, image_rotary_emb: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||||
|
batch_size, sequence_length, _ = hidden_states.shape
|
||||||
|
|
||||||
|
query = self.to_q(hidden_states)
|
||||||
|
key = self.to_k(encoder_hidden_states)
|
||||||
|
value = self.to_v(encoder_hidden_states)
|
||||||
|
|
||||||
|
query = query.view(batch_size, -1, self.heads, self.dim_head)
|
||||||
|
key = key.view(batch_size, -1, self.kv_heads, self.dim_head)
|
||||||
|
value = value.view(batch_size, -1, self.kv_heads, self.dim_head)
|
||||||
|
|
||||||
|
query = self.norm_q(query)
|
||||||
|
key = self.norm_k(key)
|
||||||
|
|
||||||
|
if image_rotary_emb is not None:
|
||||||
|
query = apply_rotary_emb(query, image_rotary_emb)
|
||||||
|
key = apply_rotary_emb(key, image_rotary_emb)
|
||||||
|
|
||||||
|
query = query.transpose(1, 2)
|
||||||
|
key = key.transpose(1, 2)
|
||||||
|
value = value.transpose(1, 2)
|
||||||
|
|
||||||
|
if self.kv_heads < self.heads:
|
||||||
|
key = key.repeat_interleave(self.heads // self.kv_heads, dim=1)
|
||||||
|
value = value.repeat_interleave(self.heads // self.kv_heads, dim=1)
|
||||||
|
|
||||||
|
hidden_states = optimized_attention_masked(query, key, value, self.heads, attention_mask, skip_reshape=True)
|
||||||
|
hidden_states = self.to_out[0](hidden_states)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class OmniGen2TransformerBlock(nn.Module):
|
||||||
|
def __init__(self, dim: int, num_attention_heads: int, num_kv_heads: int, multiple_of: int, ffn_dim_multiplier: float, norm_eps: float, modulation: bool = True, dtype=None, device=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
self.modulation = modulation
|
||||||
|
|
||||||
|
self.attn = Attention(
|
||||||
|
query_dim=dim,
|
||||||
|
dim_head=dim // num_attention_heads,
|
||||||
|
heads=num_attention_heads,
|
||||||
|
kv_heads=num_kv_heads,
|
||||||
|
eps=1e-5,
|
||||||
|
bias=False,
|
||||||
|
dtype=dtype, device=device, operations=operations,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.feed_forward = LuminaFeedForward(
|
||||||
|
dim=dim,
|
||||||
|
inner_dim=4 * dim,
|
||||||
|
multiple_of=multiple_of,
|
||||||
|
dtype=dtype, device=device, operations=operations
|
||||||
|
)
|
||||||
|
|
||||||
|
if modulation:
|
||||||
|
self.norm1 = LuminaRMSNormZero(embedding_dim=dim, norm_eps=norm_eps, dtype=dtype, device=device, operations=operations)
|
||||||
|
else:
|
||||||
|
self.norm1 = operations.RMSNorm(dim, eps=norm_eps, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
self.ffn_norm1 = operations.RMSNorm(dim, eps=norm_eps, dtype=dtype, device=device)
|
||||||
|
self.norm2 = operations.RMSNorm(dim, eps=norm_eps, dtype=dtype, device=device)
|
||||||
|
self.ffn_norm2 = operations.RMSNorm(dim, eps=norm_eps, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, image_rotary_emb: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||||
|
if self.modulation:
|
||||||
|
norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb)
|
||||||
|
attn_output = self.attn(norm_hidden_states, norm_hidden_states, attention_mask, image_rotary_emb)
|
||||||
|
hidden_states = hidden_states + gate_msa.unsqueeze(1).tanh() * self.norm2(attn_output)
|
||||||
|
mlp_output = self.feed_forward(self.ffn_norm1(hidden_states) * (1 + scale_mlp.unsqueeze(1)))
|
||||||
|
hidden_states = hidden_states + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(mlp_output)
|
||||||
|
else:
|
||||||
|
norm_hidden_states = self.norm1(hidden_states)
|
||||||
|
attn_output = self.attn(norm_hidden_states, norm_hidden_states, attention_mask, image_rotary_emb)
|
||||||
|
hidden_states = hidden_states + self.norm2(attn_output)
|
||||||
|
mlp_output = self.feed_forward(self.ffn_norm1(hidden_states))
|
||||||
|
hidden_states = hidden_states + self.ffn_norm2(mlp_output)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class OmniGen2RotaryPosEmbed(nn.Module):
|
||||||
|
def __init__(self, theta: int, axes_dim: Tuple[int, int, int], axes_lens: Tuple[int, int, int] = (300, 512, 512), patch_size: int = 2):
|
||||||
|
super().__init__()
|
||||||
|
self.theta = theta
|
||||||
|
self.axes_dim = axes_dim
|
||||||
|
self.axes_lens = axes_lens
|
||||||
|
self.patch_size = patch_size
|
||||||
|
self.rope_embedder = EmbedND(dim=sum(axes_dim), theta=self.theta, axes_dim=axes_dim)
|
||||||
|
|
||||||
|
def forward(self, batch_size, encoder_seq_len, l_effective_cap_len, l_effective_ref_img_len, l_effective_img_len, ref_img_sizes, img_sizes, device):
|
||||||
|
p = self.patch_size
|
||||||
|
|
||||||
|
seq_lengths = [cap_len + sum(ref_img_len) + img_len for cap_len, ref_img_len, img_len in zip(l_effective_cap_len, l_effective_ref_img_len, l_effective_img_len)]
|
||||||
|
|
||||||
|
max_seq_len = max(seq_lengths)
|
||||||
|
max_ref_img_len = max([sum(ref_img_len) for ref_img_len in l_effective_ref_img_len])
|
||||||
|
max_img_len = max(l_effective_img_len)
|
||||||
|
|
||||||
|
position_ids = torch.zeros(batch_size, max_seq_len, 3, dtype=torch.int32, device=device)
|
||||||
|
|
||||||
|
for i, (cap_seq_len, seq_len) in enumerate(zip(l_effective_cap_len, seq_lengths)):
|
||||||
|
position_ids[i, :cap_seq_len] = repeat(torch.arange(cap_seq_len, dtype=torch.int32, device=device), "l -> l 3")
|
||||||
|
|
||||||
|
pe_shift = cap_seq_len
|
||||||
|
pe_shift_len = cap_seq_len
|
||||||
|
|
||||||
|
if ref_img_sizes[i] is not None:
|
||||||
|
for ref_img_size, ref_img_len in zip(ref_img_sizes[i], l_effective_ref_img_len[i]):
|
||||||
|
H, W = ref_img_size
|
||||||
|
ref_H_tokens, ref_W_tokens = H // p, W // p
|
||||||
|
|
||||||
|
row_ids = repeat(torch.arange(ref_H_tokens, dtype=torch.int32, device=device), "h -> h w", w=ref_W_tokens).flatten()
|
||||||
|
col_ids = repeat(torch.arange(ref_W_tokens, dtype=torch.int32, device=device), "w -> h w", h=ref_H_tokens).flatten()
|
||||||
|
position_ids[i, pe_shift_len:pe_shift_len + ref_img_len, 0] = pe_shift
|
||||||
|
position_ids[i, pe_shift_len:pe_shift_len + ref_img_len, 1] = row_ids
|
||||||
|
position_ids[i, pe_shift_len:pe_shift_len + ref_img_len, 2] = col_ids
|
||||||
|
|
||||||
|
pe_shift += max(ref_H_tokens, ref_W_tokens)
|
||||||
|
pe_shift_len += ref_img_len
|
||||||
|
|
||||||
|
H, W = img_sizes[i]
|
||||||
|
H_tokens, W_tokens = H // p, W // p
|
||||||
|
|
||||||
|
row_ids = repeat(torch.arange(H_tokens, dtype=torch.int32, device=device), "h -> h w", w=W_tokens).flatten()
|
||||||
|
col_ids = repeat(torch.arange(W_tokens, dtype=torch.int32, device=device), "w -> h w", h=H_tokens).flatten()
|
||||||
|
|
||||||
|
position_ids[i, pe_shift_len: seq_len, 0] = pe_shift
|
||||||
|
position_ids[i, pe_shift_len: seq_len, 1] = row_ids
|
||||||
|
position_ids[i, pe_shift_len: seq_len, 2] = col_ids
|
||||||
|
|
||||||
|
freqs_cis = self.rope_embedder(position_ids).movedim(1, 2)
|
||||||
|
|
||||||
|
cap_freqs_cis_shape = list(freqs_cis.shape)
|
||||||
|
cap_freqs_cis_shape[1] = encoder_seq_len
|
||||||
|
cap_freqs_cis = torch.zeros(*cap_freqs_cis_shape, device=device, dtype=freqs_cis.dtype)
|
||||||
|
|
||||||
|
ref_img_freqs_cis_shape = list(freqs_cis.shape)
|
||||||
|
ref_img_freqs_cis_shape[1] = max_ref_img_len
|
||||||
|
ref_img_freqs_cis = torch.zeros(*ref_img_freqs_cis_shape, device=device, dtype=freqs_cis.dtype)
|
||||||
|
|
||||||
|
img_freqs_cis_shape = list(freqs_cis.shape)
|
||||||
|
img_freqs_cis_shape[1] = max_img_len
|
||||||
|
img_freqs_cis = torch.zeros(*img_freqs_cis_shape, device=device, dtype=freqs_cis.dtype)
|
||||||
|
|
||||||
|
for i, (cap_seq_len, ref_img_len, img_len, seq_len) in enumerate(zip(l_effective_cap_len, l_effective_ref_img_len, l_effective_img_len, seq_lengths)):
|
||||||
|
cap_freqs_cis[i, :cap_seq_len] = freqs_cis[i, :cap_seq_len]
|
||||||
|
ref_img_freqs_cis[i, :sum(ref_img_len)] = freqs_cis[i, cap_seq_len:cap_seq_len + sum(ref_img_len)]
|
||||||
|
img_freqs_cis[i, :img_len] = freqs_cis[i, cap_seq_len + sum(ref_img_len):cap_seq_len + sum(ref_img_len) + img_len]
|
||||||
|
|
||||||
|
return cap_freqs_cis, ref_img_freqs_cis, img_freqs_cis, freqs_cis, l_effective_cap_len, seq_lengths
|
||||||
|
|
||||||
|
|
||||||
|
class OmniGen2Transformer2DModel(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
patch_size: int = 2,
|
||||||
|
in_channels: int = 16,
|
||||||
|
out_channels: Optional[int] = None,
|
||||||
|
hidden_size: int = 2304,
|
||||||
|
num_layers: int = 26,
|
||||||
|
num_refiner_layers: int = 2,
|
||||||
|
num_attention_heads: int = 24,
|
||||||
|
num_kv_heads: int = 8,
|
||||||
|
multiple_of: int = 256,
|
||||||
|
ffn_dim_multiplier: Optional[float] = None,
|
||||||
|
norm_eps: float = 1e-5,
|
||||||
|
axes_dim_rope: Tuple[int, int, int] = (32, 32, 32),
|
||||||
|
axes_lens: Tuple[int, int, int] = (300, 512, 512),
|
||||||
|
text_feat_dim: int = 1024,
|
||||||
|
timestep_scale: float = 1.0,
|
||||||
|
image_model=None,
|
||||||
|
device=None,
|
||||||
|
dtype=None,
|
||||||
|
operations=None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.patch_size = patch_size
|
||||||
|
self.out_channels = out_channels or in_channels
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.dtype = dtype
|
||||||
|
|
||||||
|
self.rope_embedder = OmniGen2RotaryPosEmbed(
|
||||||
|
theta=10000,
|
||||||
|
axes_dim=axes_dim_rope,
|
||||||
|
axes_lens=axes_lens,
|
||||||
|
patch_size=patch_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.x_embedder = operations.Linear(patch_size * patch_size * in_channels, hidden_size, dtype=dtype, device=device)
|
||||||
|
self.ref_image_patch_embedder = operations.Linear(patch_size * patch_size * in_channels, hidden_size, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
self.time_caption_embed = Lumina2CombinedTimestepCaptionEmbedding(
|
||||||
|
hidden_size=hidden_size,
|
||||||
|
text_feat_dim=text_feat_dim,
|
||||||
|
norm_eps=norm_eps,
|
||||||
|
timestep_scale=timestep_scale, dtype=dtype, device=device, operations=operations
|
||||||
|
)
|
||||||
|
|
||||||
|
self.noise_refiner = nn.ModuleList([
|
||||||
|
OmniGen2TransformerBlock(
|
||||||
|
hidden_size, num_attention_heads, num_kv_heads,
|
||||||
|
multiple_of, ffn_dim_multiplier, norm_eps, modulation=True, dtype=dtype, device=device, operations=operations
|
||||||
|
) for _ in range(num_refiner_layers)
|
||||||
|
])
|
||||||
|
|
||||||
|
self.ref_image_refiner = nn.ModuleList([
|
||||||
|
OmniGen2TransformerBlock(
|
||||||
|
hidden_size, num_attention_heads, num_kv_heads,
|
||||||
|
multiple_of, ffn_dim_multiplier, norm_eps, modulation=True, dtype=dtype, device=device, operations=operations
|
||||||
|
) for _ in range(num_refiner_layers)
|
||||||
|
])
|
||||||
|
|
||||||
|
self.context_refiner = nn.ModuleList([
|
||||||
|
OmniGen2TransformerBlock(
|
||||||
|
hidden_size, num_attention_heads, num_kv_heads,
|
||||||
|
multiple_of, ffn_dim_multiplier, norm_eps, modulation=False, dtype=dtype, device=device, operations=operations
|
||||||
|
) for _ in range(num_refiner_layers)
|
||||||
|
])
|
||||||
|
|
||||||
|
self.layers = nn.ModuleList([
|
||||||
|
OmniGen2TransformerBlock(
|
||||||
|
hidden_size, num_attention_heads, num_kv_heads,
|
||||||
|
multiple_of, ffn_dim_multiplier, norm_eps, modulation=True, dtype=dtype, device=device, operations=operations
|
||||||
|
) for _ in range(num_layers)
|
||||||
|
])
|
||||||
|
|
||||||
|
self.norm_out = LuminaLayerNormContinuous(
|
||||||
|
embedding_dim=hidden_size,
|
||||||
|
conditioning_embedding_dim=min(hidden_size, 1024),
|
||||||
|
elementwise_affine=False,
|
||||||
|
eps=1e-6,
|
||||||
|
out_dim=patch_size * patch_size * self.out_channels, dtype=dtype, device=device, operations=operations
|
||||||
|
)
|
||||||
|
|
||||||
|
self.image_index_embedding = nn.Parameter(torch.empty(5, hidden_size, device=device, dtype=dtype))
|
||||||
|
|
||||||
|
def flat_and_pad_to_seq(self, hidden_states, ref_image_hidden_states):
|
||||||
|
batch_size = len(hidden_states)
|
||||||
|
p = self.patch_size
|
||||||
|
|
||||||
|
img_sizes = [(img.size(1), img.size(2)) for img in hidden_states]
|
||||||
|
l_effective_img_len = [(H // p) * (W // p) for (H, W) in img_sizes]
|
||||||
|
|
||||||
|
if ref_image_hidden_states is not None:
|
||||||
|
ref_image_hidden_states = list(map(lambda ref: comfy.ldm.common_dit.pad_to_patch_size(ref, (p, p)), ref_image_hidden_states))
|
||||||
|
ref_img_sizes = [[(imgs.size(2), imgs.size(3)) if imgs is not None else None for imgs in ref_image_hidden_states]] * batch_size
|
||||||
|
l_effective_ref_img_len = [[(ref_img_size[0] // p) * (ref_img_size[1] // p) for ref_img_size in _ref_img_sizes] if _ref_img_sizes is not None else [0] for _ref_img_sizes in ref_img_sizes]
|
||||||
|
else:
|
||||||
|
ref_img_sizes = [None for _ in range(batch_size)]
|
||||||
|
l_effective_ref_img_len = [[0] for _ in range(batch_size)]
|
||||||
|
|
||||||
|
flat_ref_img_hidden_states = None
|
||||||
|
if ref_image_hidden_states is not None:
|
||||||
|
imgs = []
|
||||||
|
for ref_img in ref_image_hidden_states:
|
||||||
|
B, C, H, W = ref_img.size()
|
||||||
|
ref_img = rearrange(ref_img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=p, p2=p)
|
||||||
|
imgs.append(ref_img)
|
||||||
|
flat_ref_img_hidden_states = torch.cat(imgs, dim=1)
|
||||||
|
|
||||||
|
img = hidden_states
|
||||||
|
B, C, H, W = img.size()
|
||||||
|
flat_hidden_states = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=p, p2=p)
|
||||||
|
|
||||||
|
return (
|
||||||
|
flat_hidden_states, flat_ref_img_hidden_states,
|
||||||
|
None, None,
|
||||||
|
l_effective_ref_img_len, l_effective_img_len,
|
||||||
|
ref_img_sizes, img_sizes,
|
||||||
|
)
|
||||||
|
|
||||||
|
def img_patch_embed_and_refine(self, hidden_states, ref_image_hidden_states, padded_img_mask, padded_ref_img_mask, noise_rotary_emb, ref_img_rotary_emb, l_effective_ref_img_len, l_effective_img_len, temb):
|
||||||
|
batch_size = len(hidden_states)
|
||||||
|
|
||||||
|
hidden_states = self.x_embedder(hidden_states)
|
||||||
|
if ref_image_hidden_states is not None:
|
||||||
|
ref_image_hidden_states = self.ref_image_patch_embedder(ref_image_hidden_states)
|
||||||
|
image_index_embedding = comfy.model_management.cast_to(self.image_index_embedding, dtype=hidden_states.dtype, device=hidden_states.device)
|
||||||
|
|
||||||
|
for i in range(batch_size):
|
||||||
|
shift = 0
|
||||||
|
for j, ref_img_len in enumerate(l_effective_ref_img_len[i]):
|
||||||
|
ref_image_hidden_states[i, shift:shift + ref_img_len, :] = ref_image_hidden_states[i, shift:shift + ref_img_len, :] + image_index_embedding[j]
|
||||||
|
shift += ref_img_len
|
||||||
|
|
||||||
|
for layer in self.noise_refiner:
|
||||||
|
hidden_states = layer(hidden_states, padded_img_mask, noise_rotary_emb, temb)
|
||||||
|
|
||||||
|
if ref_image_hidden_states is not None:
|
||||||
|
for layer in self.ref_image_refiner:
|
||||||
|
ref_image_hidden_states = layer(ref_image_hidden_states, padded_ref_img_mask, ref_img_rotary_emb, temb)
|
||||||
|
|
||||||
|
hidden_states = torch.cat([ref_image_hidden_states, hidden_states], dim=1)
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
def forward(self, x, timesteps, context, num_tokens, ref_latents=None, attention_mask=None, **kwargs):
|
||||||
|
B, C, H, W = x.shape
|
||||||
|
hidden_states = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size))
|
||||||
|
_, _, H_padded, W_padded = hidden_states.shape
|
||||||
|
timestep = 1.0 - timesteps
|
||||||
|
text_hidden_states = context
|
||||||
|
text_attention_mask = attention_mask
|
||||||
|
ref_image_hidden_states = ref_latents
|
||||||
|
device = hidden_states.device
|
||||||
|
|
||||||
|
temb, text_hidden_states = self.time_caption_embed(timestep, text_hidden_states, hidden_states[0].dtype)
|
||||||
|
|
||||||
|
(
|
||||||
|
hidden_states, ref_image_hidden_states,
|
||||||
|
img_mask, ref_img_mask,
|
||||||
|
l_effective_ref_img_len, l_effective_img_len,
|
||||||
|
ref_img_sizes, img_sizes,
|
||||||
|
) = self.flat_and_pad_to_seq(hidden_states, ref_image_hidden_states)
|
||||||
|
|
||||||
|
(
|
||||||
|
context_rotary_emb, ref_img_rotary_emb, noise_rotary_emb,
|
||||||
|
rotary_emb, encoder_seq_lengths, seq_lengths,
|
||||||
|
) = self.rope_embedder(
|
||||||
|
hidden_states.shape[0], text_hidden_states.shape[1], [num_tokens] * text_hidden_states.shape[0],
|
||||||
|
l_effective_ref_img_len, l_effective_img_len,
|
||||||
|
ref_img_sizes, img_sizes, device,
|
||||||
|
)
|
||||||
|
|
||||||
|
for layer in self.context_refiner:
|
||||||
|
text_hidden_states = layer(text_hidden_states, text_attention_mask, context_rotary_emb)
|
||||||
|
|
||||||
|
img_len = hidden_states.shape[1]
|
||||||
|
combined_img_hidden_states = self.img_patch_embed_and_refine(
|
||||||
|
hidden_states, ref_image_hidden_states,
|
||||||
|
img_mask, ref_img_mask,
|
||||||
|
noise_rotary_emb, ref_img_rotary_emb,
|
||||||
|
l_effective_ref_img_len, l_effective_img_len,
|
||||||
|
temb,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = torch.cat([text_hidden_states, combined_img_hidden_states], dim=1)
|
||||||
|
attention_mask = None
|
||||||
|
|
||||||
|
for layer in self.layers:
|
||||||
|
hidden_states = layer(hidden_states, attention_mask, rotary_emb, temb)
|
||||||
|
|
||||||
|
hidden_states = self.norm_out(hidden_states, temb)
|
||||||
|
|
||||||
|
p = self.patch_size
|
||||||
|
output = rearrange(hidden_states[:, -img_len:], 'b (h w) (p1 p2 c) -> b c (h p1) (w p2)', h=H_padded // p, w=W_padded// p, p1=p, p2=p)[:, :, :H, :W]
|
||||||
|
|
||||||
|
return -output
|
@ -41,6 +41,7 @@ import comfy.ldm.hunyuan3d.model
|
|||||||
import comfy.ldm.hidream.model
|
import comfy.ldm.hidream.model
|
||||||
import comfy.ldm.chroma.model
|
import comfy.ldm.chroma.model
|
||||||
import comfy.ldm.ace.model
|
import comfy.ldm.ace.model
|
||||||
|
import comfy.ldm.omnigen.omnigen2
|
||||||
|
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import comfy.patcher_extension
|
import comfy.patcher_extension
|
||||||
@ -815,6 +816,7 @@ class PixArt(BaseModel):
|
|||||||
class Flux(BaseModel):
|
class Flux(BaseModel):
|
||||||
def __init__(self, model_config, model_type=ModelType.FLUX, device=None, unet_model=comfy.ldm.flux.model.Flux):
|
def __init__(self, model_config, model_type=ModelType.FLUX, device=None, unet_model=comfy.ldm.flux.model.Flux):
|
||||||
super().__init__(model_config, model_type, device=device, unet_model=unet_model)
|
super().__init__(model_config, model_type, device=device, unet_model=unet_model)
|
||||||
|
self.memory_usage_factor_conds = ("ref_latents",)
|
||||||
|
|
||||||
def concat_cond(self, **kwargs):
|
def concat_cond(self, **kwargs):
|
||||||
try:
|
try:
|
||||||
@ -875,8 +877,23 @@ class Flux(BaseModel):
|
|||||||
guidance = kwargs.get("guidance", 3.5)
|
guidance = kwargs.get("guidance", 3.5)
|
||||||
if guidance is not None:
|
if guidance is not None:
|
||||||
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance]))
|
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance]))
|
||||||
|
|
||||||
|
ref_latents = kwargs.get("reference_latents", None)
|
||||||
|
if ref_latents is not None:
|
||||||
|
latents = []
|
||||||
|
for lat in ref_latents:
|
||||||
|
latents.append(self.process_latent_in(lat))
|
||||||
|
out['ref_latents'] = comfy.conds.CONDList(latents)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
def extra_conds_shapes(self, **kwargs):
|
||||||
|
out = {}
|
||||||
|
ref_latents = kwargs.get("reference_latents", None)
|
||||||
|
if ref_latents is not None:
|
||||||
|
out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16])
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
class GenmoMochi(BaseModel):
|
class GenmoMochi(BaseModel):
|
||||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.genmo.joint_model.asymm_models_joint.AsymmDiTJoint)
|
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.genmo.joint_model.asymm_models_joint.AsymmDiTJoint)
|
||||||
@ -1230,3 +1247,33 @@ class ACEStep(BaseModel):
|
|||||||
out['speaker_embeds'] = comfy.conds.CONDRegular(torch.zeros(noise.shape[0], 512, device=noise.device, dtype=noise.dtype))
|
out['speaker_embeds'] = comfy.conds.CONDRegular(torch.zeros(noise.shape[0], 512, device=noise.device, dtype=noise.dtype))
|
||||||
out['lyrics_strength'] = comfy.conds.CONDConstant(kwargs.get("lyrics_strength", 1.0))
|
out['lyrics_strength'] = comfy.conds.CONDConstant(kwargs.get("lyrics_strength", 1.0))
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
class Omnigen2(BaseModel):
|
||||||
|
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||||
|
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.omnigen.omnigen2.OmniGen2Transformer2DModel)
|
||||||
|
self.memory_usage_factor_conds = ("ref_latents",)
|
||||||
|
|
||||||
|
def extra_conds(self, **kwargs):
|
||||||
|
out = super().extra_conds(**kwargs)
|
||||||
|
attention_mask = kwargs.get("attention_mask", None)
|
||||||
|
if attention_mask is not None:
|
||||||
|
if torch.numel(attention_mask) != attention_mask.sum():
|
||||||
|
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
|
||||||
|
out['num_tokens'] = comfy.conds.CONDConstant(max(1, torch.sum(attention_mask).item()))
|
||||||
|
cross_attn = kwargs.get("cross_attn", None)
|
||||||
|
if cross_attn is not None:
|
||||||
|
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||||
|
ref_latents = kwargs.get("reference_latents", None)
|
||||||
|
if ref_latents is not None:
|
||||||
|
latents = []
|
||||||
|
for lat in ref_latents:
|
||||||
|
latents.append(self.process_latent_in(lat))
|
||||||
|
out['ref_latents'] = comfy.conds.CONDList(latents)
|
||||||
|
return out
|
||||||
|
|
||||||
|
def extra_conds_shapes(self, **kwargs):
|
||||||
|
out = {}
|
||||||
|
ref_latents = kwargs.get("reference_latents", None)
|
||||||
|
if ref_latents is not None:
|
||||||
|
out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16])
|
||||||
|
return out
|
||||||
|
@ -459,6 +459,26 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
|||||||
|
|
||||||
return dit_config
|
return dit_config
|
||||||
|
|
||||||
|
if '{}time_caption_embed.timestep_embedder.linear_1.bias'.format(key_prefix) in state_dict_keys: # Omnigen2
|
||||||
|
dit_config = {}
|
||||||
|
dit_config["image_model"] = "omnigen2"
|
||||||
|
dit_config["axes_dim_rope"] = [40, 40, 40]
|
||||||
|
dit_config["axes_lens"] = [1024, 1664, 1664]
|
||||||
|
dit_config["ffn_dim_multiplier"] = None
|
||||||
|
dit_config["hidden_size"] = 2520
|
||||||
|
dit_config["in_channels"] = 16
|
||||||
|
dit_config["multiple_of"] = 256
|
||||||
|
dit_config["norm_eps"] = 1e-05
|
||||||
|
dit_config["num_attention_heads"] = 21
|
||||||
|
dit_config["num_kv_heads"] = 7
|
||||||
|
dit_config["num_layers"] = 32
|
||||||
|
dit_config["num_refiner_layers"] = 2
|
||||||
|
dit_config["out_channels"] = None
|
||||||
|
dit_config["patch_size"] = 2
|
||||||
|
dit_config["text_feat_dim"] = 2048
|
||||||
|
dit_config["timestep_scale"] = 1000.0
|
||||||
|
return dit_config
|
||||||
|
|
||||||
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
|
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
@ -1318,6 +1318,13 @@ def supports_fp8_compute(device=None):
|
|||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
def extended_fp16_support():
|
||||||
|
# TODO: check why some models work with fp16 on newer torch versions but not on older
|
||||||
|
if torch_version_numeric < (2, 7):
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
def soft_empty_cache(force=False):
|
def soft_empty_cache(force=False):
|
||||||
global cpu_state
|
global cpu_state
|
||||||
if cpu_state == CPUState.MPS:
|
if cpu_state == CPUState.MPS:
|
||||||
|
@ -1251,13 +1251,13 @@ class SchedulerHandler(NamedTuple):
|
|||||||
use_ms: bool = True
|
use_ms: bool = True
|
||||||
|
|
||||||
SCHEDULER_HANDLERS = {
|
SCHEDULER_HANDLERS = {
|
||||||
"normal": SchedulerHandler(normal_scheduler),
|
"simple": SchedulerHandler(simple_scheduler),
|
||||||
|
"sgm_uniform": SchedulerHandler(partial(normal_scheduler, sgm=True)),
|
||||||
"karras": SchedulerHandler(k_diffusion_sampling.get_sigmas_karras, use_ms=False),
|
"karras": SchedulerHandler(k_diffusion_sampling.get_sigmas_karras, use_ms=False),
|
||||||
"exponential": SchedulerHandler(k_diffusion_sampling.get_sigmas_exponential, use_ms=False),
|
"exponential": SchedulerHandler(k_diffusion_sampling.get_sigmas_exponential, use_ms=False),
|
||||||
"sgm_uniform": SchedulerHandler(partial(normal_scheduler, sgm=True)),
|
|
||||||
"simple": SchedulerHandler(simple_scheduler),
|
|
||||||
"ddim_uniform": SchedulerHandler(ddim_scheduler),
|
"ddim_uniform": SchedulerHandler(ddim_scheduler),
|
||||||
"beta": SchedulerHandler(beta_scheduler),
|
"beta": SchedulerHandler(beta_scheduler),
|
||||||
|
"normal": SchedulerHandler(normal_scheduler),
|
||||||
"linear_quadratic": SchedulerHandler(linear_quadratic_schedule),
|
"linear_quadratic": SchedulerHandler(linear_quadratic_schedule),
|
||||||
"kl_optimal": SchedulerHandler(kl_optimal_scheduler, use_ms=False),
|
"kl_optimal": SchedulerHandler(kl_optimal_scheduler, use_ms=False),
|
||||||
}
|
}
|
||||||
|
12
comfy/sd.py
12
comfy/sd.py
@ -44,6 +44,7 @@ import comfy.text_encoders.lumina2
|
|||||||
import comfy.text_encoders.wan
|
import comfy.text_encoders.wan
|
||||||
import comfy.text_encoders.hidream
|
import comfy.text_encoders.hidream
|
||||||
import comfy.text_encoders.ace
|
import comfy.text_encoders.ace
|
||||||
|
import comfy.text_encoders.omnigen2
|
||||||
|
|
||||||
import comfy.model_patcher
|
import comfy.model_patcher
|
||||||
import comfy.lora
|
import comfy.lora
|
||||||
@ -754,6 +755,7 @@ class CLIPType(Enum):
|
|||||||
HIDREAM = 14
|
HIDREAM = 14
|
||||||
CHROMA = 15
|
CHROMA = 15
|
||||||
ACE = 16
|
ACE = 16
|
||||||
|
OMNIGEN2 = 17
|
||||||
|
|
||||||
|
|
||||||
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
|
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
|
||||||
@ -773,6 +775,7 @@ class TEModel(Enum):
|
|||||||
LLAMA3_8 = 7
|
LLAMA3_8 = 7
|
||||||
T5_XXL_OLD = 8
|
T5_XXL_OLD = 8
|
||||||
GEMMA_2_2B = 9
|
GEMMA_2_2B = 9
|
||||||
|
QWEN25_3B = 10
|
||||||
|
|
||||||
def detect_te_model(sd):
|
def detect_te_model(sd):
|
||||||
if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:
|
if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:
|
||||||
@ -793,6 +796,8 @@ def detect_te_model(sd):
|
|||||||
return TEModel.T5_BASE
|
return TEModel.T5_BASE
|
||||||
if 'model.layers.0.post_feedforward_layernorm.weight' in sd:
|
if 'model.layers.0.post_feedforward_layernorm.weight' in sd:
|
||||||
return TEModel.GEMMA_2_2B
|
return TEModel.GEMMA_2_2B
|
||||||
|
if 'model.layers.0.self_attn.k_proj.bias' in sd:
|
||||||
|
return TEModel.QWEN25_3B
|
||||||
if "model.layers.0.post_attention_layernorm.weight" in sd:
|
if "model.layers.0.post_attention_layernorm.weight" in sd:
|
||||||
return TEModel.LLAMA3_8
|
return TEModel.LLAMA3_8
|
||||||
return None
|
return None
|
||||||
@ -894,6 +899,9 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
|||||||
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(**llama_detect(clip_data),
|
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(**llama_detect(clip_data),
|
||||||
clip_l=False, clip_g=False, t5=False, llama=True, dtype_t5=None, t5xxl_scaled_fp8=None)
|
clip_l=False, clip_g=False, t5=False, llama=True, dtype_t5=None, t5xxl_scaled_fp8=None)
|
||||||
clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer
|
clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer
|
||||||
|
elif te_model == TEModel.QWEN25_3B:
|
||||||
|
clip_target.clip = comfy.text_encoders.omnigen2.te(**llama_detect(clip_data))
|
||||||
|
clip_target.tokenizer = comfy.text_encoders.omnigen2.Omnigen2Tokenizer
|
||||||
else:
|
else:
|
||||||
# clip_l
|
# clip_l
|
||||||
if clip_type == CLIPType.SD3:
|
if clip_type == CLIPType.SD3:
|
||||||
@ -1160,7 +1168,7 @@ def load_diffusion_model_state_dict(sd, model_options={}):
|
|||||||
model.load_model_weights(new_sd, "")
|
model.load_model_weights(new_sd, "")
|
||||||
left_over = sd.keys()
|
left_over = sd.keys()
|
||||||
if len(left_over) > 0:
|
if len(left_over) > 0:
|
||||||
logging.info("left over keys in unet: {}".format(left_over))
|
logging.info("left over keys in diffusion model: {}".format(left_over))
|
||||||
return comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=offload_device)
|
return comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=offload_device)
|
||||||
|
|
||||||
|
|
||||||
@ -1168,7 +1176,7 @@ def load_diffusion_model(unet_path, model_options={}):
|
|||||||
sd = comfy.utils.load_torch_file(unet_path)
|
sd = comfy.utils.load_torch_file(unet_path)
|
||||||
model = load_diffusion_model_state_dict(sd, model_options=model_options)
|
model = load_diffusion_model_state_dict(sd, model_options=model_options)
|
||||||
if model is None:
|
if model is None:
|
||||||
logging.error("ERROR UNSUPPORTED UNET {}".format(unet_path))
|
logging.error("ERROR UNSUPPORTED DIFFUSION MODEL {}".format(unet_path))
|
||||||
raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path))
|
raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path))
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
@ -482,7 +482,8 @@ class SDTokenizer:
|
|||||||
if end_token is not None:
|
if end_token is not None:
|
||||||
self.end_token = end_token
|
self.end_token = end_token
|
||||||
else:
|
else:
|
||||||
self.end_token = empty[0]
|
if has_end_token:
|
||||||
|
self.end_token = empty[0]
|
||||||
|
|
||||||
if pad_token is not None:
|
if pad_token is not None:
|
||||||
self.pad_token = pad_token
|
self.pad_token = pad_token
|
||||||
|
@ -18,6 +18,7 @@ import comfy.text_encoders.cosmos
|
|||||||
import comfy.text_encoders.lumina2
|
import comfy.text_encoders.lumina2
|
||||||
import comfy.text_encoders.wan
|
import comfy.text_encoders.wan
|
||||||
import comfy.text_encoders.ace
|
import comfy.text_encoders.ace
|
||||||
|
import comfy.text_encoders.omnigen2
|
||||||
|
|
||||||
from . import supported_models_base
|
from . import supported_models_base
|
||||||
from . import latent_formats
|
from . import latent_formats
|
||||||
@ -1181,6 +1182,41 @@ class ACEStep(supported_models_base.BASE):
|
|||||||
def clip_target(self, state_dict={}):
|
def clip_target(self, state_dict={}):
|
||||||
return supported_models_base.ClipTarget(comfy.text_encoders.ace.AceT5Tokenizer, comfy.text_encoders.ace.AceT5Model)
|
return supported_models_base.ClipTarget(comfy.text_encoders.ace.AceT5Tokenizer, comfy.text_encoders.ace.AceT5Model)
|
||||||
|
|
||||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep]
|
class Omnigen2(supported_models_base.BASE):
|
||||||
|
unet_config = {
|
||||||
|
"image_model": "omnigen2",
|
||||||
|
}
|
||||||
|
|
||||||
|
sampling_settings = {
|
||||||
|
"multiplier": 1.0,
|
||||||
|
"shift": 2.6,
|
||||||
|
}
|
||||||
|
|
||||||
|
memory_usage_factor = 1.65 #TODO
|
||||||
|
|
||||||
|
unet_extra_config = {}
|
||||||
|
latent_format = latent_formats.Flux
|
||||||
|
|
||||||
|
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
||||||
|
|
||||||
|
vae_key_prefix = ["vae."]
|
||||||
|
text_encoder_key_prefix = ["text_encoders."]
|
||||||
|
|
||||||
|
def __init__(self, unet_config):
|
||||||
|
super().__init__(unet_config)
|
||||||
|
if comfy.model_management.extended_fp16_support():
|
||||||
|
self.supported_inference_dtypes = [torch.float16] + self.supported_inference_dtypes
|
||||||
|
|
||||||
|
def get_model(self, state_dict, prefix="", device=None):
|
||||||
|
out = model_base.Omnigen2(self, device=device)
|
||||||
|
return out
|
||||||
|
|
||||||
|
def clip_target(self, state_dict={}):
|
||||||
|
pref = self.text_encoder_key_prefix[0]
|
||||||
|
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_3b.transformer.".format(pref))
|
||||||
|
return supported_models_base.ClipTarget(comfy.text_encoders.omnigen2.LuminaTokenizer, comfy.text_encoders.omnigen2.te(**hunyuan_detect))
|
||||||
|
|
||||||
|
|
||||||
|
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep, Omnigen2]
|
||||||
|
|
||||||
models += [SVD_img2vid]
|
models += [SVD_img2vid]
|
||||||
|
@ -24,6 +24,24 @@ class Llama2Config:
|
|||||||
head_dim = 128
|
head_dim = 128
|
||||||
rms_norm_add = False
|
rms_norm_add = False
|
||||||
mlp_activation = "silu"
|
mlp_activation = "silu"
|
||||||
|
qkv_bias = False
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Qwen25_3BConfig:
|
||||||
|
vocab_size: int = 151936
|
||||||
|
hidden_size: int = 2048
|
||||||
|
intermediate_size: int = 11008
|
||||||
|
num_hidden_layers: int = 36
|
||||||
|
num_attention_heads: int = 16
|
||||||
|
num_key_value_heads: int = 2
|
||||||
|
max_position_embeddings: int = 128000
|
||||||
|
rms_norm_eps: float = 1e-6
|
||||||
|
rope_theta: float = 1000000.0
|
||||||
|
transformer_type: str = "llama"
|
||||||
|
head_dim = 128
|
||||||
|
rms_norm_add = False
|
||||||
|
mlp_activation = "silu"
|
||||||
|
qkv_bias = True
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Gemma2_2B_Config:
|
class Gemma2_2B_Config:
|
||||||
@ -40,6 +58,7 @@ class Gemma2_2B_Config:
|
|||||||
head_dim = 256
|
head_dim = 256
|
||||||
rms_norm_add = True
|
rms_norm_add = True
|
||||||
mlp_activation = "gelu_pytorch_tanh"
|
mlp_activation = "gelu_pytorch_tanh"
|
||||||
|
qkv_bias = False
|
||||||
|
|
||||||
class RMSNorm(nn.Module):
|
class RMSNorm(nn.Module):
|
||||||
def __init__(self, dim: int, eps: float = 1e-5, add=False, device=None, dtype=None):
|
def __init__(self, dim: int, eps: float = 1e-5, add=False, device=None, dtype=None):
|
||||||
@ -98,9 +117,9 @@ class Attention(nn.Module):
|
|||||||
self.inner_size = self.num_heads * self.head_dim
|
self.inner_size = self.num_heads * self.head_dim
|
||||||
|
|
||||||
ops = ops or nn
|
ops = ops or nn
|
||||||
self.q_proj = ops.Linear(config.hidden_size, self.inner_size, bias=False, device=device, dtype=dtype)
|
self.q_proj = ops.Linear(config.hidden_size, self.inner_size, bias=config.qkv_bias, device=device, dtype=dtype)
|
||||||
self.k_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False, device=device, dtype=dtype)
|
self.k_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=config.qkv_bias, device=device, dtype=dtype)
|
||||||
self.v_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False, device=device, dtype=dtype)
|
self.v_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=config.qkv_bias, device=device, dtype=dtype)
|
||||||
self.o_proj = ops.Linear(self.inner_size, config.hidden_size, bias=False, device=device, dtype=dtype)
|
self.o_proj = ops.Linear(self.inner_size, config.hidden_size, bias=False, device=device, dtype=dtype)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@ -320,6 +339,14 @@ class Llama2(BaseLlama, torch.nn.Module):
|
|||||||
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
|
|
||||||
|
class Qwen25_3B(BaseLlama, torch.nn.Module):
|
||||||
|
def __init__(self, config_dict, dtype, device, operations):
|
||||||
|
super().__init__()
|
||||||
|
config = Qwen25_3BConfig(**config_dict)
|
||||||
|
self.num_layers = config.num_hidden_layers
|
||||||
|
|
||||||
|
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
||||||
|
self.dtype = dtype
|
||||||
|
|
||||||
class Gemma2_2B(BaseLlama, torch.nn.Module):
|
class Gemma2_2B(BaseLlama, torch.nn.Module):
|
||||||
def __init__(self, config_dict, dtype, device, operations):
|
def __init__(self, config_dict, dtype, device, operations):
|
||||||
|
44
comfy/text_encoders/omnigen2.py
Normal file
44
comfy/text_encoders/omnigen2.py
Normal file
@ -0,0 +1,44 @@
|
|||||||
|
from transformers import Qwen2Tokenizer
|
||||||
|
from comfy import sd1_clip
|
||||||
|
import comfy.text_encoders.llama
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen25_3BTokenizer(sd1_clip.SDTokenizer):
|
||||||
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
|
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer")
|
||||||
|
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=2048, embedding_key='qwen25_3b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=151643, tokenizer_data=tokenizer_data)
|
||||||
|
|
||||||
|
|
||||||
|
class Omnigen2Tokenizer(sd1_clip.SD1Tokenizer):
|
||||||
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
|
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="qwen25_3b", tokenizer=Qwen25_3BTokenizer)
|
||||||
|
self.llama_template = '<|im_start|>system\nYou are a helpful assistant that generates high-quality images based on user instructions.<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n'
|
||||||
|
|
||||||
|
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None,**kwargs):
|
||||||
|
if llama_template is None:
|
||||||
|
llama_text = self.llama_template.format(text)
|
||||||
|
else:
|
||||||
|
llama_text = llama_template.format(text)
|
||||||
|
return super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, **kwargs)
|
||||||
|
|
||||||
|
class Qwen25_3BModel(sd1_clip.SDClipModel):
|
||||||
|
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}):
|
||||||
|
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Qwen25_3B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
|
||||||
|
|
||||||
|
|
||||||
|
class Omnigen2Model(sd1_clip.SD1ClipModel):
|
||||||
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||||
|
super().__init__(device=device, dtype=dtype, name="qwen25_3b", clip_model=Qwen25_3BModel, model_options=model_options)
|
||||||
|
|
||||||
|
|
||||||
|
def te(dtype_llama=None, llama_scaled_fp8=None):
|
||||||
|
class Omnigen2TEModel_(Omnigen2Model):
|
||||||
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||||
|
if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options:
|
||||||
|
model_options = model_options.copy()
|
||||||
|
model_options["scaled_fp8"] = llama_scaled_fp8
|
||||||
|
if dtype_llama is not None:
|
||||||
|
dtype = dtype_llama
|
||||||
|
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
||||||
|
return Omnigen2TEModel_
|
151388
comfy/text_encoders/qwen25_tokenizer/merges.txt
Normal file
151388
comfy/text_encoders/qwen25_tokenizer/merges.txt
Normal file
File diff suppressed because it is too large
Load Diff
241
comfy/text_encoders/qwen25_tokenizer/tokenizer_config.json
Normal file
241
comfy/text_encoders/qwen25_tokenizer/tokenizer_config.json
Normal file
@ -0,0 +1,241 @@
|
|||||||
|
{
|
||||||
|
"add_bos_token": false,
|
||||||
|
"add_prefix_space": false,
|
||||||
|
"added_tokens_decoder": {
|
||||||
|
"151643": {
|
||||||
|
"content": "<|endoftext|>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"151644": {
|
||||||
|
"content": "<|im_start|>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"151645": {
|
||||||
|
"content": "<|im_end|>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"151646": {
|
||||||
|
"content": "<|object_ref_start|>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"151647": {
|
||||||
|
"content": "<|object_ref_end|>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"151648": {
|
||||||
|
"content": "<|box_start|>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"151649": {
|
||||||
|
"content": "<|box_end|>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"151650": {
|
||||||
|
"content": "<|quad_start|>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"151651": {
|
||||||
|
"content": "<|quad_end|>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"151652": {
|
||||||
|
"content": "<|vision_start|>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"151653": {
|
||||||
|
"content": "<|vision_end|>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"151654": {
|
||||||
|
"content": "<|vision_pad|>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"151655": {
|
||||||
|
"content": "<|image_pad|>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"151656": {
|
||||||
|
"content": "<|video_pad|>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"151657": {
|
||||||
|
"content": "<tool_call>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": false
|
||||||
|
},
|
||||||
|
"151658": {
|
||||||
|
"content": "</tool_call>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": false
|
||||||
|
},
|
||||||
|
"151659": {
|
||||||
|
"content": "<|fim_prefix|>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": false
|
||||||
|
},
|
||||||
|
"151660": {
|
||||||
|
"content": "<|fim_middle|>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": false
|
||||||
|
},
|
||||||
|
"151661": {
|
||||||
|
"content": "<|fim_suffix|>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": false
|
||||||
|
},
|
||||||
|
"151662": {
|
||||||
|
"content": "<|fim_pad|>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": false
|
||||||
|
},
|
||||||
|
"151663": {
|
||||||
|
"content": "<|repo_name|>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": false
|
||||||
|
},
|
||||||
|
"151664": {
|
||||||
|
"content": "<|file_sep|>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": false
|
||||||
|
},
|
||||||
|
"151665": {
|
||||||
|
"content": "<|img|>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"151666": {
|
||||||
|
"content": "<|endofimg|>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"151667": {
|
||||||
|
"content": "<|meta|>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"151668": {
|
||||||
|
"content": "<|endofmeta|>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"additional_special_tokens": [
|
||||||
|
"<|im_start|>",
|
||||||
|
"<|im_end|>",
|
||||||
|
"<|object_ref_start|>",
|
||||||
|
"<|object_ref_end|>",
|
||||||
|
"<|box_start|>",
|
||||||
|
"<|box_end|>",
|
||||||
|
"<|quad_start|>",
|
||||||
|
"<|quad_end|>",
|
||||||
|
"<|vision_start|>",
|
||||||
|
"<|vision_end|>",
|
||||||
|
"<|vision_pad|>",
|
||||||
|
"<|image_pad|>",
|
||||||
|
"<|video_pad|>"
|
||||||
|
],
|
||||||
|
"bos_token": null,
|
||||||
|
"chat_template": "{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0]['role'] == 'system' %}\n {{- messages[0]['content'] }}\n {%- else %}\n {{- 'You are a helpful assistant.' }}\n {%- endif %}\n {{- \"\\n\\n# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within <tools></tools> XML tags:\\n<tools>\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n</tools>\\n\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\n<tool_call>\\n{\\\"name\\\": <function-name>, \\\"arguments\\\": <args-json-object>}\\n</tool_call><|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0]['role'] == 'system' %}\n {{- '<|im_start|>system\\n' + messages[0]['content'] + '<|im_end|>\\n' }}\n {%- else %}\n {{- '<|im_start|>system\\nYou are a helpful assistant.<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- for message in messages %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) or (message.role == \"assistant\" and not message.tool_calls) %}\n {{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {{- '<|im_start|>' + message.role }}\n {%- if message.content %}\n {{- '\\n' + message.content }}\n {%- endif %}\n {%- for tool_call in message.tool_calls %}\n {%- if tool_call.function is defined %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '\\n<tool_call>\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {{- tool_call.arguments | tojson }}\n {{- '}\\n</tool_call>' }}\n {%- endfor %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n<tool_response>\\n' }}\n {{- message.content }}\n {{- '\\n</tool_response>' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n{%- endif %}\n",
|
||||||
|
"clean_up_tokenization_spaces": false,
|
||||||
|
"eos_token": "<|im_end|>",
|
||||||
|
"errors": "replace",
|
||||||
|
"extra_special_tokens": {},
|
||||||
|
"model_max_length": 131072,
|
||||||
|
"pad_token": "<|endoftext|>",
|
||||||
|
"processor_class": "Qwen2_5_VLProcessor",
|
||||||
|
"split_special_tokens": false,
|
||||||
|
"tokenizer_class": "Qwen2Tokenizer",
|
||||||
|
"unk_token": null
|
||||||
|
}
|
1
comfy/text_encoders/qwen25_tokenizer/vocab.json
Normal file
1
comfy/text_encoders/qwen25_tokenizer/vocab.json
Normal file
File diff suppressed because one or more lines are too long
@ -146,7 +146,7 @@ class T5Attention(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
values = self.relative_attention_bias(relative_position_bucket, out_dtype=dtype) # shape (query_length, key_length, num_heads)
|
values = self.relative_attention_bias(relative_position_bucket, out_dtype=dtype) # shape (query_length, key_length, num_heads)
|
||||||
values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length)
|
values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length)
|
||||||
return values
|
return values.contiguous()
|
||||||
|
|
||||||
def forward(self, x, mask=None, past_bias=None, optimized_attention=None):
|
def forward(self, x, mask=None, past_bias=None, optimized_attention=None):
|
||||||
q = self.q(x)
|
q = self.q(x)
|
||||||
|
@ -2,6 +2,7 @@ import math
|
|||||||
import comfy.samplers
|
import comfy.samplers
|
||||||
import comfy.sample
|
import comfy.sample
|
||||||
from comfy.k_diffusion import sampling as k_diffusion_sampling
|
from comfy.k_diffusion import sampling as k_diffusion_sampling
|
||||||
|
from comfy.comfy_types import IO, ComfyNodeABC, InputTypeDict
|
||||||
import latent_preview
|
import latent_preview
|
||||||
import torch
|
import torch
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
@ -480,6 +481,46 @@ class SamplerDPMAdaptative:
|
|||||||
"s_noise":s_noise })
|
"s_noise":s_noise })
|
||||||
return (sampler, )
|
return (sampler, )
|
||||||
|
|
||||||
|
|
||||||
|
class SamplerER_SDE(ComfyNodeABC):
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls) -> InputTypeDict:
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"solver_type": (IO.COMBO, {"options": ["ER-SDE", "Reverse-time SDE", "ODE"]}),
|
||||||
|
"max_stage": (IO.INT, {"default": 3, "min": 1, "max": 3}),
|
||||||
|
"eta": (
|
||||||
|
IO.FLOAT,
|
||||||
|
{"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01, "round": False, "tooltip": "Stochastic strength of reverse-time SDE.\nWhen eta=0, it reduces to deterministic ODE. This setting doesn't apply to ER-SDE solver type."},
|
||||||
|
),
|
||||||
|
"s_noise": (IO.FLOAT, {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01, "round": False}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = (IO.SAMPLER,)
|
||||||
|
CATEGORY = "sampling/custom_sampling/samplers"
|
||||||
|
|
||||||
|
FUNCTION = "get_sampler"
|
||||||
|
|
||||||
|
def get_sampler(self, solver_type, max_stage, eta, s_noise):
|
||||||
|
if solver_type == "ODE" or (solver_type == "Reverse-time SDE" and eta == 0):
|
||||||
|
eta = 0
|
||||||
|
s_noise = 0
|
||||||
|
|
||||||
|
def reverse_time_sde_noise_scaler(x):
|
||||||
|
return x ** (eta + 1)
|
||||||
|
|
||||||
|
if solver_type == "ER-SDE":
|
||||||
|
# Use the default one in sample_er_sde()
|
||||||
|
noise_scaler = None
|
||||||
|
else:
|
||||||
|
noise_scaler = reverse_time_sde_noise_scaler
|
||||||
|
|
||||||
|
sampler_name = "er_sde"
|
||||||
|
sampler = comfy.samplers.ksampler(sampler_name, {"s_noise": s_noise, "noise_scaler": noise_scaler, "max_stage": max_stage})
|
||||||
|
return (sampler,)
|
||||||
|
|
||||||
|
|
||||||
class Noise_EmptyNoise:
|
class Noise_EmptyNoise:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.seed = 0
|
self.seed = 0
|
||||||
@ -609,8 +650,14 @@ class Guider_DualCFG(comfy.samplers.CFGGuider):
|
|||||||
def predict_noise(self, x, timestep, model_options={}, seed=None):
|
def predict_noise(self, x, timestep, model_options={}, seed=None):
|
||||||
negative_cond = self.conds.get("negative", None)
|
negative_cond = self.conds.get("negative", None)
|
||||||
middle_cond = self.conds.get("middle", None)
|
middle_cond = self.conds.get("middle", None)
|
||||||
|
positive_cond = self.conds.get("positive", None)
|
||||||
|
if model_options.get("disable_cfg1_optimization", False) == False:
|
||||||
|
if math.isclose(self.cfg2, 1.0):
|
||||||
|
negative_cond = None
|
||||||
|
if math.isclose(self.cfg1, 1.0):
|
||||||
|
middle_cond = None
|
||||||
|
|
||||||
out = comfy.samplers.calc_cond_batch(self.inner_model, [negative_cond, middle_cond, self.conds.get("positive", None)], x, timestep, model_options)
|
out = comfy.samplers.calc_cond_batch(self.inner_model, [negative_cond, middle_cond, positive_cond], x, timestep, model_options)
|
||||||
return comfy.samplers.cfg_function(self.inner_model, out[1], out[0], self.cfg2, x, timestep, model_options=model_options, cond=middle_cond, uncond=negative_cond) + (out[2] - out[1]) * self.cfg1
|
return comfy.samplers.cfg_function(self.inner_model, out[1], out[0], self.cfg2, x, timestep, model_options=model_options, cond=middle_cond, uncond=negative_cond) + (out[2] - out[1]) * self.cfg1
|
||||||
|
|
||||||
class DualCFGGuider:
|
class DualCFGGuider:
|
||||||
@ -781,6 +828,7 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
"SamplerDPMPP_SDE": SamplerDPMPP_SDE,
|
"SamplerDPMPP_SDE": SamplerDPMPP_SDE,
|
||||||
"SamplerDPMPP_2S_Ancestral": SamplerDPMPP_2S_Ancestral,
|
"SamplerDPMPP_2S_Ancestral": SamplerDPMPP_2S_Ancestral,
|
||||||
"SamplerDPMAdaptative": SamplerDPMAdaptative,
|
"SamplerDPMAdaptative": SamplerDPMAdaptative,
|
||||||
|
"SamplerER_SDE": SamplerER_SDE,
|
||||||
"SplitSigmas": SplitSigmas,
|
"SplitSigmas": SplitSigmas,
|
||||||
"SplitSigmasDenoise": SplitSigmasDenoise,
|
"SplitSigmasDenoise": SplitSigmasDenoise,
|
||||||
"FlipSigmas": FlipSigmas,
|
"FlipSigmas": FlipSigmas,
|
||||||
|
26
comfy_extras/nodes_edit_model.py
Normal file
26
comfy_extras/nodes_edit_model.py
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
import node_helpers
|
||||||
|
|
||||||
|
|
||||||
|
class ReferenceLatent:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": {"conditioning": ("CONDITIONING", ),
|
||||||
|
},
|
||||||
|
"optional": {"latent": ("LATENT", ),}
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("CONDITIONING",)
|
||||||
|
FUNCTION = "append"
|
||||||
|
|
||||||
|
CATEGORY = "advanced/conditioning/edit_models"
|
||||||
|
DESCRIPTION = "This node sets the guiding latent for an edit model. If the model supports it you can chain multiple to set multiple reference images."
|
||||||
|
|
||||||
|
def append(self, conditioning, latent=None):
|
||||||
|
if latent is not None:
|
||||||
|
conditioning = node_helpers.conditioning_set_values(conditioning, {"reference_latents": [latent["samples"]]}, append=True)
|
||||||
|
return (conditioning, )
|
||||||
|
|
||||||
|
|
||||||
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
"ReferenceLatent": ReferenceLatent,
|
||||||
|
}
|
@ -1,4 +1,5 @@
|
|||||||
import node_helpers
|
import node_helpers
|
||||||
|
import comfy.utils
|
||||||
|
|
||||||
class CLIPTextEncodeFlux:
|
class CLIPTextEncodeFlux:
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -56,8 +57,52 @@ class FluxDisableGuidance:
|
|||||||
return (c, )
|
return (c, )
|
||||||
|
|
||||||
|
|
||||||
|
PREFERED_KONTEXT_RESOLUTIONS = [
|
||||||
|
(672, 1568),
|
||||||
|
(688, 1504),
|
||||||
|
(720, 1456),
|
||||||
|
(752, 1392),
|
||||||
|
(800, 1328),
|
||||||
|
(832, 1248),
|
||||||
|
(880, 1184),
|
||||||
|
(944, 1104),
|
||||||
|
(1024, 1024),
|
||||||
|
(1104, 944),
|
||||||
|
(1184, 880),
|
||||||
|
(1248, 832),
|
||||||
|
(1328, 800),
|
||||||
|
(1392, 752),
|
||||||
|
(1456, 720),
|
||||||
|
(1504, 688),
|
||||||
|
(1568, 672),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class FluxKontextImageScale:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": {"image": ("IMAGE", ),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("IMAGE",)
|
||||||
|
FUNCTION = "scale"
|
||||||
|
|
||||||
|
CATEGORY = "advanced/conditioning/flux"
|
||||||
|
DESCRIPTION = "This node resizes the image to one that is more optimal for flux kontext."
|
||||||
|
|
||||||
|
def scale(self, image):
|
||||||
|
width = image.shape[2]
|
||||||
|
height = image.shape[1]
|
||||||
|
aspect_ratio = width / height
|
||||||
|
_, width, height = min((abs(aspect_ratio - w / h), w, h) for w, h in PREFERED_KONTEXT_RESOLUTIONS)
|
||||||
|
image = comfy.utils.common_upscale(image.movedim(-1, 1), width, height, "lanczos", "center").movedim(1, -1)
|
||||||
|
return (image, )
|
||||||
|
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
"CLIPTextEncodeFlux": CLIPTextEncodeFlux,
|
"CLIPTextEncodeFlux": CLIPTextEncodeFlux,
|
||||||
"FluxGuidance": FluxGuidance,
|
"FluxGuidance": FluxGuidance,
|
||||||
"FluxDisableGuidance": FluxDisableGuidance,
|
"FluxDisableGuidance": FluxDisableGuidance,
|
||||||
|
"FluxKontextImageScale": FluxKontextImageScale,
|
||||||
}
|
}
|
||||||
|
@ -268,6 +268,52 @@ class ModelMergeWAN2_1(comfy_extras.nodes_model_merging.ModelMergeBlocks):
|
|||||||
|
|
||||||
return {"required": arg_dict}
|
return {"required": arg_dict}
|
||||||
|
|
||||||
|
class ModelMergeCosmosPredict2_2B(comfy_extras.nodes_model_merging.ModelMergeBlocks):
|
||||||
|
CATEGORY = "advanced/model_merging/model_specific"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
arg_dict = { "model1": ("MODEL",),
|
||||||
|
"model2": ("MODEL",)}
|
||||||
|
|
||||||
|
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
|
||||||
|
|
||||||
|
arg_dict["pos_embedder."] = argument
|
||||||
|
arg_dict["x_embedder."] = argument
|
||||||
|
arg_dict["t_embedder."] = argument
|
||||||
|
arg_dict["t_embedding_norm."] = argument
|
||||||
|
|
||||||
|
|
||||||
|
for i in range(28):
|
||||||
|
arg_dict["blocks.{}.".format(i)] = argument
|
||||||
|
|
||||||
|
arg_dict["final_layer."] = argument
|
||||||
|
|
||||||
|
return {"required": arg_dict}
|
||||||
|
|
||||||
|
class ModelMergeCosmosPredict2_14B(comfy_extras.nodes_model_merging.ModelMergeBlocks):
|
||||||
|
CATEGORY = "advanced/model_merging/model_specific"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
arg_dict = { "model1": ("MODEL",),
|
||||||
|
"model2": ("MODEL",)}
|
||||||
|
|
||||||
|
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
|
||||||
|
|
||||||
|
arg_dict["pos_embedder."] = argument
|
||||||
|
arg_dict["x_embedder."] = argument
|
||||||
|
arg_dict["t_embedder."] = argument
|
||||||
|
arg_dict["t_embedding_norm."] = argument
|
||||||
|
|
||||||
|
|
||||||
|
for i in range(36):
|
||||||
|
arg_dict["blocks.{}.".format(i)] = argument
|
||||||
|
|
||||||
|
arg_dict["final_layer."] = argument
|
||||||
|
|
||||||
|
return {"required": arg_dict}
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
"ModelMergeSD1": ModelMergeSD1,
|
"ModelMergeSD1": ModelMergeSD1,
|
||||||
"ModelMergeSD2": ModelMergeSD1, #SD1 and SD2 have the same blocks
|
"ModelMergeSD2": ModelMergeSD1, #SD1 and SD2 have the same blocks
|
||||||
@ -281,4 +327,6 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
"ModelMergeCosmos7B": ModelMergeCosmos7B,
|
"ModelMergeCosmos7B": ModelMergeCosmos7B,
|
||||||
"ModelMergeCosmos14B": ModelMergeCosmos14B,
|
"ModelMergeCosmos14B": ModelMergeCosmos14B,
|
||||||
"ModelMergeWAN2_1": ModelMergeWAN2_1,
|
"ModelMergeWAN2_1": ModelMergeWAN2_1,
|
||||||
|
"ModelMergeCosmosPredict2_2B": ModelMergeCosmosPredict2_2B,
|
||||||
|
"ModelMergeCosmosPredict2_14B": ModelMergeCosmosPredict2_14B,
|
||||||
}
|
}
|
||||||
|
@ -4,6 +4,7 @@ import comfy.sampler_helpers
|
|||||||
import comfy.samplers
|
import comfy.samplers
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
import node_helpers
|
import node_helpers
|
||||||
|
import math
|
||||||
|
|
||||||
def perp_neg(x, noise_pred_pos, noise_pred_neg, noise_pred_nocond, neg_scale, cond_scale):
|
def perp_neg(x, noise_pred_pos, noise_pred_neg, noise_pred_nocond, neg_scale, cond_scale):
|
||||||
pos = noise_pred_pos - noise_pred_nocond
|
pos = noise_pred_pos - noise_pred_nocond
|
||||||
@ -69,8 +70,23 @@ class Guider_PerpNeg(comfy.samplers.CFGGuider):
|
|||||||
negative_cond = self.conds.get("negative", None)
|
negative_cond = self.conds.get("negative", None)
|
||||||
empty_cond = self.conds.get("empty_negative_prompt", None)
|
empty_cond = self.conds.get("empty_negative_prompt", None)
|
||||||
|
|
||||||
(noise_pred_pos, noise_pred_neg, noise_pred_empty) = \
|
if model_options.get("disable_cfg1_optimization", False) == False:
|
||||||
comfy.samplers.calc_cond_batch(self.inner_model, [positive_cond, negative_cond, empty_cond], x, timestep, model_options)
|
if math.isclose(self.neg_scale, 0.0):
|
||||||
|
negative_cond = None
|
||||||
|
if math.isclose(self.cfg, 1.0):
|
||||||
|
empty_cond = None
|
||||||
|
|
||||||
|
conds = [positive_cond, negative_cond, empty_cond]
|
||||||
|
|
||||||
|
out = comfy.samplers.calc_cond_batch(self.inner_model, conds, x, timestep, model_options)
|
||||||
|
|
||||||
|
# Apply pre_cfg_functions since sampling_function() is skipped
|
||||||
|
for fn in model_options.get("sampler_pre_cfg_function", []):
|
||||||
|
args = {"conds":conds, "conds_out": out, "cond_scale": self.cfg, "timestep": timestep,
|
||||||
|
"input": x, "sigma": timestep, "model": self.inner_model, "model_options": model_options}
|
||||||
|
out = fn(args)
|
||||||
|
|
||||||
|
noise_pred_pos, noise_pred_neg, noise_pred_empty = out
|
||||||
cfg_result = perp_neg(x, noise_pred_pos, noise_pred_neg, noise_pred_empty, self.neg_scale, self.cfg)
|
cfg_result = perp_neg(x, noise_pred_pos, noise_pred_neg, noise_pred_empty, self.neg_scale, self.cfg)
|
||||||
|
|
||||||
# normally this would be done in cfg_function, but we skipped
|
# normally this would be done in cfg_function, but we skipped
|
||||||
@ -82,6 +98,7 @@ class Guider_PerpNeg(comfy.samplers.CFGGuider):
|
|||||||
"denoised": cfg_result,
|
"denoised": cfg_result,
|
||||||
"cond": positive_cond,
|
"cond": positive_cond,
|
||||||
"uncond": negative_cond,
|
"uncond": negative_cond,
|
||||||
|
"cond_scale": self.cfg,
|
||||||
"model": self.inner_model,
|
"model": self.inner_model,
|
||||||
"uncond_denoised": noise_pred_neg,
|
"uncond_denoised": noise_pred_neg,
|
||||||
"cond_denoised": noise_pred_pos,
|
"cond_denoised": noise_pred_pos,
|
||||||
|
71
comfy_extras/nodes_tcfg.py
Normal file
71
comfy_extras/nodes_tcfg.py
Normal file
@ -0,0 +1,71 @@
|
|||||||
|
# TCFG: Tangential Damping Classifier-free Guidance - (arXiv: https://arxiv.org/abs/2503.18137)
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from comfy.comfy_types import IO, ComfyNodeABC, InputTypeDict
|
||||||
|
|
||||||
|
|
||||||
|
def score_tangential_damping(cond_score: torch.Tensor, uncond_score: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Drop tangential components from uncond score to align with cond score."""
|
||||||
|
# (B, 1, ...)
|
||||||
|
batch_num = cond_score.shape[0]
|
||||||
|
cond_score_flat = cond_score.reshape(batch_num, 1, -1).float()
|
||||||
|
uncond_score_flat = uncond_score.reshape(batch_num, 1, -1).float()
|
||||||
|
|
||||||
|
# Score matrix A (B, 2, ...)
|
||||||
|
score_matrix = torch.cat((uncond_score_flat, cond_score_flat), dim=1)
|
||||||
|
try:
|
||||||
|
_, _, Vh = torch.linalg.svd(score_matrix, full_matrices=False)
|
||||||
|
except RuntimeError:
|
||||||
|
# Fallback to CPU
|
||||||
|
_, _, Vh = torch.linalg.svd(score_matrix.cpu(), full_matrices=False)
|
||||||
|
|
||||||
|
# Drop the tangential components
|
||||||
|
v1 = Vh[:, 0:1, :].to(uncond_score_flat.device) # (B, 1, ...)
|
||||||
|
uncond_score_td = (uncond_score_flat @ v1.transpose(-2, -1)) * v1
|
||||||
|
return uncond_score_td.reshape_as(uncond_score).to(uncond_score.dtype)
|
||||||
|
|
||||||
|
|
||||||
|
class TCFG(ComfyNodeABC):
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls) -> InputTypeDict:
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"model": (IO.MODEL, {}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = (IO.MODEL,)
|
||||||
|
RETURN_NAMES = ("patched_model",)
|
||||||
|
FUNCTION = "patch"
|
||||||
|
|
||||||
|
CATEGORY = "advanced/guidance"
|
||||||
|
DESCRIPTION = "TCFG – Tangential Damping CFG (2503.18137)\n\nRefine the uncond (negative) to align with the cond (positive) for improving quality."
|
||||||
|
|
||||||
|
def patch(self, model):
|
||||||
|
m = model.clone()
|
||||||
|
|
||||||
|
def tangential_damping_cfg(args):
|
||||||
|
# Assume [cond, uncond, ...]
|
||||||
|
x = args["input"]
|
||||||
|
conds_out = args["conds_out"]
|
||||||
|
if len(conds_out) <= 1 or None in args["conds"][:2]:
|
||||||
|
# Skip when either cond or uncond is None
|
||||||
|
return conds_out
|
||||||
|
cond_pred = conds_out[0]
|
||||||
|
uncond_pred = conds_out[1]
|
||||||
|
uncond_td = score_tangential_damping(x - cond_pred, x - uncond_pred)
|
||||||
|
uncond_pred_td = x - uncond_td
|
||||||
|
return [cond_pred, uncond_pred_td] + conds_out[2:]
|
||||||
|
|
||||||
|
m.set_model_sampler_pre_cfg_function(tangential_damping_cfg)
|
||||||
|
return (m,)
|
||||||
|
|
||||||
|
|
||||||
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
"TCFG": TCFG,
|
||||||
|
}
|
||||||
|
|
||||||
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
|
"TCFG": "Tangential Damping CFG",
|
||||||
|
}
|
@ -1,3 +1,3 @@
|
|||||||
# This file is automatically generated by the build process when version is
|
# This file is automatically generated by the build process when version is
|
||||||
# updated in pyproject.toml.
|
# updated in pyproject.toml.
|
||||||
__version__ = "0.3.41"
|
__version__ = "0.3.43"
|
||||||
|
14
main.py
14
main.py
@ -55,6 +55,9 @@ def apply_custom_paths():
|
|||||||
|
|
||||||
|
|
||||||
def execute_prestartup_script():
|
def execute_prestartup_script():
|
||||||
|
if args.disable_all_custom_nodes and len(args.whitelist_custom_nodes) == 0:
|
||||||
|
return
|
||||||
|
|
||||||
def execute_script(script_path):
|
def execute_script(script_path):
|
||||||
module_name = os.path.splitext(script_path)[0]
|
module_name = os.path.splitext(script_path)[0]
|
||||||
try:
|
try:
|
||||||
@ -66,9 +69,6 @@ def execute_prestartup_script():
|
|||||||
logging.error(f"Failed to execute startup-script: {script_path} / {e}")
|
logging.error(f"Failed to execute startup-script: {script_path} / {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if args.disable_all_custom_nodes:
|
|
||||||
return
|
|
||||||
|
|
||||||
node_paths = folder_paths.get_folder_paths("custom_nodes")
|
node_paths = folder_paths.get_folder_paths("custom_nodes")
|
||||||
for custom_node_path in node_paths:
|
for custom_node_path in node_paths:
|
||||||
possible_modules = os.listdir(custom_node_path)
|
possible_modules = os.listdir(custom_node_path)
|
||||||
@ -81,6 +81,9 @@ def execute_prestartup_script():
|
|||||||
|
|
||||||
script_path = os.path.join(module_path, "prestartup_script.py")
|
script_path = os.path.join(module_path, "prestartup_script.py")
|
||||||
if os.path.exists(script_path):
|
if os.path.exists(script_path):
|
||||||
|
if args.disable_all_custom_nodes and possible_module not in args.whitelist_custom_nodes:
|
||||||
|
logging.info(f"Prestartup Skipping {possible_module} due to disable_all_custom_nodes and whitelist_custom_nodes")
|
||||||
|
continue
|
||||||
time_before = time.perf_counter()
|
time_before = time.perf_counter()
|
||||||
success = execute_script(script_path)
|
success = execute_script(script_path)
|
||||||
node_prestartup_times.append((time.perf_counter() - time_before, module_path, success))
|
node_prestartup_times.append((time.perf_counter() - time_before, module_path, success))
|
||||||
@ -276,7 +279,10 @@ def start_comfyui(asyncio_loop=None):
|
|||||||
prompt_server = server.PromptServer(asyncio_loop)
|
prompt_server = server.PromptServer(asyncio_loop)
|
||||||
|
|
||||||
hook_breaker_ac10a0.save_functions()
|
hook_breaker_ac10a0.save_functions()
|
||||||
nodes.init_extra_nodes(init_custom_nodes=not args.disable_all_custom_nodes, init_api_nodes=not args.disable_api_nodes)
|
nodes.init_extra_nodes(
|
||||||
|
init_custom_nodes=(not args.disable_all_custom_nodes) or len(args.whitelist_custom_nodes) > 0,
|
||||||
|
init_api_nodes=not args.disable_api_nodes
|
||||||
|
)
|
||||||
hook_breaker_ac10a0.restore_functions()
|
hook_breaker_ac10a0.restore_functions()
|
||||||
|
|
||||||
cuda_malloc_warning()
|
cuda_malloc_warning()
|
||||||
|
9
nodes.py
9
nodes.py
@ -920,7 +920,7 @@ class CLIPLoader:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ),
|
return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ),
|
||||||
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace"], ),
|
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2"], ),
|
||||||
},
|
},
|
||||||
"optional": {
|
"optional": {
|
||||||
"device": (["default", "cpu"], {"advanced": True}),
|
"device": (["default", "cpu"], {"advanced": True}),
|
||||||
@ -930,7 +930,7 @@ class CLIPLoader:
|
|||||||
|
|
||||||
CATEGORY = "advanced/loaders"
|
CATEGORY = "advanced/loaders"
|
||||||
|
|
||||||
DESCRIPTION = "[Recipes]\n\nstable_diffusion: clip-l\nstable_cascade: clip-g\nsd3: t5 xxl/ clip-g / clip-l\nstable_audio: t5 base\nmochi: t5 xxl\ncosmos: old t5 xxl\nlumina2: gemma 2 2B\nwan: umt5 xxl\n hidream: llama-3.1 (Recommend) or t5"
|
DESCRIPTION = "[Recipes]\n\nstable_diffusion: clip-l\nstable_cascade: clip-g\nsd3: t5 xxl/ clip-g / clip-l\nstable_audio: t5 base\nmochi: t5 xxl\ncosmos: old t5 xxl\nlumina2: gemma 2 2B\nwan: umt5 xxl\n hidream: llama-3.1 (Recommend) or t5\nomnigen2: qwen vl 2.5 3B"
|
||||||
|
|
||||||
def load_clip(self, clip_name, type="stable_diffusion", device="default"):
|
def load_clip(self, clip_name, type="stable_diffusion", device="default"):
|
||||||
clip_type = getattr(comfy.sd.CLIPType, type.upper(), comfy.sd.CLIPType.STABLE_DIFFUSION)
|
clip_type = getattr(comfy.sd.CLIPType, type.upper(), comfy.sd.CLIPType.STABLE_DIFFUSION)
|
||||||
@ -2187,6 +2187,9 @@ def init_external_custom_nodes():
|
|||||||
module_path = os.path.join(custom_node_path, possible_module)
|
module_path = os.path.join(custom_node_path, possible_module)
|
||||||
if os.path.isfile(module_path) and os.path.splitext(module_path)[1] != ".py": continue
|
if os.path.isfile(module_path) and os.path.splitext(module_path)[1] != ".py": continue
|
||||||
if module_path.endswith(".disabled"): continue
|
if module_path.endswith(".disabled"): continue
|
||||||
|
if args.disable_all_custom_nodes and possible_module not in args.whitelist_custom_nodes:
|
||||||
|
logging.info(f"Skipping {possible_module} due to disable_all_custom_nodes and whitelist_custom_nodes")
|
||||||
|
continue
|
||||||
time_before = time.perf_counter()
|
time_before = time.perf_counter()
|
||||||
success = load_custom_node(module_path, base_node_names, module_parent="custom_nodes")
|
success = load_custom_node(module_path, base_node_names, module_parent="custom_nodes")
|
||||||
node_import_times.append((time.perf_counter() - time_before, module_path, success))
|
node_import_times.append((time.perf_counter() - time_before, module_path, success))
|
||||||
@ -2280,6 +2283,8 @@ def init_builtin_extra_nodes():
|
|||||||
"nodes_ace.py",
|
"nodes_ace.py",
|
||||||
"nodes_string.py",
|
"nodes_string.py",
|
||||||
"nodes_camera_trajectory.py",
|
"nodes_camera_trajectory.py",
|
||||||
|
"nodes_edit_model.py",
|
||||||
|
"nodes_tcfg.py"
|
||||||
]
|
]
|
||||||
|
|
||||||
import_failed = []
|
import_failed = []
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "ComfyUI"
|
name = "ComfyUI"
|
||||||
version = "0.3.41"
|
version = "0.3.43"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
license = { file = "LICENSE" }
|
license = { file = "LICENSE" }
|
||||||
requires-python = ">=3.9"
|
requires-python = ">=3.9"
|
||||||
|
@ -1,13 +1,13 @@
|
|||||||
comfyui-frontend-package==1.22.2
|
comfyui-frontend-package==1.23.4
|
||||||
comfyui-workflow-templates==0.1.29
|
comfyui-workflow-templates==0.1.31
|
||||||
comfyui-embedded-docs==0.2.2
|
comfyui-embedded-docs==0.2.3
|
||||||
torch
|
torch
|
||||||
torchsde
|
torchsde
|
||||||
torchvision
|
torchvision
|
||||||
torchaudio
|
torchaudio
|
||||||
numpy>=1.25.0
|
numpy>=1.25.0
|
||||||
einops
|
einops
|
||||||
transformers>=4.28.1
|
transformers>=4.37.2
|
||||||
tokenizers>=0.13.3
|
tokenizers>=0.13.3
|
||||||
sentencepiece
|
sentencepiece
|
||||||
safetensors>=0.4.2
|
safetensors>=0.4.2
|
||||||
|
Loading…
x
Reference in New Issue
Block a user