diff --git a/comfy/ldm/flux/model.py b/comfy/ldm/flux/model.py index 13520c4ef..6346c0c3a 100644 --- a/comfy/ldm/flux/model.py +++ b/comfy/ldm/flux/model.py @@ -128,7 +128,7 @@ class Flux(nn.Module): blocks_replace = patches_replace.get("dit", {}) for i, block in enumerate(self.double_blocks): - transformer_options["block"] = ("double_block", i) + transformer_options["block"] = ("double_block", i, 2) if ("double_block", i) in blocks_replace: def block_wrap(args): out = {} @@ -170,7 +170,7 @@ class Flux(nn.Module): img = torch.cat((txt, img), 1) for i, block in enumerate(self.single_blocks): - transformer_options["block"] = ("single_block", i) + transformer_options["block"] = ("single_block", i, 1) if ("single_block", i) in blocks_replace: def block_wrap(args): out = {} diff --git a/comfy_extras/nodes_sortblock.py b/comfy_extras/nodes_sortblock.py index 628b709c6..58a8f1cb7 100644 --- a/comfy_extras/nodes_sortblock.py +++ b/comfy_extras/nodes_sortblock.py @@ -10,7 +10,6 @@ import comfy.model_patcher if TYPE_CHECKING: from uuid import UUID - def outer_sample_wrapper(executor, *args, **kwargs): try: logging.info("Sortblock: inside outer_sample!") @@ -24,6 +23,7 @@ def outer_sample_wrapper(executor, *args, **kwargs): return executor(*args, **kwargs) finally: sb_holder = guider.model_options["transformer_options"]["sortblock"] + sb_holder.print_block_info(0) sb_holder.reset() guider.model_options = orig_model_options @@ -43,14 +43,38 @@ def model_forward_wrapper(executor, *args, **kwargs): transformer_options["total_double_block"] = len(executor.class_obj.double_blocks) transformer_options["total_single_block"] = len(executor.class_obj.single_blocks) # save the original forwards on the blocks - logging.info(f"Sortblock: saving preparing {transformer_options['total_double_block']} double blocks and {transformer_options['total_single_block']} single blocks") + logging.info(f"Sortblock: preparing {transformer_options['total_double_block']} double blocks and {transformer_options['total_single_block']} single blocks") for block in executor.class_obj.double_blocks: prepare_block(block, sb_holder) for block in executor.class_obj.single_blocks: prepare_block(block, sb_holder) - sb_holder.initial_step = False + try: + return executor(*args, **kwargs) + finally: + sb_holder: SortblockHolder = transformer_options["sortblock"] + # do double blocks + total_double_block = len(executor.class_obj.double_blocks) + total_single_block = len(executor.class_obj.single_blocks) + perform_sortblock(sb_holder.blocks[:total_double_block]) + #perform_sortblock(sb_holder.blocks[total_double_block:]) + if sb_holder.initial_step: + sb_holder.initial_step = False + +def perform_sortblock(blocks: list): + candidate_blocks = [] + for block in blocks: + cache: BlockCache = getattr(block, "__block_cache") + cache.allowed_to_skip = False + if cache.want_to_skip: + candidate_blocks.append(block) + if len(candidate_blocks) > 0: + percentage_to_skip = 1.0 + candidate_blocks.sort(key=lambda x: getattr(x, "__block_cache").cumulative_change_rate) + blocks_to_skip = int(len(candidate_blocks) * percentage_to_skip) + for block in candidate_blocks[:blocks_to_skip]: + cache: BlockCache = getattr(block, "__block_cache") + cache.allowed_to_skip = True - return executor(*args, **kwargs) def prepare_block(block, sb_holder: SortblockHolder, stream_count: int=1): @@ -69,14 +93,16 @@ def clean_block(block): def block_forward_factory(func, block): def block_forward_wrapper(*args, **kwargs): transformer_options: dict[str] = kwargs.get("transformer_options", None) - logging.info(f"Sortblock: inside block {transformer_options['block']}") + #logging.info(f"Sortblock: inside block {transformer_options['block']}") sb_holder: SortblockHolder = transformer_options["sortblock"] + cache: BlockCache = block.__block_cache + if sb_holder.initial_step: + cache.stream_count = transformer_options['block'][2] if sb_holder.is_past_end_timestep(): return func(*args, **kwargs) # do sortblock stuff - cache: BlockCache = block.__block_cache keys = list(kwargs.keys()) - x: torch.Tensor = kwargs[keys[0]] + x = cache.get_next_x_prev(kwargs) timestep: float = sb_holder.curr_t # prepare next_x_prev next_x_prev = cache.get_next_x_prev(kwargs) @@ -91,10 +117,13 @@ def block_forward_factory(func, block): cache.cumulative_change_rate += approx_output_change_rate if cache.cumulative_change_rate < sb_holder.reuse_threshold: # accumulate error + skip block - pass + cache.want_to_skip = True + if cache.allowed_to_skip: + return cache.apply_cache_diff(x) else: # reset error; NOT skipping block and recalculating cache.cumulative_change_rate = 0.0 + cache.want_to_skip = False # output_raw is expected to have cache.stream_count elements if count is greaater than 1 (double block, etc.) output_raw: Union[torch.Tensor, tuple[torch.Tensor, ...]] = func(*args, **kwargs) # if more than one stream from block, only use first one @@ -162,6 +191,12 @@ class SortblockHolder: def should_do_sortblock(self) -> bool: return self.do_sortblock + def print_block_info(self, index: int): + block = self.blocks[index] + cache = getattr(block, "__block_cache") + logging.info(f"Sortblock: block {index} output_change_rates: {cache.output_change_rates}") + logging.info(f"Sortblock: block {index} approx_output_change_rates: {cache.approx_output_change_rates}") + def reset(self): self.initial_step = True self.curr_t = 0.0 @@ -193,6 +228,8 @@ class BlockCache: self.approx_output_change_rates = [] self.total_steps_skipped = 0 self.state_metadata = None + self.want_to_skip = False + self.allowed_to_skip = False def has_cache_diff(self) -> bool: return self.cache_diff[0] is not None @@ -228,12 +265,21 @@ class BlockCache: return x.clone() return x - def apply_cache_diff(self, x: torch.Tensor): + def apply_cache_diff(self, x: Union[torch.Tensor, tuple[torch.Tensor, ...]]): self.total_steps_skipped += 1 - return x + self.cache_diff.to(x.device) + if not isinstance(x, tuple): + x = (x, ) + to_return = tuple([x[i] + self.cache_diff[i] for i in range(self.stream_count)]) + if len(to_return) == 1: + return to_return[0] + return to_return - def update_cache_diff(self, output: torch.Tensor, x: torch.Tensor): - self.cache_diff = output - x + def update_cache_diff(self, output_raw: Union[torch.Tensor, tuple[torch.Tensor, ...]], x: Union[torch.Tensor, tuple[torch.Tensor, ...]]): + if not isinstance(output_raw, tuple): + output_raw = (output_raw, ) + if not isinstance(x, tuple): + x = (x, ) + self.cache_diff = tuple([output_raw[i] - x[i] for i in range(self.stream_count)]) def check_metadata(self, x: torch.Tensor) -> bool: # TODO: make sure shapes are correct @@ -256,6 +302,8 @@ class BlockCache: self.approx_output_change_rates = [] self.total_steps_skipped = 0 self.state_metadata = None + self.want_to_skip = False + self.allowed_to_skip = False class SortblockNode(io.ComfyNode): @classmethod