fixes, corrections; ported MaskPreview, WebcamCapture and LoadImageOutput nodes

This commit is contained in:
bigcat88 2025-07-09 11:09:19 +03:00
parent 1eb1a44883
commit fefb24cc33
No known key found for this signature in database
GPG Key ID: 1F0BF0EC3CF22721
8 changed files with 469 additions and 291 deletions

View File

@ -209,8 +209,6 @@ class WidgetInputV3(InputV3):
}) })
def get_io_type_V1(self): def get_io_type_V1(self):
if isinstance(self, Combo.Input):
return self.as_value_type_v1()
return self.widgetType if self.widgetType is not None else super().get_io_type_V1() return self.widgetType if self.widgetType is not None else super().get_io_type_V1()
@ -411,18 +409,7 @@ class Combo(ComfyType):
self.remote = remote self.remote = remote
self.default: str self.default: str
def as_dict_V1(self): def get_io_type_V1(self):
return super().as_dict_V1() | prune_dict({
"multiselect": self.multiselect,
"options": self.options,
"control_after_generate": self.control_after_generate,
"image_upload": self.image_upload,
"image_folder": self.image_folder.value if self.image_folder else None,
"content_types": self.content_types if self.content_types else None,
"remote": self.remote.as_dict() if self.remote else None,
})
def as_value_type_v1(self):
if getattr(self, "image_folder"): if getattr(self, "image_folder"):
if self.image_folder == FolderType.input: if self.image_folder == FolderType.input:
target_dir = folder_paths.get_input_directory() target_dir = folder_paths.get_input_directory()
@ -434,6 +421,18 @@ class Combo(ComfyType):
if self.content_types is None: if self.content_types is None:
return files return files
return sorted(folder_paths.filter_files_content_types(files, self.content_types)) return sorted(folder_paths.filter_files_content_types(files, self.content_types))
return super().get_io_type_V1()
def as_dict_V1(self):
return super().as_dict_V1() | prune_dict({
"multiselect": self.multiselect,
"options": self.options,
"control_after_generate": self.control_after_generate,
"image_upload": self.image_upload,
"image_folder": self.image_folder.value if self.image_folder else None,
"content_types": self.content_types if self.content_types else None,
"remote": self.remote.as_dict() if self.remote else None,
})
@comfytype(io_type="COMBO") @comfytype(io_type="COMBO")
@ -463,6 +462,20 @@ class MultiCombo(ComfyType):
class Image(ComfyTypeIO): class Image(ComfyTypeIO):
Type = torch.Tensor Type = torch.Tensor
@comfytype(io_type="WEBCAM")
class Webcam(ComfyTypeIO):
Type = str
class Input(WidgetInputV3):
"""Webcam input."""
Type = str
def __init__(
self, id: str, display_name: str=None, optional=False,
tooltip: str=None, lazy: bool=None, default: str=None, socketless: bool=None
):
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, self.io_type)
@comfytype(io_type="MASK") @comfytype(io_type="MASK")
class Mask(ComfyTypeIO): class Mask(ComfyTypeIO):
Type = torch.Tensor Type = torch.Tensor
@ -1121,7 +1134,7 @@ class ComfyNodeV3:
type_clone: type[ComfyNodeV3] = type(f"CLEAN_{c_type.__name__}", c_type.__bases__, {}) type_clone: type[ComfyNodeV3] = type(f"CLEAN_{c_type.__name__}", c_type.__bases__, {})
# TODO: what parameters should be carried over? # TODO: what parameters should be carried over?
type_clone.SCHEMA = c_type.SCHEMA type_clone.SCHEMA = c_type.SCHEMA
type_clone.hidden = HiddenHolder.from_dict(hidden_inputs) type_clone.hidden = HiddenHolder.from_dict(hidden_inputs) if hidden_inputs is not None else None
# TODO: add anything we would want to expose inside node's execute function # TODO: add anything we would want to expose inside node's execute function
return type_clone return type_clone

View File

