Implement noise augmentation for SD 4X upscale model.

This commit is contained in:
comfyanonymous
2024-01-03 14:27:11 -05:00
parent ef4f6037cb
commit 8c6493578b
6 changed files with 33 additions and 14 deletions

View File

@@ -498,7 +498,7 @@ class UNetModel(nn.Module):
if self.num_classes is not None:
if isinstance(self.num_classes, int):
self.label_emb = nn.Embedding(num_classes, time_embed_dim)
self.label_emb = nn.Embedding(num_classes, time_embed_dim, dtype=self.dtype, device=device)
elif self.num_classes == "continuous":
print("setting up linear c_adm embedding layer")
self.label_emb = nn.Linear(1, time_embed_dim)