V3: primitive nodes; additional ruff rules for V3 nodes

This commit is contained in:
bigcat88 2025-07-15 17:40:15 +03:00
parent f687f8af7c
commit c196dd5d0f
No known key found for this signature in database
GPG Key ID: 1F0BF0EC3CF22721
8 changed files with 242 additions and 96 deletions

View File

@ -65,14 +65,14 @@ class EmptyLatentAudio_V3(io.ComfyNodeV3):
def execute(cls, seconds, batch_size) -> io.NodeOutput: def execute(cls, seconds, batch_size) -> io.NodeOutput:
length = round((seconds * 44100 / 2048) / 2) * 2 length = round((seconds * 44100 / 2048) / 2) * 2
latent = torch.zeros([batch_size, 64, length], device=comfy.model_management.intermediate_device()) latent = torch.zeros([batch_size, 64, length], device=comfy.model_management.intermediate_device())
return io.NodeOutput({"samples":latent, "type": "audio"}) return io.NodeOutput({"samples": latent, "type": "audio"})
class LoadAudio_V3(io.ComfyNodeV3): class LoadAudio_V3(io.ComfyNodeV3):
@classmethod @classmethod
def DEFINE_SCHEMA(cls): def DEFINE_SCHEMA(cls):
return io.SchemaV3( return io.SchemaV3(
node_id="LoadAudio_V3", # frontend expects "LoadAudio" to work node_id="LoadAudio_V3", # frontend expects "LoadAudio" to work
display_name="Load Audio _V3", # frontend ignores "display_name" for this node display_name="Load Audio _V3", # frontend ignores "display_name" for this node
category="audio", category="audio",
inputs=[ inputs=[
@ -110,7 +110,7 @@ class PreviewAudio_V3(io.ComfyNodeV3):
@classmethod @classmethod
def DEFINE_SCHEMA(cls): def DEFINE_SCHEMA(cls):
return io.SchemaV3( return io.SchemaV3(
node_id="PreviewAudio_V3", # frontend expects "PreviewAudio" to work node_id="PreviewAudio_V3", # frontend expects "PreviewAudio" to work
display_name="Preview Audio _V3", # frontend ignores "display_name" for this node display_name="Preview Audio _V3", # frontend ignores "display_name" for this node
category="audio", category="audio",
inputs=[ inputs=[
@ -129,7 +129,7 @@ class SaveAudioMP3_V3(io.ComfyNodeV3):
@classmethod @classmethod
def DEFINE_SCHEMA(cls): def DEFINE_SCHEMA(cls):
return io.SchemaV3( return io.SchemaV3(
node_id="SaveAudioMP3_V3", # frontend expects "SaveAudioMP3" to work node_id="SaveAudioMP3_V3", # frontend expects "SaveAudioMP3" to work
display_name="Save Audio(MP3) _V3", # frontend ignores "display_name" for this node display_name="Save Audio(MP3) _V3", # frontend ignores "display_name" for this node
category="audio", category="audio",
inputs=[ inputs=[
@ -150,7 +150,7 @@ class SaveAudioOpus_V3(io.ComfyNodeV3):
@classmethod @classmethod
def DEFINE_SCHEMA(cls): def DEFINE_SCHEMA(cls):
return io.SchemaV3( return io.SchemaV3(
node_id="SaveAudioOpus_V3", # frontend expects "SaveAudioOpus" to work node_id="SaveAudioOpus_V3", # frontend expects "SaveAudioOpus" to work
display_name="Save Audio(Opus) _V3", # frontend ignores "display_name" for this node display_name="Save Audio(Opus) _V3", # frontend ignores "display_name" for this node
category="audio", category="audio",
inputs=[ inputs=[
@ -171,7 +171,7 @@ class SaveAudio_V3(io.ComfyNodeV3):
@classmethod @classmethod
def DEFINE_SCHEMA(cls): def DEFINE_SCHEMA(cls):
return io.SchemaV3( return io.SchemaV3(
node_id="SaveAudio_V3", # frontend expects "SaveAudio" to work node_id="SaveAudio_V3", # frontend expects "SaveAudio" to work
display_name="Save Audio _V3", # frontend ignores "display_name" for this node display_name="Save Audio _V3", # frontend ignores "display_name" for this node
category="audio", category="audio",
inputs=[ inputs=[
@ -203,7 +203,7 @@ class VAEDecodeAudio_V3(io.ComfyNodeV3):
@classmethod @classmethod
def execute(cls, vae, samples) -> io.NodeOutput: def execute(cls, vae, samples) -> io.NodeOutput:
audio = vae.decode(samples["samples"]).movedim(-1, 1) audio = vae.decode(samples["samples"]).movedim(-1, 1)
std = torch.std(audio, dim=[1,2], keepdim=True) * 5.0 std = torch.std(audio, dim=[1, 2], keepdim=True) * 5.0
std[std < 1.0] = 1.0 std[std < 1.0] = 1.0
audio /= std audio /= std
return io.NodeOutput({"waveform": audio, "sample_rate": 44100}) return io.NodeOutput({"waveform": audio, "sample_rate": 44100})
@ -250,7 +250,7 @@ def _save_audio(cls, audio, filename_prefix="ComfyUI", format="flac", quality="1
OPUS_RATES = [8000, 12000, 16000, 24000, 48000] OPUS_RATES = [8000, 12000, 16000, 24000, 48000]
results = [] results = []
for (batch_number, waveform) in enumerate(audio["waveform"].cpu()): for batch_number, waveform in enumerate(audio["waveform"].cpu()):
filename_with_batch_num = filename.replace("%batch_num%", str(batch_number)) filename_with_batch_num = filename.replace("%batch_num%", str(batch_number))
file = f"{filename_with_batch_num}_{counter:05}_.{format}" file = f"{filename_with_batch_num}_{counter:05}_.{format}"
output_path = os.path.join(full_output_folder, file) output_path = os.path.join(full_output_folder, file)
@ -277,7 +277,7 @@ def _save_audio(cls, audio, filename_prefix="ComfyUI", format="flac", quality="1
# Create output with specified format # Create output with specified format
output_buffer = BytesIO() output_buffer = BytesIO()
output_container = av.open(output_buffer, mode='w', format=format) output_container = av.open(output_buffer, mode="w", format=format)
# Set metadata on the container # Set metadata on the container
for key, value in metadata.items(): for key, value in metadata.items():
@ -299,19 +299,19 @@ def _save_audio(cls, audio, filename_prefix="ComfyUI", format="flac", quality="1
elif format == "mp3": elif format == "mp3":
out_stream = output_container.add_stream("libmp3lame", rate=sample_rate) out_stream = output_container.add_stream("libmp3lame", rate=sample_rate)
if quality == "V0": if quality == "V0":
#TODO i would really love to support V3 and V5 but there doesn't seem to be a way to set the qscale level, the property below is a bool # TODO i would really love to support V3 and V5 but there doesn't seem to be a way to set the qscale level, the property below is a bool
out_stream.codec_context.qscale = 1 out_stream.codec_context.qscale = 1
elif quality == "128k": elif quality == "128k":
out_stream.bit_rate = 128000 out_stream.bit_rate = 128000
elif quality == "320k": elif quality == "320k":
out_stream.bit_rate = 320000 out_stream.bit_rate = 320000
else: # format == "flac": else: # format == "flac":
out_stream = output_container.add_stream("flac", rate=sample_rate) out_stream = output_container.add_stream("flac", rate=sample_rate)
frame = av.AudioFrame.from_ndarray( frame = av.AudioFrame.from_ndarray(
waveform.movedim(0, 1).reshape(1, -1).float().numpy(), waveform.movedim(0, 1).reshape(1, -1).float().numpy(),
format='flt', format="flt",
layout='mono' if waveform.shape[0] == 1 else 'stereo', layout="mono" if waveform.shape[0] == 1 else "stereo",
) )
frame.sample_rate = sample_rate frame.sample_rate = sample_rate
frame.pts = 0 frame.pts = 0
@ -325,7 +325,7 @@ def _save_audio(cls, audio, filename_prefix="ComfyUI", format="flac", quality="1
# Write the output to file # Write the output to file
output_buffer.seek(0) output_buffer.seek(0)
with open(output_path, 'wb') as f: with open(output_path, "wb") as f:
f.write(output_buffer.getbuffer()) f.write(output_buffer.getbuffer())
results.append(ui.SavedResult(file, subfolder, io.FolderType.output)) results.append(ui.SavedResult(file, subfolder, io.FolderType.output))

View File

@ -1,5 +1,5 @@
from comfy.cldm.control_types import UNION_CONTROLNET_TYPES
import comfy.utils import comfy.utils
from comfy.cldm.control_types import UNION_CONTROLNET_TYPES
from comfy_api.v3 import io from comfy_api.v3 import io
@ -27,11 +27,13 @@ class ControlNetApplyAdvanced_V3(io.ComfyNodeV3):
) )
@classmethod @classmethod
def execute(cls, positive, negative, control_net, image, strength, start_percent, end_percent, vae=None, extra_concat=[]) -> io.NodeOutput: def execute(
cls, positive, negative, control_net, image, strength, start_percent, end_percent, vae=None, extra_concat=[]
) -> io.NodeOutput:
if strength == 0: if strength == 0:
return io.NodeOutput(positive, negative) return io.NodeOutput(positive, negative)
control_hint = image.movedim(-1,1) control_hint = image.movedim(-1, 1)
cnets = {} cnets = {}
out = [] out = []
@ -40,16 +42,18 @@ class ControlNetApplyAdvanced_V3(io.ComfyNodeV3):
for t in conditioning: for t in conditioning:
d = t[1].copy() d = t[1].copy()
prev_cnet = d.get('control', None) prev_cnet = d.get("control", None)
if prev_cnet in cnets: if prev_cnet in cnets:
c_net = cnets[prev_cnet] c_net = cnets[prev_cnet]
else: else:
c_net = control_net.copy().set_cond_hint(control_hint, strength, (start_percent, end_percent), vae=vae, extra_concat=extra_concat) c_net = control_net.copy().set_cond_hint(
control_hint, strength, (start_percent, end_percent), vae=vae, extra_concat=extra_concat
)
c_net.set_previous_controlnet(prev_cnet) c_net.set_previous_controlnet(prev_cnet)
cnets[prev_cnet] = c_net cnets[prev_cnet] = c_net
d['control'] = c_net d["control"] = c_net
d['control_apply_to_uncond'] = False d["control_apply_to_uncond"] = False
n = [t[0], d] n = [t[0], d]
c.append(n) c.append(n)
out.append(c) out.append(c)
@ -107,7 +111,9 @@ class ControlNetInpaintingAliMamaApply_V3(ControlNetApplyAdvanced_V3):
) )
@classmethod @classmethod
def execute(cls, positive, negative, control_net, vae, image, mask, strength, start_percent, end_percent) -> io.NodeOutput: def execute(
cls, positive, negative, control_net, vae, image, mask, strength, start_percent, end_percent
) -> io.NodeOutput:
extra_concat = [] extra_concat = []
if control_net.concat_mask: if control_net.concat_mask:
mask = 1.0 - mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])) mask = 1.0 - mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1]))
@ -115,7 +121,17 @@ class ControlNetInpaintingAliMamaApply_V3(ControlNetApplyAdvanced_V3):
image = image * mask_apply.movedim(1, -1).repeat(1, 1, 1, image.shape[3]) image = image * mask_apply.movedim(1, -1).repeat(1, 1, 1, image.shape[3])
extra_concat = [mask] extra_concat = [mask]
return super().execute(positive, negative, control_net, image, strength, start_percent, end_percent, vae=vae, extra_concat=extra_concat) return super().execute(
positive,
negative,
control_net,
image,
strength,
start_percent,
end_percent,
vae=vae,
extra_concat=extra_concat,
)
NODES_LIST: list[type[io.ComfyNodeV3]] = [ NODES_LIST: list[type[io.ComfyNodeV3]] = [

View File

@ -1,16 +1,16 @@
import hashlib
import json import json
import os import os
import torch
import hashlib
import numpy as np import numpy as np
import torch
from PIL import Image, ImageOps, ImageSequence from PIL import Image, ImageOps, ImageSequence
from PIL.PngImagePlugin import PngInfo from PIL.PngImagePlugin import PngInfo
from comfy_api.v3 import io, ui
from comfy.cli_args import args
import folder_paths import folder_paths
import node_helpers import node_helpers
from comfy.cli_args import args
from comfy_api.v3 import io, ui
class SaveImage_V3(io.ComfyNodeV3): class SaveImage_V3(io.ComfyNodeV3):
@ -29,7 +29,8 @@ class SaveImage_V3(io.ComfyNodeV3):
io.String.Input( io.String.Input(
"filename_prefix", "filename_prefix",
default="ComfyUI", default="ComfyUI",
tooltip="The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes.", tooltip="The prefix for the file to save. This may include formatting information "
"such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes.",
), ),
], ],
hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo], hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo],
@ -42,8 +43,8 @@ class SaveImage_V3(io.ComfyNodeV3):
filename_prefix, folder_paths.get_output_directory(), images[0].shape[1], images[0].shape[0] filename_prefix, folder_paths.get_output_directory(), images[0].shape[1], images[0].shape[0]
) )
results = [] results = []
for (batch_number, image) in enumerate(images): for batch_number, image in enumerate(images):
i = 255. * image.cpu().numpy() i = 255.0 * image.cpu().numpy()
img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8)) img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8))
metadata = None metadata = None
if not args.disable_metadata: if not args.disable_metadata:
@ -82,13 +83,13 @@ class SaveAnimatedPNG_V3(io.ComfyNodeV3):
@classmethod @classmethod
def execute(cls, images, fps, compress_level, filename_prefix="ComfyUI") -> io.NodeOutput: def execute(cls, images, fps, compress_level, filename_prefix="ComfyUI") -> io.NodeOutput:
full_output_folder, filename, counter, subfolder, filename_prefix = ( full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(
folder_paths.get_save_image_path(filename_prefix, folder_paths.get_output_directory(), images[0].shape[1], images[0].shape[0]) filename_prefix, folder_paths.get_output_directory(), images[0].shape[1], images[0].shape[0]
) )
results = [] results = []
pil_images = [] pil_images = []
for image in images: for image in images:
img = Image.fromarray(np.clip(255. * image.cpu().numpy(), 0, 255).astype(np.uint8)) img = Image.fromarray(np.clip(255.0 * image.cpu().numpy(), 0, 255).astype(np.uint8))
pil_images.append(img) pil_images.append(img)
metadata = None metadata = None
@ -96,19 +97,34 @@ class SaveAnimatedPNG_V3(io.ComfyNodeV3):
metadata = PngInfo() metadata = PngInfo()
if cls.hidden.prompt is not None: if cls.hidden.prompt is not None:
metadata.add( metadata.add(
b"comf", "prompt".encode("latin-1", "strict") + b"\0" + json.dumps(cls.hidden.prompt).encode("latin-1", "strict"), after_idat=True b"comf",
"prompt".encode("latin-1", "strict")
+ b"\0"
+ json.dumps(cls.hidden.prompt).encode("latin-1", "strict"),
after_idat=True,
) )
if cls.hidden.extra_pnginfo is not None: if cls.hidden.extra_pnginfo is not None:
for x in cls.hidden.extra_pnginfo: for x in cls.hidden.extra_pnginfo:
metadata.add( metadata.add(
b"comf", x.encode("latin-1", "strict") + b"\0" + json.dumps(cls.hidden.extra_pnginfo[x]).encode("latin-1", "strict"), after_idat=True b"comf",
x.encode("latin-1", "strict")
+ b"\0"
+ json.dumps(cls.hidden.extra_pnginfo[x]).encode("latin-1", "strict"),
after_idat=True,
) )
file = f"{filename}_{counter:05}_.png" file = f"{filename}_{counter:05}_.png"
pil_images[0].save(os.path.join(full_output_folder, file), pnginfo=metadata, compress_level=compress_level, save_all=True, duration=int(1000.0/fps), append_images=pil_images[1:]) pil_images[0].save(
os.path.join(full_output_folder, file),
pnginfo=metadata,
compress_level=compress_level,
save_all=True,
duration=int(1000.0 / fps),
append_images=pil_images[1:],
)
results.append(ui.SavedResult(file, subfolder, io.FolderType.output)) results.append(ui.SavedResult(file, subfolder, io.FolderType.output))
return io.NodeOutput(ui={"images": results, "animated": (True,) }) return io.NodeOutput(ui={"images": results, "animated": (True,)})
class SaveAnimatedWEBP_V3(io.ComfyNodeV3): class SaveAnimatedWEBP_V3(io.ComfyNodeV3):
@ -136,11 +152,13 @@ class SaveAnimatedWEBP_V3(io.ComfyNodeV3):
@classmethod @classmethod
def execute(cls, images, fps, filename_prefix, lossless, quality, method, num_frames=0) -> io.NodeOutput: def execute(cls, images, fps, filename_prefix, lossless, quality, method, num_frames=0) -> io.NodeOutput:
method = cls.COMPRESS_METHODS.get(method) method = cls.COMPRESS_METHODS.get(method)
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, folder_paths.get_output_directory(), images[0].shape[1], images[0].shape[0]) full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(
filename_prefix, folder_paths.get_output_directory(), images[0].shape[1], images[0].shape[0]
)
results = [] results = []
pil_images = [] pil_images = []
for image in images: for image in images:
img = Image.fromarray(np.clip(255. * image.cpu().numpy(), 0, 255).astype(np.uint8)) img = Image.fromarray(np.clip(255.0 * image.cpu().numpy(), 0, 255).astype(np.uint8))
pil_images.append(img) pil_images.append(img)
metadata = pil_images[0].getexif() metadata = pil_images[0].getexif()
@ -148,7 +166,7 @@ class SaveAnimatedWEBP_V3(io.ComfyNodeV3):
if cls.hidden.prompt is not None: if cls.hidden.prompt is not None:
metadata[0x0110] = "prompt:{}".format(json.dumps(cls.hidden.prompt)) metadata[0x0110] = "prompt:{}".format(json.dumps(cls.hidden.prompt))
if cls.hidden.extra_pnginfo is not None: if cls.hidden.extra_pnginfo is not None:
inital_exif = 0x010f inital_exif = 0x010F
for x in cls.hidden.extra_pnginfo: for x in cls.hidden.extra_pnginfo:
metadata[inital_exif] = "{}:{}".format(x, json.dumps(cls.hidden.extra_pnginfo[x])) metadata[inital_exif] = "{}:{}".format(x, json.dumps(cls.hidden.extra_pnginfo[x]))
inital_exif -= 1 inital_exif -= 1
@ -160,8 +178,9 @@ class SaveAnimatedWEBP_V3(io.ComfyNodeV3):
file = f"{filename}_{counter:05}_.webp" file = f"{filename}_{counter:05}_.webp"
pil_images[i].save( pil_images[i].save(
os.path.join(full_output_folder, file), os.path.join(full_output_folder, file),
save_all=True, duration=int(1000.0/fps), save_all=True,
append_images=pil_images[i + 1:i + num_frames], duration=int(1000.0 / fps),
append_images=pil_images[i + 1 : i + num_frames],
exif=metadata, exif=metadata,
lossless=lossless, lossless=lossless,
quality=quality, quality=quality,
@ -228,12 +247,12 @@ class LoadImage_V3(io.ComfyNodeV3):
output_masks = [] output_masks = []
w, h = None, None w, h = None, None
excluded_formats = ['MPO'] excluded_formats = ["MPO"]
for i in ImageSequence.Iterator(img): for i in ImageSequence.Iterator(img):
i = node_helpers.pillow(ImageOps.exif_transpose, i) i = node_helpers.pillow(ImageOps.exif_transpose, i)
if i.mode == 'I': if i.mode == "I":
i = i.point(lambda i: i * (1 / 255)) i = i.point(lambda i: i * (1 / 255))
image = i.convert("RGB") image = i.convert("RGB")
@ -246,14 +265,14 @@ class LoadImage_V3(io.ComfyNodeV3):
image = np.array(image).astype(np.float32) / 255.0 image = np.array(image).astype(np.float32) / 255.0
image = torch.from_numpy(image)[None,] image = torch.from_numpy(image)[None,]
if 'A' in i.getbands(): if "A" in i.getbands():
mask = np.array(i.getchannel('A')).astype(np.float32) / 255.0 mask = np.array(i.getchannel("A")).astype(np.float32) / 255.0
mask = 1. - torch.from_numpy(mask) mask = 1.0 - torch.from_numpy(mask)
elif i.mode == 'P' and 'transparency' in i.info: elif i.mode == "P" and "transparency" in i.info:
mask = np.array(i.convert('RGBA').getchannel('A')).astype(np.float32) / 255.0 mask = np.array(i.convert("RGBA").getchannel("A")).astype(np.float32) / 255.0
mask = 1. - torch.from_numpy(mask) mask = 1.0 - torch.from_numpy(mask)
else: else:
mask = torch.zeros((64,64), dtype=torch.float32, device="cpu") mask = torch.zeros((64, 64), dtype=torch.float32, device="cpu")
output_images.append(image) output_images.append(image)
output_masks.append(mask.unsqueeze(0)) output_masks.append(mask.unsqueeze(0))
@ -270,7 +289,7 @@ class LoadImage_V3(io.ComfyNodeV3):
def fingerprint_inputs(s, image): def fingerprint_inputs(s, image):
image_path = folder_paths.get_annotated_filepath(image) image_path = folder_paths.get_annotated_filepath(image)
m = hashlib.sha256() m = hashlib.sha256()
with open(image_path, 'rb') as f: with open(image_path, "rb") as f:
m.update(f.read()) m.update(f.read())
return m.digest().hex() return m.digest().hex()
@ -288,8 +307,8 @@ class LoadImageOutput_V3(io.ComfyNodeV3):
node_id="LoadImageOutput_V3", node_id="LoadImageOutput_V3",
display_name="Load Image (from Outputs) _V3", display_name="Load Image (from Outputs) _V3",
description="Load an image from the output folder. " description="Load an image from the output folder. "
"When the refresh button is clicked, the node will update the image list " "When the refresh button is clicked, the node will update the image list "
"and automatically select the first image, allowing for easy iteration.", "and automatically select the first image, allowing for easy iteration.",
category="image", category="image",
inputs=[ inputs=[
io.Combo.Input( io.Combo.Input(
@ -317,12 +336,12 @@ class LoadImageOutput_V3(io.ComfyNodeV3):
output_masks = [] output_masks = []
w, h = None, None w, h = None, None
excluded_formats = ['MPO'] excluded_formats = ["MPO"]
for i in ImageSequence.Iterator(img): for i in ImageSequence.Iterator(img):
i = node_helpers.pillow(ImageOps.exif_transpose, i) i = node_helpers.pillow(ImageOps.exif_transpose, i)
if i.mode == 'I': if i.mode == "I":
i = i.point(lambda i: i * (1 / 255)) i = i.point(lambda i: i * (1 / 255))
image = i.convert("RGB") image = i.convert("RGB")
@ -335,12 +354,12 @@ class LoadImageOutput_V3(io.ComfyNodeV3):
image = np.array(image).astype(np.float32) / 255.0 image = np.array(image).astype(np.float32) / 255.0
image = torch.from_numpy(image)[None,] image = torch.from_numpy(image)[None,]
if 'A' in i.getbands(): if "A" in i.getbands():
mask = np.array(i.getchannel('A')).astype(np.float32) / 255.0 mask = np.array(i.getchannel("A")).astype(np.float32) / 255.0
mask = 1. - torch.from_numpy(mask) mask = 1.0 - torch.from_numpy(mask)
elif i.mode == 'P' and 'transparency' in i.info: elif i.mode == "P" and "transparency" in i.info:
mask = np.array(i.convert('RGBA').getchannel('A')).astype(np.float32) / 255.0 mask = np.array(i.convert("RGBA").getchannel("A")).astype(np.float32) / 255.0
mask = 1. - torch.from_numpy(mask) mask = 1.0 - torch.from_numpy(mask)
else: else:
mask = torch.zeros((64, 64), dtype=torch.float32, device="cpu") mask = torch.zeros((64, 64), dtype=torch.float32, device="cpu")
output_images.append(image) output_images.append(image)
@ -359,7 +378,7 @@ class LoadImageOutput_V3(io.ComfyNodeV3):
def fingerprint_inputs(s, image): def fingerprint_inputs(s, image):
image_path = folder_paths.get_annotated_filepath(image) image_path = folder_paths.get_annotated_filepath(image)
m = hashlib.sha256() m = hashlib.sha256()
with open(image_path, 'rb') as f: with open(image_path, "rb") as f:
m.update(f.read()) m.update(f.read())
return m.digest().hex() return m.digest().hex()

View File

@ -0,0 +1,104 @@
from __future__ import annotations
import sys
from comfy_api.v3 import io
class String_V3(io.ComfyNodeV3):
@classmethod
def DEFINE_SCHEMA(cls):
return io.SchemaV3(
node_id="PrimitiveString_V3",
display_name="String _V3",
category="utils/primitive",
inputs=[
io.String.Input("value"),
],
outputs=[io.String.Output()],
)
@classmethod
def execute(cls, value: str) -> io.NodeOutput:
return io.NodeOutput(value)
class StringMultiline_V3(io.ComfyNodeV3):
@classmethod
def DEFINE_SCHEMA(cls):
return io.SchemaV3(
node_id="PrimitiveStringMultiline_V3",
display_name="String (Multiline) _V3",
category="utils/primitive",
inputs=[
io.String.Input("value", multiline=True),
],
outputs=[io.String.Output()],
)
@classmethod
def execute(cls, value: str) -> io.NodeOutput:
return io.NodeOutput(value)
class Int_V3(io.ComfyNodeV3):
@classmethod
def DEFINE_SCHEMA(cls):
return io.SchemaV3(
node_id="PrimitiveInt_V3",
display_name="Int _V3",
category="utils/primitive",
inputs=[
io.Int.Input("value", min=-sys.maxsize, max=sys.maxsize, control_after_generate=True),
],
outputs=[io.Int.Output()],
)
@classmethod
def execute(cls, value: int) -> io.NodeOutput:
return io.NodeOutput(value)
class Float_V3(io.ComfyNodeV3):
@classmethod
def DEFINE_SCHEMA(cls):
return io.SchemaV3(
node_id="PrimitiveFloat_V3",
display_name="Float _V3",
category="utils/primitive",
inputs=[
io.Float.Input("value", min=-sys.maxsize, max=sys.maxsize),
],
outputs=[io.Float.Output()],
)
@classmethod
def execute(cls, value: float) -> io.NodeOutput:
return io.NodeOutput(value)
class Boolean_V3(io.ComfyNodeV3):
@classmethod
def DEFINE_SCHEMA(cls):
return io.SchemaV3(
node_id="PrimitiveBoolean_V3",
display_name="Boolean _V3",
category="utils/primitive",
inputs=[
io.Boolean.Input("value"),
],
outputs=[io.Boolean.Output()],
)
@classmethod
def execute(cls, value: bool) -> io.NodeOutput:
return io.NodeOutput(value)
NODES_LIST: list[type[io.ComfyNodeV3]] = [
String_V3,
StringMultiline_V3,
Int_V3,
Float_V3,
Boolean_V3,
]

View File

@ -1,25 +1,25 @@
""" """
This file is part of ComfyUI. This file is part of ComfyUI.
Copyright (C) 2024 Stability AI Copyright (C) 2024 Stability AI
This program is free software: you can redistribute it and/or modify This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or the Free Software Foundation, either version 3 of the License, or
(at your option) any later version. (at your option) any later version.
This program is distributed in the hope that it will be useful, This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details. GNU General Public License for more details.
You should have received a copy of the GNU General Public License You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>. along with this program. If not, see <https://www.gnu.org/licenses/>.
""" """
import torch import torch
import nodes
import comfy.utils
import comfy.utils
import nodes
from comfy_api.v3 import io from comfy_api.v3 import io
@ -30,7 +30,7 @@ class StableCascade_EmptyLatentImage_V3(io.ComfyNodeV3):
node_id="StableCascade_EmptyLatentImage_V3", node_id="StableCascade_EmptyLatentImage_V3",
category="latent/stable_cascade", category="latent/stable_cascade",
inputs=[ inputs=[
io.Int.Input("width", default=1024,min=256,max=nodes.MAX_RESOLUTION, step=8), io.Int.Input("width", default=1024, min=256, max=nodes.MAX_RESOLUTION, step=8),
io.Int.Input("height", default=1024, min=256, max=nodes.MAX_RESOLUTION, step=8), io.Int.Input("height", default=1024, min=256, max=nodes.MAX_RESOLUTION, step=8),
io.Int.Input("compression", default=42, min=4, max=128, step=1), io.Int.Input("compression", default=42, min=4, max=128, step=1),
io.Int.Input("batch_size", default=1, min=1, max=4096), io.Int.Input("batch_size", default=1, min=1, max=4096),
@ -72,9 +72,9 @@ class StableCascade_StageC_VAEEncode_V3(io.ComfyNodeV3):
out_width = (width // compression) * vae.downscale_ratio out_width = (width // compression) * vae.downscale_ratio
out_height = (height // compression) * vae.downscale_ratio out_height = (height // compression) * vae.downscale_ratio
s = comfy.utils.common_upscale(image.movedim(-1,1), out_width, out_height, "bicubic", "center").movedim(1,-1) s = comfy.utils.common_upscale(image.movedim(-1, 1), out_width, out_height, "bicubic", "center").movedim(1, -1)
c_latent = vae.encode(s[:,:,:,:3]) c_latent = vae.encode(s[:, :, :, :3])
b_latent = torch.zeros([c_latent.shape[0], 4, (height // 8) * 2, (width // 8) * 2]) b_latent = torch.zeros([c_latent.shape[0], 4, (height // 8) * 2, (width // 8) * 2])
return io.NodeOutput({"samples": c_latent}, {"samples": b_latent}) return io.NodeOutput({"samples": c_latent}, {"samples": b_latent})
@ -90,7 +90,7 @@ class StableCascade_StageB_Conditioning_V3(io.ComfyNodeV3):
io.Latent.Input("stage_c"), io.Latent.Input("stage_c"),
], ],
outputs=[ outputs=[
io.Conditioning.Output(), io.Conditioning.Output(),
], ],
) )
@ -99,7 +99,7 @@ class StableCascade_StageB_Conditioning_V3(io.ComfyNodeV3):
c = [] c = []
for t in conditioning: for t in conditioning:
d = t[1].copy() d = t[1].copy()
d['stable_cascade_prior'] = stage_c['samples'] d["stable_cascade_prior"] = stage_c["samples"]
n = [t[0], d] n = [t[0], d]
c.append(n) c.append(n)
return io.NodeOutput(c) return io.NodeOutput(c)
@ -128,7 +128,7 @@ class StableCascade_SuperResolutionControlnet_V3(io.ComfyNodeV3):
width = image.shape[-2] width = image.shape[-2]
height = image.shape[-3] height = image.shape[-3]
batch_size = image.shape[0] batch_size = image.shape[0]
controlnet_input = vae.encode(image[:,:,:,:3]).movedim(1, -1) controlnet_input = vae.encode(image[:, :, :, :3]).movedim(1, -1)
c_latent = torch.zeros([batch_size, 16, height // 16, width // 16]) c_latent = torch.zeros([batch_size, 16, height // 16, width // 16])
b_latent = torch.zeros([batch_size, 4, height // 2, width // 2]) b_latent = torch.zeros([batch_size, 4, height // 2, width // 2])

View File

@ -1,14 +1,13 @@
import hashlib import hashlib
import torch
import numpy as np import numpy as np
import torch
from PIL import Image, ImageOps, ImageSequence from PIL import Image, ImageOps, ImageSequence
from comfy_api.v3 import io
import nodes
import folder_paths import folder_paths
import node_helpers import node_helpers
import nodes
from comfy_api.v3 import io
MAX_RESOLUTION = nodes.MAX_RESOLUTION MAX_RESOLUTION = nodes.MAX_RESOLUTION
@ -51,12 +50,12 @@ class WebcamCapture_V3(io.ComfyNodeV3):
output_masks = [] output_masks = []
w, h = None, None w, h = None, None
excluded_formats = ['MPO'] excluded_formats = ["MPO"]
for i in ImageSequence.Iterator(img): for i in ImageSequence.Iterator(img):
i = node_helpers.pillow(ImageOps.exif_transpose, i) i = node_helpers.pillow(ImageOps.exif_transpose, i)
if i.mode == 'I': if i.mode == "I":
i = i.point(lambda i: i * (1 / 255)) i = i.point(lambda i: i * (1 / 255))
image = i.convert("RGB") image = i.convert("RGB")
@ -69,12 +68,12 @@ class WebcamCapture_V3(io.ComfyNodeV3):
image = np.array(image).astype(np.float32) / 255.0 image = np.array(image).astype(np.float32) / 255.0
image = torch.from_numpy(image)[None,] image = torch.from_numpy(image)[None,]
if 'A' in i.getbands(): if "A" in i.getbands():
mask = np.array(i.getchannel('A')).astype(np.float32) / 255.0 mask = np.array(i.getchannel("A")).astype(np.float32) / 255.0
mask = 1. - torch.from_numpy(mask) mask = 1.0 - torch.from_numpy(mask)
elif i.mode == 'P' and 'transparency' in i.info: elif i.mode == "P" and "transparency" in i.info:
mask = np.array(i.convert('RGBA').getchannel('A')).astype(np.float32) / 255.0 mask = np.array(i.convert("RGBA").getchannel("A")).astype(np.float32) / 255.0
mask = 1. - torch.from_numpy(mask) mask = 1.0 - torch.from_numpy(mask)
else: else:
mask = torch.zeros((64, 64), dtype=torch.float32, device="cpu") mask = torch.zeros((64, 64), dtype=torch.float32, device="cpu")
output_images.append(image) output_images.append(image)
@ -93,7 +92,7 @@ class WebcamCapture_V3(io.ComfyNodeV3):
def fingerprint_inputs(s, image, width, height, capture_on_queue): def fingerprint_inputs(s, image, width, height, capture_on_queue):
image_path = folder_paths.get_annotated_filepath(image) image_path = folder_paths.get_annotated_filepath(image)
m = hashlib.sha256() m = hashlib.sha256()
with open(image_path, 'rb') as f: with open(image_path, "rb") as f:
m.update(f.read()) m.update(f.read())
return m.digest().hex() return m.digest().hex()

View File

@ -2303,6 +2303,7 @@ def init_builtin_extra_nodes():
"v3/nodes_controlnet.py", "v3/nodes_controlnet.py",
"v3/nodes_images.py", "v3/nodes_images.py",
"v3/nodes_mask.py", "v3/nodes_mask.py",
"v3/nodes_primitive.py",
"v3/nodes_webcam.py", "v3/nodes_webcam.py",
"v3/nodes_stable_cascade.py", "v3/nodes_stable_cascade.py",
] ]

View File

@ -12,6 +12,8 @@ documentation = "https://docs.comfy.org/"
[tool.ruff] [tool.ruff]
lint.select = [ lint.select = [
"E", # pycodestyle errors
"I", # isort
"N805", # invalid-first-argument-name-for-method "N805", # invalid-first-argument-name-for-method
"S307", # suspicious-eval-usage "S307", # suspicious-eval-usage
"S102", # exec "S102", # exec
@ -22,3 +24,8 @@ lint.select = [
"F", "F",
] ]
exclude = ["*.ipynb"] exclude = ["*.ipynb"]
line-length = 120
lint.pycodestyle.ignore-overlong-task-comments = true
[tool.ruff.lint.per-file-ignores]
"!comfy_extras/v3/*" = ["E", "I"] # enable these rules only for V3 nodes