mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-09-14 21:45:06 +00:00
Add a way to pass options to the transformers blocks.
This commit is contained in:
@@ -78,7 +78,7 @@ class DDIMSampler(object):
|
||||
dynamic_threshold=None,
|
||||
ucg_schedule=None,
|
||||
denoise_function=None,
|
||||
cond_concat=None,
|
||||
extra_args=None,
|
||||
to_zero=True,
|
||||
end_step=None,
|
||||
**kwargs
|
||||
@@ -101,7 +101,7 @@ class DDIMSampler(object):
|
||||
dynamic_threshold=dynamic_threshold,
|
||||
ucg_schedule=ucg_schedule,
|
||||
denoise_function=denoise_function,
|
||||
cond_concat=cond_concat,
|
||||
extra_args=extra_args,
|
||||
to_zero=to_zero,
|
||||
end_step=end_step
|
||||
)
|
||||
@@ -174,7 +174,7 @@ class DDIMSampler(object):
|
||||
dynamic_threshold=dynamic_threshold,
|
||||
ucg_schedule=ucg_schedule,
|
||||
denoise_function=None,
|
||||
cond_concat=None
|
||||
extra_args=None
|
||||
)
|
||||
return samples, intermediates
|
||||
|
||||
@@ -185,7 +185,7 @@ class DDIMSampler(object):
|
||||
mask=None, x0=None, img_callback=None, log_every_t=100,
|
||||
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
||||
unconditional_guidance_scale=1., unconditional_conditioning=None, dynamic_threshold=None,
|
||||
ucg_schedule=None, denoise_function=None, cond_concat=None, to_zero=True, end_step=None):
|
||||
ucg_schedule=None, denoise_function=None, extra_args=None, to_zero=True, end_step=None):
|
||||
device = self.model.betas.device
|
||||
b = shape[0]
|
||||
if x_T is None:
|
||||
@@ -225,7 +225,7 @@ class DDIMSampler(object):
|
||||
corrector_kwargs=corrector_kwargs,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=unconditional_conditioning,
|
||||
dynamic_threshold=dynamic_threshold, denoise_function=denoise_function, cond_concat=cond_concat)
|
||||
dynamic_threshold=dynamic_threshold, denoise_function=denoise_function, extra_args=extra_args)
|
||||
img, pred_x0 = outs
|
||||
if callback: callback(i)
|
||||
if img_callback: img_callback(pred_x0, i)
|
||||
@@ -249,11 +249,11 @@ class DDIMSampler(object):
|
||||
def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
|
||||
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
||||
unconditional_guidance_scale=1., unconditional_conditioning=None,
|
||||
dynamic_threshold=None, denoise_function=None, cond_concat=None):
|
||||
dynamic_threshold=None, denoise_function=None, extra_args=None):
|
||||
b, *_, device = *x.shape, x.device
|
||||
|
||||
if denoise_function is not None:
|
||||
model_output = denoise_function(self.model.apply_model, x, t, unconditional_conditioning, c, unconditional_guidance_scale, cond_concat)
|
||||
model_output = denoise_function(self.model.apply_model, x, t, **extra_args)
|
||||
elif unconditional_conditioning is None or unconditional_guidance_scale == 1.:
|
||||
model_output = self.model.apply_model(x, t, c)
|
||||
else:
|
||||
|
@@ -1317,12 +1317,12 @@ class DiffusionWrapper(torch.nn.Module):
|
||||
self.conditioning_key = conditioning_key
|
||||
assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm', 'hybrid-adm', 'crossattn-adm']
|
||||
|
||||
def forward(self, x, t, c_concat: list = None, c_crossattn: list = None, c_adm=None, control=None):
|
||||
def forward(self, x, t, c_concat: list = None, c_crossattn: list = None, c_adm=None, control=None, transformer_options={}):
|
||||
if self.conditioning_key is None:
|
||||
out = self.diffusion_model(x, t, control=control)
|
||||
out = self.diffusion_model(x, t, control=control, transformer_options=transformer_options)
|
||||
elif self.conditioning_key == 'concat':
|
||||
xc = torch.cat([x] + c_concat, dim=1)
|
||||
out = self.diffusion_model(xc, t, control=control)
|
||||
out = self.diffusion_model(xc, t, control=control, transformer_options=transformer_options)
|
||||
elif self.conditioning_key == 'crossattn':
|
||||
if not self.sequential_cross_attn:
|
||||
cc = torch.cat(c_crossattn, 1)
|
||||
@@ -1332,25 +1332,25 @@ class DiffusionWrapper(torch.nn.Module):
|
||||
# TorchScript changes names of the arguments
|
||||
# with argument cc defined as context=cc scripted model will produce
|
||||
# an error: RuntimeError: forward() is missing value for argument 'argument_3'.
|
||||
out = self.scripted_diffusion_model(x, t, cc, control=control)
|
||||
out = self.scripted_diffusion_model(x, t, cc, control=control, transformer_options=transformer_options)
|
||||
else:
|
||||
out = self.diffusion_model(x, t, context=cc, control=control)
|
||||
out = self.diffusion_model(x, t, context=cc, control=control, transformer_options=transformer_options)
|
||||
elif self.conditioning_key == 'hybrid':
|
||||
xc = torch.cat([x] + c_concat, dim=1)
|
||||
cc = torch.cat(c_crossattn, 1)
|
||||
out = self.diffusion_model(xc, t, context=cc, control=control)
|
||||
out = self.diffusion_model(xc, t, context=cc, control=control, transformer_options=transformer_options)
|
||||
elif self.conditioning_key == 'hybrid-adm':
|
||||
assert c_adm is not None
|
||||
xc = torch.cat([x] + c_concat, dim=1)
|
||||
cc = torch.cat(c_crossattn, 1)
|
||||
out = self.diffusion_model(xc, t, context=cc, y=c_adm, control=control)
|
||||
out = self.diffusion_model(xc, t, context=cc, y=c_adm, control=control, transformer_options=transformer_options)
|
||||
elif self.conditioning_key == 'crossattn-adm':
|
||||
assert c_adm is not None
|
||||
cc = torch.cat(c_crossattn, 1)
|
||||
out = self.diffusion_model(x, t, context=cc, y=c_adm, control=control)
|
||||
out = self.diffusion_model(x, t, context=cc, y=c_adm, control=control, transformer_options=transformer_options)
|
||||
elif self.conditioning_key == 'adm':
|
||||
cc = c_crossattn[0]
|
||||
out = self.diffusion_model(x, t, y=cc, control=control)
|
||||
out = self.diffusion_model(x, t, y=cc, control=control, transformer_options=transformer_options)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
Reference in New Issue
Block a user