Support base SDXL and SDXL refiner models.

Large refactor of the model detection and loading code.
This commit is contained in:
comfyanonymous
2023-06-22 13:03:50 -04:00
parent 9fccf4aa03
commit f87ec10a97
16 changed files with 754 additions and 289 deletions

View File

@@ -600,7 +600,7 @@ class SpatialTransformer(nn.Module):
use_checkpoint=True, dtype=None):
super().__init__()
if exists(context_dim) and not isinstance(context_dim, list):
context_dim = [context_dim]
context_dim = [context_dim] * depth
self.in_channels = in_channels
inner_dim = n_heads * d_head
self.norm = Normalize(in_channels, dtype=dtype)
@@ -630,7 +630,7 @@ class SpatialTransformer(nn.Module):
def forward(self, x, context=None, transformer_options={}):
# note: if no context is given, cross-attention defaults to self-attention
if not isinstance(context, list):
context = [context]
context = [context] * len(self.transformer_blocks)
b, c, h, w = x.shape
x_in = x
x = self.norm(x)