diff --git a/comfy/samplers.py b/comfy/samplers.py index 260527661..90cce078d 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -3,7 +3,7 @@ from __future__ import annotations import comfy.model_management from .k_diffusion import sampling as k_diffusion_sampling from .extra_samplers import uni_pc -from typing import TYPE_CHECKING, Callable, NamedTuple +from typing import TYPE_CHECKING, Callable, NamedTuple, Any if TYPE_CHECKING: from comfy.model_patcher import ModelPatcher from comfy.model_base import BaseModel @@ -428,74 +428,85 @@ def _calc_cond_batch_multigpu(model: BaseModel, conds: list[list[dict]], x_in: t if batched_to_run_length >= conds_per_device: index_device += 1 - thread_result = collections.namedtuple('thread_result', ['output', 'mult', 'area', 'batch_chunks', 'cond_or_uncond']) + class thread_result(NamedTuple): + output: Any + mult: Any + area: Any + batch_chunks: int + cond_or_uncond: Any + error: Exception = None + def _handle_batch(device: torch.device, batch_tuple: tuple[comfy.hooks.HookGroup, tuple], results: list[thread_result]): - model_current: BaseModel = model_options["multigpu_clones"][device].model - # run every hooked_to_run separately - with torch.no_grad(): - for hooks, to_batch in batch_tuple: - input_x = [] - mult = [] - c = [] - cond_or_uncond = [] - uuids = [] - area = [] - control: ControlBase = None - patches = None - for x in to_batch: - o = x - p = o[0] - input_x.append(p.input_x) - mult.append(p.mult) - c.append(p.conditioning) - area.append(p.area) - cond_or_uncond.append(o[1]) - uuids.append(p.uuid) - control = p.control - patches = p.patches + try: + model_current: BaseModel = model_options["multigpu_clones"][device].model + # run every hooked_to_run separately + with torch.no_grad(): + for hooks, to_batch in batch_tuple: + input_x = [] + mult = [] + c = [] + cond_or_uncond = [] + uuids = [] + area = [] + control: ControlBase = None + patches = None + for x in to_batch: + o = x + p = o[0] + input_x.append(p.input_x) + mult.append(p.mult) + c.append(p.conditioning) + area.append(p.area) + cond_or_uncond.append(o[1]) + uuids.append(p.uuid) + control = p.control + patches = p.patches - batch_chunks = len(cond_or_uncond) - input_x = torch.cat(input_x).to(device) - c = cond_cat(c, device=device) - timestep_ = torch.cat([timestep.to(device)] * batch_chunks) + batch_chunks = len(cond_or_uncond) + input_x = torch.cat(input_x).to(device) + c = cond_cat(c, device=device) + timestep_ = torch.cat([timestep.to(device)] * batch_chunks) - transformer_options = model_current.current_patcher.apply_hooks(hooks=hooks) - if 'transformer_options' in model_options: - transformer_options = comfy.patcher_extension.merge_nested_dicts(transformer_options, - model_options['transformer_options'], - copy_dict1=False) + transformer_options = model_current.current_patcher.apply_hooks(hooks=hooks) + if 'transformer_options' in model_options: + transformer_options = comfy.patcher_extension.merge_nested_dicts(transformer_options, + model_options['transformer_options'], + copy_dict1=False) - if patches is not None: - # TODO: replace with merge_nested_dicts function - if "patches" in transformer_options: - cur_patches = transformer_options["patches"].copy() - for p in patches: - if p in cur_patches: - cur_patches[p] = cur_patches[p] + patches[p] - else: - cur_patches[p] = patches[p] - transformer_options["patches"] = cur_patches + if patches is not None: + # TODO: replace with merge_nested_dicts function + if "patches" in transformer_options: + cur_patches = transformer_options["patches"].copy() + for p in patches: + if p in cur_patches: + cur_patches[p] = cur_patches[p] + patches[p] + else: + cur_patches[p] = patches[p] + transformer_options["patches"] = cur_patches + else: + transformer_options["patches"] = patches + + transformer_options["cond_or_uncond"] = cond_or_uncond[:] + transformer_options["uuids"] = uuids[:] + transformer_options["sigmas"] = timestep + transformer_options["sample_sigmas"] = transformer_options["sample_sigmas"].to(device) + transformer_options["multigpu_thread_device"] = device + + cast_transformer_options(transformer_options, device=device) + c['transformer_options'] = transformer_options + + if control is not None: + device_control = control.get_instance_for_device(device) + c['control'] = device_control.get_control(input_x, timestep_, c, len(cond_or_uncond), transformer_options) + + if 'model_function_wrapper' in model_options: + output = model_options['model_function_wrapper'](model_current.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).to(output_device).chunk(batch_chunks) else: - transformer_options["patches"] = patches - - transformer_options["cond_or_uncond"] = cond_or_uncond[:] - transformer_options["uuids"] = uuids[:] - transformer_options["sigmas"] = timestep - transformer_options["sample_sigmas"] = transformer_options["sample_sigmas"].to(device) - transformer_options["multigpu_thread_device"] = device - - cast_transformer_options(transformer_options, device=device) - c['transformer_options'] = transformer_options - - if control is not None: - device_control = control.get_instance_for_device(device) - c['control'] = device_control.get_control(input_x, timestep_, c, len(cond_or_uncond), transformer_options) - - if 'model_function_wrapper' in model_options: - output = model_options['model_function_wrapper'](model_current.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).to(output_device).chunk(batch_chunks) - else: - output = model_current.apply_model(input_x, timestep_, **c).to(output_device).chunk(batch_chunks) - results.append(thread_result(output, mult, area, batch_chunks, cond_or_uncond)) + output = model_current.apply_model(input_x, timestep_, **c).to(output_device).chunk(batch_chunks) + results.append(thread_result(output, mult, area, batch_chunks, cond_or_uncond)) + except Exception as e: + results.append(thread_result(None, None, None, None, None, error=e)) + raise results: list[thread_result] = [] @@ -508,7 +519,9 @@ def _calc_cond_batch_multigpu(model: BaseModel, conds: list[list[dict]], x_in: t for thread in threads: thread.join() - for output, mult, area, batch_chunks, cond_or_uncond in results: + for output, mult, area, batch_chunks, cond_or_uncond, error in results: + if error is not None: + raise error for o in range(batch_chunks): cond_index = cond_or_uncond[o] a = area[o]