converted nodes files starting with "l" letter

This commit is contained in:
bigcat88 2025-07-24 10:18:55 +03:00
parent e5cac06bbe
commit 991de5fc81
No known key found for this signature in database
GPG Key ID: 1F0BF0EC3CF22721
8 changed files with 556 additions and 2 deletions

View File

@ -657,9 +657,34 @@ class Accumulation(ComfyTypeIO):
accum: list[Any]
Type = AccumulationDict
@comfytype(io_type="LOAD3D_CAMERA")
class Load3DCamera(ComfyTypeIO):
Type = Any # TODO: figure out type for this; in code, only described as image['camera_info'], gotten from a LOAD_3D or LOAD_3D_ANIMATION type
class CameraInfo(TypedDict):
position: dict[str, float | int]
target: dict[str, float | int]
zoom: int
cameraType: str
Type = CameraInfo
@comfytype(io_type="LOAD_3D")
class Load3D(ComfyTypeIO):
"""3D models are stored as a dictionary."""
class Model3DDict(TypedDict):
image: str
mask: str
normal: str
camera_info: Load3DCamera.CameraInfo
recording: NotRequired[str]
Type = Model3DDict
@comfytype(io_type="LOAD_3D_ANIMATION")
class Load3DAnimation(Load3D):
...
@comfytype(io_type="PHOTOMAKER")

View File

@ -479,7 +479,7 @@ class PreviewUI3D(_UIOutput):
self.values = values
def as_dict(self):
return {"3d": self.values}
return {"result": self.values}
class PreviewText(_UIOutput):

View File

