Merge pull request #8724 from comfyanonymous/v3-definition-wip

V3 definition update - Resource management + Preview helper
This commit is contained in:
Jedrzej Kosinski 2025-06-28 16:50:44 -07:00 committed by GitHub
commit aff5271291
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 223 additions and 89 deletions

View File

@ -4,6 +4,7 @@ from enum import Enum
from abc import ABC, abstractmethod
from dataclasses import dataclass, asdict
from collections import Counter
from comfy_api.v3.resources import Resources, ResourcesLocal
# used for type hinting
import torch
from spandrel import ImageModelDescriptor
@ -281,6 +282,7 @@ class NodeStateLocal(NodeState):
def __delitem__(self, key: str):
del self.local_state[key]
@comfytype(io_type="BOOLEAN")
class Boolean:
Type = bool
@ -966,6 +968,7 @@ class ComfyNodeV3:
# filled in during execution
state: NodeState = None
resources: Resources = None
hidden: HiddenHolder = None
@classmethod
@ -995,6 +998,7 @@ class ComfyNodeV3:
def __init__(self):
self.local_state: NodeStateLocal = None
self.local_resources: ResourcesLocal = None
self.__class__.VALIDATE_CLASS()
@classmethod
@ -1207,7 +1211,7 @@ class NodeOutput:
'''
Standardized output of a node; can pass in any number of args and/or a UIOutput into 'ui' kwarg.
'''
def __init__(self, *args: Any, ui: UIOutput | dict=None, expand: dict=None, block_execution: str=None, **kwargs):
def __init__(self, *args: Any, ui: _UIOutput | dict=None, expand: dict=None, block_execution: str=None, **kwargs):
self.args = args
self.ui = ui
self.expand = expand
@ -1219,21 +1223,7 @@ class NodeOutput:
# TODO: use kwargs to refer to outputs by id + organize in proper order
return self.args if len(self.args) > 0 else None
class SavedResult:
def __init__(self, filename: str, subfolder: str, type: FolderType):
self.filename = filename
self.subfolder = subfolder
self.type = type
def as_dict(self):
return {
"filename": self.filename,
"subfolder": self.subfolder,
"type": self.type.value
}
class UIOutput(ABC):
class _UIOutput(ABC):
def __init__(self):
pass
@ -1241,61 +1231,6 @@ class UIOutput(ABC):
def as_dict(self) -> dict:
... # TODO: finish
class UIImages(UIOutput):
def __init__(self, values: list[SavedResult | dict], animated=False, **kwargs):
self.values = values
self.animated = animated
def as_dict(self):
values = [x.as_dict() if isinstance(x, SavedResult) else x for x in self.values]
return {
"images": values,
"animated": (self.animated,)
}
class UILatents(UIOutput):
def __init__(self, values: list[SavedResult | dict], **kwargs):
self.values = values
def as_dict(self):
values = [x.as_dict() if isinstance(x, SavedResult) else x for x in self.values]
return {
"latents": values,
}
class UIAudio(UIOutput):
def __init__(self, values: list[SavedResult | dict], **kwargs):
self.values = values
def as_dict(self):
values = [x.as_dict() if isinstance(x, SavedResult) else x for x in self.values]
return {
"audio": values,
}
class UI3D(UIOutput):
def __init__(self, values: list[SavedResult | dict], **kwargs):
self.values = values
def as_dict(self):
values = [x.as_dict() if isinstance(x, SavedResult) else x for x in self.values]
return {
"3d": values,
}
class UIText(UIOutput):
def __init__(self, value: str, **kwargs):
self.value = value
def as_dict(self):
return {"text": (self.value,)}
def create_image_preview(image: Image.Type) -> UIImages:
# TODO: finish, right now is just Cursor's hallucination
return UIImages([SavedResult("preview.png", "comfy_org", FolderType.output)])
class TestNode(ComfyNodeV3):
@classmethod
def DEFINE_SCHEMA(cls):

65
comfy_api/v3/resources.py Normal file
View File

