mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-09-10 11:35:40 +00:00
Implement EasyCache and Invent LazyCache (#9496)
* Attempting a universal implementation of EasyCache, starting with flux as test; I screwed up the math a bit, but when I set it just right it works. * Fixed math to make threshold work as expected, refactored code to use EasyCacheHolder instead of a dict wrapped by object * Use sigmas from transformer_options instead of timesteps to be compatible with a greater amount of models, make end_percent work * Make log statement when not skipping useful, preparing for per-cond caching * Added DIFFUSION_MODEL wrapper around forward function for wan model * Add subsampling for heuristic inputs * Add subsampling to output_prev (output_prev_subsampled now) * Properly consider conds in EasyCache logic * Created SuperEasyCache to test what happens if caching and reuse is moved outside the scope of conds, added PREDICT_NOISE wrapper to facilitate this test * Change max reuse_threshold to 3.0 * Mark EasyCache/SuperEasyCache as experimental (beta) * Make Lumina2 compatible with EasyCache * Add EasyCache support for Qwen Image * Fix missing comma, curse you Cursor * Add EasyCache support to AceStep * Add EasyCache support to Chroma * Added EasyCache support to Cosmos Predict t2i * Make EasyCache not crash with Cosmos Predict ImagToVideo latents, but does not work well at all * Add EasyCache support to hidream * Added EasyCache support to hunyuan video * Added EasyCache support to hunyuan3d * Added EasyCache support to LTXV (not very good, but does not crash) * Implemented EasyCache for aura_flow * Renamed SuperEasyCache to LazyCache, hardcoded subsample_factor to 8 on nodes * Eatra logging when verbose is true for EasyCache
This commit is contained in:
@@ -19,6 +19,7 @@ import torch
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
|
import comfy.patcher_extension
|
||||||
|
|
||||||
from comfy.ldm.lightricks.model import TimestepEmbedding, Timesteps
|
from comfy.ldm.lightricks.model import TimestepEmbedding, Timesteps
|
||||||
from .attention import LinearTransformerBlock, t2i_modulate
|
from .attention import LinearTransformerBlock, t2i_modulate
|
||||||
@@ -343,7 +344,28 @@ class ACEStepTransformer2DModel(nn.Module):
|
|||||||
output = self.final_layer(hidden_states, embedded_timestep, output_length)
|
output = self.final_layer(hidden_states, embedded_timestep, output_length)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def forward(
|
def forward(self,
|
||||||
|
x,
|
||||||
|
timestep,
|
||||||
|
attention_mask=None,
|
||||||
|
context: Optional[torch.Tensor] = None,
|
||||||
|
text_attention_mask: Optional[torch.LongTensor] = None,
|
||||||
|
speaker_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
lyric_token_idx: Optional[torch.LongTensor] = None,
|
||||||
|
lyric_mask: Optional[torch.LongTensor] = None,
|
||||||
|
block_controlnet_hidden_states: Optional[Union[List[torch.Tensor], torch.Tensor]] = None,
|
||||||
|
controlnet_scale: Union[float, torch.Tensor] = 1.0,
|
||||||
|
lyrics_strength=1.0,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||||
|
self._forward,
|
||||||
|
self,
|
||||||
|
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, kwargs.get("transformer_options", {}))
|
||||||
|
).execute(x, timestep, attention_mask, context, text_attention_mask, speaker_embeds, lyric_token_idx, lyric_mask, block_controlnet_hidden_states,
|
||||||
|
controlnet_scale, lyrics_strength, **kwargs)
|
||||||
|
|
||||||
|
def _forward(
|
||||||
self,
|
self,
|
||||||
x,
|
x,
|
||||||
timestep,
|
timestep,
|
||||||
|
@@ -9,6 +9,7 @@ import torch.nn.functional as F
|
|||||||
|
|
||||||
from comfy.ldm.modules.attention import optimized_attention
|
from comfy.ldm.modules.attention import optimized_attention
|
||||||
import comfy.ops
|
import comfy.ops
|
||||||
|
import comfy.patcher_extension
|
||||||
import comfy.ldm.common_dit
|
import comfy.ldm.common_dit
|
||||||
|
|
||||||
def modulate(x, shift, scale):
|
def modulate(x, shift, scale):
|
||||||
@@ -436,6 +437,13 @@ class MMDiT(nn.Module):
|
|||||||
return x + pos_encoding.reshape(1, -1, self.positional_encoding.shape[-1])
|
return x + pos_encoding.reshape(1, -1, self.positional_encoding.shape[-1])
|
||||||
|
|
||||||
def forward(self, x, timestep, context, transformer_options={}, **kwargs):
|
def forward(self, x, timestep, context, transformer_options={}, **kwargs):
|
||||||
|
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||||
|
self._forward,
|
||||||
|
self,
|
||||||
|
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
|
||||||
|
).execute(x, timestep, context, transformer_options, **kwargs)
|
||||||
|
|
||||||
|
def _forward(self, x, timestep, context, transformer_options={}, **kwargs):
|
||||||
patches_replace = transformer_options.get("patches_replace", {})
|
patches_replace = transformer_options.get("patches_replace", {})
|
||||||
# patchify x, add PE
|
# patchify x, add PE
|
||||||
b, c, h, w = x.shape
|
b, c, h, w = x.shape
|
||||||
|
@@ -5,6 +5,7 @@ from dataclasses import dataclass
|
|||||||
import torch
|
import torch
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
from einops import rearrange, repeat
|
from einops import rearrange, repeat
|
||||||
|
import comfy.patcher_extension
|
||||||
import comfy.ldm.common_dit
|
import comfy.ldm.common_dit
|
||||||
|
|
||||||
from comfy.ldm.flux.layers import (
|
from comfy.ldm.flux.layers import (
|
||||||
@@ -253,6 +254,13 @@ class Chroma(nn.Module):
|
|||||||
return img
|
return img
|
||||||
|
|
||||||
def forward(self, x, timestep, context, guidance, control=None, transformer_options={}, **kwargs):
|
def forward(self, x, timestep, context, guidance, control=None, transformer_options={}, **kwargs):
|
||||||
|
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||||
|
self._forward,
|
||||||
|
self,
|
||||||
|
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
|
||||||
|
).execute(x, timestep, context, guidance, control, transformer_options, **kwargs)
|
||||||
|
|
||||||
|
def _forward(self, x, timestep, context, guidance, control=None, transformer_options={}, **kwargs):
|
||||||
bs, c, h, w = x.shape
|
bs, c, h, w = x.shape
|
||||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size))
|
x = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size))
|
||||||
|
|
||||||
|
@@ -27,6 +27,8 @@ from torchvision import transforms
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
import comfy.patcher_extension
|
||||||
|
|
||||||
from .blocks import (
|
from .blocks import (
|
||||||
FinalLayer,
|
FinalLayer,
|
||||||
GeneralDITTransformerBlock,
|
GeneralDITTransformerBlock,
|
||||||
@@ -435,6 +437,42 @@ class GeneralDIT(nn.Module):
|
|||||||
latent_condition_sigma: Optional[torch.Tensor] = None,
|
latent_condition_sigma: Optional[torch.Tensor] = None,
|
||||||
condition_video_augment_sigma: Optional[torch.Tensor] = None,
|
condition_video_augment_sigma: Optional[torch.Tensor] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
|
):
|
||||||
|
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||||
|
self._forward,
|
||||||
|
self,
|
||||||
|
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, kwargs.get("transformer_options", {}))
|
||||||
|
).execute(x,
|
||||||
|
timesteps,
|
||||||
|
context,
|
||||||
|
attention_mask,
|
||||||
|
fps,
|
||||||
|
image_size,
|
||||||
|
padding_mask,
|
||||||
|
scalar_feature,
|
||||||
|
data_type,
|
||||||
|
latent_condition,
|
||||||
|
latent_condition_sigma,
|
||||||
|
condition_video_augment_sigma,
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
|
def _forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
timesteps: torch.Tensor,
|
||||||
|
context: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
# crossattn_emb: torch.Tensor,
|
||||||
|
# crossattn_mask: Optional[torch.Tensor] = None,
|
||||||
|
fps: Optional[torch.Tensor] = None,
|
||||||
|
image_size: Optional[torch.Tensor] = None,
|
||||||
|
padding_mask: Optional[torch.Tensor] = None,
|
||||||
|
scalar_feature: Optional[torch.Tensor] = None,
|
||||||
|
data_type: Optional[DataType] = DataType.VIDEO,
|
||||||
|
latent_condition: Optional[torch.Tensor] = None,
|
||||||
|
latent_condition_sigma: Optional[torch.Tensor] = None,
|
||||||
|
condition_video_augment_sigma: Optional[torch.Tensor] = None,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
|
@@ -11,6 +11,7 @@ import math
|
|||||||
from .position_embedding import VideoRopePosition3DEmb, LearnablePosEmbAxis
|
from .position_embedding import VideoRopePosition3DEmb, LearnablePosEmbAxis
|
||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
|
|
||||||
|
import comfy.patcher_extension
|
||||||
from comfy.ldm.modules.attention import optimized_attention
|
from comfy.ldm.modules.attention import optimized_attention
|
||||||
|
|
||||||
def apply_rotary_pos_emb(
|
def apply_rotary_pos_emb(
|
||||||
@@ -805,7 +806,21 @@ class MiniTrainDIT(nn.Module):
|
|||||||
)
|
)
|
||||||
return x_B_C_Tt_Hp_Wp
|
return x_B_C_Tt_Hp_Wp
|
||||||
|
|
||||||
def forward(
|
def forward(self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
timesteps: torch.Tensor,
|
||||||
|
context: torch.Tensor,
|
||||||
|
fps: Optional[torch.Tensor] = None,
|
||||||
|
padding_mask: Optional[torch.Tensor] = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||||
|
self._forward,
|
||||||
|
self,
|
||||||
|
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, kwargs.get("transformer_options", {}))
|
||||||
|
).execute(x, timesteps, context, fps, padding_mask, **kwargs)
|
||||||
|
|
||||||
|
def _forward(
|
||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
timesteps: torch.Tensor,
|
timesteps: torch.Tensor,
|
||||||
|
@@ -6,6 +6,7 @@ import torch
|
|||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
from einops import rearrange, repeat
|
from einops import rearrange, repeat
|
||||||
import comfy.ldm.common_dit
|
import comfy.ldm.common_dit
|
||||||
|
import comfy.patcher_extension
|
||||||
|
|
||||||
from .layers import (
|
from .layers import (
|
||||||
DoubleStreamBlock,
|
DoubleStreamBlock,
|
||||||
@@ -214,6 +215,13 @@ class Flux(nn.Module):
|
|||||||
return img, repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
return img, repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
||||||
|
|
||||||
def forward(self, x, timestep, context, y=None, guidance=None, ref_latents=None, control=None, transformer_options={}, **kwargs):
|
def forward(self, x, timestep, context, y=None, guidance=None, ref_latents=None, control=None, transformer_options={}, **kwargs):
|
||||||
|
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||||
|
self._forward,
|
||||||
|
self,
|
||||||
|
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
|
||||||
|
).execute(x, timestep, context, y, guidance, ref_latents, control, transformer_options, **kwargs)
|
||||||
|
|
||||||
|
def _forward(self, x, timestep, context, y=None, guidance=None, ref_latents=None, control=None, transformer_options={}, **kwargs):
|
||||||
bs, c, h_orig, w_orig = x.shape
|
bs, c, h_orig, w_orig = x.shape
|
||||||
patch_size = self.patch_size
|
patch_size = self.patch_size
|
||||||
|
|
||||||
|
@@ -13,6 +13,7 @@ from comfy.ldm.flux.layers import LastLayer
|
|||||||
|
|
||||||
from comfy.ldm.modules.attention import optimized_attention
|
from comfy.ldm.modules.attention import optimized_attention
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
|
import comfy.patcher_extension
|
||||||
import comfy.ldm.common_dit
|
import comfy.ldm.common_dit
|
||||||
|
|
||||||
|
|
||||||
@@ -692,7 +693,23 @@ class HiDreamImageTransformer2DModel(nn.Module):
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
return x, x_masks, img_sizes
|
return x, x_masks, img_sizes
|
||||||
|
|
||||||
def forward(
|
def forward(self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
t: torch.Tensor,
|
||||||
|
y: Optional[torch.Tensor] = None,
|
||||||
|
context: Optional[torch.Tensor] = None,
|
||||||
|
encoder_hidden_states_llama3=None,
|
||||||
|
image_cond=None,
|
||||||
|
control = None,
|
||||||
|
transformer_options = {},
|
||||||
|
):
|
||||||
|
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||||
|
self._forward,
|
||||||
|
self,
|
||||||
|
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
|
||||||
|
).execute(x, t, y, context, encoder_hidden_states_llama3, image_cond, control, transformer_options)
|
||||||
|
|
||||||
|
def _forward(
|
||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
t: torch.Tensor,
|
t: torch.Tensor,
|
||||||
|
@@ -7,6 +7,7 @@ from comfy.ldm.flux.layers import (
|
|||||||
SingleStreamBlock,
|
SingleStreamBlock,
|
||||||
timestep_embedding,
|
timestep_embedding,
|
||||||
)
|
)
|
||||||
|
import comfy.patcher_extension
|
||||||
|
|
||||||
|
|
||||||
class Hunyuan3Dv2(nn.Module):
|
class Hunyuan3Dv2(nn.Module):
|
||||||
@@ -67,6 +68,13 @@ class Hunyuan3Dv2(nn.Module):
|
|||||||
self.final_layer = LastLayer(hidden_size, 1, in_channels, dtype=dtype, device=device, operations=operations)
|
self.final_layer = LastLayer(hidden_size, 1, in_channels, dtype=dtype, device=device, operations=operations)
|
||||||
|
|
||||||
def forward(self, x, timestep, context, guidance=None, transformer_options={}, **kwargs):
|
def forward(self, x, timestep, context, guidance=None, transformer_options={}, **kwargs):
|
||||||
|
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||||
|
self._forward,
|
||||||
|
self,
|
||||||
|
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
|
||||||
|
).execute(x, timestep, context, guidance, transformer_options, **kwargs)
|
||||||
|
|
||||||
|
def _forward(self, x, timestep, context, guidance=None, transformer_options={}, **kwargs):
|
||||||
x = x.movedim(-1, -2)
|
x = x.movedim(-1, -2)
|
||||||
timestep = 1.0 - timestep
|
timestep = 1.0 - timestep
|
||||||
txt = context
|
txt = context
|
||||||
|
@@ -1,6 +1,7 @@
|
|||||||
#Based on Flux code because of weird hunyuan video code license.
|
#Based on Flux code because of weird hunyuan video code license.
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import comfy.patcher_extension
|
||||||
import comfy.ldm.flux.layers
|
import comfy.ldm.flux.layers
|
||||||
import comfy.ldm.modules.diffusionmodules.mmdit
|
import comfy.ldm.modules.diffusionmodules.mmdit
|
||||||
from comfy.ldm.modules.attention import optimized_attention
|
from comfy.ldm.modules.attention import optimized_attention
|
||||||
@@ -348,6 +349,13 @@ class HunyuanVideo(nn.Module):
|
|||||||
return repeat(img_ids, "t h w c -> b (t h w) c", b=bs)
|
return repeat(img_ids, "t h w c -> b (t h w) c", b=bs)
|
||||||
|
|
||||||
def forward(self, x, timestep, context, y, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, control=None, transformer_options={}, **kwargs):
|
def forward(self, x, timestep, context, y, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, control=None, transformer_options={}, **kwargs):
|
||||||
|
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||||
|
self._forward,
|
||||||
|
self,
|
||||||
|
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
|
||||||
|
).execute(x, timestep, context, y, guidance, attention_mask, guiding_frame_index, ref_latent, control, transformer_options, **kwargs)
|
||||||
|
|
||||||
|
def _forward(self, x, timestep, context, y, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, control=None, transformer_options={}, **kwargs):
|
||||||
bs, c, t, h, w = x.shape
|
bs, c, t, h, w = x.shape
|
||||||
img_ids = self.img_ids(x)
|
img_ids = self.img_ids(x)
|
||||||
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
||||||
|
@@ -1,5 +1,6 @@
|
|||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
import comfy.patcher_extension
|
||||||
import comfy.ldm.modules.attention
|
import comfy.ldm.modules.attention
|
||||||
import comfy.ldm.common_dit
|
import comfy.ldm.common_dit
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
@@ -420,6 +421,13 @@ class LTXVModel(torch.nn.Module):
|
|||||||
self.patchifier = SymmetricPatchifier(1)
|
self.patchifier = SymmetricPatchifier(1)
|
||||||
|
|
||||||
def forward(self, x, timestep, context, attention_mask, frame_rate=25, transformer_options={}, keyframe_idxs=None, **kwargs):
|
def forward(self, x, timestep, context, attention_mask, frame_rate=25, transformer_options={}, keyframe_idxs=None, **kwargs):
|
||||||
|
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||||
|
self._forward,
|
||||||
|
self,
|
||||||
|
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
|
||||||
|
).execute(x, timestep, context, attention_mask, frame_rate, transformer_options, keyframe_idxs, **kwargs)
|
||||||
|
|
||||||
|
def _forward(self, x, timestep, context, attention_mask, frame_rate=25, transformer_options={}, keyframe_idxs=None, **kwargs):
|
||||||
patches_replace = transformer_options.get("patches_replace", {})
|
patches_replace = transformer_options.get("patches_replace", {})
|
||||||
|
|
||||||
orig_shape = list(x.shape)
|
orig_shape = list(x.shape)
|
||||||
|
@@ -11,6 +11,7 @@ import comfy.ldm.common_dit
|
|||||||
from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder
|
from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder
|
||||||
from comfy.ldm.modules.attention import optimized_attention_masked
|
from comfy.ldm.modules.attention import optimized_attention_masked
|
||||||
from comfy.ldm.flux.layers import EmbedND
|
from comfy.ldm.flux.layers import EmbedND
|
||||||
|
import comfy.patcher_extension
|
||||||
|
|
||||||
|
|
||||||
def modulate(x, scale):
|
def modulate(x, scale):
|
||||||
@@ -590,8 +591,15 @@ class NextDiT(nn.Module):
|
|||||||
|
|
||||||
return padded_full_embed, mask, img_sizes, l_effective_cap_len, freqs_cis
|
return padded_full_embed, mask, img_sizes, l_effective_cap_len, freqs_cis
|
||||||
|
|
||||||
# def forward(self, x, t, cap_feats, cap_mask):
|
|
||||||
def forward(self, x, timesteps, context, num_tokens, attention_mask=None, **kwargs):
|
def forward(self, x, timesteps, context, num_tokens, attention_mask=None, **kwargs):
|
||||||
|
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||||
|
self._forward,
|
||||||
|
self,
|
||||||
|
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, kwargs.get("transformer_options", {}))
|
||||||
|
).execute(x, timesteps, context, num_tokens, attention_mask, **kwargs)
|
||||||
|
|
||||||
|
# def forward(self, x, t, cap_feats, cap_mask):
|
||||||
|
def _forward(self, x, timesteps, context, num_tokens, attention_mask=None, **kwargs):
|
||||||
t = 1.0 - timesteps
|
t = 1.0 - timesteps
|
||||||
cap_feats = context
|
cap_feats = context
|
||||||
cap_mask = attention_mask
|
cap_mask = attention_mask
|
||||||
|
@@ -9,6 +9,7 @@ from comfy.ldm.lightricks.model import TimestepEmbedding, Timesteps
|
|||||||
from comfy.ldm.modules.attention import optimized_attention_masked
|
from comfy.ldm.modules.attention import optimized_attention_masked
|
||||||
from comfy.ldm.flux.layers import EmbedND
|
from comfy.ldm.flux.layers import EmbedND
|
||||||
import comfy.ldm.common_dit
|
import comfy.ldm.common_dit
|
||||||
|
import comfy.patcher_extension
|
||||||
|
|
||||||
class GELU(nn.Module):
|
class GELU(nn.Module):
|
||||||
def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool = True, dtype=None, device=None, operations=None):
|
def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool = True, dtype=None, device=None, operations=None):
|
||||||
@@ -355,7 +356,14 @@ class QwenImageTransformer2DModel(nn.Module):
|
|||||||
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0) - (w_len // 2)
|
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0) - (w_len // 2)
|
||||||
return hidden_states, repeat(img_ids, "h w c -> b (h w) c", b=bs), orig_shape
|
return hidden_states, repeat(img_ids, "h w c -> b (h w) c", b=bs), orig_shape
|
||||||
|
|
||||||
def forward(
|
def forward(self, x, timestep, context, attention_mask=None, guidance=None, ref_latents=None, transformer_options={}, **kwargs):
|
||||||
|
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||||
|
self._forward,
|
||||||
|
self,
|
||||||
|
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
|
||||||
|
).execute(x, timestep, context, attention_mask, guidance, ref_latents, transformer_options, **kwargs)
|
||||||
|
|
||||||
|
def _forward(
|
||||||
self,
|
self,
|
||||||
x,
|
x,
|
||||||
timesteps,
|
timesteps,
|
||||||
|
@@ -11,6 +11,7 @@ from comfy.ldm.flux.layers import EmbedND
|
|||||||
from comfy.ldm.flux.math import apply_rope
|
from comfy.ldm.flux.math import apply_rope
|
||||||
import comfy.ldm.common_dit
|
import comfy.ldm.common_dit
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
|
import comfy.patcher_extension
|
||||||
|
|
||||||
|
|
||||||
def sinusoidal_embedding_1d(dim, position):
|
def sinusoidal_embedding_1d(dim, position):
|
||||||
@@ -573,6 +574,13 @@ class WanModel(torch.nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
def forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, transformer_options={}, **kwargs):
|
def forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, transformer_options={}, **kwargs):
|
||||||
|
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||||
|
self._forward,
|
||||||
|
self,
|
||||||
|
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
|
||||||
|
).execute(x, timestep, context, clip_fea, time_dim_concat, transformer_options, **kwargs)
|
||||||
|
|
||||||
|
def _forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, transformer_options={}, **kwargs):
|
||||||
bs, c, t, h, w = x.shape
|
bs, c, t, h, w = x.shape
|
||||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
|
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
|
||||||
|
|
||||||
|
@@ -50,6 +50,7 @@ class WrappersMP:
|
|||||||
OUTER_SAMPLE = "outer_sample"
|
OUTER_SAMPLE = "outer_sample"
|
||||||
PREPARE_SAMPLING = "prepare_sampling"
|
PREPARE_SAMPLING = "prepare_sampling"
|
||||||
SAMPLER_SAMPLE = "sampler_sample"
|
SAMPLER_SAMPLE = "sampler_sample"
|
||||||
|
PREDICT_NOISE = "predict_noise"
|
||||||
CALC_COND_BATCH = "calc_cond_batch"
|
CALC_COND_BATCH = "calc_cond_batch"
|
||||||
APPLY_MODEL = "apply_model"
|
APPLY_MODEL = "apply_model"
|
||||||
DIFFUSION_MODEL = "diffusion_model"
|
DIFFUSION_MODEL = "diffusion_model"
|
||||||
|
@@ -953,7 +953,14 @@ class CFGGuider:
|
|||||||
self.original_conds[k] = comfy.sampler_helpers.convert_cond(conds[k])
|
self.original_conds[k] = comfy.sampler_helpers.convert_cond(conds[k])
|
||||||
|
|
||||||
def __call__(self, *args, **kwargs):
|
def __call__(self, *args, **kwargs):
|
||||||
return self.predict_noise(*args, **kwargs)
|
return self.outer_predict_noise(*args, **kwargs)
|
||||||
|
|
||||||
|
def outer_predict_noise(self, x, timestep, model_options={}, seed=None):
|
||||||
|
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||||
|
self.predict_noise,
|
||||||
|
self,
|
||||||
|
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.PREDICT_NOISE, self.model_options, is_model_options=True)
|
||||||
|
).execute(x, timestep, model_options, seed)
|
||||||
|
|
||||||
def predict_noise(self, x, timestep, model_options={}, seed=None):
|
def predict_noise(self, x, timestep, model_options={}, seed=None):
|
||||||
return sampling_function(self.inner_model, x, timestep, self.conds.get("negative", None), self.conds.get("positive", None), self.cfg, model_options=model_options, seed=seed)
|
return sampling_function(self.inner_model, x, timestep, self.conds.get("negative", None), self.conds.get("positive", None), self.cfg, model_options=model_options, seed=seed)
|
||||||
|
459
comfy_extras/nodes_easycache.py
Normal file
459
comfy_extras/nodes_easycache.py
Normal file
@@ -0,0 +1,459 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
from typing import TYPE_CHECKING, Union
|
||||||
|
from comfy_api.latest import io, ComfyExtension
|
||||||
|
import comfy.patcher_extension
|
||||||
|
import logging
|
||||||
|
import torch
|
||||||
|
import comfy.model_patcher
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
|
||||||
|
def easycache_forward_wrapper(executor, *args, **kwargs):
|
||||||
|
# get values from args
|
||||||
|
x: torch.Tensor = args[0]
|
||||||
|
transformer_options: dict[str] = args[-1]
|
||||||
|
if not isinstance(transformer_options, dict):
|
||||||
|
transformer_options = kwargs.get("transformer_options")
|
||||||
|
if not transformer_options:
|
||||||
|
transformer_options = args[-2]
|
||||||
|
easycache: EasyCacheHolder = transformer_options["easycache"]
|
||||||
|
sigmas = transformer_options["sigmas"]
|
||||||
|
uuids = transformer_options["uuids"]
|
||||||
|
if sigmas is not None and easycache.is_past_end_timestep(sigmas):
|
||||||
|
return executor(*args, **kwargs)
|
||||||
|
# prepare next x_prev
|
||||||
|
has_first_cond_uuid = easycache.has_first_cond_uuid(uuids)
|
||||||
|
next_x_prev = x
|
||||||
|
input_change = None
|
||||||
|
do_easycache = easycache.should_do_easycache(sigmas)
|
||||||
|
if do_easycache:
|
||||||
|
# if first cond marked this step for skipping, skip it and use appropriate cached values
|
||||||
|
if easycache.skip_current_step:
|
||||||
|
if easycache.verbose:
|
||||||
|
logging.info(f"EasyCache [verbose] - was marked to skip this step by {easycache.first_cond_uuid}. Present uuids: {uuids}")
|
||||||
|
return easycache.apply_cache_diff(x, uuids)
|
||||||
|
if easycache.initial_step:
|
||||||
|
easycache.first_cond_uuid = uuids[0]
|
||||||
|
has_first_cond_uuid = easycache.has_first_cond_uuid(uuids)
|
||||||
|
easycache.initial_step = False
|
||||||
|
if has_first_cond_uuid:
|
||||||
|
if easycache.has_x_prev_subsampled():
|
||||||
|
input_change = (easycache.subsample(x, uuids, clone=False) - easycache.x_prev_subsampled).flatten().abs().mean()
|
||||||
|
if easycache.has_output_prev_norm() and easycache.has_relative_transformation_rate():
|
||||||
|
approx_output_change_rate = (easycache.relative_transformation_rate * input_change) / easycache.output_prev_norm
|
||||||
|
easycache.cumulative_change_rate += approx_output_change_rate
|
||||||
|
if easycache.cumulative_change_rate < easycache.reuse_threshold:
|
||||||
|
if easycache.verbose:
|
||||||
|
logging.info(f"EasyCache [verbose] - skipping step; cumulative_change_rate: {easycache.cumulative_change_rate}, reuse_threshold: {easycache.reuse_threshold}")
|
||||||
|
# other conds should also skip this step, and instead use their cached values
|
||||||
|
easycache.skip_current_step = True
|
||||||
|
return easycache.apply_cache_diff(x, uuids)
|
||||||
|
else:
|
||||||
|
if easycache.verbose:
|
||||||
|
logging.info(f"EasyCache [verbose] - NOT skipping step; cumulative_change_rate: {easycache.cumulative_change_rate}, reuse_threshold: {easycache.reuse_threshold}")
|
||||||
|
easycache.cumulative_change_rate = 0.0
|
||||||
|
|
||||||
|
output: torch.Tensor = executor(*args, **kwargs)
|
||||||
|
if has_first_cond_uuid and easycache.has_output_prev_norm():
|
||||||
|
output_change = (easycache.subsample(output, uuids, clone=False) - easycache.output_prev_subsampled).flatten().abs().mean()
|
||||||
|
if easycache.verbose:
|
||||||
|
output_change_rate = output_change / easycache.output_prev_norm
|
||||||
|
easycache.output_change_rates.append(output_change_rate.item())
|
||||||
|
if easycache.has_relative_transformation_rate():
|
||||||
|
approx_output_change_rate = (easycache.relative_transformation_rate * input_change) / easycache.output_prev_norm
|
||||||
|
easycache.approx_output_change_rates.append(approx_output_change_rate.item())
|
||||||
|
if easycache.verbose:
|
||||||
|
logging.info(f"EasyCache [verbose] - approx_output_change_rate: {approx_output_change_rate}")
|
||||||
|
if input_change is not None:
|
||||||
|
easycache.relative_transformation_rate = output_change / input_change
|
||||||
|
if easycache.verbose:
|
||||||
|
logging.info(f"EasyCache [verbose] - output_change_rate: {output_change_rate}")
|
||||||
|
# TODO: allow cache_diff to be offloaded
|
||||||
|
easycache.update_cache_diff(output, next_x_prev, uuids)
|
||||||
|
if has_first_cond_uuid:
|
||||||
|
easycache.x_prev_subsampled = easycache.subsample(next_x_prev, uuids)
|
||||||
|
easycache.output_prev_subsampled = easycache.subsample(output, uuids)
|
||||||
|
easycache.output_prev_norm = output.flatten().abs().mean()
|
||||||
|
if easycache.verbose:
|
||||||
|
logging.info(f"EasyCache [verbose] - x_prev_subsampled: {easycache.x_prev_subsampled.shape}")
|
||||||
|
return output
|
||||||
|
|
||||||
|
def lazycache_predict_noise_wrapper(executor, *args, **kwargs):
|
||||||
|
# get values from args
|
||||||
|
x: torch.Tensor = args[0]
|
||||||
|
timestep: float = args[1]
|
||||||
|
model_options: dict[str] = args[2]
|
||||||
|
easycache: LazyCacheHolder = model_options["transformer_options"]["easycache"]
|
||||||
|
if easycache.is_past_end_timestep(timestep):
|
||||||
|
return executor(*args, **kwargs)
|
||||||
|
# prepare next x_prev
|
||||||
|
next_x_prev = x
|
||||||
|
input_change = None
|
||||||
|
do_easycache = easycache.should_do_easycache(timestep)
|
||||||
|
if do_easycache:
|
||||||
|
if easycache.has_x_prev_subsampled():
|
||||||
|
if easycache.has_x_prev_subsampled():
|
||||||
|
input_change = (easycache.subsample(x, clone=False) - easycache.x_prev_subsampled).flatten().abs().mean()
|
||||||
|
if easycache.has_output_prev_norm() and easycache.has_relative_transformation_rate():
|
||||||
|
approx_output_change_rate = (easycache.relative_transformation_rate * input_change) / easycache.output_prev_norm
|
||||||
|
easycache.cumulative_change_rate += approx_output_change_rate
|
||||||
|
if easycache.cumulative_change_rate < easycache.reuse_threshold:
|
||||||
|
if easycache.verbose:
|
||||||
|
logging.info(f"LazyCache [verbose] - skipping step; cumulative_change_rate: {easycache.cumulative_change_rate}, reuse_threshold: {easycache.reuse_threshold}")
|
||||||
|
# other conds should also skip this step, and instead use their cached values
|
||||||
|
easycache.skip_current_step = True
|
||||||
|
return easycache.apply_cache_diff(x)
|
||||||
|
else:
|
||||||
|
if easycache.verbose:
|
||||||
|
logging.info(f"LazyCache [verbose] - NOT skipping step; cumulative_change_rate: {easycache.cumulative_change_rate}, reuse_threshold: {easycache.reuse_threshold}")
|
||||||
|
easycache.cumulative_change_rate = 0.0
|
||||||
|
output: torch.Tensor = executor(*args, **kwargs)
|
||||||
|
if easycache.has_output_prev_norm():
|
||||||
|
output_change = (easycache.subsample(output, clone=False) - easycache.output_prev_subsampled).flatten().abs().mean()
|
||||||
|
if easycache.verbose:
|
||||||
|
output_change_rate = output_change / easycache.output_prev_norm
|
||||||
|
easycache.output_change_rates.append(output_change_rate.item())
|
||||||
|
if easycache.has_relative_transformation_rate():
|
||||||
|
approx_output_change_rate = (easycache.relative_transformation_rate * input_change) / easycache.output_prev_norm
|
||||||
|
easycache.approx_output_change_rates.append(approx_output_change_rate.item())
|
||||||
|
if easycache.verbose:
|
||||||
|
logging.info(f"LazyCache [verbose] - approx_output_change_rate: {approx_output_change_rate}")
|
||||||
|
if input_change is not None:
|
||||||
|
easycache.relative_transformation_rate = output_change / input_change
|
||||||
|
if easycache.verbose:
|
||||||
|
logging.info(f"LazyCache [verbose] - output_change_rate: {output_change_rate}")
|
||||||
|
# TODO: allow cache_diff to be offloaded
|
||||||
|
easycache.update_cache_diff(output, next_x_prev)
|
||||||
|
easycache.x_prev_subsampled = easycache.subsample(next_x_prev)
|
||||||
|
easycache.output_prev_subsampled = easycache.subsample(output)
|
||||||
|
easycache.output_prev_norm = output.flatten().abs().mean()
|
||||||
|
if easycache.verbose:
|
||||||
|
logging.info(f"LazyCache [verbose] - x_prev_subsampled: {easycache.x_prev_subsampled.shape}")
|
||||||
|
return output
|
||||||
|
|
||||||
|
def easycache_calc_cond_batch_wrapper(executor, *args, **kwargs):
|
||||||
|
model_options = args[-1]
|
||||||
|
easycache: EasyCacheHolder = model_options["transformer_options"]["easycache"]
|
||||||
|
easycache.skip_current_step = False
|
||||||
|
# TODO: check if first_cond_uuid is active at this timestep; otherwise, EasyCache needs to be partially reset
|
||||||
|
return executor(*args, **kwargs)
|
||||||
|
|
||||||
|
def easycache_sample_wrapper(executor, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
This OUTER_SAMPLE wrapper makes sure easycache is prepped for current run, and all memory usage is cleared at the end.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
guider = executor.class_obj
|
||||||
|
orig_model_options = guider.model_options
|
||||||
|
guider.model_options = comfy.model_patcher.create_model_options_clone(orig_model_options)
|
||||||
|
# clone and prepare timesteps
|
||||||
|
guider.model_options["transformer_options"]["easycache"] = guider.model_options["transformer_options"]["easycache"].clone().prepare_timesteps(guider.model_patcher.model.model_sampling)
|
||||||
|
easycache: Union[EasyCacheHolder, LazyCacheHolder] = guider.model_options['transformer_options']['easycache']
|
||||||
|
logging.info(f"{easycache.name} enabled - threshold: {easycache.reuse_threshold}, start_percent: {easycache.start_percent}, end_percent: {easycache.end_percent}")
|
||||||
|
return executor(*args, **kwargs)
|
||||||
|
finally:
|
||||||
|
easycache = guider.model_options['transformer_options']['easycache']
|
||||||
|
output_change_rates = easycache.output_change_rates
|
||||||
|
approx_output_change_rates = easycache.approx_output_change_rates
|
||||||
|
if easycache.verbose:
|
||||||
|
logging.info(f"{easycache.name} [verbose] - output_change_rates {len(output_change_rates)}: {output_change_rates}")
|
||||||
|
logging.info(f"{easycache.name} [verbose] - approx_output_change_rates {len(approx_output_change_rates)}: {approx_output_change_rates}")
|
||||||
|
total_steps = len(args[3])-1
|
||||||
|
logging.info(f"{easycache.name} - skipped {easycache.total_steps_skipped}/{total_steps} steps ({total_steps/(total_steps-easycache.total_steps_skipped):.2f}x speedup).")
|
||||||
|
easycache.reset()
|
||||||
|
guider.model_options = orig_model_options
|
||||||
|
|
||||||
|
|
||||||
|
class EasyCacheHolder:
|
||||||
|
def __init__(self, reuse_threshold: float, start_percent: float, end_percent: float, subsample_factor: int, offload_cache_diff: bool, verbose: bool=False):
|
||||||
|
self.name = "EasyCache"
|
||||||
|
self.reuse_threshold = reuse_threshold
|
||||||
|
self.start_percent = start_percent
|
||||||
|
self.end_percent = end_percent
|
||||||
|
self.subsample_factor = subsample_factor
|
||||||
|
self.offload_cache_diff = offload_cache_diff
|
||||||
|
self.verbose = verbose
|
||||||
|
# timestep values
|
||||||
|
self.start_t = 0.0
|
||||||
|
self.end_t = 0.0
|
||||||
|
# control values
|
||||||
|
self.relative_transformation_rate: float = None
|
||||||
|
self.cumulative_change_rate = 0.0
|
||||||
|
self.initial_step = True
|
||||||
|
self.skip_current_step = False
|
||||||
|
# cache values
|
||||||
|
self.first_cond_uuid = None
|
||||||
|
self.x_prev_subsampled: torch.Tensor = None
|
||||||
|
self.output_prev_subsampled: torch.Tensor = None
|
||||||
|
self.output_prev_norm: torch.Tensor = None
|
||||||
|
self.uuid_cache_diffs: dict[UUID, torch.Tensor] = {}
|
||||||
|
self.output_change_rates = []
|
||||||
|
self.approx_output_change_rates = []
|
||||||
|
self.total_steps_skipped = 0
|
||||||
|
# how to deal with mismatched dims
|
||||||
|
self.allow_mismatch = True
|
||||||
|
self.cut_from_start = True
|
||||||
|
|
||||||
|
def is_past_end_timestep(self, timestep: float) -> bool:
|
||||||
|
return not (timestep[0] > self.end_t).item()
|
||||||
|
|
||||||
|
def should_do_easycache(self, timestep: float) -> bool:
|
||||||
|
return (timestep[0] <= self.start_t).item()
|
||||||
|
|
||||||
|
def has_x_prev_subsampled(self) -> bool:
|
||||||
|
return self.x_prev_subsampled is not None
|
||||||
|
|
||||||
|
def has_output_prev_subsampled(self) -> bool:
|
||||||
|
return self.output_prev_subsampled is not None
|
||||||
|
|
||||||
|
def has_output_prev_norm(self) -> bool:
|
||||||
|
return self.output_prev_norm is not None
|
||||||
|
|
||||||
|
def has_relative_transformation_rate(self) -> bool:
|
||||||
|
return self.relative_transformation_rate is not None
|
||||||
|
|
||||||
|
def prepare_timesteps(self, model_sampling):
|
||||||
|
self.start_t = model_sampling.percent_to_sigma(self.start_percent)
|
||||||
|
self.end_t = model_sampling.percent_to_sigma(self.end_percent)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def subsample(self, x: torch.Tensor, uuids: list[UUID], clone: bool = True) -> torch.Tensor:
|
||||||
|
batch_offset = x.shape[0] // len(uuids)
|
||||||
|
uuid_idx = uuids.index(self.first_cond_uuid)
|
||||||
|
if self.subsample_factor > 1:
|
||||||
|
to_return = x[uuid_idx*batch_offset:(uuid_idx+1)*batch_offset, ..., ::self.subsample_factor, ::self.subsample_factor]
|
||||||
|
if clone:
|
||||||
|
return to_return.clone()
|
||||||
|
return to_return
|
||||||
|
to_return = x[uuid_idx*batch_offset:(uuid_idx+1)*batch_offset, ...]
|
||||||
|
if clone:
|
||||||
|
return to_return.clone()
|
||||||
|
return to_return
|
||||||
|
|
||||||
|
def apply_cache_diff(self, x: torch.Tensor, uuids: list[UUID]):
|
||||||
|
if self.first_cond_uuid in uuids:
|
||||||
|
self.total_steps_skipped += 1
|
||||||
|
batch_offset = x.shape[0] // len(uuids)
|
||||||
|
for i, uuid in enumerate(uuids):
|
||||||
|
# if cached dims don't match x dims, cut off excess and hope for the best (cosmos world2video)
|
||||||
|
if x.shape[1:] != self.uuid_cache_diffs[uuid].shape[1:]:
|
||||||
|
if not self.allow_mismatch:
|
||||||
|
raise ValueError(f"Cached dims {self.uuid_cache_diffs[uuid].shape} don't match x dims {x.shape} - this is no good")
|
||||||
|
slicing = []
|
||||||
|
skip_this_dim = True
|
||||||
|
for dim_u, dim_x in zip(self.uuid_cache_diffs[uuid].shape, x.shape):
|
||||||
|
if skip_this_dim:
|
||||||
|
skip_this_dim = False
|
||||||
|
continue
|
||||||
|
if dim_u != dim_x:
|
||||||
|
if self.cut_from_start:
|
||||||
|
slicing.append(slice(dim_x-dim_u, None))
|
||||||
|
else:
|
||||||
|
slicing.append(slice(None, dim_u))
|
||||||
|
else:
|
||||||
|
slicing.append(slice(None))
|
||||||
|
slicing = [slice(i*batch_offset,(i+1)*batch_offset)] + slicing
|
||||||
|
x = x[slicing]
|
||||||
|
x += self.uuid_cache_diffs[uuid].to(x.device)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def update_cache_diff(self, output: torch.Tensor, x: torch.Tensor, uuids: list[UUID]):
|
||||||
|
# if output dims don't match x dims, cut off excess and hope for the best (cosmos world2video)
|
||||||
|
if output.shape[1:] != x.shape[1:]:
|
||||||
|
if not self.allow_mismatch:
|
||||||
|
raise ValueError(f"Output dims {output.shape} don't match x dims {x.shape} - this is no good")
|
||||||
|
slicing = []
|
||||||
|
skip_dim = True
|
||||||
|
for dim_o, dim_x in zip(output.shape, x.shape):
|
||||||
|
if not skip_dim and dim_o != dim_x:
|
||||||
|
if self.cut_from_start:
|
||||||
|
slicing.append(slice(dim_x-dim_o, None))
|
||||||
|
else:
|
||||||
|
slicing.append(slice(None, dim_o))
|
||||||
|
else:
|
||||||
|
slicing.append(slice(None))
|
||||||
|
skip_dim = False
|
||||||
|
x = x[slicing]
|
||||||
|
diff = output - x
|
||||||
|
batch_offset = diff.shape[0] // len(uuids)
|
||||||
|
for i, uuid in enumerate(uuids):
|
||||||
|
self.uuid_cache_diffs[uuid] = diff[i*batch_offset:(i+1)*batch_offset, ...]
|
||||||
|
|
||||||
|
def has_first_cond_uuid(self, uuids: list[UUID]) -> bool:
|
||||||
|
return self.first_cond_uuid in uuids
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.relative_transformation_rate = 0.0
|
||||||
|
self.cumulative_change_rate = 0.0
|
||||||
|
self.initial_step = True
|
||||||
|
self.skip_current_step = False
|
||||||
|
self.output_change_rates = []
|
||||||
|
self.first_cond_uuid = None
|
||||||
|
del self.x_prev_subsampled
|
||||||
|
self.x_prev_subsampled = None
|
||||||
|
del self.output_prev_subsampled
|
||||||
|
self.output_prev_subsampled = None
|
||||||
|
del self.output_prev_norm
|
||||||
|
self.output_prev_norm = None
|
||||||
|
del self.uuid_cache_diffs
|
||||||
|
self.uuid_cache_diffs = {}
|
||||||
|
self.total_steps_skipped = 0
|
||||||
|
return self
|
||||||
|
|
||||||
|
def clone(self):
|
||||||
|
return EasyCacheHolder(self.reuse_threshold, self.start_percent, self.end_percent, self.subsample_factor, self.offload_cache_diff, self.verbose)
|
||||||
|
|
||||||
|
|
||||||
|
class EasyCacheNode(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls) -> io.Schema:
|
||||||
|
return io.Schema(
|
||||||
|
node_id="EasyCache",
|
||||||
|
display_name="EasyCache",
|
||||||
|
description="Native EasyCache implementation.",
|
||||||
|
category="advanced/debug/model",
|
||||||
|
is_experimental=True,
|
||||||
|
inputs=[
|
||||||
|
io.Model.Input("model", tooltip="The model to add EasyCache to."),
|
||||||
|
io.Float.Input("reuse_threshold", min=0.0, default=0.2, max=3.0, step=0.01, tooltip="The threshold for reusing cached steps."),
|
||||||
|
io.Float.Input("start_percent", min=0.0, default=0.15, max=1.0, step=0.01, tooltip="The relative sampling step to begin use of EasyCache."),
|
||||||
|
io.Float.Input("end_percent", min=0.0, default=0.95, max=1.0, step=0.01, tooltip="The relative sampling step to end use of EasyCache."),
|
||||||
|
io.Boolean.Input("verbose", default=False, tooltip="Whether to log verbose information."),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Model.Output(tooltip="The model with EasyCache."),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, model: io.Model.Type, reuse_threshold: float, start_percent: float, end_percent: float, verbose: bool) -> io.NodeOutput:
|
||||||
|
model = model.clone()
|
||||||
|
model.model_options["transformer_options"]["easycache"] = EasyCacheHolder(reuse_threshold, start_percent, end_percent, subsample_factor=8, offload_cache_diff=False, verbose=verbose)
|
||||||
|
model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, "easycache", easycache_sample_wrapper)
|
||||||
|
model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.CALC_COND_BATCH, "easycache", easycache_calc_cond_batch_wrapper)
|
||||||
|
model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, "easycache", easycache_forward_wrapper)
|
||||||
|
return io.NodeOutput(model)
|
||||||
|
|
||||||
|
|
||||||
|
class LazyCacheHolder:
|
||||||
|
def __init__(self, reuse_threshold: float, start_percent: float, end_percent: float, subsample_factor: int, offload_cache_diff: bool, verbose: bool=False):
|
||||||
|
self.name = "LazyCache"
|
||||||
|
self.reuse_threshold = reuse_threshold
|
||||||
|
self.start_percent = start_percent
|
||||||
|
self.end_percent = end_percent
|
||||||
|
self.subsample_factor = subsample_factor
|
||||||
|
self.offload_cache_diff = offload_cache_diff
|
||||||
|
self.verbose = verbose
|
||||||
|
# timestep values
|
||||||
|
self.start_t = 0.0
|
||||||
|
self.end_t = 0.0
|
||||||
|
# control values
|
||||||
|
self.relative_transformation_rate: float = None
|
||||||
|
self.cumulative_change_rate = 0.0
|
||||||
|
self.initial_step = True
|
||||||
|
# cache values
|
||||||
|
self.x_prev_subsampled: torch.Tensor = None
|
||||||
|
self.output_prev_subsampled: torch.Tensor = None
|
||||||
|
self.output_prev_norm: torch.Tensor = None
|
||||||
|
self.cache_diff: torch.Tensor = None
|
||||||
|
self.output_change_rates = []
|
||||||
|
self.approx_output_change_rates = []
|
||||||
|
self.total_steps_skipped = 0
|
||||||
|
|
||||||
|
def has_cache_diff(self) -> bool:
|
||||||
|
return self.cache_diff is not None
|
||||||
|
|
||||||
|
def is_past_end_timestep(self, timestep: float) -> bool:
|
||||||
|
return not (timestep[0] > self.end_t).item()
|
||||||
|
|
||||||
|
def should_do_easycache(self, timestep: float) -> bool:
|
||||||
|
return (timestep[0] <= self.start_t).item()
|
||||||
|
|
||||||
|
def has_x_prev_subsampled(self) -> bool:
|
||||||
|
return self.x_prev_subsampled is not None
|
||||||
|
|
||||||
|
def has_output_prev_subsampled(self) -> bool:
|
||||||
|
return self.output_prev_subsampled is not None
|
||||||
|
|
||||||
|
def has_output_prev_norm(self) -> bool:
|
||||||
|
return self.output_prev_norm is not None
|
||||||
|
|
||||||
|
def has_relative_transformation_rate(self) -> bool:
|
||||||
|
return self.relative_transformation_rate is not None
|
||||||
|
|
||||||
|
def prepare_timesteps(self, model_sampling):
|
||||||
|
self.start_t = model_sampling.percent_to_sigma(self.start_percent)
|
||||||
|
self.end_t = model_sampling.percent_to_sigma(self.end_percent)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def subsample(self, x: torch.Tensor, clone: bool = True) -> torch.Tensor:
|
||||||
|
if self.subsample_factor > 1:
|
||||||
|
to_return = x[..., ::self.subsample_factor, ::self.subsample_factor]
|
||||||
|
if clone:
|
||||||
|
return to_return.clone()
|
||||||
|
return to_return
|
||||||
|
if clone:
|
||||||
|
return x.clone()
|
||||||
|
return x
|
||||||
|
|
||||||
|
def apply_cache_diff(self, x: torch.Tensor):
|
||||||
|
self.total_steps_skipped += 1
|
||||||
|
return x + self.cache_diff.to(x.device)
|
||||||
|
|
||||||
|
def update_cache_diff(self, output: torch.Tensor, x: torch.Tensor):
|
||||||
|
self.cache_diff = output - x
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.relative_transformation_rate = 0.0
|
||||||
|
self.cumulative_change_rate = 0.0
|
||||||
|
self.initial_step = True
|
||||||
|
self.output_change_rates = []
|
||||||
|
self.approx_output_change_rates = []
|
||||||
|
del self.cache_diff
|
||||||
|
self.cache_diff = None
|
||||||
|
self.total_steps_skipped = 0
|
||||||
|
return self
|
||||||
|
|
||||||
|
def clone(self):
|
||||||
|
return LazyCacheHolder(self.reuse_threshold, self.start_percent, self.end_percent, self.subsample_factor, self.offload_cache_diff, self.verbose)
|
||||||
|
|
||||||
|
class LazyCacheNode(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls) -> io.Schema:
|
||||||
|
return io.Schema(
|
||||||
|
node_id="LazyCache",
|
||||||
|
display_name="LazyCache",
|
||||||
|
description="A homebrew version of EasyCache - even 'easier' version of EasyCache to implement. Overall works worse than EasyCache, but better in some rare cases AND universal compatibility with everything in ComfyUI.",
|
||||||
|
category="advanced/debug/model",
|
||||||
|
is_experimental=True,
|
||||||
|
inputs=[
|
||||||
|
io.Model.Input("model", tooltip="The model to add LazyCache to."),
|
||||||
|
io.Float.Input("reuse_threshold", min=0.0, default=0.2, max=3.0, step=0.01, tooltip="The threshold for reusing cached steps."),
|
||||||
|
io.Float.Input("start_percent", min=0.0, default=0.15, max=1.0, step=0.01, tooltip="The relative sampling step to begin use of LazyCache."),
|
||||||
|
io.Float.Input("end_percent", min=0.0, default=0.95, max=1.0, step=0.01, tooltip="The relative sampling step to end use of LazyCache."),
|
||||||
|
io.Boolean.Input("verbose", default=False, tooltip="Whether to log verbose information."),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Model.Output(tooltip="The model with LazyCache."),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, model: io.Model.Type, reuse_threshold: float, start_percent: float, end_percent: float, verbose: bool) -> io.NodeOutput:
|
||||||
|
model = model.clone()
|
||||||
|
model.model_options["transformer_options"]["easycache"] = LazyCacheHolder(reuse_threshold, start_percent, end_percent, subsample_factor=8, offload_cache_diff=False, verbose=verbose)
|
||||||
|
model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, "lazycache", easycache_sample_wrapper)
|
||||||
|
model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.PREDICT_NOISE, "lazycache", lazycache_predict_noise_wrapper)
|
||||||
|
return io.NodeOutput(model)
|
||||||
|
|
||||||
|
|
||||||
|
class EasyCacheExtension(ComfyExtension):
|
||||||
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||||
|
return [
|
||||||
|
EasyCacheNode,
|
||||||
|
LazyCacheNode,
|
||||||
|
]
|
||||||
|
|
||||||
|
def comfy_entrypoint():
|
||||||
|
return EasyCacheExtension()
|
3
nodes.py
3
nodes.py
@@ -2322,7 +2322,8 @@ async def init_builtin_extra_nodes():
|
|||||||
"nodes_tcfg.py",
|
"nodes_tcfg.py",
|
||||||
"nodes_context_windows.py",
|
"nodes_context_windows.py",
|
||||||
"nodes_qwen.py",
|
"nodes_qwen.py",
|
||||||
"nodes_model_patch.py"
|
"nodes_model_patch.py",
|
||||||
|
"nodes_easycache.py",
|
||||||
]
|
]
|
||||||
|
|
||||||
import_failed = []
|
import_failed = []
|
||||||
|
Reference in New Issue
Block a user