diff --git a/comfy/controlnet.py b/comfy/controlnet.py index 0029a4987..31227ae31 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -64,6 +64,18 @@ class StrengthType(Enum): CONSTANT = 1 LINEAR_UP = 2 +class ControlIsolation: + '''Temporarily set a ControlBase object's previous_controlnet to None to prevent cascading calls.''' + def __init__(self, control: ControlBase): + self.control = control + self.orig_previous_controlnet = control.previous_controlnet + + def __enter__(self): + self.control.previous_controlnet = None + + def __exit__(self, *args): + self.control.previous_controlnet = self.orig_previous_controlnet + class ControlBase: def __init__(self): self.cond_hint_original = None @@ -112,7 +124,9 @@ class ControlBase: def cleanup(self): if self.previous_controlnet is not None: self.previous_controlnet.cleanup() - + for device_cnet in self.multigpu_clones.values(): + with ControlIsolation(device_cnet): + device_cnet.cleanup() self.cond_hint = None self.extra_concat = None self.timestep_range = None @@ -120,19 +134,15 @@ class ControlBase: def get_models(self): out = [] for device_cnet in self.multigpu_clones.values(): - out += device_cnet.get_models() + out += device_cnet.get_models_only_self() if self.previous_controlnet is not None: out += self.previous_controlnet.get_models() return out def get_models_only_self(self): 'Calls get_models, but temporarily sets previous_controlnet to None.' - try: - orig_previous_controlnet = self.previous_controlnet - self.previous_controlnet = None + with ControlIsolation(self): return self.get_models() - finally: - self.previous_controlnet = orig_previous_controlnet def get_instance_for_device(self, device): 'Returns instance of this Control object intended for selected device.'