@ -3,10 +3,7 @@ import scipy.ndimage
import torch import torch
import comfy.utils import comfy.utils
import node_helpers import node_helpers
import folder_paths
import random
import nodes
from nodes import MAX_RESOLUTION from nodes import MAX_RESOLUTION
def composite(destination, source, x, y, mask = None, multiplier = 8, resize_source = False): def composite(destination, source, x, y, mask = None, multiplier = 8, resize_source = False):
@ -365,30 +362,6 @@ class ThresholdMask:
mask = (mask > value).float() mask = (mask > value).float()
return (mask,) return (mask,)
# Mask Preview - original implement from
# https://github.com/cubiq/ComfyUI_essentials/blob/9d9f4bedfc9f0321c19faf71855e228c93bd0dc9/mask.py#L81
# upstream requested in https://github.com/Kosinkadink/rfcs/blob/main/rfcs/0000-corenodes.md#preview-nodes
class MaskPreview(nodes.SaveImage):
def __init__(self):
self.output_dir = folder_paths.get_temp_directory()
self.type = "temp"
self.prefix_append = "_temp_" + ''.join(random.choice("abcdefghijklmnopqrstupvxyz") for x in range(5))
self.compress_level = 4
@classmethod
def INPUT_TYPES(s):
return {
"required": {"mask": ("MASK",), },
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
}
FUNCTION = "execute"
CATEGORY = "mask"
def execute(self, mask, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None):
preview = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])).movedim(1, -1).expand(-1, -1, -1, 3)
return self.save_images(preview, filename_prefix, prompt, extra_pnginfo)
NODE_CLASS_MAPPINGS = { NODE_CLASS_MAPPINGS = {
"LatentCompositeMasked": LatentCompositeMasked, "LatentCompositeMasked": LatentCompositeMasked,
@ -403,10 +376,8 @@ NODE_CLASS_MAPPINGS = {
"FeatherMask": FeatherMask, "FeatherMask": FeatherMask,
"GrowMask": GrowMask, "GrowMask": GrowMask,
"ThresholdMask": ThresholdMask, "ThresholdMask": ThresholdMask,
"MaskPreview": MaskPreview
} }
NODE_DISPLAY_NAME_MAPPINGS = { NODE_DISPLAY_NAME_MAPPINGS = {
"ImageToMask": "Convert Image to Mask", "ImageToMask": "Convert Image to Mask",
"MaskToImage": "Convert Mask to Image",
} }

View File

@ -1,37 +0,0 @@
import nodes
import folder_paths
MAX_RESOLUTION = nodes.MAX_RESOLUTION
class WebcamCapture(nodes.LoadImage):
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image": ("WEBCAM", {}),
"width": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
"height": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
"capture_on_queue": ("BOOLEAN", {"default": True}),
}
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "load_capture"
CATEGORY = "image"
def load_capture(self, image, **kwargs):
return super().load_image(folder_paths.get_annotated_filepath(image))
@classmethod
def IS_CHANGED(cls, image, width, height, capture_on_queue):
return super().IS_CHANGED(image)
NODE_CLASS_MAPPINGS = {
"WebcamCapture": WebcamCapture,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"WebcamCapture": "Webcam Capture",
}

View File

@ -0,0 +1,283 @@
import json
import os
import torch
import hashlib
import numpy as np
from PIL import Image, ImageOps, ImageSequence
from PIL.PngImagePlugin import PngInfo
from comfy_api.v3 import io, ui
from comfy.cli_args import args
import folder_paths
import node_helpers
class SaveImage(io.ComfyNodeV3):
@classmethod
def DEFINE_SCHEMA(cls):
return io.SchemaV3(
node_id="SaveImage",
display_name="Save Image",
description="Saves the input images to your ComfyUI output directory.",
category="image",
inputs=[
io.Image.Input(
"images",
display_name="images",
tooltip="The images to save.",
),
io.String.Input(
"filename_prefix",
default="ComfyUI",
tooltip="The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes.",
),
],
hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo],
is_output_node=True,
)
@classmethod
def execute(cls, images, filename_prefix="ComfyUI"):
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(
"", folder_paths.get_output_directory(), images[0].shape[1], images[0].shape[0]
)
results = []
for (batch_number, image) in enumerate(images):
i = 255. * image.cpu().numpy()
img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8))
metadata = None
if not args.disable_metadata:
metadata = PngInfo()
if cls.hidden.prompt is not None:
metadata.add_text("prompt", json.dumps(cls.hidden.prompt))
if cls.hidden.extra_pnginfo is not None:
for x in cls.hidden.extra_pnginfo:
metadata.add_text(x, json.dumps(cls.hidden.extra_pnginfo[x]))
filename_with_batch_num = filename.replace("%batch_num%", str(batch_number))
file = f"{filename_with_batch_num}_{counter:05}_.png"
img.save(os.path.join(full_output_folder, file), pnginfo=metadata, compress_level=4)
results.append({
"filename": file,
"subfolder": subfolder,
"type": "output",
})
counter += 1
return io.NodeOutput(ui={"images": results})
class PreviewImage(io.ComfyNodeV3):
@classmethod
def DEFINE_SCHEMA(cls):
return io.SchemaV3(
node_id="PreviewImage",
display_name="Preview Image",
description="Preview the input images.",
category="image",
inputs=[
io.Image.Input(
"images",
display_name="images",
tooltip="The images to preview.",
),
],
hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo],
is_output_node=True,
)
@classmethod
def execute(cls, images):
return io.NodeOutput(ui=ui.PreviewImage(images))
class LoadImage(io.ComfyNodeV3):
@classmethod
def DEFINE_SCHEMA(cls):
return io.SchemaV3(
node_id="LoadImage",
display_name="Load Image",
category="image",
inputs=[
io.Combo.Input(
"image",
display_name="image",
image_upload=True,
image_folder=io.FolderType.input,
content_types=["image"],
),
],
outputs=[
io.Image.Output(
"IMAGE",
),
io.Mask.Output(
"MASK",
),
],
)
@classmethod
def execute(cls, image) -> io.NodeOutput:
img = node_helpers.pillow(Image.open, folder_paths.get_annotated_filepath(image))
output_images = []
output_masks = []
w, h = None, None
excluded_formats = ['MPO']
for i in ImageSequence.Iterator(img):
i = node_helpers.pillow(ImageOps.exif_transpose, i)
if i.mode == 'I':
i = i.point(lambda i: i * (1 / 255))
image = i.convert("RGB")
if len(output_images) == 0:
w = image.size[0]
h = image.size[1]
if image.size[0] != w or image.size[1] != h:
continue
image = np.array(image).astype(np.float32) / 255.0
image = torch.from_numpy(image)[None,]
if 'A' in i.getbands():
mask = np.array(i.getchannel('A')).astype(np.float32) / 255.0
mask = 1. - torch.from_numpy(mask)
elif i.mode == 'P' and 'transparency' in i.info:
mask = np.array(i.convert('RGBA').getchannel('A')).astype(np.float32) / 255.0
mask = 1. - torch.from_numpy(mask)
else:
mask = torch.zeros((64,64), dtype=torch.float32, device="cpu")
output_images.append(image)
output_masks.append(mask.unsqueeze(0))
if len(output_images) > 1 and img.format not in excluded_formats:
output_image = torch.cat(output_images, dim=0)
output_mask = torch.cat(output_masks, dim=0)
else:
output_image = output_images[0]
output_mask = output_masks[0]
return io.NodeOutput(output_image, output_mask)
@classmethod
def IS_CHANGED(s, image):
image_path = folder_paths.get_annotated_filepath(image)
m = hashlib.sha256()
with open(image_path, 'rb') as f:
m.update(f.read())
return m.digest().hex()
@classmethod
def VALIDATE_INPUTS(s, image):
if not folder_paths.exists_annotated_filepath(image):
return "Invalid image file: {}".format(image)
return True
class LoadImageOutput(io.ComfyNodeV3):
@classmethod
def DEFINE_SCHEMA(cls):
return io.SchemaV3(
node_id="LoadImageOutput",
display_name="Load Image (from Outputs)",
description="Load an image from the output folder. "
"When the refresh button is clicked, the node will update the image list "
"and automatically select the first image, allowing for easy iteration.",
category="image",
inputs=[
io.Combo.Input(
"image",
display_name="image",
image_upload=True,
image_folder=io.FolderType.output,
content_types=["image"],
remote=io.RemoteOptions(
route="/internal/files/output",
refresh_button=True,
control_after_refresh="first",
),
),
],
outputs=[
io.Image.Output(
"IMAGE",
),
io.Mask.Output(
"MASK",
),
],
)
@classmethod
def execute(cls, image) -> io.NodeOutput:
img = node_helpers.pillow(Image.open, folder_paths.get_annotated_filepath(image))
output_images = []
output_masks = []
w, h = None, None
excluded_formats = ['MPO']
for i in ImageSequence.Iterator(img):
i = node_helpers.pillow(ImageOps.exif_transpose, i)
if i.mode == 'I':
i = i.point(lambda i: i * (1 / 255))
image = i.convert("RGB")
if len(output_images) == 0:
w = image.size[0]
h = image.size[1]
if image.size[0] != w or image.size[1] != h:
continue
image = np.array(image).astype(np.float32) / 255.0
image = torch.from_numpy(image)[None,]
if 'A' in i.getbands():
mask = np.array(i.getchannel('A')).astype(np.float32) / 255.0
mask = 1. - torch.from_numpy(mask)
elif i.mode == 'P' and 'transparency' in i.info:
mask = np.array(i.convert('RGBA').getchannel('A')).astype(np.float32) / 255.0
mask = 1. - torch.from_numpy(mask)
else:
mask = torch.zeros((64, 64), dtype=torch.float32, device="cpu")
output_images.append(image)
output_masks.append(mask.unsqueeze(0))
if len(output_images) > 1 and img.format not in excluded_formats:
output_image = torch.cat(output_images, dim=0)
output_mask = torch.cat(output_masks, dim=0)
else:
output_image = output_images[0]
output_mask = output_masks[0]
return io.NodeOutput(output_image, output_mask)
@classmethod
def IS_CHANGED(s, image):
image_path = folder_paths.get_annotated_filepath(image)
m = hashlib.sha256()
with open(image_path, 'rb') as f:
m.update(f.read())
return m.digest().hex()
@classmethod
def VALIDATE_INPUTS(s, image):
if not folder_paths.exists_annotated_filepath(image):
return "Invalid image file: {}".format(image)
return True
NODES_LIST: list[type[io.ComfyNodeV3]] = [
SaveImage,
PreviewImage,
LoadImage,
LoadImageOutput,
]

