diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index f78a1a6c..e97badd0 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -479,23 +479,19 @@ class CrossAttentionPytorch(nn.Module): return self.to_out(out) import sys -if model_management.xformers_enabled() == False: +if model_management.xformers_enabled(): + print("Using xformers cross attention") + CrossAttention = MemoryEfficientCrossAttention +elif model_management.pytorch_attention_enabled(): + print("Using pytorch cross attention") + CrossAttention = CrossAttentionPytorch +else: if "--use-split-cross-attention" in sys.argv: print("Using split optimization for cross attention") CrossAttention = CrossAttentionDoggettx else: - if "--use-pytorch-cross-attention" in sys.argv: - print("Using pytorch cross attention") - torch.backends.cuda.enable_math_sdp(False) - torch.backends.cuda.enable_flash_sdp(True) - torch.backends.cuda.enable_mem_efficient_sdp(True) - CrossAttention = CrossAttentionPytorch - else: - print("Using sub quadratic optimization for cross attention, if you have memory or speed issues try using: --use-split-cross-attention") - CrossAttention = CrossAttentionBirchSan -else: - print("Using xformers cross attention") - CrossAttention = MemoryEfficientCrossAttention + print("Using sub quadratic optimization for cross attention, if you have memory or speed issues try using: --use-split-cross-attention") + CrossAttention = CrossAttentionBirchSan class BasicTransformerBlock(nn.Module): diff --git a/comfy/ldm/modules/diffusionmodules/model.py b/comfy/ldm/modules/diffusionmodules/model.py index fcbee29f..129b86a7 100644 --- a/comfy/ldm/modules/diffusionmodules/model.py +++ b/comfy/ldm/modules/diffusionmodules/model.py @@ -299,6 +299,64 @@ class MemoryEfficientAttnBlock(nn.Module): out = self.proj_out(out) return x+out +class MemoryEfficientAttnBlockPytorch(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.k = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.v = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.attention_op: Optional[Any] = None + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + B, C, H, W = q.shape + q, k, v = map(lambda x: rearrange(x, 'b c h w -> b (h w) c'), (q, k, v)) + + q, k, v = map( + lambda t: t.unsqueeze(3) + .reshape(B, t.shape[1], 1, C) + .permute(0, 2, 1, 3) + .reshape(B * 1, t.shape[1], C) + .contiguous(), + (q, k, v), + ) + out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False) + + out = ( + out.unsqueeze(0) + .reshape(B, 1, out.shape[1], C) + .permute(0, 2, 1, 3) + .reshape(B, out.shape[1], C) + ) + out = rearrange(out, 'b (h w) c -> b c h w', b=B, h=H, w=W, c=C) + out = self.proj_out(out) + return x+out class MemoryEfficientCrossAttentionWrapper(MemoryEfficientCrossAttention): def forward(self, x, context=None, mask=None): @@ -313,6 +371,8 @@ def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None): assert attn_type in ["vanilla", "vanilla-xformers", "memory-efficient-cross-attn", "linear", "none"], f'attn_type {attn_type} unknown' if model_management.xformers_enabled() and attn_type == "vanilla": attn_type = "vanilla-xformers" + if model_management.pytorch_attention_enabled() and attn_type == "vanilla": + attn_type = "vanilla-pytorch" print(f"making attention of type '{attn_type}' with {in_channels} in_channels") if attn_type == "vanilla": assert attn_kwargs is None @@ -320,6 +380,8 @@ def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None): elif attn_type == "vanilla-xformers": print(f"building MemoryEfficientAttnBlock with {in_channels} in_channels...") return MemoryEfficientAttnBlock(in_channels) + elif attn_type == "vanilla-pytorch": + return MemoryEfficientAttnBlockPytorch(in_channels) elif type == "memory-efficient-cross-attn": attn_kwargs["query_dim"] = in_channels return MemoryEfficientCrossAttentionWrapper(**attn_kwargs) diff --git a/comfy/model_management.py b/comfy/model_management.py index 7365beef..482b1add 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -41,6 +41,14 @@ else: except: XFORMERS_IS_AVAILBLE = False +ENABLE_PYTORCH_ATTENTION = False +if "--use-pytorch-cross-attention" in sys.argv: + torch.backends.cuda.enable_math_sdp(True) + torch.backends.cuda.enable_flash_sdp(True) + torch.backends.cuda.enable_mem_efficient_sdp(True) + ENABLE_PYTORCH_ATTENTION = True + XFORMERS_IS_AVAILBLE = False + if "--cpu" in sys.argv: vram_state = CPU @@ -175,6 +183,9 @@ def xformers_enabled(): return False return XFORMERS_IS_AVAILBLE +def pytorch_attention_enabled(): + return ENABLE_PYTORCH_ATTENTION + def get_free_memory(dev=None, torch_free_too=False): if dev is None: dev = get_torch_device() diff --git a/main.py b/main.py index fc37781c..b2b3f1c4 100644 --- a/main.py +++ b/main.py @@ -15,6 +15,7 @@ if __name__ == "__main__": print("\t--port 8188\t\t\tSet the listen port.") print("\t--dont-upcast-attention\t\tDisable upcasting of attention \n\t\t\t\t\tcan boost speed but increase the chances of black images.\n") print("\t--use-split-cross-attention\tUse the split cross attention optimization instead of the sub-quadratic one.\n\t\t\t\t\tIgnored when xformers is used.") + print("\t--use-pytorch-cross-attention\tUse the new pytorch 2.0 cross attention function.") print("\t--disable-xformers\t\tdisables xformers") print() print("\t--highvram\t\t\tBy default models will be unloaded to CPU memory after being used.\n\t\t\t\t\tThis option keeps them in GPU memory.\n")