Add support for GLIGEN textbox model.

This commit is contained in:
comfyanonymous
2023-04-19 09:36:19 -04:00
parent 472b1cc0d8
commit 3696d1699a
9 changed files with 491 additions and 28 deletions

View File

@@ -510,6 +510,14 @@ class BasicTransformerBlock(nn.Module):
return checkpoint(self._forward, (x, context, transformer_options), self.parameters(), self.checkpoint)
def _forward(self, x, context=None, transformer_options={}):
current_index = None
if "current_index" in transformer_options:
current_index = transformer_options["current_index"]
if "patches" in transformer_options:
transformer_patches = transformer_options["patches"]
else:
transformer_patches = {}
n = self.norm1(x)
if "tomesd" in transformer_options:
m, u = tomesd.get_functions(x, transformer_options["tomesd"]["ratio"], transformer_options["original_shape"])
@@ -518,11 +526,19 @@ class BasicTransformerBlock(nn.Module):
n = self.attn1(n, context=context if self.disable_self_attn else None)
x += n
if "middle_patch" in transformer_patches:
patch = transformer_patches["middle_patch"]
for p in patch:
x = p(current_index, x)
n = self.norm2(x)
n = self.attn2(n, context=context)
x += n
x = self.ff(self.norm3(x)) + x
if current_index is not None:
transformer_options["current_index"] += 1
return x

View File

@@ -782,6 +782,8 @@ class UNetModel(nn.Module):
:return: an [N x C x ...] Tensor of outputs.
"""
transformer_options["original_shape"] = list(x.shape)
transformer_options["current_index"] = 0
assert (y is not None) == (
self.num_classes is not None
), "must specify y if and only if the model is class-conditional"