import torch.nn as nn import torch.nn.functional as F from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, AttnBlock import comfy.ops ops = comfy.ops.disable_weight_init class PixelShuffle2D(nn.Module): def __init__(self, in_dim, out_dim, op=ops.Conv2d): super().__init__() self.conv = op(in_dim, out_dim >> 2, 3, 1, 1) self.ratio = (in_dim << 2) // out_dim def forward(self, x): b, c, h, w = x.shape h2, w2 = h >> 1, w >> 1 y = self.conv(x).view(b, -1, h2, 2, w2, 2).permute(0, 3, 5, 1, 2, 4).reshape(b, -1, h2, w2) r = x.view(b, c, h2, 2, w2, 2).permute(0, 3, 5, 1, 2, 4).reshape(b, c << 2, h2, w2) return y + r.view(b, y.shape[1], self.ratio, h2, w2).mean(2) class PixelUnshuffle2D(nn.Module): def __init__(self, in_dim, out_dim, op=ops.Conv2d): super().__init__() self.conv = op(in_dim, out_dim << 2, 3, 1, 1) self.scale = (out_dim << 2) // in_dim def forward(self, x): b, c, h, w = x.shape h2, w2 = h << 1, w << 1 y = self.conv(x).view(b, 2, 2, -1, h, w).permute(0, 3, 4, 1, 5, 2).reshape(b, -1, h2, w2) r = x.repeat_interleave(self.scale, 1).view(b, 2, 2, -1, h, w).permute(0, 3, 4, 1, 5, 2).reshape(b, -1, h2, w2) return y + r class Encoder(nn.Module): def __init__(self, in_channels, z_channels, block_out_channels, num_res_blocks, ffactor_spatial, downsample_match_channel=True, **_): super().__init__() self.z_channels = z_channels self.block_out_channels = block_out_channels self.num_res_blocks = num_res_blocks self.conv_in = ops.Conv2d(in_channels, block_out_channels[0], 3, 1, 1) self.down = nn.ModuleList() ch = block_out_channels[0] depth = (ffactor_spatial >> 1).bit_length() for i, tgt in enumerate(block_out_channels): stage = nn.Module() stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt, out_channels=tgt, temb_channels=0, conv_op=ops.Conv2d) for j in range(num_res_blocks)]) ch = tgt if i < depth: nxt = block_out_channels[i + 1] if i + 1 < len(block_out_channels) and downsample_match_channel else ch stage.downsample = PixelShuffle2D(ch, nxt, ops.Conv2d) ch = nxt self.down.append(stage) self.mid = nn.Module() self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=ops.Conv2d) self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv2d) self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=ops.Conv2d) self.norm_out = nn.GroupNorm(32, ch, 1e-6, True) self.conv_out = ops.Conv2d(ch, z_channels << 1, 3, 1, 1) def forward(self, x): x = self.conv_in(x) for stage in self.down: for blk in stage.block: x = blk(x) if hasattr(stage, 'downsample'): x = stage.downsample(x) x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(x))) b, c, h, w = x.shape grp = c // (self.z_channels << 1) skip = x.view(b, c // grp, grp, h, w).mean(2) return self.conv_out(F.silu(self.norm_out(x))) + skip class Decoder(nn.Module): def __init__(self, z_channels, out_channels, block_out_channels, num_res_blocks, ffactor_spatial, upsample_match_channel=True, **_): super().__init__() block_out_channels = block_out_channels[::-1] self.z_channels = z_channels self.block_out_channels = block_out_channels self.num_res_blocks = num_res_blocks ch = block_out_channels[0] self.conv_in = ops.Conv2d(z_channels, ch, 3, 1, 1) self.mid = nn.Module() self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=ops.Conv2d) self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv2d) self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=ops.Conv2d) self.up = nn.ModuleList() depth = (ffactor_spatial >> 1).bit_length() for i, tgt in enumerate(block_out_channels): stage = nn.Module() stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt, out_channels=tgt, temb_channels=0, conv_op=ops.Conv2d) for j in range(num_res_blocks + 1)]) ch = tgt if i < depth: nxt = block_out_channels[i + 1] if i + 1 < len(block_out_channels) and upsample_match_channel else ch stage.upsample = PixelUnshuffle2D(ch, nxt, ops.Conv2d) ch = nxt self.up.append(stage) self.norm_out = nn.GroupNorm(32, ch, 1e-6, True) self.conv_out = ops.Conv2d(ch, out_channels, 3, 1, 1) def forward(self, z): x = self.conv_in(z) + z.repeat_interleave(self.block_out_channels[0] // self.z_channels, 1) x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(x))) for stage in self.up: for blk in stage.block: x = blk(x) if hasattr(stage, 'upsample'): x = stage.upsample(x) return self.conv_out(F.silu(self.norm_out(x)))