diff --git a/comfy_extras/v3/nodes_latent.py b/comfy_extras/v3/nodes_latent.py new file mode 100644 index 000000000..54c96e072 --- /dev/null +++ b/comfy_extras/v3/nodes_latent.py @@ -0,0 +1,340 @@ +from __future__ import annotations + +import torch + +import comfy.utils +import comfy_extras.nodes_post_processing +from comfy_api.v3 import io + + +def reshape_latent_to(target_shape, latent, repeat_batch=True): + if latent.shape[1:] != target_shape[1:]: + latent = comfy.utils.common_upscale( + latent, target_shape[-1], target_shape[-2], "bilinear", "center" + ) + if repeat_batch: + return comfy.utils.repeat_to_batch_size(latent, target_shape[0]) + return latent + + +class LatentAdd(io.ComfyNodeV3): + @classmethod + def define_schema(cls): + return io.SchemaV3( + node_id="LatentAdd_V3", + category="latent/advanced", + inputs=[ + io.Latent.Input(id="samples1"), + io.Latent.Input(id="samples2"), + ], + outputs=[ + io.Latent.Output(), + ], + ) + + @classmethod + def execute(cls, samples1, samples2): + samples_out = samples1.copy() + + s1 = samples1["samples"] + s2 = samples2["samples"] + + s2 = reshape_latent_to(s1.shape, s2) + samples_out["samples"] = s1 + s2 + return io.NodeOutput(samples_out) + + +class LatentApplyOperation(io.ComfyNodeV3): + @classmethod + def define_schema(cls): + return io.SchemaV3( + node_id="LatentApplyOperation_V3", + category="latent/advanced/operations", + is_experimental=True, + inputs=[ + io.Latent.Input(id="samples"), + io.LatentOperation.Input(id="operation"), + ], + outputs=[ + io.Latent.Output(), + ], + ) + + @classmethod + def execute(cls, samples, operation): + samples_out = samples.copy() + + s1 = samples["samples"] + samples_out["samples"] = operation(latent=s1) + return io.NodeOutput(samples_out) + + +class LatentApplyOperationCFG(io.ComfyNodeV3): + @classmethod + def define_schema(cls): + return io.SchemaV3( + node_id="LatentApplyOperationCFG_V3", + category="latent/advanced/operations", + is_experimental=True, + inputs=[ + io.Model.Input(id="model"), + io.LatentOperation.Input(id="operation"), + ], + outputs=[ + io.Model.Output(), + ], + ) + + @classmethod + def execute(cls, model, operation): + m = model.clone() + + def pre_cfg_function(args): + conds_out = args["conds_out"] + if len(conds_out) == 2: + conds_out[0] = operation(latent=(conds_out[0] - conds_out[1])) + conds_out[1] + else: + conds_out[0] = operation(latent=conds_out[0]) + return conds_out + + m.set_model_sampler_pre_cfg_function(pre_cfg_function) + return io.NodeOutput(m) + + +class LatentBatch(io.ComfyNodeV3): + @classmethod + def define_schema(cls): + return io.SchemaV3( + node_id="LatentBatch_V3", + category="latent/batch", + inputs=[ + io.Latent.Input(id="samples1"), + io.Latent.Input(id="samples2"), + ], + outputs=[ + io.Latent.Output(), + ], + ) + + @classmethod + def execute(cls, samples1, samples2): + samples_out = samples1.copy() + s1 = samples1["samples"] + s2 = samples2["samples"] + + s2 = reshape_latent_to(s1.shape, s2, repeat_batch=False) + s = torch.cat((s1, s2), dim=0) + samples_out["samples"] = s + samples_out["batch_index"] = (samples1.get("batch_index", [x for x in range(0, s1.shape[0])]) + + samples2.get("batch_index", [x for x in range(0, s2.shape[0])])) + return io.NodeOutput(samples_out) + + +class LatentBatchSeedBehavior(io.ComfyNodeV3): + @classmethod + def define_schema(cls): + return io.SchemaV3( + node_id="LatentBatchSeedBehavior_V3", + category="latent/advanced", + inputs=[ + io.Latent.Input(id="samples"), + io.Combo.Input(id="seed_behavior", options=["random", "fixed"], default="fixed"), + ], + outputs=[ + io.Latent.Output(), + ], + ) + + @classmethod + def execute(cls, samples, seed_behavior): + samples_out = samples.copy() + latent = samples["samples"] + if seed_behavior == "random": + if 'batch_index' in samples_out: + samples_out.pop('batch_index') + elif seed_behavior == "fixed": + batch_number = samples_out.get("batch_index", [0])[0] + samples_out["batch_index"] = [batch_number] * latent.shape[0] + + return io.NodeOutput(samples_out) + + +class LatentInterpolate(io.ComfyNodeV3): + @classmethod + def define_schema(cls): + return io.SchemaV3( + node_id="LatentInterpolate_V3", + category="latent/advanced", + inputs=[ + io.Latent.Input(id="samples1"), + io.Latent.Input(id="samples2"), + io.Float.Input(id="ratio", default=1.0, min=0.0, max=1.0, step=0.01), + ], + outputs=[ + io.Latent.Output(), + ], + ) + + @classmethod + def execute(cls, samples1, samples2, ratio): + samples_out = samples1.copy() + + s1 = samples1["samples"] + s2 = samples2["samples"] + + s2 = reshape_latent_to(s1.shape, s2) + + m1 = torch.linalg.vector_norm(s1, dim=(1)) + m2 = torch.linalg.vector_norm(s2, dim=(1)) + + s1 = torch.nan_to_num(s1 / m1) + s2 = torch.nan_to_num(s2 / m2) + + t = (s1 * ratio + s2 * (1.0 - ratio)) + mt = torch.linalg.vector_norm(t, dim=(1)) + st = torch.nan_to_num(t / mt) + + samples_out["samples"] = st * (m1 * ratio + m2 * (1.0 - ratio)) + return io.NodeOutput(samples_out) + + +class LatentMultiply(io.ComfyNodeV3): + @classmethod + def define_schema(cls): + return io.SchemaV3( + node_id="LatentMultiply_V3", + category="latent/advanced", + inputs=[ + io.Latent.Input(id="samples"), + io.Float.Input(id="multiplier", default=1.0, min=-10.0, max=10.0, step=0.01), + ], + outputs=[ + io.Latent.Output(), + ], + ) + + @classmethod + def execute(cls, samples, multiplier): + samples_out = samples.copy() + + s1 = samples["samples"] + samples_out["samples"] = s1 * multiplier + return io.NodeOutput(samples_out) + + +class LatentOperationSharpen(io.ComfyNodeV3): + @classmethod + def define_schema(cls): + return io.SchemaV3( + node_id="LatentOperationSharpen_V3", + category="latent/advanced/operations", + is_experimental=True, + inputs=[ + io.Int.Input(id="sharpen_radius", default=9, min=1, max=31, step=1), + io.Float.Input(id="sigma", default=1.0, min=0.1, max=10.0, step=0.1), + io.Float.Input(id="alpha", default=0.1, min=0.0, max=5.0, step=0.01), + ], + outputs=[ + io.LatentOperation.Output(), + ], + ) + + @classmethod + def execute(cls, sharpen_radius, sigma, alpha): + def sharpen(latent, **kwargs): + luminance = (torch.linalg.vector_norm(latent, dim=(1)) + 1e-6)[:,None] + normalized_latent = latent / luminance + channels = latent.shape[1] + + kernel_size = sharpen_radius * 2 + 1 + kernel = comfy_extras.nodes_post_processing.gaussian_kernel(kernel_size, sigma, device=luminance.device) + center = kernel_size // 2 + + kernel *= alpha * -10 + kernel[center, center] = kernel[center, center] - kernel.sum() + 1.0 + + padded_image = torch.nn.functional.pad( + normalized_latent, (sharpen_radius,sharpen_radius,sharpen_radius,sharpen_radius), "reflect" + ) + sharpened = torch.nn.functional.conv2d( + padded_image, kernel.repeat(channels, 1, 1).unsqueeze(1), padding=kernel_size // 2, groups=channels + )[:,:,sharpen_radius:-sharpen_radius, sharpen_radius:-sharpen_radius] + + return luminance * sharpened + return io.NodeOutput(sharpen) + + +class LatentOperationTonemapReinhard(io.ComfyNodeV3): + @classmethod + def define_schema(cls): + return io.SchemaV3( + node_id="LatentOperationTonemapReinhard_V3", + category="latent/advanced/operations", + is_experimental=True, + inputs=[ + io.Float.Input(id="multiplier", default=1.0, min=0.0, max=100.0, step=0.01), + ], + outputs=[ + io.LatentOperation.Output(), + ], + ) + + @classmethod + def execute(cls, multiplier): + def tonemap_reinhard(latent, **kwargs): + latent_vector_magnitude = (torch.linalg.vector_norm(latent, dim=(1)) + 0.0000000001)[:,None] + normalized_latent = latent / latent_vector_magnitude + + mean = torch.mean(latent_vector_magnitude, dim=(1,2,3), keepdim=True) + std = torch.std(latent_vector_magnitude, dim=(1,2,3), keepdim=True) + + top = (std * 5 + mean) * multiplier + + #reinhard + latent_vector_magnitude *= (1.0 / top) + new_magnitude = latent_vector_magnitude / (latent_vector_magnitude + 1.0) + new_magnitude *= top + + return normalized_latent * new_magnitude + return io.NodeOutput(tonemap_reinhard) + + +class LatentSubtract(io.ComfyNodeV3): + @classmethod + def define_schema(cls): + return io.SchemaV3( + node_id="LatentSubtract_V3", + category="latent/advanced", + inputs=[ + io.Latent.Input(id="samples1"), + io.Latent.Input(id="samples2"), + ], + outputs=[ + io.Latent.Output(), + ], + ) + + @classmethod + def execute(cls, samples1, samples2): + samples_out = samples1.copy() + + s1 = samples1["samples"] + s2 = samples2["samples"] + + s2 = reshape_latent_to(s1.shape, s2) + samples_out["samples"] = s1 - s2 + return io.NodeOutput(samples_out) + + +NODES_LIST = [ + LatentAdd, + LatentApplyOperation, + LatentApplyOperationCFG, + LatentBatch, + LatentBatchSeedBehavior, + LatentInterpolate, + LatentMultiply, + LatentOperationSharpen, + LatentOperationTonemapReinhard, + LatentSubtract, +] diff --git a/nodes.py b/nodes.py index 50956edb8..9be58e027 100644 --- a/nodes.py +++ b/nodes.py @@ -2319,6 +2319,7 @@ def init_builtin_extra_nodes(): "v3/nodes_fresca.py", "v3/nodes_gits.py", "v3/nodes_images.py", + "v3/nodes_latent.py", "v3/nodes_mask.py", "v3/nodes_preview_any.py", "v3/nodes_primitive.py",