Fix generated stubs differing by Python version

This commit is contained in:
Jacob Segal 2025-07-25 19:24:57 -07:00
parent 689db36073
commit b6754d935b
7 changed files with 88 additions and 20 deletions

View File

@ -398,14 +398,18 @@ class AsyncToSyncConverter:
origin_name = origin_name.split(".")[-1] origin_name = origin_name.split(".")[-1]
# Special handling for types.UnionType (Python 3.10+ pipe operator) # Special handling for types.UnionType (Python 3.10+ pipe operator)
if origin_name == "UnionType": # Convert to old-style Union for compatibility
if str(origin) == "<class 'types.UnionType'>" or origin_name == "UnionType":
origin_name = "Union" origin_name = "Union"
# Format arguments recursively # Format arguments recursively
if args: if args:
formatted_args = [ formatted_args = []
cls._format_type_annotation(arg, type_tracker) for arg in args for arg in args:
] # Track each type in the union
if type_tracker:
type_tracker.track_type(arg)
formatted_args.append(cls._format_type_annotation(arg, type_tracker))
return f"{origin_name}[{', '.join(formatted_args)}]" return f"{origin_name}[{', '.join(formatted_args)}]"
else: else:
return origin_name return origin_name
@ -489,23 +493,27 @@ class AsyncToSyncConverter:
cls, cls,
sig: inspect.Signature, sig: inspect.Signature,
skip_self: bool = True, skip_self: bool = True,
type_hints: Optional[dict] = None,
type_tracker: Optional[TypeTracker] = None, type_tracker: Optional[TypeTracker] = None,
) -> str: ) -> str:
"""Format method parameters for stub files.""" """Format method parameters for stub files."""
params = [] params = []
if type_hints is None:
type_hints = {}
for i, (param_name, param) in enumerate(sig.parameters.items()): for i, (param_name, param) in enumerate(sig.parameters.items()):
if i == 0 and param_name == "self" and skip_self: if i == 0 and param_name == "self" and skip_self:
params.append("self") params.append("self")
else: else:
# Get type annotation # Get type annotation from type hints if available, otherwise from signature
type_str = cls._format_type_annotation(param.annotation, type_tracker) annotation = type_hints.get(param_name, param.annotation)
type_str = cls._format_type_annotation(annotation, type_tracker)
# Get default value # Get default value
default_str = cls._format_parameter_default(param.default) default_str = cls._format_parameter_default(param.default)
# Combine parameter parts # Combine parameter parts
if param.annotation is inspect.Parameter.empty: if annotation is inspect.Parameter.empty:
params.append(f"{param_name}: Any{default_str}") params.append(f"{param_name}: Any{default_str}")
else: else:
params.append(f"{param_name}: {type_str}{default_str}") params.append(f"{param_name}: {type_str}{default_str}")
@ -523,13 +531,21 @@ class AsyncToSyncConverter:
"""Generate a complete method signature for stub files.""" """Generate a complete method signature for stub files."""
sig = inspect.signature(method) sig = inspect.signature(method)
# Try to get evaluated type hints to resolve string annotations
try:
from typing import get_type_hints
type_hints = get_type_hints(method)
except Exception:
# Fallback to empty dict if we can't get type hints
type_hints = {}
# For async methods, extract the actual return type # For async methods, extract the actual return type
return_annotation = sig.return_annotation return_annotation = type_hints.get('return', sig.return_annotation)
if is_async and inspect.iscoroutinefunction(method): if is_async and inspect.iscoroutinefunction(method):
return_annotation = cls._extract_coroutine_return_type(return_annotation) return_annotation = cls._extract_coroutine_return_type(return_annotation)
# Format parameters # Format parameters with type hints
params_str = cls._format_method_parameters(sig, type_tracker=type_tracker) params_str = cls._format_method_parameters(sig, type_hints=type_hints, type_tracker=type_tracker)
# Format return type # Format return type
return_type = cls._format_type_annotation(return_annotation, type_tracker) return_type = cls._format_type_annotation(return_annotation, type_tracker)
@ -556,8 +572,18 @@ class AsyncToSyncConverter:
additional_types = [] additional_types = []
if module: if module:
# Check if module has __all__ defined
module_all = getattr(module, "__all__", None)
for name, obj in sorted(inspect.getmembers(module)): for name, obj in sorted(inspect.getmembers(module)):
if isinstance(obj, type): if isinstance(obj, type):
# Skip if __all__ is defined and this name isn't in it
# unless it's already been tracked as used in type annotations
if module_all is not None and name not in module_all:
# Check if this type was actually used in annotations
if name not in type_tracker.discovered_types:
continue
# Check for NamedTuple # Check for NamedTuple
if issubclass(obj, tuple) and hasattr(obj, "_fields"): if issubclass(obj, tuple) and hasattr(obj, "_fields"):
additional_types.append(name) additional_types.append(name)
@ -636,9 +662,17 @@ class AsyncToSyncConverter:
try: try:
init_method = getattr(attr, "__init__") init_method = getattr(attr, "__init__")
init_sig = inspect.signature(init_method) init_sig = inspect.signature(init_method)
# Try to get type hints
try:
from typing import get_type_hints
init_hints = get_type_hints(init_method)
except Exception:
init_hints = {}
# Format parameters # Format parameters
params_str = cls._format_method_parameters( params_str = cls._format_method_parameters(
init_sig, type_tracker=type_tracker init_sig, type_hints=init_hints, type_tracker=type_tracker
) )
# Add __init__ docstring if available (before the method) # Add __init__ docstring if available (before the method)
if hasattr(init_method, "__doc__") and init_method.__doc__: if hasattr(init_method, "__doc__") and init_method.__doc__:
@ -790,9 +824,17 @@ class AsyncToSyncConverter:
try: try:
init_method = async_class.__init__ init_method = async_class.__init__
init_signature = inspect.signature(init_method) init_signature = inspect.signature(init_method)
# Try to get type hints for __init__
try:
from typing import get_type_hints
init_hints = get_type_hints(init_method)
except Exception:
init_hints = {}
# Format parameters # Format parameters
params_str = cls._format_method_parameters( params_str = cls._format_method_parameters(
init_signature, type_tracker=type_tracker init_signature, type_hints=init_hints, type_tracker=type_tracker
) )
# Add __init__ docstring if available (before the method) # Add __init__ docstring if available (before the method)
if hasattr(init_method, "__doc__") and init_method.__doc__: if hasattr(init_method, "__doc__") and init_method.__doc__:
@ -875,18 +917,21 @@ class AsyncToSyncConverter:
for attr_name, attr_type in sorted(all_annotations.items()): for attr_name, attr_type in sorted(all_annotations.items()):
for class_name, class_type in class_attributes: for class_name, class_type in class_attributes:
# If the class type matches the annotated type # If the class type matches the annotated type
if attr_type == class_type or ( if (
hasattr(attr_type, "__name__") attr_type == class_type
and attr_type.__name__ == class_name or (hasattr(attr_type, "__name__") and attr_type.__name__ == class_name)
or (isinstance(attr_type, str) and attr_type == class_name)
): ):
attribute_mappings[class_name] = attr_name attribute_mappings[class_name] = attr_name
# Remove the extra checking - annotations should be sufficient # Remove the extra checking - annotations should be sufficient
# Add the attribute declarations with proper names # Add the attribute declarations with proper names
for class_name, _ in class_attributes: for class_name, class_type in class_attributes:
# Use the attribute name if found in mappings, otherwise use class name # Check if there's a mapping from annotation
attr_name = attribute_mappings.get(class_name, class_name) attr_name = attribute_mappings.get(class_name, class_name)
# Use the annotation name if it exists, even if the attribute doesn't exist yet
# This is because the attribute might be created at runtime
stub_content.append(f" {attr_name}: {class_name}Sync") stub_content.append(f" {attr_name}: {class_name}Sync")
stub_content.append("") # Add a final newline stub_content.append("") # Add a final newline

