diff --git a/comfy_extras/nodes_mask.py b/comfy_extras/nodes_mask.py new file mode 100644 index 000000000..4dfb0b93e --- /dev/null +++ b/comfy_extras/nodes_mask.py @@ -0,0 +1,263 @@ +import torch + +from nodes import MAX_RESOLUTION + +class LatentCompositeMasked: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "destination": ("LATENT",), + "source": ("LATENT",), + "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}), + "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}), + }, + "optional": { + "mask": ("MASK",), + } + } + RETURN_TYPES = ("LATENT",) + FUNCTION = "composite" + + CATEGORY = "latent" + + def composite(self, destination, source, x, y, mask = None): + output = destination.copy() + destination = destination["samples"].clone() + source = source["samples"] + + x = max(-source.shape[3] * 8, min(x, destination.shape[3] * 8)) + y = max(-source.shape[2] * 8, min(y, destination.shape[2] * 8)) + + left, top = (x // 8, y // 8) + right, bottom = (left + source.shape[3], top + source.shape[2],) + + + if mask is None: + mask = torch.ones_like(source) + else: + mask = mask.clone() + mask = torch.nn.functional.interpolate(mask[None, None], size=(source.shape[2], source.shape[3]), mode="bilinear") + mask = mask.repeat((source.shape[0], source.shape[1], 1, 1)) + + # calculate the bounds of the source that will be overlapping the destination + # this prevents the source trying to overwrite latent pixels that are out of bounds + # of the destination + visible_width, visible_height = (destination.shape[3] - left + min(0, x), destination.shape[2] - top + min(0, y),) + + mask = mask[:, :, :visible_height, :visible_width] + inverse_mask = torch.ones_like(mask) - mask + + source_portion = mask * source[:, :, :visible_height, :visible_width] + destination_portion = inverse_mask * destination[:, :, top:bottom, left:right] + + destination[:, :, top:bottom, left:right] = source_portion + destination_portion + + output["samples"] = destination + + return (output,) + +class MaskToImage: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "mask": ("MASK",), + } + } + + CATEGORY = "mask" + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "mask_to_image" + + def mask_to_image(self, mask): + result = mask[None, :, :, None].expand(-1, -1, -1, 3) + return (result,) + +class ImageToMask: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "image": ("IMAGE",), + "channel": (["red", "green", "blue"],), + } + } + + CATEGORY = "mask" + + RETURN_TYPES = ("MASK",) + FUNCTION = "image_to_mask" + + def image_to_mask(self, image, channel): + channels = ["red", "green", "blue"] + mask = image[0, :, :, channels.index(channel)] + return (mask,) + +class SolidMask: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "value": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), + "width": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}), + "height": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}), + } + } + + CATEGORY = "mask" + + RETURN_TYPES = ("MASK",) + + FUNCTION = "solid" + + def solid(self, value, width, height): + out = torch.full((height, width), value, dtype=torch.float32, device="cpu") + return (out,) + +class InvertMask: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "mask": ("MASK",), + } + } + + CATEGORY = "mask" + + RETURN_TYPES = ("MASK",) + + FUNCTION = "invert" + + def invert(self, mask): + out = 1.0 - mask + return (out,) + +class CropMask: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "mask": ("MASK",), + "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), + "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), + "width": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}), + "height": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}), + } + } + + CATEGORY = "mask" + + RETURN_TYPES = ("MASK",) + + FUNCTION = "crop" + + def crop(self, mask, x, y, width, height): + out = mask[y:y + height, x:x + width] + return (out,) + +class MaskComposite: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "destination": ("MASK",), + "source": ("MASK",), + "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), + "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), + "operation": (["multiply", "add", "subtract"],), + } + } + + CATEGORY = "mask" + + RETURN_TYPES = ("MASK",) + + FUNCTION = "combine" + + def combine(self, destination, source, x, y, operation): + output = destination.clone() + + left, top = (x, y,) + right, bottom = (min(left + source.shape[1], destination.shape[1]), min(top + source.shape[0], destination.shape[0])) + visible_width, visible_height = (right - left, bottom - top,) + + source_portion = source[:visible_height, :visible_width] + destination_portion = destination[top:bottom, left:right] + + match operation: + case "multiply": + output[top:bottom, left:right] = destination_portion * source_portion + case "add": + output[top:bottom, left:right] = destination_portion + source_portion + case "subtract": + output[top:bottom, left:right] = destination_portion - source_portion + + output = torch.clamp(output, 0.0, 1.0) + + return (output,) + +class FeatherMask: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "mask": ("MASK",), + "left": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), + "top": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), + "right": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), + "bottom": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), + } + } + + CATEGORY = "mask" + + RETURN_TYPES = ("MASK",) + + FUNCTION = "feather" + + def feather(self, mask, left, top, right, bottom): + output = mask.clone() + + left = min(left, output.shape[1]) + right = min(right, output.shape[1]) + top = min(top, output.shape[0]) + bottom = min(bottom, output.shape[0]) + + for x in range(left): + feather_rate = (x + 1.0) / left + output[:, x] *= feather_rate + + for x in range(right): + feather_rate = (x + 1) / right + output[:, -x] *= feather_rate + + for y in range(top): + feather_rate = (y + 1) / top + output[y, :] *= feather_rate + + for y in range(bottom): + feather_rate = (y + 1) / bottom + output[-y, :] *= feather_rate + + return (output,) + + + +NODE_CLASS_MAPPINGS = { + "LatentCompositeMasked": LatentCompositeMasked, + "MaskToImage": MaskToImage, + "ImageToMask": ImageToMask, + "SolidMask": SolidMask, + "InvertMask": InvertMask, + "CropMask": CropMask, + "MaskComposite": MaskComposite, + "FeatherMask": FeatherMask, +} + +NODE_DISPLAY_NAME_MAPPINGS = { + "ImageToMask": "Convert Image to Mask", + "MaskToImage": "Convert Mask to Image", +} diff --git a/nodes.py b/nodes.py index b81d16015..b68c8ef43 100644 --- a/nodes.py +++ b/nodes.py @@ -872,7 +872,7 @@ class SaveImage: "filename": file, "subfolder": subfolder, "type": self.type - }); + }) counter += 1 return { "ui": { "images": results } } @@ -933,7 +933,7 @@ class LoadImageMask: "channel": (["alpha", "red", "green", "blue"], ),} } - CATEGORY = "image" + CATEGORY = "mask" RETURN_TYPES = ("MASK",) FUNCTION = "load_image" @@ -1193,3 +1193,4 @@ def init_custom_nodes(): load_custom_nodes() load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_upscale_model.py")) load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_post_processing.py")) + load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_mask.py")) diff --git a/web/extensions/core/widgetInputs.js b/web/extensions/core/widgetInputs.js index 3764c9848..2b3603419 100644 --- a/web/extensions/core/widgetInputs.js +++ b/web/extensions/core/widgetInputs.js @@ -159,27 +159,31 @@ app.registerExtension({ const r = origOnInputDblClick ? origOnInputDblClick.apply(this, arguments) : undefined; const input = this.inputs[slot]; - if (input.widget && !input[ignoreDblClick]) { - const node = LiteGraph.createNode("PrimitiveNode"); - app.graph.add(node); - - // Calculate a position that wont directly overlap another node - const pos = [this.pos[0] - node.size[0] - 30, this.pos[1]]; - while (isNodeAtPos(pos)) { - pos[1] += LiteGraph.NODE_TITLE_HEIGHT; - } - - node.pos = pos; - node.connect(0, this, slot); - node.title = input.name; - - // Prevent adding duplicates due to triple clicking - input[ignoreDblClick] = true; - setTimeout(() => { - delete input[ignoreDblClick]; - }, 300); + if (!input.widget || !input[ignoreDblClick])// Not a widget input or already handled input + { + if (!(input.type in ComfyWidgets)) return r;//also Not a ComfyWidgets input (do nothing) } + // Create a primitive node + const node = LiteGraph.createNode("PrimitiveNode"); + app.graph.add(node); + + // Calculate a position that wont directly overlap another node + const pos = [this.pos[0] - node.size[0] - 30, this.pos[1]]; + while (isNodeAtPos(pos)) { + pos[1] += LiteGraph.NODE_TITLE_HEIGHT; + } + + node.pos = pos; + node.connect(0, this, slot); + node.title = input.name; + + // Prevent adding duplicates due to triple clicking + input[ignoreDblClick] = true; + setTimeout(() => { + delete input[ignoreDblClick]; + }, 300); + return r; }; }, @@ -233,7 +237,9 @@ app.registerExtension({ // Fires before the link is made allowing us to reject it if it isn't valid // No widget, we cant connect - if (!input.widget) return false; + if (!input.widget) { + if (!(input.type in ComfyWidgets)) return false; + } if (this.outputs[slot].links?.length) { return this.#isValidConnection(input); @@ -252,9 +258,17 @@ app.registerExtension({ const input = theirNode.inputs[link.target_slot]; if (!input) return; - const widget = input.widget; - const { type, linkType } = getWidgetType(widget.config); + var _widget; + if (!input.widget) { + if (!(input.type in ComfyWidgets)) return; + _widget = { "name": input.name, "config": [input.type, {}] }//fake widget + } else { + _widget = input.widget; + } + + const widget = _widget; + const { type, linkType } = getWidgetType(widget.config); // Update our output to restrict to the widget type this.outputs[0].type = linkType; this.outputs[0].name = type; @@ -274,7 +288,7 @@ app.registerExtension({ if (type in ComfyWidgets) { widget = (ComfyWidgets[type](this, "value", inputData, app) || {}).widget; } else { - widget = this.addWidget(type, "value", null, () => {}, {}); + widget = this.addWidget(type, "value", null, () => { }, {}); } if (node?.widgets && widget) {