mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-09-11 03:58:22 +00:00
Small refactor of some vae code. (#9787)
This commit is contained in:
@@ -145,7 +145,7 @@ class Downsample(nn.Module):
|
|||||||
|
|
||||||
class ResnetBlock(nn.Module):
|
class ResnetBlock(nn.Module):
|
||||||
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
|
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
|
||||||
dropout, temb_channels=512, conv_op=ops.Conv2d):
|
dropout=0.0, temb_channels=512, conv_op=ops.Conv2d):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
out_channels = in_channels if out_channels is None else out_channels
|
out_channels = in_channels if out_channels is None else out_channels
|
||||||
@@ -183,7 +183,7 @@ class ResnetBlock(nn.Module):
|
|||||||
stride=1,
|
stride=1,
|
||||||
padding=0)
|
padding=0)
|
||||||
|
|
||||||
def forward(self, x, temb):
|
def forward(self, x, temb=None):
|
||||||
h = x
|
h = x
|
||||||
h = self.norm1(h)
|
h = self.norm1(h)
|
||||||
h = self.swish(h)
|
h = self.swish(h)
|
||||||
|
Reference in New Issue
Block a user