mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-09-13 04:55:53 +00:00
Add SetUnionControlNetType to set the type of the union controlnet model.
This commit is contained in:
@@ -45,6 +45,7 @@ class ControlBase:
|
||||
self.timestep_range = None
|
||||
self.compression_ratio = 8
|
||||
self.upscale_algorithm = 'nearest-exact'
|
||||
self.extra_args = {}
|
||||
|
||||
if device is None:
|
||||
device = comfy.model_management.get_torch_device()
|
||||
@@ -90,6 +91,7 @@ class ControlBase:
|
||||
c.compression_ratio = self.compression_ratio
|
||||
c.upscale_algorithm = self.upscale_algorithm
|
||||
c.latent_format = self.latent_format
|
||||
c.extra_args = self.extra_args.copy()
|
||||
c.vae = self.vae
|
||||
|
||||
def inference_memory_requirements(self, dtype):
|
||||
@@ -135,6 +137,10 @@ class ControlBase:
|
||||
o[i] = prev_val + o[i] #TODO: change back to inplace add if shared tensors stop being an issue
|
||||
return out
|
||||
|
||||
def set_extra_arg(self, argument, value=None):
|
||||
self.extra_args[argument] = value
|
||||
|
||||
|
||||
class ControlNet(ControlBase):
|
||||
def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, device=None, load_device=None, manual_cast_dtype=None):
|
||||
super().__init__(device)
|
||||
@@ -191,7 +197,7 @@ class ControlNet(ControlBase):
|
||||
timestep = self.model_sampling_current.timestep(t)
|
||||
x_noisy = self.model_sampling_current.calculate_input(t, x_noisy)
|
||||
|
||||
control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(dtype), y=y)
|
||||
control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(dtype), y=y, **self.extra_args)
|
||||
return self.control_merge(control, control_prev, output_dtype)
|
||||
|
||||
def copy(self):
|
||||
|
Reference in New Issue
Block a user