Remove useless code. (#9041)

This is only needed on old pytorch 2.0 and older.
This commit is contained in:
comfyanonymous 2025-07-25 01:58:28 -07:00 committed by GitHub
parent 4293e4da21
commit e6e5d33b35
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -52,15 +52,6 @@ class RMS_norm(nn.Module):
x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma.to(x) + (self.bias.to(x) if self.bias is not None else 0)
class Upsample(nn.Upsample):
def forward(self, x):
"""
Fix bfloat16 support for nearest neighbor interpolation.
"""
return super().forward(x.float()).type_as(x)
class Resample(nn.Module):
def __init__(self, dim, mode):
@ -73,11 +64,11 @@ class Resample(nn.Module):
# layers
if mode == 'upsample2d':
self.resample = nn.Sequential(
Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
nn.Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
ops.Conv2d(dim, dim // 2, 3, padding=1))
elif mode == 'upsample3d':
self.resample = nn.Sequential(
Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
nn.Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
ops.Conv2d(dim, dim // 2, 3, padding=1))
self.time_conv = CausalConv3d(
dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))