SEEDS: update noise decomposition and refactor (#9633)

- Update the decomposition to reflect interval dependency
- Extract phi computations into functions
- Use torch.lerp for interpolation
This commit is contained in:
chaObserv
2025-08-31 12:01:45 +08:00
committed by GitHub
parent fe442fac2e
commit 32a627bf1f

View File

@@ -171,6 +171,16 @@ def offset_first_sigma_for_snr(sigmas, model_sampling, percent_offset=1e-4):
return sigmas return sigmas
def ei_h_phi_1(h: torch.Tensor) -> torch.Tensor:
"""Compute the result of h*phi_1(h) in exponential integrator methods."""
return torch.expm1(h)
def ei_h_phi_2(h: torch.Tensor) -> torch.Tensor:
"""Compute the result of h*phi_2(h) in exponential integrator methods."""
return (torch.expm1(h) - h) / h
@torch.no_grad() @torch.no_grad()
def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.): def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
"""Implements Algorithm 2 (Euler steps) from Karras et al. (2022).""" """Implements Algorithm 2 (Euler steps) from Karras et al. (2022)."""
@@ -1550,13 +1560,12 @@ def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None
@torch.no_grad() @torch.no_grad()
def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=0.5): def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=0.5):
"""SEEDS-2 - Stochastic Explicit Exponential Derivative-free Solvers (VP Data Prediction) stage 2. """SEEDS-2 - Stochastic Explicit Exponential Derivative-free Solvers (VP Data Prediction) stage 2.
arXiv: https://arxiv.org/abs/2305.14267 arXiv: https://arxiv.org/abs/2305.14267 (NeurIPS 2023)
""" """
extra_args = {} if extra_args is None else extra_args extra_args = {} if extra_args is None else extra_args
seed = extra_args.get("seed", None) seed = extra_args.get("seed", None)
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]]) s_in = x.new_ones([x.shape[0]])
inject_noise = eta > 0 and s_noise > 0 inject_noise = eta > 0 and s_noise > 0
model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling') model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling')
@@ -1564,55 +1573,53 @@ def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=Non
lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling) lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling)
sigmas = offset_first_sigma_for_snr(sigmas, model_sampling) sigmas = offset_first_sigma_for_snr(sigmas, model_sampling)
fac = 1 / (2 * r)
for i in trange(len(sigmas) - 1, disable=disable): for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args) denoised = model(x, sigmas[i] * s_in, **extra_args)
if callback is not None: if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
if sigmas[i + 1] == 0: if sigmas[i + 1] == 0:
x = denoised x = denoised
else: continue
lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1])
h = lambda_t - lambda_s
h_eta = h * (eta + 1)
lambda_s_1 = lambda_s + r * h
fac = 1 / (2 * r)
sigma_s_1 = sigma_fn(lambda_s_1)
# alpha_t = sigma_t * exp(log(alpha_t / sigma_t)) = sigma_t * exp(lambda_t) lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1])
alpha_s_1 = sigma_s_1 * lambda_s_1.exp() h = lambda_t - lambda_s
alpha_t = sigmas[i + 1] * lambda_t.exp() h_eta = h * (eta + 1)
lambda_s_1 = torch.lerp(lambda_s, lambda_t, r)
sigma_s_1 = sigma_fn(lambda_s_1)
coeff_1, coeff_2 = (-r * h_eta).expm1(), (-h_eta).expm1() alpha_s_1 = sigma_s_1 * lambda_s_1.exp()
if inject_noise: alpha_t = sigmas[i + 1] * lambda_t.exp()
# 0 < r < 1
noise_coeff_1 = (-2 * r * h * eta).expm1().neg().sqrt()
noise_coeff_2 = (-r * h * eta).exp() * (-2 * (1 - r) * h * eta).expm1().neg().sqrt()
noise_1, noise_2 = noise_sampler(sigmas[i], sigma_s_1), noise_sampler(sigma_s_1, sigmas[i + 1])
# Step 1 # Step 1
x_2 = sigma_s_1 / sigmas[i] * (-r * h * eta).exp() * x - alpha_s_1 * coeff_1 * denoised x_2 = sigma_s_1 / sigmas[i] * (-r * h * eta).exp() * x - alpha_s_1 * ei_h_phi_1(-r * h_eta) * denoised
if inject_noise: if inject_noise:
x_2 = x_2 + sigma_s_1 * (noise_coeff_1 * noise_1) * s_noise sde_noise = (-2 * r * h * eta).expm1().neg().sqrt() * noise_sampler(sigmas[i], sigma_s_1)
denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args) x_2 = x_2 + sde_noise * sigma_s_1 * s_noise
denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args)
# Step 2 # Step 2
denoised_d = (1 - fac) * denoised + fac * denoised_2 denoised_d = torch.lerp(denoised, denoised_2, fac)
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * coeff_2 * denoised_d x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * ei_h_phi_1(-h_eta) * denoised_d
if inject_noise: if inject_noise:
x = x + sigmas[i + 1] * (noise_coeff_2 * noise_1 + noise_coeff_1 * noise_2) * s_noise segment_factor = (r - 1) * h * eta
sde_noise = sde_noise * segment_factor.exp()
sde_noise = sde_noise + segment_factor.mul(2).expm1().neg().sqrt() * noise_sampler(sigma_s_1, sigmas[i + 1])
x = x + sde_noise * sigmas[i + 1] * s_noise
return x return x
@torch.no_grad() @torch.no_grad()
def sample_seeds_3(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r_1=1./3, r_2=2./3): def sample_seeds_3(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r_1=1./3, r_2=2./3):
"""SEEDS-3 - Stochastic Explicit Exponential Derivative-free Solvers (VP Data Prediction) stage 3. """SEEDS-3 - Stochastic Explicit Exponential Derivative-free Solvers (VP Data Prediction) stage 3.
arXiv: https://arxiv.org/abs/2305.14267 arXiv: https://arxiv.org/abs/2305.14267 (NeurIPS 2023)
""" """
extra_args = {} if extra_args is None else extra_args extra_args = {} if extra_args is None else extra_args
seed = extra_args.get("seed", None) seed = extra_args.get("seed", None)
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]]) s_in = x.new_ones([x.shape[0]])
inject_noise = eta > 0 and s_noise > 0 inject_noise = eta > 0 and s_noise > 0
model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling') model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling')
@@ -1624,45 +1631,49 @@ def sample_seeds_3(model, x, sigmas, extra_args=None, callback=None, disable=Non
denoised = model(x, sigmas[i] * s_in, **extra_args) denoised = model(x, sigmas[i] * s_in, **extra_args)
if callback is not None: if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
if sigmas[i + 1] == 0: if sigmas[i + 1] == 0:
x = denoised x = denoised
else: continue
lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1])
h = lambda_t - lambda_s
h_eta = h * (eta + 1)
lambda_s_1 = lambda_s + r_1 * h
lambda_s_2 = lambda_s + r_2 * h
sigma_s_1, sigma_s_2 = sigma_fn(lambda_s_1), sigma_fn(lambda_s_2)
# alpha_t = sigma_t * exp(log(alpha_t / sigma_t)) = sigma_t * exp(lambda_t) lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1])
alpha_s_1 = sigma_s_1 * lambda_s_1.exp() h = lambda_t - lambda_s
alpha_s_2 = sigma_s_2 * lambda_s_2.exp() h_eta = h * (eta + 1)
alpha_t = sigmas[i + 1] * lambda_t.exp() lambda_s_1 = torch.lerp(lambda_s, lambda_t, r_1)
lambda_s_2 = torch.lerp(lambda_s, lambda_t, r_2)
sigma_s_1, sigma_s_2 = sigma_fn(lambda_s_1), sigma_fn(lambda_s_2)
coeff_1, coeff_2, coeff_3 = (-r_1 * h_eta).expm1(), (-r_2 * h_eta).expm1(), (-h_eta).expm1() alpha_s_1 = sigma_s_1 * lambda_s_1.exp()
if inject_noise: alpha_s_2 = sigma_s_2 * lambda_s_2.exp()
# 0 < r_1 < r_2 < 1 alpha_t = sigmas[i + 1] * lambda_t.exp()
noise_coeff_1 = (-2 * r_1 * h * eta).expm1().neg().sqrt()
noise_coeff_2 = (-r_1 * h * eta).exp() * (-2 * (r_2 - r_1) * h * eta).expm1().neg().sqrt()
noise_coeff_3 = (-r_2 * h * eta).exp() * (-2 * (1 - r_2) * h * eta).expm1().neg().sqrt()
noise_1, noise_2, noise_3 = noise_sampler(sigmas[i], sigma_s_1), noise_sampler(sigma_s_1, sigma_s_2), noise_sampler(sigma_s_2, sigmas[i + 1])
# Step 1 # Step 1
x_2 = sigma_s_1 / sigmas[i] * (-r_1 * h * eta).exp() * x - alpha_s_1 * coeff_1 * denoised x_2 = sigma_s_1 / sigmas[i] * (-r_1 * h * eta).exp() * x - alpha_s_1 * ei_h_phi_1(-r_1 * h_eta) * denoised
if inject_noise: if inject_noise:
x_2 = x_2 + sigma_s_1 * (noise_coeff_1 * noise_1) * s_noise sde_noise = (-2 * r_1 * h * eta).expm1().neg().sqrt() * noise_sampler(sigmas[i], sigma_s_1)
denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args) x_2 = x_2 + sde_noise * sigma_s_1 * s_noise
denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args)
# Step 2 # Step 2
x_3 = sigma_s_2 / sigmas[i] * (-r_2 * h * eta).exp() * x - alpha_s_2 * coeff_2 * denoised + (r_2 / r_1) * alpha_s_2 * (coeff_2 / (r_2 * h_eta) + 1) * (denoised_2 - denoised) a3_2 = r_2 / r_1 * ei_h_phi_2(-r_2 * h_eta)
if inject_noise: a3_1 = ei_h_phi_1(-r_2 * h_eta) - a3_2
x_3 = x_3 + sigma_s_2 * (noise_coeff_2 * noise_1 + noise_coeff_1 * noise_2) * s_noise x_3 = sigma_s_2 / sigmas[i] * (-r_2 * h * eta).exp() * x - alpha_s_2 * (a3_1 * denoised + a3_2 * denoised_2)
denoised_3 = model(x_3, sigma_s_2 * s_in, **extra_args) if inject_noise:
segment_factor = (r_1 - r_2) * h * eta
sde_noise = sde_noise * segment_factor.exp()
sde_noise = sde_noise + segment_factor.mul(2).expm1().neg().sqrt() * noise_sampler(sigma_s_1, sigma_s_2)
x_3 = x_3 + sde_noise * sigma_s_2 * s_noise
denoised_3 = model(x_3, sigma_s_2 * s_in, **extra_args)
# Step 3 # Step 3
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * coeff_3 * denoised + (1. / r_2) * alpha_t * (coeff_3 / h_eta + 1) * (denoised_3 - denoised) b3 = ei_h_phi_2(-h_eta) / r_2
if inject_noise: b1 = ei_h_phi_1(-h_eta) - b3
x = x + sigmas[i + 1] * (noise_coeff_3 * noise_1 + noise_coeff_2 * noise_2 + noise_coeff_1 * noise_3) * s_noise x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * (b1 * denoised + b3 * denoised_3)
if inject_noise:
segment_factor = (r_2 - 1) * h * eta
sde_noise = sde_noise * segment_factor.exp()
sde_noise = sde_noise + segment_factor.mul(2).expm1().neg().sqrt() * noise_sampler(sigma_s_2, sigmas[i + 1])
x = x + sde_noise * sigmas[i + 1] * s_noise
return x return x