mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-07-27 16:26:39 +00:00
converted to V3 schema
This commit is contained in:
parent
941dea9439
commit
9208b4a7c1
@ -662,6 +662,12 @@ class Accumulation(ComfyTypeIO):
|
||||
class Load3DCamera(ComfyTypeIO):
|
||||
Type = Any # TODO: figure out type for this; in code, only described as image['camera_info'], gotten from a LOAD_3D or LOAD_3D_ANIMATION type
|
||||
|
||||
|
||||
@comfytype(io_type="PHOTOMAKER")
|
||||
class Photomaker(ComfyTypeIO):
|
||||
Type = Any
|
||||
|
||||
|
||||
@comfytype(io_type="POINT")
|
||||
class Point(ComfyTypeIO):
|
||||
Type = Any # NOTE: I couldn't find any references in core code to POINT io_type. Does this exist?
|
||||
|
34
comfy_extras/v3/nodes_edit_model.py
Normal file
34
comfy_extras/v3/nodes_edit_model.py
Normal file
@ -0,0 +1,34 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import node_helpers
|
||||
from comfy_api.v3 import io
|
||||
|
||||
|
||||
class ReferenceLatent(io.ComfyNodeV3):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.SchemaV3(
|
||||
node_id="ReferenceLatent_V3",
|
||||
category="advanced/conditioning/edit_models",
|
||||
description="This node sets the guiding latent for an edit model. If the model supports it you can chain multiple to set multiple reference images.",
|
||||
inputs=[
|
||||
io.Conditioning.Input("conditioning"),
|
||||
io.Latent.Input("latent", optional=True),
|
||||
],
|
||||
outputs=[
|
||||
io.Conditioning.Output(),
|
||||
]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, conditioning, latent=None):
|
||||
if latent is not None:
|
||||
conditioning = node_helpers.conditioning_set_values(
|
||||
conditioning, {"reference_latents": [latent["samples"]]}, append=True
|
||||
)
|
||||
return io.NodeOutput(conditioning)
|
||||
|
||||
|
||||
NODES_LIST = [
|
||||
ReferenceLatent,
|
||||
]
|
71
comfy_extras/v3/nodes_hidream.py
Normal file
71
comfy_extras/v3/nodes_hidream.py
Normal file
@ -0,0 +1,71 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import comfy.model_management
|
||||
import comfy.sd
|
||||
import folder_paths
|
||||
from comfy_api.v3 import io
|
||||
|
||||
|
||||
class CLIPTextEncodeHiDream(io.ComfyNodeV3):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.SchemaV3(
|
||||
node_id="CLIPTextEncodeHiDream_V3",
|
||||
category="advanced/conditioning",
|
||||
inputs=[
|
||||
io.Clip.Input("clip"),
|
||||
io.String.Input("clip_l", multiline=True, dynamic_prompts=True),
|
||||
io.String.Input("clip_g", multiline=True, dynamic_prompts=True),
|
||||
io.String.Input("t5xxl", multiline=True, dynamic_prompts=True),
|
||||
io.String.Input("llama", multiline=True, dynamic_prompts=True),
|
||||
],
|
||||
outputs=[
|
||||
io.Conditioning.Output(),
|
||||
]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, clip, clip_l, clip_g, t5xxl, llama):
|
||||
tokens = clip.tokenize(clip_g)
|
||||
tokens["l"] = clip.tokenize(clip_l)["l"]
|
||||
tokens["t5xxl"] = clip.tokenize(t5xxl)["t5xxl"]
|
||||
tokens["llama"] = clip.tokenize(llama)["llama"]
|
||||
return io.NodeOutput(clip.encode_from_tokens_scheduled(tokens))
|
||||
|
||||
|
||||
class QuadrupleCLIPLoader(io.ComfyNodeV3):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.SchemaV3(
|
||||
node_id="QuadrupleCLIPLoader_V3",
|
||||
category="advanced/loaders",
|
||||
description="[Recipes]\n\nhidream: long clip-l, long clip-g, t5xxl, llama_8b_3.1_instruct",
|
||||
inputs=[
|
||||
io.Combo.Input("clip_name1", options=folder_paths.get_filename_list("text_encoders")),
|
||||
io.Combo.Input("clip_name2", options=folder_paths.get_filename_list("text_encoders")),
|
||||
io.Combo.Input("clip_name3", options=folder_paths.get_filename_list("text_encoders")),
|
||||
io.Combo.Input("clip_name4", options=folder_paths.get_filename_list("text_encoders")),
|
||||
],
|
||||
outputs=[
|
||||
io.Clip.Output(),
|
||||
]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, clip_name1, clip_name2, clip_name3, clip_name4):
|
||||
clip_path1 = folder_paths.get_full_path_or_raise("text_encoders", clip_name1)
|
||||
clip_path2 = folder_paths.get_full_path_or_raise("text_encoders", clip_name2)
|
||||
clip_path3 = folder_paths.get_full_path_or_raise("text_encoders", clip_name3)
|
||||
clip_path4 = folder_paths.get_full_path_or_raise("text_encoders", clip_name4)
|
||||
return io.NodeOutput(
|
||||
comfy.sd.load_clip(
|
||||
ckpt_paths=[clip_path1, clip_path2, clip_path3, clip_path4],
|
||||
embedding_directory=folder_paths.get_folder_paths("embeddings"),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
NODES_LIST = [
|
||||
CLIPTextEncodeHiDream,
|
||||
QuadrupleCLIPLoader,
|
||||
]
|
38
comfy_extras/v3/nodes_mochi.py
Normal file
38
comfy_extras/v3/nodes_mochi.py
Normal file
@ -0,0 +1,38 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
|
||||
import comfy.model_management
|
||||
import nodes
|
||||
from comfy_api.v3 import io
|
||||
|
||||
|
||||
class EmptyMochiLatentVideo(io.ComfyNodeV3):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.SchemaV3(
|
||||
node_id="EmptyMochiLatentVideo_V3",
|
||||
category="latent/video",
|
||||
inputs=[
|
||||
io.Int.Input("width", default=848, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||
io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||
io.Int.Input("length", default=25, min=7, max=nodes.MAX_RESOLUTION, step=6),
|
||||
io.Int.Input("batch_size", default=1, min=1, max=4096),
|
||||
],
|
||||
outputs=[
|
||||
io.Latent.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, width, height, length, batch_size=1):
|
||||
latent = torch.zeros(
|
||||
[batch_size, 12, ((length - 1) // 6) + 1, height // 8, width // 8],
|
||||
device=comfy.model_management.intermediate_device(),
|
||||
)
|
||||
return io.NodeOutput({"samples": latent})
|
||||
|
||||
|
||||
NODES_LIST = [
|
||||
EmptyMochiLatentVideo,
|
||||
]
|
387
comfy_extras/v3/nodes_model_advanced.py
Normal file
387
comfy_extras/v3/nodes_model_advanced.py
Normal file
@ -0,0 +1,387 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
|
||||
import comfy.latent_formats
|
||||
import comfy.model_sampling
|
||||
import comfy.sd
|
||||
import node_helpers
|
||||
import nodes
|
||||
from comfy_api.v3 import io
|
||||
|
||||
|
||||
class LCM(comfy.model_sampling.EPS):
|
||||
def calculate_denoised(self, sigma, model_output, model_input):
|
||||
timestep = self.timestep(sigma).view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
|
||||
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
|
||||
x0 = model_input - model_output * sigma
|
||||
|
||||
sigma_data = 0.5
|
||||
scaled_timestep = timestep * 10.0 #timestep_scaling
|
||||
|
||||
c_skip = sigma_data**2 / (scaled_timestep**2 + sigma_data**2)
|
||||
c_out = scaled_timestep / (scaled_timestep**2 + sigma_data**2) ** 0.5
|
||||
|
||||
return c_out * x0 + c_skip * model_input
|
||||
|
||||
|
||||
class ModelSamplingDiscreteDistilled(comfy.model_sampling.ModelSamplingDiscrete):
|
||||
original_timesteps = 50
|
||||
|
||||
def __init__(self, model_config=None, zsnr=None):
|
||||
super().__init__(model_config, zsnr=zsnr)
|
||||
|
||||
self.skip_steps = self.num_timesteps // self.original_timesteps
|
||||
|
||||
sigmas_valid = torch.zeros((self.original_timesteps), dtype=torch.float32)
|
||||
for x in range(self.original_timesteps):
|
||||
sigmas_valid[self.original_timesteps - 1 - x] = self.sigmas[self.num_timesteps - 1 - x * self.skip_steps]
|
||||
|
||||
self.set_sigmas(sigmas_valid)
|
||||
|
||||
def timestep(self, sigma):
|
||||
log_sigma = sigma.log()
|
||||
dists = log_sigma.to(self.log_sigmas.device) - self.log_sigmas[:, None]
|
||||
return (dists.abs().argmin(dim=0).view(sigma.shape) * self.skip_steps + (self.skip_steps - 1)).to(sigma.device)
|
||||
|
||||
def sigma(self, timestep):
|
||||
t = torch.clamp(
|
||||
((timestep.float().to(self.log_sigmas.device) - (self.skip_steps - 1)) / self.skip_steps).float(),
|
||||
min=0,
|
||||
max=(len(self.sigmas) - 1),
|
||||
)
|
||||
low_idx = t.floor().long()
|
||||
high_idx = t.ceil().long()
|
||||
w = t.frac()
|
||||
log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx]
|
||||
return log_sigma.exp().to(timestep.device)
|
||||
|
||||
|
||||
class ModelComputeDtype(io.ComfyNodeV3):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.SchemaV3(
|
||||
node_id="ModelComputeDtype_V3",
|
||||
category="advanced/debug/model",
|
||||
inputs=[
|
||||
io.Model.Input("model"),
|
||||
io.Combo.Input("dtype", options=["default", "fp32", "fp16", "bf16"]),
|
||||
],
|
||||
outputs=[
|
||||
io.Model.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model, dtype):
|
||||
m = model.clone()
|
||||
m.set_model_compute_dtype(node_helpers.string_to_torch_dtype(dtype))
|
||||
return io.NodeOutput(m)
|
||||
|
||||
|
||||
class ModelSamplingContinuousEDM(io.ComfyNodeV3):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.SchemaV3(
|
||||
node_id="ModelSamplingContinuousEDM_V3",
|
||||
category="advanced/model",
|
||||
inputs=[
|
||||
io.Model.Input("model"),
|
||||
io.Combo.Input(
|
||||
"sampling", options=["v_prediction", "edm", "edm_playground_v2.5", "eps", "cosmos_rflow"]
|
||||
),
|
||||
io.Float.Input("sigma_max", default=120.0, min=0.0, max=1000.0, step=0.001, round=False),
|
||||
io.Float.Input("sigma_min", default=0.002, min=0.0, max=1000.0, step=0.001, round=False),
|
||||
],
|
||||
outputs=[
|
||||
io.Model.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model, sampling, sigma_max, sigma_min):
|
||||
m = model.clone()
|
||||
|
||||
sampling_base = comfy.model_sampling.ModelSamplingContinuousEDM
|
||||
latent_format = None
|
||||
sigma_data = 1.0
|
||||
if sampling == "eps":
|
||||
sampling_type = comfy.model_sampling.EPS
|
||||
elif sampling == "edm":
|
||||
sampling_type = comfy.model_sampling.EDM
|
||||
sigma_data = 0.5
|
||||
elif sampling == "v_prediction":
|
||||
sampling_type = comfy.model_sampling.V_PREDICTION
|
||||
elif sampling == "edm_playground_v2.5":
|
||||
sampling_type = comfy.model_sampling.EDM
|
||||
sigma_data = 0.5
|
||||
latent_format = comfy.latent_formats.SDXL_Playground_2_5()
|
||||
elif sampling == "cosmos_rflow":
|
||||
sampling_type = comfy.model_sampling.COSMOS_RFLOW
|
||||
sampling_base = comfy.model_sampling.ModelSamplingCosmosRFlow
|
||||
|
||||
class ModelSamplingAdvanced(sampling_base, sampling_type):
|
||||
pass
|
||||
|
||||
model_sampling = ModelSamplingAdvanced(model.model.model_config)
|
||||
model_sampling.set_parameters(sigma_min, sigma_max, sigma_data)
|
||||
m.add_object_patch("model_sampling", model_sampling)
|
||||
if latent_format is not None:
|
||||
m.add_object_patch("latent_format", latent_format)
|
||||
return io.NodeOutput(m)
|
||||
|
||||
|
||||
class ModelSamplingContinuousV(io.ComfyNodeV3):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.SchemaV3(
|
||||
node_id="ModelSamplingContinuousV_V3",
|
||||
category="advanced/model",
|
||||
inputs=[
|
||||
io.Model.Input("model"),
|
||||
io.Combo.Input("sampling", options=["v_prediction"]),
|
||||
io.Float.Input("sigma_max", default=500.0, min=0.0, max=1000.0, step=0.001, round=False),
|
||||
io.Float.Input("sigma_min", default=0.03, min=0.0, max=1000.0, step=0.001, round=False),
|
||||
],
|
||||
outputs=[
|
||||
io.Model.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model, sampling, sigma_max, sigma_min):
|
||||
m = model.clone()
|
||||
|
||||
sigma_data = 1.0
|
||||
if sampling == "v_prediction":
|
||||
sampling_type = comfy.model_sampling.V_PREDICTION
|
||||
|
||||
class ModelSamplingAdvanced(comfy.model_sampling.ModelSamplingContinuousV, sampling_type):
|
||||
pass
|
||||
|
||||
model_sampling = ModelSamplingAdvanced(model.model.model_config)
|
||||
model_sampling.set_parameters(sigma_min, sigma_max, sigma_data)
|
||||
m.add_object_patch("model_sampling", model_sampling)
|
||||
return io.NodeOutput(m)
|
||||
|
||||
|
||||
class ModelSamplingDiscrete(io.ComfyNodeV3):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.SchemaV3(
|
||||
node_id="ModelSamplingDiscrete_V3",
|
||||
category="advanced/model",
|
||||
inputs=[
|
||||
io.Model.Input("model"),
|
||||
io.Combo.Input("sampling", options=["eps", "v_prediction", "lcm", "x0", "img_to_img"]),
|
||||
io.Boolean.Input("zsnr", default=False),
|
||||
],
|
||||
outputs=[
|
||||
io.Model.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model, sampling, zsnr):
|
||||
m = model.clone()
|
||||
|
||||
sampling_base = comfy.model_sampling.ModelSamplingDiscrete
|
||||
if sampling == "eps":
|
||||
sampling_type = comfy.model_sampling.EPS
|
||||
elif sampling == "v_prediction":
|
||||
sampling_type = comfy.model_sampling.V_PREDICTION
|
||||
elif sampling == "lcm":
|
||||
sampling_type = LCM
|
||||
sampling_base = ModelSamplingDiscreteDistilled
|
||||
elif sampling == "x0":
|
||||
sampling_type = comfy.model_sampling.X0
|
||||
elif sampling == "img_to_img":
|
||||
sampling_type = comfy.model_sampling.IMG_TO_IMG
|
||||
|
||||
class ModelSamplingAdvanced(sampling_base, sampling_type):
|
||||
pass
|
||||
|
||||
model_sampling = ModelSamplingAdvanced(model.model.model_config, zsnr=zsnr)
|
||||
|
||||
m.add_object_patch("model_sampling", model_sampling)
|
||||
return io.NodeOutput(m)
|
||||
|
||||
|
||||
class ModelSamplingFlux(io.ComfyNodeV3):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.SchemaV3(
|
||||
node_id="ModelSamplingFlux_V3",
|
||||
category="advanced/model",
|
||||
inputs=[
|
||||
io.Model.Input("model"),
|
||||
io.Float.Input("max_shift", default=1.15, min=0.0, max=100.0, step=0.01),
|
||||
io.Float.Input("base_shift", default=0.5, min=0.0, max=100.0, step=0.01),
|
||||
io.Int.Input("width", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=8),
|
||||
io.Int.Input("height", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=8),
|
||||
],
|
||||
outputs=[
|
||||
io.Model.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model, max_shift, base_shift, width, height):
|
||||
m = model.clone()
|
||||
|
||||
x1 = 256
|
||||
x2 = 4096
|
||||
mm = (max_shift - base_shift) / (x2 - x1)
|
||||
b = base_shift - mm * x1
|
||||
shift = (width * height / (8 * 8 * 2 * 2)) * mm + b
|
||||
|
||||
sampling_base = comfy.model_sampling.ModelSamplingFlux
|
||||
sampling_type = comfy.model_sampling.CONST
|
||||
|
||||
class ModelSamplingAdvanced(sampling_base, sampling_type):
|
||||
pass
|
||||
|
||||
model_sampling = ModelSamplingAdvanced(model.model.model_config)
|
||||
model_sampling.set_parameters(shift=shift)
|
||||
m.add_object_patch("model_sampling", model_sampling)
|
||||
return io.NodeOutput(m)
|
||||
|
||||
|
||||
class ModelSamplingSD3(io.ComfyNodeV3):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.SchemaV3(
|
||||
node_id="ModelSamplingSD3_V3",
|
||||
category="advanced/model",
|
||||
inputs=[
|
||||
io.Model.Input("model"),
|
||||
io.Float.Input("shift", default=3.0, min=0.0, max=100.0, step=0.01),
|
||||
],
|
||||
outputs=[
|
||||
io.Model.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model, shift, multiplier: int | float = 1000):
|
||||
m = model.clone()
|
||||
|
||||
sampling_base = comfy.model_sampling.ModelSamplingDiscreteFlow
|
||||
sampling_type = comfy.model_sampling.CONST
|
||||
|
||||
class ModelSamplingAdvanced(sampling_base, sampling_type):
|
||||
pass
|
||||
|
||||
model_sampling = ModelSamplingAdvanced(model.model.model_config)
|
||||
model_sampling.set_parameters(shift=shift, multiplier=multiplier)
|
||||
m.add_object_patch("model_sampling", model_sampling)
|
||||
return io.NodeOutput(m)
|
||||
|
||||
|
||||
class ModelSamplingAuraFlow(ModelSamplingSD3):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.SchemaV3(
|
||||
node_id="ModelSamplingAuraFlow_V3",
|
||||
category="advanced/model",
|
||||
inputs=[
|
||||
io.Model.Input("model"),
|
||||
io.Float.Input("shift", default=1.73, min=0.0, max=100.0, step=0.01),
|
||||
],
|
||||
outputs=[
|
||||
io.Model.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model, shift, multiplier: int | float = 1.0):
|
||||
return super().execute(model, shift, multiplier)
|
||||
|
||||
|
||||
class ModelSamplingStableCascade(io.ComfyNodeV3):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.SchemaV3(
|
||||
node_id="ModelSamplingStableCascade_V3",
|
||||
category="advanced/model",
|
||||
inputs=[
|
||||
io.Model.Input("model"),
|
||||
io.Float.Input("shift", default=2.0, min=0.0, max=100.0, step=0.01),
|
||||
],
|
||||
outputs=[
|
||||
io.Model.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model, shift):
|
||||
m = model.clone()
|
||||
|
||||
sampling_base = comfy.model_sampling.StableCascadeSampling
|
||||
sampling_type = comfy.model_sampling.EPS
|
||||
|
||||
class ModelSamplingAdvanced(sampling_base, sampling_type):
|
||||
pass
|
||||
|
||||
model_sampling = ModelSamplingAdvanced(model.model.model_config)
|
||||
model_sampling.set_parameters(shift)
|
||||
m.add_object_patch("model_sampling", model_sampling)
|
||||
return io.NodeOutput(m)
|
||||
|
||||
|
||||
class RescaleCFG(io.ComfyNodeV3):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.SchemaV3(
|
||||
node_id="RescaleCFG_V3",
|
||||
category="advanced/model",
|
||||
inputs=[
|
||||
io.Model.Input("model"),
|
||||
io.Float.Input("multiplier", default=0.7, min=0.0, max=1.0, step=0.01),
|
||||
],
|
||||
outputs=[
|
||||
io.Model.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model, multiplier):
|
||||
def rescale_cfg(args):
|
||||
cond = args["cond"]
|
||||
uncond = args["uncond"]
|
||||
cond_scale = args["cond_scale"]
|
||||
sigma = args["sigma"]
|
||||
sigma = sigma.view(sigma.shape[:1] + (1,) * (cond.ndim - 1))
|
||||
x_orig = args["input"]
|
||||
|
||||
#rescale cfg has to be done on v-pred model output
|
||||
x = x_orig / (sigma * sigma + 1.0)
|
||||
cond = ((x - (x_orig - cond)) * (sigma ** 2 + 1.0) ** 0.5) / (sigma)
|
||||
uncond = ((x - (x_orig - uncond)) * (sigma ** 2 + 1.0) ** 0.5) / (sigma)
|
||||
|
||||
#rescalecfg
|
||||
x_cfg = uncond + cond_scale * (cond - uncond)
|
||||
ro_pos = torch.std(cond, dim=(1,2,3), keepdim=True)
|
||||
ro_cfg = torch.std(x_cfg, dim=(1,2,3), keepdim=True)
|
||||
|
||||
x_rescaled = x_cfg * (ro_pos / ro_cfg)
|
||||
x_final = multiplier * x_rescaled + (1.0 - multiplier) * x_cfg
|
||||
|
||||
return x_orig - (x - x_final * sigma / (sigma * sigma + 1.0) ** 0.5)
|
||||
|
||||
m = model.clone()
|
||||
m.set_model_sampler_cfg_function(rescale_cfg)
|
||||
return io.NodeOutput(m)
|
||||
|
||||
|
||||
NODES_LIST = [
|
||||
ModelSamplingAuraFlow,
|
||||
ModelComputeDtype,
|
||||
ModelSamplingContinuousEDM,
|
||||
ModelSamplingContinuousV,
|
||||
ModelSamplingDiscrete,
|
||||
ModelSamplingFlux,
|
||||
ModelSamplingSD3,
|
||||
ModelSamplingStableCascade,
|
||||
RescaleCFG,
|
||||
]
|
68
comfy_extras/v3/nodes_model_downscale.py
Normal file
68
comfy_extras/v3/nodes_model_downscale.py
Normal file
@ -0,0 +1,68 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import comfy.utils
|
||||
from comfy_api.v3 import io
|
||||
|
||||
|
||||
class PatchModelAddDownscale(io.ComfyNodeV3):
|
||||
UPSCALE_METHODS = ["bicubic", "nearest-exact", "bilinear", "area", "bislerp"]
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.SchemaV3(
|
||||
node_id="PatchModelAddDownscale_V3",
|
||||
display_name="PatchModelAddDownscale (Kohya Deep Shrink) _V3",
|
||||
category="model_patches/unet",
|
||||
inputs=[
|
||||
io.Model.Input("model"),
|
||||
io.Int.Input("block_number", default=3, min=1, max=32, step=1),
|
||||
io.Float.Input("downscale_factor", default=2.0, min=0.1, max=9.0, step=0.001),
|
||||
io.Float.Input("start_percent", default=0.0, min=0.0, max=1.0, step=0.001),
|
||||
io.Float.Input("end_percent", default=0.35, min=0.0, max=1.0, step=0.001),
|
||||
io.Boolean.Input("downscale_after_skip", default=True),
|
||||
io.Combo.Input("downscale_method", options=cls.UPSCALE_METHODS),
|
||||
io.Combo.Input("upscale_method", options=cls.UPSCALE_METHODS),
|
||||
],
|
||||
outputs=[
|
||||
io.Model.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(
|
||||
cls, model, block_number, downscale_factor, start_percent, end_percent, downscale_after_skip, downscale_method, upscale_method
|
||||
):
|
||||
model_sampling = model.get_model_object("model_sampling")
|
||||
sigma_start = model_sampling.percent_to_sigma(start_percent)
|
||||
sigma_end = model_sampling.percent_to_sigma(end_percent)
|
||||
|
||||
def input_block_patch(h, transformer_options):
|
||||
if transformer_options["block"][1] == block_number:
|
||||
sigma = transformer_options["sigmas"][0].item()
|
||||
if sigma <= sigma_start and sigma >= sigma_end:
|
||||
h = comfy.utils.common_upscale(
|
||||
h,
|
||||
round(h.shape[-1] * (1.0 / downscale_factor)),
|
||||
round(h.shape[-2] * (1.0 / downscale_factor)),
|
||||
downscale_method,
|
||||
"disabled",
|
||||
)
|
||||
return h
|
||||
|
||||
def output_block_patch(h, hsp, transformer_options):
|
||||
if h.shape[2] != hsp.shape[2]:
|
||||
h = comfy.utils.common_upscale(h, hsp.shape[-1], hsp.shape[-2], upscale_method, "disabled")
|
||||
return h, hsp
|
||||
|
||||
m = model.clone()
|
||||
if downscale_after_skip:
|
||||
m.set_model_input_block_patch_after_skip(input_block_patch)
|
||||
else:
|
||||
m.set_model_input_block_patch(input_block_patch)
|
||||
m.set_model_output_block_patch(output_block_patch)
|
||||
return io.NodeOutput(m)
|
||||
|
||||
|
||||
NODES_LIST = [
|
||||
PatchModelAddDownscale,
|
||||
]
|
205
comfy_extras/v3/nodes_photomaker.py
Normal file
205
comfy_extras/v3/nodes_photomaker.py
Normal file
@ -0,0 +1,205 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
import comfy.clip_model
|
||||
import comfy.clip_vision
|
||||
import comfy.model_management
|
||||
import comfy.ops
|
||||
import comfy.utils
|
||||
import folder_paths
|
||||
from comfy_api.v3 import io
|
||||
|
||||
# code for model from:
|
||||
# https://github.com/TencentARC/PhotoMaker/blob/main/photomaker/model.py under Apache License Version 2.0
|
||||
VISION_CONFIG_DICT = {
|
||||
"hidden_size": 1024,
|
||||
"image_size": 224,
|
||||
"intermediate_size": 4096,
|
||||
"num_attention_heads": 16,
|
||||
"num_channels": 3,
|
||||
"num_hidden_layers": 24,
|
||||
"patch_size": 14,
|
||||
"projection_dim": 768,
|
||||
"hidden_act": "quick_gelu",
|
||||
"model_type": "clip_vision_model",
|
||||
}
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, in_dim, out_dim, hidden_dim, use_residual=True, operations=comfy.ops):
|
||||
super().__init__()
|
||||
if use_residual:
|
||||
assert in_dim == out_dim
|
||||
self.layernorm = operations.LayerNorm(in_dim)
|
||||
self.fc1 = operations.Linear(in_dim, hidden_dim)
|
||||
self.fc2 = operations.Linear(hidden_dim, out_dim)
|
||||
self.use_residual = use_residual
|
||||
self.act_fn = nn.GELU()
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
x = self.layernorm(x)
|
||||
x = self.fc1(x)
|
||||
x = self.act_fn(x)
|
||||
x = self.fc2(x)
|
||||
if self.use_residual:
|
||||
x = x + residual
|
||||
return x
|
||||
|
||||
|
||||
class FuseModule(nn.Module):
|
||||
def __init__(self, embed_dim, operations):
|
||||
super().__init__()
|
||||
self.mlp1 = MLP(embed_dim * 2, embed_dim, embed_dim, use_residual=False, operations=operations)
|
||||
self.mlp2 = MLP(embed_dim, embed_dim, embed_dim, use_residual=True, operations=operations)
|
||||
self.layer_norm = operations.LayerNorm(embed_dim)
|
||||
|
||||
def fuse_fn(self, prompt_embeds, id_embeds):
|
||||
stacked_id_embeds = torch.cat([prompt_embeds, id_embeds], dim=-1)
|
||||
stacked_id_embeds = self.mlp1(stacked_id_embeds) + prompt_embeds
|
||||
stacked_id_embeds = self.mlp2(stacked_id_embeds)
|
||||
stacked_id_embeds = self.layer_norm(stacked_id_embeds)
|
||||
return stacked_id_embeds
|
||||
|
||||
def forward(
|
||||
self,
|
||||
prompt_embeds,
|
||||
id_embeds,
|
||||
class_tokens_mask,
|
||||
) -> torch.Tensor:
|
||||
# id_embeds shape: [b, max_num_inputs, 1, 2048]
|
||||
id_embeds = id_embeds.to(prompt_embeds.dtype)
|
||||
num_inputs = class_tokens_mask.sum().unsqueeze(0) # TODO: check for training case
|
||||
batch_size, max_num_inputs = id_embeds.shape[:2]
|
||||
# seq_length: 77
|
||||
seq_length = prompt_embeds.shape[1]
|
||||
# flat_id_embeds shape: [b*max_num_inputs, 1, 2048]
|
||||
flat_id_embeds = id_embeds.view(
|
||||
-1, id_embeds.shape[-2], id_embeds.shape[-1]
|
||||
)
|
||||
# valid_id_mask [b*max_num_inputs]
|
||||
valid_id_mask = (
|
||||
torch.arange(max_num_inputs, device=flat_id_embeds.device)[None, :]
|
||||
< num_inputs[:, None]
|
||||
)
|
||||
valid_id_embeds = flat_id_embeds[valid_id_mask.flatten()]
|
||||
|
||||
prompt_embeds = prompt_embeds.view(-1, prompt_embeds.shape[-1])
|
||||
class_tokens_mask = class_tokens_mask.view(-1)
|
||||
valid_id_embeds = valid_id_embeds.view(-1, valid_id_embeds.shape[-1])
|
||||
# slice out the image token embeddings
|
||||
image_token_embeds = prompt_embeds[class_tokens_mask]
|
||||
stacked_id_embeds = self.fuse_fn(image_token_embeds, valid_id_embeds)
|
||||
assert class_tokens_mask.sum() == stacked_id_embeds.shape[0], f"{class_tokens_mask.sum()} != {stacked_id_embeds.shape[0]}"
|
||||
prompt_embeds.masked_scatter_(class_tokens_mask[:, None], stacked_id_embeds.to(prompt_embeds.dtype))
|
||||
return prompt_embeds.view(batch_size, seq_length, -1)
|
||||
|
||||
|
||||
class PhotoMakerIDEncoder(comfy.clip_model.CLIPVisionModelProjection):
|
||||
def __init__(self):
|
||||
self.load_device = comfy.model_management.text_encoder_device()
|
||||
offload_device = comfy.model_management.text_encoder_offload_device()
|
||||
dtype = comfy.model_management.text_encoder_dtype(self.load_device)
|
||||
|
||||
super().__init__(VISION_CONFIG_DICT, dtype, offload_device, comfy.ops.manual_cast)
|
||||
self.visual_projection_2 = comfy.ops.manual_cast.Linear(1024, 1280, bias=False)
|
||||
self.fuse_module = FuseModule(2048, comfy.ops.manual_cast)
|
||||
|
||||
def forward(self, id_pixel_values, prompt_embeds, class_tokens_mask):
|
||||
b, num_inputs, c, h, w = id_pixel_values.shape
|
||||
id_pixel_values = id_pixel_values.view(b * num_inputs, c, h, w)
|
||||
|
||||
shared_id_embeds = self.vision_model(id_pixel_values)[2]
|
||||
id_embeds = self.visual_projection(shared_id_embeds)
|
||||
id_embeds_2 = self.visual_projection_2(shared_id_embeds)
|
||||
|
||||
id_embeds = id_embeds.view(b, num_inputs, 1, -1)
|
||||
id_embeds_2 = id_embeds_2.view(b, num_inputs, 1, -1)
|
||||
|
||||
id_embeds = torch.cat((id_embeds, id_embeds_2), dim=-1)
|
||||
return self.fuse_module(prompt_embeds, id_embeds, class_tokens_mask)
|
||||
|
||||
|
||||
class PhotoMakerEncode(io.ComfyNodeV3):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.SchemaV3(
|
||||
node_id="PhotoMakerEncode_V3",
|
||||
category="_for_testing/photomaker",
|
||||
inputs=[
|
||||
io.Photomaker.Input("photomaker"),
|
||||
io.Image.Input("image"),
|
||||
io.Clip.Input("clip"),
|
||||
io.String.Input("text", multiline=True, dynamic_prompts=True, default="photograph of photomaker"),
|
||||
],
|
||||
outputs=[
|
||||
io.Conditioning.Output(),
|
||||
],
|
||||
is_experimental=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, photomaker, image, clip, text):
|
||||
special_token = "photomaker"
|
||||
pixel_values = comfy.clip_vision.clip_preprocess(image.to(photomaker.load_device)).float()
|
||||
try:
|
||||
index = text.split(" ").index(special_token) + 1
|
||||
except ValueError:
|
||||
index = -1
|
||||
tokens = clip.tokenize(text, return_word_ids=True)
|
||||
out_tokens = {}
|
||||
for k in tokens:
|
||||
out_tokens[k] = []
|
||||
for t in tokens[k]:
|
||||
f = list(filter(lambda x: x[2] != index, t))
|
||||
while len(f) < len(t):
|
||||
f.append(t[-1])
|
||||
out_tokens[k].append(f)
|
||||
|
||||
cond, pooled = clip.encode_from_tokens(out_tokens, return_pooled=True)
|
||||
|
||||
if index > 0:
|
||||
token_index = index - 1
|
||||
num_id_images = 1
|
||||
class_tokens_mask = [True if token_index <= i < token_index+num_id_images else False for i in range(77)]
|
||||
out = photomaker(
|
||||
id_pixel_values=pixel_values.unsqueeze(0), prompt_embeds=cond.to(photomaker.load_device),
|
||||
class_tokens_mask=torch.tensor(class_tokens_mask, dtype=torch.bool, device=photomaker.load_device).unsqueeze(0),
|
||||
)
|
||||
else:
|
||||
out = cond
|
||||
|
||||
return io.NodeOutput([[out, {"pooled_output": pooled}]])
|
||||
|
||||
|
||||
class PhotoMakerLoader(io.ComfyNodeV3):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.SchemaV3(
|
||||
node_id="PhotoMakerLoader_V3",
|
||||
category="_for_testing/photomaker",
|
||||
inputs=[
|
||||
io.Combo.Input("photomaker_model_name", options=folder_paths.get_filename_list("photomaker")),
|
||||
],
|
||||
outputs=[
|
||||
io.Photomaker.Output(),
|
||||
],
|
||||
is_experimental=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, photomaker_model_name):
|
||||
photomaker_model_path = folder_paths.get_full_path_or_raise("photomaker", photomaker_model_name)
|
||||
photomaker_model = PhotoMakerIDEncoder()
|
||||
data = comfy.utils.load_torch_file(photomaker_model_path, safe_load=True)
|
||||
if "id_encoder" in data:
|
||||
data = data["id_encoder"]
|
||||
photomaker_model.load_state_dict(data)
|
||||
return io.NodeOutput(photomaker_model)
|
||||
|
||||
|
||||
NODES_LIST = [
|
||||
PhotoMakerEncode,
|
||||
PhotoMakerLoader,
|
||||
]
|
33
comfy_extras/v3/nodes_pixart.py
Normal file
33
comfy_extras/v3/nodes_pixart.py
Normal file
@ -0,0 +1,33 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import nodes
|
||||
from comfy_api.v3 import io
|
||||
|
||||
|
||||
class CLIPTextEncodePixArtAlpha(io.ComfyNodeV3):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.SchemaV3(
|
||||
node_id="CLIPTextEncodePixArtAlpha_V3",
|
||||
category="advanced/conditioning",
|
||||
description="Encodes text and sets the resolution conditioning for PixArt Alpha. Does not apply to PixArt Sigma.",
|
||||
inputs=[
|
||||
io.Int.Input("width", default=1024, min=0, max=nodes.MAX_RESOLUTION),
|
||||
io.Int.Input("height", default=1024, min=0, max=nodes.MAX_RESOLUTION),
|
||||
io.String.Input("text", multiline=True, dynamic_prompts=True),
|
||||
io.Clip.Input("clip"),
|
||||
],
|
||||
outputs=[
|
||||
io.Conditioning.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, width, height, text, clip):
|
||||
tokens = clip.tokenize(text)
|
||||
return io.NodeOutput(clip.encode_from_tokens_scheduled(tokens, add_dict={"width": width, "height": height}))
|
||||
|
||||
|
||||
NODES_LIST = [
|
||||
CLIPTextEncodePixArtAlpha,
|
||||
]
|
255
comfy_extras/v3/nodes_post_processing.py
Normal file
255
comfy_extras/v3/nodes_post_processing.py
Normal file
@ -0,0 +1,255 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from PIL import Image
|
||||
|
||||
import comfy.model_management
|
||||
import comfy.utils
|
||||
import node_helpers
|
||||
from comfy_api.v3 import io
|
||||
|
||||
|
||||
def gaussian_kernel(kernel_size: int, sigma: float, device=None):
|
||||
x, y = torch.meshgrid(torch.linspace(-1, 1, kernel_size, device=device), torch.linspace(-1, 1, kernel_size, device=device), indexing="ij")
|
||||
d = torch.sqrt(x * x + y * y)
|
||||
g = torch.exp(-(d * d) / (2.0 * sigma * sigma))
|
||||
return g / g.sum()
|
||||
|
||||
|
||||
class Blend(io.ComfyNodeV3):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.SchemaV3(
|
||||
node_id="ImageBlend_V3",
|
||||
category="image/postprocessing",
|
||||
inputs=[
|
||||
io.Image.Input("image1"),
|
||||
io.Image.Input("image2"),
|
||||
io.Float.Input("blend_factor", default=0.5, min=0.0, max=1.0, step=0.01),
|
||||
io.Combo.Input("blend_mode", options=["normal", "multiply", "screen", "overlay", "soft_light", "difference"]),
|
||||
],
|
||||
outputs=[
|
||||
io.Image.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, image1: torch.Tensor, image2: torch.Tensor, blend_factor: float, blend_mode: str):
|
||||
image1, image2 = node_helpers.image_alpha_fix(image1, image2)
|
||||
image2 = image2.to(image1.device)
|
||||
if image1.shape != image2.shape:
|
||||
image2 = image2.permute(0, 3, 1, 2)
|
||||
image2 = comfy.utils.common_upscale(
|
||||
image2, image1.shape[2], image1.shape[1], upscale_method="bicubic", crop="center"
|
||||
)
|
||||
image2 = image2.permute(0, 2, 3, 1)
|
||||
|
||||
blended_image = cls.blend_mode(image1, image2, blend_mode)
|
||||
blended_image = image1 * (1 - blend_factor) + blended_image * blend_factor
|
||||
blended_image = torch.clamp(blended_image, 0, 1)
|
||||
return io.NodeOutput(blended_image)
|
||||
|
||||
@classmethod
|
||||
def blend_mode(cls, img1, img2, mode):
|
||||
if mode == "normal":
|
||||
return img2
|
||||
elif mode == "multiply":
|
||||
return img1 * img2
|
||||
elif mode == "screen":
|
||||
return 1 - (1 - img1) * (1 - img2)
|
||||
elif mode == "overlay":
|
||||
return torch.where(img1 <= 0.5, 2 * img1 * img2, 1 - 2 * (1 - img1) * (1 - img2))
|
||||
elif mode == "soft_light":
|
||||
return torch.where(img2 <= 0.5, img1 - (1 - 2 * img2) * img1 * (1 - img1), img1 + (2 * img2 - 1) * (cls.g(img1) - img1))
|
||||
elif mode == "difference":
|
||||
return img1 - img2
|
||||
raise ValueError(f"Unsupported blend mode: {mode}")
|
||||
|
||||
@classmethod
|
||||
def g(cls, x):
|
||||
return torch.where(x <= 0.25, ((16 * x - 12) * x + 4) * x, torch.sqrt(x))
|
||||
|
||||
|
||||
class Blur(io.ComfyNodeV3):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.SchemaV3(
|
||||
node_id="ImageBlur_V3",
|
||||
category="image/postprocessing",
|
||||
inputs=[
|
||||
io.Image.Input("image"),
|
||||
io.Int.Input("blur_radius", default=1, min=1, max=31, step=1),
|
||||
io.Float.Input("sigma", default=1.0, min=0.1, max=10.0, step=0.1),
|
||||
],
|
||||
outputs=[
|
||||
io.Image.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, image: torch.Tensor, blur_radius: int, sigma: float):
|
||||
if blur_radius == 0:
|
||||
return io.NodeOutput(image)
|
||||
|
||||
image = image.to(comfy.model_management.get_torch_device())
|
||||
batch_size, height, width, channels = image.shape
|
||||
|
||||
kernel_size = blur_radius * 2 + 1
|
||||
kernel = gaussian_kernel(kernel_size, sigma, device=image.device).repeat(channels, 1, 1).unsqueeze(1)
|
||||
|
||||
image = image.permute(0, 3, 1, 2) # Torch wants (B, C, H, W) we use (B, H, W, C)
|
||||
padded_image = F.pad(image, (blur_radius,blur_radius,blur_radius,blur_radius), "reflect")
|
||||
blurred = F.conv2d(padded_image, kernel, padding=kernel_size // 2, groups=channels)[:,:,blur_radius:-blur_radius, blur_radius:-blur_radius]
|
||||
blurred = blurred.permute(0, 2, 3, 1)
|
||||
|
||||
return io.NodeOutput(blurred.to(comfy.model_management.intermediate_device()))
|
||||
|
||||
|
||||
class ImageScaleToTotalPixels(io.ComfyNodeV3):
|
||||
upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "lanczos"]
|
||||
crop_methods = ["disabled", "center"]
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.SchemaV3(
|
||||
node_id="ImageScaleToTotalPixels_V3",
|
||||
category="image/upscaling",
|
||||
inputs=[
|
||||
io.Image.Input("image"),
|
||||
io.Combo.Input("upscale_method", options=cls.upscale_methods),
|
||||
io.Float.Input("megapixels", default=1.0, min=0.01, max=16.0, step=0.01),
|
||||
],
|
||||
outputs=[
|
||||
io.Image.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, image, upscale_method, megapixels):
|
||||
samples = image.movedim(-1,1)
|
||||
total = int(megapixels * 1024 * 1024)
|
||||
|
||||
scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2]))
|
||||
width = round(samples.shape[3] * scale_by)
|
||||
height = round(samples.shape[2] * scale_by)
|
||||
|
||||
s = comfy.utils.common_upscale(samples, width, height, upscale_method, "disabled")
|
||||
return io.NodeOutput(s.movedim(1,-1))
|
||||
|
||||
|
||||
class Quantize(io.ComfyNodeV3):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.SchemaV3(
|
||||
node_id="ImageQuantize_V3",
|
||||
category="image/postprocessing",
|
||||
inputs=[
|
||||
io.Image.Input("image"),
|
||||
io.Int.Input("colors", default=256, min=1, max=256, step=1),
|
||||
io.Combo.Input("dither", options=["none", "floyd-steinberg", "bayer-2", "bayer-4", "bayer-8", "bayer-16"]),
|
||||
],
|
||||
outputs=[
|
||||
io.Image.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def bayer(im, pal_im, order):
|
||||
def normalized_bayer_matrix(n):
|
||||
if n == 0:
|
||||
return np.zeros((1,1), "float32")
|
||||
q = 4 ** n
|
||||
m = q * normalized_bayer_matrix(n - 1)
|
||||
return np.bmat(((m-1.5, m+0.5), (m+1.5, m-0.5))) / q
|
||||
|
||||
num_colors = len(pal_im.getpalette()) // 3
|
||||
spread = 2 * 256 / num_colors
|
||||
bayer_n = int(math.log2(order))
|
||||
bayer_matrix = torch.from_numpy(spread * normalized_bayer_matrix(bayer_n) + 0.5)
|
||||
|
||||
result = torch.from_numpy(np.array(im).astype(np.float32))
|
||||
tw = math.ceil(result.shape[0] / bayer_matrix.shape[0])
|
||||
th = math.ceil(result.shape[1] / bayer_matrix.shape[1])
|
||||
tiled_matrix = bayer_matrix.tile(tw, th).unsqueeze(-1)
|
||||
result.add_(tiled_matrix[:result.shape[0],:result.shape[1]]).clamp_(0, 255)
|
||||
result = result.to(dtype=torch.uint8)
|
||||
|
||||
im = Image.fromarray(result.cpu().numpy())
|
||||
return im.quantize(palette=pal_im, dither=Image.Dither.NONE)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, image: torch.Tensor, colors: int, dither: str):
|
||||
batch_size, height, width, _ = image.shape
|
||||
result = torch.zeros_like(image)
|
||||
|
||||
for b in range(batch_size):
|
||||
im = Image.fromarray((image[b] * 255).to(torch.uint8).numpy(), mode='RGB')
|
||||
|
||||
pal_im = im.quantize(colors=colors) # Required as described in https://github.com/python-pillow/Pillow/issues/5836
|
||||
|
||||
if dither == "none":
|
||||
quantized_image = im.quantize(palette=pal_im, dither=Image.Dither.NONE)
|
||||
elif dither == "floyd-steinberg":
|
||||
quantized_image = im.quantize(palette=pal_im, dither=Image.Dither.FLOYDSTEINBERG)
|
||||
elif dither.startswith("bayer"):
|
||||
order = int(dither.split('-')[-1])
|
||||
quantized_image = cls.bayer(im, pal_im, order)
|
||||
|
||||
quantized_array = torch.tensor(np.array(quantized_image.convert("RGB"))).float() / 255
|
||||
result[b] = quantized_array
|
||||
|
||||
return io.NodeOutput(result)
|
||||
|
||||
|
||||
class Sharpen(io.ComfyNodeV3):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.SchemaV3(
|
||||
node_id="ImageSharpen_V3",
|
||||
category="image/postprocessing",
|
||||
inputs=[
|
||||
io.Image.Input("image"),
|
||||
io.Int.Input("sharpen_radius", default=1, min=1, max=31, step=1),
|
||||
io.Float.Input("sigma", default=1.0, min=0.1, max=10.0, step=0.01),
|
||||
io.Float.Input("alpha", default=1.0, min=0.0, max=5.0, step=0.01),
|
||||
],
|
||||
outputs=[
|
||||
io.Image.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, image: torch.Tensor, sharpen_radius: int, sigma:float, alpha: float):
|
||||
if sharpen_radius == 0:
|
||||
return io.NodeOutput(image)
|
||||
|
||||
batch_size, height, width, channels = image.shape
|
||||
image = image.to(comfy.model_management.get_torch_device())
|
||||
|
||||
kernel_size = sharpen_radius * 2 + 1
|
||||
kernel = gaussian_kernel(kernel_size, sigma, device=image.device) * -(alpha*10)
|
||||
center = kernel_size // 2
|
||||
kernel[center, center] = kernel[center, center] - kernel.sum() + 1.0
|
||||
kernel = kernel.repeat(channels, 1, 1).unsqueeze(1)
|
||||
|
||||
tensor_image = image.permute(0, 3, 1, 2) # Torch wants (B, C, H, W) we use (B, H, W, C)
|
||||
tensor_image = F.pad(tensor_image, (sharpen_radius,sharpen_radius,sharpen_radius,sharpen_radius), "reflect")
|
||||
sharpened = F.conv2d(tensor_image, kernel, padding=center, groups=channels)[:,:,sharpen_radius:-sharpen_radius, sharpen_radius:-sharpen_radius]
|
||||
sharpened = sharpened.permute(0, 2, 3, 1)
|
||||
|
||||
result = torch.clamp(sharpened, 0, 1)
|
||||
|
||||
return io.NodeOutput(result.to(comfy.model_management.intermediate_device()))
|
||||
|
||||
|
||||
NODES_LIST = [
|
||||
Blend,
|
||||
Blur,
|
||||
ImageScaleToTotalPixels,
|
||||
Quantize,
|
||||
Sharpen,
|
||||
]
|
8
nodes.py
8
nodes.py
@ -2314,18 +2314,26 @@ def init_builtin_extra_nodes():
|
||||
"v3/nodes_controlnet.py",
|
||||
"v3/nodes_cosmos.py",
|
||||
"v3/nodes_differential_diffusion.py",
|
||||
"v3/nodes_edit_model.py",
|
||||
"v3/nodes_flux.py",
|
||||
"v3/nodes_freelunch.py",
|
||||
"v3/nodes_fresca.py",
|
||||
"v3/nodes_gits.py",
|
||||
"v3/nodes_hidream.py",
|
||||
"v3/nodes_images.py",
|
||||
"v3/nodes_latent.py",
|
||||
"v3/nodes_lt.py",
|
||||
"v3/nodes_mask.py",
|
||||
"v3/nodes_mochi.py",
|
||||
"v3/nodes_model_advanced.py",
|
||||
"v3/nodes_model_downscale.py",
|
||||
"v3/nodes_morphology.py",
|
||||
"v3/nodes_optimalsteps.py",
|
||||
"v3/nodes_pag.py",
|
||||
"v3/nodes_perpneg.py",
|
||||
"v3/nodes_photomaker.py",
|
||||
"v3/nodes_pixart.py",
|
||||
"v3/nodes_post_processing.py",
|
||||
"v3/nodes_preview_any.py",
|
||||
"v3/nodes_primitive.py",
|
||||
"v3/nodes_rebatch.py",
|
||||
|
Loading…
x
Reference in New Issue
Block a user