Added context window support to core sampling code (#9238)

* Added initial support for basic context windows - in progress

* Add prepare_sampling wrapper for context window to more accurately estimate latent memory requirements, fixed merging wrappers/callbacks dicts in prepare_model_patcher

* Made context windows compatible with different dimensions; works for WAN, but results are bad

* Fix comfy.patcher_extension.merge_nested_dicts calls in prepare_model_patcher in sampler_helpers.py

* Considering adding some callbacks to context window code to allow extensions of behavior without the need to rewrite code

* Made dim slicing cleaner

* Add Wan Context WIndows node for testing

* Made context schedule and fuse method functions be stored on the handler instead of needing to be registered in core code to be found

* Moved some code around between node_context_windows.py and context_windows.py

* Change manual context window nodes names/ids

* Added callbacks to IndexListContexHandler

* Adjusted default values for context_length and context_overlap, made schema.inputs definition for WAN Context Windows less annoying

* Make get_resized_cond more robust for various dim sizes

* Fix typo

* Another small fix
This commit is contained in:
Jedrzej Kosinski
2025-08-13 18:33:05 -07:00
committed by GitHub
parent c991a5da65
commit e4f7ea105f
5 changed files with 639 additions and 5 deletions

View File

@@ -16,6 +16,7 @@ import comfy.sampler_helpers
import comfy.model_patcher
import comfy.patcher_extension
import comfy.hooks
import comfy.context_windows
import scipy.stats
import numpy
@@ -198,14 +199,20 @@ def finalize_default_conds(model: 'BaseModel', hooked_to_run: dict[comfy.hooks.H
hooked_to_run.setdefault(p.hooks, list())
hooked_to_run[p.hooks] += [(p, i)]
def calc_cond_batch(model: 'BaseModel', conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options):
def calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options: dict[str]):
handler: comfy.context_windows.ContextHandlerABC = model_options.get("context_handler", None)
if handler is None or not handler.should_use_context(model, conds, x_in, timestep, model_options):
return _calc_cond_batch_outer(model, conds, x_in, timestep, model_options)
return handler.execute(_calc_cond_batch_outer, model, conds, x_in, timestep, model_options)
def _calc_cond_batch_outer(model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options):
executor = comfy.patcher_extension.WrapperExecutor.new_executor(
_calc_cond_batch,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.CALC_COND_BATCH, model_options, is_model_options=True)
)
return executor.execute(model, conds, x_in, timestep, model_options)
def _calc_cond_batch(model: 'BaseModel', conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options):
def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options):
out_conds = []
out_counts = []
# separate conds by matching hooks