mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-09-10 11:35:40 +00:00
ONNX tracing fixes.
This commit is contained in:
@@ -15,6 +15,7 @@ from .layers import (
|
||||
)
|
||||
|
||||
from einops import rearrange, repeat
|
||||
import comfy.ldm.common_dit
|
||||
|
||||
@dataclass
|
||||
class FluxParams:
|
||||
@@ -42,7 +43,7 @@ class Flux(nn.Module):
|
||||
self.dtype = dtype
|
||||
params = FluxParams(**kwargs)
|
||||
self.params = params
|
||||
self.in_channels = params.in_channels
|
||||
self.in_channels = params.in_channels * 2 * 2
|
||||
self.out_channels = self.in_channels
|
||||
if params.hidden_size % params.num_heads != 0:
|
||||
raise ValueError(
|
||||
@@ -125,10 +126,7 @@ class Flux(nn.Module):
|
||||
def forward(self, x, timestep, context, y, guidance, **kwargs):
|
||||
bs, c, h, w = x.shape
|
||||
patch_size = 2
|
||||
pad_h = (patch_size - h % 2) % patch_size
|
||||
pad_w = (patch_size - w % 2) % patch_size
|
||||
|
||||
x = torch.nn.functional.pad(x, (0, pad_w, 0, pad_h), mode='circular')
|
||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
|
||||
|
||||
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
|
||||
|
||||
|
Reference in New Issue
Block a user