converted 6 more files

This commit is contained in:
bigcat88 2025-07-25 14:35:04 +03:00
parent 631916dfb2
commit 5a8c426112
No known key found for this signature in database
GPG Key ID: 1F0BF0EC3CF22721
10 changed files with 2101 additions and 4 deletions

View File

@ -51,7 +51,7 @@ class EmptyLatentAudio(io.ComfyNode):
inputs=[ inputs=[
io.Float.Input("seconds", default=47.6, min=1.0, max=1000.0, step=0.1), io.Float.Input("seconds", default=47.6, min=1.0, max=1000.0, step=0.1),
io.Int.Input( io.Int.Input(
id="batch_size", default=1, min=1, max=4096, tooltip="The number of latent images in the batch." "batch_size", default=1, min=1, max=4096, tooltip="The number of latent images in the batch."
), ),
], ],
outputs=[io.Latent.Output()], outputs=[io.Latent.Output()],

View File

@ -0,0 +1,672 @@
from __future__ import annotations
import json
import os
import struct
import numpy as np
import torch
import comfy.model_management
import folder_paths
from comfy.cli_args import args
from comfy.ldm.modules.diffusionmodules.mmdit import (
get_1d_sincos_pos_embed_from_grid_torch,
)
from comfy_api.latest import io
class VOXEL:
def __init__(self, data):
self.data = data
class MESH:
def __init__(self, vertices, faces):
self.vertices = vertices
self.faces = faces
def voxel_to_mesh(voxels, threshold=0.5, device=None):
if device is None:
device = torch.device("cpu")
voxels = voxels.to(device)
binary = (voxels > threshold).float()
padded = torch.nn.functional.pad(binary, (1, 1, 1, 1, 1, 1), 'constant', 0)
D, H, W = binary.shape
neighbors = torch.tensor([
[0, 0, 1],
[0, 0, -1],
[0, 1, 0],
[0, -1, 0],
[1, 0, 0],
[-1, 0, 0]
], device=device)
z, y, x = torch.meshgrid(
torch.arange(D, device=device),
torch.arange(H, device=device),
torch.arange(W, device=device),
indexing='ij'
)
voxel_indices = torch.stack([z.flatten(), y.flatten(), x.flatten()], dim=1)
solid_mask = binary.flatten() > 0
solid_indices = voxel_indices[solid_mask]
corner_offsets = [
torch.tensor([
[0, 0, 1], [0, 1, 1], [1, 1, 1], [1, 0, 1]
], device=device),
torch.tensor([
[0, 0, 0], [1, 0, 0], [1, 1, 0], [0, 1, 0]
], device=device),
torch.tensor([
[0, 1, 0], [1, 1, 0], [1, 1, 1], [0, 1, 1]
], device=device),
torch.tensor([
[0, 0, 0], [0, 0, 1], [1, 0, 1], [1, 0, 0]
], device=device),
torch.tensor([
[1, 0, 1], [1, 1, 1], [1, 1, 0], [1, 0, 0]
], device=device),
torch.tensor([
[0, 1, 0], [0, 1, 1], [0, 0, 1], [0, 0, 0]
], device=device)
]
all_vertices = []
all_indices = []
vertex_count = 0
for face_idx, offset in enumerate(neighbors):
neighbor_indices = solid_indices + offset
padded_indices = neighbor_indices + 1
is_exposed = padded[
padded_indices[:, 0],
padded_indices[:, 1],
padded_indices[:, 2]
] == 0
if not is_exposed.any():
continue
exposed_indices = solid_indices[is_exposed]
corners = corner_offsets[face_idx].unsqueeze(0)
face_vertices = exposed_indices.unsqueeze(1) + corners
all_vertices.append(face_vertices.reshape(-1, 3))
num_faces = exposed_indices.shape[0]
face_indices = torch.arange(
vertex_count,
vertex_count + 4 * num_faces,
device=device
).reshape(-1, 4)
all_indices.append(torch.stack([face_indices[:, 0], face_indices[:, 1], face_indices[:, 2]], dim=1))
all_indices.append(torch.stack([face_indices[:, 0], face_indices[:, 2], face_indices[:, 3]], dim=1))
vertex_count += 4 * num_faces
if len(all_vertices) > 0:
vertices = torch.cat(all_vertices, dim=0)
faces = torch.cat(all_indices, dim=0)
else:
vertices = torch.zeros((1, 3))
faces = torch.zeros((1, 3))
v_min = 0
v_max = max(voxels.shape)
vertices = vertices - (v_min + v_max) / 2
scale = (v_max - v_min) / 2
if scale > 0:
vertices = vertices / scale
vertices = torch.fliplr(vertices)
return vertices, faces
def voxel_to_mesh_surfnet(voxels, threshold=0.5, device=None):
if device is None:
device = torch.device("cpu")
voxels = voxels.to(device)
D, H, W = voxels.shape
padded = torch.nn.functional.pad(voxels, (1, 1, 1, 1, 1, 1), 'constant', 0)
z, y, x = torch.meshgrid(
torch.arange(D, device=device),
torch.arange(H, device=device),
torch.arange(W, device=device),
indexing='ij'
)
cell_positions = torch.stack([z.flatten(), y.flatten(), x.flatten()], dim=1)
corner_offsets = torch.tensor([
[0, 0, 0], [1, 0, 0], [0, 1, 0], [1, 1, 0],
[0, 0, 1], [1, 0, 1], [0, 1, 1], [1, 1, 1]
], device=device)
corner_values = torch.zeros((cell_positions.shape[0], 8), device=device)
for c, (dz, dy, dx) in enumerate(corner_offsets):
corner_values[:, c] = padded[
cell_positions[:, 0] + dz,
cell_positions[:, 1] + dy,
cell_positions[:, 2] + dx
]
corner_signs = corner_values > threshold
has_inside = torch.any(corner_signs, dim=1)
has_outside = torch.any(~corner_signs, dim=1)
contains_surface = has_inside & has_outside
active_cells = cell_positions[contains_surface]
active_signs = corner_signs[contains_surface]
active_values = corner_values[contains_surface]
if active_cells.shape[0] == 0:
return torch.zeros((0, 3), device=device), torch.zeros((0, 3), dtype=torch.long, device=device)
edges = torch.tensor([
[0, 1], [0, 2], [0, 4], [1, 3],
[1, 5], [2, 3], [2, 6], [3, 7],
[4, 5], [4, 6], [5, 7], [6, 7]
], device=device)
cell_vertices = {}
progress = comfy.utils.ProgressBar(100)
for edge_idx, (e1, e2) in enumerate(edges):
progress.update(1)
crossing = active_signs[:, e1] != active_signs[:, e2]
if not crossing.any():
continue
cell_indices = torch.nonzero(crossing, as_tuple=True)[0]
v1 = active_values[cell_indices, e1]
v2 = active_values[cell_indices, e2]
t = torch.zeros_like(v1, device=device)
denom = v2 - v1
valid = denom != 0
t[valid] = (threshold - v1[valid]) / denom[valid]
t[~valid] = 0.5
p1 = corner_offsets[e1].float()
p2 = corner_offsets[e2].float()
intersection = p1.unsqueeze(0) + t.unsqueeze(1) * (p2.unsqueeze(0) - p1.unsqueeze(0))
for i, point in zip(cell_indices.tolist(), intersection):
if i not in cell_vertices:
cell_vertices[i] = []
cell_vertices[i].append(point)
# Calculate the final vertices as the average of intersection points for each cell
vertices = []
vertex_lookup = {}
vert_progress_mod = round(len(cell_vertices)/50)
for i, points in cell_vertices.items():
if not i % vert_progress_mod:
progress.update(1)
if points:
vertex = torch.stack(points).mean(dim=0)
vertex = vertex + active_cells[i].float()
vertex_lookup[tuple(active_cells[i].tolist())] = len(vertices)
vertices.append(vertex)
if not vertices:
return torch.zeros((0, 3), device=device), torch.zeros((0, 3), dtype=torch.long, device=device)
final_vertices = torch.stack(vertices)
inside_corners_mask = active_signs
outside_corners_mask = ~active_signs
inside_counts = inside_corners_mask.sum(dim=1, keepdim=True).float()
outside_counts = outside_corners_mask.sum(dim=1, keepdim=True).float()
inside_pos = torch.zeros((active_cells.shape[0], 3), device=device)
outside_pos = torch.zeros((active_cells.shape[0], 3), device=device)
for i in range(8):
mask_inside = inside_corners_mask[:, i].unsqueeze(1)
mask_outside = outside_corners_mask[:, i].unsqueeze(1)
inside_pos += corner_offsets[i].float().unsqueeze(0) * mask_inside
outside_pos += corner_offsets[i].float().unsqueeze(0) * mask_outside
inside_pos /= inside_counts
outside_pos /= outside_counts
gradients = inside_pos - outside_pos
pos_dirs = torch.tensor([
[1, 0, 0],
[0, 1, 0],
[0, 0, 1]
], device=device)
cross_products = [
torch.linalg.cross(pos_dirs[i].float(), pos_dirs[j].float())
for i in range(3) for j in range(i+1, 3)
]
faces = []
all_keys = set(vertex_lookup.keys())
face_progress_mod = round(len(active_cells)/38*3)
for pair_idx, (i, j) in enumerate([(0,1), (0,2), (1,2)]):
dir_i = pos_dirs[i]
dir_j = pos_dirs[j]
cross_product = cross_products[pair_idx]
ni_positions = active_cells + dir_i
nj_positions = active_cells + dir_j
diag_positions = active_cells + dir_i + dir_j
alignments = torch.matmul(gradients, cross_product)
valid_quads = []
quad_indices = []
for idx, active_cell in enumerate(active_cells):
if not idx % face_progress_mod:
progress.update(1)
cell_key = tuple(active_cell.tolist())
ni_key = tuple(ni_positions[idx].tolist())
nj_key = tuple(nj_positions[idx].tolist())
diag_key = tuple(diag_positions[idx].tolist())
if cell_key in all_keys and ni_key in all_keys and nj_key in all_keys and diag_key in all_keys:
v0 = vertex_lookup[cell_key]
v1 = vertex_lookup[ni_key]
v2 = vertex_lookup[nj_key]
v3 = vertex_lookup[diag_key]
valid_quads.append((v0, v1, v2, v3))
quad_indices.append(idx)
for q_idx, (v0, v1, v2, v3) in enumerate(valid_quads):
cell_idx = quad_indices[q_idx]
if alignments[cell_idx] > 0:
faces.append(torch.tensor([v0, v1, v3], device=device, dtype=torch.long))
faces.append(torch.tensor([v0, v3, v2], device=device, dtype=torch.long))
else:
faces.append(torch.tensor([v0, v3, v1], device=device, dtype=torch.long))
faces.append(torch.tensor([v0, v2, v3], device=device, dtype=torch.long))
if faces:
faces = torch.stack(faces)
else:
faces = torch.zeros((0, 3), dtype=torch.long, device=device)
v_min = 0
v_max = max(D, H, W)
final_vertices = final_vertices - (v_min + v_max) / 2
scale = (v_max - v_min) / 2
if scale > 0:
final_vertices = final_vertices / scale
final_vertices = torch.fliplr(final_vertices)
return final_vertices, faces
def save_glb(vertices, faces, filepath, metadata=None):
"""
Save PyTorch tensor vertices and faces as a GLB file without external dependencies.
Parameters:
vertices: torch.Tensor of shape (N, 3) - The vertex coordinates
faces: torch.Tensor of shape (M, 3) - The face indices (triangle faces)
filepath: str - Output filepath (should end with .glb)
"""
# Convert tensors to numpy arrays
vertices_np = vertices.cpu().numpy().astype(np.float32)
faces_np = faces.cpu().numpy().astype(np.uint32)
vertices_buffer = vertices_np.tobytes()
indices_buffer = faces_np.tobytes()
def pad_to_4_bytes(buffer):
padding_length = (4 - (len(buffer) % 4)) % 4
return buffer + b'\x00' * padding_length
vertices_buffer_padded = pad_to_4_bytes(vertices_buffer)
indices_buffer_padded = pad_to_4_bytes(indices_buffer)
buffer_data = vertices_buffer_padded + indices_buffer_padded
vertices_byte_length = len(vertices_buffer)
vertices_byte_offset = 0
indices_byte_length = len(indices_buffer)
indices_byte_offset = len(vertices_buffer_padded)
gltf = {
"asset": {"version": "2.0", "generator": "ComfyUI"},
"buffers": [
{
"byteLength": len(buffer_data)
}
],
"bufferViews": [
{
"buffer": 0,
"byteOffset": vertices_byte_offset,
"byteLength": vertices_byte_length,
"target": 34962 # ARRAY_BUFFER
},
{
"buffer": 0,
"byteOffset": indices_byte_offset,
"byteLength": indices_byte_length,
"target": 34963 # ELEMENT_ARRAY_BUFFER
}
],
"accessors": [
{
"bufferView": 0,
"byteOffset": 0,
"componentType": 5126, # FLOAT
"count": len(vertices_np),
"type": "VEC3",
"max": vertices_np.max(axis=0).tolist(),
"min": vertices_np.min(axis=0).tolist()
},
{
"bufferView": 1,
"byteOffset": 0,
"componentType": 5125, # UNSIGNED_INT
"count": faces_np.size,
"type": "SCALAR"
}
],
"meshes": [
{
"primitives": [
{
"attributes": {
"POSITION": 0
},
"indices": 1,
"mode": 4 # TRIANGLES
}
]
}
],
"nodes": [
{
"mesh": 0
}
],
"scenes": [
{
"nodes": [0]
}
],
"scene": 0
}
if metadata is not None:
gltf["asset"]["extras"] = metadata
# Convert the JSON to bytes
gltf_json = json.dumps(gltf).encode('utf8')
def pad_json_to_4_bytes(buffer):
padding_length = (4 - (len(buffer) % 4)) % 4
return buffer + b' ' * padding_length
gltf_json_padded = pad_json_to_4_bytes(gltf_json)
# Create the GLB header
# Magic glTF
glb_header = struct.pack('<4sII', b'glTF', 2, 12 + 8 + len(gltf_json_padded) + 8 + len(buffer_data))
# Create JSON chunk header (chunk type 0)
json_chunk_header = struct.pack('<II', len(gltf_json_padded), 0x4E4F534A) # "JSON" in little endian
# Create BIN chunk header (chunk type 1)
bin_chunk_header = struct.pack('<II', len(buffer_data), 0x004E4942) # "BIN\0" in little endian
# Write the GLB file
with open(filepath, 'wb') as f:
f.write(glb_header)
f.write(json_chunk_header)
f.write(gltf_json_padded)
f.write(bin_chunk_header)
f.write(buffer_data)
return filepath
class EmptyLatentHunyuan3Dv2(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="EmptyLatentHunyuan3Dv2_V3",
category="latent/3d",
inputs=[
io.Int.Input("resolution", default=3072, min=1, max=8192),
io.Int.Input("batch_size", default=1, min=1, max=4096, tooltip="The number of latent images in the batch.")
],
outputs=[
io.Latent.Output()
]
)
@classmethod
def execute(cls, resolution, batch_size):
latent = torch.zeros([batch_size, 64, resolution], device=comfy.model_management.intermediate_device())
return io.NodeOutput({"samples": latent, "type": "hunyuan3dv2"})
class Hunyuan3Dv2Conditioning(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="Hunyuan3Dv2Conditioning_V3",
category="conditioning/video_models",
inputs=[
io.ClipVisionOutput.Input("clip_vision_output")
],
outputs=[
io.Conditioning.Output(display_name="positive"),
io.Conditioning.Output(display_name="negative")
]
)
@classmethod
def execute(cls, clip_vision_output):
embeds = clip_vision_output.last_hidden_state
positive = [[embeds, {}]]
negative = [[torch.zeros_like(embeds), {}]]
return io.NodeOutput(positive, negative)
class Hunyuan3Dv2ConditioningMultiView(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="Hunyuan3Dv2ConditioningMultiView_V3",
category="conditioning/video_models",
inputs=[
io.ClipVisionOutput.Input("front", optional=True),
io.ClipVisionOutput.Input("left", optional=True),
io.ClipVisionOutput.Input("back", optional=True),
io.ClipVisionOutput.Input("right", optional=True)
],
outputs=[
io.Conditioning.Output(display_name="positive"),
io.Conditioning.Output(display_name="negative")
]
)
@classmethod
def execute(cls, front=None, left=None, back=None, right=None):
all_embeds = [front, left, back, right]
out = []
pos_embeds = None
for i, e in enumerate(all_embeds):
if e is not None:
if pos_embeds is None:
pos_embeds = get_1d_sincos_pos_embed_from_grid_torch(e.last_hidden_state.shape[-1], torch.arange(4))
out.append(e.last_hidden_state + pos_embeds[i].reshape(1, 1, -1))
embeds = torch.cat(out, dim=1)
positive = [[embeds, {}]]
negative = [[torch.zeros_like(embeds), {}]]
return io.NodeOutput(positive, negative)
class SaveGLB(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="SaveGLB_V3",
category="3d",
is_output_node=True,
inputs=[
io.Mesh.Input("mesh"),
io.String.Input("filename_prefix", default="mesh/ComfyUI")
],
outputs=[],
hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo]
)
@classmethod
def execute(cls, mesh, filename_prefix):
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, folder_paths.get_output_directory())
results = []
metadata = {}
if not args.disable_metadata:
if cls.hidden.prompt is not None:
metadata["prompt"] = json.dumps(cls.hidden.prompt)
if cls.hidden.extra_pnginfo is not None:
for x in cls.hidden.extra_pnginfo:
metadata[x] = json.dumps(cls.hidden.extra_pnginfo[x])
for i in range(mesh.vertices.shape[0]):
f = f"{filename}_{counter:05}_.glb"
save_glb(mesh.vertices[i], mesh.faces[i], os.path.join(full_output_folder, f), metadata)
results.append({
"filename": f,
"subfolder": subfolder,
"type": "output"
})
counter += 1
return io.NodeOutput(ui={"ui": {"3d": results}}) # TODO: do we need an additional type of preview for this?
class VAEDecodeHunyuan3D(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="VAEDecodeHunyuan3D_V3",
category="latent/3d",
inputs=[
io.Latent.Input("samples"),
io.Vae.Input("vae"),
io.Int.Input("num_chunks", default=8000, min=1000, max=500000),
io.Int.Input("octree_resolution", default=256, min=16, max=512)
],
outputs=[
io.Voxel.Output()
]
)
@classmethod
def execute(cls, vae, samples, num_chunks, octree_resolution):
voxels = VOXEL(vae.decode(samples["samples"], vae_options={"num_chunks": num_chunks, "octree_resolution": octree_resolution}))
return io.NodeOutput(voxels)
class VoxelToMesh(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="VoxelToMesh_V3",
category="3d",
inputs=[
io.Voxel.Input("voxel"),
io.Combo.Input("algorithm", options=["surface net", "basic"]),
io.Float.Input("threshold", default=0.6, min=-1.0, max=1.0, step=0.01)
],
outputs=[
io.Mesh.Output()
]
)
@classmethod
def execute(cls, voxel, algorithm, threshold):
vertices = []
faces = []
if algorithm == "basic":
mesh_function = voxel_to_mesh
elif algorithm == "surface net":
mesh_function = voxel_to_mesh_surfnet
for x in voxel.data:
v, f = mesh_function(x, threshold=threshold, device=None)
vertices.append(v)
faces.append(f)
return io.NodeOutput(MESH(torch.stack(vertices), torch.stack(faces)))
class VoxelToMeshBasic(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="VoxelToMeshBasic_V3",
category="3d",
inputs=[
io.Voxel.Input("voxel"),
io.Float.Input("threshold", default=0.6, min=-1.0, max=1.0, step=0.01)
],
outputs=[
io.Mesh.Output()
]
)
@classmethod
def execute(cls, voxel, threshold):
vertices = []
faces = []
for x in voxel.data:
v, f = voxel_to_mesh(x, threshold=threshold, device=None)
vertices.append(v)
faces.append(f)
return io.NodeOutput(MESH(torch.stack(vertices), torch.stack(faces)))
NODES_LIST = [
EmptyLatentHunyuan3Dv2,
Hunyuan3Dv2Conditioning,
Hunyuan3Dv2ConditioningMultiView,
SaveGLB,
VAEDecodeHunyuan3D,
VoxelToMesh,
VoxelToMeshBasic,
]

