mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-09-14 05:25:23 +00:00
Add debug options to force on and off attention upcasting.
This commit is contained in:
@@ -19,6 +19,14 @@ from comfy.cli_args import args
|
||||
import comfy.ops
|
||||
ops = comfy.ops.disable_weight_init
|
||||
|
||||
|
||||
def get_attn_precision(attn_precision):
|
||||
if args.dont_upcast_attention:
|
||||
return None
|
||||
if attn_precision is None and args.force_upcast_attention:
|
||||
return torch.float32
|
||||
return attn_precision
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
@@ -78,6 +86,8 @@ def Normalize(in_channels, dtype=None, device=None):
|
||||
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device)
|
||||
|
||||
def attention_basic(q, k, v, heads, mask=None, attn_precision=None):
|
||||
attn_precision = get_attn_precision(attn_precision)
|
||||
|
||||
b, _, dim_head = q.shape
|
||||
dim_head //= heads
|
||||
scale = dim_head ** -0.5
|
||||
@@ -128,6 +138,8 @@ def attention_basic(q, k, v, heads, mask=None, attn_precision=None):
|
||||
|
||||
|
||||
def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None):
|
||||
attn_precision = get_attn_precision(attn_precision)
|
||||
|
||||
b, _, dim_head = query.shape
|
||||
dim_head //= heads
|
||||
|
||||
@@ -188,6 +200,8 @@ def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None)
|
||||
return hidden_states
|
||||
|
||||
def attention_split(q, k, v, heads, mask=None, attn_precision=None):
|
||||
attn_precision = get_attn_precision(attn_precision)
|
||||
|
||||
b, _, dim_head = q.shape
|
||||
dim_head //= heads
|
||||
scale = dim_head ** -0.5
|
||||
|
Reference in New Issue
Block a user