Disable autocast in unet for increased speed.

This commit is contained in:
comfyanonymous
2023-07-05 20:58:44 -04:00
parent 603f02d613
commit ddc6f12ad5
9 changed files with 84 additions and 79 deletions

View File

@@ -215,10 +215,12 @@ class PositionNet(nn.Module):
def forward(self, boxes, masks, positive_embeddings):
B, N, _ = boxes.shape
masks = masks.unsqueeze(-1)
dtype = self.linears[0].weight.dtype
masks = masks.unsqueeze(-1).to(dtype)
positive_embeddings = positive_embeddings.to(dtype)
# embedding position (it may includes padding as placeholder)
xyxy_embedding = self.fourier_embedder(boxes) # B*N*4 --> B*N*C
xyxy_embedding = self.fourier_embedder(boxes.to(dtype)) # B*N*4 --> B*N*C
# learnable null embedding
positive_null = self.null_positive_feature.view(1, 1, -1)
@@ -252,7 +254,8 @@ class Gligen(nn.Module):
if self.lowvram == True:
self.position_net.cpu()
def func_lowvram(key, x):
def func_lowvram(x, extra_options):
key = extra_options["transformer_index"]
module = self.module_list[key]
module.to(x.device)
r = module(x, objs)