mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-07-27 16:26:39 +00:00
227 lines
8.5 KiB
Python
227 lines
8.5 KiB
Python
from __future__ import annotations
|
|
|
|
from enum import Enum
|
|
|
|
import torch
|
|
|
|
import comfy.utils
|
|
from comfy_api.v3 import io
|
|
|
|
|
|
def resize_mask(mask, shape):
|
|
return torch.nn.functional.interpolate(
|
|
mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(shape[0], shape[1]), mode="bilinear"
|
|
).squeeze(1)
|
|
|
|
|
|
class PorterDuffMode(Enum):
|
|
ADD = 0
|
|
CLEAR = 1
|
|
DARKEN = 2
|
|
DST = 3
|
|
DST_ATOP = 4
|
|
DST_IN = 5
|
|
DST_OUT = 6
|
|
DST_OVER = 7
|
|
LIGHTEN = 8
|
|
MULTIPLY = 9
|
|
OVERLAY = 10
|
|
SCREEN = 11
|
|
SRC = 12
|
|
SRC_ATOP = 13
|
|
SRC_IN = 14
|
|
SRC_OUT = 15
|
|
SRC_OVER = 16
|
|
XOR = 17
|
|
|
|
|
|
def porter_duff_composite(
|
|
src_image: torch.Tensor, src_alpha: torch.Tensor, dst_image: torch.Tensor, dst_alpha: torch.Tensor, mode: PorterDuffMode
|
|
):
|
|
# convert mask to alpha
|
|
src_alpha = 1 - src_alpha
|
|
dst_alpha = 1 - dst_alpha
|
|
# premultiply alpha
|
|
src_image = src_image * src_alpha
|
|
dst_image = dst_image * dst_alpha
|
|
|
|
# composite ops below assume alpha-premultiplied images
|
|
if mode == PorterDuffMode.ADD:
|
|
out_alpha = torch.clamp(src_alpha + dst_alpha, 0, 1)
|
|
out_image = torch.clamp(src_image + dst_image, 0, 1)
|
|
elif mode == PorterDuffMode.CLEAR:
|
|
out_alpha = torch.zeros_like(dst_alpha)
|
|
out_image = torch.zeros_like(dst_image)
|
|
elif mode == PorterDuffMode.DARKEN:
|
|
out_alpha = src_alpha + dst_alpha - src_alpha * dst_alpha
|
|
out_image = (1 - dst_alpha) * src_image + (1 - src_alpha) * dst_image + torch.min(src_image, dst_image)
|
|
elif mode == PorterDuffMode.DST:
|
|
out_alpha = dst_alpha
|
|
out_image = dst_image
|
|
elif mode == PorterDuffMode.DST_ATOP:
|
|
out_alpha = src_alpha
|
|
out_image = src_alpha * dst_image + (1 - dst_alpha) * src_image
|
|
elif mode == PorterDuffMode.DST_IN:
|
|
out_alpha = src_alpha * dst_alpha
|
|
out_image = dst_image * src_alpha
|
|
elif mode == PorterDuffMode.DST_OUT:
|
|
out_alpha = (1 - src_alpha) * dst_alpha
|
|
out_image = (1 - src_alpha) * dst_image
|
|
elif mode == PorterDuffMode.DST_OVER:
|
|
out_alpha = dst_alpha + (1 - dst_alpha) * src_alpha
|
|
out_image = dst_image + (1 - dst_alpha) * src_image
|
|
elif mode == PorterDuffMode.LIGHTEN:
|
|
out_alpha = src_alpha + dst_alpha - src_alpha * dst_alpha
|
|
out_image = (1 - dst_alpha) * src_image + (1 - src_alpha) * dst_image + torch.max(src_image, dst_image)
|
|
elif mode == PorterDuffMode.MULTIPLY:
|
|
out_alpha = src_alpha * dst_alpha
|
|
out_image = src_image * dst_image
|
|
elif mode == PorterDuffMode.OVERLAY:
|
|
out_alpha = src_alpha + dst_alpha - src_alpha * dst_alpha
|
|
out_image = torch.where(2 * dst_image < dst_alpha, 2 * src_image * dst_image,
|
|
src_alpha * dst_alpha - 2 * (dst_alpha - src_image) * (src_alpha - dst_image))
|
|
elif mode == PorterDuffMode.SCREEN:
|
|
out_alpha = src_alpha + dst_alpha - src_alpha * dst_alpha
|
|
out_image = src_image + dst_image - src_image * dst_image
|
|
elif mode == PorterDuffMode.SRC:
|
|
out_alpha = src_alpha
|
|
out_image = src_image
|
|
elif mode == PorterDuffMode.SRC_ATOP:
|
|
out_alpha = dst_alpha
|
|
out_image = dst_alpha * src_image + (1 - src_alpha) * dst_image
|
|
elif mode == PorterDuffMode.SRC_IN:
|
|
out_alpha = src_alpha * dst_alpha
|
|
out_image = src_image * dst_alpha
|
|
elif mode == PorterDuffMode.SRC_OUT:
|
|
out_alpha = (1 - dst_alpha) * src_alpha
|
|
out_image = (1 - dst_alpha) * src_image
|
|
elif mode == PorterDuffMode.SRC_OVER:
|
|
out_alpha = src_alpha + (1 - src_alpha) * dst_alpha
|
|
out_image = src_image + (1 - src_alpha) * dst_image
|
|
elif mode == PorterDuffMode.XOR:
|
|
out_alpha = (1 - dst_alpha) * src_alpha + (1 - src_alpha) * dst_alpha
|
|
out_image = (1 - dst_alpha) * src_image + (1 - src_alpha) * dst_image
|
|
else:
|
|
return None, None
|
|
|
|
# back to non-premultiplied alpha
|
|
out_image = torch.where(out_alpha > 1e-5, out_image / out_alpha, torch.zeros_like(out_image))
|
|
out_image = torch.clamp(out_image, 0, 1)
|
|
# convert alpha to mask
|
|
out_alpha = 1 - out_alpha
|
|
return out_image, out_alpha
|
|
|
|
|
|
class JoinImageWithAlpha(io.ComfyNodeV3):
|
|
@classmethod
|
|
def define_schema(cls):
|
|
return io.SchemaV3(
|
|
node_id="JoinImageWithAlpha_V3",
|
|
display_name="Join Image with Alpha _V3",
|
|
category="mask/compositing",
|
|
inputs=[
|
|
io.Image.Input("image"),
|
|
io.Mask.Input("alpha"),
|
|
],
|
|
outputs=[io.Image.Output()],
|
|
)
|
|
|
|
@classmethod
|
|
def execute(cls, image: torch.Tensor, alpha: torch.Tensor) -> io.NodeOutput:
|
|
batch_size = min(len(image), len(alpha))
|
|
out_images = []
|
|
|
|
alpha = 1.0 - resize_mask(alpha, image.shape[1:])
|
|
for i in range(batch_size):
|
|
out_images.append(torch.cat((image[i][:, :, :3], alpha[i].unsqueeze(2)), dim=2))
|
|
|
|
return io.NodeOutput(torch.stack(out_images))
|
|
|
|
|
|
class PorterDuffImageComposite(io.ComfyNodeV3):
|
|
@classmethod
|
|
def define_schema(cls):
|
|
return io.SchemaV3(
|
|
node_id="PorterDuffImageComposite_V3",
|
|
display_name="Porter-Duff Image Composite _V3",
|
|
category="mask/compositing",
|
|
inputs=[
|
|
io.Image.Input("source"),
|
|
io.Mask.Input("source_alpha"),
|
|
io.Image.Input("destination"),
|
|
io.Mask.Input("destination_alpha"),
|
|
io.Combo.Input("mode", options=[mode.name for mode in PorterDuffMode], default=PorterDuffMode.DST.name),
|
|
],
|
|
outputs=[io.Image.Output(), io.Mask.Output()],
|
|
)
|
|
|
|
@classmethod
|
|
def execute(
|
|
cls, source: torch.Tensor, source_alpha: torch.Tensor, destination: torch.Tensor, destination_alpha: torch.Tensor, mode
|
|
) -> io.NodeOutput:
|
|
batch_size = min(len(source), len(source_alpha), len(destination), len(destination_alpha))
|
|
out_images = []
|
|
out_alphas = []
|
|
|
|
for i in range(batch_size):
|
|
src_image = source[i]
|
|
dst_image = destination[i]
|
|
|
|
assert src_image.shape[2] == dst_image.shape[2] # inputs need to have same number of channels
|
|
|
|
src_alpha = source_alpha[i].unsqueeze(2)
|
|
dst_alpha = destination_alpha[i].unsqueeze(2)
|
|
|
|
if dst_alpha.shape[:2] != dst_image.shape[:2]:
|
|
upscale_input = dst_alpha.unsqueeze(0).permute(0, 3, 1, 2)
|
|
upscale_output = comfy.utils.common_upscale(
|
|
upscale_input, dst_image.shape[1], dst_image.shape[0], upscale_method='bicubic', crop='center'
|
|
)
|
|
dst_alpha = upscale_output.permute(0, 2, 3, 1).squeeze(0)
|
|
if src_image.shape != dst_image.shape:
|
|
upscale_input = src_image.unsqueeze(0).permute(0, 3, 1, 2)
|
|
upscale_output = comfy.utils.common_upscale(
|
|
upscale_input, dst_image.shape[1], dst_image.shape[0], upscale_method='bicubic', crop='center'
|
|
)
|
|
src_image = upscale_output.permute(0, 2, 3, 1).squeeze(0)
|
|
if src_alpha.shape != dst_alpha.shape:
|
|
upscale_input = src_alpha.unsqueeze(0).permute(0, 3, 1, 2)
|
|
upscale_output = comfy.utils.common_upscale(
|
|
upscale_input, dst_alpha.shape[1], dst_alpha.shape[0], upscale_method='bicubic', crop='center'
|
|
)
|
|
src_alpha = upscale_output.permute(0, 2, 3, 1).squeeze(0)
|
|
|
|
out_image, out_alpha = porter_duff_composite(src_image, src_alpha, dst_image, dst_alpha, PorterDuffMode[mode])
|
|
|
|
out_images.append(out_image)
|
|
out_alphas.append(out_alpha.squeeze(2))
|
|
|
|
return io.NodeOutput(torch.stack(out_images), torch.stack(out_alphas))
|
|
|
|
|
|
class SplitImageWithAlpha(io.ComfyNodeV3):
|
|
@classmethod
|
|
def define_schema(cls):
|
|
return io.SchemaV3(
|
|
node_id="SplitImageWithAlpha_V3",
|
|
display_name="Split Image with Alpha _V3",
|
|
category="mask/compositing",
|
|
inputs=[
|
|
io.Image.Input("image"),
|
|
],
|
|
outputs=[io.Image.Output(), io.Mask.Output()],
|
|
)
|
|
|
|
@classmethod
|
|
def execute(cls, image: torch.Tensor) -> io.NodeOutput:
|
|
out_images = [i[:, :, :3] for i in image]
|
|
out_alphas = [i[:, :, 3] if i.shape[2] > 3 else torch.ones_like(i[:, :, 0]) for i in image]
|
|
return io.NodeOutput(torch.stack(out_images), 1.0 - torch.stack(out_alphas))
|
|
|
|
|
|
NODES_LIST = [
|
|
JoinImageWithAlpha,
|
|
PorterDuffImageComposite,
|
|
SplitImageWithAlpha,
|
|
]
|