View File

@ -127,12 +127,12 @@ class LTXVAddGuide(io.ComfyNode):
io.Vae.Input("vae"), io.Vae.Input("vae"),
io.Latent.Input("latent"), io.Latent.Input("latent"),
io.Image.Input( io.Image.Input(
id="image", "image",
tooltip="Image or video to condition the latent video on. Must be 8*n + 1 frames. " tooltip="Image or video to condition the latent video on. Must be 8*n + 1 frames. "
"If the video is not 8*n + 1 frames, it will be cropped to the nearest 8*n + 1 frames.", "If the video is not 8*n + 1 frames, it will be cropped to the nearest 8*n + 1 frames.",
), ),
io.Int.Input( io.Int.Input(
id="frame_idx", "frame_idx",
default=0, default=0,
min=-9999, min=-9999,
max=9999, max=9999,

View File

@ -0,0 +1,51 @@
from __future__ import annotations
import torch
import torch.nn.functional as F
from comfy_api.latest import io
class Mahiro(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="Mahiro_V3",
display_name="Mahiro is so cute that she deserves a better guidance function!! (。・ω・。) _V3",
category="_for_testing",
description="Modify the guidance to scale more on the 'direction' of the positive prompt rather than the difference between the negative prompt.",
is_experimental=True,
inputs=[
io.Model.Input("model")
],
outputs=[
io.Model.Output(display_name="patched_model")
]
)
@classmethod
def execute(cls, model):
m = model.clone()
def mahiro_normd(args):
scale: float = args['cond_scale']
cond_p: torch.Tensor = args['cond_denoised']
uncond_p: torch.Tensor = args['uncond_denoised']
#naive leap
leap = cond_p * scale
#sim with uncond leap
u_leap = uncond_p * scale
cfg = args["denoised"]
merge = (leap + cfg) / 2
normu = torch.sqrt(u_leap.abs()) * u_leap.sign()
normm = torch.sqrt(merge.abs()) * merge.sign()
sim = F.cosine_similarity(normu, normm).mean()
simsc = 2 * (sim+1)
wm = (simsc*cfg + (4-simsc)*leap) / 4
return wm
m.set_model_sampler_post_cfg_function(mahiro_normd)
return io.NodeOutput(m)
NODES_LIST = [
Mahiro,
]

View File

@ -17,7 +17,7 @@ class LCM(comfy.model_sampling.EPS):
x0 = model_input - model_output * sigma x0 = model_input - model_output * sigma
sigma_data = 0.5 sigma_data = 0.5
scaled_timestep = timestep * 10.0 #timestep_scaling scaled_timestep = timestep * 10.0 # timestep_scaling
c_skip = sigma_data**2 / (scaled_timestep**2 + sigma_data**2) c_skip = sigma_data**2 / (scaled_timestep**2 + sigma_data**2)
c_out = scaled_timestep / (scaled_timestep**2 + sigma_data**2) ** 0.5 c_out = scaled_timestep / (scaled_timestep**2 + sigma_data**2) ** 0.5

View File

@ -0,0 +1,422 @@
from __future__ import annotations
import json
import os
import torch
import comfy.model_base
import comfy.model_management
import comfy.model_sampling
import comfy.sd
import comfy.utils
import folder_paths
from comfy.cli_args import args
from comfy_api.latest import io
def save_checkpoint(model, clip=None, vae=None, clip_vision=None, filename_prefix=None, output_dir=None, prompt=None, extra_pnginfo=None):
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, output_dir)
prompt_info = ""
if prompt is not None:
prompt_info = json.dumps(prompt)
metadata = {}
enable_modelspec = True
if isinstance(model.model, comfy.model_base.SDXL):
if isinstance(model.model, comfy.model_base.SDXL_instructpix2pix):
metadata["modelspec.architecture"] = "stable-diffusion-xl-v1-edit"
else:
metadata["modelspec.architecture"] = "stable-diffusion-xl-v1-base"
elif isinstance(model.model, comfy.model_base.SDXLRefiner):
metadata["modelspec.architecture"] = "stable-diffusion-xl-v1-refiner"
elif isinstance(model.model, comfy.model_base.SVD_img2vid):
metadata["modelspec.architecture"] = "stable-video-diffusion-img2vid-v1"
elif isinstance(model.model, comfy.model_base.SD3):
metadata["modelspec.architecture"] = "stable-diffusion-v3-medium" #TODO: other SD3 variants
else:
enable_modelspec = False
if enable_modelspec:
metadata["modelspec.sai_model_spec"] = "1.0.0"
metadata["modelspec.implementation"] = "sgm"
metadata["modelspec.title"] = "{} {}".format(filename, counter)
#TODO:
# "stable-diffusion-v1", "stable-diffusion-v1-inpainting", "stable-diffusion-v2-512",
# "stable-diffusion-v2-768-v", "stable-diffusion-v2-unclip-l", "stable-diffusion-v2-unclip-h",
# "v2-inpainting"
extra_keys = {}
model_sampling = model.get_model_object("model_sampling")
if isinstance(model_sampling, comfy.model_sampling.ModelSamplingContinuousEDM):
if isinstance(model_sampling, comfy.model_sampling.V_PREDICTION):
extra_keys["edm_vpred.sigma_max"] = torch.tensor(model_sampling.sigma_max).float()
extra_keys["edm_vpred.sigma_min"] = torch.tensor(model_sampling.sigma_min).float()
if model.model.model_type == comfy.model_base.ModelType.EPS:
metadata["modelspec.predict_key"] = "epsilon"
elif model.model.model_type == comfy.model_base.ModelType.V_PREDICTION:
metadata["modelspec.predict_key"] = "v"
extra_keys["v_pred"] = torch.tensor([])
if getattr(model_sampling, "zsnr", False):
extra_keys["ztsnr"] = torch.tensor([])
if not args.disable_metadata:
metadata["prompt"] = prompt_info
if extra_pnginfo is not None:
for x in extra_pnginfo:
metadata[x] = json.dumps(extra_pnginfo[x])
output_checkpoint = f"{filename}_{counter:05}_.safetensors"
output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
comfy.sd.save_checkpoint(output_checkpoint, model, clip, vae, clip_vision, metadata=metadata, extra_keys=extra_keys)
class CheckpointSave(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="CheckpointSave_V3",
display_name="Save Checkpoint _V3",
category="advanced/model_merging",
is_output_node=True,
inputs=[
io.Model.Input("model"),
io.Clip.Input("clip"),
io.Vae.Input("vae"),
io.String.Input("filename_prefix", default="checkpoints/ComfyUI")
],
outputs=[],
hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo]
)
@classmethod
def execute(cls, model, clip, vae, filename_prefix):
save_checkpoint(model, clip=clip, vae=vae, filename_prefix=filename_prefix, output_dir=folder_paths.get_output_directory(), prompt=cls.hidden.prompt, extra_pnginfo=cls.hidden.extra_pnginfo)
return io.NodeOutput()
class CLIPAdd(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="CLIPMergeAdd_V3",
category="advanced/model_merging",
inputs=[
io.Clip.Input("clip1"),
io.Clip.Input("clip2")
],
outputs=[
io.Clip.Output()
]
)
@classmethod
def execute(cls, clip1, clip2):
m = clip1.clone()
kp = clip2.get_key_patches()
for k in kp:
if k.endswith(".position_ids") or k.endswith(".logit_scale"):
continue
m.add_patches({k: kp[k]}, 1.0, 1.0)
return io.NodeOutput(m)
class CLIPMergeSimple(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="CLIPMergeSimple_V3",
category="advanced/model_merging",
inputs=[
io.Clip.Input("clip1"),
io.Clip.Input("clip2"),
io.Float.Input("ratio", default=1.0, min=0.0, max=1.0, step=0.01)
],
outputs=[
io.Clip.Output()
]
)
@classmethod
def execute(cls, clip1, clip2, ratio):
m = clip1.clone()
kp = clip2.get_key_patches()
for k in kp:
if k.endswith(".position_ids") or k.endswith(".logit_scale"):
continue
m.add_patches({k: kp[k]}, 1.0 - ratio, ratio)
return io.NodeOutput(m)
class CLIPSave(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="CLIPSave_V3",
category="advanced/model_merging",
is_output_node=True,
inputs=[
io.Clip.Input("clip"),
io.String.Input("filename_prefix", default="clip/ComfyUI")
],
outputs=[],
hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo]
)
@classmethod
def execute(cls, clip, filename_prefix):
prompt_info = ""
if cls.hidden.prompt is not None:
prompt_info = json.dumps(cls.hidden.prompt)
metadata = {}
if not args.disable_metadata:
metadata["format"] = "pt"
metadata["prompt"] = prompt_info
if cls.hidden.extra_pnginfo is not None:
for x in cls.hidden.extra_pnginfo:
metadata[x] = json.dumps(cls.hidden.extra_pnginfo[x])
comfy.model_management.load_models_gpu([clip.load_model()], force_patch_weights=True)
clip_sd = clip.get_sd()
for prefix in ["clip_l.", "clip_g.", "clip_h.", "t5xxl.", "pile_t5xl.", "mt5xl.", "umt5xxl.", "t5base.", "gemma2_2b.", "llama.", "hydit_clip.", ""]:
k = list(filter(lambda a: a.startswith(prefix), clip_sd.keys()))
current_clip_sd = {}
for x in k:
current_clip_sd[x] = clip_sd.pop(x)
if len(current_clip_sd) == 0:
continue
p = prefix[:-1]
replace_prefix = {}
filename_prefix_ = filename_prefix
if len(p) > 0:
filename_prefix_ = "{}_{}".format(filename_prefix_, p)
replace_prefix[prefix] = ""
replace_prefix["transformer."] = ""
full_output_folder, filename, counter, subfolder, filename_prefix_ = folder_paths.get_save_image_path(filename_prefix_, folder_paths.get_output_directory())
output_checkpoint = f"{filename}_{counter:05}_.safetensors"
output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
current_clip_sd = comfy.utils.state_dict_prefix_replace(current_clip_sd, replace_prefix)
comfy.utils.save_torch_file(current_clip_sd, output_checkpoint, metadata=metadata)
return io.NodeOutput()
class CLIPSubtract(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="CLIPMergeSubtract_V3",
category="advanced/model_merging",
inputs=[
io.Clip.Input("clip1"),
io.Clip.Input("clip2"),
io.Float.Input("multiplier", default=1.0, min=-10.0, max=10.0, step=0.01)
],
outputs=[
io.Clip.Output()
]
)
@classmethod
def execute(cls, clip1, clip2, multiplier):
m = clip1.clone()
kp = clip2.get_key_patches()
for k in kp:
if k.endswith(".position_ids") or k.endswith(".logit_scale"):
continue
m.add_patches({k: kp[k]}, - multiplier, multiplier)
return io.NodeOutput(m)
class ModelAdd(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="ModelMergeAdd_V3",
category="advanced/model_merging",
inputs=[
io.Model.Input("model1"),
io.Model.Input("model2")
],
outputs=[
io.Model.Output()
]
)
@classmethod
def execute(cls, model1, model2):
m = model1.clone()
kp = model2.get_key_patches("diffusion_model.")
for k in kp:
m.add_patches({k: kp[k]}, 1.0, 1.0)
return io.NodeOutput(m)
class ModelMergeBlocks(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="ModelMergeBlocks_V3",
category="advanced/model_merging",
inputs=[
io.Model.Input("model1"),
io.Model.Input("model2"),
io.Float.Input("input", default=1.0, min=0.0, max=1.0, step=0.01),
io.Float.Input("middle", default=1.0, min=0.0, max=1.0, step=0.01),
io.Float.Input("out", default=1.0, min=0.0, max=1.0, step=0.01)
],
outputs=[
io.Model.Output()
]
)
@classmethod
def execute(cls, model1, model2, **kwargs):
m = model1.clone()
kp = model2.get_key_patches("diffusion_model.")
default_ratio = next(iter(kwargs.values()))
for k in kp:
ratio = default_ratio
k_unet = k[len("diffusion_model."):]
last_arg_size = 0
for arg in kwargs:
if k_unet.startswith(arg) and last_arg_size < len(arg):
ratio = kwargs[arg]
last_arg_size = len(arg)
m.add_patches({k: kp[k]}, 1.0 - ratio, ratio)
return io.NodeOutput(m)
class ModelMergeSimple(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="ModelMergeSimple_V3",
category="advanced/model_merging",
inputs=[
io.Model.Input("model1"),
io.Model.Input("model2"),
io.Float.Input("ratio", default=1.0, min=0.0, max=1.0, step=0.01)
],
outputs=[
io.Model.Output()
]
)
@classmethod
def execute(cls, model1, model2, ratio):
m = model1.clone()
kp = model2.get_key_patches("diffusion_model.")
for k in kp:
m.add_patches({k: kp[k]}, 1.0 - ratio, ratio)
return io.NodeOutput(m)
class ModelSave(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="ModelSave_V3",
category="advanced/model_merging",
is_output_node=True,
inputs=[
io.Model.Input("model"),
io.String.Input("filename_prefix", default="diffusion_models/ComfyUI")
],
outputs=[],
hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo]
)
@classmethod
def execute(cls, model, filename_prefix):
save_checkpoint(model, filename_prefix=filename_prefix, output_dir=folder_paths.get_output_directory(), prompt=cls.hidden.prompt, extra_pnginfo=cls.hidden.extra_pnginfo)
return io.NodeOutput()
class ModelSubtract(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="ModelMergeSubtract_V3",
category="advanced/model_merging",
inputs=[
io.Model.Input("model1"),
io.Model.Input("model2"),
io.Float.Input("multiplier", default=1.0, min=-10.0, max=10.0, step=0.01)
],
outputs=[
io.Model.Output()
]
)
@classmethod
def execute(cls, model1, model2, multiplier):
m = model1.clone()
kp = model2.get_key_patches("diffusion_model.")
for k in kp:
m.add_patches({k: kp[k]}, - multiplier, multiplier)
return io.NodeOutput(m)
class VAESave(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="VAESave_V3",
category="advanced/model_merging",
is_output_node=True,
inputs=[
io.Vae.Input("vae"),
io.String.Input("filename_prefix", default="vae/ComfyUI_vae")
],
outputs=[],
hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo]
)
@classmethod
def execute(cls, vae, filename_prefix):
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, folder_paths.get_output_directory())
prompt_info = ""
if cls.hidden.prompt is not None:
prompt_info = json.dumps(cls.hidden.prompt)
metadata = {}
if not args.disable_metadata:
metadata["prompt"] = prompt_info
if cls.hidden.extra_pnginfo is not None:
for x in cls.hidden.extra_pnginfo:
metadata[x] = json.dumps(cls.hidden.extra_pnginfo[x])
output_checkpoint = f"{filename}_{counter:05}_.safetensors"
output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
comfy.utils.save_torch_file(vae.get_sd(), output_checkpoint, metadata=metadata)
return io.NodeOutput()
NODES_LIST = [
CheckpointSave,
CLIPAdd,
CLIPMergeSimple,
CLIPSave,
CLIPSubtract,
ModelAdd,
ModelMergeBlocks,
ModelMergeSimple,
ModelSave,
ModelSubtract,
VAESave,
]