@ -0,0 +1,65 @@
from __future__ import annotations
import comfy.utils
import folder_paths
import logging
from abc import ABC, abstractmethod
from typing import Any
import torch
class ResourceKey(ABC):
Type = Any
def __init__(self):
...
class TorchDictFolderFilename(ResourceKey):
'''Key for requesting a torch file via file_name from a folder category.'''
Type = dict[str, torch.Tensor]
def __init__(self, folder_name: str, file_name: str):
self.folder_name = folder_name
self.file_name = file_name
def __hash__(self):
return hash((self.folder_name, self.file_name))
def __eq__(self, other: object) -> bool:
if not isinstance(other, TorchDictFolderFilename):
return False
return self.folder_name == other.folder_name and self.file_name == other.file_name
def __str__(self):
return f"{self.folder_name} -> {self.file_name}"
class Resources(ABC):
def __init__(self):
...
@abstractmethod
def get(self, key: ResourceKey, default: Any=...) -> Any:
pass
class ResourcesLocal(Resources):
def __init__(self):
super().__init__()
self.local_resources: dict[ResourceKey, Any] = {}
def get(self, key: ResourceKey, default: Any=...) -> Any:
cached = self.local_resources.get(key, None)
if cached is not None:
logging.info(f"Using cached resource '{key}'")
return cached
logging.info(f"Loading resource '{key}'")
to_return = None
if isinstance(key, TorchDictFolderFilename):
if default is ...:
to_return = comfy.utils.load_torch_file(folder_paths.get_full_path_or_raise(key.folder_name, key.file_name), safe_load=True)
else:
full_path = folder_paths.get_full_path(key.folder_name, key.file_name)
if full_path is not None:
to_return = comfy.utils.load_torch_file(full_path, safe_load=True)
if to_return is not None:
self.local_resources[key] = to_return
return to_return
if default is not ...:
return default
raise Exception(f"Unsupported resource key type: {type(key)}")

143
comfy_api/v3/ui.py Normal file
View File

@ -0,0 +1,143 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from comfy_api.v3.io import Image, Mask, FolderType, _UIOutput
# used for image preview
import folder_paths
import random
from PIL import Image as PILImage
import os
import numpy as np
class SavedResult:
def __init__(self, filename: str, subfolder: str, type: FolderType):
self.filename = filename
self.subfolder = subfolder
self.type = type
def as_dict(self):
return {
"filename": self.filename,
"subfolder": self.subfolder,
"type": self.type
}
class PreviewImage(_UIOutput):
def __init__(self, image: Image.Type, animated: bool=False, **kwargs):
output_dir = folder_paths.get_temp_directory()
type = "temp"
prefix_append = "_temp_" + ''.join(random.choice("abcdefghijklmnopqrstupvxyz") for x in range(5))
compress_level = 1
filename_prefix = "ComfyUI"
filename_prefix += prefix_append
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, output_dir, image[0].shape[1], image[0].shape[0])
results = list()
for (batch_number, image) in enumerate(image):
i = 255. * image.cpu().numpy()
img = PILImage.fromarray(np.clip(i, 0, 255).astype(np.uint8))
metadata = None
# if not args.disable_metadata:
# metadata = PngInfo()
# if prompt is not None:
# metadata.add_text("prompt", json.dumps(prompt))
# if extra_pnginfo is not None:
# for x in extra_pnginfo:
# metadata.add_text(x, json.dumps(extra_pnginfo[x]))
filename_with_batch_num = filename.replace("%batch_num%", str(batch_number))
file = f"{filename_with_batch_num}_{counter:05}_.png"
img.save(os.path.join(full_output_folder, file), pnginfo=metadata, compress_level=compress_level)
results.append(SavedResult(file, subfolder, type))
counter += 1
self.values = results
self.animated = animated
def as_dict(self):
values = [x.as_dict() if isinstance(x, SavedResult) else x for x in self.values]
return {
"images": values,
"animated": (self.animated,)
}
class PreviewMask(PreviewImage):
def __init__(self, mask: PreviewMask.Type, animated: bool=False, **kwargs):
preview = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])).movedim(1, -1).expand(-1, -1, -1, 3)
super().__init__(preview, animated, **kwargs)
# class UILatent(_UIOutput):
# def __init__(self, values: list[SavedResult | dict], **kwargs):
# output_dir = folder_paths.get_temp_directory()
# type = "temp"
# prefix_append = "_temp_" + ''.join(random.choice("abcdefghijklmnopqrstupvxyz") for x in range(5))
# compress_level = 1
# filename_prefix = "ComfyUI"
# full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
# # support save metadata for latent sharing
# prompt_info = ""
# if prompt is not None:
# prompt_info = json.dumps(prompt)
# metadata = None
# if not args.disable_metadata:
# metadata = {"prompt": prompt_info}
# if extra_pnginfo is not None:
# for x in extra_pnginfo:
# metadata[x] = json.dumps(extra_pnginfo[x])
# file = f"{filename}_{counter:05}_.latent"
# results: list[FileLocator] = []
# results.append({
# "filename": file,
# "subfolder": subfolder,
# "type": "output"
# })
# file = os.path.join(full_output_folder, file)
# output = {}
# output["latent_tensor"] = samples["samples"].contiguous()
# output["latent_format_version_0"] = torch.tensor([])
# comfy.utils.save_torch_file(output, file, metadata=metadata)
# self.values = values
# def as_dict(self):
# values = [x.as_dict() if isinstance(x, SavedResult) else x for x in self.values]
# return {
# "latents": values,
# }
class PreviewAudio(_UIOutput):
def __init__(self, values: list[SavedResult | dict], **kwargs):
self.values = values
def as_dict(self):
values = [x.as_dict() if isinstance(x, SavedResult) else x for x in self.values]
return {
"audio": values,
}
class PreviewUI3D(_UIOutput):
def __init__(self, values: list[SavedResult | dict], **kwargs):
self.values = values
def as_dict(self):
values = [x.as_dict() if isinstance(x, SavedResult) else x for x in self.values]
return {
"3d": values,
}
class PreviewText(_UIOutput):
def __init__(self, value: str, **kwargs):
self.value = value
def as_dict(self):
return {"text": (self.value,)}