@ -0,0 +1,56 @@
from __future__ import annotations
import torch
from comfy_api.v3 import io
class InstructPixToPixConditioning(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="InstructPixToPixConditioning_V3",
category="conditioning/instructpix2pix",
inputs=[
io.Conditioning.Input(id="positive"),
io.Conditioning.Input(id="negative"),
io.Vae.Input(id="vae"),
io.Image.Input(id="pixels"),
],
outputs=[
io.Conditioning.Output(display_name="positive"),
io.Conditioning.Output(display_name="negative"),
io.Latent.Output(display_name="latent"),
],
)
@classmethod
def execute(cls, positive, negative, pixels, vae):
x = (pixels.shape[1] // 8) * 8
y = (pixels.shape[2] // 8) * 8
if pixels.shape[1] != x or pixels.shape[2] != y:
x_offset = (pixels.shape[1] % 8) // 2
y_offset = (pixels.shape[2] % 8) // 2
pixels = pixels[:,x_offset:x + x_offset, y_offset:y + y_offset,:]
concat_latent = vae.encode(pixels)
out_latent = {}
out_latent["samples"] = torch.zeros_like(concat_latent)
out = []
for conditioning in [positive, negative]:
c = []
for t in conditioning:
d = t[1].copy()
d["concat_latent_image"] = concat_latent
n = [t[0], d]
c.append(n)
out.append(c)
return io.NodeOutput(out[0], out[1], out_latent)
NODES_LIST = [
InstructPixToPixConditioning,
]

View File

@ -0,0 +1,180 @@
from __future__ import annotations
import os
from pathlib import Path
import folder_paths
import nodes
from comfy_api.input_impl import VideoFromFile
from comfy_api.v3 import io, ui
def normalize_path(path):
return path.replace("\\", "/")
class Load3D(io.ComfyNode):
@classmethod
def define_schema(cls):
input_dir = os.path.join(folder_paths.get_input_directory(), "3d")
os.makedirs(input_dir, exist_ok=True)
input_path = Path(input_dir)
base_path = Path(folder_paths.get_input_directory())
files = [
normalize_path(str(file_path.relative_to(base_path)))
for file_path in input_path.rglob("*")
if file_path.suffix.lower() in {".gltf", ".glb", ".obj", ".fbx", ".stl"}
]
return io.Schema(
node_id="Load3D_V3",
display_name="Load 3D _V3",
category="3d",
is_experimental=True,
inputs=[
io.Combo.Input(id="model_file", options=sorted(files), upload=io.UploadType.model),
io.Load3D.Input(id="image"),
io.Int.Input(id="width", default=1024, min=1, max=4096, step=1),
io.Int.Input(id="height", default=1024, min=1, max=4096, step=1),
],
outputs=[
io.Image.Output(display_name="image"),
io.Mask.Output(display_name="mask"),
io.String.Output(display_name="mesh_path"),
io.Image.Output(display_name="normal"),
io.Image.Output(display_name="lineart"),
io.Load3DCamera.Output(display_name="camera_info"),
io.Video.Output(display_name="recording_video"),
],
)
@classmethod
def execute(cls, model_file, image, **kwargs):
image_path = folder_paths.get_annotated_filepath(image["image"])
mask_path = folder_paths.get_annotated_filepath(image["mask"])
normal_path = folder_paths.get_annotated_filepath(image["normal"])
lineart_path = folder_paths.get_annotated_filepath(image["lineart"])
load_image_node = nodes.LoadImage()
output_image, ignore_mask = load_image_node.load_image(image=image_path)
ignore_image, output_mask = load_image_node.load_image(image=mask_path)
normal_image, ignore_mask2 = load_image_node.load_image(image=normal_path)
lineart_image, ignore_mask3 = load_image_node.load_image(image=lineart_path)
video = None
if image["recording"] != "":
recording_video_path = folder_paths.get_annotated_filepath(image["recording"])
video = VideoFromFile(recording_video_path)
return io.NodeOutput(
output_image, output_mask, model_file, normal_image, lineart_image, image["camera_info"], video
)
class Load3DAnimation(io.ComfyNode):
@classmethod
def define_schema(cls):
input_dir = os.path.join(folder_paths.get_input_directory(), "3d")
os.makedirs(input_dir, exist_ok=True)
input_path = Path(input_dir)
base_path = Path(folder_paths.get_input_directory())
files = [
normalize_path(str(file_path.relative_to(base_path)))
for file_path in input_path.rglob("*")
if file_path.suffix.lower() in {".gltf", ".glb", ".fbx"}
]
return io.Schema(
node_id="Load3DAnimation_V3",
display_name="Load 3D - Animation _V3",
category="3d",
is_experimental=True,
inputs=[
io.Combo.Input(id="model_file", options=sorted(files), upload=io.UploadType.model),
io.Load3DAnimation.Input(id="image"),
io.Int.Input(id="width", default=1024, min=1, max=4096, step=1),
io.Int.Input(id="height", default=1024, min=1, max=4096, step=1),
],
outputs=[
io.Image.Output(display_name="image"),
io.Mask.Output(display_name="mask"),
io.String.Output(display_name="mesh_path"),
io.Image.Output(display_name="normal"),
io.Load3DCamera.Output(display_name="camera_info"),
io.Video.Output(display_name="recording_video"),
],
)
@classmethod
def execute(cls, model_file, image, **kwargs):
image_path = folder_paths.get_annotated_filepath(image["image"])
mask_path = folder_paths.get_annotated_filepath(image["mask"])
normal_path = folder_paths.get_annotated_filepath(image["normal"])
load_image_node = nodes.LoadImage()
output_image, ignore_mask = load_image_node.load_image(image=image_path)
ignore_image, output_mask = load_image_node.load_image(image=mask_path)
normal_image, ignore_mask2 = load_image_node.load_image(image=normal_path)
video = None
if image['recording'] != "":
recording_video_path = folder_paths.get_annotated_filepath(image["recording"])
video = VideoFromFile(recording_video_path)
return io.NodeOutput(output_image, output_mask, model_file, normal_image, image["camera_info"], video)
class Preview3D(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="Preview3D_V3", # frontend expects "Preview3D" to work
display_name="Preview 3D _V3",
category="3d",
is_experimental=True,
is_output_node=True,
inputs=[
io.String.Input(id="model_file", default="", multiline=False),
io.Load3DCamera.Input(id="camera_info", optional=True),
],
outputs=[],
)
@classmethod
def execute(cls, model_file, camera_info=None):
return io.NodeOutput(ui=ui.PreviewUI3D([model_file, camera_info], cls=cls))
class Preview3DAnimation(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="Preview3DAnimation_V3", # frontend expects "Preview3DAnimation" to work
display_name="Preview 3D - Animation _V3",
category="3d",
is_experimental=True,
is_output_node=True,
inputs=[
io.String.Input(id="model_file", default="", multiline=False),
io.Load3DCamera.Input(id="camera_info", optional=True),
],
outputs=[],
)
@classmethod
def execute(cls, model_file, camera_info=None):
return io.NodeOutput(ui=ui.PreviewUI3D([model_file, camera_info], cls=cls))
NODES_LIST = [
Load3D,
Load3DAnimation,
Preview3D,
Preview3DAnimation,
]

View File

@ -0,0 +1,138 @@
from __future__ import annotations
import logging
import os
from enum import Enum
import torch
import comfy.model_management
import comfy.utils
import folder_paths
from comfy_api.v3 import io
CLAMP_QUANTILE = 0.99
def extract_lora(diff, rank):
conv2d = (len(diff.shape) == 4)
kernel_size = None if not conv2d else diff.size()[2:4]
conv2d_3x3 = conv2d and kernel_size != (1, 1)
out_dim, in_dim = diff.size()[0:2]
rank = min(rank, in_dim, out_dim)
if conv2d:
if conv2d_3x3:
diff = diff.flatten(start_dim=1)
else:
diff = diff.squeeze()
U, S, Vh = torch.linalg.svd(diff.float())
U = U[:, :rank]
S = S[:rank]
U = U @ torch.diag(S)
Vh = Vh[:rank, :]
dist = torch.cat([U.flatten(), Vh.flatten()])
hi_val = torch.quantile(dist, CLAMP_QUANTILE)
low_val = -hi_val
U = U.clamp(low_val, hi_val)
Vh = Vh.clamp(low_val, hi_val)
if conv2d:
U = U.reshape(out_dim, rank, 1, 1)
Vh = Vh.reshape(rank, in_dim, kernel_size[0], kernel_size[1])
return (U, Vh)
class LORAType(Enum):
STANDARD = 0
FULL_DIFF = 1
LORA_TYPES = {
"standard": LORAType.STANDARD,
"full_diff": LORAType.FULL_DIFF,
}
def calc_lora_model(model_diff, rank, prefix_model, prefix_lora, output_sd, lora_type, bias_diff=False):
comfy.model_management.load_models_gpu([model_diff], force_patch_weights=True)
sd = model_diff.model_state_dict(filter_prefix=prefix_model)
for k in sd:
if k.endswith(".weight"):
weight_diff = sd[k]
if lora_type == LORAType.STANDARD:
if weight_diff.ndim < 2:
if bias_diff:
output_sd["{}{}.diff".format(prefix_lora, k[len(prefix_model):-7])] = weight_diff.contiguous().half().cpu()
continue
try:
out = extract_lora(weight_diff, rank)
output_sd["{}{}.lora_up.weight".format(prefix_lora, k[len(prefix_model):-7])] = out[0].contiguous().half().cpu()
output_sd["{}{}.lora_down.weight".format(prefix_lora, k[len(prefix_model):-7])] = out[1].contiguous().half().cpu()
except Exception:
logging.warning("Could not generate lora weights for key {}, is the weight difference a zero?".format(k))
elif lora_type == LORAType.FULL_DIFF:
output_sd["{}{}.diff".format(prefix_lora, k[len(prefix_model):-7])] = weight_diff.contiguous().half().cpu()
elif bias_diff and k.endswith(".bias"):
output_sd["{}{}.diff_b".format(prefix_lora, k[len(prefix_model):-5])] = sd[k].contiguous().half().cpu()
return output_sd
class LoraSave(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="LoraSave_V3",
display_name="Extract and Save Lora _V3",
category="_for_testing",
is_output_node=True,
inputs=[
io.String.Input(id="filename_prefix", default="loras/ComfyUI_extracted_lora"),
io.Int.Input(id="rank", default=8, min=1, max=4096, step=1),
io.Combo.Input(id="lora_type", options=list(LORA_TYPES.keys())),
io.Boolean.Input(id="bias_diff", default=True),
io.Model.Input(
id="model_diff", optional=True, tooltip="The ModelSubtract output to be converted to a lora."
),
io.Clip.Input(
id="text_encoder_diff", optional=True, tooltip="The CLIPSubtract output to be converted to a lora."
),
],
outputs=[],
is_experimental=True,
)
@classmethod
def execute(cls, filename_prefix, rank, lora_type, bias_diff, model_diff=None, text_encoder_diff=None):
if model_diff is None and text_encoder_diff is None:
return io.NodeOutput()
lora_type = LORA_TYPES.get(lora_type)
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(
filename_prefix, folder_paths.get_output_directory()
)
output_sd = {}
if model_diff is not None:
output_sd = calc_lora_model(
model_diff, rank, "diffusion_model.", "diffusion_model.", output_sd, lora_type, bias_diff=bias_diff
)
if text_encoder_diff is not None:
output_sd = calc_lora_model(
text_encoder_diff.patcher, rank, "", "text_encoders.", output_sd, lora_type, bias_diff=bias_diff
)
output_checkpoint = f"{filename}_{counter:05}_.safetensors"
output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
comfy.utils.save_torch_file(output_sd, output_checkpoint, metadata=None)
return io.NodeOutput()
NODES_LIST = [
LoraSave,
]

File diff suppressed because one or more lines are too long

View File

@ -0,0 +1,116 @@
from __future__ import annotations
import torch
from comfy_api.v3 import io
class CLIPTextEncodeLumina2(io.ComfyNode):
SYSTEM_PROMPT = {
"superior": "You are an assistant designed to generate superior images with the superior "
"degree of image-text alignment based on textual prompts or user prompts.",
"alignment": "You are an assistant designed to generate high-quality images with the "
"highest degree of image-text alignment based on textual prompts."
}
SYSTEM_PROMPT_TIP = "Lumina2 provide two types of system prompts:" \
"Superior: You are an assistant designed to generate superior images with the superior "\
"degree of image-text alignment based on textual prompts or user prompts. "\
"Alignment: You are an assistant designed to generate high-quality images with the highest "\
"degree of image-text alignment based on textual prompts."
@classmethod
def define_schema(cls):
return io.Schema(
node_id="CLIPTextEncodeLumina2_V3",
display_name="CLIP Text Encode for Lumina2 _V3",
category="conditioning",
description="Encodes a system prompt and a user prompt using a CLIP model into an embedding "
"that can be used to guide the diffusion model towards generating specific images.",
inputs=[
io.Combo.Input(id="system_prompt", options=list(cls.SYSTEM_PROMPT.keys()), tooltip=cls.SYSTEM_PROMPT_TIP),
io.String.Input(id="user_prompt", multiline=True, dynamic_prompts=True, tooltip="The text to be encoded."),
io.Clip.Input(id="clip", tooltip="The CLIP model used for encoding the text."),
],
outputs=[
io.Conditioning.Output(tooltip="A conditioning containing the embedded text used to guide the diffusion model."),
],
)
@classmethod
def execute(cls, system_prompt, user_prompt, clip):
if clip is None:
raise RuntimeError(
"ERROR: clip input is invalid: None\n\n"
"If the clip is from a checkpoint loader node your checkpoint does not contain a valid clip or text encoder model."
)
system_prompt = cls.SYSTEM_PROMPT[system_prompt]
prompt = f'{system_prompt} <Prompt Start> {user_prompt}'
tokens = clip.tokenize(prompt)
return io.NodeOutput(clip.encode_from_tokens_scheduled(tokens))
class RenormCFG(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="RenormCFG_V3",
category="advanced/model",
inputs=[
io.Model.Input(id="model"),
io.Float.Input(id="cfg_trunc", default=100, min=0.0, max=100.0, step=0.01),
io.Float.Input(id="renorm_cfg", default=1.0, min=0.0, max=100.0, step=0.01),
],
outputs=[
io.Model.Output(),
],
)
@classmethod
def execute(cls, model, cfg_trunc, renorm_cfg):
def renorm_cfg_func(args):
cond_denoised = args["cond_denoised"]
uncond_denoised = args["uncond_denoised"]
cond_scale = args["cond_scale"]
timestep = args["timestep"]
x_orig = args["input"]
in_channels = model.model.diffusion_model.in_channels
if timestep[0] < cfg_trunc:
cond_eps, uncond_eps = cond_denoised[:, :in_channels], uncond_denoised[:, :in_channels]
cond_rest, _ = cond_denoised[:, in_channels:], uncond_denoised[:, in_channels:]
half_eps = uncond_eps + cond_scale * (cond_eps - uncond_eps)
half_rest = cond_rest
if float(renorm_cfg) > 0.0:
ori_pos_norm = torch.linalg.vector_norm(
cond_eps,
dim=tuple(range(1, len(cond_eps.shape))),
keepdim=True
)
max_new_norm = ori_pos_norm * float(renorm_cfg)
new_pos_norm = torch.linalg.vector_norm(
half_eps, dim=tuple(range(1, len(half_eps.shape))), keepdim=True
)
if new_pos_norm >= max_new_norm:
half_eps = half_eps * (max_new_norm / new_pos_norm)
else:
cond_eps, uncond_eps = cond_denoised[:, :in_channels], uncond_denoised[:, :in_channels]
cond_rest, _ = cond_denoised[:, in_channels:], uncond_denoised[:, in_channels:]
half_eps = cond_eps
half_rest = cond_rest
cfg_result = torch.cat([half_eps, half_rest], dim=1)
# cfg_result = uncond_denoised + (cond_denoised - uncond_denoised) * cond_scale
return x_orig - cfg_result
m = model.clone()
m.set_model_sampler_cfg_function(renorm_cfg_func)
return io.NodeOutput(m)
NODES_LIST = [
CLIPTextEncodeLumina2,
RenormCFG,
]

View File

@ -2321,8 +2321,13 @@ def init_builtin_extra_nodes():
"v3/nodes_gits.py",
"v3/nodes_hidream.py",
"v3/nodes_images.py",
"v3/nodes_ip2p.py",
"v3/nodes_latent.py",
"v3/nodes_load_3d.py",
"v3/nodes_lora_extract.py",
"v3/nodes_lotus.py",
"v3/nodes_lt.py",
"v3/nodes_lumina2.py",
"v3/nodes_mask.py",
"v3/nodes_mochi.py",
"v3/nodes_model_advanced.py",