From b6754d935b314d54cbe3f3f515e5ce3d55052426 Mon Sep 17 00:00:00 2001 From: Jacob Segal Date: Fri, 25 Jul 2025 19:24:57 -0700 Subject: [PATCH] Fix generated stubs differing by Python version --- comfy_api/internal/async_to_sync.py | 79 +++++++++++++++---- comfy_api/latest/__init__.py | 7 ++ .../latest/generated/ComfyAPISyncStub.pyi | 2 +- comfy_api/v0_0_1/__init__.py | 8 ++ .../v0_0_1/generated/ComfyAPISyncStub.pyi | 2 +- comfy_api/v0_0_2/__init__.py | 8 ++ .../v0_0_2/generated/ComfyAPISyncStub.pyi | 2 +- 7 files changed, 88 insertions(+), 20 deletions(-) diff --git a/comfy_api/internal/async_to_sync.py b/comfy_api/internal/async_to_sync.py index f6bf04230..dbc62255c 100644 --- a/comfy_api/internal/async_to_sync.py +++ b/comfy_api/internal/async_to_sync.py @@ -398,14 +398,18 @@ class AsyncToSyncConverter: origin_name = origin_name.split(".")[-1] # Special handling for types.UnionType (Python 3.10+ pipe operator) - if origin_name == "UnionType": + # Convert to old-style Union for compatibility + if str(origin) == "" or origin_name == "UnionType": origin_name = "Union" # Format arguments recursively if args: - formatted_args = [ - cls._format_type_annotation(arg, type_tracker) for arg in args - ] + formatted_args = [] + for arg in args: + # Track each type in the union + if type_tracker: + type_tracker.track_type(arg) + formatted_args.append(cls._format_type_annotation(arg, type_tracker)) return f"{origin_name}[{', '.join(formatted_args)}]" else: return origin_name @@ -489,23 +493,27 @@ class AsyncToSyncConverter: cls, sig: inspect.Signature, skip_self: bool = True, + type_hints: Optional[dict] = None, type_tracker: Optional[TypeTracker] = None, ) -> str: """Format method parameters for stub files.""" params = [] + if type_hints is None: + type_hints = {} for i, (param_name, param) in enumerate(sig.parameters.items()): if i == 0 and param_name == "self" and skip_self: params.append("self") else: - # Get type annotation - type_str = cls._format_type_annotation(param.annotation, type_tracker) + # Get type annotation from type hints if available, otherwise from signature + annotation = type_hints.get(param_name, param.annotation) + type_str = cls._format_type_annotation(annotation, type_tracker) # Get default value default_str = cls._format_parameter_default(param.default) # Combine parameter parts - if param.annotation is inspect.Parameter.empty: + if annotation is inspect.Parameter.empty: params.append(f"{param_name}: Any{default_str}") else: params.append(f"{param_name}: {type_str}{default_str}") @@ -522,14 +530,22 @@ class AsyncToSyncConverter: ) -> str: """Generate a complete method signature for stub files.""" sig = inspect.signature(method) + + # Try to get evaluated type hints to resolve string annotations + try: + from typing import get_type_hints + type_hints = get_type_hints(method) + except Exception: + # Fallback to empty dict if we can't get type hints + type_hints = {} # For async methods, extract the actual return type - return_annotation = sig.return_annotation + return_annotation = type_hints.get('return', sig.return_annotation) if is_async and inspect.iscoroutinefunction(method): return_annotation = cls._extract_coroutine_return_type(return_annotation) - # Format parameters - params_str = cls._format_method_parameters(sig, type_tracker=type_tracker) + # Format parameters with type hints + params_str = cls._format_method_parameters(sig, type_hints=type_hints, type_tracker=type_tracker) # Format return type return_type = cls._format_type_annotation(return_annotation, type_tracker) @@ -556,8 +572,18 @@ class AsyncToSyncConverter: additional_types = [] if module: + # Check if module has __all__ defined + module_all = getattr(module, "__all__", None) + for name, obj in sorted(inspect.getmembers(module)): if isinstance(obj, type): + # Skip if __all__ is defined and this name isn't in it + # unless it's already been tracked as used in type annotations + if module_all is not None and name not in module_all: + # Check if this type was actually used in annotations + if name not in type_tracker.discovered_types: + continue + # Check for NamedTuple if issubclass(obj, tuple) and hasattr(obj, "_fields"): additional_types.append(name) @@ -636,9 +662,17 @@ class AsyncToSyncConverter: try: init_method = getattr(attr, "__init__") init_sig = inspect.signature(init_method) + + # Try to get type hints + try: + from typing import get_type_hints + init_hints = get_type_hints(init_method) + except Exception: + init_hints = {} + # Format parameters params_str = cls._format_method_parameters( - init_sig, type_tracker=type_tracker + init_sig, type_hints=init_hints, type_tracker=type_tracker ) # Add __init__ docstring if available (before the method) if hasattr(init_method, "__doc__") and init_method.__doc__: @@ -790,9 +824,17 @@ class AsyncToSyncConverter: try: init_method = async_class.__init__ init_signature = inspect.signature(init_method) + + # Try to get type hints for __init__ + try: + from typing import get_type_hints + init_hints = get_type_hints(init_method) + except Exception: + init_hints = {} + # Format parameters params_str = cls._format_method_parameters( - init_signature, type_tracker=type_tracker + init_signature, type_hints=init_hints, type_tracker=type_tracker ) # Add __init__ docstring if available (before the method) if hasattr(init_method, "__doc__") and init_method.__doc__: @@ -875,18 +917,21 @@ class AsyncToSyncConverter: for attr_name, attr_type in sorted(all_annotations.items()): for class_name, class_type in class_attributes: # If the class type matches the annotated type - if attr_type == class_type or ( - hasattr(attr_type, "__name__") - and attr_type.__name__ == class_name + if ( + attr_type == class_type + or (hasattr(attr_type, "__name__") and attr_type.__name__ == class_name) + or (isinstance(attr_type, str) and attr_type == class_name) ): attribute_mappings[class_name] = attr_name # Remove the extra checking - annotations should be sufficient # Add the attribute declarations with proper names - for class_name, _ in class_attributes: - # Use the attribute name if found in mappings, otherwise use class name + for class_name, class_type in class_attributes: + # Check if there's a mapping from annotation attr_name = attribute_mappings.get(class_name, class_name) + # Use the annotation name if it exists, even if the attribute doesn't exist yet + # This is because the attribute might be created at runtime stub_content.append(f" {attr_name}: {class_name}Sync") stub_content.append("") # Add a final newline diff --git a/comfy_api/latest/__init__.py b/comfy_api/latest/__init__.py index bcf09ffbf..e1f3a3655 100644 --- a/comfy_api/latest/__init__.py +++ b/comfy_api/latest/__init__.py @@ -97,3 +97,10 @@ if TYPE_CHECKING: ComfyAPISync: Type[comfy_api.latest.generated.ComfyAPISyncStub.ComfyAPISyncStub] ComfyAPISync = create_sync_class(ComfyAPI_latest) +__all__ = [ + "ComfyAPI", + "ComfyAPISync", + "Input", + "InputImpl", + "Types", +] diff --git a/comfy_api/latest/generated/ComfyAPISyncStub.pyi b/comfy_api/latest/generated/ComfyAPISyncStub.pyi index 280893ddb..525c074dd 100644 --- a/comfy_api/latest/generated/ComfyAPISyncStub.pyi +++ b/comfy_api/latest/generated/ComfyAPISyncStub.pyi @@ -15,6 +15,6 @@ class ComfyAPISyncStub: Migration from previous API: comfy.utils.PROGRESS_BAR_HOOK """ - def set_progress(self, value: float, max_value: float, node_id: Union[str, None] = None, preview_image: Union[tuple[str, Image, Union[int, None]], Image, Tensor, None] = None) -> None: ... + def set_progress(self, value: float, max_value: float, node_id: Union[str, None] = None, preview_image: Union[Image, Tensor, None] = None, ignore_size_limit: bool = False) -> None: ... execution: ExecutionSync diff --git a/comfy_api/v0_0_1/__init__.py b/comfy_api/v0_0_1/__init__.py index ab6dc2b42..93608771d 100644 --- a/comfy_api/v0_0_1/__init__.py +++ b/comfy_api/v0_0_1/__init__.py @@ -32,3 +32,11 @@ if TYPE_CHECKING: ComfyAPISync: Type[ComfyAPISyncStub] ComfyAPISync = create_sync_class(ComfyAPIAdapter_v0_0_1) + +__all__ = [ + "ComfyAPI", + "ComfyAPISync", + "Input", + "InputImpl", + "Types", +] diff --git a/comfy_api/v0_0_1/generated/ComfyAPISyncStub.pyi b/comfy_api/v0_0_1/generated/ComfyAPISyncStub.pyi index c31461f17..270030324 100644 --- a/comfy_api/v0_0_1/generated/ComfyAPISyncStub.pyi +++ b/comfy_api/v0_0_1/generated/ComfyAPISyncStub.pyi @@ -15,6 +15,6 @@ class ComfyAPISyncStub: Migration from previous API: comfy.utils.PROGRESS_BAR_HOOK """ - def set_progress(self, value: float, max_value: float, node_id: Union[str, None] = None, preview_image: Union[tuple[str, Image, Union[int, None]], Image, Tensor, None] = None) -> None: ... + def set_progress(self, value: float, max_value: float, node_id: Union[str, None] = None, preview_image: Union[Image, Tensor, None] = None, ignore_size_limit: bool = False) -> None: ... execution: ExecutionSync diff --git a/comfy_api/v0_0_2/__init__.py b/comfy_api/v0_0_2/__init__.py index 1b68bcc97..ea83833fb 100644 --- a/comfy_api/v0_0_2/__init__.py +++ b/comfy_api/v0_0_2/__init__.py @@ -33,3 +33,11 @@ if TYPE_CHECKING: ComfyAPISync: Type[ComfyAPISyncStub] ComfyAPISync = create_sync_class(ComfyAPIAdapter_v0_0_2) + +__all__ = [ + "ComfyAPI", + "ComfyAPISync", + "Input", + "InputImpl", + "Types", +] diff --git a/comfy_api/v0_0_2/generated/ComfyAPISyncStub.pyi b/comfy_api/v0_0_2/generated/ComfyAPISyncStub.pyi index b3ad0e3df..7fcec685e 100644 --- a/comfy_api/v0_0_2/generated/ComfyAPISyncStub.pyi +++ b/comfy_api/v0_0_2/generated/ComfyAPISyncStub.pyi @@ -15,6 +15,6 @@ class ComfyAPISyncStub: Migration from previous API: comfy.utils.PROGRESS_BAR_HOOK """ - def set_progress(self, value: float, max_value: float, node_id: Union[str, None] = None, preview_image: Union[tuple[str, Image, Union[int, None]], Image, Tensor, None] = None) -> None: ... + def set_progress(self, value: float, max_value: float, node_id: Union[str, None] = None, preview_image: Union[Image, Tensor, None] = None, ignore_size_limit: bool = False) -> None: ... execution: ExecutionSync