View File

@ -97,3 +97,10 @@ if TYPE_CHECKING:
ComfyAPISync: Type[comfy_api.latest.generated.ComfyAPISyncStub.ComfyAPISyncStub] ComfyAPISync: Type[comfy_api.latest.generated.ComfyAPISyncStub.ComfyAPISyncStub]
ComfyAPISync = create_sync_class(ComfyAPI_latest) ComfyAPISync = create_sync_class(ComfyAPI_latest)
__all__ = [
"ComfyAPI",
"ComfyAPISync",
"Input",
"InputImpl",
"Types",
]

View File

@ -15,6 +15,6 @@ class ComfyAPISyncStub:
Migration from previous API: comfy.utils.PROGRESS_BAR_HOOK Migration from previous API: comfy.utils.PROGRESS_BAR_HOOK
""" """
def set_progress(self, value: float, max_value: float, node_id: Union[str, None] = None, preview_image: Union[tuple[str, Image, Union[int, None]], Image, Tensor, None] = None) -> None: ... def set_progress(self, value: float, max_value: float, node_id: Union[str, None] = None, preview_image: Union[Image, Tensor, None] = None, ignore_size_limit: bool = False) -> None: ...
execution: ExecutionSync execution: ExecutionSync

View File

@ -32,3 +32,11 @@ if TYPE_CHECKING:
ComfyAPISync: Type[ComfyAPISyncStub] ComfyAPISync: Type[ComfyAPISyncStub]
ComfyAPISync = create_sync_class(ComfyAPIAdapter_v0_0_1) ComfyAPISync = create_sync_class(ComfyAPIAdapter_v0_0_1)
__all__ = [
"ComfyAPI",
"ComfyAPISync",
"Input",
"InputImpl",
"Types",
]

