T2I adapter SDXL.

This commit is contained in:
comfyanonymous
2023-08-22 14:38:34 -04:00
parent f2a7cc9121
commit 85fde89d7f
2 changed files with 48 additions and 10 deletions

View File

@@ -101,17 +101,30 @@ class ResnetBlock(nn.Module):
class Adapter(nn.Module):
def __init__(self, channels=[320, 640, 1280, 1280], nums_rb=3, cin=64, ksize=3, sk=False, use_conv=True):
def __init__(self, channels=[320, 640, 1280, 1280], nums_rb=3, cin=64, ksize=3, sk=False, use_conv=True, xl=True):
super(Adapter, self).__init__()
self.unshuffle = nn.PixelUnshuffle(8)
unshuffle = 8
resblock_no_downsample = []
resblock_downsample = [3, 2, 1]
self.xl = xl
if self.xl:
unshuffle = 16
resblock_no_downsample = [1]
resblock_downsample = [2]
self.input_channels = cin // (unshuffle * unshuffle)
self.unshuffle = nn.PixelUnshuffle(unshuffle)
self.channels = channels
self.nums_rb = nums_rb
self.body = []
for i in range(len(channels)):
for j in range(nums_rb):
if (i != 0) and (j == 0):
if (i in resblock_downsample) and (j == 0):
self.body.append(
ResnetBlock(channels[i - 1], channels[i], down=True, ksize=ksize, sk=sk, use_conv=use_conv))
elif (i in resblock_no_downsample) and (j == 0):
self.body.append(
ResnetBlock(channels[i - 1], channels[i], down=False, ksize=ksize, sk=sk, use_conv=use_conv))
else:
self.body.append(
ResnetBlock(channels[i], channels[i], down=False, ksize=ksize, sk=sk, use_conv=use_conv))
@@ -128,8 +141,16 @@ class Adapter(nn.Module):
for j in range(self.nums_rb):
idx = i * self.nums_rb + j
x = self.body[idx](x)
features.append(None)
features.append(None)
if self.xl:
features.append(None)
if i == 0:
features.append(None)
features.append(None)
if i == 2:
features.append(None)
else:
features.append(None)
features.append(None)
features.append(x)
return features
@@ -243,10 +264,14 @@ class extractor(nn.Module):
class Adapter_light(nn.Module):
def __init__(self, channels=[320, 640, 1280, 1280], nums_rb=3, cin=64):
super(Adapter_light, self).__init__()
self.unshuffle = nn.PixelUnshuffle(8)
unshuffle = 8
self.unshuffle = nn.PixelUnshuffle(unshuffle)
self.input_channels = cin // (unshuffle * unshuffle)
self.channels = channels
self.nums_rb = nums_rb
self.body = []
self.xl = False
for i in range(len(channels)):
if i == 0:
self.body.append(extractor(in_c=cin, inter_c=channels[i]//4, out_c=channels[i], nums_rb=nums_rb, down=False))