Doing some experimentation

This commit is contained in:
Jedrzej Kosinski
2025-09-02 22:19:12 -07:00
parent a40c5ae341
commit 295b49c165
3 changed files with 509 additions and 1 deletions

View File

@@ -0,0 +1,503 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Union
from comfy_api.latest import io, ComfyExtension
import comfy.patcher_extension
import logging
import torch
import math
import comfy.model_patcher
if TYPE_CHECKING:
from uuid import UUID
def easysortblock_predict_noise_wrapper(executor, *args, **kwargs):
# get values from args
x: torch.Tensor = args[0]
timestep: float = args[1]
model_options: dict[str] = args[2]
easycache: EasySortblockHolder = model_options["transformer_options"]["easycache"]
# initialize predict_ratios
if easycache.initial_step:
sample_sigmas = model_options["transformer_options"]["sample_sigmas"]
relevant_sigmas = []
for i,sigma in enumerate(sample_sigmas):
if easycache.check_if_within_timesteps(sigma):
relevant_sigmas.append((i, sigma))
start_index = relevant_sigmas[0][0]
end_index = relevant_sigmas[-1][0]
easycache.predict_ratios = torch.linspace(easycache.start_predict_ratio, easycache.end_predict_ratio, end_index - start_index + 1)
easycache.predict_start_index = start_index
easycache.skip_current_step = False
if easycache.is_past_end_timestep(timestep):
return executor(*args, **kwargs)
# prepare next x_prev
next_x_prev = x
input_change = None
do_easycache = easycache.should_do_easycache(timestep)
if do_easycache:
easycache.check_metadata(x)
if easycache.has_x_prev_subsampled():
if easycache.has_x_prev_subsampled():
input_change = (easycache.subsample(x, clone=False) - easycache.x_prev_subsampled).flatten().abs().mean()
if easycache.has_output_prev_norm() and easycache.has_relative_transformation_rate():
approx_output_change_rate = (easycache.relative_transformation_rate * input_change) / easycache.output_prev_norm
easycache.cumulative_change_rate += approx_output_change_rate
if easycache.cumulative_change_rate < easycache.reuse_threshold:
if easycache.verbose:
logging.info(f"EasySortblock [verbose] - skipping step; cumulative_change_rate: {easycache.cumulative_change_rate}, reuse_threshold: {easycache.reuse_threshold}")
# other conds should also skip this step
easycache.skip_current_step = True
easycache.steps_skipped.append(easycache.step_count)
else:
if easycache.verbose:
logging.info(f"EasySortblock [verbose] - NOT skipping step; cumulative_change_rate: {easycache.cumulative_change_rate}, reuse_threshold: {easycache.reuse_threshold}")
easycache.cumulative_change_rate = 0.0
output: torch.Tensor = executor(*args, **kwargs)
if easycache.has_output_prev_norm():
output_change = (easycache.subsample(output, clone=False) - easycache.output_prev_subsampled).flatten().abs().mean()
if easycache.verbose:
output_change_rate = output_change / easycache.output_prev_norm
easycache.output_change_rates.append(output_change_rate.item())
if easycache.has_relative_transformation_rate():
approx_output_change_rate = (easycache.relative_transformation_rate * input_change) / easycache.output_prev_norm
easycache.approx_output_change_rates.append(approx_output_change_rate.item())
if easycache.verbose:
logging.info(f"EasySortblock [verbose] - approx_output_change_rate: {approx_output_change_rate}")
if input_change is not None:
easycache.relative_transformation_rate = output_change / input_change
if easycache.verbose:
logging.info(f"EasySortblock [verbose] - output_change_rate: {output_change_rate}")
easycache.x_prev_subsampled = easycache.subsample(next_x_prev)
easycache.output_prev_subsampled = easycache.subsample(output)
easycache.output_prev_norm = output.flatten().abs().mean()
if easycache.verbose:
logging.info(f"EasySortblock [verbose] - x_prev_subsampled: {easycache.x_prev_subsampled.shape}")
# increment step count
easycache.step_count += 1
easycache.initial_step = False
return output
def easysortblock_outer_sample_wrapper(executor, *args, **kwargs):
"""
This OUTER_SAMPLE wrapper makes sure EasySortblock is prepped for current run, and all memory usage is cleared at the end.
"""
try:
guider = executor.class_obj
orig_model_options = guider.model_options
guider.model_options = comfy.model_patcher.create_model_options_clone(orig_model_options)
# clone and prepare timesteps
guider.model_options["transformer_options"]["easycache"] = guider.model_options["transformer_options"]["easycache"].clone().prepare_timesteps(guider.model_patcher.model.model_sampling)
easycache: EasySortblockHolder = guider.model_options['transformer_options']['easycache']
logging.info(f"{easycache.name} enabled - threshold: {easycache.reuse_threshold}, start_percent: {easycache.start_percent}, end_percent: {easycache.end_percent}")
return executor(*args, **kwargs)
finally:
easycache = guider.model_options['transformer_options']['easycache']
output_change_rates = easycache.output_change_rates
approx_output_change_rates = easycache.approx_output_change_rates
if easycache.verbose:
logging.info(f"{easycache.name} [verbose] - output_change_rates {len(output_change_rates)}: {output_change_rates}")
logging.info(f"{easycache.name} [verbose] - approx_output_change_rates {len(approx_output_change_rates)}: {approx_output_change_rates}")
total_steps = len(args[3])-1
logging.info(f"{easycache.name} - skipped {len(easycache.steps_skipped)}/{total_steps} steps")# ({total_steps/(total_steps-easycache.total_steps_skipped):.2f}x speedup).")
logging.info(f"{easycache.name} - skipped steps: {easycache.steps_skipped}")
easycache.reset()
guider.model_options = orig_model_options
def model_forward_wrapper(executor, *args, **kwargs):
# TODO: make work with batches of conds
transformer_options: dict[str] = args[-1]
if not isinstance(transformer_options, dict):
transformer_options = kwargs.get("transformer_options")
if not transformer_options:
transformer_options = args[-2]
sigmas = transformer_options["sigmas"]
sb_holder: EasySortblockHolder = transformer_options["easycache"]
# if initial step, prepare everything for Sortblock
if sb_holder.initial_step:
logging.info(f"EasySortblock: inside model {executor.class_obj.__class__.__name__}")
# TODO: generalize for other models
# these won't stick around past this step; should store on sb_holder instead
logging.info(f"EasySortblock: preparing {len(executor.class_obj.double_blocks)} double blocks and {len(executor.class_obj.single_blocks)} single blocks")
if hasattr(executor.class_obj, "double_blocks"):
for block in executor.class_obj.double_blocks:
prepare_block(block, sb_holder)
if hasattr(executor.class_obj, "single_blocks"):
for block in executor.class_obj.single_blocks:
prepare_block(block, sb_holder)
if hasattr(executor.class_obj, "blocks"):
for block in executor.class_obj.block:
prepare_block(block, sb_holder)
if sb_holder.skip_current_step:
predict_index = max(0, sb_holder.step_count - sb_holder.predict_start_index)
predict_ratio = sb_holder.predict_ratios[predict_index]
logging.info(f"EasySortblock: skipping step {sb_holder.step_count}, predict_ratio: {predict_ratio}")
# reuse_ratio = 1.0 - predict_ratio
for block_type, blocks in sb_holder.blocks_per_type.items():
for block in blocks:
cache: BlockCache = block.__block_cache
cache.allowed_to_skip = False
sorted_blocks = sorted(blocks, key=lambda x: (x.__block_cache.consecutive_skipped_steps, x.__block_cache.prev_change_rate))
# for block in sorted_blocks:
# pass
threshold_index = int(len(sorted_blocks) * predict_ratio)
# blocks with lower similarity are marked for recomputation
for block in sorted_blocks[:threshold_index]:
cache: BlockCache = block.__block_cache
cache.allowed_to_skip = True
logging.info(f"EasySortblock: skip block {block.__class__.__name__} - consecutive_skipped_steps: {block.__block_cache.consecutive_skipped_steps}, prev_change_rate: {block.__block_cache.prev_change_rate}, index: {block.__block_cache.block_index}")
not_skipped = [block for block in blocks if not block.__block_cache.allowed_to_skip]
for block in not_skipped:
logging.info(f"EasySortblock: reco block {block.__class__.__name__} - consecutive_skipped_steps: {block.__block_cache.consecutive_skipped_steps}, prev_change_rate: {block.__block_cache.prev_change_rate}, index: {block.__block_cache.block_index}")
logging.info(f"EasySortblock: for {block_type}, selected {len(sorted_blocks[:threshold_index])} blocks for prediction and {len(sorted_blocks[threshold_index:])} blocks for recomputation")
# return executor(*args, **kwargs)
to_return = executor(*args, **kwargs)
return to_return
def block_forward_factory(func, block):
def block_forward_wrapper(*args, **kwargs):
transformer_options: dict[str] = kwargs.get("transformer_options")
sigmas = transformer_options["sigmas"]
sb_holder: EasySortblockHolder = transformer_options["easycache"]
cache: BlockCache = block.__block_cache
# make sure stream count is properly set for this block
if sb_holder.initial_step:
sb_holder.add_to_blocks_per_type(block, transformer_options['block'][0])
cache.block_index = transformer_options['block'][1]
cache.stream_count = transformer_options['block'][2]
if sb_holder.is_past_end_timestep(sigmas):
return func(*args, **kwargs)
# do sortblock stuff
x = cache.get_next_x_prev(args, kwargs)
# prepare next_x_prev
next_x_prev = cache.get_next_x_prev(args, kwargs, clone=True)
input_change = None
do_sortblock = sb_holder.should_do_easycache(sigmas)
if do_sortblock:
# TODO: checkmetadata
if cache.has_x_prev_subsampled():
input_change = (cache.subsample(x, clone=False) - cache.x_prev_subsampled).flatten().abs().mean()
if cache.has_output_prev_norm() and cache.has_relative_transformation_rate():
approx_output_change_rate = (cache.relative_transformation_rate * input_change) / cache.output_prev_norm
cache.cumulative_change_rate += approx_output_change_rate
if cache.allowed_to_skip:
# if cache.cumulative_change_rate < sb_holder.reuse_threshold:
# accumulate error + skip block
# cache.want_to_skip = True
# if cache.allowed_to_skip:
cache.consecutive_skipped_steps += 1
cache.prev_change_rate = approx_output_change_rate
return cache.apply_cache_diff(x, sb_holder)
else:
# reset error; NOT skipping block and recalculating
cache.cumulative_change_rate = 0.0
cache.prev_change_rate = approx_output_change_rate
cache.want_to_skip = False
cache.consecutive_skipped_steps = 0
# output_raw is expected to have cache.stream_count elements if count is greaater than 1 (double block, etc.)
output_raw: Union[torch.Tensor, tuple[torch.Tensor, ...]] = func(*args, **kwargs)
# if more than one stream from block, only use first one
if isinstance(output_raw, tuple):
output = output_raw[0]
else:
output = output_raw
if cache.has_output_prev_norm():
output_change = (cache.subsample(output, clone=False) - cache.output_prev_subsampled).flatten().abs().mean()
# if verbose in future
output_change_rate = output_change / cache.output_prev_norm
cache.output_change_rates.append(output_change_rate.item())
if cache.has_relative_transformation_rate():
approx_output_change_rate = (cache.relative_transformation_rate * input_change) / cache.output_prev_norm
cache.approx_output_change_rates.append(approx_output_change_rate.item())
if input_change is not None:
cache.relative_transformation_rate = output_change / input_change
# TODO: allow cache_diff to be offloaded
cache.update_cache_diff(output_raw, next_x_prev)
cache.x_prev_subsampled = cache.subsample(next_x_prev)
cache.output_prev_subsampled = cache.subsample(output)
cache.output_prev_norm = output.flatten().abs().mean()
return output_raw
return block_forward_wrapper
def prepare_block(block, sb_holder: EasySortblockHolder, stream_count: int=1):
sb_holder.add_to_all_blocks(block)
block.__original_forward = block.forward
block.forward = block_forward_factory(block.__original_forward, block)
block.__block_cache = BlockCache(subsample_factor=sb_holder.subsample_factor, verbose=sb_holder.verbose)
def clean_block(block):
block.forward = block.__original_forward
del block.__original_forward
del block.__block_cache
class BlockCache:
def __init__(self, subsample_factor: int=8, verbose: bool=False):
self.subsample_factor = subsample_factor
self.verbose = verbose
self.stream_count = 1
self.block_index = 0
# control values
self.relative_transformation_rate: float = None
self.cumulative_change_rate = 0.0
self.prev_change_rate = 0.0
# cached values
self.x_prev_subsampled: torch.Tensor = None
self.output_prev_subsampled: torch.Tensor = None
self.output_prev_norm: torch.Tensor = None
self.cache_diff: list[torch.Tensor] = []
self.output_change_rates = []
self.approx_output_change_rates = []
self.steps_skipped: list[int] = []
self.consecutive_skipped_steps = 0
# self.state_metadata = None
self.want_to_skip = False
self.allowed_to_skip = False
def has_cache_diff(self) -> bool:
return self.cache_diff[0] is not None
def has_x_prev_subsampled(self) -> bool:
return self.x_prev_subsampled is not None
def has_output_prev_subsampled(self) -> bool:
return self.output_prev_subsampled is not None
def has_output_prev_norm(self) -> bool:
return self.output_prev_norm is not None
def has_relative_transformation_rate(self) -> bool:
return self.relative_transformation_rate is not None
def get_next_x_prev(self, d_args: tuple[torch.Tensor, ...], d_kwargs: dict[str, torch.Tensor], clone: bool=False) -> tuple[torch.Tensor, ...]:
if self.stream_count == 1:
if clone:
return d_args[0].clone()
return d_args[0]
keys = list(d_kwargs.keys())[:self.stream_count]
orig_inputs = []
for key in keys:
if clone:
orig_inputs.append(d_kwargs[key].clone())
else:
orig_inputs.append(d_kwargs[key])
return tuple(orig_inputs)
def subsample(self, x: Union[torch.Tensor, tuple[torch.Tensor, ...]], clone: bool = True) -> torch.Tensor:
# subsample only the first compoenent
if isinstance(x, tuple):
return self.subsample(x[0], clone)
if self.subsample_factor > 1:
to_return = x[..., ::self.subsample_factor, ::self.subsample_factor]
if clone:
return to_return.clone()
return to_return
if clone:
return x.clone()
return x
def apply_cache_diff(self, x: Union[torch.Tensor, tuple[torch.Tensor, ...]], sb_holder: EasySortblockHolder):
self.steps_skipped.append(sb_holder.step_count)
if not isinstance(x, tuple):
x = (x, )
to_return = tuple([x[i] + self.cache_diff[i] for i in range(self.stream_count)])
if len(to_return) == 1:
return to_return[0]
return to_return
def update_cache_diff(self, output_raw: Union[torch.Tensor, tuple[torch.Tensor, ...]], x: Union[torch.Tensor, tuple[torch.Tensor, ...]]):
if not isinstance(output_raw, tuple):
output_raw = (output_raw, )
if not isinstance(x, tuple):
x = (x, )
self.cache_diff = tuple([output_raw[i] - x[i] for i in range(self.stream_count)])
def reset(self):
self.relative_transformation_rate = 0.0
self.cumulative_change_rate = 0.0
self.prev_change_rate = 0.0
self.x_prev_subsampled = None
self.output_prev_subsampled = None
self.output_prev_norm = None
self.cache_diff = []
self.output_change_rates = []
self.approx_output_change_rates = []
self.steps_skipped = []
self.consecutive_skipped_steps = 0
self.want_to_skip = False
self.allowed_to_skip = False
return self
class EasySortblockHolder:
def __init__(self, reuse_threshold: float, start_predict_ratio: float, end_predict_ratio: float, max_skipped_steps: int,
start_percent: float, end_percent: float, subsample_factor: int, verbose: bool=False):
self.name = "EasySortblock"
self.reuse_threshold = reuse_threshold
self.start_predict_ratio = start_predict_ratio
self.end_predict_ratio = end_predict_ratio
self.max_skipped_steps = max_skipped_steps
self.start_percent = start_percent
self.end_percent = end_percent
self.subsample_factor = subsample_factor
self.verbose = verbose
# timestep values
self.start_t = 0.0
self.end_t = 0.0
# control values
self.relative_transformation_rate: float = None
self.cumulative_change_rate = 0.0
self.initial_step = True
self.step_count = 0
self.predict_ratios = []
self.skip_current_step = False
self.predict_start_index = 0
# cache values
self.x_prev_subsampled: torch.Tensor = None
self.output_prev_subsampled: torch.Tensor = None
self.output_prev_norm: torch.Tensor = None
self.steps_skipped: list[int] = []
self.output_change_rates = []
self.approx_output_change_rates = []
self.state_metadata = None
self.all_blocks = []
self.blocks_per_type = {}
def add_to_all_blocks(self, block):
self.all_blocks.append(block)
def add_to_blocks_per_type(self, block, block_type: str):
self.blocks_per_type.setdefault(block_type, []).append(block)
def is_past_end_timestep(self, timestep: float) -> bool:
return not (timestep[0] > self.end_t).item()
def should_do_easycache(self, timestep: float) -> bool:
return (timestep[0] <= self.start_t).item()
def check_if_within_timesteps(self, timestep: Union[float, torch.Tensor]) -> bool:
return (timestep <= self.start_t).item() and (timestep > self.end_t).item()
def has_x_prev_subsampled(self) -> bool:
return self.x_prev_subsampled is not None
def has_output_prev_subsampled(self) -> bool:
return self.output_prev_subsampled is not None
def has_output_prev_norm(self) -> bool:
return self.output_prev_norm is not None
def has_relative_transformation_rate(self) -> bool:
return self.relative_transformation_rate is not None
def prepare_timesteps(self, model_sampling):
self.start_t = model_sampling.percent_to_sigma(self.start_percent)
self.end_t = model_sampling.percent_to_sigma(self.end_percent)
return self
def subsample(self, x: torch.Tensor, clone: bool = True) -> torch.Tensor:
if self.subsample_factor > 1:
to_return = x[..., ::self.subsample_factor, ::self.subsample_factor]
if clone:
return to_return.clone()
return to_return
if clone:
return x.clone()
return x
def check_metadata(self, x: torch.Tensor) -> bool:
metadata = (x.device, x.dtype, x.shape)
if self.state_metadata is None:
self.state_metadata = metadata
return True
if metadata == self.state_metadata:
return True
logging.warning(f"{self.name} - Tensor shape, dtype or device changed, resetting state")
self.reset()
return False
def reset(self):
logging.info(f"EasySortblock: resetting {len(self.all_blocks)} blocks")
for block in self.all_blocks:
clean_block(block)
self.relative_transformation_rate = 0.0
self.cumulative_change_rate = 0.0
self.initial_step = True
self.step_count = 0
self.predict_ratios = []
self.skip_current_step = False
self.predict_start_index = 0
self.x_prev_subsampled = None
self.output_prev_subsampled = None
self.output_prev_norm = None
self.steps_skipped = []
self.output_change_rates = []
self.approx_output_change_rates = []
self.state_metadata = None
self.all_blocks = []
self.blocks_per_type = {}
return self
def clone(self):
return EasySortblockHolder(self.reuse_threshold, self.start_predict_ratio, self.end_predict_ratio, self.max_skipped_steps,
self.start_percent, self.end_percent, self.subsample_factor, self.verbose)
class EasySortblockScaledNode(io.ComfyNode):
@classmethod
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="EasySortblockScaled",
display_name="EasySortblockScaled",
description="A homebrew version of EasyCache - even 'easier' version of EasyCache to implement. Overall works worse than EasyCache, but better in some rare cases AND universal compatibility with everything in ComfyUI.",
category="advanced/debug/model",
is_experimental=True,
inputs=[
io.Model.Input("model", tooltip="The model to add Sortblock to."),
io.Float.Input("reuse_threshold", min=0.0, default=0.2, max=3.0, step=0.01, tooltip="The threshold for reusing cached steps."),
io.Float.Input("start_predict_ratio", min=0.0, default=0.2, max=1.0, step=0.01, tooltip="The ratio of blocks to predict."),
io.Float.Input("end_predict_ratio", min=0.0, default=0.9, max=1.0, step=0.01, tooltip="The ratio of blocks to predict."),
io.Int.Input("policy_refresh_interval", min=3, default=5, max=100, step=1, tooltip="The interval at which to refresh the policy."),
io.Float.Input("start_percent", min=0.0, default=0.15, max=1.0, step=0.01, tooltip="The relative sampling step to begin use of Sortblock."),
io.Float.Input("end_percent", min=0.0, default=0.95, max=1.0, step=0.01, tooltip="The relative sampling step to end use of Sortblock."),
io.Boolean.Input("verbose", default=False, tooltip="Whether to log verbose information."),
],
outputs=[
io.Model.Output(tooltip="The model with Sortblock."),
],
)
@classmethod
def execute(cls, model: io.Model.Type, reuse_threshold: float, start_predict_ratio: float, end_predict_ratio: float, policy_refresh_interval: int, start_percent: float, end_percent: float, verbose: bool) -> io.NodeOutput:
# TODO: check for specific flavors of supported models
model = model.clone()
model.model_options["transformer_options"]["easycache"] = EasySortblockHolder(reuse_threshold, start_predict_ratio, end_predict_ratio, policy_refresh_interval, start_percent, end_percent, subsample_factor=8, verbose=verbose)
model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.PREDICT_NOISE, "sortblock", easysortblock_predict_noise_wrapper)
model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, "sortblock", easysortblock_outer_sample_wrapper)
model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, "sortblock", model_forward_wrapper)
return io.NodeOutput(model)
class EasySortblockExtension(ComfyExtension):
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
# EasySortblockNode,
EasySortblockScaledNode,
]
def comfy_entrypoint():
return EasySortblockExtension()

