Add support for GLIGEN textbox model.

This commit is contained in:
comfyanonymous
2023-04-19 09:36:19 -04:00
parent 472b1cc0d8
commit 3696d1699a
9 changed files with 491 additions and 28 deletions

View File

@@ -70,7 +70,21 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
control = None
if 'control' in cond[1]:
control = cond[1]['control']
return (input_x, mult, conditionning, area, control)
patches = None
if 'gligen' in cond[1]:
gligen = cond[1]['gligen']
patches = {}
gligen_type = gligen[0]
gligen_model = gligen[1]
if gligen_type == "position":
gligen_patch = gligen_model.set_position(input_x.shape, gligen[2], input_x.device)
else:
gligen_patch = gligen_model.set_empty(input_x.shape, input_x.device)
patches['middle_patch'] = [gligen_patch]
return (input_x, mult, conditionning, area, control, patches)
def cond_equal_size(c1, c2):
if c1 is c2:
@@ -91,12 +105,21 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
def can_concat_cond(c1, c2):
if c1[0].shape != c2[0].shape:
return False
#control
if (c1[4] is None) != (c2[4] is None):
return False
if c1[4] is not None:
if c1[4] is not c2[4]:
return False
#patches
if (c1[5] is None) != (c2[5] is None):
return False
if (c1[5] is not None):
if c1[5] is not c2[5]:
return False
return cond_equal_size(c1[2], c2[2])
def cond_cat(c_list):
@@ -166,6 +189,7 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
cond_or_uncond = []
area = []
control = None
patches = None
for x in to_batch:
o = to_run.pop(x)
p = o[0]
@@ -175,6 +199,7 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
area += [p[3]]
cond_or_uncond += [o[1]]
control = p[4]
patches = p[5]
batch_chunks = len(cond_or_uncond)
input_x = torch.cat(input_x)
@@ -184,8 +209,14 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
if control is not None:
c['control'] = control.get_control(input_x, timestep_, c['c_crossattn'], len(cond_or_uncond))
transformer_options = {}
if 'transformer_options' in model_options:
c['transformer_options'] = model_options['transformer_options']
transformer_options = model_options['transformer_options'].copy()
if patches is not None:
transformer_options["patches"] = patches
c['transformer_options'] = transformer_options
output = model_function(input_x, timestep_, cond=c).chunk(batch_chunks)
del input_x
@@ -309,8 +340,7 @@ def create_cond_with_same_area_if_none(conds, c):
n = c[1].copy()
conds += [[smallest[0], n]]
def apply_control_net_to_equal_area(conds, uncond):
def apply_empty_x_to_equal_area(conds, uncond, name, uncond_fill_func):
cond_cnets = []
cond_other = []
uncond_cnets = []
@@ -318,15 +348,15 @@ def apply_control_net_to_equal_area(conds, uncond):
for t in range(len(conds)):
x = conds[t]
if 'area' not in x[1]:
if 'control' in x[1] and x[1]['control'] is not None:
cond_cnets.append(x[1]['control'])
if name in x[1] and x[1][name] is not None:
cond_cnets.append(x[1][name])
else:
cond_other.append((x, t))
for t in range(len(uncond)):
x = uncond[t]
if 'area' not in x[1]:
if 'control' in x[1] and x[1]['control'] is not None:
uncond_cnets.append(x[1]['control'])
if name in x[1] and x[1][name] is not None:
uncond_cnets.append(x[1][name])
else:
uncond_other.append((x, t))
@@ -336,15 +366,16 @@ def apply_control_net_to_equal_area(conds, uncond):
for x in range(len(cond_cnets)):
temp = uncond_other[x % len(uncond_other)]
o = temp[0]
if 'control' in o[1] and o[1]['control'] is not None:
if name in o[1] and o[1][name] is not None:
n = o[1].copy()
n['control'] = cond_cnets[x]
n[name] = uncond_fill_func(cond_cnets, x)
uncond += [[o[0], n]]
else:
n = o[1].copy()
n['control'] = cond_cnets[x]
n[name] = uncond_fill_func(cond_cnets, x)
uncond[temp[1]] = [o[0], n]
def encode_adm(noise_augmentor, conds, batch_size, device):
for t in range(len(conds)):
x = conds[t]
@@ -378,6 +409,7 @@ def encode_adm(noise_augmentor, conds, batch_size, device):
return conds
class KSampler:
SCHEDULERS = ["karras", "normal", "simple", "ddim_uniform"]
SAMPLERS = ["euler", "euler_ancestral", "heun", "dpm_2", "dpm_2_ancestral",
@@ -466,7 +498,8 @@ class KSampler:
for c in negative:
create_cond_with_same_area_if_none(positive, c)
apply_control_net_to_equal_area(positive, negative)
apply_empty_x_to_equal_area(positive, negative, 'control', lambda cond_cnets, x: cond_cnets[x])
apply_empty_x_to_equal_area(positive, negative, 'gligen', lambda cond_cnets, x: cond_cnets[x])
if self.model.model.diffusion_model.dtype == torch.float16:
precision_scope = torch.autocast