Disable autocast in unet for increased speed.

This commit is contained in:
comfyanonymous
2023-07-05 20:58:44 -04:00
parent 603f02d613
commit ddc6f12ad5
9 changed files with 84 additions and 79 deletions

View File

@@ -84,7 +84,7 @@ def _summarize_chunk(
max_score, _ = torch.max(attn_weights, -1, keepdim=True)
max_score = max_score.detach()
torch.exp(attn_weights - max_score, out=attn_weights)
exp_weights = attn_weights
exp_weights = attn_weights.to(value.dtype)
exp_values = torch.bmm(exp_weights, value)
max_score = max_score.squeeze(-1)
return AttnChunk(exp_values, exp_weights.sum(dim=-1), max_score)
@@ -166,7 +166,7 @@ def _get_attention_scores_no_kv_chunking(
attn_scores /= summed
attn_probs = attn_scores
hidden_states_slice = torch.bmm(attn_probs, value)
hidden_states_slice = torch.bmm(attn_probs.to(value.dtype), value)
return hidden_states_slice
class ScannedChunk(NamedTuple):