View File

@@ -25,6 +25,7 @@ def prepare_noise_wrapper(executor, *args, **kwargs):
start_index = relevant_sigmas[0][0] start_index = relevant_sigmas[0][0]
end_index = relevant_sigmas[-1][0] end_index = relevant_sigmas[-1][0]
sb_holder.predict_ratios = torch.linspace(sb_holder.start_predict_ratio, sb_holder.end_predict_ratio, end_index - start_index + 1) sb_holder.predict_ratios = torch.linspace(sb_holder.start_predict_ratio, sb_holder.end_predict_ratio, end_index - start_index + 1)
sb_holder.predict_start_index = start_index
return executor(*args, **kwargs) return executor(*args, **kwargs)
finally: finally:
@@ -95,7 +96,8 @@ def model_forward_wrapper(executor, *args, **kwargs):
# when 1: Select DiT blocks(4) # when 1: Select DiT blocks(4)
if sb_holder.step_modulus == 1: if sb_holder.step_modulus == 1:
predict_ratio = sb_holder.predict_ratios[sb_holder.active_steps-1] predict_index = max(0, sb_holder.step_count - sb_holder.predict_start_index)
predict_ratio = sb_holder.predict_ratios[predict_index]
logging.info(f"Sortblock: for step {sb_holder.step_count}, selecting blocks for recomputation and prediction, predict_ratio: {predict_ratio}") logging.info(f"Sortblock: for step {sb_holder.step_count}, selecting blocks for recomputation and prediction, predict_ratio: {predict_ratio}")
reuse_ratio = 1.0 - predict_ratio reuse_ratio = 1.0 - predict_ratio
for block_type, blocks in sb_holder.blocks_per_type.items(): for block_type, blocks in sb_holder.blocks_per_type.items():
@@ -322,6 +324,7 @@ class SortblockHolder:
self.do_sortblock = False self.do_sortblock = False
self.active_steps = 0 self.active_steps = 0
self.predict_ratios = [] self.predict_ratios = []
self.predict_start_index = 0
# cache values # cache values
self.all_blocks = [] self.all_blocks = []
@@ -371,6 +374,7 @@ class SortblockHolder:
self.active_steps = 0 self.active_steps = 0
self.predict_ratios = [] self.predict_ratios = []
self.do_sortblock = False self.do_sortblock = False
self.predict_start_index = 0
return self return self
def clone(self): def clone(self):

View File

@@ -2326,6 +2326,7 @@ async def init_builtin_extra_nodes():
"nodes_easycache.py", "nodes_easycache.py",
"nodes_audio_encoder.py", "nodes_audio_encoder.py",
"nodes_sortblock.py", "nodes_sortblock.py",
"nodes_easysortblock.py",
] ]
import_failed = [] import_failed = []