From 8ab15c863c91bce1f9c3a32f947cb4ec659fd7fb Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Fri, 9 May 2025 01:52:47 -0700 Subject: [PATCH] Add --mmap-torch-files to enable use of mmap when loading ckpt/pt (#8021) --- comfy/cli_args.py | 2 ++ comfy/utils.py | 9 ++++++++- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 97b348f0..de292d9b 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -142,6 +142,8 @@ class PerformanceFeature(enum.Enum): parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: fp16_accumulation fp8_matrix_mult cublas_ops") +parser.add_argument("--mmap-torch-files", action="store_true", help="Use mmap when loading ckpt/pt files.") + parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.") parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.") parser.add_argument("--windows-standalone-build", action="store_true", help="Windows standalone build: Enable convenient things that most people using the standalone windows build will probably enjoy (like auto opening the page on startup).") diff --git a/comfy/utils.py b/comfy/utils.py index a826e41b..561e1b85 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -28,6 +28,9 @@ import logging import itertools from torch.nn.functional import interpolate from einops import rearrange +from comfy.cli_args import args + +MMAP_TORCH_FILES = args.mmap_torch_files ALWAYS_SAFE_LOAD = False if hasattr(torch.serialization, "add_safe_globals"): # TODO: this was added in pytorch 2.4, the unsafe path should be removed once earlier versions are deprecated @@ -67,8 +70,12 @@ def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False): raise ValueError("{}\n\nFile path: {}\n\nThe safetensors file is corrupt/incomplete. Check the file size and make sure you have copied/downloaded it correctly.".format(message, ckpt)) raise e else: + torch_args = {} + if MMAP_TORCH_FILES: + torch_args["mmap"] = True + if safe_load or ALWAYS_SAFE_LOAD: - pl_sd = torch.load(ckpt, map_location=device, weights_only=True) + pl_sd = torch.load(ckpt, map_location=device, weights_only=True, **torch_args) else: pl_sd = torch.load(ckpt, map_location=device, pickle_module=comfy.checkpoint_pickle) if "global_step" in pl_sd: