Replaced io_type with direct strings instead of using node_typing.py's IO class

This commit is contained in:
Jedrzej Kosinski 2025-06-28 11:14:18 -07:00
parent 0122bc43ea
commit f4ece6731b

View File

@ -4,7 +4,6 @@ from enum import Enum
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass, asdict from dataclasses import dataclass, asdict
from collections import Counter from collections import Counter
from comfy.comfy_types.node_typing import IO
# used for type hinting # used for type hinting
import torch import torch
from spandrel import ImageModelDescriptor from spandrel import ImageModelDescriptor
@ -136,7 +135,7 @@ def comfytype(io_type: str, **kwargs):
return new_cls return new_cls
return decorator return decorator
def Custom(io_type: IO | str) -> type[ComfyType]: def Custom(io_type: str) -> type[ComfyType]:
'''Create a ComfyType for a custom io_type.''' '''Create a ComfyType for a custom io_type.'''
@comfytype(io_type=io_type) @comfytype(io_type=io_type)
class CustomComfyType(ComfyTypeIO): class CustomComfyType(ComfyTypeIO):
@ -282,7 +281,7 @@ class NodeStateLocal(NodeState):
def __delitem__(self, key: str): def __delitem__(self, key: str):
del self.local_state[key] del self.local_state[key]
@comfytype(io_type=IO.BOOLEAN) @comfytype(io_type="BOOLEAN")
class Boolean: class Boolean:
Type = bool Type = bool
@ -305,7 +304,7 @@ class Boolean:
class Output(OutputV3): class Output(OutputV3):
... ...
@comfytype(io_type=IO.INT) @comfytype(io_type="INT")
class Int: class Int:
Type = int Type = int
@ -334,8 +333,8 @@ class Int:
class Output(OutputV3): class Output(OutputV3):
... ...
@comfytype(io_type=IO.FLOAT) @comfytype(io_type="FLOAT")
class Float: class Float(ComfyTypeIO):
Type = float Type = float
class Input(WidgetInputV3): class Input(WidgetInputV3):
@ -360,11 +359,8 @@ class Float:
"display": self.display_mode, "display": self.display_mode,
}) })
class Output(OutputV3): @comfytype(io_type="STRING")
... class String(ComfyTypeIO):
@comfytype(io_type=IO.STRING)
class String:
Type = str Type = str
class Input(WidgetInputV3): class Input(WidgetInputV3):
@ -383,11 +379,8 @@ class String:
"placeholder": self.placeholder, "placeholder": self.placeholder,
}) })
class Output(OutputV3): @comfytype(io_type="COMBO")
... class Combo(ComfyType):
@comfytype(io_type=IO.COMBO)
class Combo:
Type = str Type = str
class Input(WidgetInputV3): class Input(WidgetInputV3):
'''Combo input (dropdown).''' '''Combo input (dropdown).'''
@ -416,8 +409,9 @@ class Combo:
"remote": self.remote.as_dict() if self.remote else None, "remote": self.remote.as_dict() if self.remote else None,
}) })
@comfytype(io_type=IO.COMBO)
class MultiCombo: @comfytype(io_type="COMBO")
class MultiCombo(ComfyType):
'''Multiselect Combo input (dropdown for selecting potentially more than one value).''' '''Multiselect Combo input (dropdown for selecting potentially more than one value).'''
# TODO: something is wrong with the serialization, frontend does not recognize it as multiselect # TODO: something is wrong with the serialization, frontend does not recognize it as multiselect
Type = list[str] Type = list[str]
@ -439,16 +433,15 @@ class MultiCombo:
}) })
return to_return return to_return
@comfytype(io_type="IMAGE")
@comfytype(io_type=IO.IMAGE)
class Image(ComfyTypeIO): class Image(ComfyTypeIO):
Type = torch.Tensor Type = torch.Tensor
@comfytype(io_type=IO.MASK) @comfytype(io_type="MASK")
class Mask(ComfyTypeIO): class Mask(ComfyTypeIO):
Type = torch.Tensor Type = torch.Tensor
@comfytype(io_type=IO.LATENT) @comfytype(io_type="LATENT")
class Latent(ComfyTypeIO): class Latent(ComfyTypeIO):
'''Latents are stored as a dictionary.''' '''Latents are stored as a dictionary.'''
class LatentDict(TypedDict): class LatentDict(TypedDict):
@ -460,7 +453,7 @@ class Latent(ComfyTypeIO):
'''Only needed if dealing with these types: audio, hunyuan3dv2''' '''Only needed if dealing with these types: audio, hunyuan3dv2'''
Type = LatentDict Type = LatentDict
@comfytype(io_type=IO.CONDITIONING) @comfytype(io_type="CONDITIONING")
class Conditioning(ComfyTypeIO): class Conditioning(ComfyTypeIO):
class PooledDict(TypedDict): class PooledDict(TypedDict):
pooled_output: torch.Tensor pooled_output: torch.Tensor
@ -577,67 +570,67 @@ class Conditioning(ComfyTypeIO):
CondList = list[tuple[torch.Tensor, PooledDict]] CondList = list[tuple[torch.Tensor, PooledDict]]
Type = CondList Type = CondList
@comfytype(io_type=IO.SAMPLER) @comfytype(io_type="SAMPLER")
class Sampler(ComfyTypeIO): class Sampler(ComfyTypeIO):
Type = Sampler Type = Sampler
@comfytype(io_type=IO.SIGMAS) @comfytype(io_type="SIGMAS")
class Sigmas(ComfyTypeIO): class Sigmas(ComfyTypeIO):
Type = torch.Tensor Type = torch.Tensor
@comfytype(io_type=IO.NOISE) @comfytype(io_type="NOISE")
class Noise(ComfyTypeIO): class Noise(ComfyTypeIO):
Type = torch.Tensor Type = torch.Tensor
@comfytype(io_type=IO.GUIDER) @comfytype(io_type="GUIDER")
class Guider(ComfyTypeIO): class Guider(ComfyTypeIO):
Type = CFGGuider Type = CFGGuider
@comfytype(io_type=IO.CLIP) @comfytype(io_type="CLIP")
class Clip(ComfyTypeIO): class Clip(ComfyTypeIO):
Type = CLIP Type = CLIP
@comfytype(io_type=IO.CONTROL_NET) @comfytype(io_type="CONTROL_NET")
class ControlNet(ComfyTypeIO): class ControlNet(ComfyTypeIO):
Type = ControlNet Type = ControlNet
@comfytype(io_type=IO.VAE) @comfytype(io_type="VAE")
class Vae(ComfyTypeIO): class Vae(ComfyTypeIO):
Type = VAE Type = VAE
@comfytype(io_type=IO.MODEL) @comfytype(io_type="MODEL")
class Model(ComfyTypeIO): class Model(ComfyTypeIO):
Type = ModelPatcher Type = ModelPatcher
@comfytype(io_type=IO.CLIP_VISION) @comfytype(io_type="CLIP_VISION")
class ClipVision(ComfyTypeIO): class ClipVision(ComfyTypeIO):
Type = ClipVisionModel Type = ClipVisionModel
@comfytype(io_type=IO.CLIP_VISION_OUTPUT) @comfytype(io_type="CLIP_VISION_OUTPUT")
class ClipVisionOutput(ComfyTypeIO): class ClipVisionOutput(ComfyTypeIO):
Type = ClipVisionOutput_ Type = ClipVisionOutput_
@comfytype(io_type=IO.STYLE_MODEL) @comfytype(io_type="STYLE_MODEL")
class StyleModel(ComfyTypeIO): class StyleModel(ComfyTypeIO):
Type = StyleModel_ Type = StyleModel_
@comfytype(io_type=IO.GLIGEN) @comfytype(io_type="GLIGEN")
class Gligen(ComfyTypeIO): class Gligen(ComfyTypeIO):
'''ModelPatcher that wraps around a 'Gligen' model.''' '''ModelPatcher that wraps around a 'Gligen' model.'''
Type = ModelPatcher Type = ModelPatcher
@comfytype(io_type=IO.UPSCALE_MODEL) @comfytype(io_type="UPSCALE_MODEL")
class UpscaleModel(ComfyTypeIO): class UpscaleModel(ComfyTypeIO):
Type = ImageModelDescriptor Type = ImageModelDescriptor
@comfytype(io_type=IO.AUDIO) @comfytype(io_type="AUDIO")
class Audio(ComfyTypeIO): class Audio(ComfyTypeIO):
class AudioDict(TypedDict): class AudioDict(TypedDict):
waveform: torch.Tensor waveform: torch.Tensor
sampler_rate: int sampler_rate: int
Type = AudioDict Type = AudioDict
@comfytype(io_type=IO.VIDEO) @comfytype(io_type="VIDEO")
class Video(ComfyTypeIO): class Video(ComfyTypeIO):
Type = VideoInput Type = VideoInput
@ -645,11 +638,11 @@ class Video(ComfyTypeIO):
class SVG(ComfyTypeIO): class SVG(ComfyTypeIO):
Type = Any # TODO: SVG class is defined in comfy_extras/nodes_images.py, causing circular reference; should be moved to somewhere else before referenced directly in v3 Type = Any # TODO: SVG class is defined in comfy_extras/nodes_images.py, causing circular reference; should be moved to somewhere else before referenced directly in v3
@comfytype(io_type=IO.LORA_MODEL) @comfytype(io_type="LORA_MODEL")
class LoraModel(ComfyTypeIO): class LoraModel(ComfyTypeIO):
Type = dict[str, torch.Tensor] Type = dict[str, torch.Tensor]
@comfytype(io_type=IO.LOSS_MAP) @comfytype(io_type="LOSS_MAP")
class LossMap(ComfyTypeIO): class LossMap(ComfyTypeIO):
class LossMapDict(TypedDict): class LossMapDict(TypedDict):
loss: list[torch.Tensor] loss: list[torch.Tensor]
@ -696,19 +689,19 @@ class Accumulation(ComfyTypeIO):
class Load3DCamera(ComfyTypeIO): class Load3DCamera(ComfyTypeIO):
Type = Any # TODO: figure out type for this; in code, only described as image['camera_info'], gotten from a LOAD_3D or LOAD_3D_ANIMATION type Type = Any # TODO: figure out type for this; in code, only described as image['camera_info'], gotten from a LOAD_3D or LOAD_3D_ANIMATION type
@comfytype(io_type=IO.POINT) @comfytype(io_type="POINT")
class Point(ComfyTypeIO): class Point(ComfyTypeIO):
Type = Any # NOTE: I couldn't find any references in core code to POINT io_type. Does this exist? Type = Any # NOTE: I couldn't find any references in core code to POINT io_type. Does this exist?
@comfytype(io_type=IO.FACE_ANALYSIS) @comfytype(io_type="FACE_ANALYSIS")
class FaceAnalysis(ComfyTypeIO): class FaceAnalysis(ComfyTypeIO):
Type = Any # NOTE: I couldn't find any references in core code to POINT io_type. Does this exist? Type = Any # NOTE: I couldn't find any references in core code to POINT io_type. Does this exist?
@comfytype(io_type=IO.BBOX) @comfytype(io_type="BBOX")
class BBOX(ComfyTypeIO): class BBOX(ComfyTypeIO):
Type = Any # NOTE: I couldn't find any references in core code to POINT io_type. Does this exist? Type = Any # NOTE: I couldn't find any references in core code to POINT io_type. Does this exist?
@comfytype(io_type=IO.SEGS) @comfytype(io_type="SEGS")
class SEGS(ComfyTypeIO): class SEGS(ComfyTypeIO):
Type = Any # NOTE: I couldn't find any references in core code to POINT io_type. Does this exist? Type = Any # NOTE: I couldn't find any references in core code to POINT io_type. Does this exist?
@ -944,7 +937,7 @@ class SchemaV3:
raise ValueError("\n".join(issues)) raise ValueError("\n".join(issues))
class Serializer: class Serializer:
def __init_subclass__(cls, io_type: IO | str, **kwargs): def __init_subclass__(cls, io_type: str, **kwargs):
cls.io_type = io_type cls.io_type = io_type
super().__init_subclass__(**kwargs) super().__init_subclass__(**kwargs)
@ -1324,7 +1317,7 @@ class TestNode(ComfyNodeV3):
if __name__ == "__main__": if __name__ == "__main__":
print("hello there") print("hello there")
inputs: list[InputV3] = [ inputs: list[InputV3] = [
Int.Input("tessfes", widgetType=IO.STRING), Int.Input("tessfes", widgetType=String.io_type),
Int.Input("my_int"), Int.Input("my_int"),
Custom("XYZ").Input("xyz"), Custom("XYZ").Input("xyz"),
Custom("MODEL_M").Input("model1"), Custom("MODEL_M").Input("model1"),