diff --git a/comfy_extras/v3/nodes_lumina2.py b/comfy_extras/v3/nodes_lumina2.py index 470537ccc..d31895bdd 100644 --- a/comfy_extras/v3/nodes_lumina2.py +++ b/comfy_extras/v3/nodes_lumina2.py @@ -5,50 +5,6 @@ import torch from comfy_api.latest import io -class CLIPTextEncodeLumina2(io.ComfyNode): - SYSTEM_PROMPT = { - "superior": "You are an assistant designed to generate superior images with the superior " - "degree of image-text alignment based on textual prompts or user prompts.", - "alignment": "You are an assistant designed to generate high-quality images with the " - "highest degree of image-text alignment based on textual prompts." - } - SYSTEM_PROMPT_TIP = "Lumina2 provide two types of system prompts:" \ - "Superior: You are an assistant designed to generate superior images with the superior "\ - "degree of image-text alignment based on textual prompts or user prompts. "\ - "Alignment: You are an assistant designed to generate high-quality images with the highest "\ - "degree of image-text alignment based on textual prompts." - - @classmethod - def define_schema(cls): - return io.Schema( - node_id="CLIPTextEncodeLumina2_V3", - display_name="CLIP Text Encode for Lumina2 _V3", - category="conditioning", - description="Encodes a system prompt and a user prompt using a CLIP model into an embedding " - "that can be used to guide the diffusion model towards generating specific images.", - inputs=[ - io.Combo.Input("system_prompt", options=list(cls.SYSTEM_PROMPT.keys()), tooltip=cls.SYSTEM_PROMPT_TIP), - io.String.Input("user_prompt", multiline=True, dynamic_prompts=True, tooltip="The text to be encoded."), - io.Clip.Input("clip", tooltip="The CLIP model used for encoding the text."), - ], - outputs=[ - io.Conditioning.Output(tooltip="A conditioning containing the embedded text used to guide the diffusion model."), - ], - ) - - @classmethod - def execute(cls, system_prompt, user_prompt, clip): - if clip is None: - raise RuntimeError( - "ERROR: clip input is invalid: None\n\n" - "If the clip is from a checkpoint loader node your checkpoint does not contain a valid clip or text encoder model." - ) - system_prompt = cls.SYSTEM_PROMPT[system_prompt] - prompt = f'{system_prompt} {user_prompt}' - tokens = clip.tokenize(prompt) - return io.NodeOutput(clip.encode_from_tokens_scheduled(tokens)) - - class RenormCFG(io.ComfyNode): @classmethod def define_schema(cls): @@ -110,7 +66,51 @@ class RenormCFG(io.ComfyNode): return io.NodeOutput(m) -NODES_LIST = [ +class CLIPTextEncodeLumina2(io.ComfyNode): + SYSTEM_PROMPT = { + "superior": "You are an assistant designed to generate superior images with the superior " + "degree of image-text alignment based on textual prompts or user prompts.", + "alignment": "You are an assistant designed to generate high-quality images with the " + "highest degree of image-text alignment based on textual prompts." + } + SYSTEM_PROMPT_TIP = "Lumina2 provide two types of system prompts:" \ + "Superior: You are an assistant designed to generate superior images with the superior " \ + "degree of image-text alignment based on textual prompts or user prompts. " \ + "Alignment: You are an assistant designed to generate high-quality images with the highest " \ + "degree of image-text alignment based on textual prompts." + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="CLIPTextEncodeLumina2_V3", + display_name="CLIP Text Encode for Lumina2 _V3", + category="conditioning", + description="Encodes a system prompt and a user prompt using a CLIP model into an embedding " + "that can be used to guide the diffusion model towards generating specific images.", + inputs=[ + io.Combo.Input("system_prompt", options=list(cls.SYSTEM_PROMPT.keys()), tooltip=cls.SYSTEM_PROMPT_TIP), + io.String.Input("user_prompt", multiline=True, dynamic_prompts=True, tooltip="The text to be encoded."), + io.Clip.Input("clip", tooltip="The CLIP model used for encoding the text."), + ], + outputs=[ + io.Conditioning.Output(tooltip="A conditioning containing the embedded text used to guide the diffusion model."), + ], + ) + + @classmethod + def execute(cls, system_prompt, user_prompt, clip): + if clip is None: + raise RuntimeError( + "ERROR: clip input is invalid: None\n\n" + "If the clip is from a checkpoint loader node your checkpoint does not contain a valid clip or text encoder model." + ) + system_prompt = cls.SYSTEM_PROMPT[system_prompt] + prompt = f'{system_prompt} {user_prompt}' + tokens = clip.tokenize(prompt) + return io.NodeOutput(clip.encode_from_tokens_scheduled(tokens)) + + +NODES_LIST: list[type[io.ComfyNode]] = [ CLIPTextEncodeLumina2, RenormCFG, ] diff --git a/comfy_extras/v3/nodes_model_advanced.py b/comfy_extras/v3/nodes_model_advanced.py index 936adf8db..9b2855e61 100644 --- a/comfy_extras/v3/nodes_model_advanced.py +++ b/comfy_extras/v3/nodes_model_advanced.py @@ -57,15 +57,16 @@ class ModelSamplingDiscreteDistilled(comfy.model_sampling.ModelSamplingDiscrete) return log_sigma.exp().to(timestep.device) -class ModelComputeDtype(io.ComfyNode): +class ModelSamplingDiscrete(io.ComfyNode): @classmethod def define_schema(cls): return io.Schema( - node_id="ModelComputeDtype_V3", - category="advanced/debug/model", + node_id="ModelSamplingDiscrete_V3", + category="advanced/model", inputs=[ io.Model.Input("model"), - io.Combo.Input("dtype", options=["default", "fp32", "fp16", "bf16"]), + io.Combo.Input("sampling", options=["eps", "v_prediction", "lcm", "x0", "img_to_img"]), + io.Boolean.Input("zsnr", default=False), ], outputs=[ io.Model.Output(), @@ -73,9 +74,150 @@ class ModelComputeDtype(io.ComfyNode): ) @classmethod - def execute(cls, model, dtype): + def execute(cls, model, sampling, zsnr): m = model.clone() - m.set_model_compute_dtype(node_helpers.string_to_torch_dtype(dtype)) + + sampling_base = comfy.model_sampling.ModelSamplingDiscrete + if sampling == "eps": + sampling_type = comfy.model_sampling.EPS + elif sampling == "v_prediction": + sampling_type = comfy.model_sampling.V_PREDICTION + elif sampling == "lcm": + sampling_type = LCM + sampling_base = ModelSamplingDiscreteDistilled + elif sampling == "x0": + sampling_type = comfy.model_sampling.X0 + elif sampling == "img_to_img": + sampling_type = comfy.model_sampling.IMG_TO_IMG + + class ModelSamplingAdvanced(sampling_base, sampling_type): + pass + + model_sampling = ModelSamplingAdvanced(model.model.model_config, zsnr=zsnr) + + m.add_object_patch("model_sampling", model_sampling) + return io.NodeOutput(m) + + +class ModelSamplingStableCascade(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="ModelSamplingStableCascade_V3", + category="advanced/model", + inputs=[ + io.Model.Input("model"), + io.Float.Input("shift", default=2.0, min=0.0, max=100.0, step=0.01), + ], + outputs=[ + io.Model.Output(), + ], + ) + + @classmethod + def execute(cls, model, shift): + m = model.clone() + + sampling_base = comfy.model_sampling.StableCascadeSampling + sampling_type = comfy.model_sampling.EPS + + class ModelSamplingAdvanced(sampling_base, sampling_type): + pass + + model_sampling = ModelSamplingAdvanced(model.model.model_config) + model_sampling.set_parameters(shift) + m.add_object_patch("model_sampling", model_sampling) + return io.NodeOutput(m) + + +class ModelSamplingSD3(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="ModelSamplingSD3_V3", + category="advanced/model", + inputs=[ + io.Model.Input("model"), + io.Float.Input("shift", default=3.0, min=0.0, max=100.0, step=0.01), + ], + outputs=[ + io.Model.Output(), + ], + ) + + @classmethod + def execute(cls, model, shift, multiplier: int | float = 1000): + m = model.clone() + + sampling_base = comfy.model_sampling.ModelSamplingDiscreteFlow + sampling_type = comfy.model_sampling.CONST + + class ModelSamplingAdvanced(sampling_base, sampling_type): + pass + + model_sampling = ModelSamplingAdvanced(model.model.model_config) + model_sampling.set_parameters(shift=shift, multiplier=multiplier) + m.add_object_patch("model_sampling", model_sampling) + return io.NodeOutput(m) + + +class ModelSamplingAuraFlow(ModelSamplingSD3): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="ModelSamplingAuraFlow_V3", + category="advanced/model", + inputs=[ + io.Model.Input("model"), + io.Float.Input("shift", default=1.73, min=0.0, max=100.0, step=0.01), + ], + outputs=[ + io.Model.Output(), + ], + ) + + @classmethod + def execute(cls, model, shift, multiplier: int | float = 1.0): + return super().execute(model, shift, multiplier) + + +class ModelSamplingFlux(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="ModelSamplingFlux_V3", + category="advanced/model", + inputs=[ + io.Model.Input("model"), + io.Float.Input("max_shift", default=1.15, min=0.0, max=100.0, step=0.01), + io.Float.Input("base_shift", default=0.5, min=0.0, max=100.0, step=0.01), + io.Int.Input("width", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=8), + io.Int.Input("height", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=8), + ], + outputs=[ + io.Model.Output(), + ], + ) + + @classmethod + def execute(cls, model, max_shift, base_shift, width, height): + m = model.clone() + + x1 = 256 + x2 = 4096 + mm = (max_shift - base_shift) / (x2 - x1) + b = base_shift - mm * x1 + shift = (width * height / (8 * 8 * 2 * 2)) * mm + b + + sampling_base = comfy.model_sampling.ModelSamplingFlux + sampling_type = comfy.model_sampling.CONST + + class ModelSamplingAdvanced(sampling_base, sampling_type): + pass + + model_sampling = ModelSamplingAdvanced(model.model.model_config) + model_sampling.set_parameters(shift=shift) + m.add_object_patch("model_sampling", model_sampling) return io.NodeOutput(m) @@ -165,170 +307,6 @@ class ModelSamplingContinuousV(io.ComfyNode): return io.NodeOutput(m) -class ModelSamplingDiscrete(io.ComfyNode): - @classmethod - def define_schema(cls): - return io.Schema( - node_id="ModelSamplingDiscrete_V3", - category="advanced/model", - inputs=[ - io.Model.Input("model"), - io.Combo.Input("sampling", options=["eps", "v_prediction", "lcm", "x0", "img_to_img"]), - io.Boolean.Input("zsnr", default=False), - ], - outputs=[ - io.Model.Output(), - ], - ) - - @classmethod - def execute(cls, model, sampling, zsnr): - m = model.clone() - - sampling_base = comfy.model_sampling.ModelSamplingDiscrete - if sampling == "eps": - sampling_type = comfy.model_sampling.EPS - elif sampling == "v_prediction": - sampling_type = comfy.model_sampling.V_PREDICTION - elif sampling == "lcm": - sampling_type = LCM - sampling_base = ModelSamplingDiscreteDistilled - elif sampling == "x0": - sampling_type = comfy.model_sampling.X0 - elif sampling == "img_to_img": - sampling_type = comfy.model_sampling.IMG_TO_IMG - - class ModelSamplingAdvanced(sampling_base, sampling_type): - pass - - model_sampling = ModelSamplingAdvanced(model.model.model_config, zsnr=zsnr) - - m.add_object_patch("model_sampling", model_sampling) - return io.NodeOutput(m) - - -class ModelSamplingFlux(io.ComfyNode): - @classmethod - def define_schema(cls): - return io.Schema( - node_id="ModelSamplingFlux_V3", - category="advanced/model", - inputs=[ - io.Model.Input("model"), - io.Float.Input("max_shift", default=1.15, min=0.0, max=100.0, step=0.01), - io.Float.Input("base_shift", default=0.5, min=0.0, max=100.0, step=0.01), - io.Int.Input("width", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=8), - io.Int.Input("height", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=8), - ], - outputs=[ - io.Model.Output(), - ], - ) - - @classmethod - def execute(cls, model, max_shift, base_shift, width, height): - m = model.clone() - - x1 = 256 - x2 = 4096 - mm = (max_shift - base_shift) / (x2 - x1) - b = base_shift - mm * x1 - shift = (width * height / (8 * 8 * 2 * 2)) * mm + b - - sampling_base = comfy.model_sampling.ModelSamplingFlux - sampling_type = comfy.model_sampling.CONST - - class ModelSamplingAdvanced(sampling_base, sampling_type): - pass - - model_sampling = ModelSamplingAdvanced(model.model.model_config) - model_sampling.set_parameters(shift=shift) - m.add_object_patch("model_sampling", model_sampling) - return io.NodeOutput(m) - - -class ModelSamplingSD3(io.ComfyNode): - @classmethod - def define_schema(cls): - return io.Schema( - node_id="ModelSamplingSD3_V3", - category="advanced/model", - inputs=[ - io.Model.Input("model"), - io.Float.Input("shift", default=3.0, min=0.0, max=100.0, step=0.01), - ], - outputs=[ - io.Model.Output(), - ], - ) - - @classmethod - def execute(cls, model, shift, multiplier: int | float = 1000): - m = model.clone() - - sampling_base = comfy.model_sampling.ModelSamplingDiscreteFlow - sampling_type = comfy.model_sampling.CONST - - class ModelSamplingAdvanced(sampling_base, sampling_type): - pass - - model_sampling = ModelSamplingAdvanced(model.model.model_config) - model_sampling.set_parameters(shift=shift, multiplier=multiplier) - m.add_object_patch("model_sampling", model_sampling) - return io.NodeOutput(m) - - -class ModelSamplingAuraFlow(ModelSamplingSD3): - @classmethod - def define_schema(cls): - return io.Schema( - node_id="ModelSamplingAuraFlow_V3", - category="advanced/model", - inputs=[ - io.Model.Input("model"), - io.Float.Input("shift", default=1.73, min=0.0, max=100.0, step=0.01), - ], - outputs=[ - io.Model.Output(), - ], - ) - - @classmethod - def execute(cls, model, shift, multiplier: int | float = 1.0): - return super().execute(model, shift, multiplier) - - -class ModelSamplingStableCascade(io.ComfyNode): - @classmethod - def define_schema(cls): - return io.Schema( - node_id="ModelSamplingStableCascade_V3", - category="advanced/model", - inputs=[ - io.Model.Input("model"), - io.Float.Input("shift", default=2.0, min=0.0, max=100.0, step=0.01), - ], - outputs=[ - io.Model.Output(), - ], - ) - - @classmethod - def execute(cls, model, shift): - m = model.clone() - - sampling_base = comfy.model_sampling.StableCascadeSampling - sampling_type = comfy.model_sampling.EPS - - class ModelSamplingAdvanced(sampling_base, sampling_type): - pass - - model_sampling = ModelSamplingAdvanced(model.model.model_config) - model_sampling.set_parameters(shift) - m.add_object_patch("model_sampling", model_sampling) - return io.NodeOutput(m) - - class RescaleCFG(io.ComfyNode): @classmethod def define_schema(cls): @@ -374,7 +352,29 @@ class RescaleCFG(io.ComfyNode): return io.NodeOutput(m) -NODES_LIST = [ +class ModelComputeDtype(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="ModelComputeDtype_V3", + category="advanced/debug/model", + inputs=[ + io.Model.Input("model"), + io.Combo.Input("dtype", options=["default", "fp32", "fp16", "bf16"]), + ], + outputs=[ + io.Model.Output(), + ], + ) + + @classmethod + def execute(cls, model, dtype): + m = model.clone() + m.set_model_compute_dtype(node_helpers.string_to_torch_dtype(dtype)) + return io.NodeOutput(m) + + +NODES_LIST: list[type[io.ComfyNode]] = [ ModelSamplingAuraFlow, ModelComputeDtype, ModelSamplingContinuousEDM, diff --git a/comfy_extras/v3/nodes_model_merging.py b/comfy_extras/v3/nodes_model_merging.py index 101629ad5..5cf62d869 100644 --- a/comfy_extras/v3/nodes_model_merging.py +++ b/comfy_extras/v3/nodes_model_merging.py @@ -75,52 +75,76 @@ def save_checkpoint(model, clip=None, vae=None, clip_vision=None, filename_prefi comfy.sd.save_checkpoint(output_checkpoint, model, clip, vae, clip_vision, metadata=metadata, extra_keys=extra_keys) -class CheckpointSave(io.ComfyNode): +class ModelMergeSimple(io.ComfyNode): @classmethod def define_schema(cls): return io.Schema( - node_id="CheckpointSave_V3", - display_name="Save Checkpoint _V3", - category="advanced/model_merging", - is_output_node=True, - inputs=[ - io.Model.Input("model"), - io.Clip.Input("clip"), - io.Vae.Input("vae"), - io.String.Input("filename_prefix", default="checkpoints/ComfyUI") - ], - outputs=[], - hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo] - ) - - @classmethod - def execute(cls, model, clip, vae, filename_prefix): - save_checkpoint(model, clip=clip, vae=vae, filename_prefix=filename_prefix, output_dir=folder_paths.get_output_directory(), prompt=cls.hidden.prompt, extra_pnginfo=cls.hidden.extra_pnginfo) - return io.NodeOutput() - - -class CLIPAdd(io.ComfyNode): - @classmethod - def define_schema(cls): - return io.Schema( - node_id="CLIPMergeAdd_V3", + node_id="ModelMergeSimple_V3", category="advanced/model_merging", inputs=[ - io.Clip.Input("clip1"), - io.Clip.Input("clip2") + io.Model.Input("model1"), + io.Model.Input("model2"), + io.Float.Input("ratio", default=1.0, min=0.0, max=1.0, step=0.01) ], outputs=[ - io.Clip.Output() + io.Model.Output() ] ) @classmethod - def execute(cls, clip1, clip2): - m = clip1.clone() - kp = clip2.get_key_patches() + def execute(cls, model1, model2, ratio): + m = model1.clone() + kp = model2.get_key_patches("diffusion_model.") + for k in kp: + m.add_patches({k: kp[k]}, 1.0 - ratio, ratio) + return io.NodeOutput(m) + + +class ModelSubtract(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="ModelMergeSubtract_V3", + category="advanced/model_merging", + inputs=[ + io.Model.Input("model1"), + io.Model.Input("model2"), + io.Float.Input("multiplier", default=1.0, min=-10.0, max=10.0, step=0.01) + ], + outputs=[ + io.Model.Output() + ] + ) + + @classmethod + def execute(cls, model1, model2, multiplier): + m = model1.clone() + kp = model2.get_key_patches("diffusion_model.") + for k in kp: + m.add_patches({k: kp[k]}, - multiplier, multiplier) + return io.NodeOutput(m) + + +class ModelAdd(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="ModelMergeAdd_V3", + category="advanced/model_merging", + inputs=[ + io.Model.Input("model1"), + io.Model.Input("model2") + ], + outputs=[ + io.Model.Output() + ] + ) + + @classmethod + def execute(cls, model1, model2): + m = model1.clone() + kp = model2.get_key_patches("diffusion_model.") for k in kp: - if k.endswith(".position_ids") or k.endswith(".logit_scale"): - continue m.add_patches({k: kp[k]}, 1.0, 1.0) return io.NodeOutput(m) @@ -152,6 +176,121 @@ class CLIPMergeSimple(io.ComfyNode): return io.NodeOutput(m) +class CLIPSubtract(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="CLIPMergeSubtract_V3", + category="advanced/model_merging", + inputs=[ + io.Clip.Input("clip1"), + io.Clip.Input("clip2"), + io.Float.Input("multiplier", default=1.0, min=-10.0, max=10.0, step=0.01) + ], + outputs=[ + io.Clip.Output() + ] + ) + + @classmethod + def execute(cls, clip1, clip2, multiplier): + m = clip1.clone() + kp = clip2.get_key_patches() + for k in kp: + if k.endswith(".position_ids") or k.endswith(".logit_scale"): + continue + m.add_patches({k: kp[k]}, - multiplier, multiplier) + return io.NodeOutput(m) + + +class CLIPAdd(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="CLIPMergeAdd_V3", + category="advanced/model_merging", + inputs=[ + io.Clip.Input("clip1"), + io.Clip.Input("clip2") + ], + outputs=[ + io.Clip.Output() + ] + ) + + @classmethod + def execute(cls, clip1, clip2): + m = clip1.clone() + kp = clip2.get_key_patches() + for k in kp: + if k.endswith(".position_ids") or k.endswith(".logit_scale"): + continue + m.add_patches({k: kp[k]}, 1.0, 1.0) + return io.NodeOutput(m) + + +class ModelMergeBlocks(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="ModelMergeBlocks_V3", + category="advanced/model_merging", + inputs=[ + io.Model.Input("model1"), + io.Model.Input("model2"), + io.Float.Input("input", default=1.0, min=0.0, max=1.0, step=0.01), + io.Float.Input("middle", default=1.0, min=0.0, max=1.0, step=0.01), + io.Float.Input("out", default=1.0, min=0.0, max=1.0, step=0.01) + ], + outputs=[ + io.Model.Output() + ] + ) + + @classmethod + def execute(cls, model1, model2, **kwargs): + m = model1.clone() + kp = model2.get_key_patches("diffusion_model.") + default_ratio = next(iter(kwargs.values())) + + for k in kp: + ratio = default_ratio + k_unet = k[len("diffusion_model."):] + + last_arg_size = 0 + for arg in kwargs: + if k_unet.startswith(arg) and last_arg_size < len(arg): + ratio = kwargs[arg] + last_arg_size = len(arg) + + m.add_patches({k: kp[k]}, 1.0 - ratio, ratio) + return io.NodeOutput(m) + + +class CheckpointSave(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="CheckpointSave_V3", + display_name="Save Checkpoint _V3", + category="advanced/model_merging", + is_output_node=True, + inputs=[ + io.Model.Input("model"), + io.Clip.Input("clip"), + io.Vae.Input("vae"), + io.String.Input("filename_prefix", default="checkpoints/ComfyUI") + ], + outputs=[], + hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo] + ) + + @classmethod + def execute(cls, model, clip, vae, filename_prefix): + save_checkpoint(model, clip=clip, vae=vae, filename_prefix=filename_prefix, output_dir=folder_paths.get_output_directory(), prompt=cls.hidden.prompt, extra_pnginfo=cls.hidden.extra_pnginfo) + return io.NodeOutput() + + class CLIPSave(io.ComfyNode): @classmethod def define_schema(cls): @@ -211,166 +350,6 @@ class CLIPSave(io.ComfyNode): return io.NodeOutput() -class CLIPSubtract(io.ComfyNode): - @classmethod - def define_schema(cls): - return io.Schema( - node_id="CLIPMergeSubtract_V3", - category="advanced/model_merging", - inputs=[ - io.Clip.Input("clip1"), - io.Clip.Input("clip2"), - io.Float.Input("multiplier", default=1.0, min=-10.0, max=10.0, step=0.01) - ], - outputs=[ - io.Clip.Output() - ] - ) - - @classmethod - def execute(cls, clip1, clip2, multiplier): - m = clip1.clone() - kp = clip2.get_key_patches() - for k in kp: - if k.endswith(".position_ids") or k.endswith(".logit_scale"): - continue - m.add_patches({k: kp[k]}, - multiplier, multiplier) - return io.NodeOutput(m) - - -class ModelAdd(io.ComfyNode): - @classmethod - def define_schema(cls): - return io.Schema( - node_id="ModelMergeAdd_V3", - category="advanced/model_merging", - inputs=[ - io.Model.Input("model1"), - io.Model.Input("model2") - ], - outputs=[ - io.Model.Output() - ] - ) - - @classmethod - def execute(cls, model1, model2): - m = model1.clone() - kp = model2.get_key_patches("diffusion_model.") - for k in kp: - m.add_patches({k: kp[k]}, 1.0, 1.0) - return io.NodeOutput(m) - - -class ModelMergeBlocks(io.ComfyNode): - @classmethod - def define_schema(cls): - return io.Schema( - node_id="ModelMergeBlocks_V3", - category="advanced/model_merging", - inputs=[ - io.Model.Input("model1"), - io.Model.Input("model2"), - io.Float.Input("input", default=1.0, min=0.0, max=1.0, step=0.01), - io.Float.Input("middle", default=1.0, min=0.0, max=1.0, step=0.01), - io.Float.Input("out", default=1.0, min=0.0, max=1.0, step=0.01) - ], - outputs=[ - io.Model.Output() - ] - ) - - @classmethod - def execute(cls, model1, model2, **kwargs): - m = model1.clone() - kp = model2.get_key_patches("diffusion_model.") - default_ratio = next(iter(kwargs.values())) - - for k in kp: - ratio = default_ratio - k_unet = k[len("diffusion_model."):] - - last_arg_size = 0 - for arg in kwargs: - if k_unet.startswith(arg) and last_arg_size < len(arg): - ratio = kwargs[arg] - last_arg_size = len(arg) - - m.add_patches({k: kp[k]}, 1.0 - ratio, ratio) - return io.NodeOutput(m) - - -class ModelMergeSimple(io.ComfyNode): - @classmethod - def define_schema(cls): - return io.Schema( - node_id="ModelMergeSimple_V3", - category="advanced/model_merging", - inputs=[ - io.Model.Input("model1"), - io.Model.Input("model2"), - io.Float.Input("ratio", default=1.0, min=0.0, max=1.0, step=0.01) - ], - outputs=[ - io.Model.Output() - ] - ) - - @classmethod - def execute(cls, model1, model2, ratio): - m = model1.clone() - kp = model2.get_key_patches("diffusion_model.") - for k in kp: - m.add_patches({k: kp[k]}, 1.0 - ratio, ratio) - return io.NodeOutput(m) - - -class ModelSave(io.ComfyNode): - @classmethod - def define_schema(cls): - return io.Schema( - node_id="ModelSave_V3", - category="advanced/model_merging", - is_output_node=True, - inputs=[ - io.Model.Input("model"), - io.String.Input("filename_prefix", default="diffusion_models/ComfyUI") - ], - outputs=[], - hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo] - ) - - @classmethod - def execute(cls, model, filename_prefix): - save_checkpoint(model, filename_prefix=filename_prefix, output_dir=folder_paths.get_output_directory(), prompt=cls.hidden.prompt, extra_pnginfo=cls.hidden.extra_pnginfo) - return io.NodeOutput() - - -class ModelSubtract(io.ComfyNode): - @classmethod - def define_schema(cls): - return io.Schema( - node_id="ModelMergeSubtract_V3", - category="advanced/model_merging", - inputs=[ - io.Model.Input("model1"), - io.Model.Input("model2"), - io.Float.Input("multiplier", default=1.0, min=-10.0, max=10.0, step=0.01) - ], - outputs=[ - io.Model.Output() - ] - ) - - @classmethod - def execute(cls, model1, model2, multiplier): - m = model1.clone() - kp = model2.get_key_patches("diffusion_model.") - for k in kp: - m.add_patches({k: kp[k]}, - multiplier, multiplier) - return io.NodeOutput(m) - - class VAESave(io.ComfyNode): @classmethod def define_schema(cls): @@ -407,7 +386,28 @@ class VAESave(io.ComfyNode): return io.NodeOutput() -NODES_LIST = [ +class ModelSave(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="ModelSave_V3", + category="advanced/model_merging", + is_output_node=True, + inputs=[ + io.Model.Input("model"), + io.String.Input("filename_prefix", default="diffusion_models/ComfyUI") + ], + outputs=[], + hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo] + ) + + @classmethod + def execute(cls, model, filename_prefix): + save_checkpoint(model, filename_prefix=filename_prefix, output_dir=folder_paths.get_output_directory(), prompt=cls.hidden.prompt, extra_pnginfo=cls.hidden.extra_pnginfo) + return io.NodeOutput() + + +NODES_LIST: list[type[io.ComfyNode]] = [ CheckpointSave, CLIPAdd, CLIPMergeSimple, diff --git a/comfy_extras/v3/nodes_model_merging_model_specific.py b/comfy_extras/v3/nodes_model_merging_model_specific.py index 867f2a1c0..59069e0cd 100644 --- a/comfy_extras/v3/nodes_model_merging_model_specific.py +++ b/comfy_extras/v3/nodes_model_merging_model_specific.py @@ -4,237 +4,6 @@ from comfy_api.latest import io from comfy_extras.v3.nodes_model_merging import ModelMergeBlocks -class ModelMergeAuraflow(ModelMergeBlocks): - @classmethod - def define_schema(cls): - inputs = [ - io.Model.Input("model1"), - io.Model.Input("model2"), - io.Float.Input("init_x_linear.", default=1.0, min=0.0, max=1.0, step=0.01), - io.Float.Input("positional_encoding", default=1.0, min=0.0, max=1.0, step=0.01), - io.Float.Input("cond_seq_linear.", default=1.0, min=0.0, max=1.0, step=0.01), - io.Float.Input("register_tokens", default=1.0, min=0.0, max=1.0, step=0.01), - io.Float.Input("t_embedder.", default=1.0, min=0.0, max=1.0, step=0.01) - ] - - for i in range(4): - inputs.append(io.Float.Input(f"double_layers.{i}.", default=1.0, min=0.0, max=1.0, step=0.01)) - - for i in range(32): - inputs.append(io.Float.Input(f"single_layers.{i}.", default=1.0, min=0.0, max=1.0, step=0.01)) - - inputs.extend([ - io.Float.Input("modF.", default=1.0, min=0.0, max=1.0, step=0.01), - io.Float.Input("final_linear.", default=1.0, min=0.0, max=1.0, step=0.01) - ]) - - return io.Schema( - node_id="ModelMergeAuraflow_V3", - category="advanced/model_merging/model_specific", - inputs=inputs, - outputs=[ - io.Model.Output(), - ] - ) - - -class ModelMergeCosmos14B(ModelMergeBlocks): - @classmethod - def define_schema(cls): - inputs = [ - io.Model.Input("model1"), - io.Model.Input("model2"), - io.Float.Input("pos_embedder.", default=1.0, min=0.0, max=1.0, step=0.01), - io.Float.Input("extra_pos_embedder.", default=1.0, min=0.0, max=1.0, step=0.01), - io.Float.Input("x_embedder.", default=1.0, min=0.0, max=1.0, step=0.01), - io.Float.Input("t_embedder.", default=1.0, min=0.0, max=1.0, step=0.01), - io.Float.Input("affline_norm.", default=1.0, min=0.0, max=1.0, step=0.01) - ] - - for i in range(36): - inputs.append(io.Float.Input(f"blocks.block{i}.", default=1.0, min=0.0, max=1.0, step=0.01)) - - inputs.append(io.Float.Input("final_layer.", default=1.0, min=0.0, max=1.0, step=0.01)) - - return io.Schema( - node_id="ModelMergeCosmos14B_V3", - category="advanced/model_merging/model_specific", - inputs=inputs, - outputs=[ - io.Model.Output(), - ] - ) - - -class ModelMergeCosmos7B(ModelMergeBlocks): - @classmethod - def define_schema(cls): - inputs = [ - io.Model.Input("model1"), - io.Model.Input("model2"), - io.Float.Input("pos_embedder.", default=1.0, min=0.0, max=1.0, step=0.01), - io.Float.Input("extra_pos_embedder.", default=1.0, min=0.0, max=1.0, step=0.01), - io.Float.Input("x_embedder.", default=1.0, min=0.0, max=1.0, step=0.01), - io.Float.Input("t_embedder.", default=1.0, min=0.0, max=1.0, step=0.01), - io.Float.Input("affline_norm.", default=1.0, min=0.0, max=1.0, step=0.01) - ] - - for i in range(28): - inputs.append(io.Float.Input(f"blocks.block{i}.", default=1.0, min=0.0, max=1.0, step=0.01)) - - inputs.append(io.Float.Input("final_layer.", default=1.0, min=0.0, max=1.0, step=0.01)) - - return io.Schema( - node_id="ModelMergeCosmos7B_V3", - category="advanced/model_merging/model_specific", - inputs=inputs, - outputs=[ - io.Model.Output(), - ] - ) - - -class ModelMergeCosmosPredict2_14B(ModelMergeBlocks): - @classmethod - def define_schema(cls): - inputs = [ - io.Model.Input("model1"), - io.Model.Input("model2"), - io.Float.Input("pos_embedder.", default=1.0, min=0.0, max=1.0, step=0.01), - io.Float.Input("x_embedder.", default=1.0, min=0.0, max=1.0, step=0.01), - io.Float.Input("t_embedder.", default=1.0, min=0.0, max=1.0, step=0.01), - io.Float.Input("t_embedding_norm.", default=1.0, min=0.0, max=1.0, step=0.01) - ] - - for i in range(36): - inputs.append(io.Float.Input(f"blocks.{i}.", default=1.0, min=0.0, max=1.0, step=0.01)) - - inputs.append(io.Float.Input("final_layer.", default=1.0, min=0.0, max=1.0, step=0.01)) - - return io.Schema( - node_id="ModelMergeCosmosPredict2_14B_V3", - category="advanced/model_merging/model_specific", - inputs=inputs, - outputs=[ - io.Model.Output(), - ] - ) - - -class ModelMergeCosmosPredict2_2B(ModelMergeBlocks): - @classmethod - def define_schema(cls): - inputs = [ - io.Model.Input("model1"), - io.Model.Input("model2"), - io.Float.Input("pos_embedder.", default=1.0, min=0.0, max=1.0, step=0.01), - io.Float.Input("x_embedder.", default=1.0, min=0.0, max=1.0, step=0.01), - io.Float.Input("t_embedder.", default=1.0, min=0.0, max=1.0, step=0.01), - io.Float.Input("t_embedding_norm.", default=1.0, min=0.0, max=1.0, step=0.01) - ] - - for i in range(28): - inputs.append(io.Float.Input(f"blocks.{i}.", default=1.0, min=0.0, max=1.0, step=0.01)) - - inputs.append(io.Float.Input("final_layer.", default=1.0, min=0.0, max=1.0, step=0.01)) - - return io.Schema( - node_id="ModelMergeCosmosPredict2_2B_V3", - category="advanced/model_merging/model_specific", - inputs=inputs, - outputs=[ - io.Model.Output(), - ] - ) - - -class ModelMergeFlux1(ModelMergeBlocks): - @classmethod - def define_schema(cls): - inputs = [ - io.Model.Input("model1"), - io.Model.Input("model2"), - io.Float.Input("img_in.", default=1.0, min=0.0, max=1.0, step=0.01), - io.Float.Input("time_in.", default=1.0, min=0.0, max=1.0, step=0.01), - io.Float.Input("guidance_in", default=1.0, min=0.0, max=1.0, step=0.01), - io.Float.Input("vector_in.", default=1.0, min=0.0, max=1.0, step=0.01), - io.Float.Input("txt_in.", default=1.0, min=0.0, max=1.0, step=0.01) - ] - - for i in range(19): - inputs.append(io.Float.Input(f"double_blocks.{i}.", default=1.0, min=0.0, max=1.0, step=0.01)) - - for i in range(38): - inputs.append(io.Float.Input(f"single_blocks.{i}.", default=1.0, min=0.0, max=1.0, step=0.01)) - - inputs.append(io.Float.Input("final_layer.", default=1.0, min=0.0, max=1.0, step=0.01)) - - return io.Schema( - node_id="ModelMergeFlux1_V3", - category="advanced/model_merging/model_specific", - inputs=inputs, - outputs=[ - io.Model.Output(), - ] - ) - - -class ModelMergeLTXV(ModelMergeBlocks): - @classmethod - def define_schema(cls): - inputs = [ - io.Model.Input("model1"), - io.Model.Input("model2"), - io.Float.Input("patchify_proj.", default=1.0, min=0.0, max=1.0, step=0.01), - io.Float.Input("adaln_single.", default=1.0, min=0.0, max=1.0, step=0.01), - io.Float.Input("caption_projection.", default=1.0, min=0.0, max=1.0, step=0.01) - ] - - for i in range(28): - inputs.append(io.Float.Input(f"transformer_blocks.{i}.", default=1.0, min=0.0, max=1.0, step=0.01)) - - inputs.extend([ - io.Float.Input("scale_shift_table", default=1.0, min=0.0, max=1.0, step=0.01), - io.Float.Input("proj_out.", default=1.0, min=0.0, max=1.0, step=0.01) - ]) - - return io.Schema( - node_id="ModelMergeLTXV_V3", - category="advanced/model_merging/model_specific", - inputs=inputs, - outputs=[ - io.Model.Output(), - ] - ) - - -class ModelMergeMochiPreview(ModelMergeBlocks): - @classmethod - def define_schema(cls): - inputs = [ - io.Model.Input("model1"), - io.Model.Input("model2"), - io.Float.Input("pos_frequencies.", default=1.0, min=0.0, max=1.0, step=0.01), - io.Float.Input("t_embedder.", default=1.0, min=0.0, max=1.0, step=0.01), - io.Float.Input("t5_y_embedder.", default=1.0, min=0.0, max=1.0, step=0.01), - io.Float.Input("t5_yproj.", default=1.0, min=0.0, max=1.0, step=0.01) - ] - - for i in range(48): - inputs.append(io.Float.Input(f"blocks.{i}.", default=1.0, min=0.0, max=1.0, step=0.01)) - - inputs.append(io.Float.Input("final_layer.", default=1.0, min=0.0, max=1.0, step=0.01)) - - return io.Schema( - node_id="ModelMergeMochiPreview_V3", - category="advanced/model_merging/model_specific", - inputs=inputs, - outputs=[ - io.Model.Output(), - ] - ) - - class ModelMergeSD1(ModelMergeBlocks): @classmethod def define_schema(cls): @@ -266,62 +35,6 @@ class ModelMergeSD1(ModelMergeBlocks): ) -class ModelMergeSD3_2B(ModelMergeBlocks): - @classmethod - def define_schema(cls): - inputs = [ - io.Model.Input("model1"), - io.Model.Input("model2"), - io.Float.Input("pos_embed.", default=1.0, min=0.0, max=1.0, step=0.01), - io.Float.Input("x_embedder.", default=1.0, min=0.0, max=1.0, step=0.01), - io.Float.Input("context_embedder.", default=1.0, min=0.0, max=1.0, step=0.01), - io.Float.Input("y_embedder.", default=1.0, min=0.0, max=1.0, step=0.01), - io.Float.Input("t_embedder.", default=1.0, min=0.0, max=1.0, step=0.01) - ] - - for i in range(24): - inputs.append(io.Float.Input(f"joint_blocks.{i}.", default=1.0, min=0.0, max=1.0, step=0.01)) - - inputs.append(io.Float.Input("final_layer.", default=1.0, min=0.0, max=1.0, step=0.01)) - - return io.Schema( - node_id="ModelMergeSD3_2B_V3", - category="advanced/model_merging/model_specific", - inputs=inputs, - outputs=[ - io.Model.Output(), - ] - ) - - -class ModelMergeSD35_Large(ModelMergeBlocks): - @classmethod - def define_schema(cls): - inputs = [ - io.Model.Input("model1"), - io.Model.Input("model2"), - io.Float.Input("pos_embed.", default=1.0, min=0.0, max=1.0, step=0.01), - io.Float.Input("x_embedder.", default=1.0, min=0.0, max=1.0, step=0.01), - io.Float.Input("context_embedder.", default=1.0, min=0.0, max=1.0, step=0.01), - io.Float.Input("y_embedder.", default=1.0, min=0.0, max=1.0, step=0.01), - io.Float.Input("t_embedder.", default=1.0, min=0.0, max=1.0, step=0.01) - ] - - for i in range(38): - inputs.append(io.Float.Input(f"joint_blocks.{i}.", default=1.0, min=0.0, max=1.0, step=0.01)) - - inputs.append(io.Float.Input("final_layer.", default=1.0, min=0.0, max=1.0, step=0.01)) - - return io.Schema( - node_id="ModelMergeSD35_Large_V3", - category="advanced/model_merging/model_specific", - inputs=inputs, - outputs=[ - io.Model.Output(), - ] - ) - - class ModelMergeSDXL(ModelMergeBlocks): @classmethod def define_schema(cls): @@ -353,6 +66,239 @@ class ModelMergeSDXL(ModelMergeBlocks): ) +class ModelMergeSD3_2B(ModelMergeBlocks): + @classmethod + def define_schema(cls): + inputs = [ + io.Model.Input("model1"), + io.Model.Input("model2"), + io.Float.Input("pos_embed.", default=1.0, min=0.0, max=1.0, step=0.01), + io.Float.Input("x_embedder.", default=1.0, min=0.0, max=1.0, step=0.01), + io.Float.Input("context_embedder.", default=1.0, min=0.0, max=1.0, step=0.01), + io.Float.Input("y_embedder.", default=1.0, min=0.0, max=1.0, step=0.01), + io.Float.Input("t_embedder.", default=1.0, min=0.0, max=1.0, step=0.01) + ] + + for i in range(24): + inputs.append(io.Float.Input(f"joint_blocks.{i}.", default=1.0, min=0.0, max=1.0, step=0.01)) + + inputs.append(io.Float.Input("final_layer.", default=1.0, min=0.0, max=1.0, step=0.01)) + + return io.Schema( + node_id="ModelMergeSD3_2B_V3", + category="advanced/model_merging/model_specific", + inputs=inputs, + outputs=[ + io.Model.Output(), + ] + ) + + +class ModelMergeAuraflow(ModelMergeBlocks): + @classmethod + def define_schema(cls): + inputs = [ + io.Model.Input("model1"), + io.Model.Input("model2"), + io.Float.Input("init_x_linear.", default=1.0, min=0.0, max=1.0, step=0.01), + io.Float.Input("positional_encoding", default=1.0, min=0.0, max=1.0, step=0.01), + io.Float.Input("cond_seq_linear.", default=1.0, min=0.0, max=1.0, step=0.01), + io.Float.Input("register_tokens", default=1.0, min=0.0, max=1.0, step=0.01), + io.Float.Input("t_embedder.", default=1.0, min=0.0, max=1.0, step=0.01) + ] + + for i in range(4): + inputs.append(io.Float.Input(f"double_layers.{i}.", default=1.0, min=0.0, max=1.0, step=0.01)) + + for i in range(32): + inputs.append(io.Float.Input(f"single_layers.{i}.", default=1.0, min=0.0, max=1.0, step=0.01)) + + inputs.extend([ + io.Float.Input("modF.", default=1.0, min=0.0, max=1.0, step=0.01), + io.Float.Input("final_linear.", default=1.0, min=0.0, max=1.0, step=0.01) + ]) + + return io.Schema( + node_id="ModelMergeAuraflow_V3", + category="advanced/model_merging/model_specific", + inputs=inputs, + outputs=[ + io.Model.Output(), + ] + ) + + +class ModelMergeFlux1(ModelMergeBlocks): + @classmethod + def define_schema(cls): + inputs = [ + io.Model.Input("model1"), + io.Model.Input("model2"), + io.Float.Input("img_in.", default=1.0, min=0.0, max=1.0, step=0.01), + io.Float.Input("time_in.", default=1.0, min=0.0, max=1.0, step=0.01), + io.Float.Input("guidance_in", default=1.0, min=0.0, max=1.0, step=0.01), + io.Float.Input("vector_in.", default=1.0, min=0.0, max=1.0, step=0.01), + io.Float.Input("txt_in.", default=1.0, min=0.0, max=1.0, step=0.01) + ] + + for i in range(19): + inputs.append(io.Float.Input(f"double_blocks.{i}.", default=1.0, min=0.0, max=1.0, step=0.01)) + + for i in range(38): + inputs.append(io.Float.Input(f"single_blocks.{i}.", default=1.0, min=0.0, max=1.0, step=0.01)) + + inputs.append(io.Float.Input("final_layer.", default=1.0, min=0.0, max=1.0, step=0.01)) + + return io.Schema( + node_id="ModelMergeFlux1_V3", + category="advanced/model_merging/model_specific", + inputs=inputs, + outputs=[ + io.Model.Output(), + ] + ) + + +class ModelMergeSD35_Large(ModelMergeBlocks): + @classmethod + def define_schema(cls): + inputs = [ + io.Model.Input("model1"), + io.Model.Input("model2"), + io.Float.Input("pos_embed.", default=1.0, min=0.0, max=1.0, step=0.01), + io.Float.Input("x_embedder.", default=1.0, min=0.0, max=1.0, step=0.01), + io.Float.Input("context_embedder.", default=1.0, min=0.0, max=1.0, step=0.01), + io.Float.Input("y_embedder.", default=1.0, min=0.0, max=1.0, step=0.01), + io.Float.Input("t_embedder.", default=1.0, min=0.0, max=1.0, step=0.01) + ] + + for i in range(38): + inputs.append(io.Float.Input(f"joint_blocks.{i}.", default=1.0, min=0.0, max=1.0, step=0.01)) + + inputs.append(io.Float.Input("final_layer.", default=1.0, min=0.0, max=1.0, step=0.01)) + + return io.Schema( + node_id="ModelMergeSD35_Large_V3", + category="advanced/model_merging/model_specific", + inputs=inputs, + outputs=[ + io.Model.Output(), + ] + ) + + +class ModelMergeMochiPreview(ModelMergeBlocks): + @classmethod + def define_schema(cls): + inputs = [ + io.Model.Input("model1"), + io.Model.Input("model2"), + io.Float.Input("pos_frequencies.", default=1.0, min=0.0, max=1.0, step=0.01), + io.Float.Input("t_embedder.", default=1.0, min=0.0, max=1.0, step=0.01), + io.Float.Input("t5_y_embedder.", default=1.0, min=0.0, max=1.0, step=0.01), + io.Float.Input("t5_yproj.", default=1.0, min=0.0, max=1.0, step=0.01) + ] + + for i in range(48): + inputs.append(io.Float.Input(f"blocks.{i}.", default=1.0, min=0.0, max=1.0, step=0.01)) + + inputs.append(io.Float.Input("final_layer.", default=1.0, min=0.0, max=1.0, step=0.01)) + + return io.Schema( + node_id="ModelMergeMochiPreview_V3", + category="advanced/model_merging/model_specific", + inputs=inputs, + outputs=[ + io.Model.Output(), + ] + ) + + +class ModelMergeLTXV(ModelMergeBlocks): + @classmethod + def define_schema(cls): + inputs = [ + io.Model.Input("model1"), + io.Model.Input("model2"), + io.Float.Input("patchify_proj.", default=1.0, min=0.0, max=1.0, step=0.01), + io.Float.Input("adaln_single.", default=1.0, min=0.0, max=1.0, step=0.01), + io.Float.Input("caption_projection.", default=1.0, min=0.0, max=1.0, step=0.01) + ] + + for i in range(28): + inputs.append(io.Float.Input(f"transformer_blocks.{i}.", default=1.0, min=0.0, max=1.0, step=0.01)) + + inputs.extend([ + io.Float.Input("scale_shift_table", default=1.0, min=0.0, max=1.0, step=0.01), + io.Float.Input("proj_out.", default=1.0, min=0.0, max=1.0, step=0.01) + ]) + + return io.Schema( + node_id="ModelMergeLTXV_V3", + category="advanced/model_merging/model_specific", + inputs=inputs, + outputs=[ + io.Model.Output(), + ] + ) + + +class ModelMergeCosmos7B(ModelMergeBlocks): + @classmethod + def define_schema(cls): + inputs = [ + io.Model.Input("model1"), + io.Model.Input("model2"), + io.Float.Input("pos_embedder.", default=1.0, min=0.0, max=1.0, step=0.01), + io.Float.Input("extra_pos_embedder.", default=1.0, min=0.0, max=1.0, step=0.01), + io.Float.Input("x_embedder.", default=1.0, min=0.0, max=1.0, step=0.01), + io.Float.Input("t_embedder.", default=1.0, min=0.0, max=1.0, step=0.01), + io.Float.Input("affline_norm.", default=1.0, min=0.0, max=1.0, step=0.01) + ] + + for i in range(28): + inputs.append(io.Float.Input(f"blocks.block{i}.", default=1.0, min=0.0, max=1.0, step=0.01)) + + inputs.append(io.Float.Input("final_layer.", default=1.0, min=0.0, max=1.0, step=0.01)) + + return io.Schema( + node_id="ModelMergeCosmos7B_V3", + category="advanced/model_merging/model_specific", + inputs=inputs, + outputs=[ + io.Model.Output(), + ] + ) + + +class ModelMergeCosmos14B(ModelMergeBlocks): + @classmethod + def define_schema(cls): + inputs = [ + io.Model.Input("model1"), + io.Model.Input("model2"), + io.Float.Input("pos_embedder.", default=1.0, min=0.0, max=1.0, step=0.01), + io.Float.Input("extra_pos_embedder.", default=1.0, min=0.0, max=1.0, step=0.01), + io.Float.Input("x_embedder.", default=1.0, min=0.0, max=1.0, step=0.01), + io.Float.Input("t_embedder.", default=1.0, min=0.0, max=1.0, step=0.01), + io.Float.Input("affline_norm.", default=1.0, min=0.0, max=1.0, step=0.01) + ] + + for i in range(36): + inputs.append(io.Float.Input(f"blocks.block{i}.", default=1.0, min=0.0, max=1.0, step=0.01)) + + inputs.append(io.Float.Input("final_layer.", default=1.0, min=0.0, max=1.0, step=0.01)) + + return io.Schema( + node_id="ModelMergeCosmos14B_V3", + category="advanced/model_merging/model_specific", + inputs=inputs, + outputs=[ + io.Model.Output(), + ] + ) + + class ModelMergeWAN2_1(ModelMergeBlocks): @classmethod def define_schema(cls): @@ -382,7 +328,61 @@ class ModelMergeWAN2_1(ModelMergeBlocks): ) -NODES_LIST = [ +class ModelMergeCosmosPredict2_2B(ModelMergeBlocks): + @classmethod + def define_schema(cls): + inputs = [ + io.Model.Input("model1"), + io.Model.Input("model2"), + io.Float.Input("pos_embedder.", default=1.0, min=0.0, max=1.0, step=0.01), + io.Float.Input("x_embedder.", default=1.0, min=0.0, max=1.0, step=0.01), + io.Float.Input("t_embedder.", default=1.0, min=0.0, max=1.0, step=0.01), + io.Float.Input("t_embedding_norm.", default=1.0, min=0.0, max=1.0, step=0.01) + ] + + for i in range(28): + inputs.append(io.Float.Input(f"blocks.{i}.", default=1.0, min=0.0, max=1.0, step=0.01)) + + inputs.append(io.Float.Input("final_layer.", default=1.0, min=0.0, max=1.0, step=0.01)) + + return io.Schema( + node_id="ModelMergeCosmosPredict2_2B_V3", + category="advanced/model_merging/model_specific", + inputs=inputs, + outputs=[ + io.Model.Output(), + ] + ) + + +class ModelMergeCosmosPredict2_14B(ModelMergeBlocks): + @classmethod + def define_schema(cls): + inputs = [ + io.Model.Input("model1"), + io.Model.Input("model2"), + io.Float.Input("pos_embedder.", default=1.0, min=0.0, max=1.0, step=0.01), + io.Float.Input("x_embedder.", default=1.0, min=0.0, max=1.0, step=0.01), + io.Float.Input("t_embedder.", default=1.0, min=0.0, max=1.0, step=0.01), + io.Float.Input("t_embedding_norm.", default=1.0, min=0.0, max=1.0, step=0.01) + ] + + for i in range(36): + inputs.append(io.Float.Input(f"blocks.{i}.", default=1.0, min=0.0, max=1.0, step=0.01)) + + inputs.append(io.Float.Input("final_layer.", default=1.0, min=0.0, max=1.0, step=0.01)) + + return io.Schema( + node_id="ModelMergeCosmosPredict2_14B_V3", + category="advanced/model_merging/model_specific", + inputs=inputs, + outputs=[ + io.Model.Output(), + ] + ) + + +NODES_LIST: list[type[io.ComfyNode]] = [ ModelMergeAuraflow, ModelMergeCosmos14B, ModelMergeCosmos7B, diff --git a/comfy_extras/v3/nodes_morphology.py b/comfy_extras/v3/nodes_morphology.py index bb4e2543a..1f28951cd 100644 --- a/comfy_extras/v3/nodes_morphology.py +++ b/comfy_extras/v3/nodes_morphology.py @@ -16,6 +16,47 @@ import comfy.model_management from comfy_api.latest import io +class Morphology(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="Morphology_V3", + display_name="ImageMorphology _V3", + category="image/postprocessing", + inputs=[ + io.Image.Input("image"), + io.Combo.Input("operation", options=["erode", "dilate", "open", "close", "gradient", "bottom_hat", "top_hat"]), + io.Int.Input("kernel_size", default=3, min=3, max=999, step=1), + ], + outputs=[ + io.Image.Output(), + ], + ) + + @classmethod + def execute(cls, image, operation, kernel_size): + device = comfy.model_management.get_torch_device() + kernel = torch.ones(kernel_size, kernel_size, device=device) + image_k = image.to(device).movedim(-1, 1) + if operation == "erode": + output = erosion(image_k, kernel) + elif operation == "dilate": + output = dilation(image_k, kernel) + elif operation == "open": + output = opening(image_k, kernel) + elif operation == "close": + output = closing(image_k, kernel) + elif operation == "gradient": + output = gradient(image_k, kernel) + elif operation == "top_hat": + output = top_hat(image_k, kernel) + elif operation == "bottom_hat": + output = bottom_hat(image_k, kernel) + else: + raise ValueError(f"Invalid operation {operation} for morphology. Must be one of 'erode', 'dilate', 'open', 'close', 'gradient', 'tophat', 'bottomhat'") + return io.NodeOutput(output.to(comfy.model_management.intermediate_device()).movedim(1, -1)) + + class ImageRGBToYUV(io.ComfyNode): @classmethod def define_schema(cls): @@ -60,48 +101,7 @@ class ImageYUVToRGB(io.ComfyNode): return io.NodeOutput(kornia.color.ycbcr_to_rgb(image.movedim(-1, 1)).movedim(1, -1)) -class Morphology(io.ComfyNode): - @classmethod - def define_schema(cls): - return io.Schema( - node_id="Morphology_V3", - display_name="ImageMorphology _V3", - category="image/postprocessing", - inputs=[ - io.Image.Input("image"), - io.Combo.Input("operation", options=["erode", "dilate", "open", "close", "gradient", "bottom_hat", "top_hat"]), - io.Int.Input("kernel_size", default=3, min=3, max=999, step=1), - ], - outputs=[ - io.Image.Output(), - ], - ) - - @classmethod - def execute(cls, image, operation, kernel_size): - device = comfy.model_management.get_torch_device() - kernel = torch.ones(kernel_size, kernel_size, device=device) - image_k = image.to(device).movedim(-1, 1) - if operation == "erode": - output = erosion(image_k, kernel) - elif operation == "dilate": - output = dilation(image_k, kernel) - elif operation == "open": - output = opening(image_k, kernel) - elif operation == "close": - output = closing(image_k, kernel) - elif operation == "gradient": - output = gradient(image_k, kernel) - elif operation == "top_hat": - output = top_hat(image_k, kernel) - elif operation == "bottom_hat": - output = bottom_hat(image_k, kernel) - else: - raise ValueError(f"Invalid operation {operation} for morphology. Must be one of 'erode', 'dilate', 'open', 'close', 'gradient', 'tophat', 'bottomhat'") - return io.NodeOutput(output.to(comfy.model_management.intermediate_device()).movedim(1, -1)) - - -NODES_LIST = [ +NODES_LIST: list[type[io.ComfyNode]] = [ ImageRGBToYUV, ImageYUVToRGB, Morphology, diff --git a/comfy_extras/v3/nodes_photomaker.py b/comfy_extras/v3/nodes_photomaker.py index 7b742cb98..cec88816e 100644 --- a/comfy_extras/v3/nodes_photomaker.py +++ b/comfy_extras/v3/nodes_photomaker.py @@ -121,6 +121,32 @@ class PhotoMakerIDEncoder(comfy.clip_model.CLIPVisionModelProjection): return self.fuse_module(prompt_embeds, id_embeds, class_tokens_mask) +class PhotoMakerLoader(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="PhotoMakerLoader_V3", + category="_for_testing/photomaker", + inputs=[ + io.Combo.Input("photomaker_model_name", options=folder_paths.get_filename_list("photomaker")), + ], + outputs=[ + io.Photomaker.Output(), + ], + is_experimental=True, + ) + + @classmethod + def execute(cls, photomaker_model_name): + photomaker_model_path = folder_paths.get_full_path_or_raise("photomaker", photomaker_model_name) + photomaker_model = PhotoMakerIDEncoder() + data = comfy.utils.load_torch_file(photomaker_model_path, safe_load=True) + if "id_encoder" in data: + data = data["id_encoder"] + photomaker_model.load_state_dict(data) + return io.NodeOutput(photomaker_model) + + class PhotoMakerEncode(io.ComfyNode): @classmethod def define_schema(cls): @@ -173,33 +199,7 @@ class PhotoMakerEncode(io.ComfyNode): return io.NodeOutput([[out, {"pooled_output": pooled}]]) -class PhotoMakerLoader(io.ComfyNode): - @classmethod - def define_schema(cls): - return io.Schema( - node_id="PhotoMakerLoader_V3", - category="_for_testing/photomaker", - inputs=[ - io.Combo.Input("photomaker_model_name", options=folder_paths.get_filename_list("photomaker")), - ], - outputs=[ - io.Photomaker.Output(), - ], - is_experimental=True, - ) - - @classmethod - def execute(cls, photomaker_model_name): - photomaker_model_path = folder_paths.get_full_path_or_raise("photomaker", photomaker_model_name) - photomaker_model = PhotoMakerIDEncoder() - data = comfy.utils.load_torch_file(photomaker_model_path, safe_load=True) - if "id_encoder" in data: - data = data["id_encoder"] - photomaker_model.load_state_dict(data) - return io.NodeOutput(photomaker_model) - - -NODES_LIST = [ +NODES_LIST: list[type[io.ComfyNode]] = [ PhotoMakerEncode, PhotoMakerLoader, ] diff --git a/comfy_extras/v3/nodes_post_processing.py b/comfy_extras/v3/nodes_post_processing.py index 1b715f33c..c09da6c83 100644 --- a/comfy_extras/v3/nodes_post_processing.py +++ b/comfy_extras/v3/nodes_post_processing.py @@ -13,13 +13,6 @@ import node_helpers from comfy_api.latest import io -def gaussian_kernel(kernel_size: int, sigma: float, device=None): - x, y = torch.meshgrid(torch.linspace(-1, 1, kernel_size, device=device), torch.linspace(-1, 1, kernel_size, device=device), indexing="ij") - d = torch.sqrt(x * x + y * y) - g = torch.exp(-(d * d) / (2.0 * sigma * sigma)) - return g / g.sum() - - class Blend(io.ComfyNode): @classmethod def define_schema(cls): @@ -109,36 +102,11 @@ class Blur(io.ComfyNode): return io.NodeOutput(blurred.to(comfy.model_management.intermediate_device())) -class ImageScaleToTotalPixels(io.ComfyNode): - upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "lanczos"] - crop_methods = ["disabled", "center"] - - @classmethod - def define_schema(cls): - return io.Schema( - node_id="ImageScaleToTotalPixels_V3", - category="image/upscaling", - inputs=[ - io.Image.Input("image"), - io.Combo.Input("upscale_method", options=cls.upscale_methods), - io.Float.Input("megapixels", default=1.0, min=0.01, max=16.0, step=0.01), - ], - outputs=[ - io.Image.Output(), - ], - ) - - @classmethod - def execute(cls, image, upscale_method, megapixels): - samples = image.movedim(-1,1) - total = int(megapixels * 1024 * 1024) - - scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2])) - width = round(samples.shape[3] * scale_by) - height = round(samples.shape[2] * scale_by) - - s = comfy.utils.common_upscale(samples, width, height, upscale_method, "disabled") - return io.NodeOutput(s.movedim(1,-1)) +def gaussian_kernel(kernel_size: int, sigma: float, device=None): + x, y = torch.meshgrid(torch.linspace(-1, 1, kernel_size, device=device), torch.linspace(-1, 1, kernel_size, device=device), indexing="ij") + d = torch.sqrt(x * x + y * y) + g = torch.exp(-(d * d) / (2.0 * sigma * sigma)) + return g / g.sum() class Quantize(io.ComfyNode): @@ -246,7 +214,39 @@ class Sharpen(io.ComfyNode): return io.NodeOutput(result.to(comfy.model_management.intermediate_device())) -NODES_LIST = [ +class ImageScaleToTotalPixels(io.ComfyNode): + upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "lanczos"] + crop_methods = ["disabled", "center"] + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="ImageScaleToTotalPixels_V3", + category="image/upscaling", + inputs=[ + io.Image.Input("image"), + io.Combo.Input("upscale_method", options=cls.upscale_methods), + io.Float.Input("megapixels", default=1.0, min=0.01, max=16.0, step=0.01), + ], + outputs=[ + io.Image.Output(), + ], + ) + + @classmethod + def execute(cls, image, upscale_method, megapixels): + samples = image.movedim(-1,1) + total = int(megapixels * 1024 * 1024) + + scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2])) + width = round(samples.shape[3] * scale_by) + height = round(samples.shape[2] * scale_by) + + s = comfy.utils.common_upscale(samples, width, height, upscale_method, "disabled") + return io.NodeOutput(s.movedim(1,-1)) + + +NODES_LIST: list[type[io.ComfyNode]] = [ Blend, Blur, ImageScaleToTotalPixels, diff --git a/comfy_extras/v3/nodes_rebatch.py b/comfy_extras/v3/nodes_rebatch.py index 26fabde1b..7922de727 100644 --- a/comfy_extras/v3/nodes_rebatch.py +++ b/comfy_extras/v3/nodes_rebatch.py @@ -142,7 +142,7 @@ class LatentRebatch(io.ComfyNode): return io.NodeOutput(output_list) -NODES_LIST = [ +NODES_LIST: list[type[io.ComfyNode]] = [ ImageRebatch, LatentRebatch, ] diff --git a/comfy_extras/v3/nodes_sd3.py b/comfy_extras/v3/nodes_sd3.py index d7401aad0..582eecc3f 100644 --- a/comfy_extras/v3/nodes_sd3.py +++ b/comfy_extras/v3/nodes_sd3.py @@ -10,6 +10,59 @@ from comfy_api.latest import io from comfy_extras.v3.nodes_slg import SkipLayerGuidanceDiT +class TripleCLIPLoader(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="TripleCLIPLoader_V3", + category="advanced/loaders", + description="[Recipes]\n\nsd3: clip-l, clip-g, t5", + inputs=[ + io.Combo.Input("clip_name1", options=folder_paths.get_filename_list("text_encoders")), + io.Combo.Input("clip_name2", options=folder_paths.get_filename_list("text_encoders")), + io.Combo.Input("clip_name3", options=folder_paths.get_filename_list("text_encoders")), + ], + outputs=[ + io.Clip.Output(), + ], + ) + + @classmethod + def execute(cls, clip_name1: str, clip_name2: str, clip_name3: str): + clip_path1 = folder_paths.get_full_path_or_raise("text_encoders", clip_name1) + clip_path2 = folder_paths.get_full_path_or_raise("text_encoders", clip_name2) + clip_path3 = folder_paths.get_full_path_or_raise("text_encoders", clip_name3) + clip = comfy.sd.load_clip( + ckpt_paths=[clip_path1, clip_path2, clip_path3], + embedding_directory=folder_paths.get_folder_paths("embeddings"), + ) + return io.NodeOutput(clip) + + +class EmptySD3LatentImage(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="EmptySD3LatentImage_V3", + category="latent/sd3", + inputs=[ + io.Int.Input("width", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("height", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("batch_size", default=1, min=1, max=4096), + ], + outputs=[ + io.Latent.Output(), + ], + ) + + @classmethod + def execute(cls, width: int, height: int, batch_size=1): + latent = torch.zeros( + [batch_size, 16, height // 8, width // 8], device=comfy.model_management.intermediate_device() + ) + return io.NodeOutput({"samples":latent}) + + class CLIPTextEncodeSD3(io.ComfyNode): @classmethod def define_schema(cls): @@ -54,30 +107,6 @@ class CLIPTextEncodeSD3(io.ComfyNode): return io.NodeOutput(clip.encode_from_tokens_scheduled(tokens)) -class EmptySD3LatentImage(io.ComfyNode): - @classmethod - def define_schema(cls): - return io.Schema( - node_id="EmptySD3LatentImage_V3", - category="latent/sd3", - inputs=[ - io.Int.Input("width", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=16), - io.Int.Input("height", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=16), - io.Int.Input("batch_size", default=1, min=1, max=4096), - ], - outputs=[ - io.Latent.Output(), - ], - ) - - @classmethod - def execute(cls, width: int, height: int, batch_size=1): - latent = torch.zeros( - [batch_size, 16, height // 8, width // 8], device=comfy.model_management.intermediate_device() - ) - return io.NodeOutput({"samples":latent}) - - class SkipLayerGuidanceSD3(SkipLayerGuidanceDiT): """ Enhance guidance towards detailed dtructure by having another set of CFG negative with skipped layers. @@ -108,36 +137,7 @@ class SkipLayerGuidanceSD3(SkipLayerGuidanceDiT): model=model, scale=scale, start_percent=start_percent, end_percent=end_percent, double_layers=layers ) - -class TripleCLIPLoader(io.ComfyNode): - @classmethod - def define_schema(cls): - return io.Schema( - node_id="TripleCLIPLoader_V3", - category="advanced/loaders", - description="[Recipes]\n\nsd3: clip-l, clip-g, t5", - inputs=[ - io.Combo.Input("clip_name1", options=folder_paths.get_filename_list("text_encoders")), - io.Combo.Input("clip_name2", options=folder_paths.get_filename_list("text_encoders")), - io.Combo.Input("clip_name3", options=folder_paths.get_filename_list("text_encoders")), - ], - outputs=[ - io.Clip.Output(), - ], - ) - - @classmethod - def execute(cls, clip_name1: str, clip_name2: str, clip_name3: str): - clip_path1 = folder_paths.get_full_path_or_raise("text_encoders", clip_name1) - clip_path2 = folder_paths.get_full_path_or_raise("text_encoders", clip_name2) - clip_path3 = folder_paths.get_full_path_or_raise("text_encoders", clip_name3) - clip = comfy.sd.load_clip( - ckpt_paths=[clip_path1, clip_path2, clip_path3], - embedding_directory=folder_paths.get_folder_paths("embeddings"), - ) - return io.NodeOutput(clip) - -NODES_LIST = [ +NODES_LIST: list[type[io.ComfyNode]] = [ CLIPTextEncodeSD3, EmptySD3LatentImage, SkipLayerGuidanceSD3, diff --git a/comfy_extras/v3/nodes_slg.py b/comfy_extras/v3/nodes_slg.py index d98c225c2..27a8b3f93 100644 --- a/comfy_extras/v3/nodes_slg.py +++ b/comfy_extras/v3/nodes_slg.py @@ -167,7 +167,7 @@ class SkipLayerGuidanceDiTSimple(io.ComfyNode): return io.NodeOutput(m) -NODES_LIST = [ +NODES_LIST: list[type[io.ComfyNode]] = [ SkipLayerGuidanceDiT, SkipLayerGuidanceDiTSimple, ] diff --git a/comfy_extras/v3/nodes_stable3d.py b/comfy_extras/v3/nodes_stable3d.py index 9993533fc..fb47da835 100644 --- a/comfy_extras/v3/nodes_stable3d.py +++ b/comfy_extras/v3/nodes_stable3d.py @@ -158,7 +158,7 @@ class SV3D_Conditioning(io.ComfyNode): return io.NodeOutput(positive, negative, {"samples":latent}) -NODES_LIST = [ +NODES_LIST: list[type[io.ComfyNode]] = [ StableZero123_Conditioning, StableZero123_Conditioning_Batched, SV3D_Conditioning, diff --git a/comfy_extras/v3/nodes_train.py b/comfy_extras/v3/nodes_train.py index 9afc9d93c..695982b80 100644 --- a/comfy_extras/v3/nodes_train.py +++ b/comfy_extras/v3/nodes_train.py @@ -162,57 +162,6 @@ def load_and_process_images(image_files, input_dir, resize_method="None", w=None return torch.cat(output_images, dim=0) -def draw_loss_graph(loss_map, steps): - width, height = 500, 300 - img = Image.new("RGB", (width, height), "white") - draw = ImageDraw.Draw(img) - - min_loss, max_loss = min(loss_map.values()), max(loss_map.values()) - scaled_loss = [(l_v - min_loss) / (max_loss - min_loss) for l_v in loss_map.values()] - - prev_point = (0, height - int(scaled_loss[0] * height)) - for i, l_v in enumerate(scaled_loss[1:], start=1): - x = int(i / (steps - 1) * width) - y = height - int(l_v * height) - draw.line([prev_point, (x, y)], fill="blue", width=2) - prev_point = (x, y) - - return img - - -def find_all_highest_child_module_with_forward(model: torch.nn.Module, result = None, name = None): - if result is None: - result = [] - elif hasattr(model, "forward") and not isinstance(model, (torch.nn.ModuleList, torch.nn.Sequential, torch.nn.ModuleDict)): - result.append(model) - logging.debug(f"Found module with forward: {name} ({model.__class__.__name__})") - return result - name = name or "root" - for next_name, child in model.named_children(): - find_all_highest_child_module_with_forward(child, result, f"{name}.{next_name}") - return result - - -def patch(m): - if not hasattr(m, "forward"): - return - org_forward = m.forward - def fwd(args, kwargs): - return org_forward(*args, **kwargs) - def checkpointing_fwd(*args, **kwargs): - return torch.utils.checkpoint.checkpoint( - fwd, args, kwargs, use_reentrant=False - ) - m.org_forward = org_forward - m.forward = checkpointing_fwd - - -def unpatch(m): - if hasattr(m, "org_forward"): - m.forward = m.org_forward - del m.org_forward - - class LoadImageSetFromFolderNode(io.ComfyNode): @classmethod def define_schema(cls): @@ -328,126 +277,55 @@ class LoadImageTextSetFromFolderNode(io.ComfyNode): return io.NodeOutput(output_tensor, conditions) -class LoraModelLoader(io.ComfyNode): - @classmethod - def define_schema(cls): - return io.Schema( - node_id="LoraModelLoader_V3", - display_name="Load LoRA Model _V3", - category="loaders", - description="Load Trained LoRA weights from Train LoRA node.", - is_experimental=True, - inputs=[ - io.Model.Input("model", tooltip="The diffusion model the LoRA will be applied to."), - io.LoraModel.Input("lora", tooltip="The LoRA model to apply to the diffusion model."), - io.Float.Input("strength_model", default=1.0, min=-100.0, max=100.0, step=0.01, tooltip="How strongly to modify the diffusion model. This value can be negative."), - ], - outputs=[ - io.Model.Output(tooltip="The modified diffusion model."), - ], +def draw_loss_graph(loss_map, steps): + width, height = 500, 300 + img = Image.new("RGB", (width, height), "white") + draw = ImageDraw.Draw(img) + + min_loss, max_loss = min(loss_map.values()), max(loss_map.values()) + scaled_loss = [(l_v - min_loss) / (max_loss - min_loss) for l_v in loss_map.values()] + + prev_point = (0, height - int(scaled_loss[0] * height)) + for i, l_v in enumerate(scaled_loss[1:], start=1): + x = int(i / (steps - 1) * width) + y = height - int(l_v * height) + draw.line([prev_point, (x, y)], fill="blue", width=2) + prev_point = (x, y) + + return img + + +def find_all_highest_child_module_with_forward(model: torch.nn.Module, result = None, name = None): + if result is None: + result = [] + elif hasattr(model, "forward") and not isinstance(model, (torch.nn.ModuleList, torch.nn.Sequential, torch.nn.ModuleDict)): + result.append(model) + logging.debug(f"Found module with forward: {name} ({model.__class__.__name__})") + return result + name = name or "root" + for next_name, child in model.named_children(): + find_all_highest_child_module_with_forward(child, result, f"{name}.{next_name}") + return result + + +def patch(m): + if not hasattr(m, "forward"): + return + org_forward = m.forward + def fwd(args, kwargs): + return org_forward(*args, **kwargs) + def checkpointing_fwd(*args, **kwargs): + return torch.utils.checkpoint.checkpoint( + fwd, args, kwargs, use_reentrant=False ) - - @classmethod - def execute(cls, model, lora, strength_model): - if strength_model == 0: - return io.NodeOutput(model) - - model_lora, _ = comfy.sd.load_lora_for_models(model, None, lora, strength_model, 0) - return io.NodeOutput(model_lora) + m.org_forward = org_forward + m.forward = checkpointing_fwd -class LossGraphNode(io.ComfyNode): - @classmethod - def define_schema(cls): - return io.Schema( - node_id="LossGraphNode_V3", - display_name="Plot Loss Graph _V3", - category="training", - description="Plots the loss graph and saves it to the output directory.", - is_experimental=True, - is_output_node=True, - inputs=[ - io.LossMap.Input("loss"), # TODO: original V1 node has also `default={}` parameter - io.String.Input("filename_prefix", default="loss_graph"), - ], - outputs=[], - hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo], - ) - - @classmethod - def execute(cls, loss, filename_prefix): - loss_values = loss["loss"] - width, height = 800, 480 - margin = 40 - - img = Image.new( - "RGB", (width + margin, height + margin), "white" - ) # Extend canvas - draw = ImageDraw.Draw(img) - - min_loss, max_loss = min(loss_values), max(loss_values) - scaled_loss = [(l_v - min_loss) / (max_loss - min_loss) for l_v in loss_values] - - steps = len(loss_values) - - prev_point = (margin, height - int(scaled_loss[0] * height)) - for i, l_v in enumerate(scaled_loss[1:], start=1): - x = margin + int(i / steps * width) # Scale X properly - y = height - int(l_v * height) - draw.line([prev_point, (x, y)], fill="blue", width=2) - prev_point = (x, y) - - draw.line([(margin, 0), (margin, height)], fill="black", width=2) # Y-axis - draw.line( - [(margin, height), (width + margin, height)], fill="black", width=2 - ) # X-axis - - try: - font = ImageFont.truetype("arial.ttf", 12) - except IOError: - font = ImageFont.load_default() - - # Add axis labels - draw.text((5, height // 2), "Loss", font=font, fill="black") - draw.text((width // 2, height + 10), "Steps", font=font, fill="black") - - # Add min/max loss values - draw.text((margin - 30, 0), f"{max_loss:.2f}", font=font, fill="black") - draw.text( - (margin - 30, height - 10), f"{min_loss:.2f}", font=font, fill="black" - ) - return io.NodeOutput(ui=ui.PreviewImage(img, cls=cls)) - - -class SaveLoRA(io.ComfyNode): - @classmethod - def define_schema(cls): - return io.Schema( - node_id="SaveLoRA_V3", - display_name="Save LoRA Weights _V3", - category="loaders", - is_experimental=True, - is_output_node=True, - inputs=[ - io.LoraModel.Input("lora", tooltip="The LoRA model to save. Do not use the model with LoRA layers."), - io.String.Input("prefix", default="loras/ComfyUI_trained_lora", tooltip="The prefix to use for the saved LoRA file."), - io.Int.Input("steps", tooltip="Optional: The number of steps to LoRA has been trained for, used to name the saved file.", optional=True), - ], - outputs=[], - ) - - @classmethod - def execute(cls, lora, prefix, steps=None): - full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path( - prefix, folder_paths.get_output_directory() - ) - if steps is None: - output_checkpoint = f"{filename}_{counter:05}_.safetensors" - else: - output_checkpoint = f"{filename}_{steps}_steps_{counter:05}_.safetensors" - output_checkpoint = os.path.join(full_output_folder, output_checkpoint) - safetensors.torch.save_file(lora, output_checkpoint) - return io.NodeOutput() +def unpatch(m): + if hasattr(m, "org_forward"): + m.forward = m.org_forward + del m.org_forward class TrainLoraNode(io.ComfyNode): @@ -656,7 +534,129 @@ class TrainLoraNode(io.ComfyNode): return io.NodeOutput(mp, lora_sd, loss_map, steps + existing_steps) -NODES_LIST = [ +class LoraModelLoader(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="LoraModelLoader_V3", + display_name="Load LoRA Model _V3", + category="loaders", + description="Load Trained LoRA weights from Train LoRA node.", + is_experimental=True, + inputs=[ + io.Model.Input("model", tooltip="The diffusion model the LoRA will be applied to."), + io.LoraModel.Input("lora", tooltip="The LoRA model to apply to the diffusion model."), + io.Float.Input("strength_model", default=1.0, min=-100.0, max=100.0, step=0.01, tooltip="How strongly to modify the diffusion model. This value can be negative."), + ], + outputs=[ + io.Model.Output(tooltip="The modified diffusion model."), + ], + ) + + @classmethod + def execute(cls, model, lora, strength_model): + if strength_model == 0: + return io.NodeOutput(model) + + model_lora, _ = comfy.sd.load_lora_for_models(model, None, lora, strength_model, 0) + return io.NodeOutput(model_lora) + + +class SaveLoRA(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SaveLoRA_V3", + display_name="Save LoRA Weights _V3", + category="loaders", + is_experimental=True, + is_output_node=True, + inputs=[ + io.LoraModel.Input("lora", tooltip="The LoRA model to save. Do not use the model with LoRA layers."), + io.String.Input("prefix", default="loras/ComfyUI_trained_lora", tooltip="The prefix to use for the saved LoRA file."), + io.Int.Input("steps", tooltip="Optional: The number of steps to LoRA has been trained for, used to name the saved file.", optional=True), + ], + outputs=[], + ) + + @classmethod + def execute(cls, lora, prefix, steps=None): + full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path( + prefix, folder_paths.get_output_directory() + ) + if steps is None: + output_checkpoint = f"{filename}_{counter:05}_.safetensors" + else: + output_checkpoint = f"{filename}_{steps}_steps_{counter:05}_.safetensors" + output_checkpoint = os.path.join(full_output_folder, output_checkpoint) + safetensors.torch.save_file(lora, output_checkpoint) + return io.NodeOutput() + + +class LossGraphNode(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="LossGraphNode_V3", + display_name="Plot Loss Graph _V3", + category="training", + description="Plots the loss graph and saves it to the output directory.", + is_experimental=True, + is_output_node=True, + inputs=[ + io.LossMap.Input("loss"), # TODO: original V1 node has also `default={}` parameter + io.String.Input("filename_prefix", default="loss_graph"), + ], + outputs=[], + hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo], + ) + + @classmethod + def execute(cls, loss, filename_prefix): + loss_values = loss["loss"] + width, height = 800, 480 + margin = 40 + + img = Image.new( + "RGB", (width + margin, height + margin), "white" + ) # Extend canvas + draw = ImageDraw.Draw(img) + + min_loss, max_loss = min(loss_values), max(loss_values) + scaled_loss = [(l_v - min_loss) / (max_loss - min_loss) for l_v in loss_values] + + steps = len(loss_values) + + prev_point = (margin, height - int(scaled_loss[0] * height)) + for i, l_v in enumerate(scaled_loss[1:], start=1): + x = margin + int(i / steps * width) # Scale X properly + y = height - int(l_v * height) + draw.line([prev_point, (x, y)], fill="blue", width=2) + prev_point = (x, y) + + draw.line([(margin, 0), (margin, height)], fill="black", width=2) # Y-axis + draw.line( + [(margin, height), (width + margin, height)], fill="black", width=2 + ) # X-axis + + try: + font = ImageFont.truetype("arial.ttf", 12) + except IOError: + font = ImageFont.load_default() + + # Add axis labels + draw.text((5, height // 2), "Loss", font=font, fill="black") + draw.text((width // 2, height + 10), "Steps", font=font, fill="black") + + # Add min/max loss values + draw.text((margin - 30, 0), f"{max_loss:.2f}", font=font, fill="black") + draw.text( + (margin - 30, height - 10), f"{min_loss:.2f}", font=font, fill="black" + ) + return io.NodeOutput(ui=ui.PreviewImage(img, cls=cls)) + + +NODES_LIST: list[type[io.ComfyNode]] = [ LoadImageSetFromFolderNode, LoadImageTextSetFromFolderNode, LoraModelLoader, diff --git a/comfy_extras/v3/nodes_video.py b/comfy_extras/v3/nodes_video.py index 0fb8b0f5b..611469da7 100644 --- a/comfy_extras/v3/nodes_video.py +++ b/comfy_extras/v3/nodes_video.py @@ -15,6 +15,108 @@ from comfy_api.latest import io, ui from comfy_api.util import VideoCodec, VideoComponents, VideoContainer +class SaveWEBM(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SaveWEBM_V3", + category="image/video", + is_experimental=True, + inputs=[ + io.Image.Input("images"), + io.String.Input("filename_prefix", default="ComfyUI"), + io.Combo.Input("codec", options=["vp9", "av1"]), + io.Float.Input("fps", default=24.0, min=0.01, max=1000.0, step=0.01), + io.Float.Input("crf", default=32.0, min=0, max=63.0, step=1, tooltip="Higher crf means lower quality with a smaller file size, lower crf means higher quality higher filesize."), + ], + outputs=[], + hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo], + is_output_node=True, + ) + + @classmethod + def execute(cls, images, codec, fps, filename_prefix, crf): + full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path( + filename_prefix, folder_paths.get_output_directory(), images[0].shape[1], images[0].shape[0] + ) + + file = f"{filename}_{counter:05}_.webm" + container = av.open(os.path.join(full_output_folder, file), mode="w") + + if cls.hidden.prompt is not None: + container.metadata["prompt"] = json.dumps(cls.hidden.prompt) + + if cls.hidden.extra_pnginfo is not None: + for x in cls.hidden.extra_pnginfo: + container.metadata[x] = json.dumps(cls.hidden.extra_pnginfo[x]) + + codec_map = {"vp9": "libvpx-vp9", "av1": "libsvtav1"} + stream = container.add_stream(codec_map[codec], rate=Fraction(round(fps * 1000), 1000)) + stream.width = images.shape[-2] + stream.height = images.shape[-3] + stream.pix_fmt = "yuv420p10le" if codec == "av1" else "yuv420p" + stream.bit_rate = 0 + stream.options = {'crf': str(crf)} + if codec == "av1": + stream.options["preset"] = "6" + + for frame in images: + frame = av.VideoFrame.from_ndarray(torch.clamp(frame[..., :3] * 255, min=0, max=255).to(device=torch.device("cpu"), dtype=torch.uint8).numpy(), format="rgb24") + for packet in stream.encode(frame): + container.mux(packet) + container.mux(stream.encode()) + container.close() + + return io.NodeOutput(ui=ui.PreviewVideo([ui.SavedResult(file, subfolder, io.FolderType.output)])) + + +class SaveVideo(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SaveVideo_V3", + display_name="Save Video _V3", + category="image/video", + description="Saves the input images to your ComfyUI output directory.", + inputs=[ + io.Video.Input("video", tooltip="The video to save."), + io.String.Input("filename_prefix", default="video/ComfyUI", tooltip="The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes."), + io.Combo.Input("format", options=VideoContainer.as_input(), default="auto", tooltip="The format to save the video as."), + io.Combo.Input("codec", options=VideoCodec.as_input(), default="auto", tooltip="The codec to use for the video."), + ], + outputs=[], + hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo], + is_output_node=True, + ) + + @classmethod + def execute(cls, video: VideoInput, filename_prefix, format, codec): + width, height = video.get_dimensions() + full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path( + filename_prefix, + folder_paths.get_output_directory(), + width, + height + ) + saved_metadata = None + if not args.disable_metadata: + metadata = {} + if cls.hidden.extra_pnginfo is not None: + metadata.update(cls.hidden.extra_pnginfo) + if cls.hidden.prompt is not None: + metadata["prompt"] = cls.hidden.prompt + if len(metadata) > 0: + saved_metadata = metadata + file = f"{filename}_{counter:05}_.{VideoContainer.get_extension(format)}" + video.save_to( + os.path.join(full_output_folder, file), + format=format, + codec=codec, + metadata=saved_metadata + ) + return io.NodeOutput(ui=ui.PreviewVideo([ui.SavedResult(file, subfolder, io.FolderType.output)])) + + class CreateVideo(io.ComfyNode): @classmethod def define_schema(cls): @@ -35,13 +137,9 @@ class CreateVideo(io.ComfyNode): @classmethod def execute(cls, images: ImageInput, fps: float, audio: AudioInput = None): - return io.NodeOutput(VideoFromComponents( - VideoComponents( - images=images, - audio=audio, - frame_rate=Fraction(fps), - ) - )) + return io.NodeOutput( + VideoFromComponents(VideoComponents(images=images, audio=audio, frame_rate=Fraction(fps))) + ) class GetVideoComponents(io.ComfyNode): @@ -105,106 +203,10 @@ class LoadVideo(io.ComfyNode): return True -class SaveVideo(io.ComfyNode): - @classmethod - def define_schema(cls): - return io.Schema( - node_id="SaveVideo_V3", - display_name="Save Video _V3", - category="image/video", - description="Saves the input images to your ComfyUI output directory.", - inputs=[ - io.Video.Input("video", tooltip="The video to save."), - io.String.Input("filename_prefix", default="video/ComfyUI", tooltip="The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes."), - io.Combo.Input("format", options=VideoContainer.as_input(), default="auto", tooltip="The format to save the video as."), - io.Combo.Input("codec", options=VideoCodec.as_input(), default="auto", tooltip="The codec to use for the video."), - ], - outputs=[], - hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo], - is_output_node=True, - ) - - @classmethod - def execute(cls, video: VideoInput, filename_prefix, format, codec): - width, height = video.get_dimensions() - full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path( - filename_prefix, - folder_paths.get_output_directory(), - width, - height - ) - saved_metadata = None - if not args.disable_metadata: - metadata = {} - if cls.hidden.extra_pnginfo is not None: - metadata.update(cls.hidden.extra_pnginfo) - if cls.hidden.prompt is not None: - metadata["prompt"] = cls.hidden.prompt - if len(metadata) > 0: - saved_metadata = metadata - file = f"{filename}_{counter:05}_.{VideoContainer.get_extension(format)}" - video.save_to( - os.path.join(full_output_folder, file), - format=format, - codec=codec, - metadata=saved_metadata - ) - return io.NodeOutput(ui=ui.PreviewVideo([ui.SavedResult(file, subfolder, io.FolderType.output)])) - - -class SaveWEBM(io.ComfyNode): - @classmethod - def define_schema(cls): - return io.Schema( - node_id="SaveWEBM_V3", - category="image/video", - is_experimental=True, - inputs=[ - io.Image.Input("images"), - io.String.Input("filename_prefix", default="ComfyUI"), - io.Combo.Input("codec", options=["vp9", "av1"]), - io.Float.Input("fps", default=24.0, min=0.01, max=1000.0, step=0.01), - io.Float.Input("crf", default=32.0, min=0, max=63.0, step=1, tooltip="Higher crf means lower quality with a smaller file size, lower crf means higher quality higher filesize."), - ], - outputs=[], - hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo], - is_output_node=True, - ) - - @classmethod - def execute(cls, images, codec, fps, filename_prefix, crf): - full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path( - filename_prefix, folder_paths.get_output_directory(), images[0].shape[1], images[0].shape[0] - ) - - file = f"{filename}_{counter:05}_.webm" - container = av.open(os.path.join(full_output_folder, file), mode="w") - - if cls.hidden.prompt is not None: - container.metadata["prompt"] = json.dumps(cls.hidden.prompt) - - if cls.hidden.extra_pnginfo is not None: - for x in cls.hidden.extra_pnginfo: - container.metadata[x] = json.dumps(cls.hidden.extra_pnginfo[x]) - - codec_map = {"vp9": "libvpx-vp9", "av1": "libsvtav1"} - stream = container.add_stream(codec_map[codec], rate=Fraction(round(fps * 1000), 1000)) - stream.width = images.shape[-2] - stream.height = images.shape[-3] - stream.pix_fmt = "yuv420p10le" if codec == "av1" else "yuv420p" - stream.bit_rate = 0 - stream.options = {'crf': str(crf)} - if codec == "av1": - stream.options["preset"] = "6" - - for frame in images: - frame = av.VideoFrame.from_ndarray(torch.clamp(frame[..., :3] * 255, min=0, max=255).to(device=torch.device("cpu"), dtype=torch.uint8).numpy(), format="rgb24") - for packet in stream.encode(frame): - container.mux(packet) - container.mux(stream.encode()) - container.close() - - return io.NodeOutput(ui=ui.PreviewVideo([ui.SavedResult(file, subfolder, io.FolderType.output)])) - - -NODES_LIST = [CreateVideo, GetVideoComponents, LoadVideo, SaveVideo, SaveWEBM] +NODES_LIST: list[type[io.ComfyNode]] = [ + CreateVideo, + GetVideoComponents, + LoadVideo, + SaveVideo, + SaveWEBM, +] diff --git a/comfy_extras/v3/nodes_video_model.py b/comfy_extras/v3/nodes_video_model.py index e0ee00d73..9ea4b3546 100644 --- a/comfy_extras/v3/nodes_video_model.py +++ b/comfy_extras/v3/nodes_video_model.py @@ -11,40 +11,6 @@ import nodes from comfy_api.latest import io -class ConditioningSetAreaPercentageVideo(io.ComfyNode): - @classmethod - def define_schema(cls): - return io.Schema( - node_id="ConditioningSetAreaPercentageVideo_V3", - category="conditioning", - inputs=[ - io.Conditioning.Input("conditioning"), - io.Float.Input("width", default=1.0, min=0, max=1.0, step=0.01), - io.Float.Input("height", default=1.0, min=0, max=1.0, step=0.01), - io.Float.Input("temporal", default=1.0, min=0, max=1.0, step=0.01), - io.Float.Input("x", default=0, min=0, max=1.0, step=0.01), - io.Float.Input("y", default=0, min=0, max=1.0, step=0.01), - io.Float.Input("z", default=0, min=0, max=1.0, step=0.01), - io.Float.Input("strength", default=1.0, min=0.0, max=10.0, step=0.01), - ], - outputs=[ - io.Conditioning.Output(), - ], - ) - - @classmethod - def execute(cls, conditioning, width, height, temporal, x, y, z, strength): - c = node_helpers.conditioning_set_values( - conditioning, - { - "area": ("percentage", temporal, height, width, z, y, x), - "strength": strength, - "set_area_to_bounds": False - ,} - ) - return io.NodeOutput(c) - - class ImageOnlyCheckpointLoader(io.ComfyNode): @classmethod def define_schema(cls): @@ -75,37 +41,6 @@ class ImageOnlyCheckpointLoader(io.ComfyNode): return io.NodeOutput(out[0], out[3], out[2]) -class ImageOnlyCheckpointSave(io.ComfyNode): - @classmethod - def define_schema(cls): - return io.Schema( - node_id="ImageOnlyCheckpointSave_V3", - category="advanced/model_merging", - inputs=[ - io.Model.Input("model"), - io.ClipVision.Input("clip_vision"), - io.Vae.Input("vae"), - io.String.Input("filename_prefix", default="checkpoints/ComfyUI"), - ], - outputs=[], - hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo], - ) - - @classmethod - def execute(cls, model, clip_vision, vae, filename_prefix): - output_dir = folder_paths.get_output_directory() - comfy_extras.nodes_model_merging.save_checkpoint( - model, - clip_vision=clip_vision, - vae=vae, - filename_prefix=filename_prefix, - output_dir=output_dir, - prompt=cls.hidden.prompt, - extra_pnginfo=cls.hidden.extra_pnginfo, - ) - return io.NodeOutput() - - class SVD_img2vid_Conditioning(io.ComfyNode): @classmethod def define_schema(cls): @@ -222,7 +157,72 @@ class VideoTriangleCFGGuidance(io.ComfyNode): return io.NodeOutput(m) -NODES_LIST = [ +class ImageOnlyCheckpointSave(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="ImageOnlyCheckpointSave_V3", + category="advanced/model_merging", + inputs=[ + io.Model.Input("model"), + io.ClipVision.Input("clip_vision"), + io.Vae.Input("vae"), + io.String.Input("filename_prefix", default="checkpoints/ComfyUI"), + ], + outputs=[], + hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo], + ) + + @classmethod + def execute(cls, model, clip_vision, vae, filename_prefix): + output_dir = folder_paths.get_output_directory() + comfy_extras.nodes_model_merging.save_checkpoint( + model, + clip_vision=clip_vision, + vae=vae, + filename_prefix=filename_prefix, + output_dir=output_dir, + prompt=cls.hidden.prompt, + extra_pnginfo=cls.hidden.extra_pnginfo, + ) + return io.NodeOutput() + + +class ConditioningSetAreaPercentageVideo(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="ConditioningSetAreaPercentageVideo_V3", + category="conditioning", + inputs=[ + io.Conditioning.Input("conditioning"), + io.Float.Input("width", default=1.0, min=0, max=1.0, step=0.01), + io.Float.Input("height", default=1.0, min=0, max=1.0, step=0.01), + io.Float.Input("temporal", default=1.0, min=0, max=1.0, step=0.01), + io.Float.Input("x", default=0, min=0, max=1.0, step=0.01), + io.Float.Input("y", default=0, min=0, max=1.0, step=0.01), + io.Float.Input("z", default=0, min=0, max=1.0, step=0.01), + io.Float.Input("strength", default=1.0, min=0.0, max=10.0, step=0.01), + ], + outputs=[ + io.Conditioning.Output(), + ], + ) + + @classmethod + def execute(cls, conditioning, width, height, temporal, x, y, z, strength): + c = node_helpers.conditioning_set_values( + conditioning, + { + "area": ("percentage", temporal, height, width, z, y, x), + "strength": strength, + "set_area_to_bounds": False + ,} + ) + return io.NodeOutput(c) + + +NODES_LIST: list[type[io.ComfyNode]] = [ ConditioningSetAreaPercentageVideo, ImageOnlyCheckpointLoader, ImageOnlyCheckpointSave,