LTXV lowvram fixes.

This commit is contained in:
comfyanonymous
2024-11-22 17:17:11 -05:00
parent bc6be6c11e
commit e5c3f4b87f
4 changed files with 13 additions and 10 deletions

View File

@@ -4,7 +4,8 @@ import torch
from .dual_conv3d import DualConv3d
from .causal_conv3d import CausalConv3d
import comfy.ops
ops = comfy.ops.disable_weight_init
def make_conv_nd(
dims: Union[int, Tuple[int, int]],
@@ -19,7 +20,7 @@ def make_conv_nd(
causal=False,
):
if dims == 2:
return torch.nn.Conv2d(
return ops.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
@@ -41,7 +42,7 @@ def make_conv_nd(
groups=groups,
bias=bias,
)
return torch.nn.Conv3d(
return ops.Conv3d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
@@ -71,11 +72,11 @@ def make_linear_nd(
bias=True,
):
if dims == 2:
return torch.nn.Conv2d(
return ops.Conv2d(
in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias
)
elif dims == 3 or dims == (2, 1):
return torch.nn.Conv3d(
return ops.Conv3d(
in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias
)
else: