Allow disabling pe in flux code for some other models.

This commit is contained in:
comfyanonymous
2025-03-18 05:09:25 -04:00
parent 50614f1b79
commit 3b19fc76e3
2 changed files with 10 additions and 6 deletions

View File

@@ -115,8 +115,11 @@ class Flux(nn.Module):
vec = vec + self.vector_in(y[:,:self.params.vec_in_dim])
txt = self.txt_in(txt)
ids = torch.cat((txt_ids, img_ids), dim=1)
pe = self.pe_embedder(ids)
if img_ids is not None:
ids = torch.cat((txt_ids, img_ids), dim=1)
pe = self.pe_embedder(ids)
else:
pe = None
blocks_replace = patches_replace.get("dit", {})
for i, block in enumerate(self.double_blocks):