diff --git a/README.md b/README.md index 4747b86df..2b1bccc58 100644 --- a/README.md +++ b/README.md @@ -135,7 +135,7 @@ You can also set this command line setting to disable the upcasting to fp32 in s ## Support and dev channel -[Matrix room: #comfyui:matrix.org](https://app.element.io/#/room/%23comfyui%3Amatrix.org) (it's like discord but open source). +[Matrix space: #comfyui_space:matrix.org](https://app.element.io/#/room/%23comfyui_space%3Amatrix.org) (it's like discord but open source). # QA diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 59683f645..692952f32 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -489,6 +489,8 @@ if XFORMERS_IS_AVAILBLE == False or "--disable-xformers" in sys.argv: if "--use-pytorch-cross-attention" in sys.argv: print("Using pytorch cross attention") torch.backends.cuda.enable_math_sdp(False) + torch.backends.cuda.enable_flash_sdp(True) + torch.backends.cuda.enable_mem_efficient_sdp(True) CrossAttention = CrossAttentionPytorch else: print("Using sub quadratic optimization for cross attention, if you have memory or speed issues try using: --use-split-cross-attention") @@ -497,6 +499,7 @@ else: print("Using xformers cross attention") CrossAttention = MemoryEfficientCrossAttention + class BasicTransformerBlock(nn.Module): def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True, disable_self_attn=False): diff --git a/comfy/ldm/modules/diffusionmodules/model.py b/comfy/ldm/modules/diffusionmodules/model.py index 6f0b41dce..01ab2ede9 100644 --- a/comfy/ldm/modules/diffusionmodules/model.py +++ b/comfy/ldm/modules/diffusionmodules/model.py @@ -7,6 +7,7 @@ from einops import rearrange from typing import Optional, Any from ldm.modules.attention import MemoryEfficientCrossAttention +import model_management try: import xformers @@ -199,12 +200,7 @@ class AttnBlock(nn.Module): r1 = torch.zeros_like(k, device=q.device) - stats = torch.cuda.memory_stats(q.device) - mem_active = stats['active_bytes.all.current'] - mem_reserved = stats['reserved_bytes.all.current'] - mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device()) - mem_free_torch = mem_reserved - mem_active - mem_free_total = mem_free_cuda + mem_free_torch + mem_free_total = model_management.get_free_memory(q.device) gb = 1024 ** 3 tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size() diff --git a/comfy/sd.py b/comfy/sd.py index 50d81f779..19722113a 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -612,8 +612,17 @@ class T2IAdapter: def load_t2i_adapter(ckpt_path, model=None): t2i_data = load_torch_file(ckpt_path) - cin = t2i_data['conv_in.weight'].shape[1] - model_ad = adapter.Adapter(cin=cin, channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True, use_conv=False) + keys = t2i_data.keys() + if "style_embedding" in keys: + pass + # TODO + # model_ad = adapter.StyleAdapter(width=1024, context_dim=768, num_head=8, n_layes=3, num_token=8) + elif "body.0.in_conv.weight" in keys: + cin = t2i_data['body.0.in_conv.weight'].shape[1] + model_ad = adapter.Adapter_light(cin=cin, channels=[320, 640, 1280, 1280], nums_rb=4) + else: + cin = t2i_data['conv_in.weight'].shape[1] + model_ad = adapter.Adapter(cin=cin, channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True, use_conv=False) model_ad.load_state_dict(t2i_data) return T2IAdapter(model_ad, cin // 64) diff --git a/comfy/t2i_adapter/adapter.py b/comfy/t2i_adapter/adapter.py index d059ba913..0221fff83 100644 --- a/comfy/t2i_adapter/adapter.py +++ b/comfy/t2i_adapter/adapter.py @@ -1,9 +1,8 @@ #taken from https://github.com/TencentARC/T2I-Adapter - import torch import torch.nn as nn -import torch.nn.functional as F -from ldm.modules.attention import SpatialTransformer, BasicTransformerBlock +from collections import OrderedDict + def conv_nd(dims, *args, **kwargs): """ @@ -17,6 +16,7 @@ def conv_nd(dims, *args, **kwargs): return nn.Conv3d(*args, **kwargs) raise ValueError(f"unsupported dimensions: {dims}") + def avg_pool_nd(dims, *args, **kwargs): """ Create a 1D, 2D, or 3D average pooling module. @@ -29,6 +29,7 @@ def avg_pool_nd(dims, *args, **kwargs): return nn.AvgPool3d(*args, **kwargs) raise ValueError(f"unsupported dimensions: {dims}") + class Downsample(nn.Module): """ A downsampling layer with an optional convolution. @@ -38,7 +39,7 @@ class Downsample(nn.Module): downsampling occurs in the inner-two dimensions. """ - def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1): + def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): super().__init__() self.channels = channels self.out_channels = out_channels or channels @@ -61,8 +62,8 @@ class Downsample(nn.Module): class ResnetBlock(nn.Module): def __init__(self, in_c, out_c, down, ksize=3, sk=False, use_conv=True): super().__init__() - ps = ksize//2 - if in_c != out_c or sk==False: + ps = ksize // 2 + if in_c != out_c or sk == False: self.in_conv = nn.Conv2d(in_c, out_c, ksize, 1, ps) else: # print('n_in') @@ -70,7 +71,7 @@ class ResnetBlock(nn.Module): self.block1 = nn.Conv2d(out_c, out_c, 3, 1, 1) self.act = nn.ReLU() self.block2 = nn.Conv2d(out_c, out_c, ksize, 1, ps) - if sk==False: + if sk == False: self.skep = nn.Conv2d(in_c, out_c, ksize, 1, ps) else: self.skep = None @@ -82,7 +83,7 @@ class ResnetBlock(nn.Module): def forward(self, x): if self.down == True: x = self.down_opt(x) - if self.in_conv is not None: # edit + if self.in_conv is not None: # edit x = self.in_conv(x) h = self.block1(x) @@ -103,12 +104,14 @@ class Adapter(nn.Module): self.body = [] for i in range(len(channels)): for j in range(nums_rb): - if (i!=0) and (j==0): - self.body.append(ResnetBlock(channels[i-1], channels[i], down=True, ksize=ksize, sk=sk, use_conv=use_conv)) + if (i != 0) and (j == 0): + self.body.append( + ResnetBlock(channels[i - 1], channels[i], down=True, ksize=ksize, sk=sk, use_conv=use_conv)) else: - self.body.append(ResnetBlock(channels[i], channels[i], down=False, ksize=ksize, sk=sk, use_conv=use_conv)) + self.body.append( + ResnetBlock(channels[i], channels[i], down=False, ksize=ksize, sk=sk, use_conv=use_conv)) self.body = nn.ModuleList(self.body) - self.conv_in = nn.Conv2d(cin,channels[0], 3, 1, 1) + self.conv_in = nn.Conv2d(cin, channels[0], 3, 1, 1) def forward(self, x): # unshuffle @@ -118,8 +121,139 @@ class Adapter(nn.Module): x = self.conv_in(x) for i in range(len(self.channels)): for j in range(self.nums_rb): - idx = i*self.nums_rb +j + idx = i * self.nums_rb + j x = self.body[idx](x) features.append(x) return features + + +class LayerNorm(nn.LayerNorm): + """Subclass torch's LayerNorm to handle fp16.""" + + def forward(self, x: torch.Tensor): + orig_type = x.dtype + ret = super().forward(x.type(torch.float32)) + return ret.type(orig_type) + + +class QuickGELU(nn.Module): + + def forward(self, x: torch.Tensor): + return x * torch.sigmoid(1.702 * x) + + +class ResidualAttentionBlock(nn.Module): + + def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): + super().__init__() + + self.attn = nn.MultiheadAttention(d_model, n_head) + self.ln_1 = LayerNorm(d_model) + self.mlp = nn.Sequential( + OrderedDict([("c_fc", nn.Linear(d_model, d_model * 4)), ("gelu", QuickGELU()), + ("c_proj", nn.Linear(d_model * 4, d_model))])) + self.ln_2 = LayerNorm(d_model) + self.attn_mask = attn_mask + + def attention(self, x: torch.Tensor): + self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None + return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] + + def forward(self, x: torch.Tensor): + x = x + self.attention(self.ln_1(x)) + x = x + self.mlp(self.ln_2(x)) + return x + + +class StyleAdapter(nn.Module): + + def __init__(self, width=1024, context_dim=768, num_head=8, n_layes=3, num_token=4): + super().__init__() + + scale = width ** -0.5 + self.transformer_layes = nn.Sequential(*[ResidualAttentionBlock(width, num_head) for _ in range(n_layes)]) + self.num_token = num_token + self.style_embedding = nn.Parameter(torch.randn(1, num_token, width) * scale) + self.ln_post = LayerNorm(width) + self.ln_pre = LayerNorm(width) + self.proj = nn.Parameter(scale * torch.randn(width, context_dim)) + + def forward(self, x): + # x shape [N, HW+1, C] + style_embedding = self.style_embedding + torch.zeros( + (x.shape[0], self.num_token, self.style_embedding.shape[-1]), device=x.device) + x = torch.cat([x, style_embedding], dim=1) + x = self.ln_pre(x) + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer_layes(x) + x = x.permute(1, 0, 2) # LND -> NLD + + x = self.ln_post(x[:, -self.num_token:, :]) + x = x @ self.proj + + return x + + +class ResnetBlock_light(nn.Module): + def __init__(self, in_c): + super().__init__() + self.block1 = nn.Conv2d(in_c, in_c, 3, 1, 1) + self.act = nn.ReLU() + self.block2 = nn.Conv2d(in_c, in_c, 3, 1, 1) + + def forward(self, x): + h = self.block1(x) + h = self.act(h) + h = self.block2(h) + + return h + x + + +class extractor(nn.Module): + def __init__(self, in_c, inter_c, out_c, nums_rb, down=False): + super().__init__() + self.in_conv = nn.Conv2d(in_c, inter_c, 1, 1, 0) + self.body = [] + for _ in range(nums_rb): + self.body.append(ResnetBlock_light(inter_c)) + self.body = nn.Sequential(*self.body) + self.out_conv = nn.Conv2d(inter_c, out_c, 1, 1, 0) + self.down = down + if self.down == True: + self.down_opt = Downsample(in_c, use_conv=False) + + def forward(self, x): + if self.down == True: + x = self.down_opt(x) + x = self.in_conv(x) + x = self.body(x) + x = self.out_conv(x) + + return x + + +class Adapter_light(nn.Module): + def __init__(self, channels=[320, 640, 1280, 1280], nums_rb=3, cin=64): + super(Adapter_light, self).__init__() + self.unshuffle = nn.PixelUnshuffle(8) + self.channels = channels + self.nums_rb = nums_rb + self.body = [] + for i in range(len(channels)): + if i == 0: + self.body.append(extractor(in_c=cin, inter_c=channels[i]//4, out_c=channels[i], nums_rb=nums_rb, down=False)) + else: + self.body.append(extractor(in_c=channels[i-1], inter_c=channels[i]//4, out_c=channels[i], nums_rb=nums_rb, down=True)) + self.body = nn.ModuleList(self.body) + + def forward(self, x): + # unshuffle + x = self.unshuffle(x) + # extract features + features = [] + for i in range(len(self.channels)): + x = self.body[i](x) + features.append(x) + + return features diff --git a/notebooks/comfyui_colab.ipynb b/notebooks/comfyui_colab.ipynb index 63456ae33..5315ab08e 100644 --- a/notebooks/comfyui_colab.ipynb +++ b/notebooks/comfyui_colab.ipynb @@ -42,7 +42,7 @@ { "cell_type": "markdown", "source": [ - "Download some models/checkpoints/vae (uncomment the wget commands for the ones you want)" + "Download some models/checkpoints/vae or custom comfyui nodes (uncomment the commands for the ones you want)" ], "metadata": { "id": "cccccccccc" @@ -54,43 +54,52 @@ "# Checkpoints\n", "\n", "# SD1.5\n", - "!wget https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.ckpt -P ./models/checkpoints/\n", + "!wget -c https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.ckpt -P ./models/checkpoints/\n", "\n", "# SD2\n", - "#!wget https://huggingface.co/stabilityai/stable-diffusion-2-1-base/resolve/main/v2-1_512-ema-pruned.safetensors -P ./models/checkpoints/\n", - "#!wget https://huggingface.co/stabilityai/stable-diffusion-2-1/resolve/main/v2-1_768-ema-pruned.safetensors -P ./models/checkpoints/\n", + "#!wget -c https://huggingface.co/stabilityai/stable-diffusion-2-1-base/resolve/main/v2-1_512-ema-pruned.safetensors -P ./models/checkpoints/\n", + "#!wget -c https://huggingface.co/stabilityai/stable-diffusion-2-1/resolve/main/v2-1_768-ema-pruned.safetensors -P ./models/checkpoints/\n", "\n", "# Some SD1.5 anime style\n", - "#!wget https://huggingface.co/WarriorMama777/OrangeMixs/resolve/main/Models/AbyssOrangeMix2/AbyssOrangeMix2_hard.safetensors -P ./models/checkpoints/\n", - "#!wget https://huggingface.co/WarriorMama777/OrangeMixs/resolve/main/Models/AbyssOrangeMix3/AOM3A1.safetensors -P ./models/checkpoints/\n", - "#!wget https://huggingface.co/WarriorMama777/OrangeMixs/resolve/main/Models/AbyssOrangeMix3/AOM3A3.safetensors -P ./models/checkpoints/\n", - "#!wget https://huggingface.co/Linaqruf/anything-v3.0/resolve/main/anything-v3-fp16-pruned.safetensors -P ./models/checkpoints/\n", + "#!wget -c https://huggingface.co/WarriorMama777/OrangeMixs/resolve/main/Models/AbyssOrangeMix2/AbyssOrangeMix2_hard.safetensors -P ./models/checkpoints/\n", + "#!wget -c https://huggingface.co/WarriorMama777/OrangeMixs/resolve/main/Models/AbyssOrangeMix3/AOM3A1.safetensors -P ./models/checkpoints/\n", + "#!wget -c https://huggingface.co/WarriorMama777/OrangeMixs/resolve/main/Models/AbyssOrangeMix3/AOM3A3.safetensors -P ./models/checkpoints/\n", + "#!wget -c https://huggingface.co/Linaqruf/anything-v3.0/resolve/main/anything-v3-fp16-pruned.safetensors -P ./models/checkpoints/\n", "\n", "# Waifu Diffusion 1.5 (anime style SD2.x 768-v)\n", - "#!wget https://huggingface.co/waifu-diffusion/wd-1-5-beta2/resolve/main/checkpoints/wd-1-5-beta2-fp16.safetensors -P ./models/checkpoints/\n", + "#!wget -c https://huggingface.co/waifu-diffusion/wd-1-5-beta2/resolve/main/checkpoints/wd-1-5-beta2-fp16.safetensors -P ./models/checkpoints/\n", "\n", "\n", "# VAE\n", - "!wget https://huggingface.co/stabilityai/sd-vae-ft-mse-original/resolve/main/vae-ft-mse-840000-ema-pruned.safetensors -P ./models/vae/\n", - "#!wget https://huggingface.co/WarriorMama777/OrangeMixs/resolve/main/VAEs/orangemix.vae.pt -P ./models/vae/\n", + "!wget -c https://huggingface.co/stabilityai/sd-vae-ft-mse-original/resolve/main/vae-ft-mse-840000-ema-pruned.safetensors -P ./models/vae/\n", + "#!wget -c https://huggingface.co/WarriorMama777/OrangeMixs/resolve/main/VAEs/orangemix.vae.pt -P ./models/vae/\n", "\n", "\n", "# Loras\n", - "#!wget --content-disposition https://civitai.com/api/download/models/10350 -P ./models/loras/ #theovercomer8sContrastFix SD2.x 768-v\n", - "#!wget --content-disposition https://civitai.com/api/download/models/10638 -P ./models/loras/ #theovercomer8sContrastFix SD1.x\n", + "#!wget -c --content-disposition https://civitai.com/api/download/models/10350 -P ./models/loras/ #theovercomer8sContrastFix SD2.x 768-v\n", + "#!wget -c --content-disposition https://civitai.com/api/download/models/10638 -P ./models/loras/ #theovercomer8sContrastFix SD1.x\n", "\n", "\n", "# T2I-Adapter\n", - "#!wget https://huggingface.co/TencentARC/T2I-Adapter/resolve/main/models/t2iadapter_depth_sd14v1.pth -P ./models/t2i_adapter/\n", - "#!wget https://huggingface.co/TencentARC/T2I-Adapter/resolve/main/models/t2iadapter_seg_sd14v1.pth -P ./models/t2i_adapter/\n", - "#!wget https://huggingface.co/TencentARC/T2I-Adapter/resolve/main/models/t2iadapter_sketch_sd14v1.pth -P ./models/t2i_adapter/\n", - "#!wget https://huggingface.co/TencentARC/T2I-Adapter/resolve/main/models/t2iadapter_keypose_sd14v1.pth -P ./models/t2i_adapter/\n", + "#!wget -c https://huggingface.co/TencentARC/T2I-Adapter/resolve/main/models/t2iadapter_depth_sd14v1.pth -P ./models/t2i_adapter/\n", + "#!wget -c https://huggingface.co/TencentARC/T2I-Adapter/resolve/main/models/t2iadapter_seg_sd14v1.pth -P ./models/t2i_adapter/\n", + "#!wget -c https://huggingface.co/TencentARC/T2I-Adapter/resolve/main/models/t2iadapter_sketch_sd14v1.pth -P ./models/t2i_adapter/\n", + "#!wget -c https://huggingface.co/TencentARC/T2I-Adapter/resolve/main/models/t2iadapter_keypose_sd14v1.pth -P ./models/t2i_adapter/\n", + "#!wget -c https://huggingface.co/TencentARC/T2I-Adapter/resolve/main/models/t2iadapter_openpose_sd14v1.pth -P ./models/t2i_adapter/\n", + "#!wget -c https://huggingface.co/TencentARC/T2I-Adapter/resolve/main/models/t2iadapter_color_sd14v1.pth -P ./models/t2i_adapter/\n", + "#!wget -c https://huggingface.co/TencentARC/T2I-Adapter/resolve/main/models/t2iadapter_canny_sd14v1.pth -P ./models/t2i_adapter/\n", + "\n", "\n", "\n", "# ControlNet\n", - "#!wget https://huggingface.co/webui/ControlNet-modules-safetensors/resolve/main/control_depth-fp16.safetensors -P ./models/controlnet/\n", - "#!wget https://huggingface.co/webui/ControlNet-modules-safetensors/resolve/main/control_scribble-fp16.safetensors -P ./models/controlnet/\n", - "#!wget https://huggingface.co/webui/ControlNet-modules-safetensors/resolve/main/control_openpose-fp16.safetensors -P ./models/controlnet/\n" + "#!wget -c https://huggingface.co/webui/ControlNet-modules-safetensors/resolve/main/control_depth-fp16.safetensors -P ./models/controlnet/\n", + "#!wget -c https://huggingface.co/webui/ControlNet-modules-safetensors/resolve/main/control_scribble-fp16.safetensors -P ./models/controlnet/\n", + "#!wget -c https://huggingface.co/webui/ControlNet-modules-safetensors/resolve/main/control_openpose-fp16.safetensors -P ./models/controlnet/\n", + "\n", + "\n", + "# Controlnet Preprocessor nodes by Fannovel16\n", + "#!cd custom_nodes && git clone https://github.com/Fannovel16/comfy_controlnet_preprocessors; cd comfy_controlnet_preprocessors && python install.py\n", + "\n" ], "metadata": { "id": "dddddddddd" @@ -101,7 +110,7 @@ { "cell_type": "markdown", "source": [ - "### Run ComfyUI with localtunnel\n", + "### Run ComfyUI with localtunnel (Recommended Way)\n", "\n", "use the **fp16** model configs for more speed\n", "\n" @@ -146,7 +155,7 @@ { "cell_type": "markdown", "source": [ - "### Run ComfyUI with colab iframe (in case localtunnel doesn't work)\n", + "### Run ComfyUI with colab iframe (use only in case the previous way with localtunnel doesn't work)\n", "use the **fp16** model configs for more speed\n", "\n", "You should see the ui appear in an iframe. If you get a 403 error, it's your firefox settings or an extension that's messing things up.\n", diff --git a/script_examples/basic_api_example.py b/script_examples/basic_api_example.py index e01f0b913..a0e22878b 100644 --- a/script_examples/basic_api_example.py +++ b/script_examples/basic_api_example.py @@ -37,10 +37,9 @@ prompt_text = """ } }, "4": { - "class_type": "CheckpointLoader", + "class_type": "CheckpointLoaderSimple", "inputs": { - "ckpt_name": "v1-5-pruned-emaonly.ckpt", - "config_name": "v1-inference.yaml" + "ckpt_name": "v1-5-pruned-emaonly.ckpt" } }, "5": { diff --git a/web/scripts/defaultGraph.js b/web/scripts/defaultGraph.js index 865f1ca0e..967377ad6 100644 --- a/web/scripts/defaultGraph.js +++ b/web/scripts/defaultGraph.js @@ -86,9 +86,9 @@ export const defaultGraph = { }, { id: 4, - type: "CheckpointLoader", + type: "CheckpointLoaderSimple", pos: [26, 474], - size: { 0: 315, 1: 122 }, + size: { 0: 315, 1: 98 }, flags: {}, order: 0, mode: 0, @@ -98,7 +98,7 @@ export const defaultGraph = { { name: "VAE", type: "VAE", links: [8], slot_index: 2 }, ], properties: {}, - widgets_values: ["v1-inference.yaml", "v1-5-pruned-emaonly.ckpt"], + widgets_values: ["v1-5-pruned-emaonly.ckpt"], }, ], links: [