Improve error handling for multigpu threads

This commit is contained in:
Jedrzej Kosinski 2025-06-24 00:48:51 -05:00
parent 9726eac475
commit 44e053c26d

View File

@ -3,7 +3,7 @@ from __future__ import annotations
import comfy.model_management import comfy.model_management
from .k_diffusion import sampling as k_diffusion_sampling from .k_diffusion import sampling as k_diffusion_sampling
from .extra_samplers import uni_pc from .extra_samplers import uni_pc
from typing import TYPE_CHECKING, Callable, NamedTuple from typing import TYPE_CHECKING, Callable, NamedTuple, Any
if TYPE_CHECKING: if TYPE_CHECKING:
from comfy.model_patcher import ModelPatcher from comfy.model_patcher import ModelPatcher
from comfy.model_base import BaseModel from comfy.model_base import BaseModel
@ -428,74 +428,85 @@ def _calc_cond_batch_multigpu(model: BaseModel, conds: list[list[dict]], x_in: t
if batched_to_run_length >= conds_per_device: if batched_to_run_length >= conds_per_device:
index_device += 1 index_device += 1
thread_result = collections.namedtuple('thread_result', ['output', 'mult', 'area', 'batch_chunks', 'cond_or_uncond']) class thread_result(NamedTuple):
output: Any
mult: Any
area: Any
batch_chunks: int
cond_or_uncond: Any
error: Exception = None
def _handle_batch(device: torch.device, batch_tuple: tuple[comfy.hooks.HookGroup, tuple], results: list[thread_result]): def _handle_batch(device: torch.device, batch_tuple: tuple[comfy.hooks.HookGroup, tuple], results: list[thread_result]):
model_current: BaseModel = model_options["multigpu_clones"][device].model try:
# run every hooked_to_run separately model_current: BaseModel = model_options["multigpu_clones"][device].model
with torch.no_grad(): # run every hooked_to_run separately
for hooks, to_batch in batch_tuple: with torch.no_grad():
input_x = [] for hooks, to_batch in batch_tuple:
mult = [] input_x = []
c = [] mult = []
cond_or_uncond = [] c = []
uuids = [] cond_or_uncond = []
area = [] uuids = []
control: ControlBase = None area = []
patches = None control: ControlBase = None
for x in to_batch: patches = None
o = x for x in to_batch:
p = o[0] o = x
input_x.append(p.input_x) p = o[0]
mult.append(p.mult) input_x.append(p.input_x)
c.append(p.conditioning) mult.append(p.mult)
area.append(p.area) c.append(p.conditioning)
cond_or_uncond.append(o[1]) area.append(p.area)
uuids.append(p.uuid) cond_or_uncond.append(o[1])
control = p.control uuids.append(p.uuid)
patches = p.patches control = p.control
patches = p.patches
batch_chunks = len(cond_or_uncond) batch_chunks = len(cond_or_uncond)
input_x = torch.cat(input_x).to(device) input_x = torch.cat(input_x).to(device)
c = cond_cat(c, device=device) c = cond_cat(c, device=device)
timestep_ = torch.cat([timestep.to(device)] * batch_chunks) timestep_ = torch.cat([timestep.to(device)] * batch_chunks)
transformer_options = model_current.current_patcher.apply_hooks(hooks=hooks) transformer_options = model_current.current_patcher.apply_hooks(hooks=hooks)
if 'transformer_options' in model_options: if 'transformer_options' in model_options:
transformer_options = comfy.patcher_extension.merge_nested_dicts(transformer_options, transformer_options = comfy.patcher_extension.merge_nested_dicts(transformer_options,
model_options['transformer_options'], model_options['transformer_options'],
copy_dict1=False) copy_dict1=False)
if patches is not None: if patches is not None:
# TODO: replace with merge_nested_dicts function # TODO: replace with merge_nested_dicts function
if "patches" in transformer_options: if "patches" in transformer_options:
cur_patches = transformer_options["patches"].copy() cur_patches = transformer_options["patches"].copy()
for p in patches: for p in patches:
if p in cur_patches: if p in cur_patches:
cur_patches[p] = cur_patches[p] + patches[p] cur_patches[p] = cur_patches[p] + patches[p]
else: else:
cur_patches[p] = patches[p] cur_patches[p] = patches[p]
transformer_options["patches"] = cur_patches transformer_options["patches"] = cur_patches
else:
transformer_options["patches"] = patches
transformer_options["cond_or_uncond"] = cond_or_uncond[:]
transformer_options["uuids"] = uuids[:]
transformer_options["sigmas"] = timestep
transformer_options["sample_sigmas"] = transformer_options["sample_sigmas"].to(device)
transformer_options["multigpu_thread_device"] = device
cast_transformer_options(transformer_options, device=device)
c['transformer_options'] = transformer_options
if control is not None:
device_control = control.get_instance_for_device(device)
c['control'] = device_control.get_control(input_x, timestep_, c, len(cond_or_uncond), transformer_options)
if 'model_function_wrapper' in model_options:
output = model_options['model_function_wrapper'](model_current.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).to(output_device).chunk(batch_chunks)
else: else:
transformer_options["patches"] = patches output = model_current.apply_model(input_x, timestep_, **c).to(output_device).chunk(batch_chunks)
results.append(thread_result(output, mult, area, batch_chunks, cond_or_uncond))
transformer_options["cond_or_uncond"] = cond_or_uncond[:] except Exception as e:
transformer_options["uuids"] = uuids[:] results.append(thread_result(None, None, None, None, None, error=e))
transformer_options["sigmas"] = timestep raise
transformer_options["sample_sigmas"] = transformer_options["sample_sigmas"].to(device)
transformer_options["multigpu_thread_device"] = device
cast_transformer_options(transformer_options, device=device)
c['transformer_options'] = transformer_options
if control is not None:
device_control = control.get_instance_for_device(device)
c['control'] = device_control.get_control(input_x, timestep_, c, len(cond_or_uncond), transformer_options)
if 'model_function_wrapper' in model_options:
output = model_options['model_function_wrapper'](model_current.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).to(output_device).chunk(batch_chunks)
else:
output = model_current.apply_model(input_x, timestep_, **c).to(output_device).chunk(batch_chunks)
results.append(thread_result(output, mult, area, batch_chunks, cond_or_uncond))
results: list[thread_result] = [] results: list[thread_result] = []
@ -508,7 +519,9 @@ def _calc_cond_batch_multigpu(model: BaseModel, conds: list[list[dict]], x_in: t
for thread in threads: for thread in threads:
thread.join() thread.join()
for output, mult, area, batch_chunks, cond_or_uncond in results: for output, mult, area, batch_chunks, cond_or_uncond, error in results:
if error is not None:
raise error
for o in range(batch_chunks): for o in range(batch_chunks):
cond_index = cond_or_uncond[o] cond_index = cond_or_uncond[o]
a = area[o] a = area[o]