View File

@ -15,6 +15,6 @@ class ComfyAPISyncStub:
Migration from previous API: comfy.utils.PROGRESS_BAR_HOOK Migration from previous API: comfy.utils.PROGRESS_BAR_HOOK
""" """
def set_progress(self, value: float, max_value: float, node_id: Union[str, None] = None, preview_image: Union[tuple[str, Image, Union[int, None]], Image, Tensor, None] = None) -> None: ... def set_progress(self, value: float, max_value: float, node_id: Union[str, None] = None, preview_image: Union[Image, Tensor, None] = None, ignore_size_limit: bool = False) -> None: ...
execution: ExecutionSync execution: ExecutionSync

View File

@ -33,3 +33,11 @@ if TYPE_CHECKING:
ComfyAPISync: Type[ComfyAPISyncStub] ComfyAPISync: Type[ComfyAPISyncStub]
ComfyAPISync = create_sync_class(ComfyAPIAdapter_v0_0_2) ComfyAPISync = create_sync_class(ComfyAPIAdapter_v0_0_2)
__all__ = [
"ComfyAPI",
"ComfyAPISync",
"Input",
"InputImpl",
"Types",
]

View File

@ -15,6 +15,6 @@ class ComfyAPISyncStub:
Migration from previous API: comfy.utils.PROGRESS_BAR_HOOK Migration from previous API: comfy.utils.PROGRESS_BAR_HOOK
""" """
def set_progress(self, value: float, max_value: float, node_id: Union[str, None] = None, preview_image: Union[tuple[str, Image, Union[int, None]], Image, Tensor, None] = None) -> None: ... def set_progress(self, value: float, max_value: float, node_id: Union[str, None] = None, preview_image: Union[Image, Tensor, None] = None, ignore_size_limit: bool = False) -> None: ...
execution: ExecutionSync execution: ExecutionSync