Support new flux model variants.

This commit is contained in:
comfyanonymous
2024-11-21 08:38:23 -05:00
parent 41444b5236
commit 8f0009aad0
9 changed files with 147 additions and 19 deletions

View File

@@ -20,6 +20,7 @@ import comfy.ldm.common_dit
@dataclass
class FluxParams:
in_channels: int
out_channels: int
vec_in_dim: int
context_in_dim: int
hidden_size: int
@@ -29,6 +30,7 @@ class FluxParams:
depth_single_blocks: int
axes_dim: list
theta: int
patch_size: int
qkv_bias: bool
guidance_embed: bool
@@ -43,8 +45,9 @@ class Flux(nn.Module):
self.dtype = dtype
params = FluxParams(**kwargs)
self.params = params
self.in_channels = params.in_channels * 2 * 2
self.out_channels = self.in_channels
self.patch_size = params.patch_size
self.in_channels = params.in_channels * params.patch_size * params.patch_size
self.out_channels = params.out_channels * params.patch_size * params.patch_size
if params.hidden_size % params.num_heads != 0:
raise ValueError(
f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
@@ -165,7 +168,7 @@ class Flux(nn.Module):
def forward(self, x, timestep, context, y, guidance, control=None, transformer_options={}, **kwargs):
bs, c, h, w = x.shape
patch_size = 2
patch_size = self.patch_size
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)