Cleaner CLIP text encoder implementation.

Use a simple CLIP model implementation instead of the one from
transformers.

This will allow some interesting things that would too hackish to implement
using the transformers implementation.
This commit is contained in:
comfyanonymous
2023-12-06 15:55:09 -05:00
parent 2db86b4676
commit fbdb14d4c4
5 changed files with 172 additions and 49 deletions

View File

@@ -112,10 +112,13 @@ def attention_basic(q, k, v, heads, mask=None):
del q, k
if exists(mask):
mask = rearrange(mask, 'b ... -> b (...)')
max_neg_value = -torch.finfo(sim.dtype).max
mask = repeat(mask, 'b j -> (b h) () j', h=h)
sim.masked_fill_(~mask, max_neg_value)
if mask.dtype == torch.bool:
mask = rearrange(mask, 'b ... -> b (...)') #TODO: check if this bool part matches pytorch attention
max_neg_value = -torch.finfo(sim.dtype).max
mask = repeat(mask, 'b j -> (b h) () j', h=h)
sim.masked_fill_(~mask, max_neg_value)
else:
sim += mask
# attention, what we cannot get enough of
sim = sim.softmax(dim=-1)
@@ -340,6 +343,18 @@ else:
if model_management.pytorch_attention_enabled():
optimized_attention_masked = attention_pytorch
def optimized_attention_for_device(device, mask=False):
if device == torch.device("cpu"): #TODO
if model_management.pytorch_attention_enabled():
return attention_pytorch
else:
return attention_basic
if mask:
return optimized_attention_masked
return optimized_attention
class CrossAttention(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=comfy.ops):
super().__init__()