mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-09-10 11:35:40 +00:00
More progress on Sortblock
This commit is contained in:
@@ -128,7 +128,7 @@ class Flux(nn.Module):
|
|||||||
|
|
||||||
blocks_replace = patches_replace.get("dit", {})
|
blocks_replace = patches_replace.get("dit", {})
|
||||||
for i, block in enumerate(self.double_blocks):
|
for i, block in enumerate(self.double_blocks):
|
||||||
transformer_options["block"] = ("double_block", i)
|
transformer_options["block"] = ("double_block", i, 2)
|
||||||
if ("double_block", i) in blocks_replace:
|
if ("double_block", i) in blocks_replace:
|
||||||
def block_wrap(args):
|
def block_wrap(args):
|
||||||
out = {}
|
out = {}
|
||||||
@@ -170,7 +170,7 @@ class Flux(nn.Module):
|
|||||||
img = torch.cat((txt, img), 1)
|
img = torch.cat((txt, img), 1)
|
||||||
|
|
||||||
for i, block in enumerate(self.single_blocks):
|
for i, block in enumerate(self.single_blocks):
|
||||||
transformer_options["block"] = ("single_block", i)
|
transformer_options["block"] = ("single_block", i, 1)
|
||||||
if ("single_block", i) in blocks_replace:
|
if ("single_block", i) in blocks_replace:
|
||||||
def block_wrap(args):
|
def block_wrap(args):
|
||||||
out = {}
|
out = {}
|
||||||
|
@@ -10,7 +10,6 @@ import comfy.model_patcher
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
|
|
||||||
def outer_sample_wrapper(executor, *args, **kwargs):
|
def outer_sample_wrapper(executor, *args, **kwargs):
|
||||||
try:
|
try:
|
||||||
logging.info("Sortblock: inside outer_sample!")
|
logging.info("Sortblock: inside outer_sample!")
|
||||||
@@ -24,6 +23,7 @@ def outer_sample_wrapper(executor, *args, **kwargs):
|
|||||||
return executor(*args, **kwargs)
|
return executor(*args, **kwargs)
|
||||||
finally:
|
finally:
|
||||||
sb_holder = guider.model_options["transformer_options"]["sortblock"]
|
sb_holder = guider.model_options["transformer_options"]["sortblock"]
|
||||||
|
sb_holder.print_block_info(0)
|
||||||
sb_holder.reset()
|
sb_holder.reset()
|
||||||
guider.model_options = orig_model_options
|
guider.model_options = orig_model_options
|
||||||
|
|
||||||
@@ -43,14 +43,38 @@ def model_forward_wrapper(executor, *args, **kwargs):
|
|||||||
transformer_options["total_double_block"] = len(executor.class_obj.double_blocks)
|
transformer_options["total_double_block"] = len(executor.class_obj.double_blocks)
|
||||||
transformer_options["total_single_block"] = len(executor.class_obj.single_blocks)
|
transformer_options["total_single_block"] = len(executor.class_obj.single_blocks)
|
||||||
# save the original forwards on the blocks
|
# save the original forwards on the blocks
|
||||||
logging.info(f"Sortblock: saving preparing {transformer_options['total_double_block']} double blocks and {transformer_options['total_single_block']} single blocks")
|
logging.info(f"Sortblock: preparing {transformer_options['total_double_block']} double blocks and {transformer_options['total_single_block']} single blocks")
|
||||||
for block in executor.class_obj.double_blocks:
|
for block in executor.class_obj.double_blocks:
|
||||||
prepare_block(block, sb_holder)
|
prepare_block(block, sb_holder)
|
||||||
for block in executor.class_obj.single_blocks:
|
for block in executor.class_obj.single_blocks:
|
||||||
prepare_block(block, sb_holder)
|
prepare_block(block, sb_holder)
|
||||||
sb_holder.initial_step = False
|
try:
|
||||||
|
return executor(*args, **kwargs)
|
||||||
|
finally:
|
||||||
|
sb_holder: SortblockHolder = transformer_options["sortblock"]
|
||||||
|
# do double blocks
|
||||||
|
total_double_block = len(executor.class_obj.double_blocks)
|
||||||
|
total_single_block = len(executor.class_obj.single_blocks)
|
||||||
|
perform_sortblock(sb_holder.blocks[:total_double_block])
|
||||||
|
#perform_sortblock(sb_holder.blocks[total_double_block:])
|
||||||
|
if sb_holder.initial_step:
|
||||||
|
sb_holder.initial_step = False
|
||||||
|
|
||||||
|
def perform_sortblock(blocks: list):
|
||||||
|
candidate_blocks = []
|
||||||
|
for block in blocks:
|
||||||
|
cache: BlockCache = getattr(block, "__block_cache")
|
||||||
|
cache.allowed_to_skip = False
|
||||||
|
if cache.want_to_skip:
|
||||||
|
candidate_blocks.append(block)
|
||||||
|
if len(candidate_blocks) > 0:
|
||||||
|
percentage_to_skip = 1.0
|
||||||
|
candidate_blocks.sort(key=lambda x: getattr(x, "__block_cache").cumulative_change_rate)
|
||||||
|
blocks_to_skip = int(len(candidate_blocks) * percentage_to_skip)
|
||||||
|
for block in candidate_blocks[:blocks_to_skip]:
|
||||||
|
cache: BlockCache = getattr(block, "__block_cache")
|
||||||
|
cache.allowed_to_skip = True
|
||||||
|
|
||||||
return executor(*args, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def prepare_block(block, sb_holder: SortblockHolder, stream_count: int=1):
|
def prepare_block(block, sb_holder: SortblockHolder, stream_count: int=1):
|
||||||
@@ -69,14 +93,16 @@ def clean_block(block):
|
|||||||
def block_forward_factory(func, block):
|
def block_forward_factory(func, block):
|
||||||
def block_forward_wrapper(*args, **kwargs):
|
def block_forward_wrapper(*args, **kwargs):
|
||||||
transformer_options: dict[str] = kwargs.get("transformer_options", None)
|
transformer_options: dict[str] = kwargs.get("transformer_options", None)
|
||||||
logging.info(f"Sortblock: inside block {transformer_options['block']}")
|
#logging.info(f"Sortblock: inside block {transformer_options['block']}")
|
||||||
sb_holder: SortblockHolder = transformer_options["sortblock"]
|
sb_holder: SortblockHolder = transformer_options["sortblock"]
|
||||||
|
cache: BlockCache = block.__block_cache
|
||||||
|
if sb_holder.initial_step:
|
||||||
|
cache.stream_count = transformer_options['block'][2]
|
||||||
if sb_holder.is_past_end_timestep():
|
if sb_holder.is_past_end_timestep():
|
||||||
return func(*args, **kwargs)
|
return func(*args, **kwargs)
|
||||||
# do sortblock stuff
|
# do sortblock stuff
|
||||||
cache: BlockCache = block.__block_cache
|
|
||||||
keys = list(kwargs.keys())
|
keys = list(kwargs.keys())
|
||||||
x: torch.Tensor = kwargs[keys[0]]
|
x = cache.get_next_x_prev(kwargs)
|
||||||
timestep: float = sb_holder.curr_t
|
timestep: float = sb_holder.curr_t
|
||||||
# prepare next_x_prev
|
# prepare next_x_prev
|
||||||
next_x_prev = cache.get_next_x_prev(kwargs)
|
next_x_prev = cache.get_next_x_prev(kwargs)
|
||||||
@@ -91,10 +117,13 @@ def block_forward_factory(func, block):
|
|||||||
cache.cumulative_change_rate += approx_output_change_rate
|
cache.cumulative_change_rate += approx_output_change_rate
|
||||||
if cache.cumulative_change_rate < sb_holder.reuse_threshold:
|
if cache.cumulative_change_rate < sb_holder.reuse_threshold:
|
||||||
# accumulate error + skip block
|
# accumulate error + skip block
|
||||||
pass
|
cache.want_to_skip = True
|
||||||
|
if cache.allowed_to_skip:
|
||||||
|
return cache.apply_cache_diff(x)
|
||||||
else:
|
else:
|
||||||
# reset error; NOT skipping block and recalculating
|
# reset error; NOT skipping block and recalculating
|
||||||
cache.cumulative_change_rate = 0.0
|
cache.cumulative_change_rate = 0.0
|
||||||
|
cache.want_to_skip = False
|
||||||
# output_raw is expected to have cache.stream_count elements if count is greaater than 1 (double block, etc.)
|
# output_raw is expected to have cache.stream_count elements if count is greaater than 1 (double block, etc.)
|
||||||
output_raw: Union[torch.Tensor, tuple[torch.Tensor, ...]] = func(*args, **kwargs)
|
output_raw: Union[torch.Tensor, tuple[torch.Tensor, ...]] = func(*args, **kwargs)
|
||||||
# if more than one stream from block, only use first one
|
# if more than one stream from block, only use first one
|
||||||
@@ -162,6 +191,12 @@ class SortblockHolder:
|
|||||||
def should_do_sortblock(self) -> bool:
|
def should_do_sortblock(self) -> bool:
|
||||||
return self.do_sortblock
|
return self.do_sortblock
|
||||||
|
|
||||||
|
def print_block_info(self, index: int):
|
||||||
|
block = self.blocks[index]
|
||||||
|
cache = getattr(block, "__block_cache")
|
||||||
|
logging.info(f"Sortblock: block {index} output_change_rates: {cache.output_change_rates}")
|
||||||
|
logging.info(f"Sortblock: block {index} approx_output_change_rates: {cache.approx_output_change_rates}")
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
self.initial_step = True
|
self.initial_step = True
|
||||||
self.curr_t = 0.0
|
self.curr_t = 0.0
|
||||||
@@ -193,6 +228,8 @@ class BlockCache:
|
|||||||
self.approx_output_change_rates = []
|
self.approx_output_change_rates = []
|
||||||
self.total_steps_skipped = 0
|
self.total_steps_skipped = 0
|
||||||
self.state_metadata = None
|
self.state_metadata = None
|
||||||
|
self.want_to_skip = False
|
||||||
|
self.allowed_to_skip = False
|
||||||
|
|
||||||
def has_cache_diff(self) -> bool:
|
def has_cache_diff(self) -> bool:
|
||||||
return self.cache_diff[0] is not None
|
return self.cache_diff[0] is not None
|
||||||
@@ -228,12 +265,21 @@ class BlockCache:
|
|||||||
return x.clone()
|
return x.clone()
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def apply_cache_diff(self, x: torch.Tensor):
|
def apply_cache_diff(self, x: Union[torch.Tensor, tuple[torch.Tensor, ...]]):
|
||||||
self.total_steps_skipped += 1
|
self.total_steps_skipped += 1
|
||||||
return x + self.cache_diff.to(x.device)
|
if not isinstance(x, tuple):
|
||||||
|
x = (x, )
|
||||||
|
to_return = tuple([x[i] + self.cache_diff[i] for i in range(self.stream_count)])
|
||||||
|
if len(to_return) == 1:
|
||||||
|
return to_return[0]
|
||||||
|
return to_return
|
||||||
|
|
||||||
def update_cache_diff(self, output: torch.Tensor, x: torch.Tensor):
|
def update_cache_diff(self, output_raw: Union[torch.Tensor, tuple[torch.Tensor, ...]], x: Union[torch.Tensor, tuple[torch.Tensor, ...]]):
|
||||||
self.cache_diff = output - x
|
if not isinstance(output_raw, tuple):
|
||||||
|
output_raw = (output_raw, )
|
||||||
|
if not isinstance(x, tuple):
|
||||||
|
x = (x, )
|
||||||
|
self.cache_diff = tuple([output_raw[i] - x[i] for i in range(self.stream_count)])
|
||||||
|
|
||||||
def check_metadata(self, x: torch.Tensor) -> bool:
|
def check_metadata(self, x: torch.Tensor) -> bool:
|
||||||
# TODO: make sure shapes are correct
|
# TODO: make sure shapes are correct
|
||||||
@@ -256,6 +302,8 @@ class BlockCache:
|
|||||||
self.approx_output_change_rates = []
|
self.approx_output_change_rates = []
|
||||||
self.total_steps_skipped = 0
|
self.total_steps_skipped = 0
|
||||||
self.state_metadata = None
|
self.state_metadata = None
|
||||||
|
self.want_to_skip = False
|
||||||
|
self.allowed_to_skip = False
|
||||||
|
|
||||||
class SortblockNode(io.ComfyNode):
|
class SortblockNode(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
|
Reference in New Issue
Block a user