View File

@ -0,0 +1,399 @@
from __future__ import annotations
from comfy_api.latest import io
from comfy_extras.v3.nodes_model_merging import ModelMergeBlocks
class ModelMergeAuraflow(ModelMergeBlocks):
@classmethod
def define_schema(cls):
inputs = [
io.Model.Input("model1"),
io.Model.Input("model2"),
io.Float.Input("init_x_linear.", default=1.0, min=0.0, max=1.0, step=0.01),
io.Float.Input("positional_encoding", default=1.0, min=0.0, max=1.0, step=0.01),
io.Float.Input("cond_seq_linear.", default=1.0, min=0.0, max=1.0, step=0.01),
io.Float.Input("register_tokens", default=1.0, min=0.0, max=1.0, step=0.01),
io.Float.Input("t_embedder.", default=1.0, min=0.0, max=1.0, step=0.01)
]
for i in range(4):
inputs.append(io.Float.Input(f"double_layers.{i}.", default=1.0, min=0.0, max=1.0, step=0.01))
for i in range(32):
inputs.append(io.Float.Input(f"single_layers.{i}.", default=1.0, min=0.0, max=1.0, step=0.01))
inputs.extend([
io.Float.Input("modF.", default=1.0, min=0.0, max=1.0, step=0.01),
io.Float.Input("final_linear.", default=1.0, min=0.0, max=1.0, step=0.01)
])
return io.Schema(
node_id="ModelMergeAuraflow_V3",
category="advanced/model_merging/model_specific",
inputs=inputs,
outputs=[
io.Model.Output(),
]
)
class ModelMergeCosmos14B(ModelMergeBlocks):
@classmethod
def define_schema(cls):
inputs = [
io.Model.Input("model1"),
io.Model.Input("model2"),
io.Float.Input("pos_embedder.", default=1.0, min=0.0, max=1.0, step=0.01),
io.Float.Input("extra_pos_embedder.", default=1.0, min=0.0, max=1.0, step=0.01),
io.Float.Input("x_embedder.", default=1.0, min=0.0, max=1.0, step=0.01),
io.Float.Input("t_embedder.", default=1.0, min=0.0, max=1.0, step=0.01),
io.Float.Input("affline_norm.", default=1.0, min=0.0, max=1.0, step=0.01)
]
for i in range(36):
inputs.append(io.Float.Input(f"blocks.block{i}.", default=1.0, min=0.0, max=1.0, step=0.01))
inputs.append(io.Float.Input("final_layer.", default=1.0, min=0.0, max=1.0, step=0.01))
return io.Schema(
node_id="ModelMergeCosmos14B_V3",
category="advanced/model_merging/model_specific",
inputs=inputs,
outputs=[
io.Model.Output(),
]
)
class ModelMergeCosmos7B(ModelMergeBlocks):
@classmethod
def define_schema(cls):
inputs = [
io.Model.Input("model1"),
io.Model.Input("model2"),
io.Float.Input("pos_embedder.", default=1.0, min=0.0, max=1.0, step=0.01),
io.Float.Input("extra_pos_embedder.", default=1.0, min=0.0, max=1.0, step=0.01),
io.Float.Input("x_embedder.", default=1.0, min=0.0, max=1.0, step=0.01),
io.Float.Input("t_embedder.", default=1.0, min=0.0, max=1.0, step=0.01),
io.Float.Input("affline_norm.", default=1.0, min=0.0, max=1.0, step=0.01)
]
for i in range(28):
inputs.append(io.Float.Input(f"blocks.block{i}.", default=1.0, min=0.0, max=1.0, step=0.01))
inputs.append(io.Float.Input("final_layer.", default=1.0, min=0.0, max=1.0, step=0.01))
return io.Schema(
node_id="ModelMergeCosmos7B_V3",
category="advanced/model_merging/model_specific",
inputs=inputs,
outputs=[
io.Model.Output(),
]
)
class ModelMergeCosmosPredict2_14B(ModelMergeBlocks):
@classmethod
def define_schema(cls):
inputs = [
io.Model.Input("model1"),
io.Model.Input("model2"),
io.Float.Input("pos_embedder.", default=1.0, min=0.0, max=1.0, step=0.01),
io.Float.Input("x_embedder.", default=1.0, min=0.0, max=1.0, step=0.01),
io.Float.Input("t_embedder.", default=1.0, min=0.0, max=1.0, step=0.01),
io.Float.Input("t_embedding_norm.", default=1.0, min=0.0, max=1.0, step=0.01)
]
for i in range(36):
inputs.append(io.Float.Input(f"blocks.{i}.", default=1.0, min=0.0, max=1.0, step=0.01))
inputs.append(io.Float.Input("final_layer.", default=1.0, min=0.0, max=1.0, step=0.01))
return io.Schema(
node_id="ModelMergeCosmosPredict2_14B_V3",
category="advanced/model_merging/model_specific",
inputs=inputs,
outputs=[
io.Model.Output(),
]
)
class ModelMergeCosmosPredict2_2B(ModelMergeBlocks):
@classmethod
def define_schema(cls):
inputs = [
io.Model.Input("model1"),
io.Model.Input("model2"),
io.Float.Input("pos_embedder.", default=1.0, min=0.0, max=1.0, step=0.01),
io.Float.Input("x_embedder.", default=1.0, min=0.0, max=1.0, step=0.01),
io.Float.Input("t_embedder.", default=1.0, min=0.0, max=1.0, step=0.01),
io.Float.Input("t_embedding_norm.", default=1.0, min=0.0, max=1.0, step=0.01)
]
for i in range(28):
inputs.append(io.Float.Input(f"blocks.{i}.", default=1.0, min=0.0, max=1.0, step=0.01))
inputs.append(io.Float.Input("final_layer.", default=1.0, min=0.0, max=1.0, step=0.01))
return io.Schema(
node_id="ModelMergeCosmosPredict2_2B_V3",
category="advanced/model_merging/model_specific",
inputs=inputs,
outputs=[
io.Model.Output(),
]
)
class ModelMergeFlux1(ModelMergeBlocks):
@classmethod
def define_schema(cls):
inputs = [
io.Model.Input("model1"),
io.Model.Input("model2"),
io.Float.Input("img_in.", default=1.0, min=0.0, max=1.0, step=0.01),
io.Float.Input("time_in.", default=1.0, min=0.0, max=1.0, step=0.01),
io.Float.Input("guidance_in", default=1.0, min=0.0, max=1.0, step=0.01),
io.Float.Input("vector_in.", default=1.0, min=0.0, max=1.0, step=0.01),
io.Float.Input("txt_in.", default=1.0, min=0.0, max=1.0, step=0.01)
]
for i in range(19):
inputs.append(io.Float.Input(f"double_blocks.{i}.", default=1.0, min=0.0, max=1.0, step=0.01))
for i in range(38):
inputs.append(io.Float.Input(f"single_blocks.{i}.", default=1.0, min=0.0, max=1.0, step=0.01))
inputs.append(io.Float.Input("final_layer.", default=1.0, min=0.0, max=1.0, step=0.01))
return io.Schema(
node_id="ModelMergeFlux1_V3",
category="advanced/model_merging/model_specific",
inputs=inputs,
outputs=[
io.Model.Output(),
]
)
class ModelMergeLTXV(ModelMergeBlocks):
@classmethod
def define_schema(cls):
inputs = [
io.Model.Input("model1"),
io.Model.Input("model2"),
io.Float.Input("patchify_proj.", default=1.0, min=0.0, max=1.0, step=0.01),
io.Float.Input("adaln_single.", default=1.0, min=0.0, max=1.0, step=0.01),
io.Float.Input("caption_projection.", default=1.0, min=0.0, max=1.0, step=0.01)
]
for i in range(28):
inputs.append(io.Float.Input(f"transformer_blocks.{i}.", default=1.0, min=0.0, max=1.0, step=0.01))
inputs.extend([
io.Float.Input("scale_shift_table", default=1.0, min=0.0, max=1.0, step=0.01),
io.Float.Input("proj_out.", default=1.0, min=0.0, max=1.0, step=0.01)
])
return io.Schema(
node_id="ModelMergeLTXV_V3",
category="advanced/model_merging/model_specific",
inputs=inputs,
outputs=[
io.Model.Output(),
]
)
class ModelMergeMochiPreview(ModelMergeBlocks):
@classmethod
def define_schema(cls):
inputs = [
io.Model.Input("model1"),
io.Model.Input("model2"),
io.Float.Input("pos_frequencies.", default=1.0, min=0.0, max=1.0, step=0.01),
io.Float.Input("t_embedder.", default=1.0, min=0.0, max=1.0, step=0.01),
io.Float.Input("t5_y_embedder.", default=1.0, min=0.0, max=1.0, step=0.01),
io.Float.Input("t5_yproj.", default=1.0, min=0.0, max=1.0, step=0.01)
]
for i in range(48):
inputs.append(io.Float.Input(f"blocks.{i}.", default=1.0, min=0.0, max=1.0, step=0.01))
inputs.append(io.Float.Input("final_layer.", default=1.0, min=0.0, max=1.0, step=0.01))
return io.Schema(
node_id="ModelMergeMochiPreview_V3",
category="advanced/model_merging/model_specific",
inputs=inputs,
outputs=[
io.Model.Output(),
]
)
class ModelMergeSD1(ModelMergeBlocks):
@classmethod
def define_schema(cls):
inputs = [
io.Model.Input("model1"),
io.Model.Input("model2"),
io.Float.Input("time_embed.", default=1.0, min=0.0, max=1.0, step=0.01),
io.Float.Input("label_emb.", default=1.0, min=0.0, max=1.0, step=0.01)
]
for i in range(12):
inputs.append(io.Float.Input(f"input_blocks.{i}.", default=1.0, min=0.0, max=1.0, step=0.01))
for i in range(3):
inputs.append(io.Float.Input(f"middle_block.{i}.", default=1.0, min=0.0, max=1.0, step=0.01))
for i in range(12):
inputs.append(io.Float.Input(f"output_blocks.{i}.", default=1.0, min=0.0, max=1.0, step=0.01))
inputs.append(io.Float.Input("out.", default=1.0, min=0.0, max=1.0, step=0.01))
return io.Schema(
node_id="ModelMergeSD1_V3",
category="advanced/model_merging/model_specific",
inputs=inputs,
outputs=[
io.Model.Output(),
]
)
class ModelMergeSD3_2B(ModelMergeBlocks):
@classmethod
def define_schema(cls):
inputs = [
io.Model.Input("model1"),
io.Model.Input("model2"),
io.Float.Input("pos_embed.", default=1.0, min=0.0, max=1.0, step=0.01),
io.Float.Input("x_embedder.", default=1.0, min=0.0, max=1.0, step=0.01),
io.Float.Input("context_embedder.", default=1.0, min=0.0, max=1.0, step=0.01),
io.Float.Input("y_embedder.", default=1.0, min=0.0, max=1.0, step=0.01),
io.Float.Input("t_embedder.", default=1.0, min=0.0, max=1.0, step=0.01)
]
for i in range(24):
inputs.append(io.Float.Input(f"joint_blocks.{i}.", default=1.0, min=0.0, max=1.0, step=0.01))
inputs.append(io.Float.Input("final_layer.", default=1.0, min=0.0, max=1.0, step=0.01))
return io.Schema(
node_id="ModelMergeSD3_2B_V3",
category="advanced/model_merging/model_specific",
inputs=inputs,
outputs=[
io.Model.Output(),
]
)
class ModelMergeSD35_Large(ModelMergeBlocks):
@classmethod
def define_schema(cls):
inputs = [
io.Model.Input("model1"),
io.Model.Input("model2"),
io.Float.Input("pos_embed.", default=1.0, min=0.0, max=1.0, step=0.01),
io.Float.Input("x_embedder.", default=1.0, min=0.0, max=1.0, step=0.01),
io.Float.Input("context_embedder.", default=1.0, min=0.0, max=1.0, step=0.01),
io.Float.Input("y_embedder.", default=1.0, min=0.0, max=1.0, step=0.01),
io.Float.Input("t_embedder.", default=1.0, min=0.0, max=1.0, step=0.01)
]
for i in range(38):
inputs.append(io.Float.Input(f"joint_blocks.{i}.", default=1.0, min=0.0, max=1.0, step=0.01))
inputs.append(io.Float.Input("final_layer.", default=1.0, min=0.0, max=1.0, step=0.01))
return io.Schema(
node_id="ModelMergeSD35_Large_V3",
category="advanced/model_merging/model_specific",
inputs=inputs,
outputs=[
io.Model.Output(),
]
)
class ModelMergeSDXL(ModelMergeBlocks):
@classmethod
def define_schema(cls):
inputs = [
io.Model.Input("model1"),
io.Model.Input("model2"),
io.Float.Input("time_embed.", default=1.0, min=0.0, max=1.0, step=0.01),
io.Float.Input("label_emb.", default=1.0, min=0.0, max=1.0, step=0.01)
]
for i in range(9):
inputs.append(io.Float.Input(f"input_blocks.{i}", default=1.0, min=0.0, max=1.0, step=0.01))
for i in range(3):
inputs.append(io.Float.Input(f"middle_block.{i}", default=1.0, min=0.0, max=1.0, step=0.01))
for i in range(9):
inputs.append(io.Float.Input(f"output_blocks.{i}", default=1.0, min=0.0, max=1.0, step=0.01))
inputs.append(io.Float.Input("out.", default=1.0, min=0.0, max=1.0, step=0.01))
return io.Schema(
node_id="ModelMergeSDXL_V3",
category="advanced/model_merging/model_specific",
inputs=inputs,
outputs=[
io.Model.Output(),
]
)
class ModelMergeWAN2_1(ModelMergeBlocks):
@classmethod
def define_schema(cls):
inputs = [
io.Model.Input("model1"),
io.Model.Input("model2"),
io.Float.Input("patch_embedding.", default=1.0, min=0.0, max=1.0, step=0.01),
io.Float.Input("time_embedding.", default=1.0, min=0.0, max=1.0, step=0.01),
io.Float.Input("time_projection.", default=1.0, min=0.0, max=1.0, step=0.01),
io.Float.Input("text_embedding.", default=1.0, min=0.0, max=1.0, step=0.01),
io.Float.Input("img_emb.", default=1.0, min=0.0, max=1.0, step=0.01)
]
for i in range(40):
inputs.append(io.Float.Input(f"blocks.{i}.", default=1.0, min=0.0, max=1.0, step=0.01))
inputs.append(io.Float.Input("head.", default=1.0, min=0.0, max=1.0, step=0.01))
return io.Schema(
node_id="ModelMergeWAN2_1_V3",
category="advanced/model_merging/model_specific",
description="1.3B model has 30 blocks, 14B model has 40 blocks. Image to video model has the extra img_emb.",
inputs=inputs,
outputs=[
io.Model.Output(),
]
)
NODES_LIST = [
ModelMergeAuraflow,
ModelMergeCosmos14B,
ModelMergeCosmos7B,
ModelMergeCosmosPredict2_14B,
ModelMergeCosmosPredict2_2B,
ModelMergeFlux1,
ModelMergeLTXV,
ModelMergeMochiPreview,
ModelMergeSD1,
ModelMergeSD3_2B,
ModelMergeSD35_Large,
ModelMergeSDXL,
ModelMergeWAN2_1,
]

