mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-09-10 03:25:22 +00:00
Support predict_ratio changing with timesteps
This commit is contained in:
@@ -13,10 +13,26 @@ if TYPE_CHECKING:
|
||||
|
||||
def prepare_noise_wrapper(executor, *args, **kwargs):
|
||||
try:
|
||||
transformer_options: dict[str] = args[2]["transformer_options"]
|
||||
sb_holder: SortblockHolder = transformer_options["sortblock"]
|
||||
if sb_holder.initial_step:
|
||||
sample_sigmas = transformer_options["sample_sigmas"]
|
||||
relevant_sigmas = []
|
||||
# find start and end steps, then use to interpolate between start and end predict ratios
|
||||
for i,sigma in enumerate(sample_sigmas):
|
||||
if sb_holder.check_if_within_timesteps(sigma):
|
||||
relevant_sigmas.append((i, sigma))
|
||||
start_index = relevant_sigmas[0][0]
|
||||
end_index = relevant_sigmas[-1][0]
|
||||
sb_holder.predict_ratios = torch.linspace(sb_holder.start_predict_ratio, sb_holder.end_predict_ratio, end_index - start_index + 1)
|
||||
|
||||
return executor(*args, **kwargs)
|
||||
finally:
|
||||
sb_holder: SortblockHolder = executor.class_obj.model_options["transformer_options"]["sortblock"]
|
||||
transformer_options: dict[str] = args[2]["transformer_options"]
|
||||
sb_holder: SortblockHolder = transformer_options["sortblock"]
|
||||
sb_holder.step_count += 1
|
||||
if sb_holder.should_do_sortblock():
|
||||
sb_holder.active_steps += 1
|
||||
|
||||
|
||||
def outer_sample_wrapper(executor, *args, **kwargs):
|
||||
@@ -29,7 +45,7 @@ def outer_sample_wrapper(executor, *args, **kwargs):
|
||||
sb_holder = guider.model_options["transformer_options"]["sortblock"]
|
||||
guider.model_options["transformer_options"]["sortblock"] = sb_holder.clone().prepare_timesteps(guider.model_patcher.model.model_sampling)
|
||||
sb_holder: SortblockHolder = guider.model_options["transformer_options"]["sortblock"]
|
||||
logging.info(f"Sortblock: enabled - threshold: {sb_holder.predict_ratio}, start_percent: {sb_holder.start_percent}, end_percent: {sb_holder.end_percent}")
|
||||
logging.info(f"Sortblock: enabled - threshold: {sb_holder.start_predict_ratio}, start_percent: {sb_holder.start_percent}, end_percent: {sb_holder.end_percent}")
|
||||
return executor(*args, **kwargs)
|
||||
finally:
|
||||
sb_holder = guider.model_options["transformer_options"]["sortblock"]
|
||||
@@ -79,11 +95,12 @@ def model_forward_wrapper(executor, *args, **kwargs):
|
||||
|
||||
# when 1: Select DiT blocks(4)
|
||||
if sb_holder.step_modulus == 1:
|
||||
logging.info(f"Sortblock: for step {sb_holder.step_count}, selecting blocks for recomputation and prediction")
|
||||
predict_ratio = 1.0 - sb_holder.predict_ratio
|
||||
predict_ratio = sb_holder.predict_ratios[sb_holder.active_steps-1]
|
||||
logging.info(f"Sortblock: for step {sb_holder.step_count}, selecting blocks for recomputation and prediction, predict_ratio: {predict_ratio}")
|
||||
reuse_ratio = 1.0 - predict_ratio
|
||||
for block_type, blocks in sb_holder.blocks_per_type.items():
|
||||
sorted_blocks = sorted(blocks, key=lambda x: x.__block_cache.cosine_similarity)
|
||||
threshold_index = int(len(sorted_blocks) * predict_ratio)
|
||||
threshold_index = int(len(sorted_blocks) * reuse_ratio)
|
||||
# blocks with lower similarity are marked for recomputation
|
||||
for block in sorted_blocks[:threshold_index]:
|
||||
cache: BlockCache = block.__block_cache
|
||||
@@ -280,8 +297,10 @@ def taylor_formula(taylor_factor: torch.Tensor, i: int, step_distance: int):
|
||||
)
|
||||
|
||||
class SortblockHolder:
|
||||
def __init__(self, predict_ratio: float, policy_refresh_interval: int, start_percent: float, end_percent: float, subsample_factor: int=8, verbose: bool=False):
|
||||
self.predict_ratio = predict_ratio
|
||||
def __init__(self, start_predict_ratio: float, end_predict_ratio: float, policy_refresh_interval: int,
|
||||
start_percent: float, end_percent: float, subsample_factor: int=8, verbose: bool=False):
|
||||
self.start_predict_ratio = start_predict_ratio
|
||||
self.end_predict_ratio = end_predict_ratio
|
||||
self.start_percent = start_percent
|
||||
self.end_percent = end_percent
|
||||
self.subsample_factor = subsample_factor
|
||||
@@ -300,6 +319,10 @@ class SortblockHolder:
|
||||
self.step_count = 0
|
||||
self.activated_steps: list[int] = [0]
|
||||
self.step_modulus = 0
|
||||
self.do_sortblock = False
|
||||
self.active_steps = 0
|
||||
self.predict_ratios = []
|
||||
|
||||
# cache values
|
||||
self.all_blocks = []
|
||||
self.blocks_per_type = {}
|
||||
@@ -315,6 +338,9 @@ class SortblockHolder:
|
||||
self.end_t = model_sampling.percent_to_sigma(self.end_percent)
|
||||
return self
|
||||
|
||||
def check_if_within_timesteps(self, timestep: Union[float, torch.Tensor]) -> bool:
|
||||
return (timestep <= self.start_t).item() and (timestep > self.end_t).item()
|
||||
|
||||
def update_should_do_sortblock(self, timestep: float) -> bool:
|
||||
self.do_sortblock = (timestep[0] <= self.start_t).item() and (timestep[0] > self.end_t).item()
|
||||
self.curr_t = timestep
|
||||
@@ -342,10 +368,13 @@ class SortblockHolder:
|
||||
self.step_count = 0
|
||||
self.activated_steps = [0]
|
||||
self.step_modulus = 0
|
||||
self.active_steps = 0
|
||||
self.predict_ratios = []
|
||||
self.do_sortblock = False
|
||||
return self
|
||||
|
||||
def clone(self):
|
||||
return SortblockHolder(predict_ratio=self.predict_ratio, policy_refresh_interval=self.policy_refresh_interval,
|
||||
return SortblockHolder(start_predict_ratio=self.start_predict_ratio, end_predict_ratio=self.end_predict_ratio, policy_refresh_interval=self.policy_refresh_interval,
|
||||
start_percent=self.start_percent, end_percent=self.end_percent, subsample_factor=self.subsample_factor,
|
||||
verbose=self.verbose)
|
||||
|
||||
@@ -376,7 +405,42 @@ class SortblockNode(io.ComfyNode):
|
||||
def execute(cls, model: io.Model.Type, predict_ratio: float, policy_refresh_interval: int, start_percent: float, end_percent: float, verbose: bool) -> io.NodeOutput:
|
||||
# TODO: check for specific flavors of supported models
|
||||
model = model.clone()
|
||||
model.model_options["transformer_options"]["sortblock"] = SortblockHolder(predict_ratio, policy_refresh_interval, start_percent, end_percent, subsample_factor=8, verbose=verbose)
|
||||
model.model_options["transformer_options"]["sortblock"] = SortblockHolder(start_predict_ratio=predict_ratio, end_predict_ratio=predict_ratio, policy_refresh_interval=policy_refresh_interval,
|
||||
start_percent=start_percent, end_percent=end_percent, subsample_factor=8, verbose=verbose)
|
||||
model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.PREDICT_NOISE, "sortblock", prepare_noise_wrapper)
|
||||
model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, "sortblock", outer_sample_wrapper)
|
||||
model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, "sortblock", model_forward_wrapper)
|
||||
return io.NodeOutput(model)
|
||||
|
||||
|
||||
class SortblockScaledNode(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls) -> io.Schema:
|
||||
return io.Schema(
|
||||
node_id="SortblockScaled",
|
||||
display_name="SortblockScaled",
|
||||
description="A homebrew version of EasyCache - even 'easier' version of EasyCache to implement. Overall works worse than EasyCache, but better in some rare cases AND universal compatibility with everything in ComfyUI.",
|
||||
category="advanced/debug/model",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
io.Model.Input("model", tooltip="The model to add Sortblock to."),
|
||||
io.Float.Input("start_predict_ratio", min=0.0, default=0.2, max=1.0, step=0.01, tooltip="The ratio of blocks to predict."),
|
||||
io.Float.Input("end_predict_ratio", min=0.0, default=0.9, max=1.0, step=0.01, tooltip="The ratio of blocks to predict."),
|
||||
io.Int.Input("policy_refresh_interval", min=3, default=5, max=100, step=1, tooltip="The interval at which to refresh the policy."),
|
||||
io.Float.Input("start_percent", min=0.0, default=0.15, max=1.0, step=0.01, tooltip="The relative sampling step to begin use of Sortblock."),
|
||||
io.Float.Input("end_percent", min=0.0, default=0.95, max=1.0, step=0.01, tooltip="The relative sampling step to end use of Sortblock."),
|
||||
io.Boolean.Input("verbose", default=False, tooltip="Whether to log verbose information."),
|
||||
],
|
||||
outputs=[
|
||||
io.Model.Output(tooltip="The model with Sortblock."),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model: io.Model.Type, start_predict_ratio: float, end_predict_ratio: float, policy_refresh_interval: int, start_percent: float, end_percent: float, verbose: bool) -> io.NodeOutput:
|
||||
# TODO: check for specific flavors of supported models
|
||||
model = model.clone()
|
||||
model.model_options["transformer_options"]["sortblock"] = SortblockHolder(start_predict_ratio, end_predict_ratio, policy_refresh_interval, start_percent, end_percent, subsample_factor=8, verbose=verbose)
|
||||
model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.PREDICT_NOISE, "sortblock", prepare_noise_wrapper)
|
||||
model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, "sortblock", outer_sample_wrapper)
|
||||
model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, "sortblock", model_forward_wrapper)
|
||||
@@ -387,6 +451,7 @@ class SortblockExtension(ComfyExtension):
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
SortblockNode,
|
||||
SortblockScaledNode,
|
||||
]
|
||||
|
||||
def comfy_entrypoint():
|
||||
|
Reference in New Issue
Block a user