View File

@ -1,5 +1,5 @@
import torch
from comfy_api.v3 import io
from comfy_api.v3 import io, ui, resources
import logging
import folder_paths
import comfy.utils
@ -96,14 +96,10 @@ class V3TestNode(io.ComfyNodeV3):
if hasattr(cls, "doohickey"):
raise Exception("The 'cls' variable leaked state on class properties between runs!")
cls.doohickey = "LOLJK"
return io.NodeOutput(some_int, image)
return io.NodeOutput(some_int, image, ui=ui.PreviewImage(image))
class V3LoraLoader(io.ComfyNodeV3):
class State(io.NodeState):
loaded_lora: tuple[str, Any] | None = None
state: State
@classmethod
def DEFINE_SCHEMA(cls):
return io.SchemaV3(
@ -147,17 +143,7 @@ class V3LoraLoader(io.ComfyNodeV3):
if strength_model == 0 and strength_clip == 0:
return io.NodeOutput(model, clip)
lora_path = folder_paths.get_full_path_or_raise("loras", lora_name)
lora = None
if cls.state.loaded_lora is not None:
if cls.state.loaded_lora[0] == lora_path:
lora = cls.state.loaded_lora[1]
else:
cls.state.loaded_lora = None
if lora is None:
lora = comfy.utils.load_torch_file(lora_path, safe_load=True)
cls.state.loaded_lora = (lora_path, lora)
lora = cls.resources.get(resources.TorchDictFolderFilename("loras", lora_name))
model_lora, clip_lora = comfy.sd.load_lora_for_models(model, clip, lora, strength_model, strength_clip)
return io.NodeOutput(model_lora, clip_lora)

View File

@ -28,7 +28,7 @@ from comfy_execution.graph import (
)
from comfy_execution.graph_utils import GraphBuilder, is_link
from comfy_execution.validation import validate_node_input
from comfy_api.v3.io import NodeOutput, ComfyNodeV3, Hidden, NodeStateLocal
from comfy_api.v3.io import NodeOutput, ComfyNodeV3, Hidden, NodeStateLocal, ResourcesLocal
class ExecutionResult(Enum):
@ -224,6 +224,11 @@ def _map_node_over_list(obj, input_data_all, func, allow_interrupt=False, execut
if obj.local_state is None:
obj.local_state = NodeStateLocal(class_clone.hidden.unique_id)
class_clone.state = obj.local_state
# NOTE: this is a mock of resource management; for local, just stores ResourcesLocal on node instance
if hasattr(obj, "local_resources"):
if obj.local_resources is None:
obj.local_resources = ResourcesLocal()
class_clone.resources = obj.local_resources
results.append(getattr(type(obj), func).__func__(class_clone, **inputs))
# V1
else: