From a40c5ae341516eda98264e5101a74684347587fd Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Tue, 2 Sep 2025 15:23:28 -0700 Subject: [PATCH] Support predict_ratio changing with timesteps --- comfy_extras/nodes_sortblock.py | 83 +++++++++++++++++++++++++++++---- 1 file changed, 74 insertions(+), 9 deletions(-) diff --git a/comfy_extras/nodes_sortblock.py b/comfy_extras/nodes_sortblock.py index 69b9ec477..b6135ae18 100644 --- a/comfy_extras/nodes_sortblock.py +++ b/comfy_extras/nodes_sortblock.py @@ -13,10 +13,26 @@ if TYPE_CHECKING: def prepare_noise_wrapper(executor, *args, **kwargs): try: + transformer_options: dict[str] = args[2]["transformer_options"] + sb_holder: SortblockHolder = transformer_options["sortblock"] + if sb_holder.initial_step: + sample_sigmas = transformer_options["sample_sigmas"] + relevant_sigmas = [] + # find start and end steps, then use to interpolate between start and end predict ratios + for i,sigma in enumerate(sample_sigmas): + if sb_holder.check_if_within_timesteps(sigma): + relevant_sigmas.append((i, sigma)) + start_index = relevant_sigmas[0][0] + end_index = relevant_sigmas[-1][0] + sb_holder.predict_ratios = torch.linspace(sb_holder.start_predict_ratio, sb_holder.end_predict_ratio, end_index - start_index + 1) + return executor(*args, **kwargs) finally: - sb_holder: SortblockHolder = executor.class_obj.model_options["transformer_options"]["sortblock"] + transformer_options: dict[str] = args[2]["transformer_options"] + sb_holder: SortblockHolder = transformer_options["sortblock"] sb_holder.step_count += 1 + if sb_holder.should_do_sortblock(): + sb_holder.active_steps += 1 def outer_sample_wrapper(executor, *args, **kwargs): @@ -29,7 +45,7 @@ def outer_sample_wrapper(executor, *args, **kwargs): sb_holder = guider.model_options["transformer_options"]["sortblock"] guider.model_options["transformer_options"]["sortblock"] = sb_holder.clone().prepare_timesteps(guider.model_patcher.model.model_sampling) sb_holder: SortblockHolder = guider.model_options["transformer_options"]["sortblock"] - logging.info(f"Sortblock: enabled - threshold: {sb_holder.predict_ratio}, start_percent: {sb_holder.start_percent}, end_percent: {sb_holder.end_percent}") + logging.info(f"Sortblock: enabled - threshold: {sb_holder.start_predict_ratio}, start_percent: {sb_holder.start_percent}, end_percent: {sb_holder.end_percent}") return executor(*args, **kwargs) finally: sb_holder = guider.model_options["transformer_options"]["sortblock"] @@ -79,11 +95,12 @@ def model_forward_wrapper(executor, *args, **kwargs): # when 1: Select DiT blocks(4) if sb_holder.step_modulus == 1: - logging.info(f"Sortblock: for step {sb_holder.step_count}, selecting blocks for recomputation and prediction") - predict_ratio = 1.0 - sb_holder.predict_ratio + predict_ratio = sb_holder.predict_ratios[sb_holder.active_steps-1] + logging.info(f"Sortblock: for step {sb_holder.step_count}, selecting blocks for recomputation and prediction, predict_ratio: {predict_ratio}") + reuse_ratio = 1.0 - predict_ratio for block_type, blocks in sb_holder.blocks_per_type.items(): sorted_blocks = sorted(blocks, key=lambda x: x.__block_cache.cosine_similarity) - threshold_index = int(len(sorted_blocks) * predict_ratio) + threshold_index = int(len(sorted_blocks) * reuse_ratio) # blocks with lower similarity are marked for recomputation for block in sorted_blocks[:threshold_index]: cache: BlockCache = block.__block_cache @@ -280,8 +297,10 @@ def taylor_formula(taylor_factor: torch.Tensor, i: int, step_distance: int): ) class SortblockHolder: - def __init__(self, predict_ratio: float, policy_refresh_interval: int, start_percent: float, end_percent: float, subsample_factor: int=8, verbose: bool=False): - self.predict_ratio = predict_ratio + def __init__(self, start_predict_ratio: float, end_predict_ratio: float, policy_refresh_interval: int, + start_percent: float, end_percent: float, subsample_factor: int=8, verbose: bool=False): + self.start_predict_ratio = start_predict_ratio + self.end_predict_ratio = end_predict_ratio self.start_percent = start_percent self.end_percent = end_percent self.subsample_factor = subsample_factor @@ -300,6 +319,10 @@ class SortblockHolder: self.step_count = 0 self.activated_steps: list[int] = [0] self.step_modulus = 0 + self.do_sortblock = False + self.active_steps = 0 + self.predict_ratios = [] + # cache values self.all_blocks = [] self.blocks_per_type = {} @@ -315,6 +338,9 @@ class SortblockHolder: self.end_t = model_sampling.percent_to_sigma(self.end_percent) return self + def check_if_within_timesteps(self, timestep: Union[float, torch.Tensor]) -> bool: + return (timestep <= self.start_t).item() and (timestep > self.end_t).item() + def update_should_do_sortblock(self, timestep: float) -> bool: self.do_sortblock = (timestep[0] <= self.start_t).item() and (timestep[0] > self.end_t).item() self.curr_t = timestep @@ -342,10 +368,13 @@ class SortblockHolder: self.step_count = 0 self.activated_steps = [0] self.step_modulus = 0 + self.active_steps = 0 + self.predict_ratios = [] + self.do_sortblock = False return self def clone(self): - return SortblockHolder(predict_ratio=self.predict_ratio, policy_refresh_interval=self.policy_refresh_interval, + return SortblockHolder(start_predict_ratio=self.start_predict_ratio, end_predict_ratio=self.end_predict_ratio, policy_refresh_interval=self.policy_refresh_interval, start_percent=self.start_percent, end_percent=self.end_percent, subsample_factor=self.subsample_factor, verbose=self.verbose) @@ -376,7 +405,42 @@ class SortblockNode(io.ComfyNode): def execute(cls, model: io.Model.Type, predict_ratio: float, policy_refresh_interval: int, start_percent: float, end_percent: float, verbose: bool) -> io.NodeOutput: # TODO: check for specific flavors of supported models model = model.clone() - model.model_options["transformer_options"]["sortblock"] = SortblockHolder(predict_ratio, policy_refresh_interval, start_percent, end_percent, subsample_factor=8, verbose=verbose) + model.model_options["transformer_options"]["sortblock"] = SortblockHolder(start_predict_ratio=predict_ratio, end_predict_ratio=predict_ratio, policy_refresh_interval=policy_refresh_interval, + start_percent=start_percent, end_percent=end_percent, subsample_factor=8, verbose=verbose) + model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.PREDICT_NOISE, "sortblock", prepare_noise_wrapper) + model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, "sortblock", outer_sample_wrapper) + model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, "sortblock", model_forward_wrapper) + return io.NodeOutput(model) + + +class SortblockScaledNode(io.ComfyNode): + @classmethod + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="SortblockScaled", + display_name="SortblockScaled", + description="A homebrew version of EasyCache - even 'easier' version of EasyCache to implement. Overall works worse than EasyCache, but better in some rare cases AND universal compatibility with everything in ComfyUI.", + category="advanced/debug/model", + is_experimental=True, + inputs=[ + io.Model.Input("model", tooltip="The model to add Sortblock to."), + io.Float.Input("start_predict_ratio", min=0.0, default=0.2, max=1.0, step=0.01, tooltip="The ratio of blocks to predict."), + io.Float.Input("end_predict_ratio", min=0.0, default=0.9, max=1.0, step=0.01, tooltip="The ratio of blocks to predict."), + io.Int.Input("policy_refresh_interval", min=3, default=5, max=100, step=1, tooltip="The interval at which to refresh the policy."), + io.Float.Input("start_percent", min=0.0, default=0.15, max=1.0, step=0.01, tooltip="The relative sampling step to begin use of Sortblock."), + io.Float.Input("end_percent", min=0.0, default=0.95, max=1.0, step=0.01, tooltip="The relative sampling step to end use of Sortblock."), + io.Boolean.Input("verbose", default=False, tooltip="Whether to log verbose information."), + ], + outputs=[ + io.Model.Output(tooltip="The model with Sortblock."), + ], + ) + + @classmethod + def execute(cls, model: io.Model.Type, start_predict_ratio: float, end_predict_ratio: float, policy_refresh_interval: int, start_percent: float, end_percent: float, verbose: bool) -> io.NodeOutput: + # TODO: check for specific flavors of supported models + model = model.clone() + model.model_options["transformer_options"]["sortblock"] = SortblockHolder(start_predict_ratio, end_predict_ratio, policy_refresh_interval, start_percent, end_percent, subsample_factor=8, verbose=verbose) model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.PREDICT_NOISE, "sortblock", prepare_noise_wrapper) model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, "sortblock", outer_sample_wrapper) model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, "sortblock", model_forward_wrapper) @@ -387,6 +451,7 @@ class SortblockExtension(ComfyExtension): async def get_node_list(self) -> list[type[io.ComfyNode]]: return [ SortblockNode, + SortblockScaledNode, ] def comfy_entrypoint():