Add ControlNet support.

This commit is contained in:
comfyanonymous
2023-02-16 10:38:08 -05:00
parent bc69fb5245
commit 4efa67fa12
9 changed files with 580 additions and 63 deletions

View File

@@ -1320,12 +1320,12 @@ class DiffusionWrapper(torch.nn.Module):
self.conditioning_key = conditioning_key
assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm', 'hybrid-adm', 'crossattn-adm']
def forward(self, x, t, c_concat: list = None, c_crossattn: list = None, c_adm=None):
def forward(self, x, t, c_concat: list = None, c_crossattn: list = None, c_adm=None, control=None):
if self.conditioning_key is None:
out = self.diffusion_model(x, t)
out = self.diffusion_model(x, t, control=control)
elif self.conditioning_key == 'concat':
xc = torch.cat([x] + c_concat, dim=1)
out = self.diffusion_model(xc, t)
out = self.diffusion_model(xc, t, control=control)
elif self.conditioning_key == 'crossattn':
if not self.sequential_cross_attn:
cc = torch.cat(c_crossattn, 1)
@@ -1335,25 +1335,25 @@ class DiffusionWrapper(torch.nn.Module):
# TorchScript changes names of the arguments
# with argument cc defined as context=cc scripted model will produce
# an error: RuntimeError: forward() is missing value for argument 'argument_3'.
out = self.scripted_diffusion_model(x, t, cc)
out = self.scripted_diffusion_model(x, t, cc, control=control)
else:
out = self.diffusion_model(x, t, context=cc)
out = self.diffusion_model(x, t, context=cc, control=control)
elif self.conditioning_key == 'hybrid':
xc = torch.cat([x] + c_concat, dim=1)
cc = torch.cat(c_crossattn, 1)
out = self.diffusion_model(xc, t, context=cc)
out = self.diffusion_model(xc, t, context=cc, control=control)
elif self.conditioning_key == 'hybrid-adm':
assert c_adm is not None
xc = torch.cat([x] + c_concat, dim=1)
cc = torch.cat(c_crossattn, 1)
out = self.diffusion_model(xc, t, context=cc, y=c_adm)
out = self.diffusion_model(xc, t, context=cc, y=c_adm, control=control)
elif self.conditioning_key == 'crossattn-adm':
assert c_adm is not None
cc = torch.cat(c_crossattn, 1)
out = self.diffusion_model(x, t, context=cc, y=c_adm)
out = self.diffusion_model(x, t, context=cc, y=c_adm, control=control)
elif self.conditioning_key == 'adm':
cc = c_crossattn[0]
out = self.diffusion_model(x, t, y=cc)
out = self.diffusion_model(x, t, y=cc, control=control)
else:
raise NotImplementedError()