mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-07-27 16:26:39 +00:00
Introduced Resources to ComfyNodeV3
This commit is contained in:
parent
2999212480
commit
0e7ff98e1d
@ -4,6 +4,9 @@ from enum import Enum
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from dataclasses import dataclass, asdict
|
from dataclasses import dataclass, asdict
|
||||||
from collections import Counter
|
from collections import Counter
|
||||||
|
import comfy.utils
|
||||||
|
import folder_paths
|
||||||
|
import logging
|
||||||
# used for type hinting
|
# used for type hinting
|
||||||
import torch
|
import torch
|
||||||
from spandrel import ImageModelDescriptor
|
from spandrel import ImageModelDescriptor
|
||||||
@ -281,6 +284,56 @@ class NodeStateLocal(NodeState):
|
|||||||
def __delitem__(self, key: str):
|
def __delitem__(self, key: str):
|
||||||
del self.local_state[key]
|
del self.local_state[key]
|
||||||
|
|
||||||
|
|
||||||
|
class ResourceKey(ABC):
|
||||||
|
def __init__(self):
|
||||||
|
...
|
||||||
|
|
||||||
|
class ResourceKeyFolderFilename(ResourceKey):
|
||||||
|
def __init__(self, folder_name: str, file_name: str):
|
||||||
|
self.folder_name = folder_name
|
||||||
|
self.file_name = file_name
|
||||||
|
|
||||||
|
def __hash__(self):
|
||||||
|
return hash((self.folder_name, self.file_name))
|
||||||
|
|
||||||
|
def __eq__(self, other: object) -> bool:
|
||||||
|
if not isinstance(other, ResourceKeyFolderFilename):
|
||||||
|
return False
|
||||||
|
return self.folder_name == other.folder_name and self.file_name == other.file_name
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return f"{self.folder_name} -> {self.file_name}"
|
||||||
|
|
||||||
|
class Resources(ABC):
|
||||||
|
def __init__(self):
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_torch_dict(self, key: ResourceKey) -> dict[str, torch.Tensor]:
|
||||||
|
pass
|
||||||
|
|
||||||
|
class ResourcesLocal(Resources):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.local_resources: dict[ResourceKey, dict[str, torch.Tensor]] = {}
|
||||||
|
|
||||||
|
def get_torch_dict(self, key: ResourceKey) -> dict[str, torch.Tensor]:
|
||||||
|
cached = self.local_resources.get(key, None)
|
||||||
|
if cached is not None:
|
||||||
|
logging.info(f"Using cached resource '{key}'")
|
||||||
|
return cached
|
||||||
|
logging.info(f"Loading resource '{key}'")
|
||||||
|
to_return = None
|
||||||
|
if isinstance(key, ResourceKeyFolderFilename):
|
||||||
|
to_return = comfy.utils.load_torch_file(folder_paths.get_full_path_or_raise(key.folder_name, key.file_name), safe_load=True)
|
||||||
|
|
||||||
|
if to_return is not None:
|
||||||
|
self.local_resources[key] = to_return
|
||||||
|
return to_return
|
||||||
|
raise Exception(f"Unsupported resource key type: {type(key)}")
|
||||||
|
|
||||||
|
|
||||||
@comfytype(io_type="BOOLEAN")
|
@comfytype(io_type="BOOLEAN")
|
||||||
class Boolean:
|
class Boolean:
|
||||||
Type = bool
|
Type = bool
|
||||||
@ -966,6 +1019,7 @@ class ComfyNodeV3:
|
|||||||
|
|
||||||
# filled in during execution
|
# filled in during execution
|
||||||
state: NodeState = None
|
state: NodeState = None
|
||||||
|
resources: Resources = None
|
||||||
hidden: HiddenHolder = None
|
hidden: HiddenHolder = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -995,6 +1049,7 @@ class ComfyNodeV3:
|
|||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.local_state: NodeStateLocal = None
|
self.local_state: NodeStateLocal = None
|
||||||
|
self.local_resources: ResourcesLocal = None
|
||||||
self.__class__.VALIDATE_CLASS()
|
self.__class__.VALIDATE_CLASS()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -100,10 +100,6 @@ class V3TestNode(io.ComfyNodeV3):
|
|||||||
|
|
||||||
|
|
||||||
class V3LoraLoader(io.ComfyNodeV3):
|
class V3LoraLoader(io.ComfyNodeV3):
|
||||||
class State(io.NodeState):
|
|
||||||
loaded_lora: tuple[str, Any] | None = None
|
|
||||||
state: State
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def DEFINE_SCHEMA(cls):
|
def DEFINE_SCHEMA(cls):
|
||||||
return io.SchemaV3(
|
return io.SchemaV3(
|
||||||
@ -147,17 +143,7 @@ class V3LoraLoader(io.ComfyNodeV3):
|
|||||||
if strength_model == 0 and strength_clip == 0:
|
if strength_model == 0 and strength_clip == 0:
|
||||||
return io.NodeOutput(model, clip)
|
return io.NodeOutput(model, clip)
|
||||||
|
|
||||||
lora_path = folder_paths.get_full_path_or_raise("loras", lora_name)
|
lora = cls.resources.get_torch_dict(io.ResourceKeyFolderFilename("loras", lora_name))
|
||||||
lora = None
|
|
||||||
if cls.state.loaded_lora is not None:
|
|
||||||
if cls.state.loaded_lora[0] == lora_path:
|
|
||||||
lora = cls.state.loaded_lora[1]
|
|
||||||
else:
|
|
||||||
cls.state.loaded_lora = None
|
|
||||||
|
|
||||||
if lora is None:
|
|
||||||
lora = comfy.utils.load_torch_file(lora_path, safe_load=True)
|
|
||||||
cls.state.loaded_lora = (lora_path, lora)
|
|
||||||
|
|
||||||
model_lora, clip_lora = comfy.sd.load_lora_for_models(model, clip, lora, strength_model, strength_clip)
|
model_lora, clip_lora = comfy.sd.load_lora_for_models(model, clip, lora, strength_model, strength_clip)
|
||||||
return io.NodeOutput(model_lora, clip_lora)
|
return io.NodeOutput(model_lora, clip_lora)
|
||||||
|
@ -28,7 +28,7 @@ from comfy_execution.graph import (
|
|||||||
)
|
)
|
||||||
from comfy_execution.graph_utils import GraphBuilder, is_link
|
from comfy_execution.graph_utils import GraphBuilder, is_link
|
||||||
from comfy_execution.validation import validate_node_input
|
from comfy_execution.validation import validate_node_input
|
||||||
from comfy_api.v3.io import NodeOutput, ComfyNodeV3, Hidden, NodeStateLocal
|
from comfy_api.v3.io import NodeOutput, ComfyNodeV3, Hidden, NodeStateLocal, ResourcesLocal
|
||||||
|
|
||||||
|
|
||||||
class ExecutionResult(Enum):
|
class ExecutionResult(Enum):
|
||||||
@ -224,6 +224,11 @@ def _map_node_over_list(obj, input_data_all, func, allow_interrupt=False, execut
|
|||||||
if obj.local_state is None:
|
if obj.local_state is None:
|
||||||
obj.local_state = NodeStateLocal(class_clone.hidden.unique_id)
|
obj.local_state = NodeStateLocal(class_clone.hidden.unique_id)
|
||||||
class_clone.state = obj.local_state
|
class_clone.state = obj.local_state
|
||||||
|
# NOTE: this is a mock of resource management; for local, just stores ResourcesLocal on node instance
|
||||||
|
if hasattr(obj, "local_resources"):
|
||||||
|
if obj.local_resources is None:
|
||||||
|
obj.local_resources = ResourcesLocal()
|
||||||
|
class_clone.resources = obj.local_resources
|
||||||
results.append(getattr(type(obj), func).__func__(class_clone, **inputs))
|
results.append(getattr(type(obj), func).__func__(class_clone, **inputs))
|
||||||
# V1
|
# V1
|
||||||
else:
|
else:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user