diff --git a/comfy_api/v3/io.py b/comfy_api/v3/io.py index b0a4931c7..2dd9da8ee 100644 --- a/comfy_api/v3/io.py +++ b/comfy_api/v3/io.py @@ -1,13 +1,23 @@ from __future__ import annotations -from typing import Any, Literal, TYPE_CHECKING, TypeVar, Callable, Optional, cast, override +from typing import Any, Literal, TYPE_CHECKING, TypeVar, Callable, Optional, cast, TypedDict, NotRequired from enum import Enum from abc import ABC, abstractmethod from dataclasses import dataclass, asdict from collections import Counter -from comfy.comfy_types.node_typing import IO # used for type hinting import torch +from spandrel import ImageModelDescriptor from comfy.model_patcher import ModelPatcher +from comfy.samplers import Sampler, CFGGuider +from comfy.sd import CLIP +from comfy.controlnet import ControlNet +from comfy.sd import VAE +from comfy.sd import StyleModel as StyleModel_ +from comfy.clip_vision import ClipVisionModel +from comfy.clip_vision import Output as ClipVisionOutput_ +from comfy_api.input import VideoInput +from comfy.hooks import HookGroup, HookKeyframeGroup +# from comfy_extras.nodes_images import SVG as SVG_ # NOTE: needs to be moved before can be imported due to circular reference class FolderType(str, Enum): @@ -125,7 +135,7 @@ def comfytype(io_type: str, **kwargs): return new_cls return decorator -def Custom(io_type: IO | str) -> type[ComfyType]: +def Custom(io_type: str) -> type[ComfyType]: '''Create a ComfyType for a custom io_type.''' @comfytype(io_type=io_type) class CustomComfyType(ComfyTypeIO): @@ -271,7 +281,7 @@ class NodeStateLocal(NodeState): def __delitem__(self, key: str): del self.local_state[key] -@comfytype(io_type=IO.BOOLEAN) +@comfytype(io_type="BOOLEAN") class Boolean: Type = bool @@ -294,7 +304,7 @@ class Boolean: class Output(OutputV3): ... -@comfytype(io_type=IO.INT) +@comfytype(io_type="INT") class Int: Type = int @@ -323,8 +333,8 @@ class Int: class Output(OutputV3): ... -@comfytype(io_type=IO.FLOAT) -class Float: +@comfytype(io_type="FLOAT") +class Float(ComfyTypeIO): Type = float class Input(WidgetInputV3): @@ -349,11 +359,8 @@ class Float: "display": self.display_mode, }) - class Output(OutputV3): - ... - -@comfytype(io_type=IO.STRING) -class String: +@comfytype(io_type="STRING") +class String(ComfyTypeIO): Type = str class Input(WidgetInputV3): @@ -372,11 +379,8 @@ class String: "placeholder": self.placeholder, }) - class Output(OutputV3): - ... - -@comfytype(io_type=IO.COMBO) -class Combo: +@comfytype(io_type="COMBO") +class Combo(ComfyType): Type = str class Input(WidgetInputV3): '''Combo input (dropdown).''' @@ -405,8 +409,9 @@ class Combo: "remote": self.remote.as_dict() if self.remote else None, }) -@comfytype(io_type=IO.COMBO) -class MultiCombo: + +@comfytype(io_type="COMBO") +class MultiCombo(ComfyType): '''Multiselect Combo input (dropdown for selecting potentially more than one value).''' # TODO: something is wrong with the serialization, frontend does not recognize it as multiselect Type = list[str] @@ -428,98 +433,278 @@ class MultiCombo: }) return to_return - -@comfytype(io_type=IO.IMAGE) +@comfytype(io_type="IMAGE") class Image(ComfyTypeIO): Type = torch.Tensor -@comfytype(io_type=IO.MASK) +@comfytype(io_type="MASK") class Mask(ComfyTypeIO): Type = torch.Tensor -@comfytype(io_type=IO.LATENT) +@comfytype(io_type="LATENT") class Latent(ComfyTypeIO): - Type = Any # TODO: make Type a TypedDict + '''Latents are stored as a dictionary.''' + class LatentDict(TypedDict): + samples: torch.Tensor + '''Latent tensors.''' + noise_mask: NotRequired[torch.Tensor] + batch_index: NotRequired[list[int]] + type: NotRequired[str] + '''Only needed if dealing with these types: audio, hunyuan3dv2''' + Type = LatentDict -@comfytype(io_type=IO.CONDITIONING) +@comfytype(io_type="CONDITIONING") class Conditioning(ComfyTypeIO): - Type = Any + class PooledDict(TypedDict): + pooled_output: torch.Tensor + '''Pooled output from CLIP.''' + control: NotRequired[ControlNet] + '''ControlNet to apply to conditioning.''' + control_apply_to_uncond: NotRequired[bool] + '''Whether to apply ControlNet to matching negative conditioning at sample time, if applicable.''' + cross_attn_controlnet: NotRequired[torch.Tensor] + '''CrossAttn from CLIP to use for controlnet only.''' + pooled_output_controlnet: NotRequired[torch.Tensor] + '''Pooled output from CLIP to use for controlnet only.''' + gligen: NotRequired[tuple[str, Gligen, list[tuple[torch.Tensor, int, ...]]]] + '''GLIGEN to apply to conditioning.''' + area: NotRequired[tuple[int, ...] | tuple[str, float, ...]] + '''Set area of conditioning. First half of values apply to dimensions, the second half apply to coordinates. + By default, the dimensions are based on total pixel amount, but the first value can be set to "percentage" to use a percentage of the image size instead. -@comfytype(io_type=IO.SAMPLER) + (1024, 1024, 0, 0) would apply conditioning to the top-left 1024x1024 pixels. + + ("percentage", 0.5, 0.5, 0, 0) would apply conditioning to the top-left 50% of the image.''' # TODO: verify its actually top-left + strength: NotRequired[float] + '''Strength of conditioning. Default strength is 1.0.''' + mask: NotRequired[torch.Tensor] + '''Mask to apply conditioning to.''' + mask_strength: NotRequired[float] + '''Strength of conditioning mask. Default strength is 1.0.''' + set_area_to_bounds: NotRequired[bool] + '''Whether conditioning mask should determine bounds of area - if set to false, latents are sampled at full resolution and result is applied in mask.''' + concat_latent_image: NotRequired[torch.Tensor] + '''Used for inpainting and specific models.''' + concat_mask: NotRequired[torch.Tensor] + '''Used for inpainting and specific models.''' + concat_image: NotRequired[torch.Tensor] + '''Used by SD_4XUpscale_Conditioning.''' + noise_augmentation: NotRequired[float] + '''Used by SD_4XUpscale_Conditioning.''' + hooks: NotRequired[HookGroup] + '''Applies hooks to conditioning.''' + default: NotRequired[bool] + '''Whether to this conditioning is 'default'; default conditioning gets applied to any areas of the image that have no masks/areas applied, assuming at least one area/mask is present during sampling.''' + start_percent: NotRequired[float] + '''Determines relative step to begin applying conditioning, expressed as a float between 0.0 and 1.0.''' + end_percent: NotRequired[float] + '''Determines relative step to end applying conditioning, expressed as a float between 0.0 and 1.0.''' + clip_start_percent: NotRequired[float] + '''Internal variable for conditioning scheduling - start of application, expressed as a float between 0.0 and 1.0.''' + clip_end_percent: NotRequired[float] + '''Internal variable for conditioning scheduling - end of application, expressed as a float between 0.0 and 1.0.''' + attention_mask: NotRequired[torch.Tensor] + '''Masks text conditioning; used by StyleModel among others.''' + attention_mask_img_shape: NotRequired[tuple[int, ...]] + '''Masks text conditioning; used by StyleModel among others.''' + unclip_conditioning: NotRequired[list[dict]] + '''Used by unCLIP.''' + conditioning_lyrics: NotRequired[torch.Tensor] + '''Used by AceT5Model.''' + seconds_start: NotRequired[float] + '''Used by StableAudio.''' + seconds_total: NotRequired[float] + '''Used by StableAudio.''' + lyrics_strength: NotRequired[float] + '''Used by AceStepAudio.''' + width: NotRequired[int] + '''Used by certain models (e.g. CLIPTextEncodeSDXL/Refiner, PixArtAlpha).''' + height: NotRequired[int] + '''Used by certain models (e.g. CLIPTextEncodeSDXL/Refiner, PixArtAlpha).''' + aesthetic_score: NotRequired[float] + '''Used by CLIPTextEncodeSDXL/Refiner.''' + crop_w: NotRequired[int] + '''Used by CLIPTextEncodeSDXL.''' + crop_h: NotRequired[int] + '''Used by CLIPTextEncodeSDXL.''' + target_width: NotRequired[int] + '''Used by CLIPTextEncodeSDXL.''' + target_height: NotRequired[int] + '''Used by CLIPTextEncodeSDXL.''' + reference_latents: NotRequired[list[torch.Tensor]] + '''Used by ReferenceLatent.''' + guidance: NotRequired[float] + '''Used by Flux-like models with guidance embed.''' + guiding_frame_index: NotRequired[int] + '''Used by Hunyuan ImageToVideo.''' + ref_latent: NotRequired[torch.Tensor] + '''Used by Hunyuan ImageToVideo.''' + keyframe_idxs: NotRequired[list[int]] + '''Used by LTXV.''' + frame_rate: NotRequired[float] + '''Used by LTXV.''' + stable_cascade_prior: NotRequired[torch.Tensor] + '''Used by StableCascade.''' + elevation: NotRequired[list[float]] + '''Used by SV3D.''' + azimuth: NotRequired[list[float]] + '''Used by SV3D.''' + motion_bucket_id: NotRequired[int] + '''Used by SVD-like models.''' + fps: NotRequired[int] + '''Used by SVD-like models.''' + augmentation_level: NotRequired[float] + '''Used by SVD-like models.''' + clip_vision_output: NotRequired[ClipVisionOutput_] + '''Used by WAN-like models.''' + vace_frames: NotRequired[torch.Tensor] + '''Used by WAN VACE.''' + vace_mask: NotRequired[torch.Tensor] + '''Used by WAN VACE.''' + vace_strength: NotRequired[float] + '''Used by WAN VACE.''' + camera_conditions: NotRequired[Any] # TODO: assign proper type once defined + '''Used by WAN Camera.''' + time_dim_concat: NotRequired[torch.Tensor] + '''Used by WAN Phantom Subject.''' + + CondList = list[tuple[torch.Tensor, PooledDict]] + Type = CondList + +@comfytype(io_type="SAMPLER") class Sampler(ComfyTypeIO): - Type = Any + Type = Sampler -@comfytype(io_type=IO.SIGMAS) +@comfytype(io_type="SIGMAS") class Sigmas(ComfyTypeIO): - Type = Any + Type = torch.Tensor -@comfytype(io_type=IO.NOISE) +@comfytype(io_type="NOISE") class Noise(ComfyTypeIO): - Type = Any + Type = torch.Tensor -@comfytype(io_type=IO.GUIDER) +@comfytype(io_type="GUIDER") class Guider(ComfyTypeIO): - Type = Any + Type = CFGGuider -@comfytype(io_type=IO.CLIP) +@comfytype(io_type="CLIP") class Clip(ComfyTypeIO): - Type = Any + Type = CLIP -@comfytype(io_type=IO.CONTROL_NET) +@comfytype(io_type="CONTROL_NET") class ControlNet(ComfyTypeIO): - Type = Any + Type = ControlNet -@comfytype(io_type=IO.VAE) +@comfytype(io_type="VAE") class Vae(ComfyTypeIO): - Type = Any + Type = VAE -@comfytype(io_type=IO.MODEL) +@comfytype(io_type="MODEL") class Model(ComfyTypeIO): Type = ModelPatcher -@comfytype(io_type=IO.CLIP_VISION) +@comfytype(io_type="CLIP_VISION") class ClipVision(ComfyTypeIO): - Type = Any + Type = ClipVisionModel -@comfytype(io_type=IO.CLIP_VISION_OUTPUT) +@comfytype(io_type="CLIP_VISION_OUTPUT") class ClipVisionOutput(ComfyTypeIO): - Type = Any + Type = ClipVisionOutput_ -@comfytype(io_type=IO.STYLE_MODEL) +@comfytype(io_type="STYLE_MODEL") class StyleModel(ComfyTypeIO): - Type = Any + Type = StyleModel_ -@comfytype(io_type=IO.GLIGEN) +@comfytype(io_type="GLIGEN") class Gligen(ComfyTypeIO): - Type = Any + '''ModelPatcher that wraps around a 'Gligen' model.''' + Type = ModelPatcher -@comfytype(io_type=IO.UPSCALE_MODEL) +@comfytype(io_type="UPSCALE_MODEL") class UpscaleModel(ComfyTypeIO): - Type = Any + Type = ImageModelDescriptor -@comfytype(io_type=IO.AUDIO) +@comfytype(io_type="AUDIO") class Audio(ComfyTypeIO): - Type = Any + class AudioDict(TypedDict): + waveform: torch.Tensor + sampler_rate: int + Type = AudioDict -@comfytype(io_type=IO.POINT) -class Point(ComfyTypeIO): - Type = Any - -@comfytype(io_type=IO.FACE_ANALYSIS) -class FaceAnalysis(ComfyTypeIO): - Type = Any - -@comfytype(io_type=IO.BBOX) -class BBOX(ComfyTypeIO): - Type = Any - -@comfytype(io_type=IO.SEGS) -class SEGS(ComfyTypeIO): - Type = Any - -@comfytype(io_type=IO.VIDEO) +@comfytype(io_type="VIDEO") class Video(ComfyTypeIO): - Type = Any + Type = VideoInput + +@comfytype(io_type="SVG") +class SVG(ComfyTypeIO): + Type = Any # TODO: SVG class is defined in comfy_extras/nodes_images.py, causing circular reference; should be moved to somewhere else before referenced directly in v3 + +@comfytype(io_type="LORA_MODEL") +class LoraModel(ComfyTypeIO): + Type = dict[str, torch.Tensor] + +@comfytype(io_type="LOSS_MAP") +class LossMap(ComfyTypeIO): + class LossMapDict(TypedDict): + loss: list[torch.Tensor] + Type = LossMapDict + +@comfytype(io_type="VOXEL") +class Voxel(ComfyTypeIO): + Type = Any # TODO: VOXEL class is defined in comfy_extras/nodes_hunyuan3d.py; should be moved to somewhere else before referenced directly in v3 + +@comfytype(io_type="MESH") +class Mesh(ComfyTypeIO): + Type = Any # TODO: MESH class is defined in comfy_extras/nodes_hunyuan3d.py; should be moved to somewhere else before referenced directly in v3 + +@comfytype(io_type="HOOKS") +class Hooks(ComfyTypeIO): + Type = HookGroup + +@comfytype(io_type="HOOK_KEYFRAMES") +class HookKeyframes(ComfyTypeIO): + Type = HookKeyframeGroup + +@comfytype(io_type="TIMESTEPS_RANGE") +class TimestepsRange(ComfyTypeIO): + '''Range defined by start and endpoint, between 0.0 and 1.0.''' + Type = tuple[int, int] + +@comfytype(io_type="LATENT_OPERATION") +class LatentOperation(ComfyTypeIO): + Type = Callable[[torch.Tensor], torch.Tensor] + +@comfytype(io_type="FLOW_CONTROL") +class FlowControl(ComfyTypeIO): + # NOTE: only used in testing_nodes right now + Type = tuple[str, Any] + +@comfytype(io_type="ACCUMULATION") +class Accumulation(ComfyTypeIO): + # NOTE: only used in testing_nodes right now + class AccumulationDict(TypedDict): + accum: list[Any] + Type = AccumulationDict + +@comfytype(io_type="LOAD3D_CAMERA") +class Load3DCamera(ComfyTypeIO): + Type = Any # TODO: figure out type for this; in code, only described as image['camera_info'], gotten from a LOAD_3D or LOAD_3D_ANIMATION type + +@comfytype(io_type="POINT") +class Point(ComfyTypeIO): + Type = Any # NOTE: I couldn't find any references in core code to POINT io_type. Does this exist? + +@comfytype(io_type="FACE_ANALYSIS") +class FaceAnalysis(ComfyTypeIO): + Type = Any # NOTE: I couldn't find any references in core code to POINT io_type. Does this exist? + +@comfytype(io_type="BBOX") +class BBOX(ComfyTypeIO): + Type = Any # NOTE: I couldn't find any references in core code to POINT io_type. Does this exist? + +@comfytype(io_type="SEGS") +class SEGS(ComfyTypeIO): + Type = Any # NOTE: I couldn't find any references in core code to POINT io_type. Does this exist? @comfytype(io_type="COMFY_MULTITYPED_V3") class MultiType: @@ -528,7 +713,7 @@ class MultiType: ''' Input that permits more than one input type; if `id` is an instance of `ComfyType.Input`, then that input will be used to create a widget (if applicable) with overridden values. ''' - def __init__(self, id: str | ComfyType.Input, types: list[type[ComfyType] | ComfyType], display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None): + def __init__(self, id: str | InputV3, types: list[type[ComfyType] | ComfyType], display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None): # if id is an Input, then use that Input with overridden values self.input_override = None if isinstance(id, InputV3): @@ -561,7 +746,6 @@ class MultiType: str_types.insert(0, self.input_override.get_io_type_V1()) return ",".join(list(dict.fromkeys(str_types))) - @override def as_dict_V1(self): if self.input_override is not None: return self.input_override.as_dict_V1() | super().as_dict_V1() @@ -754,7 +938,7 @@ class SchemaV3: raise ValueError("\n".join(issues)) class Serializer: - def __init_subclass__(cls, io_type: IO | str, **kwargs): + def __init_subclass__(cls, io_type: str, **kwargs): cls.io_type = io_type super().__init_subclass__(**kwargs) @@ -1134,7 +1318,7 @@ class TestNode(ComfyNodeV3): if __name__ == "__main__": print("hello there") inputs: list[InputV3] = [ - Int.Input("tessfes", widgetType=IO.STRING), + Int.Input("tessfes", widgetType=String.io_type), Int.Input("my_int"), Custom("XYZ").Input("xyz"), Custom("MODEL_M").Input("model1"), diff --git a/comfy_extras/nodes_v3_test.py b/comfy_extras/nodes_v3_test.py index 9f385f952..1bfc8dc37 100644 --- a/comfy_extras/nodes_v3_test.py +++ b/comfy_extras/nodes_v3_test.py @@ -36,7 +36,7 @@ class V3TestNode(io.ComfyNodeV3): io.Image.Input("image", display_name="new_image"), XYZ.Input("xyz", optional=True), io.Custom("JKL").Input("jkl", optional=True), - io.Mask.Input("mask", optional=True), + io.Mask.Input("mask", display_name="mask haha", optional=True), io.Int.Input("some_int", display_name="new_name", min=0, max=127, default=42, tooltip="My tooltip 😎", display_mode=io.NumberDisplay.slider), io.Combo.Input("combo", options=["a", "b", "c"], tooltip="This is a combo input"), diff --git a/server.py b/server.py index 1a135fca7..28cf6b5ae 100644 --- a/server.py +++ b/server.py @@ -555,7 +555,7 @@ class PromptServer(): def node_info(node_class): obj_class = nodes.NODE_CLASS_MAPPINGS[node_class] - if isinstance(obj_class, ComfyNodeV3): + if issubclass(obj_class, ComfyNodeV3): return obj_class.GET_NODE_INFO_V1() info = {} info['input'] = obj_class.INPUT_TYPES()