Clean up percent start/end and make controlnets work with sigmas.

This commit is contained in:
comfyanonymous
2023-10-31 22:14:32 -04:00
parent a268a574fa
commit 7c0f255de1
3 changed files with 26 additions and 9 deletions

View File

@@ -82,6 +82,9 @@ class ModelSamplingDiscrete(torch.nn.Module):
log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx]
return log_sigma.exp()
def percent_to_sigma(self, percent):
return self.sigma(torch.tensor(percent * 999.0))
def model_sampling(model_config, model_type):
if model_type == ModelType.EPS:
c = EPS
@@ -126,7 +129,7 @@ class BaseModel(torch.nn.Module):
context = c_crossattn
dtype = self.get_dtype()
xc = xc.to(dtype)
t = self.model_sampling.timestep(t).to(dtype)
t = self.model_sampling.timestep(t).float()
context = context.to(dtype)
extra_conds = {}
for o in kwargs: