mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-09-15 05:57:57 +00:00
Add support for GLIGEN textbox model.
This commit is contained in:
@@ -510,6 +510,14 @@ class BasicTransformerBlock(nn.Module):
|
||||
return checkpoint(self._forward, (x, context, transformer_options), self.parameters(), self.checkpoint)
|
||||
|
||||
def _forward(self, x, context=None, transformer_options={}):
|
||||
current_index = None
|
||||
if "current_index" in transformer_options:
|
||||
current_index = transformer_options["current_index"]
|
||||
if "patches" in transformer_options:
|
||||
transformer_patches = transformer_options["patches"]
|
||||
else:
|
||||
transformer_patches = {}
|
||||
|
||||
n = self.norm1(x)
|
||||
if "tomesd" in transformer_options:
|
||||
m, u = tomesd.get_functions(x, transformer_options["tomesd"]["ratio"], transformer_options["original_shape"])
|
||||
@@ -518,11 +526,19 @@ class BasicTransformerBlock(nn.Module):
|
||||
n = self.attn1(n, context=context if self.disable_self_attn else None)
|
||||
|
||||
x += n
|
||||
if "middle_patch" in transformer_patches:
|
||||
patch = transformer_patches["middle_patch"]
|
||||
for p in patch:
|
||||
x = p(current_index, x)
|
||||
|
||||
n = self.norm2(x)
|
||||
n = self.attn2(n, context=context)
|
||||
|
||||
x += n
|
||||
x = self.ff(self.norm3(x)) + x
|
||||
|
||||
if current_index is not None:
|
||||
transformer_options["current_index"] += 1
|
||||
return x
|
||||
|
||||
|
||||
|
@@ -782,6 +782,8 @@ class UNetModel(nn.Module):
|
||||
:return: an [N x C x ...] Tensor of outputs.
|
||||
"""
|
||||
transformer_options["original_shape"] = list(x.shape)
|
||||
transformer_options["current_index"] = 0
|
||||
|
||||
assert (y is not None) == (
|
||||
self.num_classes is not None
|
||||
), "must specify y if and only if the model is class-conditional"
|
||||
|
Reference in New Issue
Block a user