Support base SDXL and SDXL refiner models.

Large refactor of the model detection and loading code.
This commit is contained in:
comfyanonymous
2023-06-22 13:03:50 -04:00
parent 9fccf4aa03
commit f87ec10a97
16 changed files with 754 additions and 289 deletions

View File

@@ -229,7 +229,7 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
timestep_ = torch.cat([timestep] * batch_chunks)
if control is not None:
c['control'] = control.get_control(input_x, timestep_, c['c_crossattn'], len(cond_or_uncond))
c['control'] = control.get_control(input_x, timestep_, c, len(cond_or_uncond))
transformer_options = {}
if 'transformer_options' in model_options:
@@ -460,8 +460,7 @@ def apply_empty_x_to_equal_area(conds, uncond, name, uncond_fill_func):
n[name] = uncond_fill_func(cond_cnets, x)
uncond[temp[1]] = [o[0], n]
def encode_adm(model, conds, batch_size, device):
def encode_adm(model, conds, batch_size, width, height, device, prompt_type):
for t in range(len(conds)):
x = conds[t]
adm_out = None
@@ -469,7 +468,11 @@ def encode_adm(model, conds, batch_size, device):
adm_out = x[1]["adm"]
else:
params = x[1].copy()
params["width"] = params.get("width", width * 8)
params["height"] = params.get("height", height * 8)
params["prompt_type"] = params.get("prompt_type", prompt_type)
adm_out = model.encode_adm(device=device, **params)
if adm_out is not None:
x[1] = x[1].copy()
x[1]["adm_encoded"] = torch.cat([adm_out] * batch_size).to(device)
@@ -580,8 +583,8 @@ class KSampler:
precision_scope = contextlib.nullcontext
if self.model.is_adm():
positive = encode_adm(self.model, positive, noise.shape[0], self.device)
negative = encode_adm(self.model, negative, noise.shape[0], self.device)
positive = encode_adm(self.model, positive, noise.shape[0], noise.shape[3], noise.shape[2], self.device, "positive")
negative = encode_adm(self.model, negative, noise.shape[0], noise.shape[3], noise.shape[2], self.device, "negative")
extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg, "model_options": self.model_options}