diff --git a/comfy_api/v3/io.py b/comfy_api/v3/io.py index 3d36793a4..4cc758562 100644 --- a/comfy_api/v3/io.py +++ b/comfy_api/v3/io.py @@ -4,9 +4,7 @@ from enum import Enum from abc import ABC, abstractmethod from dataclasses import dataclass, asdict from collections import Counter -import comfy.utils -import folder_paths -import logging +from comfy_api.v3.resources import Resources, ResourcesLocal # used for type hinting import torch from spandrel import ImageModelDescriptor @@ -285,55 +283,6 @@ class NodeStateLocal(NodeState): del self.local_state[key] -class ResourceKey(ABC): - def __init__(self): - ... - -class ResourceKeyFolderFilename(ResourceKey): - def __init__(self, folder_name: str, file_name: str): - self.folder_name = folder_name - self.file_name = file_name - - def __hash__(self): - return hash((self.folder_name, self.file_name)) - - def __eq__(self, other: object) -> bool: - if not isinstance(other, ResourceKeyFolderFilename): - return False - return self.folder_name == other.folder_name and self.file_name == other.file_name - - def __str__(self): - return f"{self.folder_name} -> {self.file_name}" - -class Resources(ABC): - def __init__(self): - ... - - @abstractmethod - def get_torch_dict(self, key: ResourceKey) -> dict[str, torch.Tensor]: - pass - -class ResourcesLocal(Resources): - def __init__(self): - super().__init__() - self.local_resources: dict[ResourceKey, dict[str, torch.Tensor]] = {} - - def get_torch_dict(self, key: ResourceKey) -> dict[str, torch.Tensor]: - cached = self.local_resources.get(key, None) - if cached is not None: - logging.info(f"Using cached resource '{key}'") - return cached - logging.info(f"Loading resource '{key}'") - to_return = None - if isinstance(key, ResourceKeyFolderFilename): - to_return = comfy.utils.load_torch_file(folder_paths.get_full_path_or_raise(key.folder_name, key.file_name), safe_load=True) - - if to_return is not None: - self.local_resources[key] = to_return - return to_return - raise Exception(f"Unsupported resource key type: {type(key)}") - - @comfytype(io_type="BOOLEAN") class Boolean: Type = bool diff --git a/comfy_api/v3/resources.py b/comfy_api/v3/resources.py new file mode 100644 index 000000000..6ff59d6ae --- /dev/null +++ b/comfy_api/v3/resources.py @@ -0,0 +1,65 @@ +from __future__ import annotations +import comfy.utils +import folder_paths +import logging +from abc import ABC, abstractmethod +from typing import Any +import torch + +class ResourceKey(ABC): + Type = Any + def __init__(self): + ... + +class TorchDictFolderFilename(ResourceKey): + '''Key for requesting a torch file via file_name from a folder category.''' + Type = dict[str, torch.Tensor] + def __init__(self, folder_name: str, file_name: str): + self.folder_name = folder_name + self.file_name = file_name + + def __hash__(self): + return hash((self.folder_name, self.file_name)) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, TorchDictFolderFilename): + return False + return self.folder_name == other.folder_name and self.file_name == other.file_name + + def __str__(self): + return f"{self.folder_name} -> {self.file_name}" + +class Resources(ABC): + def __init__(self): + ... + + @abstractmethod + def get(self, key: ResourceKey, default: Any=...) -> Any: + pass + +class ResourcesLocal(Resources): + def __init__(self): + super().__init__() + self.local_resources: dict[ResourceKey, Any] = {} + + def get(self, key: ResourceKey, default: Any=...) -> Any: + cached = self.local_resources.get(key, None) + if cached is not None: + logging.info(f"Using cached resource '{key}'") + return cached + logging.info(f"Loading resource '{key}'") + to_return = None + if isinstance(key, TorchDictFolderFilename): + if default is ...: + to_return = comfy.utils.load_torch_file(folder_paths.get_full_path_or_raise(key.folder_name, key.file_name), safe_load=True) + else: + full_path = folder_paths.get_full_path(key.folder_name, key.file_name) + if full_path is not None: + to_return = comfy.utils.load_torch_file(full_path, safe_load=True) + + if to_return is not None: + self.local_resources[key] = to_return + return to_return + if default is not ...: + return default + raise Exception(f"Unsupported resource key type: {type(key)}") diff --git a/comfy_extras/nodes_v3_test.py b/comfy_extras/nodes_v3_test.py index e84fdaa87..9120d8b8c 100644 --- a/comfy_extras/nodes_v3_test.py +++ b/comfy_extras/nodes_v3_test.py @@ -1,5 +1,5 @@ import torch -from comfy_api.v3 import io, ui +from comfy_api.v3 import io, ui, resources import logging import folder_paths import comfy.utils @@ -143,7 +143,7 @@ class V3LoraLoader(io.ComfyNodeV3): if strength_model == 0 and strength_clip == 0: return io.NodeOutput(model, clip) - lora = cls.resources.get_torch_dict(io.ResourceKeyFolderFilename("loras", lora_name)) + lora = cls.resources.get(resources.TorchDictFolderFilename("loras", lora_name)) model_lora, clip_lora = comfy.sd.load_lora_for_models(model, clip, lora, strength_model, strength_clip) return io.NodeOutput(model_lora, clip_lora)