put V1 nodes back

This commit is contained in:
bigcat88 2025-07-10 07:48:45 +03:00
parent 965d2f9b8f
commit d8b91bb84e
No known key found for this signature in database
GPG Key ID: 1F0BF0EC3CF22721
6 changed files with 268 additions and 25 deletions

View File

@ -3,7 +3,10 @@ import scipy.ndimage
import torch
import comfy.utils
import node_helpers
import folder_paths
import random
import nodes
from nodes import MAX_RESOLUTION
def composite(destination, source, x, y, mask = None, multiplier = 8, resize_source = False):
@ -362,6 +365,30 @@ class ThresholdMask:
mask = (mask > value).float()
return (mask,)
# Mask Preview - original implement from
# https://github.com/cubiq/ComfyUI_essentials/blob/9d9f4bedfc9f0321c19faf71855e228c93bd0dc9/mask.py#L81
# upstream requested in https://github.com/Kosinkadink/rfcs/blob/main/rfcs/0000-corenodes.md#preview-nodes
class MaskPreview(nodes.SaveImage):
def __init__(self):
self.output_dir = folder_paths.get_temp_directory()
self.type = "temp"
self.prefix_append = "_temp_" + ''.join(random.choice("abcdefghijklmnopqrstupvxyz") for x in range(5))
self.compress_level = 4
@classmethod
def INPUT_TYPES(s):
return {
"required": {"mask": ("MASK",), },
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
}
FUNCTION = "execute"
CATEGORY = "mask"
def execute(self, mask, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None):
preview = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])).movedim(1, -1).expand(-1, -1, -1, 3)
return self.save_images(preview, filename_prefix, prompt, extra_pnginfo)
NODE_CLASS_MAPPINGS = {
"LatentCompositeMasked": LatentCompositeMasked,
@ -376,8 +403,10 @@ NODE_CLASS_MAPPINGS = {
"FeatherMask": FeatherMask,
"GrowMask": GrowMask,
"ThresholdMask": ThresholdMask,
"MaskPreview": MaskPreview
}
NODE_DISPLAY_NAME_MAPPINGS = {
"ImageToMask": "Convert Image to Mask",
"MaskToImage": "Convert Mask to Image",
}

View File

@ -0,0 +1,37 @@
import nodes
import folder_paths
MAX_RESOLUTION = nodes.MAX_RESOLUTION
class WebcamCapture(nodes.LoadImage):
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image": ("WEBCAM", {}),
"width": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
"height": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
"capture_on_queue": ("BOOLEAN", {"default": True}),
}
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "load_capture"
CATEGORY = "image"
def load_capture(self, image, **kwargs):
return super().load_image(folder_paths.get_annotated_filepath(image))
@classmethod
def IS_CHANGED(cls, image, width, height, capture_on_queue):
return super().IS_CHANGED(image)
NODE_CLASS_MAPPINGS = {
"WebcamCapture": WebcamCapture,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"WebcamCapture": "Webcam Capture",
}

View File

