Defined TypedDict hints for Latent, Conditioning, and Audio types

This commit is contained in:
Jedrzej Kosinski 2025-06-27 16:57:55 -07:00
parent ba857bd8a0
commit d0c077423a

View File

@ -1,5 +1,5 @@
from __future__ import annotations from __future__ import annotations
from typing import Any, Literal, TYPE_CHECKING, TypeVar, Callable, Optional, cast from typing import Any, Literal, TYPE_CHECKING, TypeVar, Callable, Optional, cast, TypedDict, NotRequired
from enum import Enum from enum import Enum
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass, asdict from dataclasses import dataclass, asdict
@ -17,6 +17,7 @@ from comfy.sd import StyleModel as StyleModel_
from comfy.clip_vision import ClipVisionModel from comfy.clip_vision import ClipVisionModel
from comfy.clip_vision import Output as ClipVisionOutput_ from comfy.clip_vision import Output as ClipVisionOutput_
from comfy_api.input import VideoInput from comfy_api.input import VideoInput
from comfy.hooks import HookGroup
class FolderType(str, Enum): class FolderType(str, Enum):
@ -448,11 +449,132 @@ class Mask(ComfyTypeIO):
@comfytype(io_type=IO.LATENT) @comfytype(io_type=IO.LATENT)
class Latent(ComfyTypeIO): class Latent(ComfyTypeIO):
Type = Any # TODO: make Type a TypedDict '''Latents are stored as a dictionary.'''
class LatentDict(TypedDict):
samples: torch.Tensor
'''Latent tensors.'''
noise_mask: NotRequired[torch.Tensor]
batch_index: NotRequired[list[int]]
type: NotRequired[str]
'''Only needed if dealing with these types: audio, hunyuan3dv2'''
Type = LatentDict
@comfytype(io_type=IO.CONDITIONING) @comfytype(io_type=IO.CONDITIONING)
class Conditioning(ComfyTypeIO): class Conditioning(ComfyTypeIO):
Type = Any # TODO: make Type a TypedDict class PooledDict(TypedDict):
pooled_output: torch.Tensor
'''Pooled output from CLIP.'''
control: NotRequired[ControlNet]
'''ControlNet to apply to conditioning.'''
control_apply_to_uncond: NotRequired[bool]
'''Whether to apply ControlNet to matching negative conditioning at sample time, if applicable.'''
cross_attn_controlnet: NotRequired[torch.Tensor]
'''CrossAttn from CLIP to use for controlnet only.'''
pooled_output_controlnet: NotRequired[torch.Tensor]
'''Pooled output from CLIP to use for controlnet only.'''
gligen: NotRequired[tuple[str, Gligen, list[tuple[torch.Tensor, int, ...]]]]
'''GLIGEN to apply to conditioning.'''
area: NotRequired[tuple[int, ...] | tuple[str, float, ...]]
'''Set area of conditioning. First half of values apply to dimensions, the second half apply to coordinates.
By default, the dimensions are based on total pixel amount, but the first value can be set to "percentage" to use a percentage of the image size instead.
(1024, 1024, 0, 0) would apply conditioning to the top-left 1024x1024 pixels.
("percentage", 0.5, 0.5, 0, 0) would apply conditioning to the top-left 50% of the image.''' # TODO: verify its actually top-left
strength: NotRequired[float]
'''Strength of conditioning. Default strength is 1.0.'''
mask: NotRequired[torch.Tensor]
'''Mask to apply conditioning to.'''
mask_strength: NotRequired[float]
'''Strength of conditioning mask. Default strength is 1.0.'''
set_area_to_bounds: NotRequired[bool]
'''Whether conditioning mask should determine bounds of area - if set to false, latents are sampled at full resolution and result is applied in mask.'''
concat_latent_image: NotRequired[torch.Tensor]
'''Used for inpainting and specific models.'''
concat_mask: NotRequired[torch.Tensor]
'''Used for inpainting and specific models.'''
concat_image: NotRequired[torch.Tensor]
'''Used by SD_4XUpscale_Conditioning.'''
noise_augmentation: NotRequired[float]
'''Used by SD_4XUpscale_Conditioning.'''
hooks: NotRequired[HookGroup]
'''Applies hooks to conditioning.'''
default: NotRequired[bool]
'''Whether to this conditioning is 'default'; default conditioning gets applied to any areas of the image that have no masks/areas applied, assuming at least one area/mask is present during sampling.'''
start_percent: NotRequired[float]
'''Determines relative step to begin applying conditioning, expressed as a float between 0.0 and 1.0.'''
end_percent: NotRequired[float]
'''Determines relative step to end applying conditioning, expressed as a float between 0.0 and 1.0.'''
clip_start_percent: NotRequired[float]
'''Internal variable for conditioning scheduling - start of application, expressed as a float between 0.0 and 1.0.'''
clip_end_percent: NotRequired[float]
'''Internal variable for conditioning scheduling - end of application, expressed as a float between 0.0 and 1.0.'''
attention_mask: NotRequired[torch.Tensor]
'''Used by StyleModel.'''
attention_mask_img_shape: NotRequired[tuple[int, ...]]
'''Used by StyleModel.'''
unclip_conditioning: NotRequired[list[dict]]
'''Used by unCLIP.'''
conditioning_lyrics: NotRequired[torch.Tensor]
seconds_start: NotRequired[float]
'''Used by StableAudio.'''
seconds_total: NotRequired[float]
'''Used by StableAudio.'''
lyrics_strength: NotRequired[float]
'''Used by AceStepAudio.'''
width: NotRequired[int]
'''Used by certain models (e.g. CLIPTextEncodeSDXL/Refiner, PixArtAlpha).'''
height: NotRequired[int]
'''Used by certain models (e.g. CLIPTextEncodeSDXL/Refiner, PixArtAlpha).'''
aesthetic_score: NotRequired[float]
'''Used by CLIPTextEncodeSDXL/Refiner.'''
crop_w: NotRequired[int]
'''Used by CLIPTextEncodeSDXL.'''
crop_h: NotRequired[int]
'''Used by CLIPTextEncodeSDXL.'''
target_width: NotRequired[int]
'''Used by CLIPTextEncodeSDXL.'''
target_height: NotRequired[int]
'''Used by CLIPTextEncodeSDXL.'''
reference_latents: NotRequired[list[torch.Tensor]]
'''Used by ReferenceLatent.'''
guidance: NotRequired[float]
'''Used by Flux-like models with guidance embed.'''
guiding_frame_index: NotRequired[int]
'''Used by Hunyuan ImageToVideo.'''
ref_latent: NotRequired[torch.Tensor]
'''Used by Hunyuan ImageToVideo.'''
keyframe_idxs: NotRequired[list[int]]
'''Used by LTXV.'''
frame_rate: NotRequired[float]
'''Used by LTXV.'''
stable_cascade_prior: NotRequired[torch.Tensor]
'''Used by StableCascade.'''
elevation: NotRequired[list[float]]
'''Used by SV3D.'''
azimuth: NotRequired[list[float]]
'''Used by SV3D.'''
motion_bucket_id: NotRequired[int]
'''Used by SVD-like models.'''
fps: NotRequired[int]
'''Used by SVD-like models.'''
augmentation_level: NotRequired[float]
'''Used by SVD-like models.'''
clip_vision_output: NotRequired[ClipVisionOutput_]
'''Used by WAN-like models.'''
vace_frames: NotRequired[torch.Tensor]
'''Used by WAN VACE.'''
vace_mask: NotRequired[torch.Tensor]
'''Used by WAN VACE.'''
vace_strength: NotRequired[float]
'''Used by WAN VACE.'''
camera_conditions: NotRequired[Any] # TODO: assign proper type once defined
'''Used by WAN Camera.'''
time_dim_concat: NotRequired[torch.Tensor]
'''Used by WAN Phantom Subject.'''
CondList = list[tuple[torch.Tensor, PooledDict]]
Type = CondList
@comfytype(io_type=IO.SAMPLER) @comfytype(io_type=IO.SAMPLER)
class Sampler(ComfyTypeIO): class Sampler(ComfyTypeIO):
@ -509,7 +631,10 @@ class UpscaleModel(ComfyTypeIO):
@comfytype(io_type=IO.AUDIO) @comfytype(io_type=IO.AUDIO)
class Audio(ComfyTypeIO): class Audio(ComfyTypeIO):
Type = Any # TODO: make Type a TypedDict class AudioDict(TypedDict):
waveform: torch.Tensor
sampler_rate: int
Type = AudioDict
@comfytype(io_type=IO.POINT) @comfytype(io_type=IO.POINT)
class Point(ComfyTypeIO): class Point(ComfyTypeIO):