mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-09-10 11:35:40 +00:00
Support new flux model variants.
This commit is contained in:
@@ -20,6 +20,7 @@ import comfy.ldm.common_dit
|
||||
@dataclass
|
||||
class FluxParams:
|
||||
in_channels: int
|
||||
out_channels: int
|
||||
vec_in_dim: int
|
||||
context_in_dim: int
|
||||
hidden_size: int
|
||||
@@ -29,6 +30,7 @@ class FluxParams:
|
||||
depth_single_blocks: int
|
||||
axes_dim: list
|
||||
theta: int
|
||||
patch_size: int
|
||||
qkv_bias: bool
|
||||
guidance_embed: bool
|
||||
|
||||
@@ -43,8 +45,9 @@ class Flux(nn.Module):
|
||||
self.dtype = dtype
|
||||
params = FluxParams(**kwargs)
|
||||
self.params = params
|
||||
self.in_channels = params.in_channels * 2 * 2
|
||||
self.out_channels = self.in_channels
|
||||
self.patch_size = params.patch_size
|
||||
self.in_channels = params.in_channels * params.patch_size * params.patch_size
|
||||
self.out_channels = params.out_channels * params.patch_size * params.patch_size
|
||||
if params.hidden_size % params.num_heads != 0:
|
||||
raise ValueError(
|
||||
f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
|
||||
@@ -165,7 +168,7 @@ class Flux(nn.Module):
|
||||
|
||||
def forward(self, x, timestep, context, y, guidance, control=None, transformer_options={}, **kwargs):
|
||||
bs, c, h, w = x.shape
|
||||
patch_size = 2
|
||||
patch_size = self.patch_size
|
||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
|
||||
|
||||
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
|
||||
|
Reference in New Issue
Block a user