"""Taken from: https://github.com/tfernd/HyperTile/""" from __future__ import annotations import math from einops import rearrange from torch import randint from comfy_api.latest import io def random_divisor(value: int, min_value: int, /, max_options: int = 1) -> int: min_value = min(min_value, value) # All big divisors of value (inclusive) divisors = [i for i in range(min_value, value + 1) if value % i == 0] ns = [value // i for i in divisors[:max_options]] # has at least 1 element if len(ns) - 1 > 0: idx = randint(low=0, high=len(ns) - 1, size=(1,)).item() else: idx = 0 return ns[idx] class HyperTile(io.ComfyNode): @classmethod def define_schema(cls): return io.Schema( node_id="HyperTile_V3", category="model_patches/unet", inputs=[ io.Model.Input("model"), io.Int.Input("tile_size", default=256, min=1, max=2048), io.Int.Input("swap_size", default=2, min=1, max=128), io.Int.Input("max_depth", default=0, min=0, max=10), io.Boolean.Input("scale_depth", default=False), ], outputs=[ io.Model.Output(), ], ) @classmethod def execute(cls, model, tile_size, swap_size, max_depth, scale_depth): latent_tile_size = max(32, tile_size) // 8 temp = None def hypertile_in(q, k, v, extra_options): nonlocal temp model_chans = q.shape[-2] orig_shape = extra_options['original_shape'] apply_to = [] for i in range(max_depth + 1): apply_to.append((orig_shape[-2] / (2 ** i)) * (orig_shape[-1] / (2 ** i))) if model_chans in apply_to: shape = extra_options["original_shape"] aspect_ratio = shape[-1] / shape[-2] hw = q.size(1) h, w = round(math.sqrt(hw * aspect_ratio)), round(math.sqrt(hw / aspect_ratio)) factor = (2 ** apply_to.index(model_chans)) if scale_depth else 1 nh = random_divisor(h, latent_tile_size * factor, swap_size) nw = random_divisor(w, latent_tile_size * factor, swap_size) if nh * nw > 1: q = rearrange(q, "b (nh h nw w) c -> (b nh nw) (h w) c", h=h // nh, w=w // nw, nh=nh, nw=nw) temp = (nh, nw, h, w) return q, k, v return q, k, v def hypertile_out(out, extra_options): nonlocal temp if temp is not None: nh, nw, h, w = temp temp = None out = rearrange(out, "(b nh nw) hw c -> b nh nw hw c", nh=nh, nw=nw) out = rearrange(out, "b nh nw (h w) c -> b (nh h nw w) c", h=h // nh, w=w // nw) return out m = model.clone() m.set_model_attn1_patch(hypertile_in) m.set_model_attn1_output_patch(hypertile_out) return io.NodeOutput(m) NODES_LIST = [ HyperTile, ]