diff --git a/comfy_api/v3/io.py b/comfy_api/v3/io.py index bbc886b84..4e9f987be 100644 --- a/comfy_api/v3/io.py +++ b/comfy_api/v3/io.py @@ -662,6 +662,12 @@ class Accumulation(ComfyTypeIO): class Load3DCamera(ComfyTypeIO): Type = Any # TODO: figure out type for this; in code, only described as image['camera_info'], gotten from a LOAD_3D or LOAD_3D_ANIMATION type + +@comfytype(io_type="PHOTOMAKER") +class Photomaker(ComfyTypeIO): + Type = Any + + @comfytype(io_type="POINT") class Point(ComfyTypeIO): Type = Any # NOTE: I couldn't find any references in core code to POINT io_type. Does this exist? diff --git a/comfy_extras/v3/nodes_edit_model.py b/comfy_extras/v3/nodes_edit_model.py new file mode 100644 index 000000000..ba43fbdb8 --- /dev/null +++ b/comfy_extras/v3/nodes_edit_model.py @@ -0,0 +1,34 @@ +from __future__ import annotations + +import node_helpers +from comfy_api.v3 import io + + +class ReferenceLatent(io.ComfyNodeV3): + @classmethod + def define_schema(cls): + return io.SchemaV3( + node_id="ReferenceLatent_V3", + category="advanced/conditioning/edit_models", + description="This node sets the guiding latent for an edit model. If the model supports it you can chain multiple to set multiple reference images.", + inputs=[ + io.Conditioning.Input("conditioning"), + io.Latent.Input("latent", optional=True), + ], + outputs=[ + io.Conditioning.Output(), + ] + ) + + @classmethod + def execute(cls, conditioning, latent=None): + if latent is not None: + conditioning = node_helpers.conditioning_set_values( + conditioning, {"reference_latents": [latent["samples"]]}, append=True + ) + return io.NodeOutput(conditioning) + + +NODES_LIST = [ + ReferenceLatent, +] diff --git a/comfy_extras/v3/nodes_hidream.py b/comfy_extras/v3/nodes_hidream.py new file mode 100644 index 000000000..dd3d90e95 --- /dev/null +++ b/comfy_extras/v3/nodes_hidream.py @@ -0,0 +1,71 @@ +from __future__ import annotations + +import comfy.model_management +import comfy.sd +import folder_paths +from comfy_api.v3 import io + + +class CLIPTextEncodeHiDream(io.ComfyNodeV3): + @classmethod + def define_schema(cls): + return io.SchemaV3( + node_id="CLIPTextEncodeHiDream_V3", + category="advanced/conditioning", + inputs=[ + io.Clip.Input("clip"), + io.String.Input("clip_l", multiline=True, dynamic_prompts=True), + io.String.Input("clip_g", multiline=True, dynamic_prompts=True), + io.String.Input("t5xxl", multiline=True, dynamic_prompts=True), + io.String.Input("llama", multiline=True, dynamic_prompts=True), + ], + outputs=[ + io.Conditioning.Output(), + ] + ) + + @classmethod + def execute(cls, clip, clip_l, clip_g, t5xxl, llama): + tokens = clip.tokenize(clip_g) + tokens["l"] = clip.tokenize(clip_l)["l"] + tokens["t5xxl"] = clip.tokenize(t5xxl)["t5xxl"] + tokens["llama"] = clip.tokenize(llama)["llama"] + return io.NodeOutput(clip.encode_from_tokens_scheduled(tokens)) + + +class QuadrupleCLIPLoader(io.ComfyNodeV3): + @classmethod + def define_schema(cls): + return io.SchemaV3( + node_id="QuadrupleCLIPLoader_V3", + category="advanced/loaders", + description="[Recipes]\n\nhidream: long clip-l, long clip-g, t5xxl, llama_8b_3.1_instruct", + inputs=[ + io.Combo.Input("clip_name1", options=folder_paths.get_filename_list("text_encoders")), + io.Combo.Input("clip_name2", options=folder_paths.get_filename_list("text_encoders")), + io.Combo.Input("clip_name3", options=folder_paths.get_filename_list("text_encoders")), + io.Combo.Input("clip_name4", options=folder_paths.get_filename_list("text_encoders")), + ], + outputs=[ + io.Clip.Output(), + ] + ) + + @classmethod + def execute(cls, clip_name1, clip_name2, clip_name3, clip_name4): + clip_path1 = folder_paths.get_full_path_or_raise("text_encoders", clip_name1) + clip_path2 = folder_paths.get_full_path_or_raise("text_encoders", clip_name2) + clip_path3 = folder_paths.get_full_path_or_raise("text_encoders", clip_name3) + clip_path4 = folder_paths.get_full_path_or_raise("text_encoders", clip_name4) + return io.NodeOutput( + comfy.sd.load_clip( + ckpt_paths=[clip_path1, clip_path2, clip_path3, clip_path4], + embedding_directory=folder_paths.get_folder_paths("embeddings"), + ) + ) + + +NODES_LIST = [ + CLIPTextEncodeHiDream, + QuadrupleCLIPLoader, +] diff --git a/comfy_extras/v3/nodes_mochi.py b/comfy_extras/v3/nodes_mochi.py new file mode 100644 index 000000000..7dca58fc5 --- /dev/null +++ b/comfy_extras/v3/nodes_mochi.py @@ -0,0 +1,38 @@ +from __future__ import annotations + +import torch + +import comfy.model_management +import nodes +from comfy_api.v3 import io + + +class EmptyMochiLatentVideo(io.ComfyNodeV3): + @classmethod + def define_schema(cls): + return io.SchemaV3( + node_id="EmptyMochiLatentVideo_V3", + category="latent/video", + inputs=[ + io.Int.Input("width", default=848, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("length", default=25, min=7, max=nodes.MAX_RESOLUTION, step=6), + io.Int.Input("batch_size", default=1, min=1, max=4096), + ], + outputs=[ + io.Latent.Output(), + ], + ) + + @classmethod + def execute(cls, width, height, length, batch_size=1): + latent = torch.zeros( + [batch_size, 12, ((length - 1) // 6) + 1, height // 8, width // 8], + device=comfy.model_management.intermediate_device(), + ) + return io.NodeOutput({"samples": latent}) + + +NODES_LIST = [ + EmptyMochiLatentVideo, +] diff --git a/comfy_extras/v3/nodes_model_advanced.py b/comfy_extras/v3/nodes_model_advanced.py new file mode 100644 index 000000000..73e8bac6d --- /dev/null +++ b/comfy_extras/v3/nodes_model_advanced.py @@ -0,0 +1,387 @@ +from __future__ import annotations + +import torch + +import comfy.latent_formats +import comfy.model_sampling +import comfy.sd +import node_helpers +import nodes +from comfy_api.v3 import io + + +class LCM(comfy.model_sampling.EPS): + def calculate_denoised(self, sigma, model_output, model_input): + timestep = self.timestep(sigma).view(sigma.shape[:1] + (1,) * (model_output.ndim - 1)) + sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1)) + x0 = model_input - model_output * sigma + + sigma_data = 0.5 + scaled_timestep = timestep * 10.0 #timestep_scaling + + c_skip = sigma_data**2 / (scaled_timestep**2 + sigma_data**2) + c_out = scaled_timestep / (scaled_timestep**2 + sigma_data**2) ** 0.5 + + return c_out * x0 + c_skip * model_input + + +class ModelSamplingDiscreteDistilled(comfy.model_sampling.ModelSamplingDiscrete): + original_timesteps = 50 + + def __init__(self, model_config=None, zsnr=None): + super().__init__(model_config, zsnr=zsnr) + + self.skip_steps = self.num_timesteps // self.original_timesteps + + sigmas_valid = torch.zeros((self.original_timesteps), dtype=torch.float32) + for x in range(self.original_timesteps): + sigmas_valid[self.original_timesteps - 1 - x] = self.sigmas[self.num_timesteps - 1 - x * self.skip_steps] + + self.set_sigmas(sigmas_valid) + + def timestep(self, sigma): + log_sigma = sigma.log() + dists = log_sigma.to(self.log_sigmas.device) - self.log_sigmas[:, None] + return (dists.abs().argmin(dim=0).view(sigma.shape) * self.skip_steps + (self.skip_steps - 1)).to(sigma.device) + + def sigma(self, timestep): + t = torch.clamp( + ((timestep.float().to(self.log_sigmas.device) - (self.skip_steps - 1)) / self.skip_steps).float(), + min=0, + max=(len(self.sigmas) - 1), + ) + low_idx = t.floor().long() + high_idx = t.ceil().long() + w = t.frac() + log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx] + return log_sigma.exp().to(timestep.device) + + +class ModelComputeDtype(io.ComfyNodeV3): + @classmethod + def define_schema(cls): + return io.SchemaV3( + node_id="ModelComputeDtype_V3", + category="advanced/debug/model", + inputs=[ + io.Model.Input("model"), + io.Combo.Input("dtype", options=["default", "fp32", "fp16", "bf16"]), + ], + outputs=[ + io.Model.Output(), + ], + ) + + @classmethod + def execute(cls, model, dtype): + m = model.clone() + m.set_model_compute_dtype(node_helpers.string_to_torch_dtype(dtype)) + return io.NodeOutput(m) + + +class ModelSamplingContinuousEDM(io.ComfyNodeV3): + @classmethod + def define_schema(cls): + return io.SchemaV3( + node_id="ModelSamplingContinuousEDM_V3", + category="advanced/model", + inputs=[ + io.Model.Input("model"), + io.Combo.Input( + "sampling", options=["v_prediction", "edm", "edm_playground_v2.5", "eps", "cosmos_rflow"] + ), + io.Float.Input("sigma_max", default=120.0, min=0.0, max=1000.0, step=0.001, round=False), + io.Float.Input("sigma_min", default=0.002, min=0.0, max=1000.0, step=0.001, round=False), + ], + outputs=[ + io.Model.Output(), + ], + ) + + @classmethod + def execute(cls, model, sampling, sigma_max, sigma_min): + m = model.clone() + + sampling_base = comfy.model_sampling.ModelSamplingContinuousEDM + latent_format = None + sigma_data = 1.0 + if sampling == "eps": + sampling_type = comfy.model_sampling.EPS + elif sampling == "edm": + sampling_type = comfy.model_sampling.EDM + sigma_data = 0.5 + elif sampling == "v_prediction": + sampling_type = comfy.model_sampling.V_PREDICTION + elif sampling == "edm_playground_v2.5": + sampling_type = comfy.model_sampling.EDM + sigma_data = 0.5 + latent_format = comfy.latent_formats.SDXL_Playground_2_5() + elif sampling == "cosmos_rflow": + sampling_type = comfy.model_sampling.COSMOS_RFLOW + sampling_base = comfy.model_sampling.ModelSamplingCosmosRFlow + + class ModelSamplingAdvanced(sampling_base, sampling_type): + pass + + model_sampling = ModelSamplingAdvanced(model.model.model_config) + model_sampling.set_parameters(sigma_min, sigma_max, sigma_data) + m.add_object_patch("model_sampling", model_sampling) + if latent_format is not None: + m.add_object_patch("latent_format", latent_format) + return io.NodeOutput(m) + + +class ModelSamplingContinuousV(io.ComfyNodeV3): + @classmethod + def define_schema(cls): + return io.SchemaV3( + node_id="ModelSamplingContinuousV_V3", + category="advanced/model", + inputs=[ + io.Model.Input("model"), + io.Combo.Input("sampling", options=["v_prediction"]), + io.Float.Input("sigma_max", default=500.0, min=0.0, max=1000.0, step=0.001, round=False), + io.Float.Input("sigma_min", default=0.03, min=0.0, max=1000.0, step=0.001, round=False), + ], + outputs=[ + io.Model.Output(), + ], + ) + + @classmethod + def execute(cls, model, sampling, sigma_max, sigma_min): + m = model.clone() + + sigma_data = 1.0 + if sampling == "v_prediction": + sampling_type = comfy.model_sampling.V_PREDICTION + + class ModelSamplingAdvanced(comfy.model_sampling.ModelSamplingContinuousV, sampling_type): + pass + + model_sampling = ModelSamplingAdvanced(model.model.model_config) + model_sampling.set_parameters(sigma_min, sigma_max, sigma_data) + m.add_object_patch("model_sampling", model_sampling) + return io.NodeOutput(m) + + +class ModelSamplingDiscrete(io.ComfyNodeV3): + @classmethod + def define_schema(cls): + return io.SchemaV3( + node_id="ModelSamplingDiscrete_V3", + category="advanced/model", + inputs=[ + io.Model.Input("model"), + io.Combo.Input("sampling", options=["eps", "v_prediction", "lcm", "x0", "img_to_img"]), + io.Boolean.Input("zsnr", default=False), + ], + outputs=[ + io.Model.Output(), + ], + ) + + @classmethod + def execute(cls, model, sampling, zsnr): + m = model.clone() + + sampling_base = comfy.model_sampling.ModelSamplingDiscrete + if sampling == "eps": + sampling_type = comfy.model_sampling.EPS + elif sampling == "v_prediction": + sampling_type = comfy.model_sampling.V_PREDICTION + elif sampling == "lcm": + sampling_type = LCM + sampling_base = ModelSamplingDiscreteDistilled + elif sampling == "x0": + sampling_type = comfy.model_sampling.X0 + elif sampling == "img_to_img": + sampling_type = comfy.model_sampling.IMG_TO_IMG + + class ModelSamplingAdvanced(sampling_base, sampling_type): + pass + + model_sampling = ModelSamplingAdvanced(model.model.model_config, zsnr=zsnr) + + m.add_object_patch("model_sampling", model_sampling) + return io.NodeOutput(m) + + +class ModelSamplingFlux(io.ComfyNodeV3): + @classmethod + def define_schema(cls): + return io.SchemaV3( + node_id="ModelSamplingFlux_V3", + category="advanced/model", + inputs=[ + io.Model.Input("model"), + io.Float.Input("max_shift", default=1.15, min=0.0, max=100.0, step=0.01), + io.Float.Input("base_shift", default=0.5, min=0.0, max=100.0, step=0.01), + io.Int.Input("width", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=8), + io.Int.Input("height", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=8), + ], + outputs=[ + io.Model.Output(), + ], + ) + + @classmethod + def execute(cls, model, max_shift, base_shift, width, height): + m = model.clone() + + x1 = 256 + x2 = 4096 + mm = (max_shift - base_shift) / (x2 - x1) + b = base_shift - mm * x1 + shift = (width * height / (8 * 8 * 2 * 2)) * mm + b + + sampling_base = comfy.model_sampling.ModelSamplingFlux + sampling_type = comfy.model_sampling.CONST + + class ModelSamplingAdvanced(sampling_base, sampling_type): + pass + + model_sampling = ModelSamplingAdvanced(model.model.model_config) + model_sampling.set_parameters(shift=shift) + m.add_object_patch("model_sampling", model_sampling) + return io.NodeOutput(m) + + +class ModelSamplingSD3(io.ComfyNodeV3): + @classmethod + def define_schema(cls): + return io.SchemaV3( + node_id="ModelSamplingSD3_V3", + category="advanced/model", + inputs=[ + io.Model.Input("model"), + io.Float.Input("shift", default=3.0, min=0.0, max=100.0, step=0.01), + ], + outputs=[ + io.Model.Output(), + ], + ) + + @classmethod + def execute(cls, model, shift, multiplier: int | float = 1000): + m = model.clone() + + sampling_base = comfy.model_sampling.ModelSamplingDiscreteFlow + sampling_type = comfy.model_sampling.CONST + + class ModelSamplingAdvanced(sampling_base, sampling_type): + pass + + model_sampling = ModelSamplingAdvanced(model.model.model_config) + model_sampling.set_parameters(shift=shift, multiplier=multiplier) + m.add_object_patch("model_sampling", model_sampling) + return io.NodeOutput(m) + + +class ModelSamplingAuraFlow(ModelSamplingSD3): + @classmethod + def define_schema(cls): + return io.SchemaV3( + node_id="ModelSamplingAuraFlow_V3", + category="advanced/model", + inputs=[ + io.Model.Input("model"), + io.Float.Input("shift", default=1.73, min=0.0, max=100.0, step=0.01), + ], + outputs=[ + io.Model.Output(), + ], + ) + + @classmethod + def execute(cls, model, shift, multiplier: int | float = 1.0): + return super().execute(model, shift, multiplier) + + +class ModelSamplingStableCascade(io.ComfyNodeV3): + @classmethod + def define_schema(cls): + return io.SchemaV3( + node_id="ModelSamplingStableCascade_V3", + category="advanced/model", + inputs=[ + io.Model.Input("model"), + io.Float.Input("shift", default=2.0, min=0.0, max=100.0, step=0.01), + ], + outputs=[ + io.Model.Output(), + ], + ) + + @classmethod + def execute(cls, model, shift): + m = model.clone() + + sampling_base = comfy.model_sampling.StableCascadeSampling + sampling_type = comfy.model_sampling.EPS + + class ModelSamplingAdvanced(sampling_base, sampling_type): + pass + + model_sampling = ModelSamplingAdvanced(model.model.model_config) + model_sampling.set_parameters(shift) + m.add_object_patch("model_sampling", model_sampling) + return io.NodeOutput(m) + + +class RescaleCFG(io.ComfyNodeV3): + @classmethod + def define_schema(cls): + return io.SchemaV3( + node_id="RescaleCFG_V3", + category="advanced/model", + inputs=[ + io.Model.Input("model"), + io.Float.Input("multiplier", default=0.7, min=0.0, max=1.0, step=0.01), + ], + outputs=[ + io.Model.Output(), + ], + ) + + @classmethod + def execute(cls, model, multiplier): + def rescale_cfg(args): + cond = args["cond"] + uncond = args["uncond"] + cond_scale = args["cond_scale"] + sigma = args["sigma"] + sigma = sigma.view(sigma.shape[:1] + (1,) * (cond.ndim - 1)) + x_orig = args["input"] + + #rescale cfg has to be done on v-pred model output + x = x_orig / (sigma * sigma + 1.0) + cond = ((x - (x_orig - cond)) * (sigma ** 2 + 1.0) ** 0.5) / (sigma) + uncond = ((x - (x_orig - uncond)) * (sigma ** 2 + 1.0) ** 0.5) / (sigma) + + #rescalecfg + x_cfg = uncond + cond_scale * (cond - uncond) + ro_pos = torch.std(cond, dim=(1,2,3), keepdim=True) + ro_cfg = torch.std(x_cfg, dim=(1,2,3), keepdim=True) + + x_rescaled = x_cfg * (ro_pos / ro_cfg) + x_final = multiplier * x_rescaled + (1.0 - multiplier) * x_cfg + + return x_orig - (x - x_final * sigma / (sigma * sigma + 1.0) ** 0.5) + + m = model.clone() + m.set_model_sampler_cfg_function(rescale_cfg) + return io.NodeOutput(m) + + +NODES_LIST = [ + ModelSamplingAuraFlow, + ModelComputeDtype, + ModelSamplingContinuousEDM, + ModelSamplingContinuousV, + ModelSamplingDiscrete, + ModelSamplingFlux, + ModelSamplingSD3, + ModelSamplingStableCascade, + RescaleCFG, +] diff --git a/comfy_extras/v3/nodes_model_downscale.py b/comfy_extras/v3/nodes_model_downscale.py new file mode 100644 index 000000000..4adde9840 --- /dev/null +++ b/comfy_extras/v3/nodes_model_downscale.py @@ -0,0 +1,68 @@ +from __future__ import annotations + +import comfy.utils +from comfy_api.v3 import io + + +class PatchModelAddDownscale(io.ComfyNodeV3): + UPSCALE_METHODS = ["bicubic", "nearest-exact", "bilinear", "area", "bislerp"] + + @classmethod + def define_schema(cls): + return io.SchemaV3( + node_id="PatchModelAddDownscale_V3", + display_name="PatchModelAddDownscale (Kohya Deep Shrink) _V3", + category="model_patches/unet", + inputs=[ + io.Model.Input("model"), + io.Int.Input("block_number", default=3, min=1, max=32, step=1), + io.Float.Input("downscale_factor", default=2.0, min=0.1, max=9.0, step=0.001), + io.Float.Input("start_percent", default=0.0, min=0.0, max=1.0, step=0.001), + io.Float.Input("end_percent", default=0.35, min=0.0, max=1.0, step=0.001), + io.Boolean.Input("downscale_after_skip", default=True), + io.Combo.Input("downscale_method", options=cls.UPSCALE_METHODS), + io.Combo.Input("upscale_method", options=cls.UPSCALE_METHODS), + ], + outputs=[ + io.Model.Output(), + ], + ) + + @classmethod + def execute( + cls, model, block_number, downscale_factor, start_percent, end_percent, downscale_after_skip, downscale_method, upscale_method + ): + model_sampling = model.get_model_object("model_sampling") + sigma_start = model_sampling.percent_to_sigma(start_percent) + sigma_end = model_sampling.percent_to_sigma(end_percent) + + def input_block_patch(h, transformer_options): + if transformer_options["block"][1] == block_number: + sigma = transformer_options["sigmas"][0].item() + if sigma <= sigma_start and sigma >= sigma_end: + h = comfy.utils.common_upscale( + h, + round(h.shape[-1] * (1.0 / downscale_factor)), + round(h.shape[-2] * (1.0 / downscale_factor)), + downscale_method, + "disabled", + ) + return h + + def output_block_patch(h, hsp, transformer_options): + if h.shape[2] != hsp.shape[2]: + h = comfy.utils.common_upscale(h, hsp.shape[-1], hsp.shape[-2], upscale_method, "disabled") + return h, hsp + + m = model.clone() + if downscale_after_skip: + m.set_model_input_block_patch_after_skip(input_block_patch) + else: + m.set_model_input_block_patch(input_block_patch) + m.set_model_output_block_patch(output_block_patch) + return io.NodeOutput(m) + + +NODES_LIST = [ + PatchModelAddDownscale, +] diff --git a/comfy_extras/v3/nodes_photomaker.py b/comfy_extras/v3/nodes_photomaker.py new file mode 100644 index 000000000..e51f8a65d --- /dev/null +++ b/comfy_extras/v3/nodes_photomaker.py @@ -0,0 +1,205 @@ +from __future__ import annotations + +import torch +import torch.nn as nn + +import comfy.clip_model +import comfy.clip_vision +import comfy.model_management +import comfy.ops +import comfy.utils +import folder_paths +from comfy_api.v3 import io + +# code for model from: +# https://github.com/TencentARC/PhotoMaker/blob/main/photomaker/model.py under Apache License Version 2.0 +VISION_CONFIG_DICT = { + "hidden_size": 1024, + "image_size": 224, + "intermediate_size": 4096, + "num_attention_heads": 16, + "num_channels": 3, + "num_hidden_layers": 24, + "patch_size": 14, + "projection_dim": 768, + "hidden_act": "quick_gelu", + "model_type": "clip_vision_model", +} + +class MLP(nn.Module): + def __init__(self, in_dim, out_dim, hidden_dim, use_residual=True, operations=comfy.ops): + super().__init__() + if use_residual: + assert in_dim == out_dim + self.layernorm = operations.LayerNorm(in_dim) + self.fc1 = operations.Linear(in_dim, hidden_dim) + self.fc2 = operations.Linear(hidden_dim, out_dim) + self.use_residual = use_residual + self.act_fn = nn.GELU() + + def forward(self, x): + residual = x + x = self.layernorm(x) + x = self.fc1(x) + x = self.act_fn(x) + x = self.fc2(x) + if self.use_residual: + x = x + residual + return x + + +class FuseModule(nn.Module): + def __init__(self, embed_dim, operations): + super().__init__() + self.mlp1 = MLP(embed_dim * 2, embed_dim, embed_dim, use_residual=False, operations=operations) + self.mlp2 = MLP(embed_dim, embed_dim, embed_dim, use_residual=True, operations=operations) + self.layer_norm = operations.LayerNorm(embed_dim) + + def fuse_fn(self, prompt_embeds, id_embeds): + stacked_id_embeds = torch.cat([prompt_embeds, id_embeds], dim=-1) + stacked_id_embeds = self.mlp1(stacked_id_embeds) + prompt_embeds + stacked_id_embeds = self.mlp2(stacked_id_embeds) + stacked_id_embeds = self.layer_norm(stacked_id_embeds) + return stacked_id_embeds + + def forward( + self, + prompt_embeds, + id_embeds, + class_tokens_mask, + ) -> torch.Tensor: + # id_embeds shape: [b, max_num_inputs, 1, 2048] + id_embeds = id_embeds.to(prompt_embeds.dtype) + num_inputs = class_tokens_mask.sum().unsqueeze(0) # TODO: check for training case + batch_size, max_num_inputs = id_embeds.shape[:2] + # seq_length: 77 + seq_length = prompt_embeds.shape[1] + # flat_id_embeds shape: [b*max_num_inputs, 1, 2048] + flat_id_embeds = id_embeds.view( + -1, id_embeds.shape[-2], id_embeds.shape[-1] + ) + # valid_id_mask [b*max_num_inputs] + valid_id_mask = ( + torch.arange(max_num_inputs, device=flat_id_embeds.device)[None, :] + < num_inputs[:, None] + ) + valid_id_embeds = flat_id_embeds[valid_id_mask.flatten()] + + prompt_embeds = prompt_embeds.view(-1, prompt_embeds.shape[-1]) + class_tokens_mask = class_tokens_mask.view(-1) + valid_id_embeds = valid_id_embeds.view(-1, valid_id_embeds.shape[-1]) + # slice out the image token embeddings + image_token_embeds = prompt_embeds[class_tokens_mask] + stacked_id_embeds = self.fuse_fn(image_token_embeds, valid_id_embeds) + assert class_tokens_mask.sum() == stacked_id_embeds.shape[0], f"{class_tokens_mask.sum()} != {stacked_id_embeds.shape[0]}" + prompt_embeds.masked_scatter_(class_tokens_mask[:, None], stacked_id_embeds.to(prompt_embeds.dtype)) + return prompt_embeds.view(batch_size, seq_length, -1) + + +class PhotoMakerIDEncoder(comfy.clip_model.CLIPVisionModelProjection): + def __init__(self): + self.load_device = comfy.model_management.text_encoder_device() + offload_device = comfy.model_management.text_encoder_offload_device() + dtype = comfy.model_management.text_encoder_dtype(self.load_device) + + super().__init__(VISION_CONFIG_DICT, dtype, offload_device, comfy.ops.manual_cast) + self.visual_projection_2 = comfy.ops.manual_cast.Linear(1024, 1280, bias=False) + self.fuse_module = FuseModule(2048, comfy.ops.manual_cast) + + def forward(self, id_pixel_values, prompt_embeds, class_tokens_mask): + b, num_inputs, c, h, w = id_pixel_values.shape + id_pixel_values = id_pixel_values.view(b * num_inputs, c, h, w) + + shared_id_embeds = self.vision_model(id_pixel_values)[2] + id_embeds = self.visual_projection(shared_id_embeds) + id_embeds_2 = self.visual_projection_2(shared_id_embeds) + + id_embeds = id_embeds.view(b, num_inputs, 1, -1) + id_embeds_2 = id_embeds_2.view(b, num_inputs, 1, -1) + + id_embeds = torch.cat((id_embeds, id_embeds_2), dim=-1) + return self.fuse_module(prompt_embeds, id_embeds, class_tokens_mask) + + +class PhotoMakerEncode(io.ComfyNodeV3): + @classmethod + def define_schema(cls): + return io.SchemaV3( + node_id="PhotoMakerEncode_V3", + category="_for_testing/photomaker", + inputs=[ + io.Photomaker.Input("photomaker"), + io.Image.Input("image"), + io.Clip.Input("clip"), + io.String.Input("text", multiline=True, dynamic_prompts=True, default="photograph of photomaker"), + ], + outputs=[ + io.Conditioning.Output(), + ], + is_experimental=True, + ) + + @classmethod + def execute(cls, photomaker, image, clip, text): + special_token = "photomaker" + pixel_values = comfy.clip_vision.clip_preprocess(image.to(photomaker.load_device)).float() + try: + index = text.split(" ").index(special_token) + 1 + except ValueError: + index = -1 + tokens = clip.tokenize(text, return_word_ids=True) + out_tokens = {} + for k in tokens: + out_tokens[k] = [] + for t in tokens[k]: + f = list(filter(lambda x: x[2] != index, t)) + while len(f) < len(t): + f.append(t[-1]) + out_tokens[k].append(f) + + cond, pooled = clip.encode_from_tokens(out_tokens, return_pooled=True) + + if index > 0: + token_index = index - 1 + num_id_images = 1 + class_tokens_mask = [True if token_index <= i < token_index+num_id_images else False for i in range(77)] + out = photomaker( + id_pixel_values=pixel_values.unsqueeze(0), prompt_embeds=cond.to(photomaker.load_device), + class_tokens_mask=torch.tensor(class_tokens_mask, dtype=torch.bool, device=photomaker.load_device).unsqueeze(0), + ) + else: + out = cond + + return io.NodeOutput([[out, {"pooled_output": pooled}]]) + + +class PhotoMakerLoader(io.ComfyNodeV3): + @classmethod + def define_schema(cls): + return io.SchemaV3( + node_id="PhotoMakerLoader_V3", + category="_for_testing/photomaker", + inputs=[ + io.Combo.Input("photomaker_model_name", options=folder_paths.get_filename_list("photomaker")), + ], + outputs=[ + io.Photomaker.Output(), + ], + is_experimental=True, + ) + + @classmethod + def execute(cls, photomaker_model_name): + photomaker_model_path = folder_paths.get_full_path_or_raise("photomaker", photomaker_model_name) + photomaker_model = PhotoMakerIDEncoder() + data = comfy.utils.load_torch_file(photomaker_model_path, safe_load=True) + if "id_encoder" in data: + data = data["id_encoder"] + photomaker_model.load_state_dict(data) + return io.NodeOutput(photomaker_model) + + +NODES_LIST = [ + PhotoMakerEncode, + PhotoMakerLoader, +] diff --git a/comfy_extras/v3/nodes_pixart.py b/comfy_extras/v3/nodes_pixart.py new file mode 100644 index 000000000..fc489e1c3 --- /dev/null +++ b/comfy_extras/v3/nodes_pixart.py @@ -0,0 +1,33 @@ +from __future__ import annotations + +import nodes +from comfy_api.v3 import io + + +class CLIPTextEncodePixArtAlpha(io.ComfyNodeV3): + @classmethod + def define_schema(cls): + return io.SchemaV3( + node_id="CLIPTextEncodePixArtAlpha_V3", + category="advanced/conditioning", + description="Encodes text and sets the resolution conditioning for PixArt Alpha. Does not apply to PixArt Sigma.", + inputs=[ + io.Int.Input("width", default=1024, min=0, max=nodes.MAX_RESOLUTION), + io.Int.Input("height", default=1024, min=0, max=nodes.MAX_RESOLUTION), + io.String.Input("text", multiline=True, dynamic_prompts=True), + io.Clip.Input("clip"), + ], + outputs=[ + io.Conditioning.Output(), + ], + ) + + @classmethod + def execute(cls, width, height, text, clip): + tokens = clip.tokenize(text) + return io.NodeOutput(clip.encode_from_tokens_scheduled(tokens, add_dict={"width": width, "height": height})) + + +NODES_LIST = [ + CLIPTextEncodePixArtAlpha, +] diff --git a/comfy_extras/v3/nodes_post_processing.py b/comfy_extras/v3/nodes_post_processing.py new file mode 100644 index 000000000..46af1ad09 --- /dev/null +++ b/comfy_extras/v3/nodes_post_processing.py @@ -0,0 +1,255 @@ +from __future__ import annotations + +import math + +import numpy as np +import torch +import torch.nn.functional as F +from PIL import Image + +import comfy.model_management +import comfy.utils +import node_helpers +from comfy_api.v3 import io + + +def gaussian_kernel(kernel_size: int, sigma: float, device=None): + x, y = torch.meshgrid(torch.linspace(-1, 1, kernel_size, device=device), torch.linspace(-1, 1, kernel_size, device=device), indexing="ij") + d = torch.sqrt(x * x + y * y) + g = torch.exp(-(d * d) / (2.0 * sigma * sigma)) + return g / g.sum() + + +class Blend(io.ComfyNodeV3): + @classmethod + def define_schema(cls): + return io.SchemaV3( + node_id="ImageBlend_V3", + category="image/postprocessing", + inputs=[ + io.Image.Input("image1"), + io.Image.Input("image2"), + io.Float.Input("blend_factor", default=0.5, min=0.0, max=1.0, step=0.01), + io.Combo.Input("blend_mode", options=["normal", "multiply", "screen", "overlay", "soft_light", "difference"]), + ], + outputs=[ + io.Image.Output(), + ], + ) + + @classmethod + def execute(cls, image1: torch.Tensor, image2: torch.Tensor, blend_factor: float, blend_mode: str): + image1, image2 = node_helpers.image_alpha_fix(image1, image2) + image2 = image2.to(image1.device) + if image1.shape != image2.shape: + image2 = image2.permute(0, 3, 1, 2) + image2 = comfy.utils.common_upscale( + image2, image1.shape[2], image1.shape[1], upscale_method="bicubic", crop="center" + ) + image2 = image2.permute(0, 2, 3, 1) + + blended_image = cls.blend_mode(image1, image2, blend_mode) + blended_image = image1 * (1 - blend_factor) + blended_image * blend_factor + blended_image = torch.clamp(blended_image, 0, 1) + return io.NodeOutput(blended_image) + + @classmethod + def blend_mode(cls, img1, img2, mode): + if mode == "normal": + return img2 + elif mode == "multiply": + return img1 * img2 + elif mode == "screen": + return 1 - (1 - img1) * (1 - img2) + elif mode == "overlay": + return torch.where(img1 <= 0.5, 2 * img1 * img2, 1 - 2 * (1 - img1) * (1 - img2)) + elif mode == "soft_light": + return torch.where(img2 <= 0.5, img1 - (1 - 2 * img2) * img1 * (1 - img1), img1 + (2 * img2 - 1) * (cls.g(img1) - img1)) + elif mode == "difference": + return img1 - img2 + raise ValueError(f"Unsupported blend mode: {mode}") + + @classmethod + def g(cls, x): + return torch.where(x <= 0.25, ((16 * x - 12) * x + 4) * x, torch.sqrt(x)) + + +class Blur(io.ComfyNodeV3): + @classmethod + def define_schema(cls): + return io.SchemaV3( + node_id="ImageBlur_V3", + category="image/postprocessing", + inputs=[ + io.Image.Input("image"), + io.Int.Input("blur_radius", default=1, min=1, max=31, step=1), + io.Float.Input("sigma", default=1.0, min=0.1, max=10.0, step=0.1), + ], + outputs=[ + io.Image.Output(), + ], + ) + + @classmethod + def execute(cls, image: torch.Tensor, blur_radius: int, sigma: float): + if blur_radius == 0: + return io.NodeOutput(image) + + image = image.to(comfy.model_management.get_torch_device()) + batch_size, height, width, channels = image.shape + + kernel_size = blur_radius * 2 + 1 + kernel = gaussian_kernel(kernel_size, sigma, device=image.device).repeat(channels, 1, 1).unsqueeze(1) + + image = image.permute(0, 3, 1, 2) # Torch wants (B, C, H, W) we use (B, H, W, C) + padded_image = F.pad(image, (blur_radius,blur_radius,blur_radius,blur_radius), "reflect") + blurred = F.conv2d(padded_image, kernel, padding=kernel_size // 2, groups=channels)[:,:,blur_radius:-blur_radius, blur_radius:-blur_radius] + blurred = blurred.permute(0, 2, 3, 1) + + return io.NodeOutput(blurred.to(comfy.model_management.intermediate_device())) + + +class ImageScaleToTotalPixels(io.ComfyNodeV3): + upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "lanczos"] + crop_methods = ["disabled", "center"] + + @classmethod + def define_schema(cls): + return io.SchemaV3( + node_id="ImageScaleToTotalPixels_V3", + category="image/upscaling", + inputs=[ + io.Image.Input("image"), + io.Combo.Input("upscale_method", options=cls.upscale_methods), + io.Float.Input("megapixels", default=1.0, min=0.01, max=16.0, step=0.01), + ], + outputs=[ + io.Image.Output(), + ], + ) + + @classmethod + def execute(cls, image, upscale_method, megapixels): + samples = image.movedim(-1,1) + total = int(megapixels * 1024 * 1024) + + scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2])) + width = round(samples.shape[3] * scale_by) + height = round(samples.shape[2] * scale_by) + + s = comfy.utils.common_upscale(samples, width, height, upscale_method, "disabled") + return io.NodeOutput(s.movedim(1,-1)) + + +class Quantize(io.ComfyNodeV3): + @classmethod + def define_schema(cls): + return io.SchemaV3( + node_id="ImageQuantize_V3", + category="image/postprocessing", + inputs=[ + io.Image.Input("image"), + io.Int.Input("colors", default=256, min=1, max=256, step=1), + io.Combo.Input("dither", options=["none", "floyd-steinberg", "bayer-2", "bayer-4", "bayer-8", "bayer-16"]), + ], + outputs=[ + io.Image.Output(), + ], + ) + + @staticmethod + def bayer(im, pal_im, order): + def normalized_bayer_matrix(n): + if n == 0: + return np.zeros((1,1), "float32") + q = 4 ** n + m = q * normalized_bayer_matrix(n - 1) + return np.bmat(((m-1.5, m+0.5), (m+1.5, m-0.5))) / q + + num_colors = len(pal_im.getpalette()) // 3 + spread = 2 * 256 / num_colors + bayer_n = int(math.log2(order)) + bayer_matrix = torch.from_numpy(spread * normalized_bayer_matrix(bayer_n) + 0.5) + + result = torch.from_numpy(np.array(im).astype(np.float32)) + tw = math.ceil(result.shape[0] / bayer_matrix.shape[0]) + th = math.ceil(result.shape[1] / bayer_matrix.shape[1]) + tiled_matrix = bayer_matrix.tile(tw, th).unsqueeze(-1) + result.add_(tiled_matrix[:result.shape[0],:result.shape[1]]).clamp_(0, 255) + result = result.to(dtype=torch.uint8) + + im = Image.fromarray(result.cpu().numpy()) + return im.quantize(palette=pal_im, dither=Image.Dither.NONE) + + @classmethod + def execute(cls, image: torch.Tensor, colors: int, dither: str): + batch_size, height, width, _ = image.shape + result = torch.zeros_like(image) + + for b in range(batch_size): + im = Image.fromarray((image[b] * 255).to(torch.uint8).numpy(), mode='RGB') + + pal_im = im.quantize(colors=colors) # Required as described in https://github.com/python-pillow/Pillow/issues/5836 + + if dither == "none": + quantized_image = im.quantize(palette=pal_im, dither=Image.Dither.NONE) + elif dither == "floyd-steinberg": + quantized_image = im.quantize(palette=pal_im, dither=Image.Dither.FLOYDSTEINBERG) + elif dither.startswith("bayer"): + order = int(dither.split('-')[-1]) + quantized_image = cls.bayer(im, pal_im, order) + + quantized_array = torch.tensor(np.array(quantized_image.convert("RGB"))).float() / 255 + result[b] = quantized_array + + return io.NodeOutput(result) + + +class Sharpen(io.ComfyNodeV3): + @classmethod + def define_schema(cls): + return io.SchemaV3( + node_id="ImageSharpen_V3", + category="image/postprocessing", + inputs=[ + io.Image.Input("image"), + io.Int.Input("sharpen_radius", default=1, min=1, max=31, step=1), + io.Float.Input("sigma", default=1.0, min=0.1, max=10.0, step=0.01), + io.Float.Input("alpha", default=1.0, min=0.0, max=5.0, step=0.01), + ], + outputs=[ + io.Image.Output(), + ], + ) + + @classmethod + def execute(cls, image: torch.Tensor, sharpen_radius: int, sigma:float, alpha: float): + if sharpen_radius == 0: + return io.NodeOutput(image) + + batch_size, height, width, channels = image.shape + image = image.to(comfy.model_management.get_torch_device()) + + kernel_size = sharpen_radius * 2 + 1 + kernel = gaussian_kernel(kernel_size, sigma, device=image.device) * -(alpha*10) + center = kernel_size // 2 + kernel[center, center] = kernel[center, center] - kernel.sum() + 1.0 + kernel = kernel.repeat(channels, 1, 1).unsqueeze(1) + + tensor_image = image.permute(0, 3, 1, 2) # Torch wants (B, C, H, W) we use (B, H, W, C) + tensor_image = F.pad(tensor_image, (sharpen_radius,sharpen_radius,sharpen_radius,sharpen_radius), "reflect") + sharpened = F.conv2d(tensor_image, kernel, padding=center, groups=channels)[:,:,sharpen_radius:-sharpen_radius, sharpen_radius:-sharpen_radius] + sharpened = sharpened.permute(0, 2, 3, 1) + + result = torch.clamp(sharpened, 0, 1) + + return io.NodeOutput(result.to(comfy.model_management.intermediate_device())) + + +NODES_LIST = [ + Blend, + Blur, + ImageScaleToTotalPixels, + Quantize, + Sharpen, +] diff --git a/nodes.py b/nodes.py index 39cd5a2f7..f260c3f34 100644 --- a/nodes.py +++ b/nodes.py @@ -2314,18 +2314,26 @@ def init_builtin_extra_nodes(): "v3/nodes_controlnet.py", "v3/nodes_cosmos.py", "v3/nodes_differential_diffusion.py", + "v3/nodes_edit_model.py", "v3/nodes_flux.py", "v3/nodes_freelunch.py", "v3/nodes_fresca.py", "v3/nodes_gits.py", + "v3/nodes_hidream.py", "v3/nodes_images.py", "v3/nodes_latent.py", "v3/nodes_lt.py", "v3/nodes_mask.py", + "v3/nodes_mochi.py", + "v3/nodes_model_advanced.py", + "v3/nodes_model_downscale.py", "v3/nodes_morphology.py", "v3/nodes_optimalsteps.py", "v3/nodes_pag.py", "v3/nodes_perpneg.py", + "v3/nodes_photomaker.py", + "v3/nodes_pixart.py", + "v3/nodes_post_processing.py", "v3/nodes_preview_any.py", "v3/nodes_primitive.py", "v3/nodes_rebatch.py",