Created Add Memory to Reserve node

This commit is contained in:
Jedrzej Kosinski
2025-08-18 14:45:21 -07:00
parent bd2ab73976
commit 34b1f51f4a
4 changed files with 60 additions and 3 deletions

View File

@@ -582,16 +582,23 @@ def free_memory(memory_required, device, keep_loaded=[]):
soft_empty_cache()
return unloaded_models
def get_models_memory_reserve(models):
total_reserve = 0
for model in models:
total_reserve += model.get_model_memory_reserve(convert_to_bytes=True)
return total_reserve
def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False):
cleanup_models_gc()
global vram_state
inference_memory = minimum_inference_memory()
extra_mem = max(inference_memory, memory_required + extra_reserved_memory())
models_memory_reserve = get_models_memory_reserve(models)
extra_mem = max(inference_memory, memory_required + extra_reserved_memory() + models_memory_reserve)
if minimum_memory_required is None:
minimum_memory_required = extra_mem
else:
minimum_memory_required = max(inference_memory, minimum_memory_required + extra_reserved_memory())
minimum_memory_required = max(inference_memory, minimum_memory_required + extra_reserved_memory() + models_memory_reserve)
models = set(models)

View File

@@ -24,7 +24,7 @@ import inspect
import logging
import math
import uuid
from typing import Callable, Optional
from typing import Callable, Optional, Union
import torch
@@ -84,6 +84,12 @@ def set_model_options_pre_cfg_function(model_options, pre_cfg_function, disable_
model_options["disable_cfg1_optimization"] = True
return model_options
def add_model_options_memory_reserve(model_options, memory_reserve_gb: float):
if "model_memory_reserve" not in model_options:
model_options["model_memory_reserve"] = []
model_options["model_memory_reserve"].append(memory_reserve_gb)
return model_options
def create_model_options_clone(orig_model_options: dict):
return comfy.patcher_extension.copy_nested_dicts(orig_model_options)
@@ -439,6 +445,17 @@ class ModelPatcher:
self.force_cast_weights = True
self.patches_uuid = uuid.uuid4() #TODO: optimize by preventing a full model reload for this
def add_model_memory_reserve(self, memory_reserve_gb: float):
"""Adds additional expected memory usage for the model, in gigabytes."""
self.model_options = add_model_options_memory_reserve(self.model_options, memory_reserve_gb)
def get_model_memory_reserve(self, convert_to_bytes: bool = False) -> Union[float, int]:
"""Returns the total expected memory usage for the model in gigabytes, or bytes if convert_to_bytes is True."""
total_reserve = sum(self.model_options.get("model_memory_reserve", []))
if convert_to_bytes:
return total_reserve * 1024 * 1024 * 1024
return total_reserve
def add_weight_wrapper(self, name, function):
self.weight_wrapper_patches[name] = self.weight_wrapper_patches.get(name, []) + [function]
self.patches_uuid = uuid.uuid4()

View File

@@ -0,0 +1,32 @@
from comfy_api.latest import io, ComfyExtension
class MemoryReserveNode(io.ComfyNode):
@classmethod
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="AddMemoryToReserve",
display_name="Add Memory to Reserve",
description="Adds additional expected memory usage for the model, in gigabytes.",
category="advanced/debug/model",
inputs=[
io.Model.Input("model", tooltip="The model to add memory reserve to."),
io.Float.Input("memory_reserve_gb", min=0.0, default=0.0, max=2048.0, step=0.1, tooltip="The additional expected memory usage for the model, in gigabytes."),
],
outputs=[
io.Model.Output(tooltip="The model with the additional memory reserve."),
],
)
@classmethod
def execute(cls, model: io.Model.Type, memory_reserve_gb: float) -> io.NodeOutput:
model.add_model_memory_reserve(memory_reserve_gb)
return io.NodeOutput(model)
class MemoryReserveExtension(ComfyExtension):
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
MemoryReserveNode,
]
def comfy_entrypoint():
return MemoryReserveExtension()

View File

@@ -2321,6 +2321,7 @@ async def init_builtin_extra_nodes():
"nodes_edit_model.py",
"nodes_tcfg.py",
"nodes_context_windows.py",
"nodes_memory_reserve.py",
]
import_failed = []