Disable autocast in unet for increased speed.

This commit is contained in:
comfyanonymous
2023-07-05 20:58:44 -04:00
parent 603f02d613
commit ddc6f12ad5
9 changed files with 84 additions and 79 deletions

View File

@@ -278,7 +278,7 @@ class CrossAttentionDoggettx(nn.Module):
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
del q_in, k_in, v_in
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device)
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
mem_free_total = model_management.get_free_memory(q.device)
@@ -314,7 +314,7 @@ class CrossAttentionDoggettx(nn.Module):
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * self.scale
first_op_done = True
s2 = s1.softmax(dim=-1)
s2 = s1.softmax(dim=-1).to(v.dtype)
del s1
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)