From 953b906f63bda2d27870b6da394d0087ea29c146 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Tue, 2 Sep 2025 00:45:59 -0700 Subject: [PATCH] Implement Sortblock for single cond usage --- comfy_extras/nodes_sortblock.py | 501 +++++++++++++++++--------------- 1 file changed, 269 insertions(+), 232 deletions(-) diff --git a/comfy_extras/nodes_sortblock.py b/comfy_extras/nodes_sortblock.py index bbcd13da1..69b9ec477 100644 --- a/comfy_extras/nodes_sortblock.py +++ b/comfy_extras/nodes_sortblock.py @@ -1,15 +1,24 @@ from __future__ import annotations from typing import TYPE_CHECKING, Union -from scipy.sparse.linalg._dsolve.linsolve import is_pydata_spmatrix from comfy_api.latest import io, ComfyExtension import comfy.patcher_extension import logging import torch +import math import comfy.model_patcher if TYPE_CHECKING: from uuid import UUID + +def prepare_noise_wrapper(executor, *args, **kwargs): + try: + return executor(*args, **kwargs) + finally: + sb_holder: SortblockHolder = executor.class_obj.model_options["transformer_options"]["sortblock"] + sb_holder.step_count += 1 + + def outer_sample_wrapper(executor, *args, **kwargs): try: logging.info("Sortblock: inside outer_sample!") @@ -17,302 +26,330 @@ def outer_sample_wrapper(executor, *args, **kwargs): orig_model_options = guider.model_options guider.model_options = comfy.model_patcher.create_model_options_clone(orig_model_options) # clone and prepare timesteps - guider.model_options["transformer_options"]["sortblock"] = guider.model_options["transformer_options"]["sortblock"].clone().prepare_timesteps(guider.model_patcher.model.model_sampling) + sb_holder = guider.model_options["transformer_options"]["sortblock"] + guider.model_options["transformer_options"]["sortblock"] = sb_holder.clone().prepare_timesteps(guider.model_patcher.model.model_sampling) sb_holder: SortblockHolder = guider.model_options["transformer_options"]["sortblock"] - logging.info(f"Sortblock: enabled - threshold: {sb_holder.reuse_threshold}, start_percent: {sb_holder.start_percent}, end_percent: {sb_holder.end_percent}") + logging.info(f"Sortblock: enabled - threshold: {sb_holder.predict_ratio}, start_percent: {sb_holder.start_percent}, end_percent: {sb_holder.end_percent}") return executor(*args, **kwargs) finally: sb_holder = guider.model_options["transformer_options"]["sortblock"] - sb_holder.print_block_info(0) - - # import plotly.express as px - # fig = px.line(x=list(range(len(sb_holder.blocks))), y=[getattr(block, "__block_cache").cumulative_change_rate for block in sb_holder.blocks]) - - + logging.info(f"Sortblock: final step count: {sb_holder.step_count}") sb_holder.reset() guider.model_options = orig_model_options def model_forward_wrapper(executor, *args, **kwargs): - timestep: float = args[1] + # TODO: make work with batches of conds transformer_options: dict[str] = args[-1] if not isinstance(transformer_options, dict): transformer_options = kwargs.get("transformer_options") if not transformer_options: transformer_options = args[-2] - logging.info(f"Sortblock: inside model {executor.class_obj.__class__.__name__}") + sigmas = transformer_options["sigmas"] sb_holder: SortblockHolder = transformer_options["sortblock"] - sb_holder.update_should_do_sortblock(timestep) - sb_holder.update_is_past_end_timestep(timestep) + sb_holder.update_should_do_sortblock(sigmas) + + # if initial step, prepare everything for Sortblock if sb_holder.initial_step: - transformer_options["total_double_block"] = len(executor.class_obj.double_blocks) - transformer_options["total_single_block"] = len(executor.class_obj.single_blocks) - # save the original forwards on the blocks - logging.info(f"Sortblock: preparing {transformer_options['total_double_block']} double blocks and {transformer_options['total_single_block']} single blocks") - for block in executor.class_obj.double_blocks: - prepare_block(block, sb_holder) - for block in executor.class_obj.single_blocks: - prepare_block(block, sb_holder) - try: - return executor(*args, **kwargs) - finally: + logging.info(f"Sortblock: inside model {executor.class_obj.__class__.__name__}") + # TODO: generalize for other models + # these won't stick around past this step; should store on sb_holder instead + logging.info(f"Sortblock: preparing {len(executor.class_obj.double_blocks)} double blocks and {len(executor.class_obj.single_blocks)} single blocks") + if hasattr(executor.class_obj, "double_blocks"): + for block in executor.class_obj.double_blocks: + prepare_block(block, sb_holder) + if hasattr(executor.class_obj, "single_blocks"): + for block in executor.class_obj.single_blocks: + prepare_block(block, sb_holder) + if hasattr(executor.class_obj, "blocks"): + for block in executor.class_obj.block: + prepare_block(block, sb_holder) + + # when 0: Initialization(1) + if sb_holder.step_modulus == 0: + logging.info(f"Sortblock: for step {sb_holder.step_count}, all blocks are marked for recomputation") + # all features are computed, input-outputs changes for all DiT blocks are stored for relative step 'k' + sb_holder.activated_steps.append(sb_holder.step_count) + for block in sb_holder.all_blocks: + cache: BlockCache = block.__block_cache + cache.mark_recompute() + + # all block operations are performed in forward pass of model + to_return = executor(*args, **kwargs) + + # when 1: Select DiT blocks(4) + if sb_holder.step_modulus == 1: + logging.info(f"Sortblock: for step {sb_holder.step_count}, selecting blocks for recomputation and prediction") + predict_ratio = 1.0 - sb_holder.predict_ratio + for block_type, blocks in sb_holder.blocks_per_type.items(): + sorted_blocks = sorted(blocks, key=lambda x: x.__block_cache.cosine_similarity) + threshold_index = int(len(sorted_blocks) * predict_ratio) + # blocks with lower similarity are marked for recomputation + for block in sorted_blocks[:threshold_index]: + cache: BlockCache = block.__block_cache + cache.mark_recompute() + # blocks with higher similarity are marked for prediction + for block in sorted_blocks[threshold_index:]: + cache: BlockCache = block.__block_cache + cache.mark_predict() + logging.info(f"Sortblock: for {block_type}, selected {len(sorted_blocks[:threshold_index])} blocks for recomputation and {len(sorted_blocks[threshold_index:])} blocks for prediction") + + if sb_holder.initial_step: + sb_holder.initial_step = False + return to_return + +def block_forward_factory(func, block): + def block_forward_wrapper(*args, **kwargs): + transformer_options: dict[str] = kwargs.get("transformer_options") sb_holder: SortblockHolder = transformer_options["sortblock"] - # do double blocks - total_double_block = len(executor.class_obj.double_blocks) - total_single_block = len(executor.class_obj.single_blocks) - perform_sortblock(sb_holder.blocks[:total_double_block]) - perform_sortblock(sb_holder.blocks[total_double_block:]) + cache: BlockCache = block.__block_cache + # make sure stream count is properly set for this block if sb_holder.initial_step: - sb_holder.initial_step = False + sb_holder.add_to_blocks_per_type(block, transformer_options['block'][0]) + cache.block_index = transformer_options['block'][1] + cache.stream_count = transformer_options['block'][2] + # do sortblock stuff + if cache.recompute and sb_holder.step_modulus != 1: + # clone relevant inputs + orig_inputs = cache.get_orig_inputs(args, kwargs, clone=True) + # get block outputs + # NOTE: output_raw is expected to have cache.stream_count elements if count is greaater than 1 (double block, etc.) + if cache.stream_count == 1: + zzz = 10 + output_raw: Union[torch.Tensor, tuple[torch.Tensor, ...]] = func(*args, **kwargs) + # perform derivative approximation; + cache.derivative_approximation(sb_holder, output_raw, orig_inputs) + # if step_modulus is 0, input-output changes for DiT block are stored + if sb_holder.step_modulus == 0: + cache.cache_previous_residual(output_raw, orig_inputs) + else: + # if not to recompute, predict features for current timestep + orig_inputs = cache.get_orig_inputs(args, kwargs, clone=False) + # when 1: Linear Prediction(2) + # if step_modulus is 1, store block residuals as 'current' after applying taylor_formula + if sb_holder.step_modulus == 1: + cache.cache_current_residual(sb_holder) + # based on features computed in last timestep, all features for current timestep are predicted using Eq. 4, + # input-output changes for all DiT blocks are stored for relative step 'k+1' + output_raw = cache.apply_linear_prediction(sb_holder, orig_inputs) + + # when 1: Identify Changes(3) + if sb_holder.step_modulus == 1: + # based on features computed in last timestep, all features for current timestep are predicted using Eq. 4, + # input-output changes for all DiT blocks are stored for relative step 'k+1' + cache.calculate_cosine_similarity() + + # return output_raw + return output_raw + return block_forward_wrapper + def perform_sortblock(blocks: list): - candidate_blocks = [] - for block in blocks: - cache: BlockCache = getattr(block, "__block_cache") - cache.allowed_to_skip = False - if cache.want_to_skip: - candidate_blocks.append(block) - if len(candidate_blocks) > 0: - percentage_to_skip = 1.0 - candidate_blocks.sort(key=lambda x: getattr(x, "__block_cache").cumulative_change_rate) - blocks_to_skip = int(len(candidate_blocks) * percentage_to_skip) - for block in candidate_blocks[:blocks_to_skip]: - cache: BlockCache = getattr(block, "__block_cache") - cache.allowed_to_skip = True - - + ... def prepare_block(block, sb_holder: SortblockHolder, stream_count: int=1): - sb_holder.add_block(block) + sb_holder.add_to_all_blocks(block) block.__original_forward = block.forward block.forward = block_forward_factory(block.__original_forward, block) block.__block_cache = BlockCache(subsample_factor=sb_holder.subsample_factor, verbose=sb_holder.verbose) - def clean_block(block): block.forward = block.__original_forward del block.__original_forward del block.__block_cache +def subsample(x: torch.Tensor, factor: int, clone: bool=True) -> torch.Tensor: + if factor > 1: + to_return = x[..., ::factor, ::factor] + if clone: + return to_return.clone() + return to_return + if clone: + return x.clone() + return x -def block_forward_factory(func, block): - def block_forward_wrapper(*args, **kwargs): - transformer_options: dict[str] = kwargs.get("transformer_options", None) - #logging.info(f"Sortblock: inside block {transformer_options['block']}") - sb_holder: SortblockHolder = transformer_options["sortblock"] - cache: BlockCache = block.__block_cache - if sb_holder.initial_step: - cache.stream_count = transformer_options['block'][2] - if sb_holder.is_past_end_timestep(): - return func(*args, **kwargs) - # do sortblock stuff - keys = list(kwargs.keys()) - x = cache.get_next_x_prev(kwargs) - timestep: float = sb_holder.curr_t - # prepare next_x_prev - next_x_prev = cache.get_next_x_prev(kwargs, clone=True) - input_change = None - do_sortblock = sb_holder.should_do_sortblock() - if do_sortblock: - # TODO: checkmetadata - if cache.has_x_prev_subsampled(): - input_change = (cache.subsample(x, clone=False) - cache.x_prev_subsampled).flatten().abs().mean() - if cache.has_output_prev_norm() and cache.has_relative_transformation_rate(): - approx_output_change_rate = (cache.relative_transformation_rate * input_change) / cache.output_prev_norm - cache.cumulative_change_rate += approx_output_change_rate - if cache.cumulative_change_rate < sb_holder.reuse_threshold: - # accumulate error + skip block - # cache.want_to_skip = True - # if cache.allowed_to_skip: - # return cache.apply_cache_diff(x) - pass - else: - # reset error; NOT skipping block and recalculating - cache.cumulative_change_rate = 0.0 - cache.want_to_skip = False - # output_raw is expected to have cache.stream_count elements if count is greaater than 1 (double block, etc.) - output_raw: Union[torch.Tensor, tuple[torch.Tensor, ...]] = func(*args, **kwargs) - # if more than one stream from block, only use first one +class BlockCache: + def __init__(self, subsample_factor: int=8, verbose: bool=False): + self.subsample_factor = subsample_factor + self.verbose = verbose + self.stream_count = 1 + self.recompute = False + self.block_index = 0 + # cached values + self.previous_residual_subsampled: torch.Tensor = None + self.current_residual_subsampled: torch.Tensor = None + self.cosine_similarity: float = None + self.previous_taylor_factors: dict[int, torch.Tensor] = {} + self.current_taylor_factors: dict[int, torch.Tensor] = {} + + def mark_recompute(self): + self.recompute = True + + def mark_predict(self): + self.recompute = False + + def cache_previous_residual(self, output_raw: Union[torch.Tensor, tuple[torch.Tensor, ...]], orig_inputs: Union[torch.Tensor, tuple[torch.Tensor, ...]]): if isinstance(output_raw, tuple): - output = output_raw[0] - else: - output = output_raw - if cache.has_output_prev_norm(): - output_change = (cache.subsample(output, clone=False) - cache.output_prev_subsampled).flatten().abs().mean() - # if verbose in future - output_change_rate = output_change / cache.output_prev_norm - cache.output_change_rates.append(output_change_rate.item()) - if cache.has_relative_transformation_rate(): - approx_output_change_rate = (cache.relative_transformation_rate * input_change) / cache.output_prev_norm - cache.approx_output_change_rates.append(approx_output_change_rate.item()) - if input_change is not None: - cache.relative_transformation_rate = output_change / input_change - # TODO: allow cache_diff to be offloaded - cache.update_cache_diff(output_raw, next_x_prev) - cache.x_prev_subsampled = cache.subsample(next_x_prev) - cache.output_prev_subsampled = cache.subsample(output) - cache.output_prev_norm = output.flatten().abs().mean() - return output_raw - return block_forward_wrapper + output_raw = output_raw[0] + if isinstance(orig_inputs, tuple): + orig_inputs = orig_inputs[0] + del self.previous_residual_subsampled + self.previous_residual_subsampled = subsample(output_raw - orig_inputs, self.subsample_factor, clone=True) + def cache_current_residual(self, sb_holder: SortblockHolder): + del self.current_residual_subsampled + self.current_residual_subsampled = subsample(self.use_taylor_formula(sb_holder)[0], self.subsample_factor, clone=True) + + def get_orig_inputs(self, d_args: tuple, d_kwargs: dict, clone: bool=True) -> tuple[torch.Tensor, ...]: + if self.stream_count == 1: + if clone: + return d_args[0].clone() + return d_args[0] + keys = list(d_kwargs.keys())[:self.stream_count] + orig_inputs = [] + for key in keys: + if clone: + orig_inputs.append(d_kwargs[key].clone()) + else: + orig_inputs.append(d_kwargs[key]) + return tuple(orig_inputs) + + def apply_linear_prediction(self, sb_holder: SortblockHolder, orig_inputs: Union[torch.Tensor, tuple[torch.Tensor, ...]]) -> None: + drop_tuple = False + if not isinstance(orig_inputs, tuple): + orig_inputs = (orig_inputs,) + drop_tuple = True + taylor_results = self.use_taylor_formula(sb_holder) + for output, taylor_result in zip(orig_inputs, taylor_results): + if output.shape != taylor_result.shape: + zzz = 10 + output += taylor_result + if drop_tuple: + orig_inputs = orig_inputs[0] + return orig_inputs + + def calculate_cosine_similarity(self) -> None: + self.cosine_similarity = torch.nn.functional.cosine_similarity(self.previous_residual_subsampled, self.current_residual_subsampled, dim=-1).mean().item() + + def derivative_approximation(self, sb_holder: SortblockHolder, output_raw: Union[torch.Tensor, tuple[torch.Tensor, ...]], orig_inputs: Union[torch.Tensor, tuple[torch.Tensor, ...]]): + activation_distance = sb_holder.activated_steps[-1] - sb_holder.activated_steps[-2] + # make tuple if not already tuple, so that works with both single and double blocks + if not isinstance(output_raw, tuple): + output_raw = (output_raw,) + if not isinstance(orig_inputs, tuple): + orig_inputs = (orig_inputs,) + + for i, (output, x) in enumerate(zip(output_raw, orig_inputs)): + feature = output.clone() - x + has_previous_taylor_factor = self.previous_taylor_factors.get(i, None) is not None + # NOTE: not sure why - 2, but that's what's in the original implementation. Maybe consider changing values? + if has_previous_taylor_factor and sb_holder.step_count > (sb_holder.first_enhance - 2): + self.current_taylor_factors[i] = ( + feature - self.previous_taylor_factors[i] + ) / activation_distance + + self.previous_taylor_factors[i] = feature + + def use_taylor_formula(self, sb_holder: SortblockHolder) -> tuple[torch.Tensor, ...]: + step_distance = sb_holder.step_count - sb_holder.activated_steps[-1] + + output_predicted = [] + + for key in self.previous_taylor_factors.keys(): + previous_tf = self.previous_taylor_factors[key] + current_tf = self.current_taylor_factors[key] + predicted = taylor_formula(previous_tf, 0, step_distance) + predicted += taylor_formula(current_tf, 1, step_distance) + output_predicted.append(predicted) + + return tuple(output_predicted) + + def reset(self): + self.recompute = False + self.current_residual_subsampled = None + self.previous_residual_subsampled = None + self.cosine_similarity = None + self.previous_taylor_factors = {} + self.current_taylor_factors = {} + +def taylor_formula(taylor_factor: torch.Tensor, i: int, step_distance: int): + return ( + (1 / math.factorial(i)) + * taylor_factor + * (step_distance ** i) + ) class SortblockHolder: - def __init__(self, reuse_threshold: float, start_percent: float, end_percent: float, subsample_factor: int=8, verbose: bool=False): - self.reuse_threshold = reuse_threshold + def __init__(self, predict_ratio: float, policy_refresh_interval: int, start_percent: float, end_percent: float, subsample_factor: int=8, verbose: bool=False): + self.predict_ratio = predict_ratio self.start_percent = start_percent self.end_percent = end_percent self.subsample_factor = subsample_factor self.verbose = verbose + + # NOTE: number represents steps + self.policy_refresh_interval = policy_refresh_interval + self.active_policy_refresh_interval = 1 + self.first_enhance = 3 # NOTE: this value is 2 higher than the one actually used in code (subtracted by 2 in derivative_approximation) # timestep values self.start_t = 0.0 self.end_t = 0.0 self.curr_t = 0.0 # control values - self.past_timestep = False - self.do_sortblock = False self.initial_step = True + self.step_count = 0 + self.activated_steps: list[int] = [0] + self.step_modulus = 0 # cache values - self.blocks = [] + self.all_blocks = [] + self.blocks_per_type = {} - def add_block(self, block): - self.blocks.append(block) + def add_to_all_blocks(self, block): + self.all_blocks.append(block) + + def add_to_blocks_per_type(self, block, block_type: str): + self.blocks_per_type.setdefault(block_type, []).append(block) def prepare_timesteps(self, model_sampling): self.start_t = model_sampling.percent_to_sigma(self.start_percent) self.end_t = model_sampling.percent_to_sigma(self.end_percent) return self - def update_is_past_end_timestep(self, timestep: float) -> bool: - self.past_timestep = not (timestep[0] > self.end_t).item() - return self.past_timestep - - def is_past_end_timestep(self) -> bool: - return self.past_timestep - def update_should_do_sortblock(self, timestep: float) -> bool: - self.do_sortblock = (timestep[0] <= self.start_t).item() + self.do_sortblock = (timestep[0] <= self.start_t).item() and (timestep[0] > self.end_t).item() self.curr_t = timestep + if self.do_sortblock: + self.active_policy_refresh_interval = self.policy_refresh_interval + else: + self.active_policy_refresh_interval = 1 + self.update_step_modulus() return self.do_sortblock + def update_step_modulus(self): + self.step_modulus = int(self.step_count % self.active_policy_refresh_interval) + def should_do_sortblock(self) -> bool: return self.do_sortblock - def print_block_info(self, index: int): - block = self.blocks[index] - cache = getattr(block, "__block_cache") - logging.info(f"Sortblock: block {index} output_change_rates: {cache.output_change_rates}") - logging.info(f"Sortblock: block {index} approx_output_change_rates: {cache.approx_output_change_rates}") - def reset(self): self.initial_step = True self.curr_t = 0.0 - logging.info(f"Sortblock: resetting {len(self.blocks)} blocks") - for block in self.blocks: + logging.info(f"Sortblock: resetting {len(self.all_blocks)} blocks") + for block in self.all_blocks: clean_block(block) - self.blocks = [] + self.all_blocks = [] + self.blocks_per_type = {} + self.step_count = 0 + self.activated_steps = [0] + self.step_modulus = 0 return self def clone(self): - return SortblockHolder(self.reuse_threshold, self.start_percent, self.end_percent, self.subsample_factor, self.verbose) + return SortblockHolder(predict_ratio=self.predict_ratio, policy_refresh_interval=self.policy_refresh_interval, + start_percent=self.start_percent, end_percent=self.end_percent, subsample_factor=self.subsample_factor, + verbose=self.verbose) -class BlockCache: - def __init__(self, subsample_factor: int=8, stream_count: int=1, verbose: bool=False): - self.subsample_factor = subsample_factor - self.stream_count = stream_count - self.verbose = verbose - # control values - self.relative_transformation_rate: float = None - self.cumulative_change_rate = 0.0 - self.initial_step = True - # cache values - self.x_prev_subsampled: torch.Tensor = None - self.output_prev_subsampled: torch.Tensor = None - self.output_prev_norm: torch.Tensor = None - self.cache_diff: list[torch.Tensor] = [None for _ in range(stream_count)] - self.output_change_rates = [] - self.approx_output_change_rates = [] - self.total_steps_skipped = 0 - self.state_metadata = None - self.want_to_skip = False - self.allowed_to_skip = False - - def has_cache_diff(self) -> bool: - return self.cache_diff[0] is not None - - def has_x_prev_subsampled(self) -> bool: - return self.x_prev_subsampled is not None - - def has_output_prev_subsampled(self) -> bool: - return self.output_prev_subsampled is not None - - def has_output_prev_norm(self) -> bool: - return self.output_prev_norm is not None - - def has_relative_transformation_rate(self) -> bool: - return self.relative_transformation_rate is not None - - def get_next_x_prev(self, d_kwargs: dict[str, torch.Tensor], clone: bool=False) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]: - keys = list(d_kwargs.keys()) - if self.stream_count == 1: - if clone: - return d_kwargs[keys[0]].clone() - return d_kwargs[keys[0]] - return tuple([d_kwargs[keys[i]].clone() if clone else d_kwargs[keys[i]] for i in range(self.stream_count)]) - - def subsample(self, x: Union[torch.Tensor, tuple[torch.Tensor, ...]], clone: bool = True) -> torch.Tensor: - # subsample only the first compoenent - if isinstance(x, tuple): - return self.subsample(x[0], clone) - if self.subsample_factor > 1: - to_return = x[..., ::self.subsample_factor, ::self.subsample_factor] - if clone: - return to_return.clone() - return to_return - if clone: - return x.clone() - return x - - def apply_cache_diff(self, x: Union[torch.Tensor, tuple[torch.Tensor, ...]]): - self.total_steps_skipped += 1 - if not isinstance(x, tuple): - x = (x, ) - to_return = tuple([x[i] + self.cache_diff[i] for i in range(self.stream_count)]) - if len(to_return) == 1: - return to_return[0] - return to_return - - def update_cache_diff(self, output_raw: Union[torch.Tensor, tuple[torch.Tensor, ...]], x: Union[torch.Tensor, tuple[torch.Tensor, ...]]): - if not isinstance(output_raw, tuple): - output_raw = (output_raw, ) - if not isinstance(x, tuple): - x = (x, ) - self.cache_diff = tuple([output_raw[i] - x[i] for i in range(self.stream_count)]) - - def check_metadata(self, x: torch.Tensor) -> bool: - # TODO: make sure shapes are correct - metadata = (x.device, x.dtype, x.shape) - if self.state_metadata is None: - self.state_metadata = metadata - return True - if metadata == self.state_metadata: - return True - logging.warn(f"{self.name} - Tensor shape, dtype or device changed, resetting state") - self.reset() - return False - - def reset(self): - self.relative_transformation_rate = 0.0 - self.cumulative_change_rate = 0.0 - self.initial_step = True - self.cache_diff = [None for _ in range(self.stream_count)] - self.output_change_rates = [] - self.approx_output_change_rates = [] - self.total_steps_skipped = 0 - self.state_metadata = None - self.want_to_skip = False - self.allowed_to_skip = False - class SortblockNode(io.ComfyNode): @classmethod def define_schema(cls) -> io.Schema: @@ -324,7 +361,8 @@ class SortblockNode(io.ComfyNode): is_experimental=True, inputs=[ io.Model.Input("model", tooltip="The model to add Sortblock to."), - io.Float.Input("reuse_threshold", min=0.0, default=0.2, max=3.0, step=0.01, tooltip="The threshold for reusing cached blocks."), + io.Float.Input("predict_ratio", min=0.0, default=0.8, max=3.0, step=0.01, tooltip="The ratio of blocks to predict."), + io.Int.Input("policy_refresh_interval", min=3, default=5, max=100, step=1, tooltip="The interval at which to refresh the policy."), io.Float.Input("start_percent", min=0.0, default=0.15, max=1.0, step=0.01, tooltip="The relative sampling step to begin use of Sortblock."), io.Float.Input("end_percent", min=0.0, default=0.95, max=1.0, step=0.01, tooltip="The relative sampling step to end use of Sortblock."), io.Boolean.Input("verbose", default=False, tooltip="Whether to log verbose information."), @@ -335,14 +373,13 @@ class SortblockNode(io.ComfyNode): ) @classmethod - def execute(cls, model: io.Model.Type, reuse_threshold: float, start_percent: float, end_percent: float, verbose: bool) -> io.NodeOutput: + def execute(cls, model: io.Model.Type, predict_ratio: float, policy_refresh_interval: int, start_percent: float, end_percent: float, verbose: bool) -> io.NodeOutput: # TODO: check for specific flavors of supported models model = model.clone() - model.model_options["transformer_options"]["sortblock"] = SortblockHolder(reuse_threshold, start_percent, end_percent, subsample_factor=8, verbose=verbose) + model.model_options["transformer_options"]["sortblock"] = SortblockHolder(predict_ratio, policy_refresh_interval, start_percent, end_percent, subsample_factor=8, verbose=verbose) + model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.PREDICT_NOISE, "sortblock", prepare_noise_wrapper) model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, "sortblock", outer_sample_wrapper) model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, "sortblock", model_forward_wrapper) - # model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, "lazycache", easycache_sample_wrapper) - # model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.PREDICT_NOISE, "lazycache", lazycache_predict_noise_wrapper) return io.NodeOutput(model)