from __future__ import annotations import comfy.utils from comfy_api.v3 import io class PatchModelAddDownscale(io.ComfyNodeV3): UPSCALE_METHODS = ["bicubic", "nearest-exact", "bilinear", "area", "bislerp"] @classmethod def define_schema(cls): return io.Schema( node_id="PatchModelAddDownscale_V3", display_name="PatchModelAddDownscale (Kohya Deep Shrink) _V3", category="model_patches/unet", inputs=[ io.Model.Input("model"), io.Int.Input("block_number", default=3, min=1, max=32, step=1), io.Float.Input("downscale_factor", default=2.0, min=0.1, max=9.0, step=0.001), io.Float.Input("start_percent", default=0.0, min=0.0, max=1.0, step=0.001), io.Float.Input("end_percent", default=0.35, min=0.0, max=1.0, step=0.001), io.Boolean.Input("downscale_after_skip", default=True), io.Combo.Input("downscale_method", options=cls.UPSCALE_METHODS), io.Combo.Input("upscale_method", options=cls.UPSCALE_METHODS), ], outputs=[ io.Model.Output(), ], ) @classmethod def execute( cls, model, block_number, downscale_factor, start_percent, end_percent, downscale_after_skip, downscale_method, upscale_method ): model_sampling = model.get_model_object("model_sampling") sigma_start = model_sampling.percent_to_sigma(start_percent) sigma_end = model_sampling.percent_to_sigma(end_percent) def input_block_patch(h, transformer_options): if transformer_options["block"][1] == block_number: sigma = transformer_options["sigmas"][0].item() if sigma <= sigma_start and sigma >= sigma_end: h = comfy.utils.common_upscale( h, round(h.shape[-1] * (1.0 / downscale_factor)), round(h.shape[-2] * (1.0 / downscale_factor)), downscale_method, "disabled", ) return h def output_block_patch(h, hsp, transformer_options): if h.shape[2] != hsp.shape[2]: h = comfy.utils.common_upscale(h, hsp.shape[-1], hsp.shape[-2], upscale_method, "disabled") return h, hsp m = model.clone() if downscale_after_skip: m.set_model_input_block_patch_after_skip(input_block_patch) else: m.set_model_input_block_patch(input_block_patch) m.set_model_output_block_patch(output_block_patch) return io.NodeOutput(m) NODES_LIST = [ PatchModelAddDownscale, ]