mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-09-13 04:55:53 +00:00
Cleaner CLIP text encoder implementation.
Use a simple CLIP model implementation instead of the one from transformers. This will allow some interesting things that would too hackish to implement using the transformers implementation.
This commit is contained in:
@@ -112,10 +112,13 @@ def attention_basic(q, k, v, heads, mask=None):
|
||||
del q, k
|
||||
|
||||
if exists(mask):
|
||||
mask = rearrange(mask, 'b ... -> b (...)')
|
||||
max_neg_value = -torch.finfo(sim.dtype).max
|
||||
mask = repeat(mask, 'b j -> (b h) () j', h=h)
|
||||
sim.masked_fill_(~mask, max_neg_value)
|
||||
if mask.dtype == torch.bool:
|
||||
mask = rearrange(mask, 'b ... -> b (...)') #TODO: check if this bool part matches pytorch attention
|
||||
max_neg_value = -torch.finfo(sim.dtype).max
|
||||
mask = repeat(mask, 'b j -> (b h) () j', h=h)
|
||||
sim.masked_fill_(~mask, max_neg_value)
|
||||
else:
|
||||
sim += mask
|
||||
|
||||
# attention, what we cannot get enough of
|
||||
sim = sim.softmax(dim=-1)
|
||||
@@ -340,6 +343,18 @@ else:
|
||||
if model_management.pytorch_attention_enabled():
|
||||
optimized_attention_masked = attention_pytorch
|
||||
|
||||
def optimized_attention_for_device(device, mask=False):
|
||||
if device == torch.device("cpu"): #TODO
|
||||
if model_management.pytorch_attention_enabled():
|
||||
return attention_pytorch
|
||||
else:
|
||||
return attention_basic
|
||||
if mask:
|
||||
return optimized_attention_masked
|
||||
|
||||
return optimized_attention
|
||||
|
||||
|
||||
class CrossAttention(nn.Module):
|
||||
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=comfy.ops):
|
||||
super().__init__()
|
||||
|
Reference in New Issue
Block a user