some exploration of sortblock as more things from paper/source code need to be added

This commit is contained in:
Jedrzej Kosinski
2025-09-01 09:39:40 -07:00
parent cf26d3d58e
commit d4a8752c8c

View File

@@ -24,6 +24,11 @@ def outer_sample_wrapper(executor, *args, **kwargs):
finally:
sb_holder = guider.model_options["transformer_options"]["sortblock"]
sb_holder.print_block_info(0)
# import plotly.express as px
# fig = px.line(x=list(range(len(sb_holder.blocks))), y=[getattr(block, "__block_cache").cumulative_change_rate for block in sb_holder.blocks])
sb_holder.reset()
guider.model_options = orig_model_options
@@ -56,7 +61,7 @@ def model_forward_wrapper(executor, *args, **kwargs):
total_double_block = len(executor.class_obj.double_blocks)
total_single_block = len(executor.class_obj.single_blocks)
perform_sortblock(sb_holder.blocks[:total_double_block])
#perform_sortblock(sb_holder.blocks[total_double_block:])
perform_sortblock(sb_holder.blocks[total_double_block:])
if sb_holder.initial_step:
sb_holder.initial_step = False
@@ -105,7 +110,7 @@ def block_forward_factory(func, block):
x = cache.get_next_x_prev(kwargs)
timestep: float = sb_holder.curr_t
# prepare next_x_prev
next_x_prev = cache.get_next_x_prev(kwargs)
next_x_prev = cache.get_next_x_prev(kwargs, clone=True)
input_change = None
do_sortblock = sb_holder.should_do_sortblock()
if do_sortblock:
@@ -117,9 +122,10 @@ def block_forward_factory(func, block):
cache.cumulative_change_rate += approx_output_change_rate
if cache.cumulative_change_rate < sb_holder.reuse_threshold:
# accumulate error + skip block
cache.want_to_skip = True
if cache.allowed_to_skip:
return cache.apply_cache_diff(x)
# cache.want_to_skip = True
# if cache.allowed_to_skip:
# return cache.apply_cache_diff(x)
pass
else:
# reset error; NOT skipping block and recalculating
cache.cumulative_change_rate = 0.0
@@ -246,11 +252,13 @@ class BlockCache:
def has_relative_transformation_rate(self) -> bool:
return self.relative_transformation_rate is not None
def get_next_x_prev(self, d_kwargs: dict[str, torch.Tensor]) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
def get_next_x_prev(self, d_kwargs: dict[str, torch.Tensor], clone: bool=False) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
keys = list(d_kwargs.keys())
if self.stream_count == 1:
if clone:
return d_kwargs[keys[0]].clone()
return d_kwargs[keys[0]]
return tuple([d_kwargs[keys[i]] for i in range(self.stream_count)])
return tuple([d_kwargs[keys[i]].clone() if clone else d_kwargs[keys[i]] for i in range(self.stream_count)])
def subsample(self, x: Union[torch.Tensor, tuple[torch.Tensor, ...]], clone: bool = True) -> torch.Tensor:
# subsample only the first compoenent