xlabs Flux controlnet implementation. (#4260)

* xlabs Flux controlnet.

* Fix not working on old python.

* Remove comment.
This commit is contained in:
comfyanonymous
2024-08-12 21:22:22 -04:00
committed by GitHub
parent b8ffb2937f
commit c032b11e07
3 changed files with 135 additions and 7 deletions

View File

@@ -38,7 +38,7 @@ class Flux(nn.Module):
Transformer model for flow matching on sequences.
"""
def __init__(self, image_model=None, dtype=None, device=None, operations=None, **kwargs):
def __init__(self, image_model=None, final_layer=True, dtype=None, device=None, operations=None, **kwargs):
super().__init__()
self.dtype = dtype
params = FluxParams(**kwargs)
@@ -83,7 +83,8 @@ class Flux(nn.Module):
]
)
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels, dtype=dtype, device=device, operations=operations)
if final_layer:
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels, dtype=dtype, device=device, operations=operations)
def forward_orig(
self,
@@ -94,6 +95,7 @@ class Flux(nn.Module):
timesteps: Tensor,
y: Tensor,
guidance: Tensor = None,
control=None,
) -> Tensor:
if img.ndim != 3 or txt.ndim != 3:
raise ValueError("Input img and txt tensors must have 3 dimensions.")
@@ -112,8 +114,15 @@ class Flux(nn.Module):
ids = torch.cat((txt_ids, img_ids), dim=1)
pe = self.pe_embedder(ids)
for block in self.double_blocks:
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
for i in range(len(self.double_blocks)):
img, txt = self.double_blocks[i](img=img, txt=txt, vec=vec, pe=pe)
if control is not None: #Controlnet
control_o = control.get("output")
if i < len(control_o):
add = control_o[i]
if add is not None:
img += add
img = torch.cat((txt, img), 1)
for block in self.single_blocks:
@@ -123,7 +132,7 @@ class Flux(nn.Module):
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
return img
def forward(self, x, timestep, context, y, guidance, **kwargs):
def forward(self, x, timestep, context, y, guidance, control=None, **kwargs):
bs, c, h, w = x.shape
patch_size = 2
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
@@ -138,5 +147,5 @@ class Flux(nn.Module):
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance)
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control)
return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:,:,:h,:w]