diff --git a/comfy_api/v3/io.py b/comfy_api/v3/io.py index 268f131b8..3d36793a4 100644 --- a/comfy_api/v3/io.py +++ b/comfy_api/v3/io.py @@ -4,6 +4,9 @@ from enum import Enum from abc import ABC, abstractmethod from dataclasses import dataclass, asdict from collections import Counter +import comfy.utils +import folder_paths +import logging # used for type hinting import torch from spandrel import ImageModelDescriptor @@ -281,6 +284,56 @@ class NodeStateLocal(NodeState): def __delitem__(self, key: str): del self.local_state[key] + +class ResourceKey(ABC): + def __init__(self): + ... + +class ResourceKeyFolderFilename(ResourceKey): + def __init__(self, folder_name: str, file_name: str): + self.folder_name = folder_name + self.file_name = file_name + + def __hash__(self): + return hash((self.folder_name, self.file_name)) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, ResourceKeyFolderFilename): + return False + return self.folder_name == other.folder_name and self.file_name == other.file_name + + def __str__(self): + return f"{self.folder_name} -> {self.file_name}" + +class Resources(ABC): + def __init__(self): + ... + + @abstractmethod + def get_torch_dict(self, key: ResourceKey) -> dict[str, torch.Tensor]: + pass + +class ResourcesLocal(Resources): + def __init__(self): + super().__init__() + self.local_resources: dict[ResourceKey, dict[str, torch.Tensor]] = {} + + def get_torch_dict(self, key: ResourceKey) -> dict[str, torch.Tensor]: + cached = self.local_resources.get(key, None) + if cached is not None: + logging.info(f"Using cached resource '{key}'") + return cached + logging.info(f"Loading resource '{key}'") + to_return = None + if isinstance(key, ResourceKeyFolderFilename): + to_return = comfy.utils.load_torch_file(folder_paths.get_full_path_or_raise(key.folder_name, key.file_name), safe_load=True) + + if to_return is not None: + self.local_resources[key] = to_return + return to_return + raise Exception(f"Unsupported resource key type: {type(key)}") + + @comfytype(io_type="BOOLEAN") class Boolean: Type = bool @@ -966,6 +1019,7 @@ class ComfyNodeV3: # filled in during execution state: NodeState = None + resources: Resources = None hidden: HiddenHolder = None @classmethod @@ -995,6 +1049,7 @@ class ComfyNodeV3: def __init__(self): self.local_state: NodeStateLocal = None + self.local_resources: ResourcesLocal = None self.__class__.VALIDATE_CLASS() @classmethod diff --git a/comfy_extras/nodes_v3_test.py b/comfy_extras/nodes_v3_test.py index 2f8a062b5..e84fdaa87 100644 --- a/comfy_extras/nodes_v3_test.py +++ b/comfy_extras/nodes_v3_test.py @@ -100,10 +100,6 @@ class V3TestNode(io.ComfyNodeV3): class V3LoraLoader(io.ComfyNodeV3): - class State(io.NodeState): - loaded_lora: tuple[str, Any] | None = None - state: State - @classmethod def DEFINE_SCHEMA(cls): return io.SchemaV3( @@ -147,17 +143,7 @@ class V3LoraLoader(io.ComfyNodeV3): if strength_model == 0 and strength_clip == 0: return io.NodeOutput(model, clip) - lora_path = folder_paths.get_full_path_or_raise("loras", lora_name) - lora = None - if cls.state.loaded_lora is not None: - if cls.state.loaded_lora[0] == lora_path: - lora = cls.state.loaded_lora[1] - else: - cls.state.loaded_lora = None - - if lora is None: - lora = comfy.utils.load_torch_file(lora_path, safe_load=True) - cls.state.loaded_lora = (lora_path, lora) + lora = cls.resources.get_torch_dict(io.ResourceKeyFolderFilename("loras", lora_name)) model_lora, clip_lora = comfy.sd.load_lora_for_models(model, clip, lora, strength_model, strength_clip) return io.NodeOutput(model_lora, clip_lora) diff --git a/execution.py b/execution.py index 8d5f708e3..5ab7f1fe0 100644 --- a/execution.py +++ b/execution.py @@ -28,7 +28,7 @@ from comfy_execution.graph import ( ) from comfy_execution.graph_utils import GraphBuilder, is_link from comfy_execution.validation import validate_node_input -from comfy_api.v3.io import NodeOutput, ComfyNodeV3, Hidden, NodeStateLocal +from comfy_api.v3.io import NodeOutput, ComfyNodeV3, Hidden, NodeStateLocal, ResourcesLocal class ExecutionResult(Enum): @@ -224,6 +224,11 @@ def _map_node_over_list(obj, input_data_all, func, allow_interrupt=False, execut if obj.local_state is None: obj.local_state = NodeStateLocal(class_clone.hidden.unique_id) class_clone.state = obj.local_state + # NOTE: this is a mock of resource management; for local, just stores ResourcesLocal on node instance + if hasattr(obj, "local_resources"): + if obj.local_resources is None: + obj.local_resources = ResourcesLocal() + class_clone.resources = obj.local_resources results.append(getattr(type(obj), func).__func__(class_clone, **inputs)) # V1 else: