Merge pull request #8724 from comfyanonymous/v3-definition-wip

V3 definition update - Resource management + Preview helper
This commit is contained in:
Jedrzej Kosinski 2025-06-28 16:50:44 -07:00 committed by GitHub
commit aff5271291
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 223 additions and 89 deletions

View File

@ -4,6 +4,7 @@ from enum import Enum
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass, asdict from dataclasses import dataclass, asdict
from collections import Counter from collections import Counter
from comfy_api.v3.resources import Resources, ResourcesLocal
# used for type hinting # used for type hinting
import torch import torch
from spandrel import ImageModelDescriptor from spandrel import ImageModelDescriptor
@ -281,6 +282,7 @@ class NodeStateLocal(NodeState):
def __delitem__(self, key: str): def __delitem__(self, key: str):
del self.local_state[key] del self.local_state[key]
@comfytype(io_type="BOOLEAN") @comfytype(io_type="BOOLEAN")
class Boolean: class Boolean:
Type = bool Type = bool
@ -966,6 +968,7 @@ class ComfyNodeV3:
# filled in during execution # filled in during execution
state: NodeState = None state: NodeState = None
resources: Resources = None
hidden: HiddenHolder = None hidden: HiddenHolder = None
@classmethod @classmethod
@ -995,6 +998,7 @@ class ComfyNodeV3:
def __init__(self): def __init__(self):
self.local_state: NodeStateLocal = None self.local_state: NodeStateLocal = None
self.local_resources: ResourcesLocal = None
self.__class__.VALIDATE_CLASS() self.__class__.VALIDATE_CLASS()
@classmethod @classmethod
@ -1207,7 +1211,7 @@ class NodeOutput:
''' '''
Standardized output of a node; can pass in any number of args and/or a UIOutput into 'ui' kwarg. Standardized output of a node; can pass in any number of args and/or a UIOutput into 'ui' kwarg.
''' '''
def __init__(self, *args: Any, ui: UIOutput | dict=None, expand: dict=None, block_execution: str=None, **kwargs): def __init__(self, *args: Any, ui: _UIOutput | dict=None, expand: dict=None, block_execution: str=None, **kwargs):
self.args = args self.args = args
self.ui = ui self.ui = ui
self.expand = expand self.expand = expand
@ -1219,21 +1223,7 @@ class NodeOutput:
# TODO: use kwargs to refer to outputs by id + organize in proper order # TODO: use kwargs to refer to outputs by id + organize in proper order
return self.args if len(self.args) > 0 else None return self.args if len(self.args) > 0 else None
class _UIOutput(ABC):
class SavedResult:
def __init__(self, filename: str, subfolder: str, type: FolderType):
self.filename = filename
self.subfolder = subfolder
self.type = type
def as_dict(self):
return {
"filename": self.filename,
"subfolder": self.subfolder,
"type": self.type.value
}
class UIOutput(ABC):
def __init__(self): def __init__(self):
pass pass
@ -1241,61 +1231,6 @@ class UIOutput(ABC):
def as_dict(self) -> dict: def as_dict(self) -> dict:
... # TODO: finish ... # TODO: finish
class UIImages(UIOutput):
def __init__(self, values: list[SavedResult | dict], animated=False, **kwargs):
self.values = values
self.animated = animated
def as_dict(self):
values = [x.as_dict() if isinstance(x, SavedResult) else x for x in self.values]
return {
"images": values,
"animated": (self.animated,)
}
class UILatents(UIOutput):
def __init__(self, values: list[SavedResult | dict], **kwargs):
self.values = values
def as_dict(self):
values = [x.as_dict() if isinstance(x, SavedResult) else x for x in self.values]
return {
"latents": values,
}
class UIAudio(UIOutput):
def __init__(self, values: list[SavedResult | dict], **kwargs):
self.values = values
def as_dict(self):
values = [x.as_dict() if isinstance(x, SavedResult) else x for x in self.values]
return {
"audio": values,
}
class UI3D(UIOutput):
def __init__(self, values: list[SavedResult | dict], **kwargs):
self.values = values
def as_dict(self):
values = [x.as_dict() if isinstance(x, SavedResult) else x for x in self.values]
return {
"3d": values,
}
class UIText(UIOutput):
def __init__(self, value: str, **kwargs):
self.value = value
def as_dict(self):
return {"text": (self.value,)}
def create_image_preview(image: Image.Type) -> UIImages:
# TODO: finish, right now is just Cursor's hallucination
return UIImages([SavedResult("preview.png", "comfy_org", FolderType.output)])
class TestNode(ComfyNodeV3): class TestNode(ComfyNodeV3):
@classmethod @classmethod
def DEFINE_SCHEMA(cls): def DEFINE_SCHEMA(cls):

65
comfy_api/v3/resources.py Normal file
View File

@ -0,0 +1,65 @@
from __future__ import annotations
import comfy.utils
import folder_paths
import logging
from abc import ABC, abstractmethod
from typing import Any
import torch
class ResourceKey(ABC):
Type = Any
def __init__(self):
...
class TorchDictFolderFilename(ResourceKey):
'''Key for requesting a torch file via file_name from a folder category.'''
Type = dict[str, torch.Tensor]
def __init__(self, folder_name: str, file_name: str):
self.folder_name = folder_name
self.file_name = file_name
def __hash__(self):
return hash((self.folder_name, self.file_name))
def __eq__(self, other: object) -> bool:
if not isinstance(other, TorchDictFolderFilename):
return False
return self.folder_name == other.folder_name and self.file_name == other.file_name
def __str__(self):
return f"{self.folder_name} -> {self.file_name}"
class Resources(ABC):
def __init__(self):
...
@abstractmethod
def get(self, key: ResourceKey, default: Any=...) -> Any:
pass
class ResourcesLocal(Resources):
def __init__(self):
super().__init__()
self.local_resources: dict[ResourceKey, Any] = {}
def get(self, key: ResourceKey, default: Any=...) -> Any:
cached = self.local_resources.get(key, None)
if cached is not None:
logging.info(f"Using cached resource '{key}'")
return cached
logging.info(f"Loading resource '{key}'")
to_return = None
if isinstance(key, TorchDictFolderFilename):
if default is ...:
to_return = comfy.utils.load_torch_file(folder_paths.get_full_path_or_raise(key.folder_name, key.file_name), safe_load=True)
else:
full_path = folder_paths.get_full_path(key.folder_name, key.file_name)
if full_path is not None:
to_return = comfy.utils.load_torch_file(full_path, safe_load=True)
if to_return is not None:
self.local_resources[key] = to_return
return to_return
if default is not ...:
return default
raise Exception(f"Unsupported resource key type: {type(key)}")

143
comfy_api/v3/ui.py Normal file
View File

@ -0,0 +1,143 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from comfy_api.v3.io import Image, Mask, FolderType, _UIOutput
# used for image preview
import folder_paths
import random
from PIL import Image as PILImage
import os
import numpy as np
class SavedResult:
def __init__(self, filename: str, subfolder: str, type: FolderType):
self.filename = filename
self.subfolder = subfolder
self.type = type
def as_dict(self):
return {
"filename": self.filename,
"subfolder": self.subfolder,
"type": self.type
}
class PreviewImage(_UIOutput):
def __init__(self, image: Image.Type, animated: bool=False, **kwargs):
output_dir = folder_paths.get_temp_directory()
type = "temp"
prefix_append = "_temp_" + ''.join(random.choice("abcdefghijklmnopqrstupvxyz") for x in range(5))
compress_level = 1
filename_prefix = "ComfyUI"
filename_prefix += prefix_append
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, output_dir, image[0].shape[1], image[0].shape[0])
results = list()
for (batch_number, image) in enumerate(image):
i = 255. * image.cpu().numpy()
img = PILImage.fromarray(np.clip(i, 0, 255).astype(np.uint8))
metadata = None
# if not args.disable_metadata:
# metadata = PngInfo()
# if prompt is not None:
# metadata.add_text("prompt", json.dumps(prompt))
# if extra_pnginfo is not None:
# for x in extra_pnginfo:
# metadata.add_text(x, json.dumps(extra_pnginfo[x]))
filename_with_batch_num = filename.replace("%batch_num%", str(batch_number))
file = f"{filename_with_batch_num}_{counter:05}_.png"
img.save(os.path.join(full_output_folder, file), pnginfo=metadata, compress_level=compress_level)
results.append(SavedResult(file, subfolder, type))
counter += 1
self.values = results
self.animated = animated
def as_dict(self):
values = [x.as_dict() if isinstance(x, SavedResult) else x for x in self.values]
return {
"images": values,
"animated": (self.animated,)
}
class PreviewMask(PreviewImage):
def __init__(self, mask: PreviewMask.Type, animated: bool=False, **kwargs):
preview = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])).movedim(1, -1).expand(-1, -1, -1, 3)
super().__init__(preview, animated, **kwargs)
# class UILatent(_UIOutput):
# def __init__(self, values: list[SavedResult | dict], **kwargs):
# output_dir = folder_paths.get_temp_directory()
# type = "temp"
# prefix_append = "_temp_" + ''.join(random.choice("abcdefghijklmnopqrstupvxyz") for x in range(5))
# compress_level = 1
# filename_prefix = "ComfyUI"
# full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
# # support save metadata for latent sharing
# prompt_info = ""
# if prompt is not None:
# prompt_info = json.dumps(prompt)
# metadata = None
# if not args.disable_metadata:
# metadata = {"prompt": prompt_info}
# if extra_pnginfo is not None:
# for x in extra_pnginfo:
# metadata[x] = json.dumps(extra_pnginfo[x])
# file = f"{filename}_{counter:05}_.latent"
# results: list[FileLocator] = []
# results.append({
# "filename": file,
# "subfolder": subfolder,
# "type": "output"
# })
# file = os.path.join(full_output_folder, file)
# output = {}
# output["latent_tensor"] = samples["samples"].contiguous()
# output["latent_format_version_0"] = torch.tensor([])
# comfy.utils.save_torch_file(output, file, metadata=metadata)
# self.values = values
# def as_dict(self):
# values = [x.as_dict() if isinstance(x, SavedResult) else x for x in self.values]
# return {
# "latents": values,
# }
class PreviewAudio(_UIOutput):
def __init__(self, values: list[SavedResult | dict], **kwargs):
self.values = values
def as_dict(self):
values = [x.as_dict() if isinstance(x, SavedResult) else x for x in self.values]
return {
"audio": values,
}
class PreviewUI3D(_UIOutput):
def __init__(self, values: list[SavedResult | dict], **kwargs):
self.values = values
def as_dict(self):
values = [x.as_dict() if isinstance(x, SavedResult) else x for x in self.values]
return {
"3d": values,
}
class PreviewText(_UIOutput):
def __init__(self, value: str, **kwargs):
self.value = value
def as_dict(self):
return {"text": (self.value,)}

View File

@ -1,5 +1,5 @@
import torch import torch
from comfy_api.v3 import io from comfy_api.v3 import io, ui, resources
import logging import logging
import folder_paths import folder_paths
import comfy.utils import comfy.utils
@ -96,14 +96,10 @@ class V3TestNode(io.ComfyNodeV3):
if hasattr(cls, "doohickey"): if hasattr(cls, "doohickey"):
raise Exception("The 'cls' variable leaked state on class properties between runs!") raise Exception("The 'cls' variable leaked state on class properties between runs!")
cls.doohickey = "LOLJK" cls.doohickey = "LOLJK"
return io.NodeOutput(some_int, image) return io.NodeOutput(some_int, image, ui=ui.PreviewImage(image))
class V3LoraLoader(io.ComfyNodeV3): class V3LoraLoader(io.ComfyNodeV3):
class State(io.NodeState):
loaded_lora: tuple[str, Any] | None = None
state: State
@classmethod @classmethod
def DEFINE_SCHEMA(cls): def DEFINE_SCHEMA(cls):
return io.SchemaV3( return io.SchemaV3(
@ -147,17 +143,7 @@ class V3LoraLoader(io.ComfyNodeV3):
if strength_model == 0 and strength_clip == 0: if strength_model == 0 and strength_clip == 0:
return io.NodeOutput(model, clip) return io.NodeOutput(model, clip)
lora_path = folder_paths.get_full_path_or_raise("loras", lora_name) lora = cls.resources.get(resources.TorchDictFolderFilename("loras", lora_name))
lora = None
if cls.state.loaded_lora is not None:
if cls.state.loaded_lora[0] == lora_path:
lora = cls.state.loaded_lora[1]
else:
cls.state.loaded_lora = None
if lora is None:
lora = comfy.utils.load_torch_file(lora_path, safe_load=True)
cls.state.loaded_lora = (lora_path, lora)
model_lora, clip_lora = comfy.sd.load_lora_for_models(model, clip, lora, strength_model, strength_clip) model_lora, clip_lora = comfy.sd.load_lora_for_models(model, clip, lora, strength_model, strength_clip)
return io.NodeOutput(model_lora, clip_lora) return io.NodeOutput(model_lora, clip_lora)

View File

@ -28,7 +28,7 @@ from comfy_execution.graph import (
) )
from comfy_execution.graph_utils import GraphBuilder, is_link from comfy_execution.graph_utils import GraphBuilder, is_link
from comfy_execution.validation import validate_node_input from comfy_execution.validation import validate_node_input
from comfy_api.v3.io import NodeOutput, ComfyNodeV3, Hidden, NodeStateLocal from comfy_api.v3.io import NodeOutput, ComfyNodeV3, Hidden, NodeStateLocal, ResourcesLocal
class ExecutionResult(Enum): class ExecutionResult(Enum):
@ -224,6 +224,11 @@ def _map_node_over_list(obj, input_data_all, func, allow_interrupt=False, execut
if obj.local_state is None: if obj.local_state is None:
obj.local_state = NodeStateLocal(class_clone.hidden.unique_id) obj.local_state = NodeStateLocal(class_clone.hidden.unique_id)
class_clone.state = obj.local_state class_clone.state = obj.local_state
# NOTE: this is a mock of resource management; for local, just stores ResourcesLocal on node instance
if hasattr(obj, "local_resources"):
if obj.local_resources is None:
obj.local_resources = ResourcesLocal()
class_clone.resources = obj.local_resources
results.append(getattr(type(obj), func).__func__(class_clone, **inputs)) results.append(getattr(type(obj), func).__func__(class_clone, **inputs))
# V1 # V1
else: else: