Lower T5 memory usage by a few hundred MB.

This commit is contained in:
comfyanonymous
2024-07-31 00:52:34 -04:00
parent 82cae45d44
commit b85216a3c0
3 changed files with 33 additions and 17 deletions

View File

@@ -355,7 +355,7 @@ class HunYuanDiT(nn.Module):
if self.use_style_cond:
if style is None:
style = torch.zeros((extra_vec.shape[0],), device=x.device, dtype=torch.int)
style_embedding = self.style_embedder(style)
style_embedding = self.style_embedder(style, out_dtype=x.dtype)
extra_vec = torch.cat([extra_vec, style_embedding], dim=1)
# Concatenate all extra vectors