Small refactor of some vae code. (#9787)

This commit is contained in:
comfyanonymous
2025-09-09 15:09:56 -07:00
committed by GitHub
parent f73b176abd
commit b288fb0db8

View File

@@ -145,7 +145,7 @@ class Downsample(nn.Module):
class ResnetBlock(nn.Module): class ResnetBlock(nn.Module):
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
dropout, temb_channels=512, conv_op=ops.Conv2d): dropout=0.0, temb_channels=512, conv_op=ops.Conv2d):
super().__init__() super().__init__()
self.in_channels = in_channels self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels out_channels = in_channels if out_channels is None else out_channels
@@ -183,7 +183,7 @@ class ResnetBlock(nn.Module):
stride=1, stride=1,
padding=0) padding=0)
def forward(self, x, temb): def forward(self, x, temb=None):
h = x h = x
h = self.norm1(h) h = self.norm1(h)
h = self.swish(h) h = self.swish(h)