From d4a8752c8c3d96e4b97cea1c794fb63dd6771e89 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Mon, 1 Sep 2025 09:39:40 -0700 Subject: [PATCH] some exploration of sortblock as more things from paper/source code need to be added --- comfy_extras/nodes_sortblock.py | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/comfy_extras/nodes_sortblock.py b/comfy_extras/nodes_sortblock.py index 58a8f1cb7..bbcd13da1 100644 --- a/comfy_extras/nodes_sortblock.py +++ b/comfy_extras/nodes_sortblock.py @@ -24,6 +24,11 @@ def outer_sample_wrapper(executor, *args, **kwargs): finally: sb_holder = guider.model_options["transformer_options"]["sortblock"] sb_holder.print_block_info(0) + + # import plotly.express as px + # fig = px.line(x=list(range(len(sb_holder.blocks))), y=[getattr(block, "__block_cache").cumulative_change_rate for block in sb_holder.blocks]) + + sb_holder.reset() guider.model_options = orig_model_options @@ -56,7 +61,7 @@ def model_forward_wrapper(executor, *args, **kwargs): total_double_block = len(executor.class_obj.double_blocks) total_single_block = len(executor.class_obj.single_blocks) perform_sortblock(sb_holder.blocks[:total_double_block]) - #perform_sortblock(sb_holder.blocks[total_double_block:]) + perform_sortblock(sb_holder.blocks[total_double_block:]) if sb_holder.initial_step: sb_holder.initial_step = False @@ -105,7 +110,7 @@ def block_forward_factory(func, block): x = cache.get_next_x_prev(kwargs) timestep: float = sb_holder.curr_t # prepare next_x_prev - next_x_prev = cache.get_next_x_prev(kwargs) + next_x_prev = cache.get_next_x_prev(kwargs, clone=True) input_change = None do_sortblock = sb_holder.should_do_sortblock() if do_sortblock: @@ -117,9 +122,10 @@ def block_forward_factory(func, block): cache.cumulative_change_rate += approx_output_change_rate if cache.cumulative_change_rate < sb_holder.reuse_threshold: # accumulate error + skip block - cache.want_to_skip = True - if cache.allowed_to_skip: - return cache.apply_cache_diff(x) + # cache.want_to_skip = True + # if cache.allowed_to_skip: + # return cache.apply_cache_diff(x) + pass else: # reset error; NOT skipping block and recalculating cache.cumulative_change_rate = 0.0 @@ -246,11 +252,13 @@ class BlockCache: def has_relative_transformation_rate(self) -> bool: return self.relative_transformation_rate is not None - def get_next_x_prev(self, d_kwargs: dict[str, torch.Tensor]) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]: + def get_next_x_prev(self, d_kwargs: dict[str, torch.Tensor], clone: bool=False) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]: keys = list(d_kwargs.keys()) if self.stream_count == 1: + if clone: + return d_kwargs[keys[0]].clone() return d_kwargs[keys[0]] - return tuple([d_kwargs[keys[i]] for i in range(self.stream_count)]) + return tuple([d_kwargs[keys[i]].clone() if clone else d_kwargs[keys[i]] for i in range(self.stream_count)]) def subsample(self, x: Union[torch.Tensor, tuple[torch.Tensor, ...]], clone: bool = True) -> torch.Tensor: # subsample only the first compoenent