Merge pull request #8718 from comfyanonymous/v3-definition-wip

V3 definition update - fix v3 node schema parsing, add missing Types
This commit is contained in:
Jedrzej Kosinski 2025-06-28 11:45:14 -07:00 committed by GitHub
commit 1ad8a72fe9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 261 additions and 77 deletions

View File

@ -1,13 +1,23 @@
from __future__ import annotations
from typing import Any, Literal, TYPE_CHECKING, TypeVar, Callable, Optional, cast, override
from typing import Any, Literal, TYPE_CHECKING, TypeVar, Callable, Optional, cast, TypedDict, NotRequired
from enum import Enum
from abc import ABC, abstractmethod
from dataclasses import dataclass, asdict
from collections import Counter
from comfy.comfy_types.node_typing import IO
# used for type hinting
import torch
from spandrel import ImageModelDescriptor
from comfy.model_patcher import ModelPatcher
from comfy.samplers import Sampler, CFGGuider
from comfy.sd import CLIP
from comfy.controlnet import ControlNet
from comfy.sd import VAE
from comfy.sd import StyleModel as StyleModel_
from comfy.clip_vision import ClipVisionModel
from comfy.clip_vision import Output as ClipVisionOutput_
from comfy_api.input import VideoInput
from comfy.hooks import HookGroup, HookKeyframeGroup
# from comfy_extras.nodes_images import SVG as SVG_ # NOTE: needs to be moved before can be imported due to circular reference
class FolderType(str, Enum):
@ -125,7 +135,7 @@ def comfytype(io_type: str, **kwargs):
return new_cls
return decorator
def Custom(io_type: IO | str) -> type[ComfyType]:
def Custom(io_type: str) -> type[ComfyType]:
'''Create a ComfyType for a custom io_type.'''
@comfytype(io_type=io_type)
class CustomComfyType(ComfyTypeIO):
@ -271,7 +281,7 @@ class NodeStateLocal(NodeState):
def __delitem__(self, key: str):
del self.local_state[key]
@comfytype(io_type=IO.BOOLEAN)
@comfytype(io_type="BOOLEAN")
class Boolean:
Type = bool
@ -294,7 +304,7 @@ class Boolean:
class Output(OutputV3):
...
@comfytype(io_type=IO.INT)
@comfytype(io_type="INT")
class Int:
Type = int
@ -323,8 +333,8 @@ class Int:
class Output(OutputV3):
...
@comfytype(io_type=IO.FLOAT)
class Float:
@comfytype(io_type="FLOAT")
class Float(ComfyTypeIO):
Type = float
class Input(WidgetInputV3):
@ -349,11 +359,8 @@ class Float:
"display": self.display_mode,
})
class Output(OutputV3):
...
@comfytype(io_type=IO.STRING)
class String:
@comfytype(io_type="STRING")
class String(ComfyTypeIO):
Type = str
class Input(WidgetInputV3):
@ -372,11 +379,8 @@ class String:
"placeholder": self.placeholder,
})
class Output(OutputV3):
...
@comfytype(io_type=IO.COMBO)
class Combo:
@comfytype(io_type="COMBO")
class Combo(ComfyType):
Type = str
class Input(WidgetInputV3):
'''Combo input (dropdown).'''
@ -405,8 +409,9 @@ class Combo:
"remote": self.remote.as_dict() if self.remote else None,
})
@comfytype(io_type=IO.COMBO)
class MultiCombo:
@comfytype(io_type="COMBO")
class MultiCombo(ComfyType):
'''Multiselect Combo input (dropdown for selecting potentially more than one value).'''
# TODO: something is wrong with the serialization, frontend does not recognize it as multiselect
Type = list[str]
@ -428,98 +433,278 @@ class MultiCombo:
})
return to_return
@comfytype(io_type=IO.IMAGE)
@comfytype(io_type="IMAGE")
class Image(ComfyTypeIO):
Type = torch.Tensor
@comfytype(io_type=IO.MASK)
@comfytype(io_type="MASK")
class Mask(ComfyTypeIO):
Type = torch.Tensor
@comfytype(io_type=IO.LATENT)
@comfytype(io_type="LATENT")
class Latent(ComfyTypeIO):
Type = Any # TODO: make Type a TypedDict
'''Latents are stored as a dictionary.'''
class LatentDict(TypedDict):
samples: torch.Tensor
'''Latent tensors.'''
noise_mask: NotRequired[torch.Tensor]
batch_index: NotRequired[list[int]]
type: NotRequired[str]
'''Only needed if dealing with these types: audio, hunyuan3dv2'''
Type = LatentDict
@comfytype(io_type=IO.CONDITIONING)
@comfytype(io_type="CONDITIONING")
class Conditioning(ComfyTypeIO):
Type = Any
class PooledDict(TypedDict):
pooled_output: torch.Tensor
'''Pooled output from CLIP.'''
control: NotRequired[ControlNet]
'''ControlNet to apply to conditioning.'''
control_apply_to_uncond: NotRequired[bool]
'''Whether to apply ControlNet to matching negative conditioning at sample time, if applicable.'''
cross_attn_controlnet: NotRequired[torch.Tensor]
'''CrossAttn from CLIP to use for controlnet only.'''
pooled_output_controlnet: NotRequired[torch.Tensor]
'''Pooled output from CLIP to use for controlnet only.'''
gligen: NotRequired[tuple[str, Gligen, list[tuple[torch.Tensor, int, ...]]]]
'''GLIGEN to apply to conditioning.'''
area: NotRequired[tuple[int, ...] | tuple[str, float, ...]]
'''Set area of conditioning. First half of values apply to dimensions, the second half apply to coordinates.
By default, the dimensions are based on total pixel amount, but the first value can be set to "percentage" to use a percentage of the image size instead.
@comfytype(io_type=IO.SAMPLER)
(1024, 1024, 0, 0) would apply conditioning to the top-left 1024x1024 pixels.
("percentage", 0.5, 0.5, 0, 0) would apply conditioning to the top-left 50% of the image.''' # TODO: verify its actually top-left
strength: NotRequired[float]
'''Strength of conditioning. Default strength is 1.0.'''
mask: NotRequired[torch.Tensor]
'''Mask to apply conditioning to.'''
mask_strength: NotRequired[float]
'''Strength of conditioning mask. Default strength is 1.0.'''
set_area_to_bounds: NotRequired[bool]
'''Whether conditioning mask should determine bounds of area - if set to false, latents are sampled at full resolution and result is applied in mask.'''
concat_latent_image: NotRequired[torch.Tensor]
'''Used for inpainting and specific models.'''
concat_mask: NotRequired[torch.Tensor]
'''Used for inpainting and specific models.'''
concat_image: NotRequired[torch.Tensor]
'''Used by SD_4XUpscale_Conditioning.'''
noise_augmentation: NotRequired[float]
'''Used by SD_4XUpscale_Conditioning.'''
hooks: NotRequired[HookGroup]
'''Applies hooks to conditioning.'''
default: NotRequired[bool]
'''Whether to this conditioning is 'default'; default conditioning gets applied to any areas of the image that have no masks/areas applied, assuming at least one area/mask is present during sampling.'''
start_percent: NotRequired[float]
'''Determines relative step to begin applying conditioning, expressed as a float between 0.0 and 1.0.'''
end_percent: NotRequired[float]
'''Determines relative step to end applying conditioning, expressed as a float between 0.0 and 1.0.'''
clip_start_percent: NotRequired[float]
'''Internal variable for conditioning scheduling - start of application, expressed as a float between 0.0 and 1.0.'''
clip_end_percent: NotRequired[float]
'''Internal variable for conditioning scheduling - end of application, expressed as a float between 0.0 and 1.0.'''
attention_mask: NotRequired[torch.Tensor]
'''Masks text conditioning; used by StyleModel among others.'''
attention_mask_img_shape: NotRequired[tuple[int, ...]]
'''Masks text conditioning; used by StyleModel among others.'''
unclip_conditioning: NotRequired[list[dict]]
'''Used by unCLIP.'''
conditioning_lyrics: NotRequired[torch.Tensor]
'''Used by AceT5Model.'''
seconds_start: NotRequired[float]
'''Used by StableAudio.'''
seconds_total: NotRequired[float]
'''Used by StableAudio.'''
lyrics_strength: NotRequired[float]
'''Used by AceStepAudio.'''
width: NotRequired[int]
'''Used by certain models (e.g. CLIPTextEncodeSDXL/Refiner, PixArtAlpha).'''
height: NotRequired[int]
'''Used by certain models (e.g. CLIPTextEncodeSDXL/Refiner, PixArtAlpha).'''
aesthetic_score: NotRequired[float]
'''Used by CLIPTextEncodeSDXL/Refiner.'''
crop_w: NotRequired[int]
'''Used by CLIPTextEncodeSDXL.'''
crop_h: NotRequired[int]
'''Used by CLIPTextEncodeSDXL.'''
target_width: NotRequired[int]
'''Used by CLIPTextEncodeSDXL.'''
target_height: NotRequired[int]
'''Used by CLIPTextEncodeSDXL.'''
reference_latents: NotRequired[list[torch.Tensor]]
'''Used by ReferenceLatent.'''
guidance: NotRequired[float]
'''Used by Flux-like models with guidance embed.'''
guiding_frame_index: NotRequired[int]
'''Used by Hunyuan ImageToVideo.'''
ref_latent: NotRequired[torch.Tensor]
'''Used by Hunyuan ImageToVideo.'''
keyframe_idxs: NotRequired[list[int]]
'''Used by LTXV.'''
frame_rate: NotRequired[float]
'''Used by LTXV.'''
stable_cascade_prior: NotRequired[torch.Tensor]
'''Used by StableCascade.'''
elevation: NotRequired[list[float]]
'''Used by SV3D.'''
azimuth: NotRequired[list[float]]
'''Used by SV3D.'''
motion_bucket_id: NotRequired[int]
'''Used by SVD-like models.'''
fps: NotRequired[int]
'''Used by SVD-like models.'''
augmentation_level: NotRequired[float]
'''Used by SVD-like models.'''
clip_vision_output: NotRequired[ClipVisionOutput_]
'''Used by WAN-like models.'''
vace_frames: NotRequired[torch.Tensor]
'''Used by WAN VACE.'''
vace_mask: NotRequired[torch.Tensor]
'''Used by WAN VACE.'''
vace_strength: NotRequired[float]
'''Used by WAN VACE.'''
camera_conditions: NotRequired[Any] # TODO: assign proper type once defined
'''Used by WAN Camera.'''
time_dim_concat: NotRequired[torch.Tensor]
'''Used by WAN Phantom Subject.'''
CondList = list[tuple[torch.Tensor, PooledDict]]
Type = CondList
@comfytype(io_type="SAMPLER")
class Sampler(ComfyTypeIO):
Type = Any
Type = Sampler
@comfytype(io_type=IO.SIGMAS)
@comfytype(io_type="SIGMAS")
class Sigmas(ComfyTypeIO):
Type = Any
Type = torch.Tensor
@comfytype(io_type=IO.NOISE)
@comfytype(io_type="NOISE")
class Noise(ComfyTypeIO):
Type = Any
Type = torch.Tensor
@comfytype(io_type=IO.GUIDER)
@comfytype(io_type="GUIDER")
class Guider(ComfyTypeIO):
Type = Any
Type = CFGGuider
@comfytype(io_type=IO.CLIP)
@comfytype(io_type="CLIP")
class Clip(ComfyTypeIO):
Type = Any
Type = CLIP
@comfytype(io_type=IO.CONTROL_NET)
@comfytype(io_type="CONTROL_NET")
class ControlNet(ComfyTypeIO):
Type = Any
Type = ControlNet
@comfytype(io_type=IO.VAE)
@comfytype(io_type="VAE")
class Vae(ComfyTypeIO):
Type = Any
Type = VAE
@comfytype(io_type=IO.MODEL)
@comfytype(io_type="MODEL")
class Model(ComfyTypeIO):
Type = ModelPatcher
@comfytype(io_type=IO.CLIP_VISION)
@comfytype(io_type="CLIP_VISION")
class ClipVision(ComfyTypeIO):
Type = Any
Type = ClipVisionModel
@comfytype(io_type=IO.CLIP_VISION_OUTPUT)
@comfytype(io_type="CLIP_VISION_OUTPUT")
class ClipVisionOutput(ComfyTypeIO):
Type = Any
Type = ClipVisionOutput_
@comfytype(io_type=IO.STYLE_MODEL)
@comfytype(io_type="STYLE_MODEL")
class StyleModel(ComfyTypeIO):
Type = Any
Type = StyleModel_
@comfytype(io_type=IO.GLIGEN)
@comfytype(io_type="GLIGEN")
class Gligen(ComfyTypeIO):
Type = Any
'''ModelPatcher that wraps around a 'Gligen' model.'''
Type = ModelPatcher
@comfytype(io_type=IO.UPSCALE_MODEL)
@comfytype(io_type="UPSCALE_MODEL")
class UpscaleModel(ComfyTypeIO):
Type = Any
Type = ImageModelDescriptor
@comfytype(io_type=IO.AUDIO)
@comfytype(io_type="AUDIO")
class Audio(ComfyTypeIO):
Type = Any
class AudioDict(TypedDict):
waveform: torch.Tensor
sampler_rate: int
Type = AudioDict
@comfytype(io_type=IO.POINT)
class Point(ComfyTypeIO):
Type = Any
@comfytype(io_type=IO.FACE_ANALYSIS)
class FaceAnalysis(ComfyTypeIO):
Type = Any
@comfytype(io_type=IO.BBOX)
class BBOX(ComfyTypeIO):
Type = Any
@comfytype(io_type=IO.SEGS)
class SEGS(ComfyTypeIO):
Type = Any
@comfytype(io_type=IO.VIDEO)
@comfytype(io_type="VIDEO")
class Video(ComfyTypeIO):
Type = Any
Type = VideoInput
@comfytype(io_type="SVG")
class SVG(ComfyTypeIO):
Type = Any # TODO: SVG class is defined in comfy_extras/nodes_images.py, causing circular reference; should be moved to somewhere else before referenced directly in v3
@comfytype(io_type="LORA_MODEL")
class LoraModel(ComfyTypeIO):
Type = dict[str, torch.Tensor]
@comfytype(io_type="LOSS_MAP")
class LossMap(ComfyTypeIO):
class LossMapDict(TypedDict):
loss: list[torch.Tensor]
Type = LossMapDict
@comfytype(io_type="VOXEL")
class Voxel(ComfyTypeIO):
Type = Any # TODO: VOXEL class is defined in comfy_extras/nodes_hunyuan3d.py; should be moved to somewhere else before referenced directly in v3
@comfytype(io_type="MESH")
class Mesh(ComfyTypeIO):
Type = Any # TODO: MESH class is defined in comfy_extras/nodes_hunyuan3d.py; should be moved to somewhere else before referenced directly in v3
@comfytype(io_type="HOOKS")
class Hooks(ComfyTypeIO):
Type = HookGroup
@comfytype(io_type="HOOK_KEYFRAMES")
class HookKeyframes(ComfyTypeIO):
Type = HookKeyframeGroup
@comfytype(io_type="TIMESTEPS_RANGE")
class TimestepsRange(ComfyTypeIO):
'''Range defined by start and endpoint, between 0.0 and 1.0.'''
Type = tuple[int, int]
@comfytype(io_type="LATENT_OPERATION")
class LatentOperation(ComfyTypeIO):
Type = Callable[[torch.Tensor], torch.Tensor]
@comfytype(io_type="FLOW_CONTROL")
class FlowControl(ComfyTypeIO):
# NOTE: only used in testing_nodes right now
Type = tuple[str, Any]
@comfytype(io_type="ACCUMULATION")
class Accumulation(ComfyTypeIO):
# NOTE: only used in testing_nodes right now
class AccumulationDict(TypedDict):
accum: list[Any]
Type = AccumulationDict
@comfytype(io_type="LOAD3D_CAMERA")
class Load3DCamera(ComfyTypeIO):
Type = Any # TODO: figure out type for this; in code, only described as image['camera_info'], gotten from a LOAD_3D or LOAD_3D_ANIMATION type
@comfytype(io_type="POINT")
class Point(ComfyTypeIO):
Type = Any # NOTE: I couldn't find any references in core code to POINT io_type. Does this exist?
@comfytype(io_type="FACE_ANALYSIS")
class FaceAnalysis(ComfyTypeIO):
Type = Any # NOTE: I couldn't find any references in core code to POINT io_type. Does this exist?
@comfytype(io_type="BBOX")
class BBOX(ComfyTypeIO):
Type = Any # NOTE: I couldn't find any references in core code to POINT io_type. Does this exist?
@comfytype(io_type="SEGS")
class SEGS(ComfyTypeIO):
Type = Any # NOTE: I couldn't find any references in core code to POINT io_type. Does this exist?
@comfytype(io_type="COMFY_MULTITYPED_V3")
class MultiType:
@ -528,7 +713,7 @@ class MultiType:
'''
Input that permits more than one input type; if `id` is an instance of `ComfyType.Input`, then that input will be used to create a widget (if applicable) with overridden values.
'''
def __init__(self, id: str | ComfyType.Input, types: list[type[ComfyType] | ComfyType], display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None):
def __init__(self, id: str | InputV3, types: list[type[ComfyType] | ComfyType], display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None):
# if id is an Input, then use that Input with overridden values
self.input_override = None
if isinstance(id, InputV3):
@ -561,7 +746,6 @@ class MultiType:
str_types.insert(0, self.input_override.get_io_type_V1())
return ",".join(list(dict.fromkeys(str_types)))
@override
def as_dict_V1(self):
if self.input_override is not None:
return self.input_override.as_dict_V1() | super().as_dict_V1()
@ -754,7 +938,7 @@ class SchemaV3:
raise ValueError("\n".join(issues))
class Serializer:
def __init_subclass__(cls, io_type: IO | str, **kwargs):
def __init_subclass__(cls, io_type: str, **kwargs):
cls.io_type = io_type
super().__init_subclass__(**kwargs)
@ -1134,7 +1318,7 @@ class TestNode(ComfyNodeV3):
if __name__ == "__main__":
print("hello there")
inputs: list[InputV3] = [
Int.Input("tessfes", widgetType=IO.STRING),
Int.Input("tessfes", widgetType=String.io_type),
Int.Input("my_int"),
Custom("XYZ").Input("xyz"),
Custom("MODEL_M").Input("model1"),

View File

@ -36,7 +36,7 @@ class V3TestNode(io.ComfyNodeV3):
io.Image.Input("image", display_name="new_image"),
XYZ.Input("xyz", optional=True),
io.Custom("JKL").Input("jkl", optional=True),
io.Mask.Input("mask", optional=True),
io.Mask.Input("mask", display_name="mask haha", optional=True),
io.Int.Input("some_int", display_name="new_name", min=0, max=127, default=42,
tooltip="My tooltip 😎", display_mode=io.NumberDisplay.slider),
io.Combo.Input("combo", options=["a", "b", "c"], tooltip="This is a combo input"),

View File

@ -555,7 +555,7 @@ class PromptServer():
def node_info(node_class):
obj_class = nodes.NODE_CLASS_MAPPINGS[node_class]
if isinstance(obj_class, ComfyNodeV3):
if issubclass(obj_class, ComfyNodeV3):
return obj_class.GET_NODE_INFO_V1()
info = {}
info['input'] = obj_class.INPUT_TYPES()