mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-09-11 20:17:30 +00:00
Support base SDXL and SDXL refiner models.
Large refactor of the model detection and loading code.
This commit is contained in:
@@ -600,7 +600,7 @@ class SpatialTransformer(nn.Module):
|
||||
use_checkpoint=True, dtype=None):
|
||||
super().__init__()
|
||||
if exists(context_dim) and not isinstance(context_dim, list):
|
||||
context_dim = [context_dim]
|
||||
context_dim = [context_dim] * depth
|
||||
self.in_channels = in_channels
|
||||
inner_dim = n_heads * d_head
|
||||
self.norm = Normalize(in_channels, dtype=dtype)
|
||||
@@ -630,7 +630,7 @@ class SpatialTransformer(nn.Module):
|
||||
def forward(self, x, context=None, transformer_options={}):
|
||||
# note: if no context is given, cross-attention defaults to self-attention
|
||||
if not isinstance(context, list):
|
||||
context = [context]
|
||||
context = [context] * len(self.transformer_blocks)
|
||||
b, c, h, w = x.shape
|
||||
x_in = x
|
||||
x = self.norm(x)
|
||||
|
Reference in New Issue
Block a user