mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-09-13 13:05:07 +00:00
137 lines
5.6 KiB
Python
137 lines
5.6 KiB
Python
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, AttnBlock
|
|
import comfy.ops
|
|
ops = comfy.ops.disable_weight_init
|
|
|
|
|
|
class PixelShuffle2D(nn.Module):
|
|
def __init__(self, in_dim, out_dim, op=ops.Conv2d):
|
|
super().__init__()
|
|
self.conv = op(in_dim, out_dim >> 2, 3, 1, 1)
|
|
self.ratio = (in_dim << 2) // out_dim
|
|
|
|
def forward(self, x):
|
|
b, c, h, w = x.shape
|
|
h2, w2 = h >> 1, w >> 1
|
|
y = self.conv(x).view(b, -1, h2, 2, w2, 2).permute(0, 3, 5, 1, 2, 4).reshape(b, -1, h2, w2)
|
|
r = x.view(b, c, h2, 2, w2, 2).permute(0, 3, 5, 1, 2, 4).reshape(b, c << 2, h2, w2)
|
|
return y + r.view(b, y.shape[1], self.ratio, h2, w2).mean(2)
|
|
|
|
|
|
class PixelUnshuffle2D(nn.Module):
|
|
def __init__(self, in_dim, out_dim, op=ops.Conv2d):
|
|
super().__init__()
|
|
self.conv = op(in_dim, out_dim << 2, 3, 1, 1)
|
|
self.scale = (out_dim << 2) // in_dim
|
|
|
|
def forward(self, x):
|
|
b, c, h, w = x.shape
|
|
h2, w2 = h << 1, w << 1
|
|
y = self.conv(x).view(b, 2, 2, -1, h, w).permute(0, 3, 4, 1, 5, 2).reshape(b, -1, h2, w2)
|
|
r = x.repeat_interleave(self.scale, 1).view(b, 2, 2, -1, h, w).permute(0, 3, 4, 1, 5, 2).reshape(b, -1, h2, w2)
|
|
return y + r
|
|
|
|
|
|
class Encoder(nn.Module):
|
|
def __init__(self, in_channels, z_channels, block_out_channels, num_res_blocks,
|
|
ffactor_spatial, downsample_match_channel=True, **_):
|
|
super().__init__()
|
|
self.z_channels = z_channels
|
|
self.block_out_channels = block_out_channels
|
|
self.num_res_blocks = num_res_blocks
|
|
self.conv_in = ops.Conv2d(in_channels, block_out_channels[0], 3, 1, 1)
|
|
|
|
self.down = nn.ModuleList()
|
|
ch = block_out_channels[0]
|
|
depth = (ffactor_spatial >> 1).bit_length()
|
|
|
|
for i, tgt in enumerate(block_out_channels):
|
|
stage = nn.Module()
|
|
stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt,
|
|
out_channels=tgt,
|
|
temb_channels=0,
|
|
conv_op=ops.Conv2d)
|
|
for j in range(num_res_blocks)])
|
|
ch = tgt
|
|
if i < depth:
|
|
nxt = block_out_channels[i + 1] if i + 1 < len(block_out_channels) and downsample_match_channel else ch
|
|
stage.downsample = PixelShuffle2D(ch, nxt, ops.Conv2d)
|
|
ch = nxt
|
|
self.down.append(stage)
|
|
|
|
self.mid = nn.Module()
|
|
self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=ops.Conv2d)
|
|
self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv2d)
|
|
self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=ops.Conv2d)
|
|
|
|
self.norm_out = nn.GroupNorm(32, ch, 1e-6, True)
|
|
self.conv_out = ops.Conv2d(ch, z_channels << 1, 3, 1, 1)
|
|
|
|
def forward(self, x):
|
|
x = self.conv_in(x)
|
|
|
|
for stage in self.down:
|
|
for blk in stage.block:
|
|
x = blk(x)
|
|
if hasattr(stage, 'downsample'):
|
|
x = stage.downsample(x)
|
|
|
|
x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(x)))
|
|
|
|
b, c, h, w = x.shape
|
|
grp = c // (self.z_channels << 1)
|
|
skip = x.view(b, c // grp, grp, h, w).mean(2)
|
|
|
|
return self.conv_out(F.silu(self.norm_out(x))) + skip
|
|
|
|
|
|
class Decoder(nn.Module):
|
|
def __init__(self, z_channels, out_channels, block_out_channels, num_res_blocks,
|
|
ffactor_spatial, upsample_match_channel=True, **_):
|
|
super().__init__()
|
|
block_out_channels = block_out_channels[::-1]
|
|
self.z_channels = z_channels
|
|
self.block_out_channels = block_out_channels
|
|
self.num_res_blocks = num_res_blocks
|
|
|
|
ch = block_out_channels[0]
|
|
self.conv_in = ops.Conv2d(z_channels, ch, 3, 1, 1)
|
|
|
|
self.mid = nn.Module()
|
|
self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=ops.Conv2d)
|
|
self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv2d)
|
|
self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=ops.Conv2d)
|
|
|
|
self.up = nn.ModuleList()
|
|
depth = (ffactor_spatial >> 1).bit_length()
|
|
|
|
for i, tgt in enumerate(block_out_channels):
|
|
stage = nn.Module()
|
|
stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt,
|
|
out_channels=tgt,
|
|
temb_channels=0,
|
|
conv_op=ops.Conv2d)
|
|
for j in range(num_res_blocks + 1)])
|
|
ch = tgt
|
|
if i < depth:
|
|
nxt = block_out_channels[i + 1] if i + 1 < len(block_out_channels) and upsample_match_channel else ch
|
|
stage.upsample = PixelUnshuffle2D(ch, nxt, ops.Conv2d)
|
|
ch = nxt
|
|
self.up.append(stage)
|
|
|
|
self.norm_out = nn.GroupNorm(32, ch, 1e-6, True)
|
|
self.conv_out = ops.Conv2d(ch, out_channels, 3, 1, 1)
|
|
|
|
def forward(self, z):
|
|
x = self.conv_in(z) + z.repeat_interleave(self.block_out_channels[0] // self.z_channels, 1)
|
|
x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(x)))
|
|
|
|
for stage in self.up:
|
|
for blk in stage.block:
|
|
x = blk(x)
|
|
if hasattr(stage, 'upsample'):
|
|
x = stage.upsample(x)
|
|
|
|
return self.conv_out(F.silu(self.norm_out(x)))
|