From f4231a80b1b904b45ade0def9b37320c4adfe71b Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Tue, 12 Aug 2025 00:15:14 +0300 Subject: [PATCH 01/13] fix(Kling Image API Node): do not pass "image_type" when no image (#9271) * fix(Kling Image API Node): do not pass "image_type" when no image * fix(Kling Image API Node): raise client-side error when kling_v1 is used with reference image --- comfy_api_nodes/nodes_kling.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/comfy_api_nodes/nodes_kling.py b/comfy_api_nodes/nodes_kling.py index 9d9eb5628..9d483bb0e 100644 --- a/comfy_api_nodes/nodes_kling.py +++ b/comfy_api_nodes/nodes_kling.py @@ -1690,7 +1690,11 @@ class KlingImageGenerationNode(KlingImageGenerationBase): ): self.validate_prompt(prompt, negative_prompt) - if image is not None: + if image is None: + image_type = None + elif model_name == KlingImageGenModelName.kling_v1: + raise ValueError(f"The model {KlingImageGenModelName.kling_v1.value} does not support reference images.") + else: image = tensor_to_base64_string(image) initial_operation = SynchronousOperation( From 1e3ae1eed8b925430e3b114ea6b7d08ea698e305 Mon Sep 17 00:00:00 2001 From: ComfyUI Wiki Date: Wed, 13 Aug 2025 05:14:27 +0800 Subject: [PATCH 02/13] Update template to 0.1.58 (#9302) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 2fb38ef27..82af5690b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.24.4 -comfyui-workflow-templates==0.1.53 +comfyui-workflow-templates==0.1.58 comfyui-embedded-docs==0.2.6 torch torchsde From e1d4f36d8df7446ebb1a5f2bf9c708c38a159f22 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 12 Aug 2025 17:13:04 -0700 Subject: [PATCH 03/13] Update test release package workflow with python 3.13 cu129. (#9306) --- .github/workflows/windows_release_package.yml | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/.github/workflows/windows_release_package.yml b/.github/workflows/windows_release_package.yml index 3926a65f3..b51746285 100644 --- a/.github/workflows/windows_release_package.yml +++ b/.github/workflows/windows_release_package.yml @@ -7,19 +7,19 @@ on: description: 'cuda version' required: true type: string - default: "128" + default: "129" python_minor: description: 'python minor version' required: true type: string - default: "12" + default: "13" python_patch: description: 'python patch version' required: true type: string - default: "10" + default: "6" # push: # branches: # - master @@ -64,6 +64,8 @@ jobs: ./python.exe get-pip.py ./python.exe -s -m pip install ../cu${{ inputs.cu }}_python_deps/* sed -i '1i../ComfyUI' ./python3${{ inputs.python_minor }}._pth + + rm ./Lib/site-packages/torch/lib/dnnl.lib #I don't think this is actually used and I need the space cd .. git clone --depth 1 https://github.com/comfyanonymous/taesd From 560d38f34c5bd532f89f2178f01ee819cf145820 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 12 Aug 2025 20:26:33 -0700 Subject: [PATCH 04/13] Wan2.2 fun control support. (#9292) --- comfy/ldm/wan/model.py | 19 +++++++++++++ comfy/model_base.py | 10 ++++++- comfy/model_detection.py | 5 ++++ comfy_extras/nodes_wan.py | 58 +++++++++++++++++++++++++++++++++++++++ 4 files changed, 91 insertions(+), 1 deletion(-) diff --git a/comfy/ldm/wan/model.py b/comfy/ldm/wan/model.py index 86d0795e9..4e2d99566 100644 --- a/comfy/ldm/wan/model.py +++ b/comfy/ldm/wan/model.py @@ -391,6 +391,7 @@ class WanModel(torch.nn.Module): cross_attn_norm=True, eps=1e-6, flf_pos_embed_token_number=None, + in_dim_ref_conv=None, image_model=None, device=None, dtype=None, @@ -484,6 +485,11 @@ class WanModel(torch.nn.Module): else: self.img_emb = None + if in_dim_ref_conv is not None: + self.ref_conv = operations.Conv2d(in_dim_ref_conv, dim, kernel_size=patch_size[1:], stride=patch_size[1:], device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + else: + self.ref_conv = None + def forward_orig( self, x, @@ -526,6 +532,13 @@ class WanModel(torch.nn.Module): e = e.reshape(t.shape[0], -1, e.shape[-1]) e0 = self.time_projection(e).unflatten(2, (6, self.dim)) + full_ref = None + if self.ref_conv is not None: + full_ref = kwargs.get("reference_latent", None) + if full_ref is not None: + full_ref = self.ref_conv(full_ref).flatten(2).transpose(1, 2) + x = torch.concat((full_ref, x), dim=1) + # context context = self.text_embedding(context) @@ -552,6 +565,9 @@ class WanModel(torch.nn.Module): # head x = self.head(x, e) + if full_ref is not None: + x = x[:, full_ref.shape[1]:] + # unpatchify x = self.unpatchify(x, grid_sizes) return x @@ -570,6 +586,9 @@ class WanModel(torch.nn.Module): x = torch.cat([x, time_dim_concat], dim=2) t_len = ((x.shape[2] + (patch_size[0] // 2)) // patch_size[0]) + if self.ref_conv is not None and "reference_latent" in kwargs: + t_len += 1 + img_ids = torch.zeros((t_len, h_len, w_len, 3), device=x.device, dtype=x.dtype) img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(0, t_len - 1, steps=t_len, device=x.device, dtype=x.dtype).reshape(-1, 1, 1) img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).reshape(1, -1, 1) diff --git a/comfy/model_base.py b/comfy/model_base.py index 8a2d9cbe6..cde61df7c 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1124,7 +1124,11 @@ class WAN21(BaseModel): mask = mask.repeat(1, 4, 1, 1, 1) mask = utils.resize_to_batch_size(mask, noise.shape[0]) - return torch.cat((mask, image), dim=1) + concat_mask_index = kwargs.get("concat_mask_index", 0) + if concat_mask_index != 0: + return torch.cat((image[:, :concat_mask_index], mask, image[:, concat_mask_index:]), dim=1) + else: + return torch.cat((mask, image), dim=1) def extra_conds(self, **kwargs): out = super().extra_conds(**kwargs) @@ -1140,6 +1144,10 @@ class WAN21(BaseModel): if time_dim_concat is not None: out['time_dim_concat'] = comfy.conds.CONDRegular(self.process_latent_in(time_dim_concat)) + reference_latents = kwargs.get("reference_latents", None) + if reference_latents is not None: + out['reference_latent'] = comfy.conds.CONDRegular(self.process_latent_in(reference_latents[-1])[:, :, 0]) + return out diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 8b57ebd2f..8acc51e20 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -373,6 +373,11 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): flf_weight = state_dict.get('{}img_emb.emb_pos'.format(key_prefix)) if flf_weight is not None: dit_config["flf_pos_embed_token_number"] = flf_weight.shape[1] + + ref_conv_weight = state_dict.get('{}ref_conv.weight'.format(key_prefix)) + if ref_conv_weight is not None: + dit_config["in_dim_ref_conv"] = ref_conv_weight.shape[1] + return dit_config if '{}latent_in.weight'.format(key_prefix) in state_dict_keys: # Hunyuan 3D diff --git a/comfy_extras/nodes_wan.py b/comfy_extras/nodes_wan.py index 0067d054d..f80c83ba6 100644 --- a/comfy_extras/nodes_wan.py +++ b/comfy_extras/nodes_wan.py @@ -103,6 +103,63 @@ class WanFunControlToVideo: out_latent["samples"] = latent return (positive, negative, out_latent) +class Wan22FunControlToVideo: + @classmethod + def INPUT_TYPES(s): + return {"required": {"positive": ("CONDITIONING", ), + "negative": ("CONDITIONING", ), + "vae": ("VAE", ), + "width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), + "height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), + "length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}), + "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), + }, + "optional": {"ref_image": ("IMAGE", ), + "control_video": ("IMAGE", ), + # "start_image": ("IMAGE", ), + }} + + RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") + RETURN_NAMES = ("positive", "negative", "latent") + FUNCTION = "encode" + + CATEGORY = "conditioning/video_models" + + def encode(self, positive, negative, vae, width, height, length, batch_size, ref_image=None, start_image=None, control_video=None): + latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) + concat_latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) + concat_latent = comfy.latent_formats.Wan21().process_out(concat_latent) + concat_latent = concat_latent.repeat(1, 2, 1, 1, 1) + mask = torch.ones((1, 1, latent.shape[2] * 4, latent.shape[-2], latent.shape[-1])) + + if start_image is not None: + start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) + concat_latent_image = vae.encode(start_image[:, :, :, :3]) + concat_latent[:,16:,:concat_latent_image.shape[2]] = concat_latent_image[:,:,:concat_latent.shape[2]] + mask[:, :, :start_image.shape[0] + 3] = 0.0 + + ref_latent = None + if ref_image is not None: + ref_image = comfy.utils.common_upscale(ref_image[:1].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) + ref_latent = vae.encode(ref_image[:, :, :, :3]) + + if control_video is not None: + control_video = comfy.utils.common_upscale(control_video[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) + concat_latent_image = vae.encode(control_video[:, :, :, :3]) + concat_latent[:,:16,:concat_latent_image.shape[2]] = concat_latent_image[:,:,:concat_latent.shape[2]] + + mask = mask.view(1, mask.shape[2] // 4, 4, mask.shape[3], mask.shape[4]).transpose(1, 2) + positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent, "concat_mask": mask, "concat_mask_index": 16}) + negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent, "concat_mask": mask, "concat_mask_index": 16}) + + if ref_latent is not None: + positive = node_helpers.conditioning_set_values(positive, {"reference_latents": [ref_latent]}, append=True) + negative = node_helpers.conditioning_set_values(negative, {"reference_latents": [ref_latent]}, append=True) + + out_latent = {} + out_latent["samples"] = latent + return (positive, negative, out_latent) + class WanFirstLastFrameToVideo: @classmethod def INPUT_TYPES(s): @@ -733,6 +790,7 @@ NODE_CLASS_MAPPINGS = { "WanTrackToVideo": WanTrackToVideo, "WanImageToVideo": WanImageToVideo, "WanFunControlToVideo": WanFunControlToVideo, + "Wan22FunControlToVideo": Wan22FunControlToVideo, "WanFunInpaintToVideo": WanFunInpaintToVideo, "WanFirstLastFrameToVideo": WanFirstLastFrameToVideo, "WanVaceToVideo": WanVaceToVideo, From 898d88e10e45f38500ca6044280bab4ca2f2d273 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Tue, 12 Aug 2025 20:34:58 -0700 Subject: [PATCH 05/13] Make torchaudio exception catching less specific (#9309) --- comfy_api/latest/_ui.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy_api/latest/_ui.py b/comfy_api/latest/_ui.py index 61597038f..26a55615f 100644 --- a/comfy_api/latest/_ui.py +++ b/comfy_api/latest/_ui.py @@ -12,7 +12,7 @@ import torch try: import torchaudio TORCH_AUDIO_AVAILABLE = True -except ImportError: +except: TORCH_AUDIO_AVAILABLE = False from PIL import Image as PILImage from PIL.PngImagePlugin import PngInfo From 3294782d19c3af0c6166aafe0465fe6b59571d17 Mon Sep 17 00:00:00 2001 From: ComfyUI Wiki Date: Wed, 13 Aug 2025 14:50:50 +0800 Subject: [PATCH 06/13] Update template to 0.1.59 (#9313) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 82af5690b..bfb31a73f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.24.4 -comfyui-workflow-templates==0.1.58 +comfyui-workflow-templates==0.1.59 comfyui-embedded-docs==0.2.6 torch torchsde From 5ca8e2fac3b6826261c5991b0663b69eff60b3a1 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 13 Aug 2025 00:01:12 -0700 Subject: [PATCH 07/13] Update release workflow to python3.13 pytorch cu129 (#9315) * Try to reduce size of portable even more. * Update stable release workflow to python 3.13 cu129 * Update dependencies workflow to python3.13 cu129 --- .github/workflows/stable-release.yml | 15 ++++++++++----- .../workflows/windows_release_dependencies.yml | 6 +++--- .github/workflows/windows_release_package.yml | 2 ++ 3 files changed, 15 insertions(+), 8 deletions(-) diff --git a/.github/workflows/stable-release.yml b/.github/workflows/stable-release.yml index 61105abe4..a5a1ed2d0 100644 --- a/.github/workflows/stable-release.yml +++ b/.github/workflows/stable-release.yml @@ -12,17 +12,17 @@ on: description: 'CUDA version' required: true type: string - default: "128" + default: "129" python_minor: description: 'Python minor version' required: true type: string - default: "12" + default: "13" python_patch: description: 'Python patch version' required: true type: string - default: "10" + default: "6" jobs: @@ -66,8 +66,13 @@ jobs: curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py ./python.exe get-pip.py ./python.exe -s -m pip install ../cu${{ inputs.cu }}_python_deps/* - sed -i '1i../ComfyUI' ./python3${{ inputs.python_minor }}._pth - cd .. + sed -i '1i../ComfyUI' ./python3${{ inputs.python_minor }}._pth + + rm ./Lib/site-packages/torch/lib/dnnl.lib #I don't think this is actually used and I need the space + rm ./Lib/site-packages/torch/lib/libprotoc.lib + rm ./Lib/site-packages/torch/lib/libprotobuf.lib + + cd .. git clone --depth 1 https://github.com/comfyanonymous/taesd cp taesd/*.safetensors ./ComfyUI_copy/models/vae_approx/ diff --git a/.github/workflows/windows_release_dependencies.yml b/.github/workflows/windows_release_dependencies.yml index dfdb96d50..7761cc1ed 100644 --- a/.github/workflows/windows_release_dependencies.yml +++ b/.github/workflows/windows_release_dependencies.yml @@ -17,19 +17,19 @@ on: description: 'cuda version' required: true type: string - default: "128" + default: "129" python_minor: description: 'python minor version' required: true type: string - default: "12" + default: "13" python_patch: description: 'python patch version' required: true type: string - default: "10" + default: "6" # push: # branches: # - master diff --git a/.github/workflows/windows_release_package.yml b/.github/workflows/windows_release_package.yml index b51746285..3334e6839 100644 --- a/.github/workflows/windows_release_package.yml +++ b/.github/workflows/windows_release_package.yml @@ -66,6 +66,8 @@ jobs: sed -i '1i../ComfyUI' ./python3${{ inputs.python_minor }}._pth rm ./Lib/site-packages/torch/lib/dnnl.lib #I don't think this is actually used and I need the space + rm ./Lib/site-packages/torch/lib/libprotoc.lib + rm ./Lib/site-packages/torch/lib/libprotobuf.lib cd .. git clone --depth 1 https://github.com/comfyanonymous/taesd From e400f26c8fc9867248394616a4b58ecc4c53fbfd Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 13 Aug 2025 00:44:54 -0700 Subject: [PATCH 08/13] Downgrade frontend for release. (#9316) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index bfb31a73f..56ed85e01 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -comfyui-frontend-package==1.24.4 +comfyui-frontend-package==1.23.4 comfyui-workflow-templates==0.1.59 comfyui-embedded-docs==0.2.6 torch From d5c1954d5cd4a789bbf84d2b75a955a5a3f93de8 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 13 Aug 2025 03:46:38 -0400 Subject: [PATCH 09/13] ComfyUI version 0.3.50 --- comfyui_version.py | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/comfyui_version.py b/comfyui_version.py index 5e2d09c81..29ec07ca6 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.3.49" +__version__ = "0.3.50" diff --git a/pyproject.toml b/pyproject.toml index 3c530ba85..659b5730a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.3.49" +version = "0.3.50" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.9" From 615eb52049df98cebdd67bc672b66dc059171d7c Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 13 Aug 2025 00:48:06 -0700 Subject: [PATCH 10/13] Put back frontend version. (#9317) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 56ed85e01..bfb31a73f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -comfyui-frontend-package==1.23.4 +comfyui-frontend-package==1.24.4 comfyui-workflow-templates==0.1.59 comfyui-embedded-docs==0.2.6 torch From afa0a45206832b0e64e38454b7841d1da7ca56e4 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 13 Aug 2025 11:42:08 -0700 Subject: [PATCH 11/13] Reduce portable size again. (#9323) * compress more * test * not needed --- .github/workflows/stable-release.yml | 2 +- .github/workflows/windows_release_package.yml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/stable-release.yml b/.github/workflows/stable-release.yml index a5a1ed2d0..2bc8e5905 100644 --- a/.github/workflows/stable-release.yml +++ b/.github/workflows/stable-release.yml @@ -90,7 +90,7 @@ jobs: cd .. - "C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=9 -mfb=128 -md=512m -ms=on -mf=BCJ2 ComfyUI_windows_portable.7z ComfyUI_windows_portable + "C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=9 -mfb=128 -md=768m -ms=on -mf=BCJ2 ComfyUI_windows_portable.7z ComfyUI_windows_portable mv ComfyUI_windows_portable.7z ComfyUI/ComfyUI_windows_portable_nvidia.7z cd ComfyUI_windows_portable diff --git a/.github/workflows/windows_release_package.yml b/.github/workflows/windows_release_package.yml index 3334e6839..46375698e 100644 --- a/.github/workflows/windows_release_package.yml +++ b/.github/workflows/windows_release_package.yml @@ -86,7 +86,7 @@ jobs: cd .. - "C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=9 -mfb=128 -md=512m -ms=on -mf=BCJ2 ComfyUI_windows_portable.7z ComfyUI_windows_portable + "C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=9 -mfb=128 -md=768m -ms=on -mf=BCJ2 ComfyUI_windows_portable.7z ComfyUI_windows_portable mv ComfyUI_windows_portable.7z ComfyUI/new_ComfyUI_windows_portable_nvidia_cu${{ inputs.cu }}_or_cpu.7z cd ComfyUI_windows_portable From 3da5a07510794c37d437cbea1d94065bb0aa8ebc Mon Sep 17 00:00:00 2001 From: contentis Date: Wed, 13 Aug 2025 20:53:27 +0200 Subject: [PATCH 12/13] SDPA backend priority (#9299) --- comfy/ldm/hunyuan3d/vae.py | 2 +- comfy/ldm/modules/attention.py | 4 ++-- comfy/ldm/modules/diffusionmodules/model.py | 2 +- comfy/ops.py | 13 +++++++++++++ 4 files changed, 17 insertions(+), 4 deletions(-) diff --git a/comfy/ldm/hunyuan3d/vae.py b/comfy/ldm/hunyuan3d/vae.py index 5eb2c6548..bea6090a2 100644 --- a/comfy/ldm/hunyuan3d/vae.py +++ b/comfy/ldm/hunyuan3d/vae.py @@ -178,7 +178,7 @@ class FourierEmbedder(nn.Module): class CrossAttentionProcessor: def __call__(self, attn, q, k, v): - out = F.scaled_dot_product_attention(q, k, v) + out = ops.scaled_dot_product_attention(q, k, v) return out diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 35d2270ee..19c3c7af1 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -448,7 +448,7 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha mask = mask.unsqueeze(1) if SDP_BATCH_LIMIT >= b: - out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False) + out = ops.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False) if not skip_output_reshape: out = ( out.transpose(1, 2).reshape(b, -1, heads * dim_head) @@ -461,7 +461,7 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha if mask.shape[0] > 1: m = mask[i : i + SDP_BATCH_LIMIT] - out[i : i + SDP_BATCH_LIMIT] = torch.nn.functional.scaled_dot_product_attention( + out[i : i + SDP_BATCH_LIMIT] = ops.scaled_dot_product_attention( q[i : i + SDP_BATCH_LIMIT], k[i : i + SDP_BATCH_LIMIT], v[i : i + SDP_BATCH_LIMIT], diff --git a/comfy/ldm/modules/diffusionmodules/model.py b/comfy/ldm/modules/diffusionmodules/model.py index 5c0373b74..79160412f 100644 --- a/comfy/ldm/modules/diffusionmodules/model.py +++ b/comfy/ldm/modules/diffusionmodules/model.py @@ -285,7 +285,7 @@ def pytorch_attention(q, k, v): ) try: - out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False) + out = ops.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False) out = out.transpose(2, 3).reshape(orig_shape) except model_management.OOM_EXCEPTION: logging.warning("scaled_dot_product_attention OOMed: switched to slice attention") diff --git a/comfy/ops.py b/comfy/ops.py index 2cc9bbc27..8b7b662b6 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -23,9 +23,18 @@ from comfy.cli_args import args, PerformanceFeature import comfy.float import comfy.rmsnorm import contextlib +from torch.nn.attention import SDPBackend, sdpa_kernel cast_to = comfy.model_management.cast_to #TODO: remove once no more references +SDPA_BACKEND_PRIORITY = [ + SDPBackend.FLASH_ATTENTION, + SDPBackend.EFFICIENT_ATTENTION, + SDPBackend.MATH, +] +if torch.cuda.is_available(): + SDPA_BACKEND_PRIORITY.insert(0, SDPBackend.CUDNN_ATTENTION) + def cast_to_input(weight, input, non_blocking=False, copy=True): return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy) @@ -249,6 +258,10 @@ class disable_weight_init: else: raise ValueError(f"unsupported dimensions: {dims}") + @staticmethod + @sdpa_kernel(backends=SDPA_BACKEND_PRIORITY, set_priority=True) + def scaled_dot_product_attention(q, k, v, *args, **kwargs): + return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs) class manual_cast(disable_weight_init): class Linear(disable_weight_init.Linear): From 9df8792d4b894a8ea8034414ef63f70deee4b1af Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 13 Aug 2025 12:12:41 -0700 Subject: [PATCH 13/13] Make last PR not crash comfy on old pytorch. (#9324) --- comfy/ldm/hunyuan3d/vae.py | 2 +- comfy/ldm/modules/attention.py | 4 +-- comfy/ldm/modules/diffusionmodules/model.py | 2 +- comfy/ops.py | 36 +++++++++++++-------- 4 files changed, 27 insertions(+), 17 deletions(-) diff --git a/comfy/ldm/hunyuan3d/vae.py b/comfy/ldm/hunyuan3d/vae.py index bea6090a2..6e8cbf1d9 100644 --- a/comfy/ldm/hunyuan3d/vae.py +++ b/comfy/ldm/hunyuan3d/vae.py @@ -178,7 +178,7 @@ class FourierEmbedder(nn.Module): class CrossAttentionProcessor: def __call__(self, attn, q, k, v): - out = ops.scaled_dot_product_attention(q, k, v) + out = comfy.ops.scaled_dot_product_attention(q, k, v) return out diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 19c3c7af1..043df28df 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -448,7 +448,7 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha mask = mask.unsqueeze(1) if SDP_BATCH_LIMIT >= b: - out = ops.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False) + out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False) if not skip_output_reshape: out = ( out.transpose(1, 2).reshape(b, -1, heads * dim_head) @@ -461,7 +461,7 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha if mask.shape[0] > 1: m = mask[i : i + SDP_BATCH_LIMIT] - out[i : i + SDP_BATCH_LIMIT] = ops.scaled_dot_product_attention( + out[i : i + SDP_BATCH_LIMIT] = comfy.ops.scaled_dot_product_attention( q[i : i + SDP_BATCH_LIMIT], k[i : i + SDP_BATCH_LIMIT], v[i : i + SDP_BATCH_LIMIT], diff --git a/comfy/ldm/modules/diffusionmodules/model.py b/comfy/ldm/modules/diffusionmodules/model.py index 79160412f..1fd12b35a 100644 --- a/comfy/ldm/modules/diffusionmodules/model.py +++ b/comfy/ldm/modules/diffusionmodules/model.py @@ -285,7 +285,7 @@ def pytorch_attention(q, k, v): ) try: - out = ops.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False) + out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False) out = out.transpose(2, 3).reshape(orig_shape) except model_management.OOM_EXCEPTION: logging.warning("scaled_dot_product_attention OOMed: switched to slice attention") diff --git a/comfy/ops.py b/comfy/ops.py index 8b7b662b6..be312d714 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -23,18 +23,32 @@ from comfy.cli_args import args, PerformanceFeature import comfy.float import comfy.rmsnorm import contextlib -from torch.nn.attention import SDPBackend, sdpa_kernel + + +def scaled_dot_product_attention(q, k, v, *args, **kwargs): + return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs) + + +try: + if torch.cuda.is_available(): + from torch.nn.attention import SDPBackend, sdpa_kernel + + SDPA_BACKEND_PRIORITY = [ + SDPBackend.FLASH_ATTENTION, + SDPBackend.EFFICIENT_ATTENTION, + SDPBackend.MATH, + ] + + SDPA_BACKEND_PRIORITY.insert(0, SDPBackend.CUDNN_ATTENTION) + + @sdpa_kernel(backends=SDPA_BACKEND_PRIORITY, set_priority=True) + def scaled_dot_product_attention(q, k, v, *args, **kwargs): + return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs) +except (ModuleNotFoundError, TypeError): + logging.warning("Could not set sdpa backend priority.") cast_to = comfy.model_management.cast_to #TODO: remove once no more references -SDPA_BACKEND_PRIORITY = [ - SDPBackend.FLASH_ATTENTION, - SDPBackend.EFFICIENT_ATTENTION, - SDPBackend.MATH, -] -if torch.cuda.is_available(): - SDPA_BACKEND_PRIORITY.insert(0, SDPBackend.CUDNN_ATTENTION) - def cast_to_input(weight, input, non_blocking=False, copy=True): return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy) @@ -258,10 +272,6 @@ class disable_weight_init: else: raise ValueError(f"unsupported dimensions: {dims}") - @staticmethod - @sdpa_kernel(backends=SDPA_BACKEND_PRIORITY, set_priority=True) - def scaled_dot_product_attention(q, k, v, *args, **kwargs): - return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs) class manual_cast(disable_weight_init): class Linear(disable_weight_init.Linear):