ComfyUI/comfy_extras/v3/nodes_model_downscale.py
2025-07-24 18:23:29 -07:00

69 lines
2.7 KiB
Python

from __future__ import annotations
import comfy.utils
from comfy_api.latest import io
class PatchModelAddDownscale(io.ComfyNode):
UPSCALE_METHODS = ["bicubic", "nearest-exact", "bilinear", "area", "bislerp"]
@classmethod
def define_schema(cls):
return io.Schema(
node_id="PatchModelAddDownscale_V3",
display_name="PatchModelAddDownscale (Kohya Deep Shrink) _V3",
category="model_patches/unet",
inputs=[
io.Model.Input("model"),
io.Int.Input("block_number", default=3, min=1, max=32, step=1),
io.Float.Input("downscale_factor", default=2.0, min=0.1, max=9.0, step=0.001),
io.Float.Input("start_percent", default=0.0, min=0.0, max=1.0, step=0.001),
io.Float.Input("end_percent", default=0.35, min=0.0, max=1.0, step=0.001),
io.Boolean.Input("downscale_after_skip", default=True),
io.Combo.Input("downscale_method", options=cls.UPSCALE_METHODS),
io.Combo.Input("upscale_method", options=cls.UPSCALE_METHODS),
],
outputs=[
io.Model.Output(),
],
)
@classmethod
def execute(
cls, model, block_number, downscale_factor, start_percent, end_percent, downscale_after_skip, downscale_method, upscale_method
):
model_sampling = model.get_model_object("model_sampling")
sigma_start = model_sampling.percent_to_sigma(start_percent)
sigma_end = model_sampling.percent_to_sigma(end_percent)
def input_block_patch(h, transformer_options):
if transformer_options["block"][1] == block_number:
sigma = transformer_options["sigmas"][0].item()
if sigma <= sigma_start and sigma >= sigma_end:
h = comfy.utils.common_upscale(
h,
round(h.shape[-1] * (1.0 / downscale_factor)),
round(h.shape[-2] * (1.0 / downscale_factor)),
downscale_method,
"disabled",
)
return h
def output_block_patch(h, hsp, transformer_options):
if h.shape[2] != hsp.shape[2]:
h = comfy.utils.common_upscale(h, hsp.shape[-1], hsp.shape[-2], upscale_method, "disabled")
return h, hsp
m = model.clone()
if downscale_after_skip:
m.set_model_input_block_patch_after_skip(input_block_patch)
else:
m.set_model_input_block_patch(input_block_patch)
m.set_model_output_block_patch(output_block_patch)
return io.NodeOutput(m)
NODES_LIST = [
PatchModelAddDownscale,
]