@ -13,12 +13,12 @@ import folder_paths
import node_helpers
class SaveImage(io.ComfyNodeV3):
class SaveImage_V3(io.ComfyNodeV3):
@classmethod
def DEFINE_SCHEMA(cls):
return io.SchemaV3(
node_id="SaveImage",
display_name="Save Image",
node_id="SaveImage_V3",
display_name="Save Image _V3",
description="Saves the input images to your ComfyUI output directory.",
category="image",
inputs=[
@ -68,12 +68,12 @@ class SaveImage(io.ComfyNodeV3):
return io.NodeOutput(ui={"images": results})
class PreviewImage(io.ComfyNodeV3):
class PreviewImage_V3(io.ComfyNodeV3):
@classmethod
def DEFINE_SCHEMA(cls):
return io.SchemaV3(
node_id="PreviewImage",
display_name="Preview Image",
node_id="PreviewImage_V3",
display_name="Preview Image _V3",
description="Preview the input images.",
category="image",
inputs=[
@ -92,12 +92,12 @@ class PreviewImage(io.ComfyNodeV3):
return io.NodeOutput(ui=ui.PreviewImage(images))
class LoadImage(io.ComfyNodeV3):
class LoadImage_V3(io.ComfyNodeV3):
@classmethod
def DEFINE_SCHEMA(cls):
return io.SchemaV3(
node_id="LoadImage",
display_name="Load Image",
node_id="LoadImage_V3",
display_name="Load Image _V3",
category="image",
inputs=[
io.Combo.Input(
@ -186,12 +186,12 @@ class LoadImage(io.ComfyNodeV3):
return True
class LoadImageOutput(io.ComfyNodeV3):
class LoadImageOutput_V3(io.ComfyNodeV3):
@classmethod
def DEFINE_SCHEMA(cls):
return io.SchemaV3(
node_id="LoadImageOutput",
display_name="Load Image (from Outputs)",
node_id="LoadImageOutput_V3",
display_name="Load Image (from Outputs) _V3",
description="Load an image from the output folder. "
"When the refresh button is clicked, the node will update the image list "
"and automatically select the first image, allowing for easy iteration.",
@ -283,8 +283,8 @@ class LoadImageOutput(io.ComfyNodeV3):
NODES_LIST: list[type[io.ComfyNodeV3]] = [
SaveImage,
PreviewImage,
LoadImage,
LoadImageOutput,
SaveImage_V3,
PreviewImage_V3,
LoadImage_V3,
LoadImageOutput_V3,
]

View File

@ -1,7 +1,7 @@
from comfy_api.v3 import io, ui
class MaskPreview(io.ComfyNodeV3):
class MaskPreview_V3(io.ComfyNodeV3):
"""Mask Preview - original implement in ComfyUI_essentials.
https://github.com/cubiq/ComfyUI_essentials/blob/9d9f4bedfc9f0321c19faf71855e228c93bd0dc9/mask.py#L81
@ -11,8 +11,8 @@ class MaskPreview(io.ComfyNodeV3):
@classmethod
def DEFINE_SCHEMA(cls):
return io.SchemaV3(
node_id="MaskPreview",
display_name="Convert Mask to Image",
node_id="MaskPreview_V3",
display_name="Convert Mask to Image _V3",
category="mask",
inputs=[
io.Mask.Input(
@ -29,4 +29,4 @@ class MaskPreview(io.ComfyNodeV3):
return io.NodeOutput(ui=ui.PreviewMask(masks))
NODES_LIST: list[type[io.ComfyNodeV3]] = [MaskPreview]
NODES_LIST: list[type[io.ComfyNodeV3]] = [MaskPreview_V3]

View File

@ -13,12 +13,12 @@ import node_helpers
MAX_RESOLUTION = nodes.MAX_RESOLUTION
class WebcamCapture(io.ComfyNodeV3):
class WebcamCapture_V3(io.ComfyNodeV3):
@classmethod
def DEFINE_SCHEMA(cls):
return io.SchemaV3(
node_id="WebcamCapture",
display_name="Webcam Capture",
node_id="WebcamCapture_V3",
display_name="Webcam Capture _V3",
category="image",
inputs=[
io.Webcam.Input(
@ -114,4 +114,4 @@ class WebcamCapture(io.ComfyNodeV3):
return True
NODES_LIST: list[type[io.ComfyNodeV3]] = [WebcamCapture]
NODES_LIST: list[type[io.ComfyNodeV3]] = [WebcamCapture_V3]

179
nodes.py
View File

@ -8,9 +8,11 @@ import hashlib
import traceback
import math
import time
import random
import logging
from PIL import Image, ImageOps
from PIL import Image, ImageOps, ImageSequence
from PIL.PngImagePlugin import PngInfo
import numpy as np
import safetensors.torch
@ -1548,6 +1550,150 @@ class KSamplerAdvanced:
disable_noise = True
return common_ksampler(model, noise_seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=denoise, disable_noise=disable_noise, start_step=start_at_step, last_step=end_at_step, force_full_denoise=force_full_denoise)
class SaveImage:
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
self.type = "output"
self.prefix_append = ""
self.compress_level = 4
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"images": ("IMAGE", {"tooltip": "The images to save."}),
"filename_prefix": ("STRING", {"default": "ComfyUI", "tooltip": "The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes."})
},
"hidden": {
"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"
},
}
RETURN_TYPES = ()
FUNCTION = "save_images"
OUTPUT_NODE = True
CATEGORY = "image"
DESCRIPTION = "Saves the input images to your ComfyUI output directory."
def save_images(self, images, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None):
filename_prefix += self.prefix_append
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0])
results = list()
for (batch_number, image) in enumerate(images):
i = 255. * image.cpu().numpy()
img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8))
metadata = None
if not args.disable_metadata:
metadata = PngInfo()
if prompt is not None:
metadata.add_text("prompt", json.dumps(prompt))
if extra_pnginfo is not None:
for x in extra_pnginfo:
metadata.add_text(x, json.dumps(extra_pnginfo[x]))
filename_with_batch_num = filename.replace("%batch_num%", str(batch_number))
file = f"{filename_with_batch_num}_{counter:05}_.png"
img.save(os.path.join(full_output_folder, file), pnginfo=metadata, compress_level=self.compress_level)
results.append({
"filename": file,
"subfolder": subfolder,
"type": self.type
})
counter += 1
return { "ui": { "images": results } }
class PreviewImage(SaveImage):
def __init__(self):
self.output_dir = folder_paths.get_temp_directory()
self.type = "temp"
self.prefix_append = "_temp_" + ''.join(random.choice("abcdefghijklmnopqrstupvxyz") for x in range(5))
self.compress_level = 1
@classmethod
def INPUT_TYPES(s):
return {"required":
{"images": ("IMAGE", ), },
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
}
class LoadImage:
@classmethod
def INPUT_TYPES(s):
input_dir = folder_paths.get_input_directory()
files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))]
files = folder_paths.filter_files_content_types(files, ["image"])
return {"required":
{"image": (sorted(files), {"image_upload": True})},
}
CATEGORY = "image"
RETURN_TYPES = ("IMAGE", "MASK")
FUNCTION = "load_image"
def load_image(self, image):
image_path = folder_paths.get_annotated_filepath(image)
img = node_helpers.pillow(Image.open, image_path)
output_images = []
output_masks = []
w, h = None, None
excluded_formats = ['MPO']
for i in ImageSequence.Iterator(img):
i = node_helpers.pillow(ImageOps.exif_transpose, i)
if i.mode == 'I':
i = i.point(lambda i: i * (1 / 255))
image = i.convert("RGB")
if len(output_images) == 0:
w = image.size[0]
h = image.size[1]
if image.size[0] != w or image.size[1] != h:
continue
image = np.array(image).astype(np.float32) / 255.0
image = torch.from_numpy(image)[None,]
if 'A' in i.getbands():
mask = np.array(i.getchannel('A')).astype(np.float32) / 255.0
mask = 1. - torch.from_numpy(mask)
elif i.mode == 'P' and 'transparency' in i.info:
mask = np.array(i.convert('RGBA').getchannel('A')).astype(np.float32) / 255.0
mask = 1. - torch.from_numpy(mask)
else:
mask = torch.zeros((64,64), dtype=torch.float32, device="cpu")
output_images.append(image)
output_masks.append(mask.unsqueeze(0))
if len(output_images) > 1 and img.format not in excluded_formats:
output_image = torch.cat(output_images, dim=0)
output_mask = torch.cat(output_masks, dim=0)
else:
output_image = output_images[0]
output_mask = output_masks[0]
return (output_image, output_mask)
@classmethod
def IS_CHANGED(s, image):
image_path = folder_paths.get_annotated_filepath(image)
m = hashlib.sha256()
with open(image_path, 'rb') as f:
m.update(f.read())
return m.digest().hex()
@classmethod
def VALIDATE_INPUTS(s, image):
if not folder_paths.exists_annotated_filepath(image):
return "Invalid image file: {}".format(image)
return True
class LoadImageMask:
_color_channels = ["alpha", "red", "green", "blue"]
@ -1599,6 +1745,28 @@ class LoadImageMask:
return True
class LoadImageOutput(LoadImage):
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image": ("COMBO", {
"image_upload": True,
"image_folder": "output",
"remote": {
"route": "/internal/files/output",
"refresh_button": True,
"control_after_refresh": "first",
},
}),
}
}
DESCRIPTION = "Load an image from the output folder. When the refresh button is clicked, the node will update the image list and automatically select the first image, allowing for easy iteration."
EXPERIMENTAL = True
FUNCTION = "load_image"
class ImageScale:
upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "lanczos"]
crop_methods = ["disabled", "center"]
@ -1781,7 +1949,11 @@ NODE_CLASS_MAPPINGS = {
"LatentUpscaleBy": LatentUpscaleBy,
"LatentFromBatch": LatentFromBatch,
"RepeatLatentBatch": RepeatLatentBatch,
"SaveImage": SaveImage,
"PreviewImage": PreviewImage,
"LoadImage": LoadImage,
"LoadImageMask": LoadImageMask,
"LoadImageOutput": LoadImageOutput,
"ImageScale": ImageScale,
"ImageScaleBy": ImageScaleBy,
"ImageInvert": ImageInvert,
@ -1878,7 +2050,11 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"LatentFromBatch" : "Latent From Batch",
"RepeatLatentBatch": "Repeat Latent Batch",
# Image
"SaveImage": "Save Image",
"PreviewImage": "Preview Image",
"LoadImage": "Load Image",
"LoadImageMask": "Load Image (as Mask)",
"LoadImageOutput": "Load Image (from Outputs)",
"ImageScale": "Upscale Image",
"ImageScaleBy": "Upscale Image By",
"ImageUpscaleWithModel": "Upscale Image (using Model)",
@ -2088,6 +2264,7 @@ def init_builtin_extra_nodes():
"nodes_align_your_steps.py",
"nodes_attention_multiply.py",
"nodes_advanced_samplers.py",
"nodes_webcam.py",
"nodes_audio.py",
"nodes_sd3.py",
"nodes_gits.py",