mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-09-13 13:05:07 +00:00
Created Add Memory to Reserve node
This commit is contained in:
@@ -24,7 +24,7 @@ import inspect
|
||||
import logging
|
||||
import math
|
||||
import uuid
|
||||
from typing import Callable, Optional
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
@@ -84,6 +84,12 @@ def set_model_options_pre_cfg_function(model_options, pre_cfg_function, disable_
|
||||
model_options["disable_cfg1_optimization"] = True
|
||||
return model_options
|
||||
|
||||
def add_model_options_memory_reserve(model_options, memory_reserve_gb: float):
|
||||
if "model_memory_reserve" not in model_options:
|
||||
model_options["model_memory_reserve"] = []
|
||||
model_options["model_memory_reserve"].append(memory_reserve_gb)
|
||||
return model_options
|
||||
|
||||
def create_model_options_clone(orig_model_options: dict):
|
||||
return comfy.patcher_extension.copy_nested_dicts(orig_model_options)
|
||||
|
||||
@@ -439,6 +445,17 @@ class ModelPatcher:
|
||||
self.force_cast_weights = True
|
||||
self.patches_uuid = uuid.uuid4() #TODO: optimize by preventing a full model reload for this
|
||||
|
||||
def add_model_memory_reserve(self, memory_reserve_gb: float):
|
||||
"""Adds additional expected memory usage for the model, in gigabytes."""
|
||||
self.model_options = add_model_options_memory_reserve(self.model_options, memory_reserve_gb)
|
||||
|
||||
def get_model_memory_reserve(self, convert_to_bytes: bool = False) -> Union[float, int]:
|
||||
"""Returns the total expected memory usage for the model in gigabytes, or bytes if convert_to_bytes is True."""
|
||||
total_reserve = sum(self.model_options.get("model_memory_reserve", []))
|
||||
if convert_to_bytes:
|
||||
return total_reserve * 1024 * 1024 * 1024
|
||||
return total_reserve
|
||||
|
||||
def add_weight_wrapper(self, name, function):
|
||||
self.weight_wrapper_patches[name] = self.weight_wrapper_patches.get(name, []) + [function]
|
||||
self.patches_uuid = uuid.uuid4()
|
||||
|
Reference in New Issue
Block a user