Switch some more prints to logging.

This commit is contained in:
comfyanonymous
2024-03-11 16:24:47 -04:00
parent 0ed72befe1
commit 2a813c3b09
10 changed files with 40 additions and 34 deletions

View File

@@ -14,6 +14,7 @@ import torch
from torch import Tensor
from torch.utils.checkpoint import checkpoint
import math
import logging
try:
from typing import Optional, NamedTuple, List, Protocol
@@ -170,7 +171,7 @@ def _get_attention_scores_no_kv_chunking(
attn_probs = attn_scores.softmax(dim=-1)
del attn_scores
except model_management.OOM_EXCEPTION:
print("ran out of memory while running softmax in _get_attention_scores_no_kv_chunking, trying slower in place softmax instead")
logging.warning("ran out of memory while running softmax in _get_attention_scores_no_kv_chunking, trying slower in place softmax instead")
attn_scores -= attn_scores.max(dim=-1, keepdim=True).values
torch.exp(attn_scores, out=attn_scores)
summed = torch.sum(attn_scores, dim=-1, keepdim=True)