mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-09-11 12:06:23 +00:00
some exploration of sortblock as more things from paper/source code need to be added
This commit is contained in:
@@ -24,6 +24,11 @@ def outer_sample_wrapper(executor, *args, **kwargs):
|
||||
finally:
|
||||
sb_holder = guider.model_options["transformer_options"]["sortblock"]
|
||||
sb_holder.print_block_info(0)
|
||||
|
||||
# import plotly.express as px
|
||||
# fig = px.line(x=list(range(len(sb_holder.blocks))), y=[getattr(block, "__block_cache").cumulative_change_rate for block in sb_holder.blocks])
|
||||
|
||||
|
||||
sb_holder.reset()
|
||||
guider.model_options = orig_model_options
|
||||
|
||||
@@ -56,7 +61,7 @@ def model_forward_wrapper(executor, *args, **kwargs):
|
||||
total_double_block = len(executor.class_obj.double_blocks)
|
||||
total_single_block = len(executor.class_obj.single_blocks)
|
||||
perform_sortblock(sb_holder.blocks[:total_double_block])
|
||||
#perform_sortblock(sb_holder.blocks[total_double_block:])
|
||||
perform_sortblock(sb_holder.blocks[total_double_block:])
|
||||
if sb_holder.initial_step:
|
||||
sb_holder.initial_step = False
|
||||
|
||||
@@ -105,7 +110,7 @@ def block_forward_factory(func, block):
|
||||
x = cache.get_next_x_prev(kwargs)
|
||||
timestep: float = sb_holder.curr_t
|
||||
# prepare next_x_prev
|
||||
next_x_prev = cache.get_next_x_prev(kwargs)
|
||||
next_x_prev = cache.get_next_x_prev(kwargs, clone=True)
|
||||
input_change = None
|
||||
do_sortblock = sb_holder.should_do_sortblock()
|
||||
if do_sortblock:
|
||||
@@ -117,9 +122,10 @@ def block_forward_factory(func, block):
|
||||
cache.cumulative_change_rate += approx_output_change_rate
|
||||
if cache.cumulative_change_rate < sb_holder.reuse_threshold:
|
||||
# accumulate error + skip block
|
||||
cache.want_to_skip = True
|
||||
if cache.allowed_to_skip:
|
||||
return cache.apply_cache_diff(x)
|
||||
# cache.want_to_skip = True
|
||||
# if cache.allowed_to_skip:
|
||||
# return cache.apply_cache_diff(x)
|
||||
pass
|
||||
else:
|
||||
# reset error; NOT skipping block and recalculating
|
||||
cache.cumulative_change_rate = 0.0
|
||||
@@ -246,11 +252,13 @@ class BlockCache:
|
||||
def has_relative_transformation_rate(self) -> bool:
|
||||
return self.relative_transformation_rate is not None
|
||||
|
||||
def get_next_x_prev(self, d_kwargs: dict[str, torch.Tensor]) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
|
||||
def get_next_x_prev(self, d_kwargs: dict[str, torch.Tensor], clone: bool=False) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
|
||||
keys = list(d_kwargs.keys())
|
||||
if self.stream_count == 1:
|
||||
if clone:
|
||||
return d_kwargs[keys[0]].clone()
|
||||
return d_kwargs[keys[0]]
|
||||
return tuple([d_kwargs[keys[i]] for i in range(self.stream_count)])
|
||||
return tuple([d_kwargs[keys[i]].clone() if clone else d_kwargs[keys[i]] for i in range(self.stream_count)])
|
||||
|
||||
def subsample(self, x: Union[torch.Tensor, tuple[torch.Tensor, ...]], clone: bool = True) -> torch.Tensor:
|
||||
# subsample only the first compoenent
|
||||
|
Reference in New Issue
Block a user