feat: add support for HunYuanDit ControlNet (#4245)

* add support for HunYuanDit ControlNet

* fix hunyuandit controlnet

* fix typo in hunyuandit controlnet

* fix typo in hunyuandit controlnet

* fix code format style

* add control_weight support for HunyuanDit Controlnet

* use control_weights in HunyuanDit Controlnet

* fix typo
This commit is contained in:
来新璐
2024-08-09 14:59:24 +08:00
committed by GitHub
parent 413322645e
commit 06eb9fb426
4 changed files with 512 additions and 1 deletions

View File

@@ -91,6 +91,8 @@ class HunYuanDiTBlock(nn.Module):
# Long Skip Connection
if self.skip_linear is not None:
cat = torch.cat([x, skip], dim=-1)
if cat.dtype != x.dtype:
cat = cat.to(x.dtype)
cat = self.skip_norm(cat)
x = self.skip_linear(cat)
@@ -362,6 +364,8 @@ class HunYuanDiT(nn.Module):
c = t + self.extra_embedder(extra_vec) # [B, D]
controls = None
if control:
controls = control.get("output", None)
# ========================= Forward pass through HunYuanDiT blocks =========================
skips = []
for layer, block in enumerate(self.blocks):