mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-09-12 12:37:01 +00:00
Set the seed in the SDE samplers to make them more reproducible.
This commit is contained in:
@@ -77,7 +77,7 @@ class BatchedBrownianTree:
|
||||
except TypeError:
|
||||
seed = [seed]
|
||||
self.batched = False
|
||||
self.trees = [torchsde.BrownianTree(t0, w0, t1, entropy=s, **kwargs) for s in seed]
|
||||
self.trees = [torchsde.BrownianTree(t0.cpu(), w0.cpu(), t1.cpu(), entropy=s, **kwargs) for s in seed]
|
||||
|
||||
@staticmethod
|
||||
def sort(a, b):
|
||||
@@ -85,7 +85,7 @@ class BatchedBrownianTree:
|
||||
|
||||
def __call__(self, t0, t1):
|
||||
t0, t1, sign = self.sort(t0, t1)
|
||||
w = torch.stack([tree(t0, t1) for tree in self.trees]) * (self.sign * sign)
|
||||
w = torch.stack([tree(t0.cpu().float(), t1.cpu().float()).to(t0.dtype).to(t0.device) for tree in self.trees]) * (self.sign * sign)
|
||||
return w if self.batched else w[0]
|
||||
|
||||
|
||||
@@ -543,7 +543,8 @@ def sample_dpmpp_2s_ancestral(model, x, sigmas, extra_args=None, callback=None,
|
||||
def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2):
|
||||
"""DPM-Solver++ (stochastic)."""
|
||||
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
||||
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max) if noise_sampler is None else noise_sampler
|
||||
seed = extra_args.get("seed", None)
|
||||
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed) if noise_sampler is None else noise_sampler
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
sigma_fn = lambda t: t.neg().exp()
|
||||
@@ -613,8 +614,9 @@ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
|
||||
if solver_type not in {'heun', 'midpoint'}:
|
||||
raise ValueError('solver_type must be \'heun\' or \'midpoint\'')
|
||||
|
||||
seed = extra_args.get("seed", None)
|
||||
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
||||
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max) if noise_sampler is None else noise_sampler
|
||||
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed) if noise_sampler is None else noise_sampler
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
|
||||
|
Reference in New Issue
Block a user