Add remaining patch

This commit is contained in:
Yoland Yan
2025-03-02 11:58:04 -08:00
parent 2cd3c8a2fb
commit f03ece18f2
7 changed files with 98 additions and 30 deletions

View File

@@ -797,12 +797,15 @@ class GeneralDITTransformerBlock(nn.Module):
adaln_lora_B_3D: Optional[torch.Tensor] = None,
) -> torch.Tensor:
for block in self.blocks:
x = block(
x,
emb_B_D,
crossattn_emb,
crossattn_mask,
rope_emb_L_1_1_D=rope_emb_L_1_1_D,
adaln_lora_B_3D=adaln_lora_B_3D,
)
if self.training:
x = torch.utils.checkpoint.checkpoint(block, x, emb_B_D, crossattn_emb, crossattn_mask, rope_emb_L_1_1_D, adaln_lora_B_3D, use_reentrant=False)
else:
x = block(
x,
emb_B_D,
crossattn_emb,
crossattn_mask,
rope_emb_L_1_1_D=rope_emb_L_1_1_D,
adaln_lora_B_3D=adaln_lora_B_3D,
)
return x

View File

@@ -750,7 +750,7 @@ class BasicTransformerBlock(nn.Module):
for p in patch:
n = p(n, extra_options)
x += n
x = n + x
if "middle_patch" in transformer_patches:
patch = transformer_patches["middle_patch"]
for p in patch:
@@ -790,12 +790,12 @@ class BasicTransformerBlock(nn.Module):
for p in patch:
n = p(n, extra_options)
x += n
x = n + x
if self.is_res:
x_skip = x
x = self.ff(self.norm3(x))
if self.is_res:
x += x_skip
x = x_skip + x
return x