From fc247150fec502b1834390516b556a87003f1d79 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Fri, 22 Aug 2025 19:41:08 -0700 Subject: [PATCH] Implement EasyCache and Invent LazyCache (#9496) * Attempting a universal implementation of EasyCache, starting with flux as test; I screwed up the math a bit, but when I set it just right it works. * Fixed math to make threshold work as expected, refactored code to use EasyCacheHolder instead of a dict wrapped by object * Use sigmas from transformer_options instead of timesteps to be compatible with a greater amount of models, make end_percent work * Make log statement when not skipping useful, preparing for per-cond caching * Added DIFFUSION_MODEL wrapper around forward function for wan model * Add subsampling for heuristic inputs * Add subsampling to output_prev (output_prev_subsampled now) * Properly consider conds in EasyCache logic * Created SuperEasyCache to test what happens if caching and reuse is moved outside the scope of conds, added PREDICT_NOISE wrapper to facilitate this test * Change max reuse_threshold to 3.0 * Mark EasyCache/SuperEasyCache as experimental (beta) * Make Lumina2 compatible with EasyCache * Add EasyCache support for Qwen Image * Fix missing comma, curse you Cursor * Add EasyCache support to AceStep * Add EasyCache support to Chroma * Added EasyCache support to Cosmos Predict t2i * Make EasyCache not crash with Cosmos Predict ImagToVideo latents, but does not work well at all * Add EasyCache support to hidream * Added EasyCache support to hunyuan video * Added EasyCache support to hunyuan3d * Added EasyCache support to LTXV (not very good, but does not crash) * Implemented EasyCache for aura_flow * Renamed SuperEasyCache to LazyCache, hardcoded subsample_factor to 8 on nodes * Eatra logging when verbose is true for EasyCache --- comfy/ldm/ace/model.py | 24 +- comfy/ldm/aura/mmdit.py | 8 + comfy/ldm/chroma/model.py | 8 + comfy/ldm/cosmos/model.py | 38 +++ comfy/ldm/cosmos/predict2.py | 17 +- comfy/ldm/flux/model.py | 8 + comfy/ldm/hidream/model.py | 19 +- comfy/ldm/hunyuan3d/model.py | 8 + comfy/ldm/hunyuan_video/model.py | 8 + comfy/ldm/lightricks/model.py | 8 + comfy/ldm/lumina/model.py | 10 +- comfy/ldm/qwen_image/model.py | 10 +- comfy/ldm/wan/model.py | 8 + comfy/patcher_extension.py | 1 + comfy/samplers.py | 9 +- comfy_extras/nodes_easycache.py | 459 +++++++++++++++++++++++++++++++ nodes.py | 3 +- 17 files changed, 639 insertions(+), 7 deletions(-) create mode 100644 comfy_extras/nodes_easycache.py diff --git a/comfy/ldm/ace/model.py b/comfy/ldm/ace/model.py index 12c524701..41d85eeb5 100644 --- a/comfy/ldm/ace/model.py +++ b/comfy/ldm/ace/model.py @@ -19,6 +19,7 @@ import torch from torch import nn import comfy.model_management +import comfy.patcher_extension from comfy.ldm.lightricks.model import TimestepEmbedding, Timesteps from .attention import LinearTransformerBlock, t2i_modulate @@ -343,7 +344,28 @@ class ACEStepTransformer2DModel(nn.Module): output = self.final_layer(hidden_states, embedded_timestep, output_length) return output - def forward( + def forward(self, + x, + timestep, + attention_mask=None, + context: Optional[torch.Tensor] = None, + text_attention_mask: Optional[torch.LongTensor] = None, + speaker_embeds: Optional[torch.FloatTensor] = None, + lyric_token_idx: Optional[torch.LongTensor] = None, + lyric_mask: Optional[torch.LongTensor] = None, + block_controlnet_hidden_states: Optional[Union[List[torch.Tensor], torch.Tensor]] = None, + controlnet_scale: Union[float, torch.Tensor] = 1.0, + lyrics_strength=1.0, + **kwargs + ): + return comfy.patcher_extension.WrapperExecutor.new_class_executor( + self._forward, + self, + comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, kwargs.get("transformer_options", {})) + ).execute(x, timestep, attention_mask, context, text_attention_mask, speaker_embeds, lyric_token_idx, lyric_mask, block_controlnet_hidden_states, + controlnet_scale, lyrics_strength, **kwargs) + + def _forward( self, x, timestep, diff --git a/comfy/ldm/aura/mmdit.py b/comfy/ldm/aura/mmdit.py index 1258ae11f..d7f32b5e8 100644 --- a/comfy/ldm/aura/mmdit.py +++ b/comfy/ldm/aura/mmdit.py @@ -9,6 +9,7 @@ import torch.nn.functional as F from comfy.ldm.modules.attention import optimized_attention import comfy.ops +import comfy.patcher_extension import comfy.ldm.common_dit def modulate(x, shift, scale): @@ -436,6 +437,13 @@ class MMDiT(nn.Module): return x + pos_encoding.reshape(1, -1, self.positional_encoding.shape[-1]) def forward(self, x, timestep, context, transformer_options={}, **kwargs): + return comfy.patcher_extension.WrapperExecutor.new_class_executor( + self._forward, + self, + comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options) + ).execute(x, timestep, context, transformer_options, **kwargs) + + def _forward(self, x, timestep, context, transformer_options={}, **kwargs): patches_replace = transformer_options.get("patches_replace", {}) # patchify x, add PE b, c, h, w = x.shape diff --git a/comfy/ldm/chroma/model.py b/comfy/ldm/chroma/model.py index 06021d4f2..5cff44dc8 100644 --- a/comfy/ldm/chroma/model.py +++ b/comfy/ldm/chroma/model.py @@ -5,6 +5,7 @@ from dataclasses import dataclass import torch from torch import Tensor, nn from einops import rearrange, repeat +import comfy.patcher_extension import comfy.ldm.common_dit from comfy.ldm.flux.layers import ( @@ -253,6 +254,13 @@ class Chroma(nn.Module): return img def forward(self, x, timestep, context, guidance, control=None, transformer_options={}, **kwargs): + return comfy.patcher_extension.WrapperExecutor.new_class_executor( + self._forward, + self, + comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options) + ).execute(x, timestep, context, guidance, control, transformer_options, **kwargs) + + def _forward(self, x, timestep, context, guidance, control=None, transformer_options={}, **kwargs): bs, c, h, w = x.shape x = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size)) diff --git a/comfy/ldm/cosmos/model.py b/comfy/ldm/cosmos/model.py index 4836e0b69..53698b758 100644 --- a/comfy/ldm/cosmos/model.py +++ b/comfy/ldm/cosmos/model.py @@ -27,6 +27,8 @@ from torchvision import transforms from enum import Enum import logging +import comfy.patcher_extension + from .blocks import ( FinalLayer, GeneralDITTransformerBlock, @@ -435,6 +437,42 @@ class GeneralDIT(nn.Module): latent_condition_sigma: Optional[torch.Tensor] = None, condition_video_augment_sigma: Optional[torch.Tensor] = None, **kwargs, + ): + return comfy.patcher_extension.WrapperExecutor.new_class_executor( + self._forward, + self, + comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, kwargs.get("transformer_options", {})) + ).execute(x, + timesteps, + context, + attention_mask, + fps, + image_size, + padding_mask, + scalar_feature, + data_type, + latent_condition, + latent_condition_sigma, + condition_video_augment_sigma, + **kwargs) + + def _forward( + self, + x: torch.Tensor, + timesteps: torch.Tensor, + context: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + # crossattn_emb: torch.Tensor, + # crossattn_mask: Optional[torch.Tensor] = None, + fps: Optional[torch.Tensor] = None, + image_size: Optional[torch.Tensor] = None, + padding_mask: Optional[torch.Tensor] = None, + scalar_feature: Optional[torch.Tensor] = None, + data_type: Optional[DataType] = DataType.VIDEO, + latent_condition: Optional[torch.Tensor] = None, + latent_condition_sigma: Optional[torch.Tensor] = None, + condition_video_augment_sigma: Optional[torch.Tensor] = None, + **kwargs, ): """ Args: diff --git a/comfy/ldm/cosmos/predict2.py b/comfy/ldm/cosmos/predict2.py index 316117f77..fcc83ba76 100644 --- a/comfy/ldm/cosmos/predict2.py +++ b/comfy/ldm/cosmos/predict2.py @@ -11,6 +11,7 @@ import math from .position_embedding import VideoRopePosition3DEmb, LearnablePosEmbAxis from torchvision import transforms +import comfy.patcher_extension from comfy.ldm.modules.attention import optimized_attention def apply_rotary_pos_emb( @@ -805,7 +806,21 @@ class MiniTrainDIT(nn.Module): ) return x_B_C_Tt_Hp_Wp - def forward( + def forward(self, + x: torch.Tensor, + timesteps: torch.Tensor, + context: torch.Tensor, + fps: Optional[torch.Tensor] = None, + padding_mask: Optional[torch.Tensor] = None, + **kwargs, + ): + return comfy.patcher_extension.WrapperExecutor.new_class_executor( + self._forward, + self, + comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, kwargs.get("transformer_options", {})) + ).execute(x, timesteps, context, fps, padding_mask, **kwargs) + + def _forward( self, x: torch.Tensor, timesteps: torch.Tensor, diff --git a/comfy/ldm/flux/model.py b/comfy/ldm/flux/model.py index c4de82795..0a77fa097 100644 --- a/comfy/ldm/flux/model.py +++ b/comfy/ldm/flux/model.py @@ -6,6 +6,7 @@ import torch from torch import Tensor, nn from einops import rearrange, repeat import comfy.ldm.common_dit +import comfy.patcher_extension from .layers import ( DoubleStreamBlock, @@ -214,6 +215,13 @@ class Flux(nn.Module): return img, repeat(img_ids, "h w c -> b (h w) c", b=bs) def forward(self, x, timestep, context, y=None, guidance=None, ref_latents=None, control=None, transformer_options={}, **kwargs): + return comfy.patcher_extension.WrapperExecutor.new_class_executor( + self._forward, + self, + comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options) + ).execute(x, timestep, context, y, guidance, ref_latents, control, transformer_options, **kwargs) + + def _forward(self, x, timestep, context, y=None, guidance=None, ref_latents=None, control=None, transformer_options={}, **kwargs): bs, c, h_orig, w_orig = x.shape patch_size = self.patch_size diff --git a/comfy/ldm/hidream/model.py b/comfy/ldm/hidream/model.py index 0305747bf..ae49cf945 100644 --- a/comfy/ldm/hidream/model.py +++ b/comfy/ldm/hidream/model.py @@ -13,6 +13,7 @@ from comfy.ldm.flux.layers import LastLayer from comfy.ldm.modules.attention import optimized_attention import comfy.model_management +import comfy.patcher_extension import comfy.ldm.common_dit @@ -692,7 +693,23 @@ class HiDreamImageTransformer2DModel(nn.Module): raise NotImplementedError return x, x_masks, img_sizes - def forward( + def forward(self, + x: torch.Tensor, + t: torch.Tensor, + y: Optional[torch.Tensor] = None, + context: Optional[torch.Tensor] = None, + encoder_hidden_states_llama3=None, + image_cond=None, + control = None, + transformer_options = {}, + ): + return comfy.patcher_extension.WrapperExecutor.new_class_executor( + self._forward, + self, + comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options) + ).execute(x, t, y, context, encoder_hidden_states_llama3, image_cond, control, transformer_options) + + def _forward( self, x: torch.Tensor, t: torch.Tensor, diff --git a/comfy/ldm/hunyuan3d/model.py b/comfy/ldm/hunyuan3d/model.py index 4e18358f0..0fa5e78c1 100644 --- a/comfy/ldm/hunyuan3d/model.py +++ b/comfy/ldm/hunyuan3d/model.py @@ -7,6 +7,7 @@ from comfy.ldm.flux.layers import ( SingleStreamBlock, timestep_embedding, ) +import comfy.patcher_extension class Hunyuan3Dv2(nn.Module): @@ -67,6 +68,13 @@ class Hunyuan3Dv2(nn.Module): self.final_layer = LastLayer(hidden_size, 1, in_channels, dtype=dtype, device=device, operations=operations) def forward(self, x, timestep, context, guidance=None, transformer_options={}, **kwargs): + return comfy.patcher_extension.WrapperExecutor.new_class_executor( + self._forward, + self, + comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options) + ).execute(x, timestep, context, guidance, transformer_options, **kwargs) + + def _forward(self, x, timestep, context, guidance=None, transformer_options={}, **kwargs): x = x.movedim(-1, -2) timestep = 1.0 - timestep txt = context diff --git a/comfy/ldm/hunyuan_video/model.py b/comfy/ldm/hunyuan_video/model.py index fbd8d4196..da1011596 100644 --- a/comfy/ldm/hunyuan_video/model.py +++ b/comfy/ldm/hunyuan_video/model.py @@ -1,6 +1,7 @@ #Based on Flux code because of weird hunyuan video code license. import torch +import comfy.patcher_extension import comfy.ldm.flux.layers import comfy.ldm.modules.diffusionmodules.mmdit from comfy.ldm.modules.attention import optimized_attention @@ -348,6 +349,13 @@ class HunyuanVideo(nn.Module): return repeat(img_ids, "t h w c -> b (t h w) c", b=bs) def forward(self, x, timestep, context, y, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, control=None, transformer_options={}, **kwargs): + return comfy.patcher_extension.WrapperExecutor.new_class_executor( + self._forward, + self, + comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options) + ).execute(x, timestep, context, y, guidance, attention_mask, guiding_frame_index, ref_latent, control, transformer_options, **kwargs) + + def _forward(self, x, timestep, context, y, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, control=None, transformer_options={}, **kwargs): bs, c, t, h, w = x.shape img_ids = self.img_ids(x) txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype) diff --git a/comfy/ldm/lightricks/model.py b/comfy/ldm/lightricks/model.py index ad9a7daea..aa2ea62b1 100644 --- a/comfy/ldm/lightricks/model.py +++ b/comfy/ldm/lightricks/model.py @@ -1,5 +1,6 @@ import torch from torch import nn +import comfy.patcher_extension import comfy.ldm.modules.attention import comfy.ldm.common_dit from einops import rearrange @@ -420,6 +421,13 @@ class LTXVModel(torch.nn.Module): self.patchifier = SymmetricPatchifier(1) def forward(self, x, timestep, context, attention_mask, frame_rate=25, transformer_options={}, keyframe_idxs=None, **kwargs): + return comfy.patcher_extension.WrapperExecutor.new_class_executor( + self._forward, + self, + comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options) + ).execute(x, timestep, context, attention_mask, frame_rate, transformer_options, keyframe_idxs, **kwargs) + + def _forward(self, x, timestep, context, attention_mask, frame_rate=25, transformer_options={}, keyframe_idxs=None, **kwargs): patches_replace = transformer_options.get("patches_replace", {}) orig_shape = list(x.shape) diff --git a/comfy/ldm/lumina/model.py b/comfy/ldm/lumina/model.py index f8dc4d7db..e08ed817d 100644 --- a/comfy/ldm/lumina/model.py +++ b/comfy/ldm/lumina/model.py @@ -11,6 +11,7 @@ import comfy.ldm.common_dit from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder from comfy.ldm.modules.attention import optimized_attention_masked from comfy.ldm.flux.layers import EmbedND +import comfy.patcher_extension def modulate(x, scale): @@ -590,8 +591,15 @@ class NextDiT(nn.Module): return padded_full_embed, mask, img_sizes, l_effective_cap_len, freqs_cis - # def forward(self, x, t, cap_feats, cap_mask): def forward(self, x, timesteps, context, num_tokens, attention_mask=None, **kwargs): + return comfy.patcher_extension.WrapperExecutor.new_class_executor( + self._forward, + self, + comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, kwargs.get("transformer_options", {})) + ).execute(x, timesteps, context, num_tokens, attention_mask, **kwargs) + + # def forward(self, x, t, cap_feats, cap_mask): + def _forward(self, x, timesteps, context, num_tokens, attention_mask=None, **kwargs): t = 1.0 - timesteps cap_feats = context cap_mask = attention_mask diff --git a/comfy/ldm/qwen_image/model.py b/comfy/ldm/qwen_image/model.py index af00ff119..57a458210 100644 --- a/comfy/ldm/qwen_image/model.py +++ b/comfy/ldm/qwen_image/model.py @@ -9,6 +9,7 @@ from comfy.ldm.lightricks.model import TimestepEmbedding, Timesteps from comfy.ldm.modules.attention import optimized_attention_masked from comfy.ldm.flux.layers import EmbedND import comfy.ldm.common_dit +import comfy.patcher_extension class GELU(nn.Module): def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool = True, dtype=None, device=None, operations=None): @@ -355,7 +356,14 @@ class QwenImageTransformer2DModel(nn.Module): img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0) - (w_len // 2) return hidden_states, repeat(img_ids, "h w c -> b (h w) c", b=bs), orig_shape - def forward( + def forward(self, x, timestep, context, attention_mask=None, guidance=None, ref_latents=None, transformer_options={}, **kwargs): + return comfy.patcher_extension.WrapperExecutor.new_class_executor( + self._forward, + self, + comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options) + ).execute(x, timestep, context, attention_mask, guidance, ref_latents, transformer_options, **kwargs) + + def _forward( self, x, timesteps, diff --git a/comfy/ldm/wan/model.py b/comfy/ldm/wan/model.py index 0726b8e1b..1885d9730 100644 --- a/comfy/ldm/wan/model.py +++ b/comfy/ldm/wan/model.py @@ -11,6 +11,7 @@ from comfy.ldm.flux.layers import EmbedND from comfy.ldm.flux.math import apply_rope import comfy.ldm.common_dit import comfy.model_management +import comfy.patcher_extension def sinusoidal_embedding_1d(dim, position): @@ -573,6 +574,13 @@ class WanModel(torch.nn.Module): return x def forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, transformer_options={}, **kwargs): + return comfy.patcher_extension.WrapperExecutor.new_class_executor( + self._forward, + self, + comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options) + ).execute(x, timestep, context, clip_fea, time_dim_concat, transformer_options, **kwargs) + + def _forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, transformer_options={}, **kwargs): bs, c, t, h, w = x.shape x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size) diff --git a/comfy/patcher_extension.py b/comfy/patcher_extension.py index 965958f4c..46cc7b2a8 100644 --- a/comfy/patcher_extension.py +++ b/comfy/patcher_extension.py @@ -50,6 +50,7 @@ class WrappersMP: OUTER_SAMPLE = "outer_sample" PREPARE_SAMPLING = "prepare_sampling" SAMPLER_SAMPLE = "sampler_sample" + PREDICT_NOISE = "predict_noise" CALC_COND_BATCH = "calc_cond_batch" APPLY_MODEL = "apply_model" DIFFUSION_MODEL = "diffusion_model" diff --git a/comfy/samplers.py b/comfy/samplers.py index d5390d64e..ec7e0b350 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -953,7 +953,14 @@ class CFGGuider: self.original_conds[k] = comfy.sampler_helpers.convert_cond(conds[k]) def __call__(self, *args, **kwargs): - return self.predict_noise(*args, **kwargs) + return self.outer_predict_noise(*args, **kwargs) + + def outer_predict_noise(self, x, timestep, model_options={}, seed=None): + return comfy.patcher_extension.WrapperExecutor.new_class_executor( + self.predict_noise, + self, + comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.PREDICT_NOISE, self.model_options, is_model_options=True) + ).execute(x, timestep, model_options, seed) def predict_noise(self, x, timestep, model_options={}, seed=None): return sampling_function(self.inner_model, x, timestep, self.conds.get("negative", None), self.conds.get("positive", None), self.cfg, model_options=model_options, seed=seed) diff --git a/comfy_extras/nodes_easycache.py b/comfy_extras/nodes_easycache.py new file mode 100644 index 000000000..e2b2efcd9 --- /dev/null +++ b/comfy_extras/nodes_easycache.py @@ -0,0 +1,459 @@ +from __future__ import annotations +from typing import TYPE_CHECKING, Union +from comfy_api.latest import io, ComfyExtension +import comfy.patcher_extension +import logging +import torch +import comfy.model_patcher +if TYPE_CHECKING: + from uuid import UUID + + +def easycache_forward_wrapper(executor, *args, **kwargs): + # get values from args + x: torch.Tensor = args[0] + transformer_options: dict[str] = args[-1] + if not isinstance(transformer_options, dict): + transformer_options = kwargs.get("transformer_options") + if not transformer_options: + transformer_options = args[-2] + easycache: EasyCacheHolder = transformer_options["easycache"] + sigmas = transformer_options["sigmas"] + uuids = transformer_options["uuids"] + if sigmas is not None and easycache.is_past_end_timestep(sigmas): + return executor(*args, **kwargs) + # prepare next x_prev + has_first_cond_uuid = easycache.has_first_cond_uuid(uuids) + next_x_prev = x + input_change = None + do_easycache = easycache.should_do_easycache(sigmas) + if do_easycache: + # if first cond marked this step for skipping, skip it and use appropriate cached values + if easycache.skip_current_step: + if easycache.verbose: + logging.info(f"EasyCache [verbose] - was marked to skip this step by {easycache.first_cond_uuid}. Present uuids: {uuids}") + return easycache.apply_cache_diff(x, uuids) + if easycache.initial_step: + easycache.first_cond_uuid = uuids[0] + has_first_cond_uuid = easycache.has_first_cond_uuid(uuids) + easycache.initial_step = False + if has_first_cond_uuid: + if easycache.has_x_prev_subsampled(): + input_change = (easycache.subsample(x, uuids, clone=False) - easycache.x_prev_subsampled).flatten().abs().mean() + if easycache.has_output_prev_norm() and easycache.has_relative_transformation_rate(): + approx_output_change_rate = (easycache.relative_transformation_rate * input_change) / easycache.output_prev_norm + easycache.cumulative_change_rate += approx_output_change_rate + if easycache.cumulative_change_rate < easycache.reuse_threshold: + if easycache.verbose: + logging.info(f"EasyCache [verbose] - skipping step; cumulative_change_rate: {easycache.cumulative_change_rate}, reuse_threshold: {easycache.reuse_threshold}") + # other conds should also skip this step, and instead use their cached values + easycache.skip_current_step = True + return easycache.apply_cache_diff(x, uuids) + else: + if easycache.verbose: + logging.info(f"EasyCache [verbose] - NOT skipping step; cumulative_change_rate: {easycache.cumulative_change_rate}, reuse_threshold: {easycache.reuse_threshold}") + easycache.cumulative_change_rate = 0.0 + + output: torch.Tensor = executor(*args, **kwargs) + if has_first_cond_uuid and easycache.has_output_prev_norm(): + output_change = (easycache.subsample(output, uuids, clone=False) - easycache.output_prev_subsampled).flatten().abs().mean() + if easycache.verbose: + output_change_rate = output_change / easycache.output_prev_norm + easycache.output_change_rates.append(output_change_rate.item()) + if easycache.has_relative_transformation_rate(): + approx_output_change_rate = (easycache.relative_transformation_rate * input_change) / easycache.output_prev_norm + easycache.approx_output_change_rates.append(approx_output_change_rate.item()) + if easycache.verbose: + logging.info(f"EasyCache [verbose] - approx_output_change_rate: {approx_output_change_rate}") + if input_change is not None: + easycache.relative_transformation_rate = output_change / input_change + if easycache.verbose: + logging.info(f"EasyCache [verbose] - output_change_rate: {output_change_rate}") + # TODO: allow cache_diff to be offloaded + easycache.update_cache_diff(output, next_x_prev, uuids) + if has_first_cond_uuid: + easycache.x_prev_subsampled = easycache.subsample(next_x_prev, uuids) + easycache.output_prev_subsampled = easycache.subsample(output, uuids) + easycache.output_prev_norm = output.flatten().abs().mean() + if easycache.verbose: + logging.info(f"EasyCache [verbose] - x_prev_subsampled: {easycache.x_prev_subsampled.shape}") + return output + +def lazycache_predict_noise_wrapper(executor, *args, **kwargs): + # get values from args + x: torch.Tensor = args[0] + timestep: float = args[1] + model_options: dict[str] = args[2] + easycache: LazyCacheHolder = model_options["transformer_options"]["easycache"] + if easycache.is_past_end_timestep(timestep): + return executor(*args, **kwargs) + # prepare next x_prev + next_x_prev = x + input_change = None + do_easycache = easycache.should_do_easycache(timestep) + if do_easycache: + if easycache.has_x_prev_subsampled(): + if easycache.has_x_prev_subsampled(): + input_change = (easycache.subsample(x, clone=False) - easycache.x_prev_subsampled).flatten().abs().mean() + if easycache.has_output_prev_norm() and easycache.has_relative_transformation_rate(): + approx_output_change_rate = (easycache.relative_transformation_rate * input_change) / easycache.output_prev_norm + easycache.cumulative_change_rate += approx_output_change_rate + if easycache.cumulative_change_rate < easycache.reuse_threshold: + if easycache.verbose: + logging.info(f"LazyCache [verbose] - skipping step; cumulative_change_rate: {easycache.cumulative_change_rate}, reuse_threshold: {easycache.reuse_threshold}") + # other conds should also skip this step, and instead use their cached values + easycache.skip_current_step = True + return easycache.apply_cache_diff(x) + else: + if easycache.verbose: + logging.info(f"LazyCache [verbose] - NOT skipping step; cumulative_change_rate: {easycache.cumulative_change_rate}, reuse_threshold: {easycache.reuse_threshold}") + easycache.cumulative_change_rate = 0.0 + output: torch.Tensor = executor(*args, **kwargs) + if easycache.has_output_prev_norm(): + output_change = (easycache.subsample(output, clone=False) - easycache.output_prev_subsampled).flatten().abs().mean() + if easycache.verbose: + output_change_rate = output_change / easycache.output_prev_norm + easycache.output_change_rates.append(output_change_rate.item()) + if easycache.has_relative_transformation_rate(): + approx_output_change_rate = (easycache.relative_transformation_rate * input_change) / easycache.output_prev_norm + easycache.approx_output_change_rates.append(approx_output_change_rate.item()) + if easycache.verbose: + logging.info(f"LazyCache [verbose] - approx_output_change_rate: {approx_output_change_rate}") + if input_change is not None: + easycache.relative_transformation_rate = output_change / input_change + if easycache.verbose: + logging.info(f"LazyCache [verbose] - output_change_rate: {output_change_rate}") + # TODO: allow cache_diff to be offloaded + easycache.update_cache_diff(output, next_x_prev) + easycache.x_prev_subsampled = easycache.subsample(next_x_prev) + easycache.output_prev_subsampled = easycache.subsample(output) + easycache.output_prev_norm = output.flatten().abs().mean() + if easycache.verbose: + logging.info(f"LazyCache [verbose] - x_prev_subsampled: {easycache.x_prev_subsampled.shape}") + return output + +def easycache_calc_cond_batch_wrapper(executor, *args, **kwargs): + model_options = args[-1] + easycache: EasyCacheHolder = model_options["transformer_options"]["easycache"] + easycache.skip_current_step = False + # TODO: check if first_cond_uuid is active at this timestep; otherwise, EasyCache needs to be partially reset + return executor(*args, **kwargs) + +def easycache_sample_wrapper(executor, *args, **kwargs): + """ + This OUTER_SAMPLE wrapper makes sure easycache is prepped for current run, and all memory usage is cleared at the end. + """ + try: + guider = executor.class_obj + orig_model_options = guider.model_options + guider.model_options = comfy.model_patcher.create_model_options_clone(orig_model_options) + # clone and prepare timesteps + guider.model_options["transformer_options"]["easycache"] = guider.model_options["transformer_options"]["easycache"].clone().prepare_timesteps(guider.model_patcher.model.model_sampling) + easycache: Union[EasyCacheHolder, LazyCacheHolder] = guider.model_options['transformer_options']['easycache'] + logging.info(f"{easycache.name} enabled - threshold: {easycache.reuse_threshold}, start_percent: {easycache.start_percent}, end_percent: {easycache.end_percent}") + return executor(*args, **kwargs) + finally: + easycache = guider.model_options['transformer_options']['easycache'] + output_change_rates = easycache.output_change_rates + approx_output_change_rates = easycache.approx_output_change_rates + if easycache.verbose: + logging.info(f"{easycache.name} [verbose] - output_change_rates {len(output_change_rates)}: {output_change_rates}") + logging.info(f"{easycache.name} [verbose] - approx_output_change_rates {len(approx_output_change_rates)}: {approx_output_change_rates}") + total_steps = len(args[3])-1 + logging.info(f"{easycache.name} - skipped {easycache.total_steps_skipped}/{total_steps} steps ({total_steps/(total_steps-easycache.total_steps_skipped):.2f}x speedup).") + easycache.reset() + guider.model_options = orig_model_options + + +class EasyCacheHolder: + def __init__(self, reuse_threshold: float, start_percent: float, end_percent: float, subsample_factor: int, offload_cache_diff: bool, verbose: bool=False): + self.name = "EasyCache" + self.reuse_threshold = reuse_threshold + self.start_percent = start_percent + self.end_percent = end_percent + self.subsample_factor = subsample_factor + self.offload_cache_diff = offload_cache_diff + self.verbose = verbose + # timestep values + self.start_t = 0.0 + self.end_t = 0.0 + # control values + self.relative_transformation_rate: float = None + self.cumulative_change_rate = 0.0 + self.initial_step = True + self.skip_current_step = False + # cache values + self.first_cond_uuid = None + self.x_prev_subsampled: torch.Tensor = None + self.output_prev_subsampled: torch.Tensor = None + self.output_prev_norm: torch.Tensor = None + self.uuid_cache_diffs: dict[UUID, torch.Tensor] = {} + self.output_change_rates = [] + self.approx_output_change_rates = [] + self.total_steps_skipped = 0 + # how to deal with mismatched dims + self.allow_mismatch = True + self.cut_from_start = True + + def is_past_end_timestep(self, timestep: float) -> bool: + return not (timestep[0] > self.end_t).item() + + def should_do_easycache(self, timestep: float) -> bool: + return (timestep[0] <= self.start_t).item() + + def has_x_prev_subsampled(self) -> bool: + return self.x_prev_subsampled is not None + + def has_output_prev_subsampled(self) -> bool: + return self.output_prev_subsampled is not None + + def has_output_prev_norm(self) -> bool: + return self.output_prev_norm is not None + + def has_relative_transformation_rate(self) -> bool: + return self.relative_transformation_rate is not None + + def prepare_timesteps(self, model_sampling): + self.start_t = model_sampling.percent_to_sigma(self.start_percent) + self.end_t = model_sampling.percent_to_sigma(self.end_percent) + return self + + def subsample(self, x: torch.Tensor, uuids: list[UUID], clone: bool = True) -> torch.Tensor: + batch_offset = x.shape[0] // len(uuids) + uuid_idx = uuids.index(self.first_cond_uuid) + if self.subsample_factor > 1: + to_return = x[uuid_idx*batch_offset:(uuid_idx+1)*batch_offset, ..., ::self.subsample_factor, ::self.subsample_factor] + if clone: + return to_return.clone() + return to_return + to_return = x[uuid_idx*batch_offset:(uuid_idx+1)*batch_offset, ...] + if clone: + return to_return.clone() + return to_return + + def apply_cache_diff(self, x: torch.Tensor, uuids: list[UUID]): + if self.first_cond_uuid in uuids: + self.total_steps_skipped += 1 + batch_offset = x.shape[0] // len(uuids) + for i, uuid in enumerate(uuids): + # if cached dims don't match x dims, cut off excess and hope for the best (cosmos world2video) + if x.shape[1:] != self.uuid_cache_diffs[uuid].shape[1:]: + if not self.allow_mismatch: + raise ValueError(f"Cached dims {self.uuid_cache_diffs[uuid].shape} don't match x dims {x.shape} - this is no good") + slicing = [] + skip_this_dim = True + for dim_u, dim_x in zip(self.uuid_cache_diffs[uuid].shape, x.shape): + if skip_this_dim: + skip_this_dim = False + continue + if dim_u != dim_x: + if self.cut_from_start: + slicing.append(slice(dim_x-dim_u, None)) + else: + slicing.append(slice(None, dim_u)) + else: + slicing.append(slice(None)) + slicing = [slice(i*batch_offset,(i+1)*batch_offset)] + slicing + x = x[slicing] + x += self.uuid_cache_diffs[uuid].to(x.device) + return x + + def update_cache_diff(self, output: torch.Tensor, x: torch.Tensor, uuids: list[UUID]): + # if output dims don't match x dims, cut off excess and hope for the best (cosmos world2video) + if output.shape[1:] != x.shape[1:]: + if not self.allow_mismatch: + raise ValueError(f"Output dims {output.shape} don't match x dims {x.shape} - this is no good") + slicing = [] + skip_dim = True + for dim_o, dim_x in zip(output.shape, x.shape): + if not skip_dim and dim_o != dim_x: + if self.cut_from_start: + slicing.append(slice(dim_x-dim_o, None)) + else: + slicing.append(slice(None, dim_o)) + else: + slicing.append(slice(None)) + skip_dim = False + x = x[slicing] + diff = output - x + batch_offset = diff.shape[0] // len(uuids) + for i, uuid in enumerate(uuids): + self.uuid_cache_diffs[uuid] = diff[i*batch_offset:(i+1)*batch_offset, ...] + + def has_first_cond_uuid(self, uuids: list[UUID]) -> bool: + return self.first_cond_uuid in uuids + + def reset(self): + self.relative_transformation_rate = 0.0 + self.cumulative_change_rate = 0.0 + self.initial_step = True + self.skip_current_step = False + self.output_change_rates = [] + self.first_cond_uuid = None + del self.x_prev_subsampled + self.x_prev_subsampled = None + del self.output_prev_subsampled + self.output_prev_subsampled = None + del self.output_prev_norm + self.output_prev_norm = None + del self.uuid_cache_diffs + self.uuid_cache_diffs = {} + self.total_steps_skipped = 0 + return self + + def clone(self): + return EasyCacheHolder(self.reuse_threshold, self.start_percent, self.end_percent, self.subsample_factor, self.offload_cache_diff, self.verbose) + + +class EasyCacheNode(io.ComfyNode): + @classmethod + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="EasyCache", + display_name="EasyCache", + description="Native EasyCache implementation.", + category="advanced/debug/model", + is_experimental=True, + inputs=[ + io.Model.Input("model", tooltip="The model to add EasyCache to."), + io.Float.Input("reuse_threshold", min=0.0, default=0.2, max=3.0, step=0.01, tooltip="The threshold for reusing cached steps."), + io.Float.Input("start_percent", min=0.0, default=0.15, max=1.0, step=0.01, tooltip="The relative sampling step to begin use of EasyCache."), + io.Float.Input("end_percent", min=0.0, default=0.95, max=1.0, step=0.01, tooltip="The relative sampling step to end use of EasyCache."), + io.Boolean.Input("verbose", default=False, tooltip="Whether to log verbose information."), + ], + outputs=[ + io.Model.Output(tooltip="The model with EasyCache."), + ], + ) + + @classmethod + def execute(cls, model: io.Model.Type, reuse_threshold: float, start_percent: float, end_percent: float, verbose: bool) -> io.NodeOutput: + model = model.clone() + model.model_options["transformer_options"]["easycache"] = EasyCacheHolder(reuse_threshold, start_percent, end_percent, subsample_factor=8, offload_cache_diff=False, verbose=verbose) + model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, "easycache", easycache_sample_wrapper) + model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.CALC_COND_BATCH, "easycache", easycache_calc_cond_batch_wrapper) + model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, "easycache", easycache_forward_wrapper) + return io.NodeOutput(model) + + +class LazyCacheHolder: + def __init__(self, reuse_threshold: float, start_percent: float, end_percent: float, subsample_factor: int, offload_cache_diff: bool, verbose: bool=False): + self.name = "LazyCache" + self.reuse_threshold = reuse_threshold + self.start_percent = start_percent + self.end_percent = end_percent + self.subsample_factor = subsample_factor + self.offload_cache_diff = offload_cache_diff + self.verbose = verbose + # timestep values + self.start_t = 0.0 + self.end_t = 0.0 + # control values + self.relative_transformation_rate: float = None + self.cumulative_change_rate = 0.0 + self.initial_step = True + # cache values + self.x_prev_subsampled: torch.Tensor = None + self.output_prev_subsampled: torch.Tensor = None + self.output_prev_norm: torch.Tensor = None + self.cache_diff: torch.Tensor = None + self.output_change_rates = [] + self.approx_output_change_rates = [] + self.total_steps_skipped = 0 + + def has_cache_diff(self) -> bool: + return self.cache_diff is not None + + def is_past_end_timestep(self, timestep: float) -> bool: + return not (timestep[0] > self.end_t).item() + + def should_do_easycache(self, timestep: float) -> bool: + return (timestep[0] <= self.start_t).item() + + def has_x_prev_subsampled(self) -> bool: + return self.x_prev_subsampled is not None + + def has_output_prev_subsampled(self) -> bool: + return self.output_prev_subsampled is not None + + def has_output_prev_norm(self) -> bool: + return self.output_prev_norm is not None + + def has_relative_transformation_rate(self) -> bool: + return self.relative_transformation_rate is not None + + def prepare_timesteps(self, model_sampling): + self.start_t = model_sampling.percent_to_sigma(self.start_percent) + self.end_t = model_sampling.percent_to_sigma(self.end_percent) + return self + + def subsample(self, x: torch.Tensor, clone: bool = True) -> torch.Tensor: + if self.subsample_factor > 1: + to_return = x[..., ::self.subsample_factor, ::self.subsample_factor] + if clone: + return to_return.clone() + return to_return + if clone: + return x.clone() + return x + + def apply_cache_diff(self, x: torch.Tensor): + self.total_steps_skipped += 1 + return x + self.cache_diff.to(x.device) + + def update_cache_diff(self, output: torch.Tensor, x: torch.Tensor): + self.cache_diff = output - x + + def reset(self): + self.relative_transformation_rate = 0.0 + self.cumulative_change_rate = 0.0 + self.initial_step = True + self.output_change_rates = [] + self.approx_output_change_rates = [] + del self.cache_diff + self.cache_diff = None + self.total_steps_skipped = 0 + return self + + def clone(self): + return LazyCacheHolder(self.reuse_threshold, self.start_percent, self.end_percent, self.subsample_factor, self.offload_cache_diff, self.verbose) + +class LazyCacheNode(io.ComfyNode): + @classmethod + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="LazyCache", + display_name="LazyCache", + description="A homebrew version of EasyCache - even 'easier' version of EasyCache to implement. Overall works worse than EasyCache, but better in some rare cases AND universal compatibility with everything in ComfyUI.", + category="advanced/debug/model", + is_experimental=True, + inputs=[ + io.Model.Input("model", tooltip="The model to add LazyCache to."), + io.Float.Input("reuse_threshold", min=0.0, default=0.2, max=3.0, step=0.01, tooltip="The threshold for reusing cached steps."), + io.Float.Input("start_percent", min=0.0, default=0.15, max=1.0, step=0.01, tooltip="The relative sampling step to begin use of LazyCache."), + io.Float.Input("end_percent", min=0.0, default=0.95, max=1.0, step=0.01, tooltip="The relative sampling step to end use of LazyCache."), + io.Boolean.Input("verbose", default=False, tooltip="Whether to log verbose information."), + ], + outputs=[ + io.Model.Output(tooltip="The model with LazyCache."), + ], + ) + + @classmethod + def execute(cls, model: io.Model.Type, reuse_threshold: float, start_percent: float, end_percent: float, verbose: bool) -> io.NodeOutput: + model = model.clone() + model.model_options["transformer_options"]["easycache"] = LazyCacheHolder(reuse_threshold, start_percent, end_percent, subsample_factor=8, offload_cache_diff=False, verbose=verbose) + model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, "lazycache", easycache_sample_wrapper) + model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.PREDICT_NOISE, "lazycache", lazycache_predict_noise_wrapper) + return io.NodeOutput(model) + + +class EasyCacheExtension(ComfyExtension): + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + EasyCacheNode, + LazyCacheNode, + ] + +def comfy_entrypoint(): + return EasyCacheExtension() diff --git a/nodes.py b/nodes.py index 9681750d3..723ce3384 100644 --- a/nodes.py +++ b/nodes.py @@ -2322,7 +2322,8 @@ async def init_builtin_extra_nodes(): "nodes_tcfg.py", "nodes_context_windows.py", "nodes_qwen.py", - "nodes_model_patch.py" + "nodes_model_patch.py", + "nodes_easycache.py", ] import_failed = []