View File

@ -0,0 +1,32 @@
from comfy_api.v3 import io, ui
class MaskPreview(io.ComfyNodeV3):
"""Mask Preview - original implement in ComfyUI_essentials.
https://github.com/cubiq/ComfyUI_essentials/blob/9d9f4bedfc9f0321c19faf71855e228c93bd0dc9/mask.py#L81
Upstream requested in https://github.com/Kosinkadink/rfcs/blob/main/rfcs/0000-corenodes.md#preview-nodes
"""
@classmethod
def DEFINE_SCHEMA(cls):
return io.SchemaV3(
node_id="MaskPreview",
display_name="Convert Mask to Image",
category="mask",
inputs=[
io.Mask.Input(
"masks",
display_name="masks",
),
],
hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo],
is_output_node=True,
)
@classmethod
def execute(cls, masks):
return io.NodeOutput(ui=ui.PreviewMask(masks))
NODES_LIST: list[type[io.ComfyNodeV3]] = [MaskPreview]

View File

@ -0,0 +1,118 @@
import hashlib
import torch
import numpy as np
from PIL import Image, ImageOps, ImageSequence
from comfy_api.v3 import io
import nodes
import folder_paths
import node_helpers
MAX_RESOLUTION = nodes.MAX_RESOLUTION
class WebcamCapture(io.ComfyNodeV3):
@classmethod
def DEFINE_SCHEMA(cls):
return io.SchemaV3(
node_id="WebcamCapture",
display_name="Webcam Capture",
category="image",
inputs=[
io.Webcam.Input(
"image",
display_name="image",
),
io.Int.Input(
"width",
display_name="width",
default=0,
min=0,
max=MAX_RESOLUTION,
step=1,
),
io.Int.Input(
"height",
display_name="height",
default=0,
min=0,
max=MAX_RESOLUTION,
step=1,
),
io.Boolean.Input(
"capture_on_queue",
default=True,
),
],
outputs=[
io.Image.Output(
"IMAGE",
),
],
)
@classmethod
def execute(cls, image, **kwargs) -> io.NodeOutput:
img = node_helpers.pillow(Image.open, folder_paths.get_annotated_filepath(image))
output_images = []
output_masks = []
w, h = None, None
excluded_formats = ['MPO']
for i in ImageSequence.Iterator(img):
i = node_helpers.pillow(ImageOps.exif_transpose, i)
if i.mode == 'I':
i = i.point(lambda i: i * (1 / 255))
image = i.convert("RGB")
if len(output_images) == 0:
w = image.size[0]
h = image.size[1]
if image.size[0] != w or image.size[1] != h:
continue
image = np.array(image).astype(np.float32) / 255.0
image = torch.from_numpy(image)[None,]
if 'A' in i.getbands():
mask = np.array(i.getchannel('A')).astype(np.float32) / 255.0
mask = 1. - torch.from_numpy(mask)
elif i.mode == 'P' and 'transparency' in i.info:
mask = np.array(i.convert('RGBA').getchannel('A')).astype(np.float32) / 255.0
mask = 1. - torch.from_numpy(mask)
else:
mask = torch.zeros((64, 64), dtype=torch.float32, device="cpu")
output_images.append(image)
output_masks.append(mask.unsqueeze(0))
if len(output_images) > 1 and img.format not in excluded_formats:
output_image = torch.cat(output_images, dim=0)
output_mask = torch.cat(output_masks, dim=0)
else:
output_image = output_images[0]
output_mask = output_masks[0]
return io.NodeOutput(output_image, output_mask)
@classmethod
def IS_CHANGED(s, image, width, height, capture_on_queue):
image_path = folder_paths.get_annotated_filepath(image)
m = hashlib.sha256()
with open(image_path, 'rb') as f:
m.update(f.read())
return m.digest().hex()
@classmethod
def VALIDATE_INPUTS(s, image):
if not folder_paths.exists_annotated_filepath(image):
return "Invalid image file: {}".format(image)
return True
NODES_LIST: list[type[io.ComfyNodeV3]] = [WebcamCapture]

