mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-06-09 23:57:14 +00:00
Merge branch 'master' into worksplit-multigpu
This commit is contained in:
commit
272e8d42c1
@ -66,6 +66,7 @@ fpunet_group.add_argument("--bf16-unet", action="store_true", help="Run the diff
|
|||||||
fpunet_group.add_argument("--fp16-unet", action="store_true", help="Run the diffusion model in fp16")
|
fpunet_group.add_argument("--fp16-unet", action="store_true", help="Run the diffusion model in fp16")
|
||||||
fpunet_group.add_argument("--fp8_e4m3fn-unet", action="store_true", help="Store unet weights in fp8_e4m3fn.")
|
fpunet_group.add_argument("--fp8_e4m3fn-unet", action="store_true", help="Store unet weights in fp8_e4m3fn.")
|
||||||
fpunet_group.add_argument("--fp8_e5m2-unet", action="store_true", help="Store unet weights in fp8_e5m2.")
|
fpunet_group.add_argument("--fp8_e5m2-unet", action="store_true", help="Store unet weights in fp8_e5m2.")
|
||||||
|
fpunet_group.add_argument("--fp8_e8m0fnu-unet", action="store_true", help="Store unet weights in fp8_e8m0fnu.")
|
||||||
|
|
||||||
fpvae_group = parser.add_mutually_exclusive_group()
|
fpvae_group = parser.add_mutually_exclusive_group()
|
||||||
fpvae_group.add_argument("--fp16-vae", action="store_true", help="Run the VAE in fp16, might cause black images.")
|
fpvae_group.add_argument("--fp16-vae", action="store_true", help="Run the VAE in fp16, might cause black images.")
|
||||||
|
@ -115,6 +115,11 @@ class InputTypeOptions(TypedDict):
|
|||||||
"""When a link exists, rather than receiving the evaluated value, you will receive the link (i.e. `["nodeId", <outputIndex>]`). Designed for node expansion."""
|
"""When a link exists, rather than receiving the evaluated value, you will receive the link (i.e. `["nodeId", <outputIndex>]`). Designed for node expansion."""
|
||||||
tooltip: NotRequired[str]
|
tooltip: NotRequired[str]
|
||||||
"""Tooltip for the input (or widget), shown on pointer hover"""
|
"""Tooltip for the input (or widget), shown on pointer hover"""
|
||||||
|
socketless: NotRequired[bool]
|
||||||
|
"""All inputs (including widgets) have an input socket to connect links. When ``true``, if there is a widget for this input, no socket will be created.
|
||||||
|
Available from frontend v1.17.5
|
||||||
|
Ref: https://github.com/Comfy-Org/ComfyUI_frontend/pull/3548
|
||||||
|
"""
|
||||||
# class InputTypeNumber(InputTypeOptions):
|
# class InputTypeNumber(InputTypeOptions):
|
||||||
# default: float | int
|
# default: float | int
|
||||||
min: NotRequired[float]
|
min: NotRequired[float]
|
||||||
|
@ -779,6 +779,7 @@ def load_controlnet_state_dict(state_dict, model=None, model_options={}):
|
|||||||
return control
|
return control
|
||||||
|
|
||||||
def load_controlnet(ckpt_path, model=None, model_options={}):
|
def load_controlnet(ckpt_path, model=None, model_options={}):
|
||||||
|
model_options = model_options.copy()
|
||||||
if "global_average_pooling" not in model_options:
|
if "global_average_pooling" not in model_options:
|
||||||
filename = os.path.splitext(ckpt_path)[0]
|
filename = os.path.splitext(ckpt_path)[0]
|
||||||
if filename.endswith("_shuffle") or filename.endswith("_shuffle_fp16"): #TODO: smarter way of enabling global_average_pooling
|
if filename.endswith("_shuffle") or filename.endswith("_shuffle_fp16"): #TODO: smarter way of enabling global_average_pooling
|
||||||
|
@ -220,6 +220,34 @@ class WanAttentionBlock(nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class VaceWanAttentionBlock(WanAttentionBlock):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
cross_attn_type,
|
||||||
|
dim,
|
||||||
|
ffn_dim,
|
||||||
|
num_heads,
|
||||||
|
window_size=(-1, -1),
|
||||||
|
qk_norm=True,
|
||||||
|
cross_attn_norm=False,
|
||||||
|
eps=1e-6,
|
||||||
|
block_id=0,
|
||||||
|
operation_settings={}
|
||||||
|
):
|
||||||
|
super().__init__(cross_attn_type, dim, ffn_dim, num_heads, window_size, qk_norm, cross_attn_norm, eps, operation_settings=operation_settings)
|
||||||
|
self.block_id = block_id
|
||||||
|
if block_id == 0:
|
||||||
|
self.before_proj = operation_settings.get("operations").Linear(self.dim, self.dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||||
|
self.after_proj = operation_settings.get("operations").Linear(self.dim, self.dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||||
|
|
||||||
|
def forward(self, c, x, **kwargs):
|
||||||
|
if self.block_id == 0:
|
||||||
|
c = self.before_proj(c) + x
|
||||||
|
c = super().forward(c, **kwargs)
|
||||||
|
c_skip = self.after_proj(c)
|
||||||
|
return c_skip, c
|
||||||
|
|
||||||
|
|
||||||
class Head(nn.Module):
|
class Head(nn.Module):
|
||||||
|
|
||||||
def __init__(self, dim, out_dim, patch_size, eps=1e-6, operation_settings={}):
|
def __init__(self, dim, out_dim, patch_size, eps=1e-6, operation_settings={}):
|
||||||
@ -395,6 +423,7 @@ class WanModel(torch.nn.Module):
|
|||||||
clip_fea=None,
|
clip_fea=None,
|
||||||
freqs=None,
|
freqs=None,
|
||||||
transformer_options={},
|
transformer_options={},
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
Forward pass through the diffusion model
|
Forward pass through the diffusion model
|
||||||
@ -457,7 +486,7 @@ class WanModel(torch.nn.Module):
|
|||||||
x = self.unpatchify(x, grid_sizes)
|
x = self.unpatchify(x, grid_sizes)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def forward(self, x, timestep, context, clip_fea=None, transformer_options={},**kwargs):
|
def forward(self, x, timestep, context, clip_fea=None, transformer_options={}, **kwargs):
|
||||||
bs, c, t, h, w = x.shape
|
bs, c, t, h, w = x.shape
|
||||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
|
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
|
||||||
patch_size = self.patch_size
|
patch_size = self.patch_size
|
||||||
@ -471,7 +500,7 @@ class WanModel(torch.nn.Module):
|
|||||||
img_ids = repeat(img_ids, "t h w c -> b (t h w) c", b=bs)
|
img_ids = repeat(img_ids, "t h w c -> b (t h w) c", b=bs)
|
||||||
|
|
||||||
freqs = self.rope_embedder(img_ids).movedim(1, 2)
|
freqs = self.rope_embedder(img_ids).movedim(1, 2)
|
||||||
return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options)[:, :, :t, :h, :w]
|
return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options, **kwargs)[:, :, :t, :h, :w]
|
||||||
|
|
||||||
def unpatchify(self, x, grid_sizes):
|
def unpatchify(self, x, grid_sizes):
|
||||||
r"""
|
r"""
|
||||||
@ -496,3 +525,115 @@ class WanModel(torch.nn.Module):
|
|||||||
u = torch.einsum('bfhwpqrc->bcfphqwr', u)
|
u = torch.einsum('bfhwpqrc->bcfphqwr', u)
|
||||||
u = u.reshape(b, c, *[i * j for i, j in zip(grid_sizes, self.patch_size)])
|
u = u.reshape(b, c, *[i * j for i, j in zip(grid_sizes, self.patch_size)])
|
||||||
return u
|
return u
|
||||||
|
|
||||||
|
|
||||||
|
class VaceWanModel(WanModel):
|
||||||
|
r"""
|
||||||
|
Wan diffusion backbone supporting both text-to-video and image-to-video.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
model_type='vace',
|
||||||
|
patch_size=(1, 2, 2),
|
||||||
|
text_len=512,
|
||||||
|
in_dim=16,
|
||||||
|
dim=2048,
|
||||||
|
ffn_dim=8192,
|
||||||
|
freq_dim=256,
|
||||||
|
text_dim=4096,
|
||||||
|
out_dim=16,
|
||||||
|
num_heads=16,
|
||||||
|
num_layers=32,
|
||||||
|
window_size=(-1, -1),
|
||||||
|
qk_norm=True,
|
||||||
|
cross_attn_norm=True,
|
||||||
|
eps=1e-6,
|
||||||
|
flf_pos_embed_token_number=None,
|
||||||
|
image_model=None,
|
||||||
|
vace_layers=None,
|
||||||
|
vace_in_dim=None,
|
||||||
|
device=None,
|
||||||
|
dtype=None,
|
||||||
|
operations=None,
|
||||||
|
):
|
||||||
|
|
||||||
|
super().__init__(model_type='t2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, image_model=image_model, device=device, dtype=dtype, operations=operations)
|
||||||
|
operation_settings = {"operations": operations, "device": device, "dtype": dtype}
|
||||||
|
|
||||||
|
# Vace
|
||||||
|
if vace_layers is not None:
|
||||||
|
self.vace_layers = vace_layers
|
||||||
|
self.vace_in_dim = vace_in_dim
|
||||||
|
# vace blocks
|
||||||
|
self.vace_blocks = nn.ModuleList([
|
||||||
|
VaceWanAttentionBlock('t2v_cross_attn', self.dim, self.ffn_dim, self.num_heads, self.window_size, self.qk_norm, self.cross_attn_norm, self.eps, block_id=i, operation_settings=operation_settings)
|
||||||
|
for i in range(self.vace_layers)
|
||||||
|
])
|
||||||
|
|
||||||
|
self.vace_layers_mapping = {i: n for n, i in enumerate(range(0, self.num_layers, self.num_layers // self.vace_layers))}
|
||||||
|
# vace patch embeddings
|
||||||
|
self.vace_patch_embedding = operations.Conv3d(
|
||||||
|
self.vace_in_dim, self.dim, kernel_size=self.patch_size, stride=self.patch_size, device=device, dtype=torch.float32
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward_orig(
|
||||||
|
self,
|
||||||
|
x,
|
||||||
|
t,
|
||||||
|
context,
|
||||||
|
vace_context,
|
||||||
|
vace_strength=1.0,
|
||||||
|
clip_fea=None,
|
||||||
|
freqs=None,
|
||||||
|
transformer_options={},
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
# embeddings
|
||||||
|
x = self.patch_embedding(x.float()).to(x.dtype)
|
||||||
|
grid_sizes = x.shape[2:]
|
||||||
|
x = x.flatten(2).transpose(1, 2)
|
||||||
|
|
||||||
|
# time embeddings
|
||||||
|
e = self.time_embedding(
|
||||||
|
sinusoidal_embedding_1d(self.freq_dim, t).to(dtype=x[0].dtype))
|
||||||
|
e0 = self.time_projection(e).unflatten(1, (6, self.dim))
|
||||||
|
|
||||||
|
# context
|
||||||
|
context = self.text_embedding(context)
|
||||||
|
|
||||||
|
context_img_len = None
|
||||||
|
if clip_fea is not None:
|
||||||
|
if self.img_emb is not None:
|
||||||
|
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
|
||||||
|
context = torch.concat([context_clip, context], dim=1)
|
||||||
|
context_img_len = clip_fea.shape[-2]
|
||||||
|
|
||||||
|
c = self.vace_patch_embedding(vace_context.float()).to(vace_context.dtype)
|
||||||
|
c = c.flatten(2).transpose(1, 2)
|
||||||
|
|
||||||
|
# arguments
|
||||||
|
x_orig = x
|
||||||
|
|
||||||
|
patches_replace = transformer_options.get("patches_replace", {})
|
||||||
|
blocks_replace = patches_replace.get("dit", {})
|
||||||
|
for i, block in enumerate(self.blocks):
|
||||||
|
if ("double_block", i) in blocks_replace:
|
||||||
|
def block_wrap(args):
|
||||||
|
out = {}
|
||||||
|
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len)
|
||||||
|
return out
|
||||||
|
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap})
|
||||||
|
x = out["img"]
|
||||||
|
else:
|
||||||
|
x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len)
|
||||||
|
|
||||||
|
ii = self.vace_layers_mapping.get(i, None)
|
||||||
|
if ii is not None:
|
||||||
|
c_skip, c = self.vace_blocks[ii](c, x=x_orig, e=e0, freqs=freqs, context=context, context_img_len=context_img_len)
|
||||||
|
x += c_skip * vace_strength
|
||||||
|
# head
|
||||||
|
x = self.head(x, e)
|
||||||
|
|
||||||
|
# unpatchify
|
||||||
|
x = self.unpatchify(x, grid_sizes)
|
||||||
|
return x
|
||||||
|
321
comfy/lora.py
321
comfy/lora.py
@ -20,6 +20,7 @@ from __future__ import annotations
|
|||||||
import comfy.utils
|
import comfy.utils
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import comfy.model_base
|
import comfy.model_base
|
||||||
|
import comfy.weight_adapter as weight_adapter
|
||||||
import logging
|
import logging
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -49,139 +50,12 @@ def load_lora(lora, to_load, log_missing=True):
|
|||||||
dora_scale = lora[dora_scale_name]
|
dora_scale = lora[dora_scale_name]
|
||||||
loaded_keys.add(dora_scale_name)
|
loaded_keys.add(dora_scale_name)
|
||||||
|
|
||||||
reshape_name = "{}.reshape_weight".format(x)
|
for adapter_cls in weight_adapter.adapters:
|
||||||
reshape = None
|
adapter = adapter_cls.load(x, lora, alpha, dora_scale, loaded_keys)
|
||||||
if reshape_name in lora.keys():
|
if adapter is not None:
|
||||||
try:
|
patch_dict[to_load[x]] = adapter
|
||||||
reshape = lora[reshape_name].tolist()
|
loaded_keys.update(adapter.loaded_keys)
|
||||||
loaded_keys.add(reshape_name)
|
continue
|
||||||
except:
|
|
||||||
pass
|
|
||||||
|
|
||||||
regular_lora = "{}.lora_up.weight".format(x)
|
|
||||||
diffusers_lora = "{}_lora.up.weight".format(x)
|
|
||||||
diffusers2_lora = "{}.lora_B.weight".format(x)
|
|
||||||
diffusers3_lora = "{}.lora.up.weight".format(x)
|
|
||||||
mochi_lora = "{}.lora_B".format(x)
|
|
||||||
transformers_lora = "{}.lora_linear_layer.up.weight".format(x)
|
|
||||||
A_name = None
|
|
||||||
|
|
||||||
if regular_lora in lora.keys():
|
|
||||||
A_name = regular_lora
|
|
||||||
B_name = "{}.lora_down.weight".format(x)
|
|
||||||
mid_name = "{}.lora_mid.weight".format(x)
|
|
||||||
elif diffusers_lora in lora.keys():
|
|
||||||
A_name = diffusers_lora
|
|
||||||
B_name = "{}_lora.down.weight".format(x)
|
|
||||||
mid_name = None
|
|
||||||
elif diffusers2_lora in lora.keys():
|
|
||||||
A_name = diffusers2_lora
|
|
||||||
B_name = "{}.lora_A.weight".format(x)
|
|
||||||
mid_name = None
|
|
||||||
elif diffusers3_lora in lora.keys():
|
|
||||||
A_name = diffusers3_lora
|
|
||||||
B_name = "{}.lora.down.weight".format(x)
|
|
||||||
mid_name = None
|
|
||||||
elif mochi_lora in lora.keys():
|
|
||||||
A_name = mochi_lora
|
|
||||||
B_name = "{}.lora_A".format(x)
|
|
||||||
mid_name = None
|
|
||||||
elif transformers_lora in lora.keys():
|
|
||||||
A_name = transformers_lora
|
|
||||||
B_name ="{}.lora_linear_layer.down.weight".format(x)
|
|
||||||
mid_name = None
|
|
||||||
|
|
||||||
if A_name is not None:
|
|
||||||
mid = None
|
|
||||||
if mid_name is not None and mid_name in lora.keys():
|
|
||||||
mid = lora[mid_name]
|
|
||||||
loaded_keys.add(mid_name)
|
|
||||||
patch_dict[to_load[x]] = ("lora", (lora[A_name], lora[B_name], alpha, mid, dora_scale, reshape))
|
|
||||||
loaded_keys.add(A_name)
|
|
||||||
loaded_keys.add(B_name)
|
|
||||||
|
|
||||||
|
|
||||||
######## loha
|
|
||||||
hada_w1_a_name = "{}.hada_w1_a".format(x)
|
|
||||||
hada_w1_b_name = "{}.hada_w1_b".format(x)
|
|
||||||
hada_w2_a_name = "{}.hada_w2_a".format(x)
|
|
||||||
hada_w2_b_name = "{}.hada_w2_b".format(x)
|
|
||||||
hada_t1_name = "{}.hada_t1".format(x)
|
|
||||||
hada_t2_name = "{}.hada_t2".format(x)
|
|
||||||
if hada_w1_a_name in lora.keys():
|
|
||||||
hada_t1 = None
|
|
||||||
hada_t2 = None
|
|
||||||
if hada_t1_name in lora.keys():
|
|
||||||
hada_t1 = lora[hada_t1_name]
|
|
||||||
hada_t2 = lora[hada_t2_name]
|
|
||||||
loaded_keys.add(hada_t1_name)
|
|
||||||
loaded_keys.add(hada_t2_name)
|
|
||||||
|
|
||||||
patch_dict[to_load[x]] = ("loha", (lora[hada_w1_a_name], lora[hada_w1_b_name], alpha, lora[hada_w2_a_name], lora[hada_w2_b_name], hada_t1, hada_t2, dora_scale))
|
|
||||||
loaded_keys.add(hada_w1_a_name)
|
|
||||||
loaded_keys.add(hada_w1_b_name)
|
|
||||||
loaded_keys.add(hada_w2_a_name)
|
|
||||||
loaded_keys.add(hada_w2_b_name)
|
|
||||||
|
|
||||||
|
|
||||||
######## lokr
|
|
||||||
lokr_w1_name = "{}.lokr_w1".format(x)
|
|
||||||
lokr_w2_name = "{}.lokr_w2".format(x)
|
|
||||||
lokr_w1_a_name = "{}.lokr_w1_a".format(x)
|
|
||||||
lokr_w1_b_name = "{}.lokr_w1_b".format(x)
|
|
||||||
lokr_t2_name = "{}.lokr_t2".format(x)
|
|
||||||
lokr_w2_a_name = "{}.lokr_w2_a".format(x)
|
|
||||||
lokr_w2_b_name = "{}.lokr_w2_b".format(x)
|
|
||||||
|
|
||||||
lokr_w1 = None
|
|
||||||
if lokr_w1_name in lora.keys():
|
|
||||||
lokr_w1 = lora[lokr_w1_name]
|
|
||||||
loaded_keys.add(lokr_w1_name)
|
|
||||||
|
|
||||||
lokr_w2 = None
|
|
||||||
if lokr_w2_name in lora.keys():
|
|
||||||
lokr_w2 = lora[lokr_w2_name]
|
|
||||||
loaded_keys.add(lokr_w2_name)
|
|
||||||
|
|
||||||
lokr_w1_a = None
|
|
||||||
if lokr_w1_a_name in lora.keys():
|
|
||||||
lokr_w1_a = lora[lokr_w1_a_name]
|
|
||||||
loaded_keys.add(lokr_w1_a_name)
|
|
||||||
|
|
||||||
lokr_w1_b = None
|
|
||||||
if lokr_w1_b_name in lora.keys():
|
|
||||||
lokr_w1_b = lora[lokr_w1_b_name]
|
|
||||||
loaded_keys.add(lokr_w1_b_name)
|
|
||||||
|
|
||||||
lokr_w2_a = None
|
|
||||||
if lokr_w2_a_name in lora.keys():
|
|
||||||
lokr_w2_a = lora[lokr_w2_a_name]
|
|
||||||
loaded_keys.add(lokr_w2_a_name)
|
|
||||||
|
|
||||||
lokr_w2_b = None
|
|
||||||
if lokr_w2_b_name in lora.keys():
|
|
||||||
lokr_w2_b = lora[lokr_w2_b_name]
|
|
||||||
loaded_keys.add(lokr_w2_b_name)
|
|
||||||
|
|
||||||
lokr_t2 = None
|
|
||||||
if lokr_t2_name in lora.keys():
|
|
||||||
lokr_t2 = lora[lokr_t2_name]
|
|
||||||
loaded_keys.add(lokr_t2_name)
|
|
||||||
|
|
||||||
if (lokr_w1 is not None) or (lokr_w2 is not None) or (lokr_w1_a is not None) or (lokr_w2_a is not None):
|
|
||||||
patch_dict[to_load[x]] = ("lokr", (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2, dora_scale))
|
|
||||||
|
|
||||||
#glora
|
|
||||||
a1_name = "{}.a1.weight".format(x)
|
|
||||||
a2_name = "{}.a2.weight".format(x)
|
|
||||||
b1_name = "{}.b1.weight".format(x)
|
|
||||||
b2_name = "{}.b2.weight".format(x)
|
|
||||||
if a1_name in lora:
|
|
||||||
patch_dict[to_load[x]] = ("glora", (lora[a1_name], lora[a2_name], lora[b1_name], lora[b2_name], alpha, dora_scale))
|
|
||||||
loaded_keys.add(a1_name)
|
|
||||||
loaded_keys.add(a2_name)
|
|
||||||
loaded_keys.add(b1_name)
|
|
||||||
loaded_keys.add(b2_name)
|
|
||||||
|
|
||||||
w_norm_name = "{}.w_norm".format(x)
|
w_norm_name = "{}.w_norm".format(x)
|
||||||
b_norm_name = "{}.b_norm".format(x)
|
b_norm_name = "{}.b_norm".format(x)
|
||||||
@ -408,26 +282,6 @@ def model_lora_keys_unet(model, key_map={}):
|
|||||||
return key_map
|
return key_map
|
||||||
|
|
||||||
|
|
||||||
def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function):
|
|
||||||
dora_scale = comfy.model_management.cast_to_device(dora_scale, weight.device, intermediate_dtype)
|
|
||||||
lora_diff *= alpha
|
|
||||||
weight_calc = weight + function(lora_diff).type(weight.dtype)
|
|
||||||
weight_norm = (
|
|
||||||
weight_calc.transpose(0, 1)
|
|
||||||
.reshape(weight_calc.shape[1], -1)
|
|
||||||
.norm(dim=1, keepdim=True)
|
|
||||||
.reshape(weight_calc.shape[1], *[1] * (weight_calc.dim() - 1))
|
|
||||||
.transpose(0, 1)
|
|
||||||
)
|
|
||||||
|
|
||||||
weight_calc *= (dora_scale / weight_norm).type(weight.dtype)
|
|
||||||
if strength != 1.0:
|
|
||||||
weight_calc -= weight
|
|
||||||
weight += strength * (weight_calc)
|
|
||||||
else:
|
|
||||||
weight[:] = weight_calc
|
|
||||||
return weight
|
|
||||||
|
|
||||||
def pad_tensor_to_shape(tensor: torch.Tensor, new_shape: list[int]) -> torch.Tensor:
|
def pad_tensor_to_shape(tensor: torch.Tensor, new_shape: list[int]) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Pad a tensor to a new shape with zeros.
|
Pad a tensor to a new shape with zeros.
|
||||||
@ -482,6 +336,16 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32, ori
|
|||||||
if isinstance(v, list):
|
if isinstance(v, list):
|
||||||
v = (calculate_weight(v[1:], v[0][1](comfy.model_management.cast_to_device(v[0][0], weight.device, intermediate_dtype, copy=True), inplace=True), key, intermediate_dtype=intermediate_dtype), )
|
v = (calculate_weight(v[1:], v[0][1](comfy.model_management.cast_to_device(v[0][0], weight.device, intermediate_dtype, copy=True), inplace=True), key, intermediate_dtype=intermediate_dtype), )
|
||||||
|
|
||||||
|
if isinstance(v, weight_adapter.WeightAdapterBase):
|
||||||
|
output = v.calculate_weight(weight, key, strength, strength_model, offset, function, intermediate_dtype, original_weights)
|
||||||
|
if output is None:
|
||||||
|
logging.warning("Calculate Weight Failed: {} {}".format(v.name, key))
|
||||||
|
else:
|
||||||
|
weight = output
|
||||||
|
if old_weight is not None:
|
||||||
|
weight = old_weight
|
||||||
|
continue
|
||||||
|
|
||||||
if len(v) == 1:
|
if len(v) == 1:
|
||||||
patch_type = "diff"
|
patch_type = "diff"
|
||||||
elif len(v) == 2:
|
elif len(v) == 2:
|
||||||
@ -508,157 +372,6 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32, ori
|
|||||||
diff_weight = comfy.model_management.cast_to_device(target_weight, weight.device, intermediate_dtype) - \
|
diff_weight = comfy.model_management.cast_to_device(target_weight, weight.device, intermediate_dtype) - \
|
||||||
comfy.model_management.cast_to_device(original_weights[key][0][0], weight.device, intermediate_dtype)
|
comfy.model_management.cast_to_device(original_weights[key][0][0], weight.device, intermediate_dtype)
|
||||||
weight += function(strength * comfy.model_management.cast_to_device(diff_weight, weight.device, weight.dtype))
|
weight += function(strength * comfy.model_management.cast_to_device(diff_weight, weight.device, weight.dtype))
|
||||||
elif patch_type == "lora": #lora/locon
|
|
||||||
mat1 = comfy.model_management.cast_to_device(v[0], weight.device, intermediate_dtype)
|
|
||||||
mat2 = comfy.model_management.cast_to_device(v[1], weight.device, intermediate_dtype)
|
|
||||||
dora_scale = v[4]
|
|
||||||
reshape = v[5]
|
|
||||||
|
|
||||||
if reshape is not None:
|
|
||||||
weight = pad_tensor_to_shape(weight, reshape)
|
|
||||||
|
|
||||||
if v[2] is not None:
|
|
||||||
alpha = v[2] / mat2.shape[0]
|
|
||||||
else:
|
|
||||||
alpha = 1.0
|
|
||||||
|
|
||||||
if v[3] is not None:
|
|
||||||
#locon mid weights, hopefully the math is fine because I didn't properly test it
|
|
||||||
mat3 = comfy.model_management.cast_to_device(v[3], weight.device, intermediate_dtype)
|
|
||||||
final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]]
|
|
||||||
mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1)
|
|
||||||
try:
|
|
||||||
lora_diff = torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1)).reshape(weight.shape)
|
|
||||||
if dora_scale is not None:
|
|
||||||
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
|
|
||||||
else:
|
|
||||||
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
|
||||||
except Exception as e:
|
|
||||||
logging.error("ERROR {} {} {}".format(patch_type, key, e))
|
|
||||||
elif patch_type == "lokr":
|
|
||||||
w1 = v[0]
|
|
||||||
w2 = v[1]
|
|
||||||
w1_a = v[3]
|
|
||||||
w1_b = v[4]
|
|
||||||
w2_a = v[5]
|
|
||||||
w2_b = v[6]
|
|
||||||
t2 = v[7]
|
|
||||||
dora_scale = v[8]
|
|
||||||
dim = None
|
|
||||||
|
|
||||||
if w1 is None:
|
|
||||||
dim = w1_b.shape[0]
|
|
||||||
w1 = torch.mm(comfy.model_management.cast_to_device(w1_a, weight.device, intermediate_dtype),
|
|
||||||
comfy.model_management.cast_to_device(w1_b, weight.device, intermediate_dtype))
|
|
||||||
else:
|
|
||||||
w1 = comfy.model_management.cast_to_device(w1, weight.device, intermediate_dtype)
|
|
||||||
|
|
||||||
if w2 is None:
|
|
||||||
dim = w2_b.shape[0]
|
|
||||||
if t2 is None:
|
|
||||||
w2 = torch.mm(comfy.model_management.cast_to_device(w2_a, weight.device, intermediate_dtype),
|
|
||||||
comfy.model_management.cast_to_device(w2_b, weight.device, intermediate_dtype))
|
|
||||||
else:
|
|
||||||
w2 = torch.einsum('i j k l, j r, i p -> p r k l',
|
|
||||||
comfy.model_management.cast_to_device(t2, weight.device, intermediate_dtype),
|
|
||||||
comfy.model_management.cast_to_device(w2_b, weight.device, intermediate_dtype),
|
|
||||||
comfy.model_management.cast_to_device(w2_a, weight.device, intermediate_dtype))
|
|
||||||
else:
|
|
||||||
w2 = comfy.model_management.cast_to_device(w2, weight.device, intermediate_dtype)
|
|
||||||
|
|
||||||
if len(w2.shape) == 4:
|
|
||||||
w1 = w1.unsqueeze(2).unsqueeze(2)
|
|
||||||
if v[2] is not None and dim is not None:
|
|
||||||
alpha = v[2] / dim
|
|
||||||
else:
|
|
||||||
alpha = 1.0
|
|
||||||
|
|
||||||
try:
|
|
||||||
lora_diff = torch.kron(w1, w2).reshape(weight.shape)
|
|
||||||
if dora_scale is not None:
|
|
||||||
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
|
|
||||||
else:
|
|
||||||
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
|
||||||
except Exception as e:
|
|
||||||
logging.error("ERROR {} {} {}".format(patch_type, key, e))
|
|
||||||
elif patch_type == "loha":
|
|
||||||
w1a = v[0]
|
|
||||||
w1b = v[1]
|
|
||||||
if v[2] is not None:
|
|
||||||
alpha = v[2] / w1b.shape[0]
|
|
||||||
else:
|
|
||||||
alpha = 1.0
|
|
||||||
|
|
||||||
w2a = v[3]
|
|
||||||
w2b = v[4]
|
|
||||||
dora_scale = v[7]
|
|
||||||
if v[5] is not None: #cp decomposition
|
|
||||||
t1 = v[5]
|
|
||||||
t2 = v[6]
|
|
||||||
m1 = torch.einsum('i j k l, j r, i p -> p r k l',
|
|
||||||
comfy.model_management.cast_to_device(t1, weight.device, intermediate_dtype),
|
|
||||||
comfy.model_management.cast_to_device(w1b, weight.device, intermediate_dtype),
|
|
||||||
comfy.model_management.cast_to_device(w1a, weight.device, intermediate_dtype))
|
|
||||||
|
|
||||||
m2 = torch.einsum('i j k l, j r, i p -> p r k l',
|
|
||||||
comfy.model_management.cast_to_device(t2, weight.device, intermediate_dtype),
|
|
||||||
comfy.model_management.cast_to_device(w2b, weight.device, intermediate_dtype),
|
|
||||||
comfy.model_management.cast_to_device(w2a, weight.device, intermediate_dtype))
|
|
||||||
else:
|
|
||||||
m1 = torch.mm(comfy.model_management.cast_to_device(w1a, weight.device, intermediate_dtype),
|
|
||||||
comfy.model_management.cast_to_device(w1b, weight.device, intermediate_dtype))
|
|
||||||
m2 = torch.mm(comfy.model_management.cast_to_device(w2a, weight.device, intermediate_dtype),
|
|
||||||
comfy.model_management.cast_to_device(w2b, weight.device, intermediate_dtype))
|
|
||||||
|
|
||||||
try:
|
|
||||||
lora_diff = (m1 * m2).reshape(weight.shape)
|
|
||||||
if dora_scale is not None:
|
|
||||||
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
|
|
||||||
else:
|
|
||||||
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
|
||||||
except Exception as e:
|
|
||||||
logging.error("ERROR {} {} {}".format(patch_type, key, e))
|
|
||||||
elif patch_type == "glora":
|
|
||||||
dora_scale = v[5]
|
|
||||||
|
|
||||||
old_glora = False
|
|
||||||
if v[3].shape[1] == v[2].shape[0] == v[0].shape[0] == v[1].shape[1]:
|
|
||||||
rank = v[0].shape[0]
|
|
||||||
old_glora = True
|
|
||||||
|
|
||||||
if v[3].shape[0] == v[2].shape[1] == v[0].shape[1] == v[1].shape[0]:
|
|
||||||
if old_glora and v[1].shape[0] == weight.shape[0] and weight.shape[0] == weight.shape[1]:
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
old_glora = False
|
|
||||||
rank = v[1].shape[0]
|
|
||||||
|
|
||||||
a1 = comfy.model_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, intermediate_dtype)
|
|
||||||
a2 = comfy.model_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, intermediate_dtype)
|
|
||||||
b1 = comfy.model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, intermediate_dtype)
|
|
||||||
b2 = comfy.model_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, intermediate_dtype)
|
|
||||||
|
|
||||||
if v[4] is not None:
|
|
||||||
alpha = v[4] / rank
|
|
||||||
else:
|
|
||||||
alpha = 1.0
|
|
||||||
|
|
||||||
try:
|
|
||||||
if old_glora:
|
|
||||||
lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1).to(dtype=intermediate_dtype), a2), a1)).reshape(weight.shape) #old lycoris glora
|
|
||||||
else:
|
|
||||||
if weight.dim() > 2:
|
|
||||||
lora_diff = torch.einsum("o i ..., i j -> o j ...", torch.einsum("o i ..., i j -> o j ...", weight.to(dtype=intermediate_dtype), a1), a2).reshape(weight.shape)
|
|
||||||
else:
|
|
||||||
lora_diff = torch.mm(torch.mm(weight.to(dtype=intermediate_dtype), a1), a2).reshape(weight.shape)
|
|
||||||
lora_diff += torch.mm(b1, b2).reshape(weight.shape)
|
|
||||||
|
|
||||||
if dora_scale is not None:
|
|
||||||
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
|
|
||||||
else:
|
|
||||||
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
|
||||||
except Exception as e:
|
|
||||||
logging.error("ERROR {} {} {}".format(patch_type, key, e))
|
|
||||||
else:
|
else:
|
||||||
logging.warning("patch type not recognized {} {}".format(patch_type, key))
|
logging.warning("patch type not recognized {} {}".format(patch_type, key))
|
||||||
|
|
||||||
|
@ -1043,6 +1043,37 @@ class WAN21(BaseModel):
|
|||||||
out['clip_fea'] = comfy.conds.CONDRegular(clip_vision_output.penultimate_hidden_states)
|
out['clip_fea'] = comfy.conds.CONDRegular(clip_vision_output.penultimate_hidden_states)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class WAN21_Vace(WAN21):
|
||||||
|
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
|
||||||
|
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.VaceWanModel)
|
||||||
|
self.image_to_video = image_to_video
|
||||||
|
|
||||||
|
def extra_conds(self, **kwargs):
|
||||||
|
out = super().extra_conds(**kwargs)
|
||||||
|
noise = kwargs.get("noise", None)
|
||||||
|
noise_shape = list(noise.shape)
|
||||||
|
vace_frames = kwargs.get("vace_frames", None)
|
||||||
|
if vace_frames is None:
|
||||||
|
noise_shape[1] = 32
|
||||||
|
vace_frames = torch.zeros(noise_shape, device=noise.device, dtype=noise.dtype)
|
||||||
|
|
||||||
|
for i in range(0, vace_frames.shape[1], 16):
|
||||||
|
vace_frames = vace_frames.clone()
|
||||||
|
vace_frames[:, i:i + 16] = self.process_latent_in(vace_frames[:, i:i + 16])
|
||||||
|
|
||||||
|
mask = kwargs.get("vace_mask", None)
|
||||||
|
if mask is None:
|
||||||
|
noise_shape[1] = 64
|
||||||
|
mask = torch.ones(noise_shape, device=noise.device, dtype=noise.dtype)
|
||||||
|
|
||||||
|
out['vace_context'] = comfy.conds.CONDRegular(torch.cat([vace_frames.to(noise), mask.to(noise)], dim=1))
|
||||||
|
|
||||||
|
vace_strength = kwargs.get("vace_strength", 1.0)
|
||||||
|
out['vace_strength'] = comfy.conds.CONDConstant(vace_strength)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
class Hunyuan3Dv2(BaseModel):
|
class Hunyuan3Dv2(BaseModel):
|
||||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan3d.model.Hunyuan3Dv2)
|
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan3d.model.Hunyuan3Dv2)
|
||||||
|
@ -317,10 +317,15 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
|||||||
dit_config["cross_attn_norm"] = True
|
dit_config["cross_attn_norm"] = True
|
||||||
dit_config["eps"] = 1e-6
|
dit_config["eps"] = 1e-6
|
||||||
dit_config["in_dim"] = state_dict['{}patch_embedding.weight'.format(key_prefix)].shape[1]
|
dit_config["in_dim"] = state_dict['{}patch_embedding.weight'.format(key_prefix)].shape[1]
|
||||||
if '{}img_emb.proj.0.bias'.format(key_prefix) in state_dict_keys:
|
if '{}vace_patch_embedding.weight'.format(key_prefix) in state_dict_keys:
|
||||||
dit_config["model_type"] = "i2v"
|
dit_config["model_type"] = "vace"
|
||||||
|
dit_config["vace_in_dim"] = state_dict['{}vace_patch_embedding.weight'.format(key_prefix)].shape[1]
|
||||||
|
dit_config["vace_layers"] = count_blocks(state_dict_keys, '{}vace_blocks.'.format(key_prefix) + '{}.')
|
||||||
else:
|
else:
|
||||||
dit_config["model_type"] = "t2v"
|
if '{}img_emb.proj.0.bias'.format(key_prefix) in state_dict_keys:
|
||||||
|
dit_config["model_type"] = "i2v"
|
||||||
|
else:
|
||||||
|
dit_config["model_type"] = "t2v"
|
||||||
flf_weight = state_dict.get('{}img_emb.emb_pos'.format(key_prefix))
|
flf_weight = state_dict.get('{}img_emb.emb_pos'.format(key_prefix))
|
||||||
if flf_weight is not None:
|
if flf_weight is not None:
|
||||||
dit_config["flf_pos_embed_token_number"] = flf_weight.shape[1]
|
dit_config["flf_pos_embed_token_number"] = flf_weight.shape[1]
|
||||||
|
@ -753,6 +753,8 @@ def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, tor
|
|||||||
return torch.float8_e4m3fn
|
return torch.float8_e4m3fn
|
||||||
if args.fp8_e5m2_unet:
|
if args.fp8_e5m2_unet:
|
||||||
return torch.float8_e5m2
|
return torch.float8_e5m2
|
||||||
|
if args.fp8_e8m0fnu_unet:
|
||||||
|
return torch.float8_e8m0fnu
|
||||||
|
|
||||||
fp8_dtype = None
|
fp8_dtype = None
|
||||||
if weight_dtype in FLOAT8_TYPES:
|
if weight_dtype in FLOAT8_TYPES:
|
||||||
|
34
comfy/sd.py
34
comfy/sd.py
@ -703,6 +703,7 @@ class CLIPType(Enum):
|
|||||||
COSMOS = 11
|
COSMOS = 11
|
||||||
LUMINA2 = 12
|
LUMINA2 = 12
|
||||||
WAN = 13
|
WAN = 13
|
||||||
|
HIDREAM = 14
|
||||||
|
|
||||||
|
|
||||||
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
|
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
|
||||||
@ -791,6 +792,9 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
|||||||
elif clip_type == CLIPType.SD3:
|
elif clip_type == CLIPType.SD3:
|
||||||
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=False, clip_g=True, t5=False)
|
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=False, clip_g=True, t5=False)
|
||||||
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
|
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
|
||||||
|
elif clip_type == CLIPType.HIDREAM:
|
||||||
|
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(clip_l=False, clip_g=True, t5=False, llama=False, dtype_t5=None, dtype_llama=None, t5xxl_scaled_fp8=None, llama_scaled_fp8=None)
|
||||||
|
clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer
|
||||||
else:
|
else:
|
||||||
clip_target.clip = sdxl_clip.SDXLRefinerClipModel
|
clip_target.clip = sdxl_clip.SDXLRefinerClipModel
|
||||||
clip_target.tokenizer = sdxl_clip.SDXLTokenizer
|
clip_target.tokenizer = sdxl_clip.SDXLTokenizer
|
||||||
@ -811,6 +815,10 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
|||||||
clip_target.clip = comfy.text_encoders.wan.te(**t5xxl_detect(clip_data))
|
clip_target.clip = comfy.text_encoders.wan.te(**t5xxl_detect(clip_data))
|
||||||
clip_target.tokenizer = comfy.text_encoders.wan.WanT5Tokenizer
|
clip_target.tokenizer = comfy.text_encoders.wan.WanT5Tokenizer
|
||||||
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
|
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
|
||||||
|
elif clip_type == CLIPType.HIDREAM:
|
||||||
|
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(**t5xxl_detect(clip_data),
|
||||||
|
clip_l=False, clip_g=False, t5=True, llama=False, dtype_llama=None, llama_scaled_fp8=None)
|
||||||
|
clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer
|
||||||
else: #CLIPType.MOCHI
|
else: #CLIPType.MOCHI
|
||||||
clip_target.clip = comfy.text_encoders.genmo.mochi_te(**t5xxl_detect(clip_data))
|
clip_target.clip = comfy.text_encoders.genmo.mochi_te(**t5xxl_detect(clip_data))
|
||||||
clip_target.tokenizer = comfy.text_encoders.genmo.MochiT5Tokenizer
|
clip_target.tokenizer = comfy.text_encoders.genmo.MochiT5Tokenizer
|
||||||
@ -827,10 +835,18 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
|||||||
clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data))
|
clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data))
|
||||||
clip_target.tokenizer = comfy.text_encoders.lumina2.LuminaTokenizer
|
clip_target.tokenizer = comfy.text_encoders.lumina2.LuminaTokenizer
|
||||||
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
|
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
|
||||||
|
elif te_model == TEModel.LLAMA3_8:
|
||||||
|
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(**llama_detect(clip_data),
|
||||||
|
clip_l=False, clip_g=False, t5=False, llama=True, dtype_t5=None, t5xxl_scaled_fp8=None)
|
||||||
|
clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer
|
||||||
else:
|
else:
|
||||||
|
# clip_l
|
||||||
if clip_type == CLIPType.SD3:
|
if clip_type == CLIPType.SD3:
|
||||||
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=True, clip_g=False, t5=False)
|
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=True, clip_g=False, t5=False)
|
||||||
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
|
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
|
||||||
|
elif clip_type == CLIPType.HIDREAM:
|
||||||
|
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(clip_l=True, clip_g=False, t5=False, llama=False, dtype_t5=None, dtype_llama=None, t5xxl_scaled_fp8=None, llama_scaled_fp8=None)
|
||||||
|
clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer
|
||||||
else:
|
else:
|
||||||
clip_target.clip = sd1_clip.SD1ClipModel
|
clip_target.clip = sd1_clip.SD1ClipModel
|
||||||
clip_target.tokenizer = sd1_clip.SD1Tokenizer
|
clip_target.tokenizer = sd1_clip.SD1Tokenizer
|
||||||
@ -848,6 +864,24 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
|||||||
elif clip_type == CLIPType.HUNYUAN_VIDEO:
|
elif clip_type == CLIPType.HUNYUAN_VIDEO:
|
||||||
clip_target.clip = comfy.text_encoders.hunyuan_video.hunyuan_video_clip(**llama_detect(clip_data))
|
clip_target.clip = comfy.text_encoders.hunyuan_video.hunyuan_video_clip(**llama_detect(clip_data))
|
||||||
clip_target.tokenizer = comfy.text_encoders.hunyuan_video.HunyuanVideoTokenizer
|
clip_target.tokenizer = comfy.text_encoders.hunyuan_video.HunyuanVideoTokenizer
|
||||||
|
elif clip_type == CLIPType.HIDREAM:
|
||||||
|
# Detect
|
||||||
|
hidream_dualclip_classes = []
|
||||||
|
for hidream_te in clip_data:
|
||||||
|
te_model = detect_te_model(hidream_te)
|
||||||
|
hidream_dualclip_classes.append(te_model)
|
||||||
|
|
||||||
|
clip_l = TEModel.CLIP_L in hidream_dualclip_classes
|
||||||
|
clip_g = TEModel.CLIP_G in hidream_dualclip_classes
|
||||||
|
t5 = TEModel.T5_XXL in hidream_dualclip_classes
|
||||||
|
llama = TEModel.LLAMA3_8 in hidream_dualclip_classes
|
||||||
|
|
||||||
|
# Initialize t5xxl_detect and llama_detect kwargs if needed
|
||||||
|
t5_kwargs = t5xxl_detect(clip_data) if t5 else {}
|
||||||
|
llama_kwargs = llama_detect(clip_data) if llama else {}
|
||||||
|
|
||||||
|
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(clip_l=clip_l, clip_g=clip_g, t5=t5, llama=llama, **t5_kwargs, **llama_kwargs)
|
||||||
|
clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer
|
||||||
else:
|
else:
|
||||||
clip_target.clip = sdxl_clip.SDXLClipModel
|
clip_target.clip = sdxl_clip.SDXLClipModel
|
||||||
clip_target.tokenizer = sdxl_clip.SDXLTokenizer
|
clip_target.tokenizer = sdxl_clip.SDXLTokenizer
|
||||||
|
@ -987,6 +987,16 @@ class WAN21_FunControl2V(WAN21_T2V):
|
|||||||
out = model_base.WAN21(self, image_to_video=False, device=device)
|
out = model_base.WAN21(self, image_to_video=False, device=device)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
class WAN21_Vace(WAN21_T2V):
|
||||||
|
unet_config = {
|
||||||
|
"image_model": "wan2.1",
|
||||||
|
"model_type": "vace",
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_model(self, state_dict, prefix="", device=None):
|
||||||
|
out = model_base.WAN21_Vace(self, image_to_video=False, device=device)
|
||||||
|
return out
|
||||||
|
|
||||||
class Hunyuan3Dv2(supported_models_base.BASE):
|
class Hunyuan3Dv2(supported_models_base.BASE):
|
||||||
unet_config = {
|
unet_config = {
|
||||||
"image_model": "hunyuan3d2",
|
"image_model": "hunyuan3d2",
|
||||||
@ -1055,6 +1065,6 @@ class HiDream(supported_models_base.BASE):
|
|||||||
return None # TODO
|
return None # TODO
|
||||||
|
|
||||||
|
|
||||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream]
|
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream]
|
||||||
|
|
||||||
models += [SVD_img2vid]
|
models += [SVD_img2vid]
|
||||||
|
@ -109,14 +109,18 @@ class HiDreamTEModel(torch.nn.Module):
|
|||||||
if self.t5xxl is not None:
|
if self.t5xxl is not None:
|
||||||
t5_output = self.t5xxl.encode_token_weights(token_weight_pairs_t5)
|
t5_output = self.t5xxl.encode_token_weights(token_weight_pairs_t5)
|
||||||
t5_out, t5_pooled = t5_output[:2]
|
t5_out, t5_pooled = t5_output[:2]
|
||||||
|
else:
|
||||||
|
t5_out = None
|
||||||
|
|
||||||
if self.llama is not None:
|
if self.llama is not None:
|
||||||
ll_output = self.llama.encode_token_weights(token_weight_pairs_llama)
|
ll_output = self.llama.encode_token_weights(token_weight_pairs_llama)
|
||||||
ll_out, ll_pooled = ll_output[:2]
|
ll_out, ll_pooled = ll_output[:2]
|
||||||
ll_out = ll_out[:, 1:]
|
ll_out = ll_out[:, 1:]
|
||||||
|
else:
|
||||||
|
ll_out = None
|
||||||
|
|
||||||
if t5_out is None:
|
if t5_out is None:
|
||||||
t5_out = torch.zeros((1, 1, 4096), device=comfy.model_management.intermediate_device())
|
t5_out = torch.zeros((1, 128, 4096), device=comfy.model_management.intermediate_device())
|
||||||
|
|
||||||
if ll_out is None:
|
if ll_out is None:
|
||||||
ll_out = torch.zeros((1, 32, 1, 4096), device=comfy.model_management.intermediate_device())
|
ll_out = torch.zeros((1, 32, 1, 4096), device=comfy.model_management.intermediate_device())
|
||||||
|
17
comfy/weight_adapter/__init__.py
Normal file
17
comfy/weight_adapter/__init__.py
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
from .base import WeightAdapterBase
|
||||||
|
from .lora import LoRAAdapter
|
||||||
|
from .loha import LoHaAdapter
|
||||||
|
from .lokr import LoKrAdapter
|
||||||
|
from .glora import GLoRAAdapter
|
||||||
|
from .oft import OFTAdapter
|
||||||
|
from .boft import BOFTAdapter
|
||||||
|
|
||||||
|
|
||||||
|
adapters: list[type[WeightAdapterBase]] = [
|
||||||
|
LoRAAdapter,
|
||||||
|
LoHaAdapter,
|
||||||
|
LoKrAdapter,
|
||||||
|
GLoRAAdapter,
|
||||||
|
OFTAdapter,
|
||||||
|
BOFTAdapter,
|
||||||
|
]
|
104
comfy/weight_adapter/base.py
Normal file
104
comfy/weight_adapter/base.py
Normal file
@ -0,0 +1,104 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
import comfy.model_management
|
||||||
|
|
||||||
|
|
||||||
|
class WeightAdapterBase:
|
||||||
|
name: str
|
||||||
|
loaded_keys: set[str]
|
||||||
|
weights: list[torch.Tensor]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load(cls, x: str, lora: dict[str, torch.Tensor]) -> Optional["WeightAdapterBase"]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def to_train(self) -> "WeightAdapterTrainBase":
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def calculate_weight(
|
||||||
|
self,
|
||||||
|
weight,
|
||||||
|
key,
|
||||||
|
strength,
|
||||||
|
strength_model,
|
||||||
|
offset,
|
||||||
|
function,
|
||||||
|
intermediate_dtype=torch.float32,
|
||||||
|
original_weight=None,
|
||||||
|
):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class WeightAdapterTrainBase(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# [TODO] Collaborate with LoRA training PR #7032
|
||||||
|
|
||||||
|
|
||||||
|
def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function):
|
||||||
|
dora_scale = comfy.model_management.cast_to_device(dora_scale, weight.device, intermediate_dtype)
|
||||||
|
lora_diff *= alpha
|
||||||
|
weight_calc = weight + function(lora_diff).type(weight.dtype)
|
||||||
|
|
||||||
|
wd_on_output_axis = dora_scale.shape[0] == weight_calc.shape[0]
|
||||||
|
if wd_on_output_axis:
|
||||||
|
weight_norm = (
|
||||||
|
weight.reshape(weight.shape[0], -1)
|
||||||
|
.norm(dim=1, keepdim=True)
|
||||||
|
.reshape(weight.shape[0], *[1] * (weight.dim() - 1))
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
weight_norm = (
|
||||||
|
weight_calc.transpose(0, 1)
|
||||||
|
.reshape(weight_calc.shape[1], -1)
|
||||||
|
.norm(dim=1, keepdim=True)
|
||||||
|
.reshape(weight_calc.shape[1], *[1] * (weight_calc.dim() - 1))
|
||||||
|
.transpose(0, 1)
|
||||||
|
)
|
||||||
|
weight_norm = weight_norm + torch.finfo(weight.dtype).eps
|
||||||
|
|
||||||
|
weight_calc *= (dora_scale / weight_norm).type(weight.dtype)
|
||||||
|
if strength != 1.0:
|
||||||
|
weight_calc -= weight
|
||||||
|
weight += strength * (weight_calc)
|
||||||
|
else:
|
||||||
|
weight[:] = weight_calc
|
||||||
|
return weight
|
||||||
|
|
||||||
|
|
||||||
|
def pad_tensor_to_shape(tensor: torch.Tensor, new_shape: list[int]) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Pad a tensor to a new shape with zeros.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tensor (torch.Tensor): The original tensor to be padded.
|
||||||
|
new_shape (List[int]): The desired shape of the padded tensor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: A new tensor padded with zeros to the specified shape.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
If the new shape is smaller than the original tensor in any dimension,
|
||||||
|
the original tensor will be truncated in that dimension.
|
||||||
|
"""
|
||||||
|
if any([new_shape[i] < tensor.shape[i] for i in range(len(new_shape))]):
|
||||||
|
raise ValueError("The new shape must be larger than the original tensor in all dimensions")
|
||||||
|
|
||||||
|
if len(new_shape) != len(tensor.shape):
|
||||||
|
raise ValueError("The new shape must have the same number of dimensions as the original tensor")
|
||||||
|
|
||||||
|
# Create a new tensor filled with zeros
|
||||||
|
padded_tensor = torch.zeros(new_shape, dtype=tensor.dtype, device=tensor.device)
|
||||||
|
|
||||||
|
# Create slicing tuples for both tensors
|
||||||
|
orig_slices = tuple(slice(0, dim) for dim in tensor.shape)
|
||||||
|
new_slices = tuple(slice(0, dim) for dim in tensor.shape)
|
||||||
|
|
||||||
|
# Copy the original tensor into the new tensor
|
||||||
|
padded_tensor[new_slices] = tensor[orig_slices]
|
||||||
|
|
||||||
|
return padded_tensor
|
115
comfy/weight_adapter/boft.py
Normal file
115
comfy/weight_adapter/boft.py
Normal file
@ -0,0 +1,115 @@
|
|||||||
|
import logging
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import comfy.model_management
|
||||||
|
from .base import WeightAdapterBase, weight_decompose
|
||||||
|
|
||||||
|
|
||||||
|
class BOFTAdapter(WeightAdapterBase):
|
||||||
|
name = "boft"
|
||||||
|
|
||||||
|
def __init__(self, loaded_keys, weights):
|
||||||
|
self.loaded_keys = loaded_keys
|
||||||
|
self.weights = weights
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load(
|
||||||
|
cls,
|
||||||
|
x: str,
|
||||||
|
lora: dict[str, torch.Tensor],
|
||||||
|
alpha: float,
|
||||||
|
dora_scale: torch.Tensor,
|
||||||
|
loaded_keys: set[str] = None,
|
||||||
|
) -> Optional["BOFTAdapter"]:
|
||||||
|
if loaded_keys is None:
|
||||||
|
loaded_keys = set()
|
||||||
|
blocks_name = "{}.boft_blocks".format(x)
|
||||||
|
rescale_name = "{}.rescale".format(x)
|
||||||
|
|
||||||
|
blocks = None
|
||||||
|
if blocks_name in lora.keys():
|
||||||
|
blocks = lora[blocks_name]
|
||||||
|
if blocks.ndim == 4:
|
||||||
|
loaded_keys.add(blocks_name)
|
||||||
|
|
||||||
|
rescale = None
|
||||||
|
if rescale_name in lora.keys():
|
||||||
|
rescale = lora[rescale_name]
|
||||||
|
loaded_keys.add(rescale_name)
|
||||||
|
|
||||||
|
if blocks is not None:
|
||||||
|
weights = (blocks, rescale, alpha, dora_scale)
|
||||||
|
return cls(loaded_keys, weights)
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def calculate_weight(
|
||||||
|
self,
|
||||||
|
weight,
|
||||||
|
key,
|
||||||
|
strength,
|
||||||
|
strength_model,
|
||||||
|
offset,
|
||||||
|
function,
|
||||||
|
intermediate_dtype=torch.float32,
|
||||||
|
original_weight=None,
|
||||||
|
):
|
||||||
|
v = self.weights
|
||||||
|
blocks = v[0]
|
||||||
|
rescale = v[1]
|
||||||
|
alpha = v[2]
|
||||||
|
dora_scale = v[3]
|
||||||
|
|
||||||
|
blocks = comfy.model_management.cast_to_device(blocks, weight.device, intermediate_dtype)
|
||||||
|
if rescale is not None:
|
||||||
|
rescale = comfy.model_management.cast_to_device(rescale, weight.device, intermediate_dtype)
|
||||||
|
|
||||||
|
boft_m, block_num, boft_b, *_ = blocks.shape
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Get r
|
||||||
|
I = torch.eye(boft_b, device=blocks.device, dtype=blocks.dtype)
|
||||||
|
# for Q = -Q^T
|
||||||
|
q = blocks - blocks.transpose(1, 2)
|
||||||
|
normed_q = q
|
||||||
|
if alpha > 0: # alpha in boft/bboft is for constraint
|
||||||
|
q_norm = torch.norm(q) + 1e-8
|
||||||
|
if q_norm > alpha:
|
||||||
|
normed_q = q * alpha / q_norm
|
||||||
|
# use float() to prevent unsupported type in .inverse()
|
||||||
|
r = (I + normed_q) @ (I - normed_q).float().inverse()
|
||||||
|
r = r.to(original_weight)
|
||||||
|
|
||||||
|
inp = org = original_weight
|
||||||
|
|
||||||
|
r_b = boft_b//2
|
||||||
|
for i in range(boft_m):
|
||||||
|
bi = r[i]
|
||||||
|
g = 2
|
||||||
|
k = 2**i * r_b
|
||||||
|
if strength != 1:
|
||||||
|
bi = bi * strength + (1-strength) * I
|
||||||
|
inp = (
|
||||||
|
inp.unflatten(-1, (-1, g, k))
|
||||||
|
.transpose(-2, -1)
|
||||||
|
.flatten(-3)
|
||||||
|
.unflatten(-1, (-1, boft_b))
|
||||||
|
)
|
||||||
|
inp = torch.einsum("b n m, b n ... -> b m ...", inp, bi)
|
||||||
|
inp = (
|
||||||
|
inp.flatten(-2).unflatten(-1, (-1, k, g)).transpose(-2, -1).flatten(-3)
|
||||||
|
)
|
||||||
|
|
||||||
|
if rescale is not None:
|
||||||
|
inp = inp * rescale
|
||||||
|
|
||||||
|
lora_diff = inp - org
|
||||||
|
lora_diff = comfy.model_management.cast_to_device(lora_diff, weight.device, intermediate_dtype)
|
||||||
|
if dora_scale is not None:
|
||||||
|
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
|
||||||
|
else:
|
||||||
|
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||||
|
except Exception as e:
|
||||||
|
logging.error("ERROR {} {} {}".format(self.name, key, e))
|
||||||
|
return weight
|
93
comfy/weight_adapter/glora.py
Normal file
93
comfy/weight_adapter/glora.py
Normal file
@ -0,0 +1,93 @@
|
|||||||
|
import logging
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import comfy.model_management
|
||||||
|
from .base import WeightAdapterBase, weight_decompose
|
||||||
|
|
||||||
|
|
||||||
|
class GLoRAAdapter(WeightAdapterBase):
|
||||||
|
name = "glora"
|
||||||
|
|
||||||
|
def __init__(self, loaded_keys, weights):
|
||||||
|
self.loaded_keys = loaded_keys
|
||||||
|
self.weights = weights
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load(
|
||||||
|
cls,
|
||||||
|
x: str,
|
||||||
|
lora: dict[str, torch.Tensor],
|
||||||
|
alpha: float,
|
||||||
|
dora_scale: torch.Tensor,
|
||||||
|
loaded_keys: set[str] = None,
|
||||||
|
) -> Optional["GLoRAAdapter"]:
|
||||||
|
if loaded_keys is None:
|
||||||
|
loaded_keys = set()
|
||||||
|
a1_name = "{}.a1.weight".format(x)
|
||||||
|
a2_name = "{}.a2.weight".format(x)
|
||||||
|
b1_name = "{}.b1.weight".format(x)
|
||||||
|
b2_name = "{}.b2.weight".format(x)
|
||||||
|
if a1_name in lora:
|
||||||
|
weights = (lora[a1_name], lora[a2_name], lora[b1_name], lora[b2_name], alpha, dora_scale)
|
||||||
|
loaded_keys.add(a1_name)
|
||||||
|
loaded_keys.add(a2_name)
|
||||||
|
loaded_keys.add(b1_name)
|
||||||
|
loaded_keys.add(b2_name)
|
||||||
|
return cls(loaded_keys, weights)
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def calculate_weight(
|
||||||
|
self,
|
||||||
|
weight,
|
||||||
|
key,
|
||||||
|
strength,
|
||||||
|
strength_model,
|
||||||
|
offset,
|
||||||
|
function,
|
||||||
|
intermediate_dtype=torch.float32,
|
||||||
|
original_weight=None,
|
||||||
|
):
|
||||||
|
v = self.weights
|
||||||
|
dora_scale = v[5]
|
||||||
|
|
||||||
|
old_glora = False
|
||||||
|
if v[3].shape[1] == v[2].shape[0] == v[0].shape[0] == v[1].shape[1]:
|
||||||
|
rank = v[0].shape[0]
|
||||||
|
old_glora = True
|
||||||
|
|
||||||
|
if v[3].shape[0] == v[2].shape[1] == v[0].shape[1] == v[1].shape[0]:
|
||||||
|
if old_glora and v[1].shape[0] == weight.shape[0] and weight.shape[0] == weight.shape[1]:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
old_glora = False
|
||||||
|
rank = v[1].shape[0]
|
||||||
|
|
||||||
|
a1 = comfy.model_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, intermediate_dtype)
|
||||||
|
a2 = comfy.model_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, intermediate_dtype)
|
||||||
|
b1 = comfy.model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, intermediate_dtype)
|
||||||
|
b2 = comfy.model_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, intermediate_dtype)
|
||||||
|
|
||||||
|
if v[4] is not None:
|
||||||
|
alpha = v[4] / rank
|
||||||
|
else:
|
||||||
|
alpha = 1.0
|
||||||
|
|
||||||
|
try:
|
||||||
|
if old_glora:
|
||||||
|
lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1).to(dtype=intermediate_dtype), a2), a1)).reshape(weight.shape) #old lycoris glora
|
||||||
|
else:
|
||||||
|
if weight.dim() > 2:
|
||||||
|
lora_diff = torch.einsum("o i ..., i j -> o j ...", torch.einsum("o i ..., i j -> o j ...", weight.to(dtype=intermediate_dtype), a1), a2).reshape(weight.shape)
|
||||||
|
else:
|
||||||
|
lora_diff = torch.mm(torch.mm(weight.to(dtype=intermediate_dtype), a1), a2).reshape(weight.shape)
|
||||||
|
lora_diff += torch.mm(b1, b2).reshape(weight.shape)
|
||||||
|
|
||||||
|
if dora_scale is not None:
|
||||||
|
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
|
||||||
|
else:
|
||||||
|
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||||
|
except Exception as e:
|
||||||
|
logging.error("ERROR {} {} {}".format(self.name, key, e))
|
||||||
|
return weight
|
100
comfy/weight_adapter/loha.py
Normal file
100
comfy/weight_adapter/loha.py
Normal file
@ -0,0 +1,100 @@
|
|||||||
|
import logging
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import comfy.model_management
|
||||||
|
from .base import WeightAdapterBase, weight_decompose
|
||||||
|
|
||||||
|
|
||||||
|
class LoHaAdapter(WeightAdapterBase):
|
||||||
|
name = "loha"
|
||||||
|
|
||||||
|
def __init__(self, loaded_keys, weights):
|
||||||
|
self.loaded_keys = loaded_keys
|
||||||
|
self.weights = weights
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load(
|
||||||
|
cls,
|
||||||
|
x: str,
|
||||||
|
lora: dict[str, torch.Tensor],
|
||||||
|
alpha: float,
|
||||||
|
dora_scale: torch.Tensor,
|
||||||
|
loaded_keys: set[str] = None,
|
||||||
|
) -> Optional["LoHaAdapter"]:
|
||||||
|
if loaded_keys is None:
|
||||||
|
loaded_keys = set()
|
||||||
|
|
||||||
|
hada_w1_a_name = "{}.hada_w1_a".format(x)
|
||||||
|
hada_w1_b_name = "{}.hada_w1_b".format(x)
|
||||||
|
hada_w2_a_name = "{}.hada_w2_a".format(x)
|
||||||
|
hada_w2_b_name = "{}.hada_w2_b".format(x)
|
||||||
|
hada_t1_name = "{}.hada_t1".format(x)
|
||||||
|
hada_t2_name = "{}.hada_t2".format(x)
|
||||||
|
if hada_w1_a_name in lora.keys():
|
||||||
|
hada_t1 = None
|
||||||
|
hada_t2 = None
|
||||||
|
if hada_t1_name in lora.keys():
|
||||||
|
hada_t1 = lora[hada_t1_name]
|
||||||
|
hada_t2 = lora[hada_t2_name]
|
||||||
|
loaded_keys.add(hada_t1_name)
|
||||||
|
loaded_keys.add(hada_t2_name)
|
||||||
|
|
||||||
|
weights = (lora[hada_w1_a_name], lora[hada_w1_b_name], alpha, lora[hada_w2_a_name], lora[hada_w2_b_name], hada_t1, hada_t2, dora_scale)
|
||||||
|
loaded_keys.add(hada_w1_a_name)
|
||||||
|
loaded_keys.add(hada_w1_b_name)
|
||||||
|
loaded_keys.add(hada_w2_a_name)
|
||||||
|
loaded_keys.add(hada_w2_b_name)
|
||||||
|
return cls(loaded_keys, weights)
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def calculate_weight(
|
||||||
|
self,
|
||||||
|
weight,
|
||||||
|
key,
|
||||||
|
strength,
|
||||||
|
strength_model,
|
||||||
|
offset,
|
||||||
|
function,
|
||||||
|
intermediate_dtype=torch.float32,
|
||||||
|
original_weight=None,
|
||||||
|
):
|
||||||
|
v = self.weights
|
||||||
|
w1a = v[0]
|
||||||
|
w1b = v[1]
|
||||||
|
if v[2] is not None:
|
||||||
|
alpha = v[2] / w1b.shape[0]
|
||||||
|
else:
|
||||||
|
alpha = 1.0
|
||||||
|
|
||||||
|
w2a = v[3]
|
||||||
|
w2b = v[4]
|
||||||
|
dora_scale = v[7]
|
||||||
|
if v[5] is not None: #cp decomposition
|
||||||
|
t1 = v[5]
|
||||||
|
t2 = v[6]
|
||||||
|
m1 = torch.einsum('i j k l, j r, i p -> p r k l',
|
||||||
|
comfy.model_management.cast_to_device(t1, weight.device, intermediate_dtype),
|
||||||
|
comfy.model_management.cast_to_device(w1b, weight.device, intermediate_dtype),
|
||||||
|
comfy.model_management.cast_to_device(w1a, weight.device, intermediate_dtype))
|
||||||
|
|
||||||
|
m2 = torch.einsum('i j k l, j r, i p -> p r k l',
|
||||||
|
comfy.model_management.cast_to_device(t2, weight.device, intermediate_dtype),
|
||||||
|
comfy.model_management.cast_to_device(w2b, weight.device, intermediate_dtype),
|
||||||
|
comfy.model_management.cast_to_device(w2a, weight.device, intermediate_dtype))
|
||||||
|
else:
|
||||||
|
m1 = torch.mm(comfy.model_management.cast_to_device(w1a, weight.device, intermediate_dtype),
|
||||||
|
comfy.model_management.cast_to_device(w1b, weight.device, intermediate_dtype))
|
||||||
|
m2 = torch.mm(comfy.model_management.cast_to_device(w2a, weight.device, intermediate_dtype),
|
||||||
|
comfy.model_management.cast_to_device(w2b, weight.device, intermediate_dtype))
|
||||||
|
|
||||||
|
try:
|
||||||
|
lora_diff = (m1 * m2).reshape(weight.shape)
|
||||||
|
if dora_scale is not None:
|
||||||
|
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
|
||||||
|
else:
|
||||||
|
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||||
|
except Exception as e:
|
||||||
|
logging.error("ERROR {} {} {}".format(self.name, key, e))
|
||||||
|
return weight
|
133
comfy/weight_adapter/lokr.py
Normal file
133
comfy/weight_adapter/lokr.py
Normal file
@ -0,0 +1,133 @@
|
|||||||
|
import logging
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import comfy.model_management
|
||||||
|
from .base import WeightAdapterBase, weight_decompose
|
||||||
|
|
||||||
|
|
||||||
|
class LoKrAdapter(WeightAdapterBase):
|
||||||
|
name = "lokr"
|
||||||
|
|
||||||
|
def __init__(self, loaded_keys, weights):
|
||||||
|
self.loaded_keys = loaded_keys
|
||||||
|
self.weights = weights
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load(
|
||||||
|
cls,
|
||||||
|
x: str,
|
||||||
|
lora: dict[str, torch.Tensor],
|
||||||
|
alpha: float,
|
||||||
|
dora_scale: torch.Tensor,
|
||||||
|
loaded_keys: set[str] = None,
|
||||||
|
) -> Optional["LoKrAdapter"]:
|
||||||
|
if loaded_keys is None:
|
||||||
|
loaded_keys = set()
|
||||||
|
lokr_w1_name = "{}.lokr_w1".format(x)
|
||||||
|
lokr_w2_name = "{}.lokr_w2".format(x)
|
||||||
|
lokr_w1_a_name = "{}.lokr_w1_a".format(x)
|
||||||
|
lokr_w1_b_name = "{}.lokr_w1_b".format(x)
|
||||||
|
lokr_t2_name = "{}.lokr_t2".format(x)
|
||||||
|
lokr_w2_a_name = "{}.lokr_w2_a".format(x)
|
||||||
|
lokr_w2_b_name = "{}.lokr_w2_b".format(x)
|
||||||
|
|
||||||
|
lokr_w1 = None
|
||||||
|
if lokr_w1_name in lora.keys():
|
||||||
|
lokr_w1 = lora[lokr_w1_name]
|
||||||
|
loaded_keys.add(lokr_w1_name)
|
||||||
|
|
||||||
|
lokr_w2 = None
|
||||||
|
if lokr_w2_name in lora.keys():
|
||||||
|
lokr_w2 = lora[lokr_w2_name]
|
||||||
|
loaded_keys.add(lokr_w2_name)
|
||||||
|
|
||||||
|
lokr_w1_a = None
|
||||||
|
if lokr_w1_a_name in lora.keys():
|
||||||
|
lokr_w1_a = lora[lokr_w1_a_name]
|
||||||
|
loaded_keys.add(lokr_w1_a_name)
|
||||||
|
|
||||||
|
lokr_w1_b = None
|
||||||
|
if lokr_w1_b_name in lora.keys():
|
||||||
|
lokr_w1_b = lora[lokr_w1_b_name]
|
||||||
|
loaded_keys.add(lokr_w1_b_name)
|
||||||
|
|
||||||
|
lokr_w2_a = None
|
||||||
|
if lokr_w2_a_name in lora.keys():
|
||||||
|
lokr_w2_a = lora[lokr_w2_a_name]
|
||||||
|
loaded_keys.add(lokr_w2_a_name)
|
||||||
|
|
||||||
|
lokr_w2_b = None
|
||||||
|
if lokr_w2_b_name in lora.keys():
|
||||||
|
lokr_w2_b = lora[lokr_w2_b_name]
|
||||||
|
loaded_keys.add(lokr_w2_b_name)
|
||||||
|
|
||||||
|
lokr_t2 = None
|
||||||
|
if lokr_t2_name in lora.keys():
|
||||||
|
lokr_t2 = lora[lokr_t2_name]
|
||||||
|
loaded_keys.add(lokr_t2_name)
|
||||||
|
|
||||||
|
if (lokr_w1 is not None) or (lokr_w2 is not None) or (lokr_w1_a is not None) or (lokr_w2_a is not None):
|
||||||
|
weights = (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2, dora_scale)
|
||||||
|
return cls(loaded_keys, weights)
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def calculate_weight(
|
||||||
|
self,
|
||||||
|
weight,
|
||||||
|
key,
|
||||||
|
strength,
|
||||||
|
strength_model,
|
||||||
|
offset,
|
||||||
|
function,
|
||||||
|
intermediate_dtype=torch.float32,
|
||||||
|
original_weight=None,
|
||||||
|
):
|
||||||
|
v = self.weights
|
||||||
|
w1 = v[0]
|
||||||
|
w2 = v[1]
|
||||||
|
w1_a = v[3]
|
||||||
|
w1_b = v[4]
|
||||||
|
w2_a = v[5]
|
||||||
|
w2_b = v[6]
|
||||||
|
t2 = v[7]
|
||||||
|
dora_scale = v[8]
|
||||||
|
dim = None
|
||||||
|
|
||||||
|
if w1 is None:
|
||||||
|
dim = w1_b.shape[0]
|
||||||
|
w1 = torch.mm(comfy.model_management.cast_to_device(w1_a, weight.device, intermediate_dtype),
|
||||||
|
comfy.model_management.cast_to_device(w1_b, weight.device, intermediate_dtype))
|
||||||
|
else:
|
||||||
|
w1 = comfy.model_management.cast_to_device(w1, weight.device, intermediate_dtype)
|
||||||
|
|
||||||
|
if w2 is None:
|
||||||
|
dim = w2_b.shape[0]
|
||||||
|
if t2 is None:
|
||||||
|
w2 = torch.mm(comfy.model_management.cast_to_device(w2_a, weight.device, intermediate_dtype),
|
||||||
|
comfy.model_management.cast_to_device(w2_b, weight.device, intermediate_dtype))
|
||||||
|
else:
|
||||||
|
w2 = torch.einsum('i j k l, j r, i p -> p r k l',
|
||||||
|
comfy.model_management.cast_to_device(t2, weight.device, intermediate_dtype),
|
||||||
|
comfy.model_management.cast_to_device(w2_b, weight.device, intermediate_dtype),
|
||||||
|
comfy.model_management.cast_to_device(w2_a, weight.device, intermediate_dtype))
|
||||||
|
else:
|
||||||
|
w2 = comfy.model_management.cast_to_device(w2, weight.device, intermediate_dtype)
|
||||||
|
|
||||||
|
if len(w2.shape) == 4:
|
||||||
|
w1 = w1.unsqueeze(2).unsqueeze(2)
|
||||||
|
if v[2] is not None and dim is not None:
|
||||||
|
alpha = v[2] / dim
|
||||||
|
else:
|
||||||
|
alpha = 1.0
|
||||||
|
|
||||||
|
try:
|
||||||
|
lora_diff = torch.kron(w1, w2).reshape(weight.shape)
|
||||||
|
if dora_scale is not None:
|
||||||
|
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
|
||||||
|
else:
|
||||||
|
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||||
|
except Exception as e:
|
||||||
|
logging.error("ERROR {} {} {}".format(self.name, key, e))
|
||||||
|
return weight
|
142
comfy/weight_adapter/lora.py
Normal file
142
comfy/weight_adapter/lora.py
Normal file
@ -0,0 +1,142 @@
|
|||||||
|
import logging
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import comfy.model_management
|
||||||
|
from .base import WeightAdapterBase, weight_decompose, pad_tensor_to_shape
|
||||||
|
|
||||||
|
|
||||||
|
class LoRAAdapter(WeightAdapterBase):
|
||||||
|
name = "lora"
|
||||||
|
|
||||||
|
def __init__(self, loaded_keys, weights):
|
||||||
|
self.loaded_keys = loaded_keys
|
||||||
|
self.weights = weights
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load(
|
||||||
|
cls,
|
||||||
|
x: str,
|
||||||
|
lora: dict[str, torch.Tensor],
|
||||||
|
alpha: float,
|
||||||
|
dora_scale: torch.Tensor,
|
||||||
|
loaded_keys: set[str] = None,
|
||||||
|
) -> Optional["LoRAAdapter"]:
|
||||||
|
if loaded_keys is None:
|
||||||
|
loaded_keys = set()
|
||||||
|
|
||||||
|
reshape_name = "{}.reshape_weight".format(x)
|
||||||
|
regular_lora = "{}.lora_up.weight".format(x)
|
||||||
|
diffusers_lora = "{}_lora.up.weight".format(x)
|
||||||
|
diffusers2_lora = "{}.lora_B.weight".format(x)
|
||||||
|
diffusers3_lora = "{}.lora.up.weight".format(x)
|
||||||
|
mochi_lora = "{}.lora_B".format(x)
|
||||||
|
transformers_lora = "{}.lora_linear_layer.up.weight".format(x)
|
||||||
|
A_name = None
|
||||||
|
|
||||||
|
if regular_lora in lora.keys():
|
||||||
|
A_name = regular_lora
|
||||||
|
B_name = "{}.lora_down.weight".format(x)
|
||||||
|
mid_name = "{}.lora_mid.weight".format(x)
|
||||||
|
elif diffusers_lora in lora.keys():
|
||||||
|
A_name = diffusers_lora
|
||||||
|
B_name = "{}_lora.down.weight".format(x)
|
||||||
|
mid_name = None
|
||||||
|
elif diffusers2_lora in lora.keys():
|
||||||
|
A_name = diffusers2_lora
|
||||||
|
B_name = "{}.lora_A.weight".format(x)
|
||||||
|
mid_name = None
|
||||||
|
elif diffusers3_lora in lora.keys():
|
||||||
|
A_name = diffusers3_lora
|
||||||
|
B_name = "{}.lora.down.weight".format(x)
|
||||||
|
mid_name = None
|
||||||
|
elif mochi_lora in lora.keys():
|
||||||
|
A_name = mochi_lora
|
||||||
|
B_name = "{}.lora_A".format(x)
|
||||||
|
mid_name = None
|
||||||
|
elif transformers_lora in lora.keys():
|
||||||
|
A_name = transformers_lora
|
||||||
|
B_name = "{}.lora_linear_layer.down.weight".format(x)
|
||||||
|
mid_name = None
|
||||||
|
|
||||||
|
if A_name is not None:
|
||||||
|
mid = None
|
||||||
|
if mid_name is not None and mid_name in lora.keys():
|
||||||
|
mid = lora[mid_name]
|
||||||
|
loaded_keys.add(mid_name)
|
||||||
|
reshape = None
|
||||||
|
if reshape_name in lora.keys():
|
||||||
|
try:
|
||||||
|
reshape = lora[reshape_name].tolist()
|
||||||
|
loaded_keys.add(reshape_name)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
weights = (lora[A_name], lora[B_name], alpha, mid, dora_scale, reshape)
|
||||||
|
loaded_keys.add(A_name)
|
||||||
|
loaded_keys.add(B_name)
|
||||||
|
return cls(loaded_keys, weights)
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def calculate_weight(
|
||||||
|
self,
|
||||||
|
weight,
|
||||||
|
key,
|
||||||
|
strength,
|
||||||
|
strength_model,
|
||||||
|
offset,
|
||||||
|
function,
|
||||||
|
intermediate_dtype=torch.float32,
|
||||||
|
original_weight=None,
|
||||||
|
):
|
||||||
|
v = self.weights
|
||||||
|
mat1 = comfy.model_management.cast_to_device(
|
||||||
|
v[0], weight.device, intermediate_dtype
|
||||||
|
)
|
||||||
|
mat2 = comfy.model_management.cast_to_device(
|
||||||
|
v[1], weight.device, intermediate_dtype
|
||||||
|
)
|
||||||
|
dora_scale = v[4]
|
||||||
|
reshape = v[5]
|
||||||
|
|
||||||
|
if reshape is not None:
|
||||||
|
weight = pad_tensor_to_shape(weight, reshape)
|
||||||
|
|
||||||
|
if v[2] is not None:
|
||||||
|
alpha = v[2] / mat2.shape[0]
|
||||||
|
else:
|
||||||
|
alpha = 1.0
|
||||||
|
|
||||||
|
if v[3] is not None:
|
||||||
|
# locon mid weights, hopefully the math is fine because I didn't properly test it
|
||||||
|
mat3 = comfy.model_management.cast_to_device(
|
||||||
|
v[3], weight.device, intermediate_dtype
|
||||||
|
)
|
||||||
|
final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]]
|
||||||
|
mat2 = (
|
||||||
|
torch.mm(
|
||||||
|
mat2.transpose(0, 1).flatten(start_dim=1),
|
||||||
|
mat3.transpose(0, 1).flatten(start_dim=1),
|
||||||
|
)
|
||||||
|
.reshape(final_shape)
|
||||||
|
.transpose(0, 1)
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
lora_diff = torch.mm(
|
||||||
|
mat1.flatten(start_dim=1), mat2.flatten(start_dim=1)
|
||||||
|
).reshape(weight.shape)
|
||||||
|
if dora_scale is not None:
|
||||||
|
weight = weight_decompose(
|
||||||
|
dora_scale,
|
||||||
|
weight,
|
||||||
|
lora_diff,
|
||||||
|
alpha,
|
||||||
|
strength,
|
||||||
|
intermediate_dtype,
|
||||||
|
function,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||||
|
except Exception as e:
|
||||||
|
logging.error("ERROR {} {} {}".format(self.name, key, e))
|
||||||
|
return weight
|
94
comfy/weight_adapter/oft.py
Normal file
94
comfy/weight_adapter/oft.py
Normal file
@ -0,0 +1,94 @@
|
|||||||
|
import logging
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import comfy.model_management
|
||||||
|
from .base import WeightAdapterBase, weight_decompose
|
||||||
|
|
||||||
|
|
||||||
|
class OFTAdapter(WeightAdapterBase):
|
||||||
|
name = "oft"
|
||||||
|
|
||||||
|
def __init__(self, loaded_keys, weights):
|
||||||
|
self.loaded_keys = loaded_keys
|
||||||
|
self.weights = weights
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load(
|
||||||
|
cls,
|
||||||
|
x: str,
|
||||||
|
lora: dict[str, torch.Tensor],
|
||||||
|
alpha: float,
|
||||||
|
dora_scale: torch.Tensor,
|
||||||
|
loaded_keys: set[str] = None,
|
||||||
|
) -> Optional["OFTAdapter"]:
|
||||||
|
if loaded_keys is None:
|
||||||
|
loaded_keys = set()
|
||||||
|
blocks_name = "{}.oft_blocks".format(x)
|
||||||
|
rescale_name = "{}.rescale".format(x)
|
||||||
|
|
||||||
|
blocks = None
|
||||||
|
if blocks_name in lora.keys():
|
||||||
|
blocks = lora[blocks_name]
|
||||||
|
if blocks.ndim == 3:
|
||||||
|
loaded_keys.add(blocks_name)
|
||||||
|
|
||||||
|
rescale = None
|
||||||
|
if rescale_name in lora.keys():
|
||||||
|
rescale = lora[rescale_name]
|
||||||
|
loaded_keys.add(rescale_name)
|
||||||
|
|
||||||
|
if blocks is not None:
|
||||||
|
weights = (blocks, rescale, alpha, dora_scale)
|
||||||
|
return cls(loaded_keys, weights)
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def calculate_weight(
|
||||||
|
self,
|
||||||
|
weight,
|
||||||
|
key,
|
||||||
|
strength,
|
||||||
|
strength_model,
|
||||||
|
offset,
|
||||||
|
function,
|
||||||
|
intermediate_dtype=torch.float32,
|
||||||
|
original_weight=None,
|
||||||
|
):
|
||||||
|
v = self.weights
|
||||||
|
blocks = v[0]
|
||||||
|
rescale = v[1]
|
||||||
|
alpha = v[2]
|
||||||
|
dora_scale = v[3]
|
||||||
|
|
||||||
|
blocks = comfy.model_management.cast_to_device(blocks, weight.device, intermediate_dtype)
|
||||||
|
if rescale is not None:
|
||||||
|
rescale = comfy.model_management.cast_to_device(rescale, weight.device, intermediate_dtype)
|
||||||
|
|
||||||
|
block_num, block_size, *_ = blocks.shape
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Get r
|
||||||
|
I = torch.eye(block_size, device=blocks.device, dtype=blocks.dtype)
|
||||||
|
# for Q = -Q^T
|
||||||
|
q = blocks - blocks.transpose(1, 2)
|
||||||
|
normed_q = q
|
||||||
|
if alpha > 0: # alpha in oft/boft is for constraint
|
||||||
|
q_norm = torch.norm(q) + 1e-8
|
||||||
|
if q_norm > alpha:
|
||||||
|
normed_q = q * alpha / q_norm
|
||||||
|
# use float() to prevent unsupported type in .inverse()
|
||||||
|
r = (I + normed_q) @ (I - normed_q).float().inverse()
|
||||||
|
r = r.to(original_weight)
|
||||||
|
lora_diff = torch.einsum(
|
||||||
|
"k n m, k n ... -> k m ...",
|
||||||
|
(r * strength) - strength * I,
|
||||||
|
original_weight,
|
||||||
|
)
|
||||||
|
if dora_scale is not None:
|
||||||
|
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
|
||||||
|
else:
|
||||||
|
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||||
|
except Exception as e:
|
||||||
|
logging.error("ERROR {} {} {}".format(self.name, key, e))
|
||||||
|
return weight
|
@ -26,7 +26,30 @@ class QuadrupleCLIPLoader:
|
|||||||
clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2, clip_path3, clip_path4], embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2, clip_path3, clip_path4], embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
||||||
return (clip,)
|
return (clip,)
|
||||||
|
|
||||||
|
class CLIPTextEncodeHiDream:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": {
|
||||||
|
"clip": ("CLIP", ),
|
||||||
|
"clip_l": ("STRING", {"multiline": True, "dynamicPrompts": True}),
|
||||||
|
"clip_g": ("STRING", {"multiline": True, "dynamicPrompts": True}),
|
||||||
|
"t5xxl": ("STRING", {"multiline": True, "dynamicPrompts": True}),
|
||||||
|
"llama": ("STRING", {"multiline": True, "dynamicPrompts": True})
|
||||||
|
}}
|
||||||
|
RETURN_TYPES = ("CONDITIONING",)
|
||||||
|
FUNCTION = "encode"
|
||||||
|
|
||||||
|
CATEGORY = "advanced/conditioning"
|
||||||
|
|
||||||
|
def encode(self, clip, clip_l, clip_g, t5xxl, llama):
|
||||||
|
|
||||||
|
tokens = clip.tokenize(clip_g)
|
||||||
|
tokens["l"] = clip.tokenize(clip_l)["l"]
|
||||||
|
tokens["t5xxl"] = clip.tokenize(t5xxl)["t5xxl"]
|
||||||
|
tokens["llama"] = clip.tokenize(llama)["llama"]
|
||||||
|
return (clip.encode_from_tokens_scheduled(tokens), )
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
"QuadrupleCLIPLoader": QuadrupleCLIPLoader,
|
"QuadrupleCLIPLoader": QuadrupleCLIPLoader,
|
||||||
|
"CLIPTextEncodeHiDream": CLIPTextEncodeHiDream,
|
||||||
}
|
}
|
||||||
|
@ -3,7 +3,10 @@ import scipy.ndimage
|
|||||||
import torch
|
import torch
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
import node_helpers
|
import node_helpers
|
||||||
|
import folder_paths
|
||||||
|
import random
|
||||||
|
|
||||||
|
import nodes
|
||||||
from nodes import MAX_RESOLUTION
|
from nodes import MAX_RESOLUTION
|
||||||
|
|
||||||
def composite(destination, source, x, y, mask = None, multiplier = 8, resize_source = False):
|
def composite(destination, source, x, y, mask = None, multiplier = 8, resize_source = False):
|
||||||
@ -362,6 +365,30 @@ class ThresholdMask:
|
|||||||
mask = (mask > value).float()
|
mask = (mask > value).float()
|
||||||
return (mask,)
|
return (mask,)
|
||||||
|
|
||||||
|
# Mask Preview - original implement from
|
||||||
|
# https://github.com/cubiq/ComfyUI_essentials/blob/9d9f4bedfc9f0321c19faf71855e228c93bd0dc9/mask.py#L81
|
||||||
|
# upstream requested in https://github.com/Kosinkadink/rfcs/blob/main/rfcs/0000-corenodes.md#preview-nodes
|
||||||
|
class MaskPreview(nodes.SaveImage):
|
||||||
|
def __init__(self):
|
||||||
|
self.output_dir = folder_paths.get_temp_directory()
|
||||||
|
self.type = "temp"
|
||||||
|
self.prefix_append = "_temp_" + ''.join(random.choice("abcdefghijklmnopqrstupvxyz") for x in range(5))
|
||||||
|
self.compress_level = 4
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {"mask": ("MASK",), },
|
||||||
|
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
|
||||||
|
}
|
||||||
|
|
||||||
|
FUNCTION = "execute"
|
||||||
|
CATEGORY = "mask"
|
||||||
|
|
||||||
|
def execute(self, mask, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None):
|
||||||
|
preview = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])).movedim(1, -1).expand(-1, -1, -1, 3)
|
||||||
|
return self.save_images(preview, filename_prefix, prompt, extra_pnginfo)
|
||||||
|
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
"LatentCompositeMasked": LatentCompositeMasked,
|
"LatentCompositeMasked": LatentCompositeMasked,
|
||||||
@ -376,6 +403,7 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
"FeatherMask": FeatherMask,
|
"FeatherMask": FeatherMask,
|
||||||
"GrowMask": GrowMask,
|
"GrowMask": GrowMask,
|
||||||
"ThresholdMask": ThresholdMask,
|
"ThresholdMask": ThresholdMask,
|
||||||
|
"MaskPreview": MaskPreview
|
||||||
}
|
}
|
||||||
|
|
||||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
# Primitive nodes that are evaluated at backend.
|
# Primitive nodes that are evaluated at backend.
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import sys
|
||||||
|
|
||||||
from comfy.comfy_types.node_typing import ComfyNodeABC, InputTypeDict, IO
|
from comfy.comfy_types.node_typing import ComfyNodeABC, InputTypeDict, IO
|
||||||
|
|
||||||
|
|
||||||
@ -23,7 +25,7 @@ class Int(ComfyNodeABC):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(cls) -> InputTypeDict:
|
def INPUT_TYPES(cls) -> InputTypeDict:
|
||||||
return {
|
return {
|
||||||
"required": {"value": (IO.INT, {"control_after_generate": True})},
|
"required": {"value": (IO.INT, {"min": -sys.maxsize, "max": sys.maxsize, "control_after_generate": True})},
|
||||||
}
|
}
|
||||||
|
|
||||||
RETURN_TYPES = (IO.INT,)
|
RETURN_TYPES = (IO.INT,)
|
||||||
@ -38,7 +40,7 @@ class Float(ComfyNodeABC):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(cls) -> InputTypeDict:
|
def INPUT_TYPES(cls) -> InputTypeDict:
|
||||||
return {
|
return {
|
||||||
"required": {"value": (IO.FLOAT, {})},
|
"required": {"value": (IO.FLOAT, {"min": -sys.maxsize, "max": sys.maxsize})},
|
||||||
}
|
}
|
||||||
|
|
||||||
RETURN_TYPES = (IO.FLOAT,)
|
RETURN_TYPES = (IO.FLOAT,)
|
||||||
|
@ -50,13 +50,15 @@ class SaveWEBM:
|
|||||||
for x in extra_pnginfo:
|
for x in extra_pnginfo:
|
||||||
container.metadata[x] = json.dumps(extra_pnginfo[x])
|
container.metadata[x] = json.dumps(extra_pnginfo[x])
|
||||||
|
|
||||||
codec_map = {"vp9": "libvpx-vp9", "av1": "libaom-av1"}
|
codec_map = {"vp9": "libvpx-vp9", "av1": "libsvtav1"}
|
||||||
stream = container.add_stream(codec_map[codec], rate=Fraction(round(fps * 1000), 1000))
|
stream = container.add_stream(codec_map[codec], rate=Fraction(round(fps * 1000), 1000))
|
||||||
stream.width = images.shape[-2]
|
stream.width = images.shape[-2]
|
||||||
stream.height = images.shape[-3]
|
stream.height = images.shape[-3]
|
||||||
stream.pix_fmt = "yuv420p"
|
stream.pix_fmt = "yuv420p10le" if codec == "av1" else "yuv420p"
|
||||||
stream.bit_rate = 0
|
stream.bit_rate = 0
|
||||||
stream.options = {'crf': str(crf)}
|
stream.options = {'crf': str(crf)}
|
||||||
|
if codec == "av1":
|
||||||
|
stream.options["preset"] = "6"
|
||||||
|
|
||||||
for frame in images:
|
for frame in images:
|
||||||
frame = av.VideoFrame.from_ndarray(torch.clamp(frame[..., :3] * 255, min=0, max=255).to(device=torch.device("cpu"), dtype=torch.uint8).numpy(), format="rgb24")
|
frame = av.VideoFrame.from_ndarray(torch.clamp(frame[..., :3] * 255, min=0, max=255).to(device=torch.device("cpu"), dtype=torch.uint8).numpy(), format="rgb24")
|
||||||
|
@ -193,9 +193,116 @@ class WanFunInpaintToVideo:
|
|||||||
return flfv.encode(positive, negative, vae, width, height, length, batch_size, start_image=start_image, end_image=end_image, clip_vision_start_image=clip_vision_output)
|
return flfv.encode(positive, negative, vae, width, height, length, batch_size, start_image=start_image, end_image=end_image, clip_vision_start_image=clip_vision_output)
|
||||||
|
|
||||||
|
|
||||||
|
class WanVaceToVideo:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": {"positive": ("CONDITIONING", ),
|
||||||
|
"negative": ("CONDITIONING", ),
|
||||||
|
"vae": ("VAE", ),
|
||||||
|
"width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
||||||
|
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
||||||
|
"length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
|
||||||
|
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
|
||||||
|
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1000.0, "step": 0.01}),
|
||||||
|
},
|
||||||
|
"optional": {"control_video": ("IMAGE", ),
|
||||||
|
"control_masks": ("MASK", ),
|
||||||
|
"reference_image": ("IMAGE", ),
|
||||||
|
}}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT", "INT")
|
||||||
|
RETURN_NAMES = ("positive", "negative", "latent", "trim_latent")
|
||||||
|
FUNCTION = "encode"
|
||||||
|
|
||||||
|
CATEGORY = "conditioning/video_models"
|
||||||
|
|
||||||
|
EXPERIMENTAL = True
|
||||||
|
|
||||||
|
def encode(self, positive, negative, vae, width, height, length, batch_size, strength, control_video=None, control_masks=None, reference_image=None):
|
||||||
|
latent_length = ((length - 1) // 4) + 1
|
||||||
|
if control_video is not None:
|
||||||
|
control_video = comfy.utils.common_upscale(control_video[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||||
|
if control_video.shape[0] < length:
|
||||||
|
control_video = torch.nn.functional.pad(control_video, (0, 0, 0, 0, 0, 0, 0, length - control_video.shape[0]), value=0.5)
|
||||||
|
else:
|
||||||
|
control_video = torch.ones((length, height, width, 3)) * 0.5
|
||||||
|
|
||||||
|
if reference_image is not None:
|
||||||
|
reference_image = comfy.utils.common_upscale(reference_image[:1].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||||
|
reference_image = vae.encode(reference_image[:, :, :, :3])
|
||||||
|
reference_image = torch.cat([reference_image, comfy.latent_formats.Wan21().process_out(torch.zeros_like(reference_image))], dim=1)
|
||||||
|
|
||||||
|
if control_masks is None:
|
||||||
|
mask = torch.ones((length, height, width, 1))
|
||||||
|
else:
|
||||||
|
mask = control_masks
|
||||||
|
if mask.ndim == 3:
|
||||||
|
mask = mask.unsqueeze(1)
|
||||||
|
mask = comfy.utils.common_upscale(mask[:length], width, height, "bilinear", "center").movedim(1, -1)
|
||||||
|
if mask.shape[0] < length:
|
||||||
|
mask = torch.nn.functional.pad(mask, (0, 0, 0, 0, 0, 0, 0, length - mask.shape[0]), value=1.0)
|
||||||
|
|
||||||
|
control_video = control_video - 0.5
|
||||||
|
inactive = (control_video * (1 - mask)) + 0.5
|
||||||
|
reactive = (control_video * mask) + 0.5
|
||||||
|
|
||||||
|
inactive = vae.encode(inactive[:, :, :, :3])
|
||||||
|
reactive = vae.encode(reactive[:, :, :, :3])
|
||||||
|
control_video_latent = torch.cat((inactive, reactive), dim=1)
|
||||||
|
if reference_image is not None:
|
||||||
|
control_video_latent = torch.cat((reference_image, control_video_latent), dim=2)
|
||||||
|
|
||||||
|
vae_stride = 8
|
||||||
|
height_mask = height // vae_stride
|
||||||
|
width_mask = width // vae_stride
|
||||||
|
mask = mask.view(length, height_mask, vae_stride, width_mask, vae_stride)
|
||||||
|
mask = mask.permute(2, 4, 0, 1, 3)
|
||||||
|
mask = mask.reshape(vae_stride * vae_stride, length, height_mask, width_mask)
|
||||||
|
mask = torch.nn.functional.interpolate(mask.unsqueeze(0), size=(latent_length, height_mask, width_mask), mode='nearest-exact').squeeze(0)
|
||||||
|
|
||||||
|
trim_latent = 0
|
||||||
|
if reference_image is not None:
|
||||||
|
mask_pad = torch.zeros_like(mask[:, :reference_image.shape[2], :, :])
|
||||||
|
mask = torch.cat((mask_pad, mask), dim=1)
|
||||||
|
latent_length += reference_image.shape[2]
|
||||||
|
trim_latent = reference_image.shape[2]
|
||||||
|
|
||||||
|
mask = mask.unsqueeze(0)
|
||||||
|
positive = node_helpers.conditioning_set_values(positive, {"vace_frames": control_video_latent, "vace_mask": mask, "vace_strength": strength})
|
||||||
|
negative = node_helpers.conditioning_set_values(negative, {"vace_frames": control_video_latent, "vace_mask": mask, "vace_strength": strength})
|
||||||
|
|
||||||
|
latent = torch.zeros([batch_size, 16, latent_length, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
||||||
|
out_latent = {}
|
||||||
|
out_latent["samples"] = latent
|
||||||
|
return (positive, negative, out_latent, trim_latent)
|
||||||
|
|
||||||
|
class TrimVideoLatent:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": { "samples": ("LATENT",),
|
||||||
|
"trim_amount": ("INT", {"default": 0, "min": 0, "max": 99999}),
|
||||||
|
}}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("LATENT",)
|
||||||
|
FUNCTION = "op"
|
||||||
|
|
||||||
|
CATEGORY = "latent/video"
|
||||||
|
|
||||||
|
EXPERIMENTAL = True
|
||||||
|
|
||||||
|
def op(self, samples, trim_amount):
|
||||||
|
samples_out = samples.copy()
|
||||||
|
|
||||||
|
s1 = samples["samples"]
|
||||||
|
samples_out["samples"] = s1[:, :, trim_amount:]
|
||||||
|
return (samples_out,)
|
||||||
|
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
"WanImageToVideo": WanImageToVideo,
|
"WanImageToVideo": WanImageToVideo,
|
||||||
"WanFunControlToVideo": WanFunControlToVideo,
|
"WanFunControlToVideo": WanFunControlToVideo,
|
||||||
"WanFunInpaintToVideo": WanFunInpaintToVideo,
|
"WanFunInpaintToVideo": WanFunInpaintToVideo,
|
||||||
"WanFirstLastFrameToVideo": WanFirstLastFrameToVideo,
|
"WanFirstLastFrameToVideo": WanFirstLastFrameToVideo,
|
||||||
|
"WanVaceToVideo": WanVaceToVideo,
|
||||||
|
"TrimVideoLatent": TrimVideoLatent,
|
||||||
}
|
}
|
||||||
|
@ -144,6 +144,8 @@ def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, e
|
|||||||
input_data_all[x] = [extra_data.get('extra_pnginfo', None)]
|
input_data_all[x] = [extra_data.get('extra_pnginfo', None)]
|
||||||
if h[x] == "UNIQUE_ID":
|
if h[x] == "UNIQUE_ID":
|
||||||
input_data_all[x] = [unique_id]
|
input_data_all[x] = [unique_id]
|
||||||
|
if h[x] == "AUTH_TOKEN_COMFY_ORG":
|
||||||
|
input_data_all[x] = [extra_data.get("auth_token_comfy_org", None)]
|
||||||
return input_data_all, missing_keys
|
return input_data_all, missing_keys
|
||||||
|
|
||||||
map_node_over_list = None #Don't hook this please
|
map_node_over_list = None #Don't hook this please
|
||||||
|
8
nodes.py
8
nodes.py
@ -917,7 +917,7 @@ class CLIPLoader:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ),
|
return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ),
|
||||||
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan"], ),
|
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream"], ),
|
||||||
},
|
},
|
||||||
"optional": {
|
"optional": {
|
||||||
"device": (["default", "cpu"], {"advanced": True}),
|
"device": (["default", "cpu"], {"advanced": True}),
|
||||||
@ -927,7 +927,7 @@ class CLIPLoader:
|
|||||||
|
|
||||||
CATEGORY = "advanced/loaders"
|
CATEGORY = "advanced/loaders"
|
||||||
|
|
||||||
DESCRIPTION = "[Recipes]\n\nstable_diffusion: clip-l\nstable_cascade: clip-g\nsd3: t5 xxl/ clip-g / clip-l\nstable_audio: t5 base\nmochi: t5 xxl\ncosmos: old t5 xxl\nlumina2: gemma 2 2B\nwan: umt5 xxl"
|
DESCRIPTION = "[Recipes]\n\nstable_diffusion: clip-l\nstable_cascade: clip-g\nsd3: t5 xxl/ clip-g / clip-l\nstable_audio: t5 base\nmochi: t5 xxl\ncosmos: old t5 xxl\nlumina2: gemma 2 2B\nwan: umt5 xxl\n hidream: llama-3.1 (Recommend) or t5"
|
||||||
|
|
||||||
def load_clip(self, clip_name, type="stable_diffusion", device="default"):
|
def load_clip(self, clip_name, type="stable_diffusion", device="default"):
|
||||||
clip_type = getattr(comfy.sd.CLIPType, type.upper(), comfy.sd.CLIPType.STABLE_DIFFUSION)
|
clip_type = getattr(comfy.sd.CLIPType, type.upper(), comfy.sd.CLIPType.STABLE_DIFFUSION)
|
||||||
@ -945,7 +945,7 @@ class DualCLIPLoader:
|
|||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required": { "clip_name1": (folder_paths.get_filename_list("text_encoders"), ),
|
return {"required": { "clip_name1": (folder_paths.get_filename_list("text_encoders"), ),
|
||||||
"clip_name2": (folder_paths.get_filename_list("text_encoders"), ),
|
"clip_name2": (folder_paths.get_filename_list("text_encoders"), ),
|
||||||
"type": (["sdxl", "sd3", "flux", "hunyuan_video"], ),
|
"type": (["sdxl", "sd3", "flux", "hunyuan_video", "hidream"], ),
|
||||||
},
|
},
|
||||||
"optional": {
|
"optional": {
|
||||||
"device": (["default", "cpu"], {"advanced": True}),
|
"device": (["default", "cpu"], {"advanced": True}),
|
||||||
@ -955,7 +955,7 @@ class DualCLIPLoader:
|
|||||||
|
|
||||||
CATEGORY = "advanced/loaders"
|
CATEGORY = "advanced/loaders"
|
||||||
|
|
||||||
DESCRIPTION = "[Recipes]\n\nsdxl: clip-l, clip-g\nsd3: clip-l, clip-g / clip-l, t5 / clip-g, t5\nflux: clip-l, t5"
|
DESCRIPTION = "[Recipes]\n\nsdxl: clip-l, clip-g\nsd3: clip-l, clip-g / clip-l, t5 / clip-g, t5\nflux: clip-l, t5\nhidream: at least one of t5 or llama, recommended t5 and llama"
|
||||||
|
|
||||||
def load_clip(self, clip_name1, clip_name2, type, device="default"):
|
def load_clip(self, clip_name1, clip_name2, type, device="default"):
|
||||||
clip_type = getattr(comfy.sd.CLIPType, type.upper(), comfy.sd.CLIPType.STABLE_DIFFUSION)
|
clip_type = getattr(comfy.sd.CLIPType, type.upper(), comfy.sd.CLIPType.STABLE_DIFFUSION)
|
||||||
|
@ -22,4 +22,4 @@ psutil
|
|||||||
kornia>=0.7.1
|
kornia>=0.7.1
|
||||||
spandrel
|
spandrel
|
||||||
soundfile
|
soundfile
|
||||||
av
|
av>=14.1.0
|
||||||
|
Loading…
x
Reference in New Issue
Block a user