diff --git a/comfy/controlnet.py b/comfy/controlnet.py index ee29251b9..0029a4987 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -15,13 +15,14 @@ You should have received a copy of the GNU General Public License along with this program. If not, see . """ - +from __future__ import annotations import torch from enum import Enum import math import os import logging +import copy import comfy.utils import comfy.model_management import comfy.model_detection @@ -36,7 +37,7 @@ import comfy.cldm.mmdit import comfy.ldm.hydit.controlnet import comfy.ldm.flux.controlnet import comfy.cldm.dit_embedder -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Union if TYPE_CHECKING: from comfy.hooks import HookGroup @@ -76,7 +77,7 @@ class ControlBase: self.compression_ratio = 8 self.upscale_algorithm = 'nearest-exact' self.extra_args = {} - self.previous_controlnet = None + self.previous_controlnet: Union[ControlBase, None] = None self.extra_conds = [] self.strength_type = StrengthType.CONSTANT self.concat_mask = False @@ -84,6 +85,7 @@ class ControlBase: self.extra_concat = None self.extra_hooks: HookGroup = None self.preprocess_image = lambda a: a + self.multigpu_clones: dict[torch.device, ControlBase] = {} def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0), vae=None, extra_concat=[]): self.cond_hint_original = cond_hint @@ -117,10 +119,33 @@ class ControlBase: def get_models(self): out = [] + for device_cnet in self.multigpu_clones.values(): + out += device_cnet.get_models() if self.previous_controlnet is not None: out += self.previous_controlnet.get_models() return out + def get_models_only_self(self): + 'Calls get_models, but temporarily sets previous_controlnet to None.' + try: + orig_previous_controlnet = self.previous_controlnet + self.previous_controlnet = None + return self.get_models() + finally: + self.previous_controlnet = orig_previous_controlnet + + def get_instance_for_device(self, device): + 'Returns instance of this Control object intended for selected device.' + return self.multigpu_clones.get(device, self) + + def deepclone_multigpu(self, load_device, autoregister=False): + ''' + Create deep clone of Control object where model(s) is set to other devices. + + When autoregister is set to True, the deep clone is also added to multigpu_clones dict. + ''' + raise NotImplementedError("Classes inheriting from ControlBase should define their own deepclone_multigpu funtion.") + def get_extra_hooks(self): out = [] if self.extra_hooks is not None: @@ -129,7 +154,7 @@ class ControlBase: out += self.previous_controlnet.get_extra_hooks() return out - def copy_to(self, c): + def copy_to(self, c: ControlBase): c.cond_hint_original = self.cond_hint_original c.strength = self.strength c.timestep_percent_range = self.timestep_percent_range @@ -280,6 +305,14 @@ class ControlNet(ControlBase): self.copy_to(c) return c + def deepclone_multigpu(self, load_device, autoregister=False): + c = self.copy() + c.control_model = copy.deepcopy(c.control_model) + c.control_model_wrapped = comfy.model_patcher.ModelPatcher(c.control_model, load_device=load_device, offload_device=comfy.model_management.unet_offload_device()) + if autoregister: + self.multigpu_clones[load_device] = c + return c + def get_models(self): out = super().get_models() out.append(self.control_model_wrapped) @@ -809,6 +842,14 @@ class T2IAdapter(ControlBase): c = T2IAdapter(self.t2i_model, self.channels_in, self.compression_ratio, self.upscale_algorithm) self.copy_to(c) return c + + def deepclone_multigpu(self, load_device, autoregister=False): + c = self.copy() + c.t2i_model = copy.deepcopy(c.t2i_model) + c.device = load_device + if autoregister: + self.multigpu_clones[load_device] = c + return c def load_t2i_adapter(t2i_data, model_options={}): #TODO: model_options compression_ratio = 8 diff --git a/comfy/samplers.py b/comfy/samplers.py index cf97b9820..27d875709 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -1,4 +1,6 @@ from __future__ import annotations + +import comfy.model_management from .k_diffusion import sampling as k_diffusion_sampling from .extra_samplers import uni_pc from typing import TYPE_CHECKING, Callable, NamedTuple @@ -427,7 +429,7 @@ def _calc_cond_batch_multigpu(model: BaseModel, conds: list[list[dict]], x_in: t cond_or_uncond = [] uuids = [] area = [] - control = None + control: ControlBase = None patches = None for x in to_batch: o = x @@ -473,7 +475,8 @@ def _calc_cond_batch_multigpu(model: BaseModel, conds: list[list[dict]], x_in: t c['transformer_options'] = transformer_options if control is not None: - c['control'] = control.get_control(input_x, timestep_, c, len(cond_or_uncond), transformer_options) + device_control = control.get_instance_for_device(device) + c['control'] = device_control.get_control(input_x, timestep_, c, len(cond_or_uncond), transformer_options) if 'model_function_wrapper' in model_options: output = model_options['model_function_wrapper'](model_current.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).to(output_device).chunk(batch_chunks) @@ -799,6 +802,8 @@ def pre_run_control(model, conds): percent_to_timestep_function = lambda a: s.percent_to_sigma(a) if 'control' in x: x['control'].pre_run(model, percent_to_timestep_function) + for device_cnet in x['control'].multigpu_clones.values(): + device_cnet.pre_run(model, percent_to_timestep_function) def apply_empty_x_to_equal_area(conds, uncond, name, uncond_fill_func): cond_cnets = [] @@ -1080,6 +1085,48 @@ def cast_to_load_options(model_options: dict[str], device=None, dtype=None): wc_list[i] = wc_list[i].to(cast) +def preprocess_multigpu_conds(conds: dict[str, list[dict[str]]], model_options: dict[str], model: ModelPatcher): + '''If multigpu acceleration required, creates deepclones of ControlNets and GLIGEN per device.''' + multigpu_models: list[ModelPatcher] = model.get_additional_models_with_key("multigpu") + if len(multigpu_models) == 0: + return + extra_devices = [x.load_device for x in multigpu_models] + # handle controlnets + controlnets: set[ControlBase] = set() + for k in conds: + for kk in conds[k]: + if 'control' in kk: + controlnets.add(kk['control']) + if len(controlnets) > 0: + # first, unload all controlnet clones + for cnet in list(controlnets): + cnet_models = cnet.get_models() + for cm in cnet_models: + comfy.model_management.unload_model_and_clones(cm, unload_additional_models=True) + + # next, make sure each controlnet has a deepclone for all relevant devices + for cnet in controlnets: + curr_cnet = cnet + while curr_cnet is not None: + for device in extra_devices: + if device not in curr_cnet.multigpu_clones: + curr_cnet.deepclone_multigpu(device, autoregister=True) + curr_cnet = curr_cnet.previous_controlnet + # since all device clones are now present, recreate the linked list for cloned cnets per device + for cnet in controlnets: + curr_cnet = cnet + while curr_cnet is not None: + prev_cnet = curr_cnet.previous_controlnet + for device in extra_devices: + device_cnet = curr_cnet.get_instance_for_device(device) + prev_device_cnet = None + if prev_cnet is not None: + prev_device_cnet = prev_cnet.get_instance_for_device(device) + device_cnet.set_previous_controlnet(prev_device_cnet) + curr_cnet = prev_cnet + # TODO: handle gligen + + class CFGGuider: def __init__(self, model_patcher: ModelPatcher): self.model_patcher = model_patcher @@ -1122,6 +1169,7 @@ class CFGGuider: return self.inner_model.process_latent_out(samples.to(torch.float32)) def outer_sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callback=None, disable_pbar=False, seed=None): + preprocess_multigpu_conds(self.conds, self.model_options, self.model_patcher) self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds, self.model_options) device = self.model_patcher.load_device