diff --git a/comfy_extras/v3/nodes_ace.py b/comfy_extras/v3/nodes_ace.py index 988e0ed5a..fdad8f800 100644 --- a/comfy_extras/v3/nodes_ace.py +++ b/comfy_extras/v3/nodes_ace.py @@ -52,6 +52,6 @@ class EmptyAceStepLatentAudio(io.ComfyNode): NODES_LIST: list[type[io.ComfyNode]] = [ - TextEncodeAceStepAudio, EmptyAceStepLatentAudio, + TextEncodeAceStepAudio, ] diff --git a/comfy_extras/v3/nodes_advanced_samplers.py b/comfy_extras/v3/nodes_advanced_samplers.py index 91512effb..ecbe7094f 100644 --- a/comfy_extras/v3/nodes_advanced_samplers.py +++ b/comfy_extras/v3/nodes_advanced_samplers.py @@ -122,7 +122,7 @@ class SamplerEulerCFGpp(io.ComfyNode): return io.NodeOutput(sampler) -NODES_LIST = [ - SamplerLCMUpscale, +NODES_LIST: list[type[io.ComfyNode]] = [ SamplerEulerCFGpp, + SamplerLCMUpscale, ] diff --git a/comfy_extras/v3/nodes_align_your_steps.py b/comfy_extras/v3/nodes_align_your_steps.py index c2a211c99..acb71c631 100644 --- a/comfy_extras/v3/nodes_align_your_steps.py +++ b/comfy_extras/v3/nodes_align_your_steps.py @@ -5,6 +5,18 @@ import torch from comfy_api.latest import io + +def loglinear_interp(t_steps, num_steps): + """Performs log-linear interpolation of a given array of decreasing numbers.""" + xs = np.linspace(0, 1, len(t_steps)) + ys = np.log(t_steps[::-1]) + + new_xs = np.linspace(0, 1, num_steps) + new_ys = np.interp(new_xs, xs, ys) + + return np.exp(new_ys)[::-1].copy() + + NOISE_LEVELS = { "SD1": [ 14.6146412293, @@ -36,17 +48,6 @@ NOISE_LEVELS = { } -def loglinear_interp(t_steps, num_steps): - """Performs log-linear interpolation of a given array of decreasing numbers.""" - xs = np.linspace(0, 1, len(t_steps)) - ys = np.log(t_steps[::-1]) - - new_xs = np.linspace(0, 1, num_steps) - new_ys = np.interp(new_xs, xs, ys) - - return np.exp(new_ys)[::-1].copy() - - class AlignYourStepsScheduler(io.ComfyNode): @classmethod def define_schema(cls) -> io.Schema: @@ -78,6 +79,6 @@ class AlignYourStepsScheduler(io.ComfyNode): return io.NodeOutput(torch.FloatTensor(sigmas)) -NODES_LIST = [ +NODES_LIST: list[type[io.ComfyNode]] = [ AlignYourStepsScheduler, ] diff --git a/comfy_extras/v3/nodes_apg.py b/comfy_extras/v3/nodes_apg.py index 961bdddb3..f9fc208d0 100644 --- a/comfy_extras/v3/nodes_apg.py +++ b/comfy_extras/v3/nodes_apg.py @@ -93,6 +93,6 @@ class APG(io.ComfyNode): return io.NodeOutput(m) -NODES_LIST = [ +NODES_LIST: list[type[io.ComfyNode]] = [ APG, ] diff --git a/comfy_extras/v3/nodes_attention_multiply.py b/comfy_extras/v3/nodes_attention_multiply.py index 9fb86714f..dfcb85568 100644 --- a/comfy_extras/v3/nodes_attention_multiply.py +++ b/comfy_extras/v3/nodes_attention_multiply.py @@ -131,9 +131,9 @@ class UNetTemporalAttentionMultiply(io.ComfyNode): return io.NodeOutput(m) -NODES_LIST = [ - UNetSelfAttentionMultiply, - UNetCrossAttentionMultiply, +NODES_LIST: list[type[io.ComfyNode]] = [ CLIPAttentionMultiply, + UNetCrossAttentionMultiply, + UNetSelfAttentionMultiply, UNetTemporalAttentionMultiply, ] diff --git a/comfy_extras/v3/nodes_audio.py b/comfy_extras/v3/nodes_audio.py index 994863a42..089c2cb73 100644 --- a/comfy_extras/v3/nodes_audio.py +++ b/comfy_extras/v3/nodes_audio.py @@ -3,6 +3,7 @@ from __future__ import annotations import hashlib import os +import av import torch import torchaudio @@ -12,6 +13,28 @@ import node_helpers from comfy_api.latest import io, ui +class EmptyLatentAudio(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="EmptyLatentAudio_V3", + category="latent/audio", + inputs=[ + io.Float.Input("seconds", default=47.6, min=1.0, max=1000.0, step=0.1), + io.Int.Input( + "batch_size", default=1, min=1, max=4096, tooltip="The number of latent images in the batch." + ), + ], + outputs=[io.Latent.Output()], + ) + + @classmethod + def execute(cls, seconds, batch_size) -> io.NodeOutput: + length = round((seconds * 44100 / 2048) / 2) * 2 + latent = torch.zeros([batch_size, 64, length], device=comfy.model_management.intermediate_device()) + return io.NodeOutput({"samples": latent, "type": "audio"}) + + class ConditioningStableAudio(io.ComfyNode): @classmethod def define_schema(cls): @@ -42,83 +65,71 @@ class ConditioningStableAudio(io.ComfyNode): ) -class EmptyLatentAudio(io.ComfyNode): +class VAEEncodeAudio(io.ComfyNode): @classmethod def define_schema(cls): return io.Schema( - node_id="EmptyLatentAudio_V3", + node_id="VAEEncodeAudio_V3", category="latent/audio", inputs=[ - io.Float.Input("seconds", default=47.6, min=1.0, max=1000.0, step=0.1), - io.Int.Input( - id="batch_size", default=1, min=1, max=4096, tooltip="The number of latent images in the batch." - ), + io.Audio.Input("audio"), + io.Vae.Input("vae"), ], outputs=[io.Latent.Output()], ) @classmethod - def execute(cls, seconds, batch_size) -> io.NodeOutput: - length = round((seconds * 44100 / 2048) / 2) * 2 - latent = torch.zeros([batch_size, 64, length], device=comfy.model_management.intermediate_device()) - return io.NodeOutput({"samples": latent, "type": "audio"}) + def execute(cls, vae, audio) -> io.NodeOutput: + sample_rate = audio["sample_rate"] + if 44100 != sample_rate: + waveform = torchaudio.functional.resample(audio["waveform"], sample_rate, 44100) + else: + waveform = audio["waveform"] + return io.NodeOutput({"samples": vae.encode(waveform.movedim(1, -1))}) -class LoadAudio(io.ComfyNode): +class VAEDecodeAudio(io.ComfyNode): @classmethod def define_schema(cls): return io.Schema( - node_id="LoadAudio_V3", # frontend expects "LoadAudio" to work - display_name="Load Audio _V3", # frontend ignores "display_name" for this node - category="audio", + node_id="VAEDecodeAudio_V3", + category="latent/audio", inputs=[ - io.Combo.Input("audio", upload=io.UploadType.audio, options=cls.get_files_options()), + io.Latent.Input("samples"), + io.Vae.Input("vae"), ], outputs=[io.Audio.Output()], ) @classmethod - def get_files_options(cls) -> list[str]: - input_dir = folder_paths.get_input_directory() - return sorted(folder_paths.filter_files_content_types(os.listdir(input_dir), ["audio", "video"])) - - @classmethod - def execute(cls, audio) -> io.NodeOutput: - waveform, sample_rate = torchaudio.load(folder_paths.get_annotated_filepath(audio)) - return io.NodeOutput({"waveform": waveform.unsqueeze(0), "sample_rate": sample_rate}) - - @classmethod - def fingerprint_inputs(s, audio): - image_path = folder_paths.get_annotated_filepath(audio) - m = hashlib.sha256() - with open(image_path, "rb") as f: - m.update(f.read()) - return m.digest().hex() - - @classmethod - def validate_inputs(s, audio): - if not folder_paths.exists_annotated_filepath(audio): - return "Invalid audio file: {}".format(audio) - return True + def execute(cls, vae, samples) -> io.NodeOutput: + audio = vae.decode(samples["samples"]).movedim(-1, 1) + std = torch.std(audio, dim=[1, 2], keepdim=True) * 5.0 + std[std < 1.0] = 1.0 + audio /= std + return io.NodeOutput({"waveform": audio, "sample_rate": 44100}) -class PreviewAudio(io.ComfyNode): +class SaveAudio(io.ComfyNode): @classmethod def define_schema(cls): return io.Schema( - node_id="PreviewAudio_V3", # frontend expects "PreviewAudio" to work - display_name="Preview Audio _V3", # frontend ignores "display_name" for this node + node_id="SaveAudio_V3", # frontend expects "SaveAudio" to work + display_name="Save Audio _V3", # frontend ignores "display_name" for this node category="audio", inputs=[ io.Audio.Input("audio"), + io.String.Input("filename_prefix", default="audio/ComfyUI"), ], hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo], is_output_node=True, ) @classmethod - def execute(cls, audio) -> io.NodeOutput: - return io.NodeOutput(ui=ui.PreviewAudio(audio, cls=cls)) + def execute(cls, audio, filename_prefix="ComfyUI", format="flac") -> io.NodeOutput: + return io.NodeOutput( + ui=ui.AudioSaveHelper.get_save_audio_ui(audio, filename_prefix=filename_prefix, cls=cls, format=format) + ) class SaveAudioMP3(io.ComfyNode): @@ -171,71 +182,99 @@ class SaveAudioOpus(io.ComfyNode): ) -class SaveAudio(io.ComfyNode): +class PreviewAudio(io.ComfyNode): @classmethod def define_schema(cls): return io.Schema( - node_id="SaveAudio_V3", # frontend expects "SaveAudio" to work - display_name="Save Audio _V3", # frontend ignores "display_name" for this node + node_id="PreviewAudio_V3", # frontend expects "PreviewAudio" to work + display_name="Preview Audio _V3", # frontend ignores "display_name" for this node category="audio", inputs=[ io.Audio.Input("audio"), - io.String.Input("filename_prefix", default="audio/ComfyUI"), ], hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo], is_output_node=True, ) @classmethod - def execute(cls, audio, filename_prefix="ComfyUI", format="flac") -> io.NodeOutput: - return io.NodeOutput( - ui=ui.AudioSaveHelper.get_save_audio_ui(audio, filename_prefix=filename_prefix, cls=cls, format=format) - ) + def execute(cls, audio) -> io.NodeOutput: + return io.NodeOutput(ui=ui.PreviewAudio(audio, cls=cls)) -class VAEDecodeAudio(io.ComfyNode): +class LoadAudio(io.ComfyNode): @classmethod def define_schema(cls): return io.Schema( - node_id="VAEDecodeAudio_V3", - category="latent/audio", + node_id="LoadAudio_V3", # frontend expects "LoadAudio" to work + display_name="Load Audio _V3", # frontend ignores "display_name" for this node + category="audio", inputs=[ - io.Latent.Input("samples"), - io.Vae.Input("vae"), + io.Combo.Input("audio", upload=io.UploadType.audio, options=cls.get_files_options()), ], outputs=[io.Audio.Output()], ) @classmethod - def execute(cls, vae, samples) -> io.NodeOutput: - audio = vae.decode(samples["samples"]).movedim(-1, 1) - std = torch.std(audio, dim=[1, 2], keepdim=True) * 5.0 - std[std < 1.0] = 1.0 - audio /= std - return io.NodeOutput({"waveform": audio, "sample_rate": 44100}) - - -class VAEEncodeAudio(io.ComfyNode): - @classmethod - def define_schema(cls): - return io.Schema( - node_id="VAEEncodeAudio_V3", - category="latent/audio", - inputs=[ - io.Audio.Input("audio"), - io.Vae.Input("vae"), - ], - outputs=[io.Latent.Output()], - ) + def get_files_options(cls) -> list[str]: + input_dir = folder_paths.get_input_directory() + return sorted(folder_paths.filter_files_content_types(os.listdir(input_dir), ["audio", "video"])) @classmethod - def execute(cls, vae, audio) -> io.NodeOutput: - sample_rate = audio["sample_rate"] - if 44100 != sample_rate: - waveform = torchaudio.functional.resample(audio["waveform"], sample_rate, 44100) - else: - waveform = audio["waveform"] - return io.NodeOutput({"samples": vae.encode(waveform.movedim(1, -1))}) + def load(cls, filepath: str) -> tuple[torch.Tensor, int]: + with av.open(filepath) as af: + if not af.streams.audio: + raise ValueError("No audio stream found in the file.") + + stream = af.streams.audio[0] + sr = stream.codec_context.sample_rate + n_channels = stream.channels + + frames = [] + length = 0 + for frame in af.decode(streams=stream.index): + buf = torch.from_numpy(frame.to_ndarray()) + if buf.shape[0] != n_channels: + buf = buf.view(-1, n_channels).t() + + frames.append(buf) + length += buf.shape[1] + + if not frames: + raise ValueError("No audio frames decoded.") + + wav = torch.cat(frames, dim=1) + wav = cls.f32_pcm(wav) + return wav, sr + + @classmethod + def f32_pcm(cls, wav: torch.Tensor) -> torch.Tensor: + """Convert audio to float 32 bits PCM format.""" + if wav.dtype.is_floating_point: + return wav + elif wav.dtype == torch.int16: + return wav.float() / (2 ** 15) + elif wav.dtype == torch.int32: + return wav.float() / (2 ** 31) + raise ValueError(f"Unsupported wav dtype: {wav.dtype}") + + @classmethod + def execute(cls, audio) -> io.NodeOutput: + waveform, sample_rate = cls.load(folder_paths.get_annotated_filepath(audio)) + return io.NodeOutput({"waveform": waveform.unsqueeze(0), "sample_rate": sample_rate}) + + @classmethod + def fingerprint_inputs(s, audio): + image_path = folder_paths.get_annotated_filepath(audio) + m = hashlib.sha256() + with open(image_path, "rb") as f: + m.update(f.read()) + return m.digest().hex() + + @classmethod + def validate_inputs(s, audio): + if not folder_paths.exists_annotated_filepath(audio): + return "Invalid audio file: {}".format(audio) + return True NODES_LIST: list[type[io.ComfyNode]] = [ @@ -243,9 +282,9 @@ NODES_LIST: list[type[io.ComfyNode]] = [ EmptyLatentAudio, LoadAudio, PreviewAudio, + SaveAudio, SaveAudioMP3, SaveAudioOpus, - SaveAudio, VAEDecodeAudio, VAEEncodeAudio, ] diff --git a/comfy_extras/v3/nodes_camera_trajectory.py b/comfy_extras/v3/nodes_camera_trajectory.py index edc159591..40fb1dcf9 100644 --- a/comfy_extras/v3/nodes_camera_trajectory.py +++ b/comfy_extras/v3/nodes_camera_trajectory.py @@ -212,6 +212,6 @@ class WanCameraEmbedding(io.ComfyNode): return io.NodeOutput(control_camera_video, width, height, length) -NODES_LIST = [ +NODES_LIST: list[type[io.ComfyNode]] = [ WanCameraEmbedding, ] diff --git a/comfy_extras/v3/nodes_canny.py b/comfy_extras/v3/nodes_canny.py index e24b0df38..0b68db381 100644 --- a/comfy_extras/v3/nodes_canny.py +++ b/comfy_extras/v3/nodes_canny.py @@ -27,6 +27,6 @@ class Canny(io.ComfyNode): return io.NodeOutput(img_out) -NODES_LIST = [ +NODES_LIST: list[type[io.ComfyNode]] = [ Canny, ] diff --git a/comfy_extras/v3/nodes_cfg.py b/comfy_extras/v3/nodes_cfg.py index 66ec27f9a..e8e84a2bd 100644 --- a/comfy_extras/v3/nodes_cfg.py +++ b/comfy_extras/v3/nodes_cfg.py @@ -5,6 +5,7 @@ import torch from comfy_api.latest import io +# https://github.com/WeichenFan/CFG-Zero-star def optimized_scale(positive, negative): positive_flat = positive.reshape(positive.shape[0], -1) negative_flat = negative.reshape(negative.shape[0], -1) @@ -21,6 +22,36 @@ def optimized_scale(positive, negative): return st_star.reshape([positive.shape[0]] + [1] * (positive.ndim - 1)) +class CFGZeroStar(io.ComfyNode): + @classmethod + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="CFGZeroStar_V3", + category="advanced/guidance", + inputs=[ + io.Model.Input("model"), + ], + outputs=[io.Model.Output(display_name="patched_model")], + ) + + @classmethod + def execute(cls, model) -> io.NodeOutput: + m = model.clone() + + def cfg_zero_star(args): + guidance_scale = args['cond_scale'] + x = args['input'] + cond_p = args['cond_denoised'] + uncond_p = args['uncond_denoised'] + out = args["denoised"] + alpha = optimized_scale(x - cond_p, x - uncond_p) + + return out + uncond_p * (alpha - 1.0) + guidance_scale * uncond_p * (1.0 - alpha) + + m.set_model_sampler_post_cfg_function(cfg_zero_star) + return io.NodeOutput(m) + + class CFGNorm(io.ComfyNode): @classmethod def define_schema(cls) -> io.Schema: @@ -52,37 +83,7 @@ class CFGNorm(io.ComfyNode): return io.NodeOutput(m) -class CFGZeroStar(io.ComfyNode): - @classmethod - def define_schema(cls) -> io.Schema: - return io.Schema( - node_id="CFGZeroStar_V3", - category="advanced/guidance", - inputs=[ - io.Model.Input("model"), - ], - outputs=[io.Model.Output(display_name="patched_model")], - ) - - @classmethod - def execute(cls, model) -> io.NodeOutput: - m = model.clone() - - def cfg_zero_star(args): - guidance_scale = args['cond_scale'] - x = args['input'] - cond_p = args['cond_denoised'] - uncond_p = args['uncond_denoised'] - out = args["denoised"] - alpha = optimized_scale(x - cond_p, x - uncond_p) - - return out + uncond_p * (alpha - 1.0) + guidance_scale * uncond_p * (1.0 - alpha) - - m.set_model_sampler_post_cfg_function(cfg_zero_star) - return io.NodeOutput(m) - - -NODES_LIST = [ +NODES_LIST: list[type[io.ComfyNode]] = [ CFGNorm, CFGZeroStar, ] diff --git a/comfy_extras/v3/nodes_clip_sdxl.py b/comfy_extras/v3/nodes_clip_sdxl.py index 54b83dc16..3d05b7595 100644 --- a/comfy_extras/v3/nodes_clip_sdxl.py +++ b/comfy_extras/v3/nodes_clip_sdxl.py @@ -4,6 +4,31 @@ import nodes from comfy_api.latest import io +class CLIPTextEncodeSDXLRefiner(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="CLIPTextEncodeSDXLRefiner_V3", + category="advanced/conditioning", + inputs=[ + io.Float.Input("ascore", default=6.0, min=0.0, max=1000.0, step=0.01), + io.Int.Input("width", default=1024, min=0, max=nodes.MAX_RESOLUTION), + io.Int.Input("height", default=1024, min=0, max=nodes.MAX_RESOLUTION), + io.String.Input("text", multiline=True, dynamic_prompts=True), + io.Clip.Input("clip"), + ], + outputs=[io.Conditioning.Output()], + ) + + @classmethod + def execute(cls, ascore, width, height, text, clip) -> io.NodeOutput: + tokens = clip.tokenize(text) + conditioning = clip.encode_from_tokens_scheduled( + tokens, add_dict={"aesthetic_score": ascore, "width": width, "height": height} + ) + return io.NodeOutput(conditioning) + + class CLIPTextEncodeSDXL(io.ComfyNode): @classmethod def define_schema(cls): @@ -48,32 +73,7 @@ class CLIPTextEncodeSDXL(io.ComfyNode): return io.NodeOutput(conditioning) -class CLIPTextEncodeSDXLRefiner(io.ComfyNode): - @classmethod - def define_schema(cls): - return io.Schema( - node_id="CLIPTextEncodeSDXLRefiner_V3", - category="advanced/conditioning", - inputs=[ - io.Float.Input("ascore", default=6.0, min=0.0, max=1000.0, step=0.01), - io.Int.Input("width", default=1024, min=0, max=nodes.MAX_RESOLUTION), - io.Int.Input("height", default=1024, min=0, max=nodes.MAX_RESOLUTION), - io.String.Input("text", multiline=True, dynamic_prompts=True), - io.Clip.Input("clip"), - ], - outputs=[io.Conditioning.Output()], - ) - - @classmethod - def execute(cls, ascore, width, height, text, clip) -> io.NodeOutput: - tokens = clip.tokenize(text) - conditioning = clip.encode_from_tokens_scheduled( - tokens, add_dict={"aesthetic_score": ascore, "width": width, "height": height} - ) - return io.NodeOutput(conditioning) - - -NODES_LIST = [ +NODES_LIST: list[type[io.ComfyNode]] = [ CLIPTextEncodeSDXL, CLIPTextEncodeSDXLRefiner, ] diff --git a/comfy_extras/v3/nodes_compositing.py b/comfy_extras/v3/nodes_compositing.py index cfe195148..b1e59ec78 100644 --- a/comfy_extras/v3/nodes_compositing.py +++ b/comfy_extras/v3/nodes_compositing.py @@ -112,32 +112,6 @@ def porter_duff_composite( return out_image, out_alpha -class JoinImageWithAlpha(io.ComfyNode): - @classmethod - def define_schema(cls): - return io.Schema( - node_id="JoinImageWithAlpha_V3", - display_name="Join Image with Alpha _V3", - category="mask/compositing", - inputs=[ - io.Image.Input("image"), - io.Mask.Input("alpha"), - ], - outputs=[io.Image.Output()], - ) - - @classmethod - def execute(cls, image: torch.Tensor, alpha: torch.Tensor) -> io.NodeOutput: - batch_size = min(len(image), len(alpha)) - out_images = [] - - alpha = 1.0 - resize_mask(alpha, image.shape[1:]) - for i in range(batch_size): - out_images.append(torch.cat((image[i][:, :, :3], alpha[i].unsqueeze(2)), dim=2)) - - return io.NodeOutput(torch.stack(out_images)) - - class PorterDuffImageComposite(io.ComfyNode): @classmethod def define_schema(cls): @@ -219,7 +193,33 @@ class SplitImageWithAlpha(io.ComfyNode): return io.NodeOutput(torch.stack(out_images), 1.0 - torch.stack(out_alphas)) -NODES_LIST = [ +class JoinImageWithAlpha(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="JoinImageWithAlpha_V3", + display_name="Join Image with Alpha _V3", + category="mask/compositing", + inputs=[ + io.Image.Input("image"), + io.Mask.Input("alpha"), + ], + outputs=[io.Image.Output()], + ) + + @classmethod + def execute(cls, image: torch.Tensor, alpha: torch.Tensor) -> io.NodeOutput: + batch_size = min(len(image), len(alpha)) + out_images = [] + + alpha = 1.0 - resize_mask(alpha, image.shape[1:]) + for i in range(batch_size): + out_images.append(torch.cat((image[i][:, :, :3], alpha[i].unsqueeze(2)), dim=2)) + + return io.NodeOutput(torch.stack(out_images)) + + +NODES_LIST: list[type[io.ComfyNode]] = [ JoinImageWithAlpha, PorterDuffImageComposite, SplitImageWithAlpha, diff --git a/comfy_extras/v3/nodes_cond.py b/comfy_extras/v3/nodes_cond.py index 9d3181886..2ce343500 100644 --- a/comfy_extras/v3/nodes_cond.py +++ b/comfy_extras/v3/nodes_cond.py @@ -54,7 +54,7 @@ class T5TokenizerOptions(io.ComfyNode): return io.NodeOutput(clip) -NODES_LIST = [ +NODES_LIST: list[type[io.ComfyNode]] = [ CLIPTextEncodeControlnet, T5TokenizerOptions, ] diff --git a/comfy_extras/v3/nodes_controlnet.py b/comfy_extras/v3/nodes_controlnet.py index 4788113a4..a4656fad2 100644 --- a/comfy_extras/v3/nodes_controlnet.py +++ b/comfy_extras/v3/nodes_controlnet.py @@ -3,6 +3,33 @@ from comfy.cldm.control_types import UNION_CONTROLNET_TYPES from comfy_api.latest import io +class SetUnionControlNetType(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SetUnionControlNetType_V3", + category="conditioning/controlnet", + inputs=[ + io.ControlNet.Input("control_net"), + io.Combo.Input("type", options=["auto"] + list(UNION_CONTROLNET_TYPES.keys())), + ], + outputs=[ + io.ControlNet.Output(), + ], + ) + + @classmethod + def execute(cls, control_net, type) -> io.NodeOutput: + control_net = control_net.copy() + type_number = UNION_CONTROLNET_TYPES.get(type, -1) + if type_number >= 0: + control_net.set_extra_arg("control_type", [type_number]) + else: + control_net.set_extra_arg("control_type", []) + + return io.NodeOutput(control_net) + + class ControlNetApplyAdvanced(io.ComfyNode): @classmethod def define_schema(cls): @@ -60,33 +87,6 @@ class ControlNetApplyAdvanced(io.ComfyNode): return io.NodeOutput(out[0], out[1]) -class SetUnionControlNetType(io.ComfyNode): - @classmethod - def define_schema(cls): - return io.Schema( - node_id="SetUnionControlNetType_V3", - category="conditioning/controlnet", - inputs=[ - io.ControlNet.Input("control_net"), - io.Combo.Input("type", options=["auto"] + list(UNION_CONTROLNET_TYPES.keys())), - ], - outputs=[ - io.ControlNet.Output(), - ], - ) - - @classmethod - def execute(cls, control_net, type) -> io.NodeOutput: - control_net = control_net.copy() - type_number = UNION_CONTROLNET_TYPES.get(type, -1) - if type_number >= 0: - control_net.set_extra_arg("control_type", [type_number]) - else: - control_net.set_extra_arg("control_type", []) - - return io.NodeOutput(control_net) - - class ControlNetInpaintingAliMamaApply(ControlNetApplyAdvanced): @classmethod def define_schema(cls): diff --git a/comfy_extras/v3/nodes_cosmos.py b/comfy_extras/v3/nodes_cosmos.py index 9779e0ffe..a32c192e8 100644 --- a/comfy_extras/v3/nodes_cosmos.py +++ b/comfy_extras/v3/nodes_cosmos.py @@ -9,6 +9,29 @@ import nodes from comfy_api.latest import io +class EmptyCosmosLatentVideo(io.ComfyNode): + @classmethod + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="EmptyCosmosLatentVideo_V3", + category="latent/video", + inputs=[ + io.Int.Input("width", default=1280, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("height", default=704, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("length", default=121, min=1, max=nodes.MAX_RESOLUTION, step=8), + io.Int.Input("batch_size", default=1, min=1, max=4096), + ], + outputs=[io.Latent.Output()], + ) + + @classmethod + def execute(cls, width, height, length, batch_size) -> io.NodeOutput: + latent = torch.zeros( + [batch_size, 16, ((length - 1) // 8) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device() + ) + return io.NodeOutput({"samples": latent}) + + def vae_encode_with_padding(vae, image, width, height, length, padding=0): pixels = comfy.utils.common_upscale(image[..., :3].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) pixel_len = min(pixels.shape[0], length) @@ -116,30 +139,7 @@ class CosmosPredict2ImageToVideoLatent(io.ComfyNode): return io.NodeOutput(out_latent) -class EmptyCosmosLatentVideo(io.ComfyNode): - @classmethod - def define_schema(cls) -> io.Schema: - return io.Schema( - node_id="EmptyCosmosLatentVideo_V3", - category="latent/video", - inputs=[ - io.Int.Input("width", default=1280, min=16, max=nodes.MAX_RESOLUTION, step=16), - io.Int.Input("height", default=704, min=16, max=nodes.MAX_RESOLUTION, step=16), - io.Int.Input("length", default=121, min=1, max=nodes.MAX_RESOLUTION, step=8), - io.Int.Input("batch_size", default=1, min=1, max=4096), - ], - outputs=[io.Latent.Output()], - ) - - @classmethod - def execute(cls, width, height, length, batch_size) -> io.NodeOutput: - latent = torch.zeros( - [batch_size, 16, ((length - 1) // 8) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device() - ) - return io.NodeOutput({"samples": latent}) - - -NODES_LIST = [ +NODES_LIST: list[type[io.ComfyNode]] = [ CosmosImageToVideoLatent, CosmosPredict2ImageToVideoLatent, EmptyCosmosLatentVideo, diff --git a/comfy_extras/v3/nodes_custom_sampler.py b/comfy_extras/v3/nodes_custom_sampler.py new file mode 100644 index 000000000..dca18b6ad --- /dev/null +++ b/comfy_extras/v3/nodes_custom_sampler.py @@ -0,0 +1,1035 @@ +from __future__ import annotations + +import math + +import torch + +import comfy.sample +import comfy.samplers +import comfy.utils +import latent_preview +import node_helpers +from comfy.k_diffusion import sa_solver +from comfy.k_diffusion import sampling as k_diffusion_sampling +from comfy_api.latest import io + + +class BasicScheduler(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="BasicScheduler_V3", + category="sampling/custom_sampling/schedulers", + inputs=[ + io.Model.Input("model"), + io.Combo.Input("scheduler", options=comfy.samplers.SCHEDULER_NAMES), + io.Int.Input("steps", default=20, min=1, max=10000), + io.Float.Input("denoise", default=1.0, min=0.0, max=1.0, step=0.01), + ], + outputs=[ + io.Sigmas.Output(), + ] + ) + + @classmethod + def execute(cls, model, scheduler, steps, denoise): + total_steps = steps + if denoise < 1.0: + if denoise <= 0.0: + return io.NodeOutput(torch.FloatTensor([])) + total_steps = int(steps/denoise) + + sigmas = comfy.samplers.calculate_sigmas(model.get_model_object("model_sampling"), scheduler, total_steps).cpu() + sigmas = sigmas[-(steps + 1):] + return io.NodeOutput(sigmas) + + +class KarrasScheduler(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="KarrasScheduler_V3", + category="sampling/custom_sampling/schedulers", + inputs=[ + io.Int.Input("steps", default=20, min=1, max=10000), + io.Float.Input("sigma_max", default=14.614642, min=0.0, max=5000.0, step=0.01, round=False), + io.Float.Input("sigma_min", default=0.0291675, min=0.0, max=5000.0, step=0.01, round=False), + io.Float.Input("rho", default=7.0, min=0.0, max=100.0, step=0.01, round=False), + ], + outputs=[ + io.Sigmas.Output(), + ] + ) + + @classmethod + def execute(cls, steps, sigma_max, sigma_min, rho): + sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, rho=rho) + return io.NodeOutput(sigmas) + + +class ExponentialScheduler(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="ExponentialScheduler_V3", + category="sampling/custom_sampling/schedulers", + inputs=[ + io.Int.Input("steps", default=20, min=1, max=10000), + io.Float.Input("sigma_max", default=14.614642, min=0.0, max=5000.0, step=0.01, round=False), + io.Float.Input("sigma_min", default=0.0291675, min=0.0, max=5000.0, step=0.01, round=False), + ], + outputs=[ + io.Sigmas.Output(), + ] + ) + + @classmethod + def execute(cls, steps, sigma_max, sigma_min): + sigmas = k_diffusion_sampling.get_sigmas_exponential(n=steps, sigma_min=sigma_min, sigma_max=sigma_max) + return io.NodeOutput(sigmas) + + +class PolyexponentialScheduler(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="PolyexponentialScheduler_V3", + category="sampling/custom_sampling/schedulers", + inputs=[ + io.Int.Input("steps", default=20, min=1, max=10000), + io.Float.Input("sigma_max", default=14.614642, min=0.0, max=5000.0, step=0.01, round=False), + io.Float.Input("sigma_min", default=0.0291675, min=0.0, max=5000.0, step=0.01, round=False), + io.Float.Input("rho", default=1.0, min=0.0, max=100.0, step=0.01, round=False), + ], + outputs=[ + io.Sigmas.Output(), + ] + ) + + @classmethod + def execute(cls, steps, sigma_max, sigma_min, rho): + sigmas = k_diffusion_sampling.get_sigmas_polyexponential(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, rho=rho) + return io.NodeOutput(sigmas) + + +class LaplaceScheduler(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="LaplaceScheduler_V3", + category="sampling/custom_sampling/schedulers", + inputs=[ + io.Int.Input("steps", default=20, min=1, max=10000), + io.Float.Input("sigma_max", default=14.614642, min=0.0, max=5000.0, step=0.01, round=False), + io.Float.Input("sigma_min", default=0.0291675, min=0.0, max=5000.0, step=0.01, round=False), + io.Float.Input("mu", default=0.0, min=-10.0, max=10.0, step=0.1, round=False), + io.Float.Input("beta", default=0.5, min=0.0, max=10.0, step=0.1, round=False), + ], + outputs=[ + io.Sigmas.Output(), + ] + ) + + @classmethod + def execute(cls, steps, sigma_max, sigma_min, mu, beta): + sigmas = k_diffusion_sampling.get_sigmas_laplace(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, mu=mu, beta=beta) + return io.NodeOutput(sigmas) + + +class SDTurboScheduler(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SDTurboScheduler_V3", + category="sampling/custom_sampling/schedulers", + inputs=[ + io.Model.Input("model"), + io.Int.Input("steps", default=1, min=1, max=10), + io.Float.Input("denoise", default=1.0, min=0, max=1.0, step=0.01), + ], + outputs=[ + io.Sigmas.Output(), + ] + ) + + @classmethod + def execute(cls, model, steps, denoise): + start_step = 10 - int(10 * denoise) + timesteps = torch.flip(torch.arange(1, 11) * 100 - 1, (0,))[start_step:start_step + steps] + sigmas = model.get_model_object("model_sampling").sigma(timesteps) + sigmas = torch.cat([sigmas, sigmas.new_zeros([1])]) + return io.NodeOutput(sigmas) + + +class BetaSamplingScheduler(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="BetaSamplingScheduler_V3", + category="sampling/custom_sampling/schedulers", + inputs=[ + io.Model.Input("model"), + io.Int.Input("steps", default=20, min=1, max=10000), + io.Float.Input("alpha", default=0.6, min=0.0, max=50.0, step=0.01, round=False), + io.Float.Input("beta", default=0.6, min=0.0, max=50.0, step=0.01, round=False), + ], + outputs=[ + io.Sigmas.Output(), + ] + ) + + @classmethod + def execute(cls, model, steps, alpha, beta): + sigmas = comfy.samplers.beta_scheduler(model.get_model_object("model_sampling"), steps, alpha=alpha, beta=beta) + return io.NodeOutput(sigmas) + + +class VPScheduler(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="VPScheduler_V3", + category="sampling/custom_sampling/schedulers", + inputs=[ + io.Int.Input("steps", default=20, min=1, max=10000), + io.Float.Input("beta_d", default=19.9, min=0.0, max=5000.0, step=0.01, round=False), + io.Float.Input("beta_min", default=0.1, min=0.0, max=5000.0, step=0.01, round=False), + io.Float.Input("eps_s", default=0.001, min=0.0, max=1.0, step=0.0001, round=False), + ], + outputs=[ + io.Sigmas.Output(), + ] + ) + + @classmethod + def execute(cls, steps, beta_d, beta_min, eps_s): + sigmas = k_diffusion_sampling.get_sigmas_vp(n=steps, beta_d=beta_d, beta_min=beta_min, eps_s=eps_s) + return io.NodeOutput(sigmas) + + +class SplitSigmas(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SplitSigmas_V3", + category="sampling/custom_sampling/sigmas", + inputs=[ + io.Sigmas.Input("sigmas"), + io.Int.Input("step", default=0, min=0, max=10000), + ], + outputs=[ + io.Sigmas.Output(display_name="high_sigmas"), + io.Sigmas.Output(display_name="low_sigmas"), + ] + ) + + @classmethod + def execute(cls, sigmas, step): + sigmas1 = sigmas[:step + 1] + sigmas2 = sigmas[step:] + return io.NodeOutput(sigmas1, sigmas2) + + +class SplitSigmasDenoise(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SplitSigmasDenoise_V3", + category="sampling/custom_sampling/sigmas", + inputs=[ + io.Sigmas.Input("sigmas"), + io.Float.Input("denoise", default=1.0, min=0.0, max=1.0, step=0.01), + ], + outputs=[ + io.Sigmas.Output(display_name="high_sigmas"), + io.Sigmas.Output(display_name="low_sigmas"), + ] + ) + + @classmethod + def execute(cls, sigmas, denoise): + steps = max(sigmas.shape[-1] - 1, 0) + total_steps = round(steps * denoise) + sigmas1 = sigmas[:-(total_steps)] + sigmas2 = sigmas[-(total_steps + 1):] + return io.NodeOutput(sigmas1, sigmas2) + + +class FlipSigmas(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="FlipSigmas_V3", + category="sampling/custom_sampling/sigmas", + inputs=[ + io.Sigmas.Input("sigmas"), + ], + outputs=[ + io.Sigmas.Output(), + ] + ) + + @classmethod + def execute(cls, sigmas): + if len(sigmas) == 0: + return io.NodeOutput(sigmas) + + sigmas = sigmas.flip(0) + if sigmas[0] == 0: + sigmas[0] = 0.0001 + return io.NodeOutput(sigmas) + + +class SetFirstSigma(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SetFirstSigma_V3", + category="sampling/custom_sampling/sigmas", + inputs=[ + io.Sigmas.Input("sigmas"), + io.Float.Input("sigma", default=136.0, min=0.0, max=20000.0, step=0.001, round=False), + ], + outputs=[ + io.Sigmas.Output(), + ] + ) + + @classmethod + def execute(cls, sigmas, sigma): + sigmas = sigmas.clone() + sigmas[0] = sigma + return io.NodeOutput(sigmas) + + +class ExtendIntermediateSigmas(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="ExtendIntermediateSigmas_V3", + category="sampling/custom_sampling/sigmas", + inputs=[ + io.Sigmas.Input("sigmas"), + io.Int.Input("steps", default=2, min=1, max=100), + io.Float.Input("start_at_sigma", default=-1.0, min=-1.0, max=20000.0, step=0.01, round=False), + io.Float.Input("end_at_sigma", default=12.0, min=0.0, max=20000.0, step=0.01, round=False), + io.Combo.Input("spacing", options=['linear', 'cosine', 'sine']), + ], + outputs=[ + io.Sigmas.Output(), + ] + ) + + @classmethod + def execute(cls, sigmas: torch.Tensor, steps: int, start_at_sigma: float, end_at_sigma: float, spacing: str): + if start_at_sigma < 0: + start_at_sigma = float("inf") + + interpolator = { + 'linear': lambda x: x, + 'cosine': lambda x: torch.sin(x*math.pi/2), + 'sine': lambda x: 1 - torch.cos(x*math.pi/2) + }[spacing] + + # linear space for our interpolation function + x = torch.linspace(0, 1, steps + 1, device=sigmas.device)[1:-1] + computed_spacing = interpolator(x) + + extended_sigmas = [] + for i in range(len(sigmas) - 1): + sigma_current = sigmas[i] + sigma_next = sigmas[i+1] + + extended_sigmas.append(sigma_current) + + if end_at_sigma <= sigma_current <= start_at_sigma: + interpolated_steps = computed_spacing * (sigma_next - sigma_current) + sigma_current + extended_sigmas.extend(interpolated_steps.tolist()) + + # Add the last sigma value + if len(sigmas) > 0: + extended_sigmas.append(sigmas[-1]) + + extended_sigmas = torch.FloatTensor(extended_sigmas) + + return io.NodeOutput(extended_sigmas) + + +class SamplingPercentToSigma(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SamplingPercentToSigma_V3", + category="sampling/custom_sampling/sigmas", + inputs=[ + io.Model.Input("model"), + io.Float.Input("sampling_percent", default=0.0, min=0.0, max=1.0, step=0.0001), + io.Boolean.Input("return_actual_sigma", default=False, tooltip="Return the actual sigma value instead of the value used for interval checks.\nThis only affects results at 0.0 and 1.0."), + ], + outputs=[ + io.Float.Output(display_name="sigma_value"), + ] + ) + + @classmethod + def execute(cls, model, sampling_percent, return_actual_sigma): + model_sampling = model.get_model_object("model_sampling") + sigma_val = model_sampling.percent_to_sigma(sampling_percent) + if return_actual_sigma: + if sampling_percent == 0.0: + sigma_val = model_sampling.sigma_max.item() + elif sampling_percent == 1.0: + sigma_val = model_sampling.sigma_min.item() + return io.NodeOutput(sigma_val) + + +class KSamplerSelect(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="KSamplerSelect_V3", + category="sampling/custom_sampling/samplers", + inputs=[ + io.Combo.Input("sampler_name", options=comfy.samplers.SAMPLER_NAMES), + ], + outputs=[ + io.Sampler.Output(), + ] + ) + + @classmethod + def execute(cls, sampler_name): + sampler = comfy.samplers.sampler_object(sampler_name) + return io.NodeOutput(sampler) + + +class SamplerDPMPP_3M_SDE(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SamplerDPMPP_3M_SDE_V3", + category="sampling/custom_sampling/samplers", + inputs=[ + io.Float.Input("eta", default=1.0, min=0.0, max=100.0, step=0.01, round=False), + io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False), + io.Combo.Input("noise_device", options=['gpu', 'cpu']), + ], + outputs=[ + io.Sampler.Output(), + ] + ) + + @classmethod + def execute(cls, eta, s_noise, noise_device): + if noise_device == 'cpu': + sampler_name = "dpmpp_3m_sde" + else: + sampler_name = "dpmpp_3m_sde_gpu" + sampler = comfy.samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise}) + return io.NodeOutput(sampler) + + +class SamplerDPMPP_2M_SDE(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SamplerDPMPP_2M_SDE_V3", + category="sampling/custom_sampling/samplers", + inputs=[ + io.Combo.Input("solver_type", options=['midpoint', 'heun']), + io.Float.Input("eta", default=1.0, min=0.0, max=100.0, step=0.01, round=False), + io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False), + io.Combo.Input("noise_device", options=['gpu', 'cpu']), + ], + outputs=[ + io.Sampler.Output(), + ] + ) + + @classmethod + def execute(cls, solver_type, eta, s_noise, noise_device): + if noise_device == 'cpu': + sampler_name = "dpmpp_2m_sde" + else: + sampler_name = "dpmpp_2m_sde_gpu" + sampler = comfy.samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise, "solver_type": solver_type}) + return io.NodeOutput(sampler) + + +class SamplerDPMPP_SDE(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SamplerDPMPP_SDE_V3", + category="sampling/custom_sampling/samplers", + inputs=[ + io.Float.Input("eta", default=1.0, min=0.0, max=100.0, step=0.01, round=False), + io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False), + io.Float.Input("r", default=0.5, min=0.0, max=100.0, step=0.01, round=False), + io.Combo.Input("noise_device", options=['gpu', 'cpu']), + ], + outputs=[ + io.Sampler.Output(), + ] + ) + + @classmethod + def execute(cls, eta, s_noise, r, noise_device): + if noise_device == 'cpu': + sampler_name = "dpmpp_sde" + else: + sampler_name = "dpmpp_sde_gpu" + sampler = comfy.samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise, "r": r}) + return io.NodeOutput(sampler) + + +class SamplerDPMPP_2S_Ancestral(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SamplerDPMPP_2S_Ancestral_V3", + category="sampling/custom_sampling/samplers", + inputs=[ + io.Float.Input("eta", default=1.0, min=0.0, max=100.0, step=0.01, round=False), + io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False), + ], + outputs=[ + io.Sampler.Output(), + ] + ) + + @classmethod + def execute(cls, eta, s_noise): + sampler = comfy.samplers.ksampler("dpmpp_2s_ancestral", {"eta": eta, "s_noise": s_noise}) + return io.NodeOutput(sampler) + + +class SamplerEulerAncestral(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SamplerEulerAncestral_V3", + category="sampling/custom_sampling/samplers", + inputs=[ + io.Float.Input("eta", default=1.0, min=0.0, max=100.0, step=0.01, round=False), + io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False), + ], + outputs=[ + io.Sampler.Output(), + ] + ) + + @classmethod + def execute(cls, eta, s_noise): + sampler = comfy.samplers.ksampler("euler_ancestral", {"eta": eta, "s_noise": s_noise}) + return io.NodeOutput(sampler) + + +class SamplerEulerAncestralCFGPP(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SamplerEulerAncestralCFGPP_V3", + display_name="SamplerEulerAncestralCFG++ _V3", + category="sampling/custom_sampling/samplers", + inputs=[ + io.Float.Input("eta", default=1.0, min=0.0, max=1.0, step=0.01, round=False), + io.Float.Input("s_noise", default=1.0, min=0.0, max=10.0, step=0.01, round=False), + ], + outputs=[ + io.Sampler.Output(), + ] + ) + + @classmethod + def execute(cls, eta, s_noise): + sampler = comfy.samplers.ksampler( + "euler_ancestral_cfg_pp", + {"eta": eta, "s_noise": s_noise}) + return io.NodeOutput(sampler) + + +class SamplerLMS(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SamplerLMS_V3", + category="sampling/custom_sampling/samplers", + inputs=[ + io.Int.Input("order", default=4, min=1, max=100), + ], + outputs=[ + io.Sampler.Output() + ] + ) + + @classmethod + def execute(cls, order): + sampler = comfy.samplers.ksampler("lms", {"order": order}) + return io.NodeOutput(sampler) + + +class SamplerDPMAdaptative(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SamplerDPMAdaptative_V3", + category="sampling/custom_sampling/samplers", + inputs=[ + io.Int.Input("order", default=3, min=2, max=3), + io.Float.Input("rtol", default=0.05, min=0.0, max=100.0, step=0.01, round=False), + io.Float.Input("atol", default=0.0078, min=0.0, max=100.0, step=0.01, round=False), + io.Float.Input("h_init", default=0.05, min=0.0, max=100.0, step=0.01, round=False), + io.Float.Input("pcoeff", default=0.0, min=0.0, max=100.0, step=0.01, round=False), + io.Float.Input("icoeff", default=1.0, min=0.0, max=100.0, step=0.01, round=False), + io.Float.Input("dcoeff", default=0.0, min=0.0, max=100.0, step=0.01, round=False), + io.Float.Input("accept_safety", default=0.81, min=0.0, max=100.0, step=0.01, round=False), + io.Float.Input("eta", default=0.0, min=0.0, max=100.0, step=0.01, round=False), + io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False), + ], + outputs=[ + io.Sampler.Output(), + ] + ) + + @classmethod + def execute(cls, order, rtol, atol, h_init, pcoeff, icoeff, dcoeff, accept_safety, eta, s_noise): + sampler = comfy.samplers.ksampler("dpm_adaptive", {"order": order, "rtol": rtol, "atol": atol, "h_init": h_init, "pcoeff": pcoeff, + "icoeff": icoeff, "dcoeff": dcoeff, "accept_safety": accept_safety, "eta": eta, + "s_noise":s_noise }) + return io.NodeOutput(sampler) + + +class SamplerER_SDE(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SamplerER_SDE_V3", + category="sampling/custom_sampling/samplers", + inputs=[ + io.Combo.Input("solver_type", options=["ER-SDE", "Reverse-time SDE", "ODE"]), + io.Int.Input("max_stage", default=3, min=1, max=3), + io.Float.Input("eta", default=1.0, min=0.0, max=100.0, step=0.01, round=False, tooltip="Stochastic strength of reverse-time SDE.\nWhen eta=0, it reduces to deterministic ODE. This setting doesn't apply to ER-SDE solver type."), + io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False), + ], + outputs=[ + io.Sampler.Output(), + ] + ) + + @classmethod + def execute(cls, solver_type, max_stage, eta, s_noise): + if solver_type == "ODE" or (solver_type == "Reverse-time SDE" and eta == 0): + eta = 0 + s_noise = 0 + + def reverse_time_sde_noise_scaler(x): + return x ** (eta + 1) + + if solver_type == "ER-SDE": + # Use the default one in sample_er_sde() + noise_scaler = None + else: + noise_scaler = reverse_time_sde_noise_scaler + + sampler_name = "er_sde" + sampler = comfy.samplers.ksampler(sampler_name, {"s_noise": s_noise, "noise_scaler": noise_scaler, "max_stage": max_stage}) + return io.NodeOutput(sampler) + + +class SamplerSASolver(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SamplerSASolver_V3", + category="sampling/custom_sampling/samplers", + inputs=[ + io.Model.Input("model"), + io.Float.Input("eta", default=1.0, min=0.0, max=10.0, step=0.01, round=False), + io.Float.Input("sde_start_percent", default=0.2, min=0.0, max=1.0, step=0.001), + io.Float.Input("sde_end_percent", default=0.8, min=0.0, max=1.0, step=0.001), + io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False), + io.Int.Input("predictor_order", default=3, min=1, max=6), + io.Int.Input("corrector_order", default=4, min=0, max=6), + io.Boolean.Input("use_pece"), + io.Boolean.Input("simple_order_2"), + ], + outputs=[ + io.Sampler.Output(), + ] + ) + + @classmethod + def execute(cls, model, eta, sde_start_percent, sde_end_percent, s_noise, predictor_order, corrector_order, use_pece, simple_order_2): + model_sampling = model.get_model_object("model_sampling") + start_sigma = model_sampling.percent_to_sigma(sde_start_percent) + end_sigma = model_sampling.percent_to_sigma(sde_end_percent) + tau_func = sa_solver.get_tau_interval_func(start_sigma, end_sigma, eta=eta) + + sampler_name = "sa_solver" + sampler = comfy.samplers.ksampler( + sampler_name, + { + "tau_func": tau_func, + "s_noise": s_noise, + "predictor_order": predictor_order, + "corrector_order": corrector_order, + "use_pece": use_pece, + "simple_order_2": simple_order_2, + }, + ) + return io.NodeOutput(sampler) + + +class Noise_EmptyNoise: + def __init__(self): + self.seed = 0 + + def generate_noise(self, input_latent): + latent_image = input_latent["samples"] + return torch.zeros(latent_image.shape, dtype=latent_image.dtype, layout=latent_image.layout, device="cpu") + + +class Noise_RandomNoise: + def __init__(self, seed): + self.seed = seed + + def generate_noise(self, input_latent): + latent_image = input_latent["samples"] + batch_inds = input_latent["batch_index"] if "batch_index" in input_latent else None + return comfy.sample.prepare_noise(latent_image, self.seed, batch_inds) + + +class SamplerCustom(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SamplerCustom_V3", + category="sampling/custom_sampling", + inputs=[ + io.Model.Input("model"), + io.Boolean.Input("add_noise", default=True), + io.Int.Input("noise_seed", default=0, min=0, max=0xffffffffffffffff, control_after_generate=True), + io.Float.Input("cfg", default=8.0, min=0.0, max=100.0, step=0.1, round=0.01), + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Sampler.Input("sampler"), + io.Sigmas.Input("sigmas"), + io.Latent.Input("latent_image"), + ], + outputs=[ + io.Latent.Output(display_name="output"), + io.Latent.Output(display_name="denoised_output"), + ] + ) + + @classmethod + def execute(cls, model, add_noise, noise_seed, cfg, positive, negative, sampler, sigmas, latent_image): + latent = latent_image + latent_image = latent["samples"] + latent = latent.copy() + latent_image = comfy.sample.fix_empty_latent_channels(model, latent_image) + latent["samples"] = latent_image + + if not add_noise: + noise = Noise_EmptyNoise().generate_noise(latent) + else: + noise = Noise_RandomNoise(noise_seed).generate_noise(latent) + + noise_mask = None + if "noise_mask" in latent: + noise_mask = latent["noise_mask"] + + x0_output = {} + callback = latent_preview.prepare_callback(model, sigmas.shape[-1] - 1, x0_output) + + disable_pbar = not comfy.utils.PROGRESS_BAR_ENABLED + samples = comfy.sample.sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent_image, noise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=noise_seed) + + out = latent.copy() + out["samples"] = samples + if "x0" in x0_output: + out_denoised = latent.copy() + out_denoised["samples"] = model.model.process_latent_out(x0_output["x0"].cpu()) + else: + out_denoised = out + return io.NodeOutput(out, out_denoised) + + +class Guider_Basic(comfy.samplers.CFGGuider): + def set_conds(self, positive): + self.inner_set_conds({"positive": positive}) + + +class BasicGuider(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="BasicGuider_V3", + category="sampling/custom_sampling/guiders", + inputs=[ + io.Model.Input("model"), + io.Conditioning.Input("conditioning"), + ], + outputs=[ + io.Guider.Output(), + ] + ) + + @classmethod + def execute(cls, model, conditioning): + guider = Guider_Basic(model) + guider.set_conds(conditioning) + return io.NodeOutput(guider) + + +class CFGGuider(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="CFGGuider_V3", + category="sampling/custom_sampling/guiders", + inputs=[ + io.Model.Input("model"), + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Float.Input("cfg", default=8.0, min=0.0, max=100.0, step=0.1, round=0.01), + ], + outputs=[ + io.Guider.Output(), + ] + ) + + @classmethod + def execute(cls, model, positive, negative, cfg): + guider = comfy.samplers.CFGGuider(model) + guider.set_conds(positive, negative) + guider.set_cfg(cfg) + return io.NodeOutput(guider) + + +class Guider_DualCFG(comfy.samplers.CFGGuider): + def set_cfg(self, cfg1, cfg2, nested=False): + self.cfg1 = cfg1 + self.cfg2 = cfg2 + self.nested = nested + + def set_conds(self, positive, middle, negative): + middle = node_helpers.conditioning_set_values(middle, {"prompt_type": "negative"}) + self.inner_set_conds({"positive": positive, "middle": middle, "negative": negative}) + + def predict_noise(self, x, timestep, model_options={}, seed=None): + negative_cond = self.conds.get("negative", None) + middle_cond = self.conds.get("middle", None) + positive_cond = self.conds.get("positive", None) + + if self.nested: + out = comfy.samplers.calc_cond_batch(self.inner_model, [negative_cond, middle_cond, positive_cond], x, timestep, model_options) + pred_text = comfy.samplers.cfg_function(self.inner_model, out[2], out[1], self.cfg1, x, timestep, model_options=model_options, cond=positive_cond, uncond=middle_cond) + return out[0] + self.cfg2 * (pred_text - out[0]) + else: + if model_options.get("disable_cfg1_optimization", False) is False: + if math.isclose(self.cfg2, 1.0): + negative_cond = None + if math.isclose(self.cfg1, 1.0): + middle_cond = None + + out = comfy.samplers.calc_cond_batch(self.inner_model, [negative_cond, middle_cond, positive_cond], x, timestep, model_options) + return comfy.samplers.cfg_function(self.inner_model, out[1], out[0], self.cfg2, x, timestep, model_options=model_options, cond=middle_cond, uncond=negative_cond) + (out[2] - out[1]) * self.cfg1 + + +class DualCFGGuider(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="DualCFGGuider_V3", + category="sampling/custom_sampling/guiders", + inputs=[ + io.Model.Input("model"), + io.Conditioning.Input("cond1"), + io.Conditioning.Input("cond2"), + io.Conditioning.Input("negative"), + io.Float.Input("cfg_conds", default=8.0, min=0.0, max=100.0, step=0.1, round=0.01), + io.Float.Input("cfg_cond2_negative", default=8.0, min=0.0, max=100.0, step=0.1, round=0.01), + io.Combo.Input("style", options=["regular", "nested"]), + ], + outputs=[ + io.Guider.Output(), + ] + ) + + @classmethod + def execute(cls, model, cond1, cond2, negative, cfg_conds, cfg_cond2_negative, style): + guider = Guider_DualCFG(model) + guider.set_conds(cond1, cond2, negative) + guider.set_cfg(cfg_conds, cfg_cond2_negative, nested=(style == "nested")) + return io.NodeOutput(guider) + + +class DisableNoise(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="DisableNoise_V3", + category="sampling/custom_sampling/noise", + inputs=[], + outputs=[ + io.Noise.Output(), + ] + ) + + @classmethod + def execute(cls): + return io.NodeOutput(Noise_EmptyNoise()) + + +class RandomNoise(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="RandomNoise_V3", + category="sampling/custom_sampling/noise", + inputs=[ + io.Int.Input("noise_seed", default=0, min=0, max=0xffffffffffffffff, control_after_generate=True), + ], + outputs=[ + io.Noise.Output(), + ] + ) + + @classmethod + def execute(cls, noise_seed): + return io.NodeOutput(Noise_RandomNoise(noise_seed)) + + +class SamplerCustomAdvanced(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SamplerCustomAdvanced_V3", + category="sampling/custom_sampling", + inputs=[ + io.Noise.Input("noise"), + io.Guider.Input("guider"), + io.Sampler.Input("sampler"), + io.Sigmas.Input("sigmas"), + io.Latent.Input("latent_image"), + ], + outputs=[ + io.Latent.Output(display_name="output"), + io.Latent.Output(display_name="denoised_output"), + ] + ) + + @classmethod + def execute(cls, noise, guider, sampler, sigmas, latent_image): + latent = latent_image + latent_image = latent["samples"] + latent = latent.copy() + latent_image = comfy.sample.fix_empty_latent_channels(guider.model_patcher, latent_image) + latent["samples"] = latent_image + + noise_mask = None + if "noise_mask" in latent: + noise_mask = latent["noise_mask"] + + x0_output = {} + callback = latent_preview.prepare_callback(guider.model_patcher, sigmas.shape[-1] - 1, x0_output) + + disable_pbar = not comfy.utils.PROGRESS_BAR_ENABLED + samples = guider.sample(noise.generate_noise(latent), latent_image, sampler, sigmas, denoise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=noise.seed) + samples = samples.to(comfy.model_management.intermediate_device()) + + out = latent.copy() + out["samples"] = samples + if "x0" in x0_output: + out_denoised = latent.copy() + out_denoised["samples"] = guider.model_patcher.model.process_latent_out(x0_output["x0"].cpu()) + else: + out_denoised = out + return io.NodeOutput(out, out_denoised) + + +class AddNoise(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="AddNoise_V3", + category="_for_testing/custom_sampling/noise", + is_experimental=True, + inputs=[ + io.Model.Input("model"), + io.Noise.Input("noise"), + io.Sigmas.Input("sigmas"), + io.Latent.Input("latent_image"), + ], + outputs=[ + io.Latent.Output(), + ] + ) + + @classmethod + def execute(cls, model, noise, sigmas, latent_image): + if len(sigmas) == 0: + return io.NodeOutput(latent_image) + + latent = latent_image + latent_image = latent["samples"] + + noisy = noise.generate_noise(latent) + + model_sampling = model.get_model_object("model_sampling") + process_latent_out = model.get_model_object("process_latent_out") + process_latent_in = model.get_model_object("process_latent_in") + + if len(sigmas) > 1: + scale = torch.abs(sigmas[0] - sigmas[-1]) + else: + scale = sigmas[0] + + if torch.count_nonzero(latent_image) > 0: #Don't shift the empty latent image. + latent_image = process_latent_in(latent_image) + noisy = model_sampling.noise_scaling(scale, noisy, latent_image) + noisy = process_latent_out(noisy) + noisy = torch.nan_to_num(noisy, nan=0.0, posinf=0.0, neginf=0.0) + + out = latent.copy() + out["samples"] = noisy + return io.NodeOutput(out) + + +NODES_LIST: list[type[io.ComfyNode]] = [ + AddNoise, + BasicGuider, + BasicScheduler, + BetaSamplingScheduler, + CFGGuider, + DisableNoise, + DualCFGGuider, + ExponentialScheduler, + ExtendIntermediateSigmas, + FlipSigmas, + KarrasScheduler, + KSamplerSelect, + LaplaceScheduler, + PolyexponentialScheduler, + RandomNoise, + SamplerCustom, + SamplerCustomAdvanced, + SamplerDPMAdaptative, + SamplerDPMPP_2M_SDE, + SamplerDPMPP_2S_Ancestral, + SamplerDPMPP_3M_SDE, + SamplerDPMPP_SDE, + SamplerER_SDE, + SamplerEulerAncestral, + SamplerEulerAncestralCFGPP, + SamplerLMS, + SamplerSASolver, + SamplingPercentToSigma, + SDTurboScheduler, + SetFirstSigma, + SplitSigmas, + SplitSigmasDenoise, + VPScheduler, +] diff --git a/comfy_extras/v3/nodes_differential_diffusion.py b/comfy_extras/v3/nodes_differential_diffusion.py index 6eb8cacbc..b4e5ecdc5 100644 --- a/comfy_extras/v3/nodes_differential_diffusion.py +++ b/comfy_extras/v3/nodes_differential_diffusion.py @@ -45,6 +45,6 @@ class DifferentialDiffusion(io.ComfyNode): return (denoise_mask >= threshold).to(denoise_mask.dtype) -NODES_LIST = [ +NODES_LIST: list[type[io.ComfyNode]] = [ DifferentialDiffusion, ] diff --git a/comfy_extras/v3/nodes_edit_model.py b/comfy_extras/v3/nodes_edit_model.py index 79dd672e3..b6164dc6a 100644 --- a/comfy_extras/v3/nodes_edit_model.py +++ b/comfy_extras/v3/nodes_edit_model.py @@ -29,6 +29,6 @@ class ReferenceLatent(io.ComfyNode): return io.NodeOutput(conditioning) -NODES_LIST = [ +NODES_LIST: list[type[io.ComfyNode]] = [ ReferenceLatent, ] diff --git a/comfy_extras/v3/nodes_flux.py b/comfy_extras/v3/nodes_flux.py index 3967fc4ad..f068f7b98 100644 --- a/comfy_extras/v3/nodes_flux.py +++ b/comfy_extras/v3/nodes_flux.py @@ -49,28 +49,6 @@ class CLIPTextEncodeFlux(io.ComfyNode): return io.NodeOutput(clip.encode_from_tokens_scheduled(tokens, add_dict={"guidance": guidance})) - -class FluxDisableGuidance(io.ComfyNode): - @classmethod - def define_schema(cls): - return io.Schema( - node_id="FluxDisableGuidance_V3", - category="advanced/conditioning/flux", - description="This node completely disables the guidance embed on Flux and Flux like models", - inputs=[ - io.Conditioning.Input("conditioning"), - ], - outputs=[ - io.Conditioning.Output(), - ], - ) - - @classmethod - def execute(cls, conditioning): - c = node_helpers.conditioning_set_values(conditioning, {"guidance": None}) - return io.NodeOutput(c) - - class FluxGuidance(io.ComfyNode): @classmethod def define_schema(cls): @@ -91,6 +69,25 @@ class FluxGuidance(io.ComfyNode): c = node_helpers.conditioning_set_values(conditioning, {"guidance": guidance}) return io.NodeOutput(c) +class FluxDisableGuidance(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="FluxDisableGuidance_V3", + category="advanced/conditioning/flux", + description="This node completely disables the guidance embed on Flux and Flux like models", + inputs=[ + io.Conditioning.Input("conditioning"), + ], + outputs=[ + io.Conditioning.Output(), + ], + ) + + @classmethod + def execute(cls, conditioning): + c = node_helpers.conditioning_set_values(conditioning, {"guidance": None}) + return io.NodeOutput(c) class FluxKontextImageScale(io.ComfyNode): @classmethod @@ -117,7 +114,7 @@ class FluxKontextImageScale(io.ComfyNode): return io.NodeOutput(image) -NODES_LIST = [ +NODES_LIST: list[type[io.ComfyNode]] = [ CLIPTextEncodeFlux, FluxDisableGuidance, FluxGuidance, diff --git a/comfy_extras/v3/nodes_freelunch.py b/comfy_extras/v3/nodes_freelunch.py index fe3e2c9dd..7467a1f88 100644 --- a/comfy_extras/v3/nodes_freelunch.py +++ b/comfy_extras/v3/nodes_freelunch.py @@ -125,7 +125,7 @@ class FreeU_V2(io.ComfyNode): return io.NodeOutput(m) -NODES_LIST = [ +NODES_LIST: list[type[io.ComfyNode]] = [ FreeU, FreeU_V2, ] diff --git a/comfy_extras/v3/nodes_fresca.py b/comfy_extras/v3/nodes_fresca.py index e9057fca5..c4115c84c 100644 --- a/comfy_extras/v3/nodes_fresca.py +++ b/comfy_extras/v3/nodes_fresca.py @@ -105,6 +105,6 @@ class FreSca(io.ComfyNode): return io.NodeOutput(m) -NODES_LIST = [ +NODES_LIST: list[type[io.ComfyNode]] = [ FreSca, ] diff --git a/comfy_extras/v3/nodes_gits.py b/comfy_extras/v3/nodes_gits.py index 2efb34763..4d500d789 100644 --- a/comfy_extras/v3/nodes_gits.py +++ b/comfy_extras/v3/nodes_gits.py @@ -371,6 +371,6 @@ class GITSScheduler(io.ComfyNode): return io.NodeOutput(torch.FloatTensor(sigmas)) -NODES_LIST = [ +NODES_LIST: list[type[io.ComfyNode]] = [ GITSScheduler, ] diff --git a/comfy_extras/v3/nodes_hidream.py b/comfy_extras/v3/nodes_hidream.py index 8afd3bb13..a7c733774 100644 --- a/comfy_extras/v3/nodes_hidream.py +++ b/comfy_extras/v3/nodes_hidream.py @@ -6,33 +6,6 @@ import folder_paths from comfy_api.latest import io -class CLIPTextEncodeHiDream(io.ComfyNode): - @classmethod - def define_schema(cls): - return io.Schema( - node_id="CLIPTextEncodeHiDream_V3", - category="advanced/conditioning", - inputs=[ - io.Clip.Input("clip"), - io.String.Input("clip_l", multiline=True, dynamic_prompts=True), - io.String.Input("clip_g", multiline=True, dynamic_prompts=True), - io.String.Input("t5xxl", multiline=True, dynamic_prompts=True), - io.String.Input("llama", multiline=True, dynamic_prompts=True), - ], - outputs=[ - io.Conditioning.Output(), - ] - ) - - @classmethod - def execute(cls, clip, clip_l, clip_g, t5xxl, llama): - tokens = clip.tokenize(clip_g) - tokens["l"] = clip.tokenize(clip_l)["l"] - tokens["t5xxl"] = clip.tokenize(t5xxl)["t5xxl"] - tokens["llama"] = clip.tokenize(llama)["llama"] - return io.NodeOutput(clip.encode_from_tokens_scheduled(tokens)) - - class QuadrupleCLIPLoader(io.ComfyNode): @classmethod def define_schema(cls): @@ -65,7 +38,34 @@ class QuadrupleCLIPLoader(io.ComfyNode): ) -NODES_LIST = [ +class CLIPTextEncodeHiDream(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="CLIPTextEncodeHiDream_V3", + category="advanced/conditioning", + inputs=[ + io.Clip.Input("clip"), + io.String.Input("clip_l", multiline=True, dynamic_prompts=True), + io.String.Input("clip_g", multiline=True, dynamic_prompts=True), + io.String.Input("t5xxl", multiline=True, dynamic_prompts=True), + io.String.Input("llama", multiline=True, dynamic_prompts=True), + ], + outputs=[ + io.Conditioning.Output(), + ] + ) + + @classmethod + def execute(cls, clip, clip_l, clip_g, t5xxl, llama): + tokens = clip.tokenize(clip_g) + tokens["l"] = clip.tokenize(clip_l)["l"] + tokens["t5xxl"] = clip.tokenize(t5xxl)["t5xxl"] + tokens["llama"] = clip.tokenize(llama)["llama"] + return io.NodeOutput(clip.encode_from_tokens_scheduled(tokens)) + + +NODES_LIST: list[type[io.ComfyNode]] = [ CLIPTextEncodeHiDream, QuadrupleCLIPLoader, ] diff --git a/comfy_extras/v3/nodes_hunyuan.py b/comfy_extras/v3/nodes_hunyuan.py index 1c2262a0e..4ad737d7b 100644 --- a/comfy_extras/v3/nodes_hunyuan.py +++ b/comfy_extras/v3/nodes_hunyuan.py @@ -7,16 +7,6 @@ import node_helpers import nodes from comfy_api.latest import io -PROMPT_TEMPLATE_ENCODE_VIDEO_I2V = ( - "<|start_header_id|>system<|end_header_id|>\n\n\nDescribe the video by detailing the following aspects according to the reference image: " - "1. The main content and theme of the video." - "2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects." - "3. Actions, events, behaviors temporal relationships, physical movement changes of the objects." - "4. background environment, light, style and atmosphere." - "5. camera angles, movements, and transitions used in the video:<|eot_id|>\n\n" - "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>" - "<|start_header_id|>assistant<|end_header_id|>\n\n" -) class CLIPTextEncodeHunyuanDiT(io.ComfyNode): @classmethod @@ -68,6 +58,51 @@ class EmptyHunyuanLatentVideo(io.ComfyNode): return io.NodeOutput({"samples":latent}) +PROMPT_TEMPLATE_ENCODE_VIDEO_I2V = ( + "<|start_header_id|>system<|end_header_id|>\n\n\nDescribe the video by detailing the following aspects according to the reference image: " + "1. The main content and theme of the video." + "2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects." + "3. Actions, events, behaviors temporal relationships, physical movement changes of the objects." + "4. background environment, light, style and atmosphere." + "5. camera angles, movements, and transitions used in the video:<|eot_id|>\n\n" + "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>" + "<|start_header_id|>assistant<|end_header_id|>\n\n" +) + + +class TextEncodeHunyuanVideo_ImageToVideo(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="TextEncodeHunyuanVideo_ImageToVideo_V3", + category="advanced/conditioning", + inputs=[ + io.Clip.Input("clip"), + io.ClipVisionOutput.Input("clip_vision_output"), + io.String.Input("prompt", multiline=True, dynamic_prompts=True), + io.Int.Input( + "image_interleave", + default=2, + min=1, + max=512, + tooltip="How much the image influences things vs the text prompt. Higher number means more influence from the text prompt.", + ), + ], + outputs=[ + io.Conditioning.Output(), + ], + ) + + @classmethod + def execute(cls, clip, clip_vision_output, prompt, image_interleave): + tokens = clip.tokenize( + prompt, llama_template=PROMPT_TEMPLATE_ENCODE_VIDEO_I2V, + image_embeds=clip_vision_output.mm_projected, + image_interleave=image_interleave, + ) + return io.NodeOutput(clip.encode_from_tokens_scheduled(tokens)) + + class HunyuanImageToVideo(io.ComfyNode): @classmethod def define_schema(cls): @@ -126,40 +161,7 @@ class HunyuanImageToVideo(io.ComfyNode): return io.NodeOutput(positive, out_latent) -class TextEncodeHunyuanVideo_ImageToVideo(io.ComfyNode): - @classmethod - def define_schema(cls): - return io.Schema( - node_id="TextEncodeHunyuanVideo_ImageToVideo_V3", - category="advanced/conditioning", - inputs=[ - io.Clip.Input("clip"), - io.ClipVisionOutput.Input("clip_vision_output"), - io.String.Input("prompt", multiline=True, dynamic_prompts=True), - io.Int.Input( - "image_interleave", - default=2, - min=1, - max=512, - tooltip="How much the image influences things vs the text prompt. Higher number means more influence from the text prompt.", - ), - ], - outputs=[ - io.Conditioning.Output(), - ], - ) - - @classmethod - def execute(cls, clip, clip_vision_output, prompt, image_interleave): - tokens = clip.tokenize( - prompt, llama_template=PROMPT_TEMPLATE_ENCODE_VIDEO_I2V, - image_embeds=clip_vision_output.mm_projected, - image_interleave=image_interleave, - ) - return io.NodeOutput(clip.encode_from_tokens_scheduled(tokens)) - - -NODES_LIST = [ +NODES_LIST: list[type[io.ComfyNode]] = [ CLIPTextEncodeHunyuanDiT, EmptyHunyuanLatentVideo, HunyuanImageToVideo, diff --git a/comfy_extras/v3/nodes_hunyuan3d.py b/comfy_extras/v3/nodes_hunyuan3d.py new file mode 100644 index 000000000..a4594c4c2 --- /dev/null +++ b/comfy_extras/v3/nodes_hunyuan3d.py @@ -0,0 +1,672 @@ +from __future__ import annotations + +import json +import os +import struct + +import numpy as np +import torch + +import comfy.model_management +import folder_paths +from comfy.cli_args import args +from comfy.ldm.modules.diffusionmodules.mmdit import ( + get_1d_sincos_pos_embed_from_grid_torch, +) +from comfy_api.latest import io + + +class VOXEL: + def __init__(self, data): + self.data = data + + +class MESH: + def __init__(self, vertices, faces): + self.vertices = vertices + self.faces = faces + + +def voxel_to_mesh(voxels, threshold=0.5, device=None): + if device is None: + device = torch.device("cpu") + voxels = voxels.to(device) + + binary = (voxels > threshold).float() + padded = torch.nn.functional.pad(binary, (1, 1, 1, 1, 1, 1), 'constant', 0) + + D, H, W = binary.shape + + neighbors = torch.tensor([ + [0, 0, 1], + [0, 0, -1], + [0, 1, 0], + [0, -1, 0], + [1, 0, 0], + [-1, 0, 0] + ], device=device) + + z, y, x = torch.meshgrid( + torch.arange(D, device=device), + torch.arange(H, device=device), + torch.arange(W, device=device), + indexing='ij' + ) + voxel_indices = torch.stack([z.flatten(), y.flatten(), x.flatten()], dim=1) + + solid_mask = binary.flatten() > 0 + solid_indices = voxel_indices[solid_mask] + + corner_offsets = [ + torch.tensor([ + [0, 0, 1], [0, 1, 1], [1, 1, 1], [1, 0, 1] + ], device=device), + torch.tensor([ + [0, 0, 0], [1, 0, 0], [1, 1, 0], [0, 1, 0] + ], device=device), + torch.tensor([ + [0, 1, 0], [1, 1, 0], [1, 1, 1], [0, 1, 1] + ], device=device), + torch.tensor([ + [0, 0, 0], [0, 0, 1], [1, 0, 1], [1, 0, 0] + ], device=device), + torch.tensor([ + [1, 0, 1], [1, 1, 1], [1, 1, 0], [1, 0, 0] + ], device=device), + torch.tensor([ + [0, 1, 0], [0, 1, 1], [0, 0, 1], [0, 0, 0] + ], device=device) + ] + + all_vertices = [] + all_indices = [] + + vertex_count = 0 + + for face_idx, offset in enumerate(neighbors): + neighbor_indices = solid_indices + offset + + padded_indices = neighbor_indices + 1 + + is_exposed = padded[ + padded_indices[:, 0], + padded_indices[:, 1], + padded_indices[:, 2] + ] == 0 + + if not is_exposed.any(): + continue + + exposed_indices = solid_indices[is_exposed] + + corners = corner_offsets[face_idx].unsqueeze(0) + + face_vertices = exposed_indices.unsqueeze(1) + corners + + all_vertices.append(face_vertices.reshape(-1, 3)) + + num_faces = exposed_indices.shape[0] + face_indices = torch.arange( + vertex_count, + vertex_count + 4 * num_faces, + device=device + ).reshape(-1, 4) + + all_indices.append(torch.stack([face_indices[:, 0], face_indices[:, 1], face_indices[:, 2]], dim=1)) + all_indices.append(torch.stack([face_indices[:, 0], face_indices[:, 2], face_indices[:, 3]], dim=1)) + + vertex_count += 4 * num_faces + + if len(all_vertices) > 0: + vertices = torch.cat(all_vertices, dim=0) + faces = torch.cat(all_indices, dim=0) + else: + vertices = torch.zeros((1, 3)) + faces = torch.zeros((1, 3)) + + v_min = 0 + v_max = max(voxels.shape) + + vertices = vertices - (v_min + v_max) / 2 + + scale = (v_max - v_min) / 2 + if scale > 0: + vertices = vertices / scale + + vertices = torch.fliplr(vertices) + return vertices, faces + +def voxel_to_mesh_surfnet(voxels, threshold=0.5, device=None): + if device is None: + device = torch.device("cpu") + voxels = voxels.to(device) + + D, H, W = voxels.shape + + padded = torch.nn.functional.pad(voxels, (1, 1, 1, 1, 1, 1), 'constant', 0) + z, y, x = torch.meshgrid( + torch.arange(D, device=device), + torch.arange(H, device=device), + torch.arange(W, device=device), + indexing='ij' + ) + cell_positions = torch.stack([z.flatten(), y.flatten(), x.flatten()], dim=1) + + corner_offsets = torch.tensor([ + [0, 0, 0], [1, 0, 0], [0, 1, 0], [1, 1, 0], + [0, 0, 1], [1, 0, 1], [0, 1, 1], [1, 1, 1] + ], device=device) + + corner_values = torch.zeros((cell_positions.shape[0], 8), device=device) + for c, (dz, dy, dx) in enumerate(corner_offsets): + corner_values[:, c] = padded[ + cell_positions[:, 0] + dz, + cell_positions[:, 1] + dy, + cell_positions[:, 2] + dx + ] + + corner_signs = corner_values > threshold + has_inside = torch.any(corner_signs, dim=1) + has_outside = torch.any(~corner_signs, dim=1) + contains_surface = has_inside & has_outside + + active_cells = cell_positions[contains_surface] + active_signs = corner_signs[contains_surface] + active_values = corner_values[contains_surface] + + if active_cells.shape[0] == 0: + return torch.zeros((0, 3), device=device), torch.zeros((0, 3), dtype=torch.long, device=device) + + edges = torch.tensor([ + [0, 1], [0, 2], [0, 4], [1, 3], + [1, 5], [2, 3], [2, 6], [3, 7], + [4, 5], [4, 6], [5, 7], [6, 7] + ], device=device) + + cell_vertices = {} + progress = comfy.utils.ProgressBar(100) + + for edge_idx, (e1, e2) in enumerate(edges): + progress.update(1) + crossing = active_signs[:, e1] != active_signs[:, e2] + if not crossing.any(): + continue + + cell_indices = torch.nonzero(crossing, as_tuple=True)[0] + + v1 = active_values[cell_indices, e1] + v2 = active_values[cell_indices, e2] + + t = torch.zeros_like(v1, device=device) + denom = v2 - v1 + valid = denom != 0 + t[valid] = (threshold - v1[valid]) / denom[valid] + t[~valid] = 0.5 + + p1 = corner_offsets[e1].float() + p2 = corner_offsets[e2].float() + + intersection = p1.unsqueeze(0) + t.unsqueeze(1) * (p2.unsqueeze(0) - p1.unsqueeze(0)) + + for i, point in zip(cell_indices.tolist(), intersection): + if i not in cell_vertices: + cell_vertices[i] = [] + cell_vertices[i].append(point) + + # Calculate the final vertices as the average of intersection points for each cell + vertices = [] + vertex_lookup = {} + + vert_progress_mod = round(len(cell_vertices)/50) + + for i, points in cell_vertices.items(): + if not i % vert_progress_mod: + progress.update(1) + + if points: + vertex = torch.stack(points).mean(dim=0) + vertex = vertex + active_cells[i].float() + vertex_lookup[tuple(active_cells[i].tolist())] = len(vertices) + vertices.append(vertex) + + if not vertices: + return torch.zeros((0, 3), device=device), torch.zeros((0, 3), dtype=torch.long, device=device) + + final_vertices = torch.stack(vertices) + + inside_corners_mask = active_signs + outside_corners_mask = ~active_signs + + inside_counts = inside_corners_mask.sum(dim=1, keepdim=True).float() + outside_counts = outside_corners_mask.sum(dim=1, keepdim=True).float() + + inside_pos = torch.zeros((active_cells.shape[0], 3), device=device) + outside_pos = torch.zeros((active_cells.shape[0], 3), device=device) + + for i in range(8): + mask_inside = inside_corners_mask[:, i].unsqueeze(1) + mask_outside = outside_corners_mask[:, i].unsqueeze(1) + inside_pos += corner_offsets[i].float().unsqueeze(0) * mask_inside + outside_pos += corner_offsets[i].float().unsqueeze(0) * mask_outside + + inside_pos /= inside_counts + outside_pos /= outside_counts + gradients = inside_pos - outside_pos + + pos_dirs = torch.tensor([ + [1, 0, 0], + [0, 1, 0], + [0, 0, 1] + ], device=device) + + cross_products = [ + torch.linalg.cross(pos_dirs[i].float(), pos_dirs[j].float()) + for i in range(3) for j in range(i+1, 3) + ] + + faces = [] + all_keys = set(vertex_lookup.keys()) + + face_progress_mod = round(len(active_cells)/38*3) + + for pair_idx, (i, j) in enumerate([(0,1), (0,2), (1,2)]): + dir_i = pos_dirs[i] + dir_j = pos_dirs[j] + cross_product = cross_products[pair_idx] + + ni_positions = active_cells + dir_i + nj_positions = active_cells + dir_j + diag_positions = active_cells + dir_i + dir_j + + alignments = torch.matmul(gradients, cross_product) + + valid_quads = [] + quad_indices = [] + + for idx, active_cell in enumerate(active_cells): + if not idx % face_progress_mod: + progress.update(1) + cell_key = tuple(active_cell.tolist()) + ni_key = tuple(ni_positions[idx].tolist()) + nj_key = tuple(nj_positions[idx].tolist()) + diag_key = tuple(diag_positions[idx].tolist()) + + if cell_key in all_keys and ni_key in all_keys and nj_key in all_keys and diag_key in all_keys: + v0 = vertex_lookup[cell_key] + v1 = vertex_lookup[ni_key] + v2 = vertex_lookup[nj_key] + v3 = vertex_lookup[diag_key] + + valid_quads.append((v0, v1, v2, v3)) + quad_indices.append(idx) + + for q_idx, (v0, v1, v2, v3) in enumerate(valid_quads): + cell_idx = quad_indices[q_idx] + if alignments[cell_idx] > 0: + faces.append(torch.tensor([v0, v1, v3], device=device, dtype=torch.long)) + faces.append(torch.tensor([v0, v3, v2], device=device, dtype=torch.long)) + else: + faces.append(torch.tensor([v0, v3, v1], device=device, dtype=torch.long)) + faces.append(torch.tensor([v0, v2, v3], device=device, dtype=torch.long)) + + if faces: + faces = torch.stack(faces) + else: + faces = torch.zeros((0, 3), dtype=torch.long, device=device) + + v_min = 0 + v_max = max(D, H, W) + + final_vertices = final_vertices - (v_min + v_max) / 2 + + scale = (v_max - v_min) / 2 + if scale > 0: + final_vertices = final_vertices / scale + + final_vertices = torch.fliplr(final_vertices) + + return final_vertices, faces + + +def save_glb(vertices, faces, filepath, metadata=None): + """ + Save PyTorch tensor vertices and faces as a GLB file without external dependencies. + + Parameters: + vertices: torch.Tensor of shape (N, 3) - The vertex coordinates + faces: torch.Tensor of shape (M, 3) - The face indices (triangle faces) + filepath: str - Output filepath (should end with .glb) + """ + + # Convert tensors to numpy arrays + vertices_np = vertices.cpu().numpy().astype(np.float32) + faces_np = faces.cpu().numpy().astype(np.uint32) + + vertices_buffer = vertices_np.tobytes() + indices_buffer = faces_np.tobytes() + + def pad_to_4_bytes(buffer): + padding_length = (4 - (len(buffer) % 4)) % 4 + return buffer + b'\x00' * padding_length + + vertices_buffer_padded = pad_to_4_bytes(vertices_buffer) + indices_buffer_padded = pad_to_4_bytes(indices_buffer) + + buffer_data = vertices_buffer_padded + indices_buffer_padded + + vertices_byte_length = len(vertices_buffer) + vertices_byte_offset = 0 + indices_byte_length = len(indices_buffer) + indices_byte_offset = len(vertices_buffer_padded) + + gltf = { + "asset": {"version": "2.0", "generator": "ComfyUI"}, + "buffers": [ + { + "byteLength": len(buffer_data) + } + ], + "bufferViews": [ + { + "buffer": 0, + "byteOffset": vertices_byte_offset, + "byteLength": vertices_byte_length, + "target": 34962 # ARRAY_BUFFER + }, + { + "buffer": 0, + "byteOffset": indices_byte_offset, + "byteLength": indices_byte_length, + "target": 34963 # ELEMENT_ARRAY_BUFFER + } + ], + "accessors": [ + { + "bufferView": 0, + "byteOffset": 0, + "componentType": 5126, # FLOAT + "count": len(vertices_np), + "type": "VEC3", + "max": vertices_np.max(axis=0).tolist(), + "min": vertices_np.min(axis=0).tolist() + }, + { + "bufferView": 1, + "byteOffset": 0, + "componentType": 5125, # UNSIGNED_INT + "count": faces_np.size, + "type": "SCALAR" + } + ], + "meshes": [ + { + "primitives": [ + { + "attributes": { + "POSITION": 0 + }, + "indices": 1, + "mode": 4 # TRIANGLES + } + ] + } + ], + "nodes": [ + { + "mesh": 0 + } + ], + "scenes": [ + { + "nodes": [0] + } + ], + "scene": 0 + } + + if metadata is not None: + gltf["asset"]["extras"] = metadata + + # Convert the JSON to bytes + gltf_json = json.dumps(gltf).encode('utf8') + + def pad_json_to_4_bytes(buffer): + padding_length = (4 - (len(buffer) % 4)) % 4 + return buffer + b' ' * padding_length + + gltf_json_padded = pad_json_to_4_bytes(gltf_json) + + # Create the GLB header + # Magic glTF + glb_header = struct.pack('<4sII', b'glTF', 2, 12 + 8 + len(gltf_json_padded) + 8 + len(buffer_data)) + + # Create JSON chunk header (chunk type 0) + json_chunk_header = struct.pack('