mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-09-13 13:05:07 +00:00
Support SSD1B model and make it easier to support asymmetric unets.
This commit is contained in:
@@ -27,7 +27,6 @@ class ControlNet(nn.Module):
|
||||
model_channels,
|
||||
hint_channels,
|
||||
num_res_blocks,
|
||||
attention_resolutions,
|
||||
dropout=0,
|
||||
channel_mult=(1, 2, 4, 8),
|
||||
conv_resample=True,
|
||||
@@ -52,6 +51,7 @@ class ControlNet(nn.Module):
|
||||
use_linear_in_transformer=False,
|
||||
adm_in_channels=None,
|
||||
transformer_depth_middle=None,
|
||||
transformer_depth_output=None,
|
||||
device=None,
|
||||
operations=comfy.ops,
|
||||
):
|
||||
@@ -79,10 +79,7 @@ class ControlNet(nn.Module):
|
||||
self.image_size = image_size
|
||||
self.in_channels = in_channels
|
||||
self.model_channels = model_channels
|
||||
if isinstance(transformer_depth, int):
|
||||
transformer_depth = len(channel_mult) * [transformer_depth]
|
||||
if transformer_depth_middle is None:
|
||||
transformer_depth_middle = transformer_depth[-1]
|
||||
|
||||
if isinstance(num_res_blocks, int):
|
||||
self.num_res_blocks = len(channel_mult) * [num_res_blocks]
|
||||
else:
|
||||
@@ -90,18 +87,16 @@ class ControlNet(nn.Module):
|
||||
raise ValueError("provide num_res_blocks either as an int (globally constant) or "
|
||||
"as a list/tuple (per-level) with the same length as channel_mult")
|
||||
self.num_res_blocks = num_res_blocks
|
||||
|
||||
if disable_self_attentions is not None:
|
||||
# should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
|
||||
assert len(disable_self_attentions) == len(channel_mult)
|
||||
if num_attention_blocks is not None:
|
||||
assert len(num_attention_blocks) == len(self.num_res_blocks)
|
||||
assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
|
||||
print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
|
||||
f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
|
||||
f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
|
||||
f"attention will still not be set.")
|
||||
|
||||
self.attention_resolutions = attention_resolutions
|
||||
transformer_depth = transformer_depth[:]
|
||||
|
||||
self.dropout = dropout
|
||||
self.channel_mult = channel_mult
|
||||
self.conv_resample = conv_resample
|
||||
@@ -180,11 +175,14 @@ class ControlNet(nn.Module):
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
operations=operations
|
||||
dtype=self.dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
]
|
||||
ch = mult * model_channels
|
||||
if ds in attention_resolutions:
|
||||
num_transformers = transformer_depth.pop(0)
|
||||
if num_transformers > 0:
|
||||
if num_head_channels == -1:
|
||||
dim_head = ch // num_heads
|
||||
else:
|
||||
@@ -201,9 +199,9 @@ class ControlNet(nn.Module):
|
||||
if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
|
||||
layers.append(
|
||||
SpatialTransformer(
|
||||
ch, num_heads, dim_head, depth=transformer_depth[level], context_dim=context_dim,
|
||||
ch, num_heads, dim_head, depth=num_transformers, context_dim=context_dim,
|
||||
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
|
||||
use_checkpoint=use_checkpoint, operations=operations
|
||||
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations
|
||||
)
|
||||
)
|
||||
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
||||
@@ -223,11 +221,13 @@ class ControlNet(nn.Module):
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
down=True,
|
||||
dtype=self.dtype,
|
||||
device=device,
|
||||
operations=operations
|
||||
)
|
||||
if resblock_updown
|
||||
else Downsample(
|
||||
ch, conv_resample, dims=dims, out_channels=out_ch, operations=operations
|
||||
ch, conv_resample, dims=dims, out_channels=out_ch, dtype=self.dtype, device=device, operations=operations
|
||||
)
|
||||
)
|
||||
)
|
||||
@@ -245,7 +245,7 @@ class ControlNet(nn.Module):
|
||||
if legacy:
|
||||
#num_heads = 1
|
||||
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
|
||||
self.middle_block = TimestepEmbedSequential(
|
||||
mid_block = [
|
||||
ResBlock(
|
||||
ch,
|
||||
time_embed_dim,
|
||||
@@ -253,12 +253,15 @@ class ControlNet(nn.Module):
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
dtype=self.dtype,
|
||||
device=device,
|
||||
operations=operations
|
||||
),
|
||||
SpatialTransformer( # always uses a self-attn
|
||||
)]
|
||||
if transformer_depth_middle >= 0:
|
||||
mid_block += [SpatialTransformer( # always uses a self-attn
|
||||
ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim,
|
||||
disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
|
||||
use_checkpoint=use_checkpoint, operations=operations
|
||||
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations
|
||||
),
|
||||
ResBlock(
|
||||
ch,
|
||||
@@ -267,9 +270,11 @@ class ControlNet(nn.Module):
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
dtype=self.dtype,
|
||||
device=device,
|
||||
operations=operations
|
||||
),
|
||||
)
|
||||
)]
|
||||
self.middle_block = TimestepEmbedSequential(*mid_block)
|
||||
self.middle_block_out = self.make_zero_conv(ch, operations=operations)
|
||||
self._feature_size += ch
|
||||
|
||||
|
Reference in New Issue
Block a user