Added initial schema validation

This commit is contained in:
Jedrzej Kosinski 2025-06-19 04:54:49 -05:00
parent aac91caf1a
commit b52154f382

View File

@ -3,6 +3,7 @@ from typing import Any, Literal, TYPE_CHECKING, TypeVar, Callable, Optional, cas
from enum import Enum from enum import Enum
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass, asdict from dataclasses import dataclass, asdict
from collections import Counter
from comfy.comfy_types.node_typing import IO from comfy.comfy_types.node_typing import IO
# used for type hinting # used for type hinting
import torch import torch
@ -699,6 +700,26 @@ class SchemaV3:
not_idempotent: bool=False not_idempotent: bool=False
"""Flags a node as not idempotent; when True, the node will run and not reuse the cached outputs when identical inputs are provided on a different node in the graph.""" """Flags a node as not idempotent; when True, the node will run and not reuse the cached outputs when identical inputs are provided on a different node in the graph."""
def validate(self):
'''Validate the schema:
- verify ids on inputs and outputs are unique - both internally and in relation to each other
'''
input_ids = [i.id for i in self.inputs] if self.inputs is not None else []
output_ids = [o.id for o in self.outputs] if self.outputs is not None else []
input_set = set(input_ids)
output_set = set(output_ids)
issues = []
# verify ids are unique per list
if len(input_set) != len(input_ids):
issues.append(f"Input ids must be unique, but {[item for item, count in Counter(input_ids).items() if count > 1]} are not.")
if len(output_set) != len(output_ids):
issues.append(f"Output ids must be unique, but {[item for item, count in Counter(output_ids).items() if count > 1]} are not.")
# verify ids are unique between lists
intersection = input_set & output_set
if len(intersection) > 0:
issues.append(f"Ids must be unique between inputs and outputs, but {intersection} are not.")
if len(issues) > 0:
raise ValueError("\n".join(issues))
class Serializer: class Serializer:
def __init_subclass__(cls, io_type: IO | str, **kwargs): def __init_subclass__(cls, io_type: IO | str, **kwargs):
@ -889,6 +910,7 @@ class ComfyNodeV3:
def GET_SCHEMA(cls) -> SchemaV3: def GET_SCHEMA(cls) -> SchemaV3:
cls.VALIDATE_CLASS() cls.VALIDATE_CLASS()
schema = cls.DEFINE_SCHEMA() schema = cls.DEFINE_SCHEMA()
schema.validate()
if cls._DESCRIPTION is None: if cls._DESCRIPTION is None:
cls._DESCRIPTION = schema.description cls._DESCRIPTION = schema.description
if cls._CATEGORY is None: if cls._CATEGORY is None: