ONNX tracing fixes.

This commit is contained in:
comfyanonymous
2024-08-04 15:45:43 -04:00
parent 0a6b008117
commit 3b71f84b50
5 changed files with 16 additions and 13 deletions

View File

@@ -15,6 +15,7 @@ from .layers import (
)
from einops import rearrange, repeat
import comfy.ldm.common_dit
@dataclass
class FluxParams:
@@ -42,7 +43,7 @@ class Flux(nn.Module):
self.dtype = dtype
params = FluxParams(**kwargs)
self.params = params
self.in_channels = params.in_channels
self.in_channels = params.in_channels * 2 * 2
self.out_channels = self.in_channels
if params.hidden_size % params.num_heads != 0:
raise ValueError(
@@ -125,10 +126,7 @@ class Flux(nn.Module):
def forward(self, x, timestep, context, y, guidance, **kwargs):
bs, c, h, w = x.shape
patch_size = 2
pad_h = (patch_size - h % 2) % patch_size
pad_w = (patch_size - w % 2) % patch_size
x = torch.nn.functional.pad(x, (0, pad_w, 0, pad_h), mode='circular')
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)