Implement beta sampling scheduler.

It is based on: https://arxiv.org/abs/2407.12173

Add "beta" to the list of schedulers and the BetaSamplingScheduler node.
This commit is contained in:
comfyanonymous
2024-07-19 17:44:56 -04:00
parent 011b11d8d7
commit 6ab8cad22e
2 changed files with 37 additions and 1 deletions

View File

@@ -111,6 +111,25 @@ class SDTurboScheduler:
sigmas = torch.cat([sigmas, sigmas.new_zeros([1])])
return (sigmas, )
class BetaSamplingScheduler:
@classmethod
def INPUT_TYPES(s):
return {"required":
{"model": ("MODEL",),
"steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
"alpha": ("FLOAT", {"default": 0.6, "min": 0.0, "max": 50.0, "step":0.01, "round": False}),
"beta": ("FLOAT", {"default": 0.6, "min": 0.0, "max": 50.0, "step":0.01, "round": False}),
}
}
RETURN_TYPES = ("SIGMAS",)
CATEGORY = "sampling/custom_sampling/schedulers"
FUNCTION = "get_sigmas"
def get_sigmas(self, model, steps, alpha, beta):
sigmas = comfy.samplers.beta_scheduler(model.get_model_object("model_sampling"), steps, alpha=alpha, beta=beta)
return (sigmas, )
class VPScheduler:
@classmethod
def INPUT_TYPES(s):
@@ -638,6 +657,7 @@ NODE_CLASS_MAPPINGS = {
"ExponentialScheduler": ExponentialScheduler,
"PolyexponentialScheduler": PolyexponentialScheduler,
"VPScheduler": VPScheduler,
"BetaSamplingScheduler": BetaSamplingScheduler,
"SDTurboScheduler": SDTurboScheduler,
"KSamplerSelect": KSamplerSelect,
"SamplerEulerAncestral": SamplerEulerAncestral,