View File

@ -321,7 +321,10 @@ def get_output_data(obj, input_data_all, execution_block_cb=None, pre_execute_cb
elif isinstance(r, NodeOutput): elif isinstance(r, NodeOutput):
# V3 # V3
if r.ui is not None: if r.ui is not None:
uis.append(r.ui.as_dict()) if isinstance(r.ui, dict):
uis.append(r.ui)
else:
uis.append(r.ui.as_dict())
if r.expand is not None: if r.expand is not None:
has_subgraph = True has_subgraph = True
new_graph = r.expand new_graph = r.expand

213
nodes.py
View File

@ -8,11 +8,9 @@ import hashlib
import traceback import traceback
import math import math
import time import time
import random
import logging import logging
from PIL import Image, ImageOps, ImageSequence from PIL import Image, ImageOps
from PIL.PngImagePlugin import PngInfo
import numpy as np import numpy as np
import safetensors.torch import safetensors.torch
@ -1551,181 +1549,6 @@ class KSamplerAdvanced:
return common_ksampler(model, noise_seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=denoise, disable_noise=disable_noise, start_step=start_at_step, last_step=end_at_step, force_full_denoise=force_full_denoise) return common_ksampler(model, noise_seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=denoise, disable_noise=disable_noise, start_step=start_at_step, last_step=end_at_step, force_full_denoise=force_full_denoise)
class SaveImage(io.ComfyNodeV3):
@classmethod
def DEFINE_SCHEMA(cls):
return io.SchemaV3(
node_id="SaveImage",
display_name="Save Image",
description="Saves the input images to your ComfyUI output directory.",
category="image",
inputs=[
io.Image.Input(
"images",
display_name="images",
tooltip="The images to save.",
),
io.String.Input(
"filename_prefix",
default="ComfyUI",
tooltip="The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes.",
),
],
hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo],
is_output_node=True,
)
def __init__(self):
super().__init__()
self.output_dir = folder_paths.get_output_directory()
self.type = "output"
self.prefix_append = ""
self.compress_level = 4
def execute(self, images, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None):
filename_prefix += self.prefix_append
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0])
results = list()
for (batch_number, image) in enumerate(images):
i = 255. * image.cpu().numpy()
img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8))
metadata = None
if not args.disable_metadata:
metadata = PngInfo()
if prompt is not None:
metadata.add_text("prompt", json.dumps(prompt))
if extra_pnginfo is not None:
for x in extra_pnginfo:
metadata.add_text(x, json.dumps(extra_pnginfo[x]))
filename_with_batch_num = filename.replace("%batch_num%", str(batch_number))
file = f"{filename_with_batch_num}_{counter:05}_.png"
img.save(os.path.join(full_output_folder, file), pnginfo=metadata, compress_level=self.compress_level)
results.append({
"filename": file,
"subfolder": subfolder,
"type": self.type,
})
counter += 1
return { "ui": { "images": results } }
class PreviewImage(SaveImage):
@classmethod
def DEFINE_SCHEMA(cls):
return io.SchemaV3(
node_id="PreviewImage",
display_name="Preview Image",
description="Preview the input images.",
category="image",
inputs=[
io.Image.Input(
"images",
display_name="images",
tooltip="The images to preview.",
),
],
hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo],
is_output_node=True,
)
def __init__(self):
super().__init__()
self.output_dir = folder_paths.get_temp_directory()
self.type = "temp"
self.prefix_append = "_temp_" + ''.join(random.choice("abcdefghijklmnopqrstupvxyz") for x in range(5))
self.compress_level = 1
class LoadImage(io.ComfyNodeV3):
@classmethod
def DEFINE_SCHEMA(cls):
return io.SchemaV3(
node_id="LoadImage",
display_name="Load Image",
category="image",
inputs=[
io.Combo.Input(
"image",
display_name="image",
image_upload=True,
image_folder=io.FolderType.input,
content_types=["image"],
),
],
outputs=[
io.Image.Output(
"IMAGE",
),
io.Mask.Output(
"MASK",
),
],
)
@classmethod
def execute(cls, image) -> io.NodeOutput:
img = node_helpers.pillow(Image.open, folder_paths.get_annotated_filepath(image))
output_images = []
output_masks = []
w, h = None, None
excluded_formats = ['MPO']
for i in ImageSequence.Iterator(img):
i = node_helpers.pillow(ImageOps.exif_transpose, i)
if i.mode == 'I':
i = i.point(lambda i: i * (1 / 255))
image = i.convert("RGB")
if len(output_images) == 0:
w = image.size[0]
h = image.size[1]
if image.size[0] != w or image.size[1] != h:
continue
image = np.array(image).astype(np.float32) / 255.0
image = torch.from_numpy(image)[None,]
if 'A' in i.getbands():
mask = np.array(i.getchannel('A')).astype(np.float32) / 255.0
mask = 1. - torch.from_numpy(mask)
elif i.mode == 'P' and 'transparency' in i.info:
mask = np.array(i.convert('RGBA').getchannel('A')).astype(np.float32) / 255.0
mask = 1. - torch.from_numpy(mask)
else:
mask = torch.zeros((64,64), dtype=torch.float32, device="cpu")
output_images.append(image)
output_masks.append(mask.unsqueeze(0))
if len(output_images) > 1 and img.format not in excluded_formats:
output_image = torch.cat(output_images, dim=0)
output_mask = torch.cat(output_masks, dim=0)
else:
output_image = output_images[0]
output_mask = output_masks[0]
return io.NodeOutput(output_image, output_mask)
@classmethod
def IS_CHANGED(s, image):
image_path = folder_paths.get_annotated_filepath(image)
m = hashlib.sha256()
with open(image_path, 'rb') as f:
m.update(f.read())
return m.digest().hex()
@classmethod
def VALIDATE_INPUTS(s, image):
if not folder_paths.exists_annotated_filepath(image):
return "Invalid image file: {}".format(image)
return True
class LoadImageMask: class LoadImageMask:
_color_channels = ["alpha", "red", "green", "blue"] _color_channels = ["alpha", "red", "green", "blue"]
@classmethod @classmethod
@ -1776,28 +1599,6 @@ class LoadImageMask:
return True return True
class LoadImageOutput(LoadImage):
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image": ("COMBO", {
"image_upload": True,
"image_folder": "output",
"remote": {
"route": "/internal/files/output",
"refresh_button": True,
"control_after_refresh": "first",
},
}),
}
}
DESCRIPTION = "Load an image from the output folder. When the refresh button is clicked, the node will update the image list and automatically select the first image, allowing for easy iteration."
EXPERIMENTAL = True
FUNCTION = "load_image"
class ImageScale: class ImageScale:
upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "lanczos"] upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "lanczos"]
crop_methods = ["disabled", "center"] crop_methods = ["disabled", "center"]
@ -1980,11 +1781,7 @@ NODE_CLASS_MAPPINGS = {
"LatentUpscaleBy": LatentUpscaleBy, "LatentUpscaleBy": LatentUpscaleBy,
"LatentFromBatch": LatentFromBatch, "LatentFromBatch": LatentFromBatch,
"RepeatLatentBatch": RepeatLatentBatch, "RepeatLatentBatch": RepeatLatentBatch,
"SaveImage": SaveImage,
"PreviewImage": PreviewImage,
"LoadImage": LoadImage,
"LoadImageMask": LoadImageMask, "LoadImageMask": LoadImageMask,
"LoadImageOutput": LoadImageOutput,
"ImageScale": ImageScale, "ImageScale": ImageScale,
"ImageScaleBy": ImageScaleBy, "ImageScaleBy": ImageScaleBy,
"ImageInvert": ImageInvert, "ImageInvert": ImageInvert,
@ -2081,11 +1878,7 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"LatentFromBatch" : "Latent From Batch", "LatentFromBatch" : "Latent From Batch",
"RepeatLatentBatch": "Repeat Latent Batch", "RepeatLatentBatch": "Repeat Latent Batch",
# Image # Image
"SaveImage": "Save Image",
"PreviewImage": "Preview Image",
"LoadImage": "Load Image",
"LoadImageMask": "Load Image (as Mask)", "LoadImageMask": "Load Image (as Mask)",
"LoadImageOutput": "Load Image (from Outputs)",
"ImageScale": "Upscale Image", "ImageScale": "Upscale Image",
"ImageScaleBy": "Upscale Image By", "ImageScaleBy": "Upscale Image By",
"ImageUpscaleWithModel": "Upscale Image (using Model)", "ImageUpscaleWithModel": "Upscale Image (using Model)",
@ -2295,7 +2088,6 @@ def init_builtin_extra_nodes():
"nodes_align_your_steps.py", "nodes_align_your_steps.py",
"nodes_attention_multiply.py", "nodes_attention_multiply.py",
"nodes_advanced_samplers.py", "nodes_advanced_samplers.py",
"nodes_webcam.py",
"nodes_audio.py", "nodes_audio.py",
"nodes_sd3.py", "nodes_sd3.py",
"nodes_gits.py", "nodes_gits.py",
@ -2330,6 +2122,9 @@ def init_builtin_extra_nodes():
"nodes_tcfg.py" "nodes_tcfg.py"
"nodes_v3_test.py", "nodes_v3_test.py",
"nodes_v1_test.py", "nodes_v1_test.py",
"v3/nodes_images.py",
"v3/nodes_mask.py",
"v3/nodes_webcam.py",
] ]
import_failed = [] import_failed = []