mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-08-02 03:07:07 +00:00
Improve error handling for multigpu threads
This commit is contained in:
parent
9726eac475
commit
44e053c26d
@ -3,7 +3,7 @@ from __future__ import annotations
|
|||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
from .k_diffusion import sampling as k_diffusion_sampling
|
from .k_diffusion import sampling as k_diffusion_sampling
|
||||||
from .extra_samplers import uni_pc
|
from .extra_samplers import uni_pc
|
||||||
from typing import TYPE_CHECKING, Callable, NamedTuple
|
from typing import TYPE_CHECKING, Callable, NamedTuple, Any
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from comfy.model_patcher import ModelPatcher
|
from comfy.model_patcher import ModelPatcher
|
||||||
from comfy.model_base import BaseModel
|
from comfy.model_base import BaseModel
|
||||||
@ -428,74 +428,85 @@ def _calc_cond_batch_multigpu(model: BaseModel, conds: list[list[dict]], x_in: t
|
|||||||
if batched_to_run_length >= conds_per_device:
|
if batched_to_run_length >= conds_per_device:
|
||||||
index_device += 1
|
index_device += 1
|
||||||
|
|
||||||
thread_result = collections.namedtuple('thread_result', ['output', 'mult', 'area', 'batch_chunks', 'cond_or_uncond'])
|
class thread_result(NamedTuple):
|
||||||
|
output: Any
|
||||||
|
mult: Any
|
||||||
|
area: Any
|
||||||
|
batch_chunks: int
|
||||||
|
cond_or_uncond: Any
|
||||||
|
error: Exception = None
|
||||||
|
|
||||||
def _handle_batch(device: torch.device, batch_tuple: tuple[comfy.hooks.HookGroup, tuple], results: list[thread_result]):
|
def _handle_batch(device: torch.device, batch_tuple: tuple[comfy.hooks.HookGroup, tuple], results: list[thread_result]):
|
||||||
model_current: BaseModel = model_options["multigpu_clones"][device].model
|
try:
|
||||||
# run every hooked_to_run separately
|
model_current: BaseModel = model_options["multigpu_clones"][device].model
|
||||||
with torch.no_grad():
|
# run every hooked_to_run separately
|
||||||
for hooks, to_batch in batch_tuple:
|
with torch.no_grad():
|
||||||
input_x = []
|
for hooks, to_batch in batch_tuple:
|
||||||
mult = []
|
input_x = []
|
||||||
c = []
|
mult = []
|
||||||
cond_or_uncond = []
|
c = []
|
||||||
uuids = []
|
cond_or_uncond = []
|
||||||
area = []
|
uuids = []
|
||||||
control: ControlBase = None
|
area = []
|
||||||
patches = None
|
control: ControlBase = None
|
||||||
for x in to_batch:
|
patches = None
|
||||||
o = x
|
for x in to_batch:
|
||||||
p = o[0]
|
o = x
|
||||||
input_x.append(p.input_x)
|
p = o[0]
|
||||||
mult.append(p.mult)
|
input_x.append(p.input_x)
|
||||||
c.append(p.conditioning)
|
mult.append(p.mult)
|
||||||
area.append(p.area)
|
c.append(p.conditioning)
|
||||||
cond_or_uncond.append(o[1])
|
area.append(p.area)
|
||||||
uuids.append(p.uuid)
|
cond_or_uncond.append(o[1])
|
||||||
control = p.control
|
uuids.append(p.uuid)
|
||||||
patches = p.patches
|
control = p.control
|
||||||
|
patches = p.patches
|
||||||
|
|
||||||
batch_chunks = len(cond_or_uncond)
|
batch_chunks = len(cond_or_uncond)
|
||||||
input_x = torch.cat(input_x).to(device)
|
input_x = torch.cat(input_x).to(device)
|
||||||
c = cond_cat(c, device=device)
|
c = cond_cat(c, device=device)
|
||||||
timestep_ = torch.cat([timestep.to(device)] * batch_chunks)
|
timestep_ = torch.cat([timestep.to(device)] * batch_chunks)
|
||||||
|
|
||||||
transformer_options = model_current.current_patcher.apply_hooks(hooks=hooks)
|
transformer_options = model_current.current_patcher.apply_hooks(hooks=hooks)
|
||||||
if 'transformer_options' in model_options:
|
if 'transformer_options' in model_options:
|
||||||
transformer_options = comfy.patcher_extension.merge_nested_dicts(transformer_options,
|
transformer_options = comfy.patcher_extension.merge_nested_dicts(transformer_options,
|
||||||
model_options['transformer_options'],
|
model_options['transformer_options'],
|
||||||
copy_dict1=False)
|
copy_dict1=False)
|
||||||
|
|
||||||
if patches is not None:
|
if patches is not None:
|
||||||
# TODO: replace with merge_nested_dicts function
|
# TODO: replace with merge_nested_dicts function
|
||||||
if "patches" in transformer_options:
|
if "patches" in transformer_options:
|
||||||
cur_patches = transformer_options["patches"].copy()
|
cur_patches = transformer_options["patches"].copy()
|
||||||
for p in patches:
|
for p in patches:
|
||||||
if p in cur_patches:
|
if p in cur_patches:
|
||||||
cur_patches[p] = cur_patches[p] + patches[p]
|
cur_patches[p] = cur_patches[p] + patches[p]
|
||||||
else:
|
else:
|
||||||
cur_patches[p] = patches[p]
|
cur_patches[p] = patches[p]
|
||||||
transformer_options["patches"] = cur_patches
|
transformer_options["patches"] = cur_patches
|
||||||
|
else:
|
||||||
|
transformer_options["patches"] = patches
|
||||||
|
|
||||||
|
transformer_options["cond_or_uncond"] = cond_or_uncond[:]
|
||||||
|
transformer_options["uuids"] = uuids[:]
|
||||||
|
transformer_options["sigmas"] = timestep
|
||||||
|
transformer_options["sample_sigmas"] = transformer_options["sample_sigmas"].to(device)
|
||||||
|
transformer_options["multigpu_thread_device"] = device
|
||||||
|
|
||||||
|
cast_transformer_options(transformer_options, device=device)
|
||||||
|
c['transformer_options'] = transformer_options
|
||||||
|
|
||||||
|
if control is not None:
|
||||||
|
device_control = control.get_instance_for_device(device)
|
||||||
|
c['control'] = device_control.get_control(input_x, timestep_, c, len(cond_or_uncond), transformer_options)
|
||||||
|
|
||||||
|
if 'model_function_wrapper' in model_options:
|
||||||
|
output = model_options['model_function_wrapper'](model_current.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).to(output_device).chunk(batch_chunks)
|
||||||
else:
|
else:
|
||||||
transformer_options["patches"] = patches
|
output = model_current.apply_model(input_x, timestep_, **c).to(output_device).chunk(batch_chunks)
|
||||||
|
results.append(thread_result(output, mult, area, batch_chunks, cond_or_uncond))
|
||||||
transformer_options["cond_or_uncond"] = cond_or_uncond[:]
|
except Exception as e:
|
||||||
transformer_options["uuids"] = uuids[:]
|
results.append(thread_result(None, None, None, None, None, error=e))
|
||||||
transformer_options["sigmas"] = timestep
|
raise
|
||||||
transformer_options["sample_sigmas"] = transformer_options["sample_sigmas"].to(device)
|
|
||||||
transformer_options["multigpu_thread_device"] = device
|
|
||||||
|
|
||||||
cast_transformer_options(transformer_options, device=device)
|
|
||||||
c['transformer_options'] = transformer_options
|
|
||||||
|
|
||||||
if control is not None:
|
|
||||||
device_control = control.get_instance_for_device(device)
|
|
||||||
c['control'] = device_control.get_control(input_x, timestep_, c, len(cond_or_uncond), transformer_options)
|
|
||||||
|
|
||||||
if 'model_function_wrapper' in model_options:
|
|
||||||
output = model_options['model_function_wrapper'](model_current.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).to(output_device).chunk(batch_chunks)
|
|
||||||
else:
|
|
||||||
output = model_current.apply_model(input_x, timestep_, **c).to(output_device).chunk(batch_chunks)
|
|
||||||
results.append(thread_result(output, mult, area, batch_chunks, cond_or_uncond))
|
|
||||||
|
|
||||||
|
|
||||||
results: list[thread_result] = []
|
results: list[thread_result] = []
|
||||||
@ -508,7 +519,9 @@ def _calc_cond_batch_multigpu(model: BaseModel, conds: list[list[dict]], x_in: t
|
|||||||
for thread in threads:
|
for thread in threads:
|
||||||
thread.join()
|
thread.join()
|
||||||
|
|
||||||
for output, mult, area, batch_chunks, cond_or_uncond in results:
|
for output, mult, area, batch_chunks, cond_or_uncond, error in results:
|
||||||
|
if error is not None:
|
||||||
|
raise error
|
||||||
for o in range(batch_chunks):
|
for o in range(batch_chunks):
|
||||||
cond_index = cond_or_uncond[o]
|
cond_index = cond_or_uncond[o]
|
||||||
a = area[o]
|
a = area[o]
|
||||||
|
Loading…
x
Reference in New Issue
Block a user