mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-09-10 11:35:40 +00:00
Added context window support to core sampling code (#9238)
* Added initial support for basic context windows - in progress * Add prepare_sampling wrapper for context window to more accurately estimate latent memory requirements, fixed merging wrappers/callbacks dicts in prepare_model_patcher * Made context windows compatible with different dimensions; works for WAN, but results are bad * Fix comfy.patcher_extension.merge_nested_dicts calls in prepare_model_patcher in sampler_helpers.py * Considering adding some callbacks to context window code to allow extensions of behavior without the need to rewrite code * Made dim slicing cleaner * Add Wan Context WIndows node for testing * Made context schedule and fuse method functions be stored on the handler instead of needing to be registered in core code to be found * Moved some code around between node_context_windows.py and context_windows.py * Change manual context window nodes names/ids * Added callbacks to IndexListContexHandler * Adjusted default values for context_length and context_overlap, made schema.inputs definition for WAN Context Windows less annoying * Make get_resized_cond more robust for various dim sizes * Fix typo * Another small fix
This commit is contained in:
@@ -16,6 +16,7 @@ import comfy.sampler_helpers
|
||||
import comfy.model_patcher
|
||||
import comfy.patcher_extension
|
||||
import comfy.hooks
|
||||
import comfy.context_windows
|
||||
import scipy.stats
|
||||
import numpy
|
||||
|
||||
@@ -198,14 +199,20 @@ def finalize_default_conds(model: 'BaseModel', hooked_to_run: dict[comfy.hooks.H
|
||||
hooked_to_run.setdefault(p.hooks, list())
|
||||
hooked_to_run[p.hooks] += [(p, i)]
|
||||
|
||||
def calc_cond_batch(model: 'BaseModel', conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options):
|
||||
def calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options: dict[str]):
|
||||
handler: comfy.context_windows.ContextHandlerABC = model_options.get("context_handler", None)
|
||||
if handler is None or not handler.should_use_context(model, conds, x_in, timestep, model_options):
|
||||
return _calc_cond_batch_outer(model, conds, x_in, timestep, model_options)
|
||||
return handler.execute(_calc_cond_batch_outer, model, conds, x_in, timestep, model_options)
|
||||
|
||||
def _calc_cond_batch_outer(model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options):
|
||||
executor = comfy.patcher_extension.WrapperExecutor.new_executor(
|
||||
_calc_cond_batch,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.CALC_COND_BATCH, model_options, is_model_options=True)
|
||||
)
|
||||
return executor.execute(model, conds, x_in, timestep, model_options)
|
||||
|
||||
def _calc_cond_batch(model: 'BaseModel', conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options):
|
||||
def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options):
|
||||
out_conds = []
|
||||
out_counts = []
|
||||
# separate conds by matching hooks
|
||||
|
Reference in New Issue
Block a user