T2I adapter SDXL.

This commit is contained in:
comfyanonymous
2023-08-22 14:38:34 -04:00
parent f2a7cc9121
commit 85fde89d7f
2 changed files with 48 additions and 10 deletions

View File

@@ -1128,7 +1128,11 @@ class T2IAdapter(ControlBase):
self.t2i_model.cpu()
control_input = list(map(lambda a: None if a is None else a.clone(), self.control_input))
return self.control_merge(control_input, None, control_prev, x_noisy.dtype)
mid = None
if self.t2i_model.xl == True:
mid = control_input[-1:]
control_input = control_input[:-1]
return self.control_merge(control_input, mid, control_prev, x_noisy.dtype)
def copy(self):
c = T2IAdapter(self.t2i_model, self.channels_in)
@@ -1151,11 +1155,20 @@ def load_t2i_adapter(t2i_data):
down_opts = list(filter(lambda a: a.endswith("down_opt.op.weight"), keys))
if len(down_opts) > 0:
use_conv = True
model_ad = adapter.Adapter(cin=cin, channels=[channel, channel*2, channel*4, channel*4][:4], nums_rb=2, ksize=ksize, sk=True, use_conv=use_conv)
xl = False
if cin == 256:
xl = True
model_ad = adapter.Adapter(cin=cin, channels=[channel, channel*2, channel*4, channel*4][:4], nums_rb=2, ksize=ksize, sk=True, use_conv=use_conv, xl=xl)
else:
return None
model_ad.load_state_dict(t2i_data)
return T2IAdapter(model_ad, cin // 64)
missing, unexpected = model_ad.load_state_dict(t2i_data)
if len(missing) > 0:
print("t2i missing", missing)
if len(unexpected) > 0:
print("t2i unexpected", unexpected)
return T2IAdapter(model_ad, model_ad.input_channels)
class StyleModel: