Allow having a different pooled output for each image in a batch.

This commit is contained in:
comfyanonymous
2023-09-21 01:14:42 -04:00
parent 0793eb9269
commit 492db2de8d
2 changed files with 4 additions and 3 deletions

View File

@@ -181,7 +181,7 @@ class SDXLRefiner(BaseModel):
out.append(self.embedder(torch.Tensor([crop_h])))
out.append(self.embedder(torch.Tensor([crop_w])))
out.append(self.embedder(torch.Tensor([aesthetic_score])))
flat = torch.flatten(torch.cat(out))[None, ]
flat = torch.flatten(torch.cat(out)).unsqueeze(dim=0).repeat(clip_pooled.shape[0], 1)
return torch.cat((clip_pooled.to(flat.device), flat), dim=1)
class SDXL(BaseModel):
@@ -206,5 +206,5 @@ class SDXL(BaseModel):
out.append(self.embedder(torch.Tensor([crop_w])))
out.append(self.embedder(torch.Tensor([target_height])))
out.append(self.embedder(torch.Tensor([target_width])))
flat = torch.flatten(torch.cat(out))[None, ]
flat = torch.flatten(torch.cat(out)).unsqueeze(dim=0).repeat(clip_pooled.shape[0], 1)
return torch.cat((clip_pooled.to(flat.device), flat), dim=1)