View File

@ -0,0 +1,165 @@
from __future__ import annotations
import torch
import comfy.utils
import nodes
from comfy_api.latest import io
def camera_embeddings(elevation, azimuth):
elevation = torch.as_tensor([elevation])
azimuth = torch.as_tensor([azimuth])
embeddings = torch.stack(
[
torch.deg2rad(
(90 - elevation) - 90
), # Zero123 polar is 90-elevation
torch.sin(torch.deg2rad(azimuth)),
torch.cos(torch.deg2rad(azimuth)),
torch.deg2rad(
90 - torch.full_like(elevation, 0)
),
], dim=-1).unsqueeze(1)
return embeddings
class StableZero123_Conditioning(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="StableZero123_Conditioning_V3",
category="conditioning/3d_models",
inputs=[
io.ClipVision.Input("clip_vision"),
io.Image.Input("init_image"),
io.Vae.Input("vae"),
io.Int.Input("width", default=256, min=16, max=nodes.MAX_RESOLUTION, step=8),
io.Int.Input("height", default=256, min=16, max=nodes.MAX_RESOLUTION, step=8),
io.Int.Input("batch_size", default=1, min=1, max=4096),
io.Float.Input("elevation", default=0.0, min=-180.0, max=180.0, step=0.1, round=False),
io.Float.Input("azimuth", default=0.0, min=-180.0, max=180.0, step=0.1, round=False)
],
outputs=[
io.Conditioning.Output(display_name="positive"),
io.Conditioning.Output(display_name="negative"),
io.Latent.Output(display_name="latent")
]
)
@classmethod
def execute(cls, clip_vision, init_image, vae, width, height, batch_size, elevation, azimuth):
output = clip_vision.encode_image(init_image)
pooled = output.image_embeds.unsqueeze(0)
pixels = comfy.utils.common_upscale(init_image.movedim(-1,1), width, height, "bilinear", "center").movedim(1,-1)
encode_pixels = pixels[:,:,:,:3]
t = vae.encode(encode_pixels)
cam_embeds = camera_embeddings(elevation, azimuth)
cond = torch.cat([pooled, cam_embeds.to(pooled.device).repeat((pooled.shape[0], 1, 1))], dim=-1)
positive = [[cond, {"concat_latent_image": t}]]
negative = [[torch.zeros_like(pooled), {"concat_latent_image": torch.zeros_like(t)}]]
latent = torch.zeros([batch_size, 4, height // 8, width // 8])
return io.NodeOutput(positive, negative, {"samples":latent})
class StableZero123_Conditioning_Batched(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="StableZero123_Conditioning_Batched_V3",
category="conditioning/3d_models",
inputs=[
io.ClipVision.Input("clip_vision"),
io.Image.Input("init_image"),
io.Vae.Input("vae"),
io.Int.Input("width", default=256, min=16, max=nodes.MAX_RESOLUTION, step=8),
io.Int.Input("height", default=256, min=16, max=nodes.MAX_RESOLUTION, step=8),
io.Int.Input("batch_size", default=1, min=1, max=4096),
io.Float.Input("elevation", default=0.0, min=-180.0, max=180.0, step=0.1, round=False),
io.Float.Input("azimuth", default=0.0, min=-180.0, max=180.0, step=0.1, round=False),
io.Float.Input("elevation_batch_increment", default=0.0, min=-180.0, max=180.0, step=0.1, round=False),
io.Float.Input("azimuth_batch_increment", default=0.0, min=-180.0, max=180.0, step=0.1, round=False)
],
outputs=[
io.Conditioning.Output(display_name="positive"),
io.Conditioning.Output(display_name="negative"),
io.Latent.Output(display_name="latent")
]
)
@classmethod
def execute(cls, clip_vision, init_image, vae, width, height, batch_size, elevation, azimuth, elevation_batch_increment, azimuth_batch_increment):
output = clip_vision.encode_image(init_image)
pooled = output.image_embeds.unsqueeze(0)
pixels = comfy.utils.common_upscale(init_image.movedim(-1,1), width, height, "bilinear", "center").movedim(1,-1)
encode_pixels = pixels[:,:,:,:3]
t = vae.encode(encode_pixels)
cam_embeds = []
for i in range(batch_size):
cam_embeds.append(camera_embeddings(elevation, azimuth))
elevation += elevation_batch_increment
azimuth += azimuth_batch_increment
cam_embeds = torch.cat(cam_embeds, dim=0)
cond = torch.cat([comfy.utils.repeat_to_batch_size(pooled, batch_size), cam_embeds], dim=-1)
positive = [[cond, {"concat_latent_image": t}]]
negative = [[torch.zeros_like(pooled), {"concat_latent_image": torch.zeros_like(t)}]]
latent = torch.zeros([batch_size, 4, height // 8, width // 8])
return io.NodeOutput(positive, negative, {"samples":latent, "batch_index": [0] * batch_size})
class SV3D_Conditioning(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="SV3D_Conditioning_V3",
category="conditioning/3d_models",
inputs=[
io.ClipVision.Input("clip_vision"),
io.Image.Input("init_image"),
io.Vae.Input("vae"),
io.Int.Input("width", default=576, min=16, max=nodes.MAX_RESOLUTION, step=8),
io.Int.Input("height", default=576, min=16, max=nodes.MAX_RESOLUTION, step=8),
io.Int.Input("video_frames", default=21, min=1, max=4096),
io.Float.Input("elevation", default=0.0, min=-90.0, max=90.0, step=0.1, round=False)
],
outputs=[
io.Conditioning.Output(display_name="positive"),
io.Conditioning.Output(display_name="negative"),
io.Latent.Output(display_name="latent")
]
)
@classmethod
def execute(cls, clip_vision, init_image, vae, width, height, video_frames, elevation):
output = clip_vision.encode_image(init_image)
pooled = output.image_embeds.unsqueeze(0)
pixels = comfy.utils.common_upscale(init_image.movedim(-1,1), width, height, "bilinear", "center").movedim(1,-1)
encode_pixels = pixels[:,:,:,:3]
t = vae.encode(encode_pixels)
azimuth = 0
azimuth_increment = 360 / (max(video_frames, 2) - 1)
elevations = []
azimuths = []
for i in range(video_frames):
elevations.append(elevation)
azimuths.append(azimuth)
azimuth += azimuth_increment
positive = [[pooled, {"concat_latent_image": t, "elevation": elevations, "azimuth": azimuths}]]
negative = [[torch.zeros_like(pooled), {"concat_latent_image": torch.zeros_like(t), "elevation": elevations, "azimuth": azimuths}]]
latent = torch.zeros([video_frames, 4, height // 8, width // 8])
return io.NodeOutput(positive, negative, {"samples":latent})
NODES_LIST = [
StableZero123_Conditioning,
StableZero123_Conditioning_Batched,
SV3D_Conditioning,
]

View File

@ -0,0 +1,380 @@
from __future__ import annotations
import re
from comfy_api.latest import io
class CaseConverter(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="CaseConverter_V3",
display_name="Case Converter _V3",
category="utils/string",
inputs=[
io.String.Input("string", multiline=True),
io.Combo.Input("mode", options=["UPPERCASE", "lowercase", "Capitalize", "Title Case"])
],
outputs=[
io.String.Output()
]
)
@classmethod
def execute(cls, string, mode):
if mode == "UPPERCASE":
result = string.upper()
elif mode == "lowercase":
result = string.lower()
elif mode == "Capitalize":
result = string.capitalize()
elif mode == "Title Case":
result = string.title()
else:
result = string
return io.NodeOutput(result)
class RegexExtract(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="RegexExtract_V3",
display_name="Regex Extract _V3",
category="utils/string",
inputs=[
io.String.Input("string", multiline=True),
io.String.Input("regex_pattern", multiline=True),
io.Combo.Input("mode", options=["First Match", "All Matches", "First Group", "All Groups"]),
io.Boolean.Input("case_insensitive", default=True),
io.Boolean.Input("multiline", default=False),
io.Boolean.Input("dotall", default=False),
io.Int.Input("group_index", default=1, min=0, max=100)
],
outputs=[
io.String.Output()
]
)
@classmethod
def execute(cls, string, regex_pattern, mode, case_insensitive, multiline, dotall, group_index):
join_delimiter = "\n"
flags = 0
if case_insensitive:
flags |= re.IGNORECASE
if multiline:
flags |= re.MULTILINE
if dotall:
flags |= re.DOTALL
try:
if mode == "First Match":
match = re.search(regex_pattern, string, flags)
if match:
result = match.group(0)
else:
result = ""
elif mode == "All Matches":
matches = re.findall(regex_pattern, string, flags)
if matches:
if isinstance(matches[0], tuple):
result = join_delimiter.join([m[0] for m in matches])
else:
result = join_delimiter.join(matches)
else:
result = ""
elif mode == "First Group":
match = re.search(regex_pattern, string, flags)
if match and len(match.groups()) >= group_index:
result = match.group(group_index)
else:
result = ""
elif mode == "All Groups":
matches = re.finditer(regex_pattern, string, flags)
results = []
for match in matches:
if match.groups() and len(match.groups()) >= group_index:
results.append(match.group(group_index))
result = join_delimiter.join(results)
else:
result = ""
except re.error:
result = ""
return io.NodeOutput(result)
class RegexMatch(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="RegexMatch_V3",
display_name="Regex Match _V3",
category="utils/string",
inputs=[
io.String.Input("string", multiline=True),
io.String.Input("regex_pattern", multiline=True),
io.Boolean.Input("case_insensitive", default=True),
io.Boolean.Input("multiline", default=False),
io.Boolean.Input("dotall", default=False)
],
outputs=[
io.Boolean.Output(display_name="matches")
]
)
@classmethod
def execute(cls, string, regex_pattern, case_insensitive, multiline, dotall):
flags = 0
if case_insensitive:
flags |= re.IGNORECASE
if multiline:
flags |= re.MULTILINE
if dotall:
flags |= re.DOTALL
try:
match = re.search(regex_pattern, string, flags)
result = match is not None
except re.error:
result = False
return io.NodeOutput(result)
class RegexReplace(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="RegexReplace_V3",
display_name="Regex Replace _V3",
category="utils/string",
description="Find and replace text using regex patterns.",
inputs=[
io.String.Input("string", multiline=True),
io.String.Input("regex_pattern", multiline=True),
io.String.Input("replace", multiline=True),
io.Boolean.Input("case_insensitive", default=True, optional=True),
io.Boolean.Input("multiline", default=False, optional=True),
io.Boolean.Input("dotall", default=False, optional=True, tooltip="When enabled, the dot (.) character will match any character including newline characters. When disabled, dots won't match newlines."),
io.Int.Input("count", default=0, min=0, max=100, optional=True, tooltip="Maximum number of replacements to make. Set to 0 to replace all occurrences (default). Set to 1 to replace only the first match, 2 for the first two matches, etc.")
],
outputs=[
io.String.Output()
]
)
@classmethod
def execute(cls, string, regex_pattern, replace, case_insensitive=True, multiline=False, dotall=False, count=0):
flags = 0
if case_insensitive:
flags |= re.IGNORECASE
if multiline:
flags |= re.MULTILINE
if dotall:
flags |= re.DOTALL
result = re.sub(regex_pattern, replace, string, count=count, flags=flags)
return io.NodeOutput(result)
class StringCompare(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="StringCompare_V3",
display_name="Compare _V3",
category="utils/string",
inputs=[
io.String.Input("string_a", multiline=True),
io.String.Input("string_b", multiline=True),
io.Combo.Input("mode", options=["Starts With", "Ends With", "Equal"]),
io.Boolean.Input("case_sensitive", default=True)
],
outputs=[
io.Boolean.Output()
]
)
@classmethod
def execute(cls, string_a, string_b, mode, case_sensitive):
if case_sensitive:
a = string_a
b = string_b
else:
a = string_a.lower()
b = string_b.lower()
if mode == "Equal":
return io.NodeOutput(a == b)
elif mode == "Starts With":
return io.NodeOutput(a.startswith(b))
elif mode == "Ends With":
return io.NodeOutput(a.endswith(b))
class StringConcatenate(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="StringConcatenate_V3",
display_name="Concatenate _V3",
category="utils/string",
inputs=[
io.String.Input("string_a", multiline=True),
io.String.Input("string_b", multiline=True),
io.String.Input("delimiter", multiline=False, default="")
],
outputs=[
io.String.Output()
]
)
@classmethod
def execute(cls, string_a, string_b, delimiter):
return io.NodeOutput(delimiter.join((string_a, string_b)))
class StringContains(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="StringContains_V3",
display_name="Contains _V3",
category="utils/string",
inputs=[
io.String.Input("string", multiline=True),
io.String.Input("substring", multiline=True),
io.Boolean.Input("case_sensitive", default=True)
],
outputs=[
io.Boolean.Output(display_name="contains")
]
)
@classmethod
def execute(cls, string, substring, case_sensitive):
if case_sensitive:
contains = substring in string
else:
contains = substring.lower() in string.lower()
return io.NodeOutput(contains)
class StringLength(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="StringLength_V3",
display_name="Length _V3",
category="utils/string",
inputs=[
io.String.Input("string", multiline=True)
],
outputs=[
io.Int.Output(display_name="length")
]
)
@classmethod
def execute(cls, string):
return io.NodeOutput(len(string))
class StringReplace(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="StringReplace_V3",
display_name="Replace _V3",
category="utils/string",
inputs=[
io.String.Input("string", multiline=True),
io.String.Input("find", multiline=True),
io.String.Input("replace", multiline=True)
],
outputs=[
io.String.Output()
]
)
@classmethod
def execute(cls, string, find, replace):
return io.NodeOutput(string.replace(find, replace))
class StringSubstring(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="StringSubstring_V3",
display_name="Substring _V3",
category="utils/string",
inputs=[
io.String.Input("string", multiline=True),
io.Int.Input("start"),
io.Int.Input("end")
],
outputs=[
io.String.Output()
]
)
@classmethod
def execute(cls, string, start, end):
return io.NodeOutput(string[start:end])
class StringTrim(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="StringTrim_V3",
display_name="Trim _V3",
category="utils/string",
inputs=[
io.String.Input("string", multiline=True),
io.Combo.Input("mode", options=["Both", "Left", "Right"])
],
outputs=[
io.String.Output()
]
)
@classmethod
def execute(cls, string, mode):
if mode == "Both":
result = string.strip()
elif mode == "Left":
result = string.lstrip()
elif mode == "Right":
result = string.rstrip()
else:
result = string
return io.NodeOutput(result)
NODES_LIST = [
CaseConverter,
RegexExtract,
RegexMatch,
RegexReplace,
StringCompare,
StringConcatenate,
StringContains,
StringLength,
StringReplace,
StringSubstring,
StringTrim,
]

View File

@ -2316,6 +2316,7 @@ async def init_builtin_extra_nodes():
"v3/nodes_cond.py", "v3/nodes_cond.py",
"v3/nodes_controlnet.py", "v3/nodes_controlnet.py",
"v3/nodes_cosmos.py", "v3/nodes_cosmos.py",
# "v3/nodes_custom_sampler.py",
"v3/nodes_differential_diffusion.py", "v3/nodes_differential_diffusion.py",
"v3/nodes_edit_model.py", "v3/nodes_edit_model.py",
"v3/nodes_flux.py", "v3/nodes_flux.py",
@ -2323,7 +2324,9 @@ async def init_builtin_extra_nodes():
"v3/nodes_fresca.py", "v3/nodes_fresca.py",
"v3/nodes_gits.py", "v3/nodes_gits.py",
"v3/nodes_hidream.py", "v3/nodes_hidream.py",
# "v3/nodes_hooks.py",
"v3/nodes_hunyuan.py", "v3/nodes_hunyuan.py",
"v3/nodes_hunyuan3d.py",
"v3/nodes_hypernetwork.py", "v3/nodes_hypernetwork.py",
"v3/nodes_hypertile.py", "v3/nodes_hypertile.py",
"v3/nodes_images.py", "v3/nodes_images.py",
@ -2334,10 +2337,13 @@ async def init_builtin_extra_nodes():
"v3/nodes_lotus.py", "v3/nodes_lotus.py",
"v3/nodes_lt.py", "v3/nodes_lt.py",
"v3/nodes_lumina2.py", "v3/nodes_lumina2.py",
"v3/nodes_mahiro.py",
"v3/nodes_mask.py", "v3/nodes_mask.py",
"v3/nodes_mochi.py", "v3/nodes_mochi.py",
"v3/nodes_model_advanced.py", "v3/nodes_model_advanced.py",
"v3/nodes_model_downscale.py", "v3/nodes_model_downscale.py",
"v3/nodes_model_merging.py",
"v3/nodes_model_merging_model_specific.py",
"v3/nodes_morphology.py", "v3/nodes_morphology.py",
"v3/nodes_optimalsteps.py", "v3/nodes_optimalsteps.py",
"v3/nodes_pag.py", "v3/nodes_pag.py",
@ -2352,7 +2358,9 @@ async def init_builtin_extra_nodes():
"v3/nodes_sd3.py", "v3/nodes_sd3.py",
"v3/nodes_sdupscale.py", "v3/nodes_sdupscale.py",
"v3/nodes_slg.py", "v3/nodes_slg.py",
"v3/nodes_stable3d.py",
"v3/nodes_stable_cascade.py", "v3/nodes_stable_cascade.py",
"v3/nodes_string.py",
"v3/nodes_tcfg.py", "v3/nodes_tcfg.py",
"v3/nodes_tomesd.py", "v3/nodes_tomesd.py",
"v3/nodes_torch_compile.py", "v3/nodes_torch_compile.py",