ONNX tracing fixes.

This commit is contained in:
comfyanonymous
2024-08-04 15:45:43 -04:00
parent 0a6b008117
commit 3b71f84b50
5 changed files with 16 additions and 13 deletions

View File

@@ -9,6 +9,7 @@ import torch.nn.functional as F
from comfy.ldm.modules.attention import optimized_attention
import comfy.ops
import comfy.ldm.common_dit
def modulate(x, shift, scale):
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
@@ -407,10 +408,7 @@ class MMDiT(nn.Module):
def patchify(self, x):
B, C, H, W = x.size()
pad_h = (self.patch_size - H % self.patch_size) % self.patch_size
pad_w = (self.patch_size - W % self.patch_size) % self.patch_size
x = torch.nn.functional.pad(x, (0, pad_w, 0, pad_h), mode='circular')
x = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size))
x = x.view(
B,
C,