mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-09-12 12:37:01 +00:00
some exploration of sortblock as more things from paper/source code need to be added
This commit is contained in:
@@ -24,6 +24,11 @@ def outer_sample_wrapper(executor, *args, **kwargs):
|
|||||||
finally:
|
finally:
|
||||||
sb_holder = guider.model_options["transformer_options"]["sortblock"]
|
sb_holder = guider.model_options["transformer_options"]["sortblock"]
|
||||||
sb_holder.print_block_info(0)
|
sb_holder.print_block_info(0)
|
||||||
|
|
||||||
|
# import plotly.express as px
|
||||||
|
# fig = px.line(x=list(range(len(sb_holder.blocks))), y=[getattr(block, "__block_cache").cumulative_change_rate for block in sb_holder.blocks])
|
||||||
|
|
||||||
|
|
||||||
sb_holder.reset()
|
sb_holder.reset()
|
||||||
guider.model_options = orig_model_options
|
guider.model_options = orig_model_options
|
||||||
|
|
||||||
@@ -56,7 +61,7 @@ def model_forward_wrapper(executor, *args, **kwargs):
|
|||||||
total_double_block = len(executor.class_obj.double_blocks)
|
total_double_block = len(executor.class_obj.double_blocks)
|
||||||
total_single_block = len(executor.class_obj.single_blocks)
|
total_single_block = len(executor.class_obj.single_blocks)
|
||||||
perform_sortblock(sb_holder.blocks[:total_double_block])
|
perform_sortblock(sb_holder.blocks[:total_double_block])
|
||||||
#perform_sortblock(sb_holder.blocks[total_double_block:])
|
perform_sortblock(sb_holder.blocks[total_double_block:])
|
||||||
if sb_holder.initial_step:
|
if sb_holder.initial_step:
|
||||||
sb_holder.initial_step = False
|
sb_holder.initial_step = False
|
||||||
|
|
||||||
@@ -105,7 +110,7 @@ def block_forward_factory(func, block):
|
|||||||
x = cache.get_next_x_prev(kwargs)
|
x = cache.get_next_x_prev(kwargs)
|
||||||
timestep: float = sb_holder.curr_t
|
timestep: float = sb_holder.curr_t
|
||||||
# prepare next_x_prev
|
# prepare next_x_prev
|
||||||
next_x_prev = cache.get_next_x_prev(kwargs)
|
next_x_prev = cache.get_next_x_prev(kwargs, clone=True)
|
||||||
input_change = None
|
input_change = None
|
||||||
do_sortblock = sb_holder.should_do_sortblock()
|
do_sortblock = sb_holder.should_do_sortblock()
|
||||||
if do_sortblock:
|
if do_sortblock:
|
||||||
@@ -117,9 +122,10 @@ def block_forward_factory(func, block):
|
|||||||
cache.cumulative_change_rate += approx_output_change_rate
|
cache.cumulative_change_rate += approx_output_change_rate
|
||||||
if cache.cumulative_change_rate < sb_holder.reuse_threshold:
|
if cache.cumulative_change_rate < sb_holder.reuse_threshold:
|
||||||
# accumulate error + skip block
|
# accumulate error + skip block
|
||||||
cache.want_to_skip = True
|
# cache.want_to_skip = True
|
||||||
if cache.allowed_to_skip:
|
# if cache.allowed_to_skip:
|
||||||
return cache.apply_cache_diff(x)
|
# return cache.apply_cache_diff(x)
|
||||||
|
pass
|
||||||
else:
|
else:
|
||||||
# reset error; NOT skipping block and recalculating
|
# reset error; NOT skipping block and recalculating
|
||||||
cache.cumulative_change_rate = 0.0
|
cache.cumulative_change_rate = 0.0
|
||||||
@@ -246,11 +252,13 @@ class BlockCache:
|
|||||||
def has_relative_transformation_rate(self) -> bool:
|
def has_relative_transformation_rate(self) -> bool:
|
||||||
return self.relative_transformation_rate is not None
|
return self.relative_transformation_rate is not None
|
||||||
|
|
||||||
def get_next_x_prev(self, d_kwargs: dict[str, torch.Tensor]) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
|
def get_next_x_prev(self, d_kwargs: dict[str, torch.Tensor], clone: bool=False) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
|
||||||
keys = list(d_kwargs.keys())
|
keys = list(d_kwargs.keys())
|
||||||
if self.stream_count == 1:
|
if self.stream_count == 1:
|
||||||
|
if clone:
|
||||||
|
return d_kwargs[keys[0]].clone()
|
||||||
return d_kwargs[keys[0]]
|
return d_kwargs[keys[0]]
|
||||||
return tuple([d_kwargs[keys[i]] for i in range(self.stream_count)])
|
return tuple([d_kwargs[keys[i]].clone() if clone else d_kwargs[keys[i]] for i in range(self.stream_count)])
|
||||||
|
|
||||||
def subsample(self, x: Union[torch.Tensor, tuple[torch.Tensor, ...]], clone: bool = True) -> torch.Tensor:
|
def subsample(self, x: Union[torch.Tensor, tuple[torch.Tensor, ...]], clone: bool = True) -> torch.Tensor:
|
||||||
# subsample only the first compoenent
|
# subsample only the first compoenent
|
||||||
|
Reference in New Issue
Block a user