mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-09-10 11:35:40 +00:00
463 lines
22 KiB
Python
463 lines
22 KiB
Python
from __future__ import annotations
|
|
from typing import TYPE_CHECKING, Union
|
|
|
|
from comfy_api.latest import io, ComfyExtension
|
|
import comfy.patcher_extension
|
|
import logging
|
|
import torch
|
|
import math
|
|
import comfy.model_patcher
|
|
if TYPE_CHECKING:
|
|
from uuid import UUID
|
|
|
|
|
|
def prepare_noise_wrapper(executor, *args, **kwargs):
|
|
try:
|
|
transformer_options: dict[str] = args[2]["transformer_options"]
|
|
sb_holder: SortblockHolder = transformer_options["sortblock"]
|
|
if sb_holder.initial_step:
|
|
sample_sigmas = transformer_options["sample_sigmas"]
|
|
relevant_sigmas = []
|
|
# find start and end steps, then use to interpolate between start and end predict ratios
|
|
for i,sigma in enumerate(sample_sigmas):
|
|
if sb_holder.check_if_within_timesteps(sigma):
|
|
relevant_sigmas.append((i, sigma))
|
|
start_index = relevant_sigmas[0][0]
|
|
end_index = relevant_sigmas[-1][0]
|
|
sb_holder.predict_ratios = torch.linspace(sb_holder.start_predict_ratio, sb_holder.end_predict_ratio, end_index - start_index + 1)
|
|
sb_holder.predict_start_index = start_index
|
|
|
|
return executor(*args, **kwargs)
|
|
finally:
|
|
transformer_options: dict[str] = args[2]["transformer_options"]
|
|
sb_holder: SortblockHolder = transformer_options["sortblock"]
|
|
sb_holder.step_count += 1
|
|
if sb_holder.should_do_sortblock():
|
|
sb_holder.active_steps += 1
|
|
|
|
|
|
def outer_sample_wrapper(executor, *args, **kwargs):
|
|
try:
|
|
logging.info("Sortblock: inside outer_sample!")
|
|
guider = executor.class_obj
|
|
orig_model_options = guider.model_options
|
|
guider.model_options = comfy.model_patcher.create_model_options_clone(orig_model_options)
|
|
# clone and prepare timesteps
|
|
sb_holder = guider.model_options["transformer_options"]["sortblock"]
|
|
guider.model_options["transformer_options"]["sortblock"] = sb_holder.clone().prepare_timesteps(guider.model_patcher.model.model_sampling)
|
|
sb_holder: SortblockHolder = guider.model_options["transformer_options"]["sortblock"]
|
|
logging.info(f"Sortblock: enabled - threshold: {sb_holder.start_predict_ratio}, start_percent: {sb_holder.start_percent}, end_percent: {sb_holder.end_percent}")
|
|
return executor(*args, **kwargs)
|
|
finally:
|
|
sb_holder = guider.model_options["transformer_options"]["sortblock"]
|
|
logging.info(f"Sortblock: final step count: {sb_holder.step_count}")
|
|
sb_holder.reset()
|
|
guider.model_options = orig_model_options
|
|
|
|
|
|
def model_forward_wrapper(executor, *args, **kwargs):
|
|
# TODO: make work with batches of conds
|
|
transformer_options: dict[str] = args[-1]
|
|
if not isinstance(transformer_options, dict):
|
|
transformer_options = kwargs.get("transformer_options")
|
|
if not transformer_options:
|
|
transformer_options = args[-2]
|
|
sigmas = transformer_options["sigmas"]
|
|
sb_holder: SortblockHolder = transformer_options["sortblock"]
|
|
sb_holder.update_should_do_sortblock(sigmas)
|
|
|
|
# if initial step, prepare everything for Sortblock
|
|
if sb_holder.initial_step:
|
|
logging.info(f"Sortblock: inside model {executor.class_obj.__class__.__name__}")
|
|
# TODO: generalize for other models
|
|
# these won't stick around past this step; should store on sb_holder instead
|
|
logging.info(f"Sortblock: preparing {len(executor.class_obj.double_blocks)} double blocks and {len(executor.class_obj.single_blocks)} single blocks")
|
|
if hasattr(executor.class_obj, "double_blocks"):
|
|
for block in executor.class_obj.double_blocks:
|
|
prepare_block(block, sb_holder)
|
|
if hasattr(executor.class_obj, "single_blocks"):
|
|
for block in executor.class_obj.single_blocks:
|
|
prepare_block(block, sb_holder)
|
|
if hasattr(executor.class_obj, "blocks"):
|
|
for block in executor.class_obj.block:
|
|
prepare_block(block, sb_holder)
|
|
|
|
# when 0: Initialization(1)
|
|
if sb_holder.step_modulus == 0:
|
|
logging.info(f"Sortblock: for step {sb_holder.step_count}, all blocks are marked for recomputation")
|
|
# all features are computed, input-outputs changes for all DiT blocks are stored for relative step 'k'
|
|
sb_holder.activated_steps.append(sb_holder.step_count)
|
|
for block in sb_holder.all_blocks:
|
|
cache: BlockCache = block.__block_cache
|
|
cache.mark_recompute()
|
|
|
|
# all block operations are performed in forward pass of model
|
|
to_return = executor(*args, **kwargs)
|
|
|
|
# when 1: Select DiT blocks(4)
|
|
if sb_holder.step_modulus == 1:
|
|
predict_index = max(0, sb_holder.step_count - sb_holder.predict_start_index)
|
|
predict_ratio = sb_holder.predict_ratios[predict_index]
|
|
logging.info(f"Sortblock: for step {sb_holder.step_count}, selecting blocks for recomputation and prediction, predict_ratio: {predict_ratio}")
|
|
reuse_ratio = 1.0 - predict_ratio
|
|
for block_type, blocks in sb_holder.blocks_per_type.items():
|
|
sorted_blocks = sorted(blocks, key=lambda x: x.__block_cache.cosine_similarity)
|
|
threshold_index = int(len(sorted_blocks) * reuse_ratio)
|
|
# blocks with lower similarity are marked for recomputation
|
|
for block in sorted_blocks[:threshold_index]:
|
|
cache: BlockCache = block.__block_cache
|
|
cache.mark_recompute()
|
|
# blocks with higher similarity are marked for prediction
|
|
for block in sorted_blocks[threshold_index:]:
|
|
cache: BlockCache = block.__block_cache
|
|
cache.mark_predict()
|
|
logging.info(f"Sortblock: for {block_type}, selected {len(sorted_blocks[:threshold_index])} blocks for recomputation and {len(sorted_blocks[threshold_index:])} blocks for prediction")
|
|
|
|
if sb_holder.initial_step:
|
|
sb_holder.initial_step = False
|
|
return to_return
|
|
|
|
def block_forward_factory(func, block):
|
|
def block_forward_wrapper(*args, **kwargs):
|
|
transformer_options: dict[str] = kwargs.get("transformer_options")
|
|
sb_holder: SortblockHolder = transformer_options["sortblock"]
|
|
cache: BlockCache = block.__block_cache
|
|
# make sure stream count is properly set for this block
|
|
if sb_holder.initial_step:
|
|
sb_holder.add_to_blocks_per_type(block, transformer_options['block'][0])
|
|
cache.block_index = transformer_options['block'][1]
|
|
cache.stream_count = transformer_options['block'][2]
|
|
# do sortblock stuff
|
|
if cache.recompute and sb_holder.step_modulus != 1:
|
|
# clone relevant inputs
|
|
orig_inputs = cache.get_orig_inputs(args, kwargs, clone=True)
|
|
# get block outputs
|
|
# NOTE: output_raw is expected to have cache.stream_count elements if count is greaater than 1 (double block, etc.)
|
|
if cache.stream_count == 1:
|
|
zzz = 10
|
|
output_raw: Union[torch.Tensor, tuple[torch.Tensor, ...]] = func(*args, **kwargs)
|
|
# perform derivative approximation;
|
|
cache.derivative_approximation(sb_holder, output_raw, orig_inputs)
|
|
# if step_modulus is 0, input-output changes for DiT block are stored
|
|
if sb_holder.step_modulus == 0:
|
|
cache.cache_previous_residual(output_raw, orig_inputs)
|
|
else:
|
|
# if not to recompute, predict features for current timestep
|
|
orig_inputs = cache.get_orig_inputs(args, kwargs, clone=False)
|
|
# when 1: Linear Prediction(2)
|
|
# if step_modulus is 1, store block residuals as 'current' after applying taylor_formula
|
|
if sb_holder.step_modulus == 1:
|
|
cache.cache_current_residual(sb_holder)
|
|
# based on features computed in last timestep, all features for current timestep are predicted using Eq. 4,
|
|
# input-output changes for all DiT blocks are stored for relative step 'k+1'
|
|
output_raw = cache.apply_linear_prediction(sb_holder, orig_inputs)
|
|
|
|
# when 1: Identify Changes(3)
|
|
if sb_holder.step_modulus == 1:
|
|
# based on features computed in last timestep, all features for current timestep are predicted using Eq. 4,
|
|
# input-output changes for all DiT blocks are stored for relative step 'k+1'
|
|
cache.calculate_cosine_similarity()
|
|
|
|
# return output_raw
|
|
return output_raw
|
|
return block_forward_wrapper
|
|
|
|
|
|
def perform_sortblock(blocks: list):
|
|
...
|
|
|
|
def prepare_block(block, sb_holder: SortblockHolder, stream_count: int=1):
|
|
sb_holder.add_to_all_blocks(block)
|
|
block.__original_forward = block.forward
|
|
block.forward = block_forward_factory(block.__original_forward, block)
|
|
block.__block_cache = BlockCache(subsample_factor=sb_holder.subsample_factor, verbose=sb_holder.verbose)
|
|
|
|
def clean_block(block):
|
|
block.forward = block.__original_forward
|
|
del block.__original_forward
|
|
del block.__block_cache
|
|
|
|
def subsample(x: torch.Tensor, factor: int, clone: bool=True) -> torch.Tensor:
|
|
if factor > 1:
|
|
to_return = x[..., ::factor, ::factor]
|
|
if clone:
|
|
return to_return.clone()
|
|
return to_return
|
|
if clone:
|
|
return x.clone()
|
|
return x
|
|
|
|
class BlockCache:
|
|
def __init__(self, subsample_factor: int=8, verbose: bool=False):
|
|
self.subsample_factor = subsample_factor
|
|
self.verbose = verbose
|
|
self.stream_count = 1
|
|
self.recompute = False
|
|
self.block_index = 0
|
|
# cached values
|
|
self.previous_residual_subsampled: torch.Tensor = None
|
|
self.current_residual_subsampled: torch.Tensor = None
|
|
self.cosine_similarity: float = None
|
|
self.previous_taylor_factors: dict[int, torch.Tensor] = {}
|
|
self.current_taylor_factors: dict[int, torch.Tensor] = {}
|
|
|
|
def mark_recompute(self):
|
|
self.recompute = True
|
|
|
|
def mark_predict(self):
|
|
self.recompute = False
|
|
|
|
def cache_previous_residual(self, output_raw: Union[torch.Tensor, tuple[torch.Tensor, ...]], orig_inputs: Union[torch.Tensor, tuple[torch.Tensor, ...]]):
|
|
if isinstance(output_raw, tuple):
|
|
output_raw = output_raw[0]
|
|
if isinstance(orig_inputs, tuple):
|
|
orig_inputs = orig_inputs[0]
|
|
del self.previous_residual_subsampled
|
|
self.previous_residual_subsampled = subsample(output_raw - orig_inputs, self.subsample_factor, clone=True)
|
|
|
|
def cache_current_residual(self, sb_holder: SortblockHolder):
|
|
del self.current_residual_subsampled
|
|
self.current_residual_subsampled = subsample(self.use_taylor_formula(sb_holder)[0], self.subsample_factor, clone=True)
|
|
|
|
def get_orig_inputs(self, d_args: tuple, d_kwargs: dict, clone: bool=True) -> tuple[torch.Tensor, ...]:
|
|
if self.stream_count == 1:
|
|
if clone:
|
|
return d_args[0].clone()
|
|
return d_args[0]
|
|
keys = list(d_kwargs.keys())[:self.stream_count]
|
|
orig_inputs = []
|
|
for key in keys:
|
|
if clone:
|
|
orig_inputs.append(d_kwargs[key].clone())
|
|
else:
|
|
orig_inputs.append(d_kwargs[key])
|
|
return tuple(orig_inputs)
|
|
|
|
def apply_linear_prediction(self, sb_holder: SortblockHolder, orig_inputs: Union[torch.Tensor, tuple[torch.Tensor, ...]]) -> None:
|
|
drop_tuple = False
|
|
if not isinstance(orig_inputs, tuple):
|
|
orig_inputs = (orig_inputs,)
|
|
drop_tuple = True
|
|
taylor_results = self.use_taylor_formula(sb_holder)
|
|
for output, taylor_result in zip(orig_inputs, taylor_results):
|
|
if output.shape != taylor_result.shape:
|
|
zzz = 10
|
|
output += taylor_result
|
|
if drop_tuple:
|
|
orig_inputs = orig_inputs[0]
|
|
return orig_inputs
|
|
|
|
def calculate_cosine_similarity(self) -> None:
|
|
self.cosine_similarity = torch.nn.functional.cosine_similarity(self.previous_residual_subsampled, self.current_residual_subsampled, dim=-1).mean().item()
|
|
|
|
def derivative_approximation(self, sb_holder: SortblockHolder, output_raw: Union[torch.Tensor, tuple[torch.Tensor, ...]], orig_inputs: Union[torch.Tensor, tuple[torch.Tensor, ...]]):
|
|
activation_distance = sb_holder.activated_steps[-1] - sb_holder.activated_steps[-2]
|
|
# make tuple if not already tuple, so that works with both single and double blocks
|
|
if not isinstance(output_raw, tuple):
|
|
output_raw = (output_raw,)
|
|
if not isinstance(orig_inputs, tuple):
|
|
orig_inputs = (orig_inputs,)
|
|
|
|
for i, (output, x) in enumerate(zip(output_raw, orig_inputs)):
|
|
feature = output.clone() - x
|
|
has_previous_taylor_factor = self.previous_taylor_factors.get(i, None) is not None
|
|
# NOTE: not sure why - 2, but that's what's in the original implementation. Maybe consider changing values?
|
|
if has_previous_taylor_factor and sb_holder.step_count > (sb_holder.first_enhance - 2):
|
|
self.current_taylor_factors[i] = (
|
|
feature - self.previous_taylor_factors[i]
|
|
) / activation_distance
|
|
|
|
self.previous_taylor_factors[i] = feature
|
|
|
|
def use_taylor_formula(self, sb_holder: SortblockHolder) -> tuple[torch.Tensor, ...]:
|
|
step_distance = sb_holder.step_count - sb_holder.activated_steps[-1]
|
|
|
|
output_predicted = []
|
|
|
|
for key in self.previous_taylor_factors.keys():
|
|
previous_tf = self.previous_taylor_factors[key]
|
|
current_tf = self.current_taylor_factors[key]
|
|
predicted = taylor_formula(previous_tf, 0, step_distance)
|
|
predicted += taylor_formula(current_tf, 1, step_distance)
|
|
output_predicted.append(predicted)
|
|
|
|
return tuple(output_predicted)
|
|
|
|
def reset(self):
|
|
self.recompute = False
|
|
self.current_residual_subsampled = None
|
|
self.previous_residual_subsampled = None
|
|
self.cosine_similarity = None
|
|
self.previous_taylor_factors = {}
|
|
self.current_taylor_factors = {}
|
|
|
|
def taylor_formula(taylor_factor: torch.Tensor, i: int, step_distance: int):
|
|
return (
|
|
(1 / math.factorial(i))
|
|
* taylor_factor
|
|
* (step_distance ** i)
|
|
)
|
|
|
|
class SortblockHolder:
|
|
def __init__(self, start_predict_ratio: float, end_predict_ratio: float, policy_refresh_interval: int,
|
|
start_percent: float, end_percent: float, subsample_factor: int=8, verbose: bool=False):
|
|
self.start_predict_ratio = start_predict_ratio
|
|
self.end_predict_ratio = end_predict_ratio
|
|
self.start_percent = start_percent
|
|
self.end_percent = end_percent
|
|
self.subsample_factor = subsample_factor
|
|
self.verbose = verbose
|
|
|
|
# NOTE: number represents steps
|
|
self.policy_refresh_interval = policy_refresh_interval
|
|
self.active_policy_refresh_interval = 1
|
|
self.first_enhance = 3 # NOTE: this value is 2 higher than the one actually used in code (subtracted by 2 in derivative_approximation)
|
|
# timestep values
|
|
self.start_t = 0.0
|
|
self.end_t = 0.0
|
|
self.curr_t = 0.0
|
|
# control values
|
|
self.initial_step = True
|
|
self.step_count = 0
|
|
self.activated_steps: list[int] = [0]
|
|
self.step_modulus = 0
|
|
self.do_sortblock = False
|
|
self.active_steps = 0
|
|
self.predict_ratios = []
|
|
self.predict_start_index = 0
|
|
|
|
# cache values
|
|
self.all_blocks = []
|
|
self.blocks_per_type = {}
|
|
|
|
def add_to_all_blocks(self, block):
|
|
self.all_blocks.append(block)
|
|
|
|
def add_to_blocks_per_type(self, block, block_type: str):
|
|
self.blocks_per_type.setdefault(block_type, []).append(block)
|
|
|
|
def prepare_timesteps(self, model_sampling):
|
|
self.start_t = model_sampling.percent_to_sigma(self.start_percent)
|
|
self.end_t = model_sampling.percent_to_sigma(self.end_percent)
|
|
return self
|
|
|
|
def check_if_within_timesteps(self, timestep: Union[float, torch.Tensor]) -> bool:
|
|
return (timestep <= self.start_t).item() and (timestep > self.end_t).item()
|
|
|
|
def update_should_do_sortblock(self, timestep: float) -> bool:
|
|
self.do_sortblock = (timestep[0] <= self.start_t).item() and (timestep[0] > self.end_t).item()
|
|
self.curr_t = timestep
|
|
if self.do_sortblock:
|
|
self.active_policy_refresh_interval = self.policy_refresh_interval
|
|
else:
|
|
self.active_policy_refresh_interval = 1
|
|
self.update_step_modulus()
|
|
return self.do_sortblock
|
|
|
|
def update_step_modulus(self):
|
|
self.step_modulus = int(self.step_count % self.active_policy_refresh_interval)
|
|
|
|
def should_do_sortblock(self) -> bool:
|
|
return self.do_sortblock
|
|
|
|
def reset(self):
|
|
self.initial_step = True
|
|
self.curr_t = 0.0
|
|
logging.info(f"Sortblock: resetting {len(self.all_blocks)} blocks")
|
|
for block in self.all_blocks:
|
|
clean_block(block)
|
|
self.all_blocks = []
|
|
self.blocks_per_type = {}
|
|
self.step_count = 0
|
|
self.activated_steps = [0]
|
|
self.step_modulus = 0
|
|
self.active_steps = 0
|
|
self.predict_ratios = []
|
|
self.do_sortblock = False
|
|
self.predict_start_index = 0
|
|
return self
|
|
|
|
def clone(self):
|
|
return SortblockHolder(start_predict_ratio=self.start_predict_ratio, end_predict_ratio=self.end_predict_ratio, policy_refresh_interval=self.policy_refresh_interval,
|
|
start_percent=self.start_percent, end_percent=self.end_percent, subsample_factor=self.subsample_factor,
|
|
verbose=self.verbose)
|
|
|
|
|
|
class SortblockNode(io.ComfyNode):
|
|
@classmethod
|
|
def define_schema(cls) -> io.Schema:
|
|
return io.Schema(
|
|
node_id="Sortblock",
|
|
display_name="Sortblock",
|
|
description="A homebrew version of EasyCache - even 'easier' version of EasyCache to implement. Overall works worse than EasyCache, but better in some rare cases AND universal compatibility with everything in ComfyUI.",
|
|
category="advanced/debug/model",
|
|
is_experimental=True,
|
|
inputs=[
|
|
io.Model.Input("model", tooltip="The model to add Sortblock to."),
|
|
io.Float.Input("predict_ratio", min=0.0, default=0.8, max=3.0, step=0.01, tooltip="The ratio of blocks to predict."),
|
|
io.Int.Input("policy_refresh_interval", min=3, default=5, max=100, step=1, tooltip="The interval at which to refresh the policy."),
|
|
io.Float.Input("start_percent", min=0.0, default=0.15, max=1.0, step=0.01, tooltip="The relative sampling step to begin use of Sortblock."),
|
|
io.Float.Input("end_percent", min=0.0, default=0.95, max=1.0, step=0.01, tooltip="The relative sampling step to end use of Sortblock."),
|
|
io.Boolean.Input("verbose", default=False, tooltip="Whether to log verbose information."),
|
|
],
|
|
outputs=[
|
|
io.Model.Output(tooltip="The model with Sortblock."),
|
|
],
|
|
)
|
|
|
|
@classmethod
|
|
def execute(cls, model: io.Model.Type, predict_ratio: float, policy_refresh_interval: int, start_percent: float, end_percent: float, verbose: bool) -> io.NodeOutput:
|
|
# TODO: check for specific flavors of supported models
|
|
model = model.clone()
|
|
model.model_options["transformer_options"]["sortblock"] = SortblockHolder(start_predict_ratio=predict_ratio, end_predict_ratio=predict_ratio, policy_refresh_interval=policy_refresh_interval,
|
|
start_percent=start_percent, end_percent=end_percent, subsample_factor=8, verbose=verbose)
|
|
model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.PREDICT_NOISE, "sortblock", prepare_noise_wrapper)
|
|
model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, "sortblock", outer_sample_wrapper)
|
|
model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, "sortblock", model_forward_wrapper)
|
|
return io.NodeOutput(model)
|
|
|
|
|
|
class SortblockScaledNode(io.ComfyNode):
|
|
@classmethod
|
|
def define_schema(cls) -> io.Schema:
|
|
return io.Schema(
|
|
node_id="SortblockScaled",
|
|
display_name="SortblockScaled",
|
|
description="A homebrew version of EasyCache - even 'easier' version of EasyCache to implement. Overall works worse than EasyCache, but better in some rare cases AND universal compatibility with everything in ComfyUI.",
|
|
category="advanced/debug/model",
|
|
is_experimental=True,
|
|
inputs=[
|
|
io.Model.Input("model", tooltip="The model to add Sortblock to."),
|
|
io.Float.Input("start_predict_ratio", min=0.0, default=0.2, max=1.0, step=0.01, tooltip="The ratio of blocks to predict."),
|
|
io.Float.Input("end_predict_ratio", min=0.0, default=0.9, max=1.0, step=0.01, tooltip="The ratio of blocks to predict."),
|
|
io.Int.Input("policy_refresh_interval", min=3, default=5, max=100, step=1, tooltip="The interval at which to refresh the policy."),
|
|
io.Float.Input("start_percent", min=0.0, default=0.15, max=1.0, step=0.01, tooltip="The relative sampling step to begin use of Sortblock."),
|
|
io.Float.Input("end_percent", min=0.0, default=0.95, max=1.0, step=0.01, tooltip="The relative sampling step to end use of Sortblock."),
|
|
io.Boolean.Input("verbose", default=False, tooltip="Whether to log verbose information."),
|
|
],
|
|
outputs=[
|
|
io.Model.Output(tooltip="The model with Sortblock."),
|
|
],
|
|
)
|
|
|
|
@classmethod
|
|
def execute(cls, model: io.Model.Type, start_predict_ratio: float, end_predict_ratio: float, policy_refresh_interval: int, start_percent: float, end_percent: float, verbose: bool) -> io.NodeOutput:
|
|
# TODO: check for specific flavors of supported models
|
|
model = model.clone()
|
|
model.model_options["transformer_options"]["sortblock"] = SortblockHolder(start_predict_ratio, end_predict_ratio, policy_refresh_interval, start_percent, end_percent, subsample_factor=8, verbose=verbose)
|
|
model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.PREDICT_NOISE, "sortblock", prepare_noise_wrapper)
|
|
model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, "sortblock", outer_sample_wrapper)
|
|
model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, "sortblock", model_forward_wrapper)
|
|
return io.NodeOutput(model)
|
|
|
|
|
|
class SortblockExtension(ComfyExtension):
|
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
|
return [
|
|
SortblockNode,
|
|
SortblockScaledNode,
|
|
]
|
|
|
|
def comfy_entrypoint():
|
|
return SortblockExtension()
|