import asyncio import concurrent.futures import contextvars import functools import inspect import logging import os import textwrap import threading from enum import Enum from typing import Optional, Type, get_origin, get_args class TypeTracker: """Tracks types discovered during stub generation for automatic import generation.""" def __init__(self): self.discovered_types = {} # type_name -> (module, qualname) self.builtin_types = { "Any", "Dict", "List", "Optional", "Tuple", "Union", "Set", "Sequence", "cast", "NamedTuple", "str", "int", "float", "bool", "None", "bytes", "object", "type", "dict", "list", "tuple", "set", } self.already_imported = ( set() ) # Track types already imported to avoid duplicates def track_type(self, annotation): """Track a type annotation and record its module/import info.""" if annotation is None or annotation is type(None): return # Skip builtins and typing module types we already import type_name = getattr(annotation, "__name__", None) if type_name and ( type_name in self.builtin_types or type_name in self.already_imported ): return # Get module and qualname module = getattr(annotation, "__module__", None) qualname = getattr(annotation, "__qualname__", type_name or "") # Skip types from typing module (they're already imported) if module == "typing": return # Skip UnionType and GenericAlias from types module as they're handled specially if module == "types" and type_name in ("UnionType", "GenericAlias"): return if module and module not in ["builtins", "__main__"]: # Store the type info if type_name: self.discovered_types[type_name] = (module, qualname) def get_imports(self, main_module_name: str) -> list[str]: """Generate import statements for all discovered types.""" imports = [] imports_by_module = {} for type_name, (module, qualname) in sorted(self.discovered_types.items()): # Skip types from the main module (they're already imported) if main_module_name and module == main_module_name: continue if module not in imports_by_module: imports_by_module[module] = [] if type_name not in imports_by_module[module]: # Avoid duplicates imports_by_module[module].append(type_name) # Generate import statements for module, types in sorted(imports_by_module.items()): if len(types) == 1: imports.append(f"from {module} import {types[0]}") else: imports.append(f"from {module} import {', '.join(sorted(set(types)))}") return imports class AsyncToSyncConverter: """ Provides utilities to convert async classes to sync classes with proper type hints. """ _thread_pool: Optional[concurrent.futures.ThreadPoolExecutor] = None _thread_pool_lock = threading.Lock() _thread_pool_initialized = False @classmethod def get_thread_pool(cls, max_workers=None) -> concurrent.futures.ThreadPoolExecutor: """Get or create the shared thread pool with proper thread-safe initialization.""" # Fast path - check if already initialized without acquiring lock if cls._thread_pool_initialized: assert cls._thread_pool is not None, "Thread pool should be initialized" return cls._thread_pool # Slow path - acquire lock and create pool if needed with cls._thread_pool_lock: if not cls._thread_pool_initialized: cls._thread_pool = concurrent.futures.ThreadPoolExecutor( max_workers=max_workers, thread_name_prefix="async_to_sync_" ) cls._thread_pool_initialized = True # This should never be None at this point, but add assertion for type checker assert cls._thread_pool is not None return cls._thread_pool @classmethod def run_async_in_thread(cls, coro_func, *args, **kwargs): """ Run an async function in a separate thread from the thread pool. Blocks until the async function completes. Properly propagates contextvars between threads and manages event loops. """ # Capture current context - this includes all context variables context = contextvars.copy_context() # Store the result and any exception that occurs result_container: dict = {"result": None, "exception": None} # Function that runs in the thread pool def run_in_thread(): # Create new event loop for this thread loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) try: # Create the coroutine within the context async def run_with_context(): # The coroutine function might access context variables return await coro_func(*args, **kwargs) # Run the coroutine with the captured context # This ensures all context variables are available in the async function result = context.run(loop.run_until_complete, run_with_context()) result_container["result"] = result except Exception as e: # Store the exception to re-raise in the calling thread result_container["exception"] = e finally: # Ensure event loop is properly closed to prevent warnings try: # Cancel any remaining tasks pending = asyncio.all_tasks(loop) for task in pending: task.cancel() # Run the loop briefly to handle cancellations if pending: loop.run_until_complete( asyncio.gather(*pending, return_exceptions=True) ) except Exception: pass # Ignore errors during cleanup # Close the event loop loop.close() # Clear the event loop from the thread asyncio.set_event_loop(None) # Submit to thread pool and wait for result thread_pool = cls.get_thread_pool() future = thread_pool.submit(run_in_thread) future.result() # Wait for completion # Re-raise any exception that occurred in the thread if result_container["exception"] is not None: raise result_container["exception"] return result_container["result"] @classmethod def create_sync_class(cls, async_class: Type, thread_pool_size=10) -> Type: """ Creates a new class with synchronous versions of all async methods. Args: async_class: The async class to convert thread_pool_size: Size of thread pool to use Returns: A new class with sync versions of all async methods """ sync_class_name = "ComfyAPISyncStub" cls.get_thread_pool(thread_pool_size) # Create a proper class with docstrings and proper base classes sync_class_dict = { "__doc__": async_class.__doc__, "__module__": async_class.__module__, "__qualname__": sync_class_name, "__orig_class__": async_class, # Store original class for typing references } # Create __init__ method def __init__(self, *args, **kwargs): self._async_instance = async_class(*args, **kwargs) # Handle annotated class attributes (like execution: Execution) # Get all annotations from the class hierarchy all_annotations = {} for base_class in reversed(inspect.getmro(async_class)): if hasattr(base_class, "__annotations__"): all_annotations.update(base_class.__annotations__) # For each annotated attribute, check if it needs to be created or wrapped for attr_name, attr_type in all_annotations.items(): if hasattr(self._async_instance, attr_name): # Attribute exists on the instance attr = getattr(self._async_instance, attr_name) # Check if this attribute needs a sync wrapper if hasattr(attr, "__class__"): from comfy_api.internal.singleton import ProxiedSingleton if isinstance(attr, ProxiedSingleton): # Create a sync version of this attribute try: sync_attr_class = cls.create_sync_class(attr.__class__) # Create instance of the sync wrapper with the async instance sync_attr = object.__new__(sync_attr_class) # type: ignore sync_attr._async_instance = attr setattr(self, attr_name, sync_attr) except Exception: # If we can't create a sync version, keep the original setattr(self, attr_name, attr) else: # Not async, just copy the reference setattr(self, attr_name, attr) else: # Attribute doesn't exist, but is annotated - create it # This handles cases like execution: Execution if isinstance(attr_type, type): # Check if the type is defined as an inner class if hasattr(async_class, attr_type.__name__): inner_class = getattr(async_class, attr_type.__name__) from comfy_api.internal.singleton import ProxiedSingleton # Create an instance of the inner class try: # For ProxiedSingleton classes, get or create the singleton instance if issubclass(inner_class, ProxiedSingleton): async_instance = inner_class.get_instance() else: async_instance = inner_class() # Create sync wrapper sync_attr_class = cls.create_sync_class(inner_class) sync_attr = object.__new__(sync_attr_class) # type: ignore sync_attr._async_instance = async_instance setattr(self, attr_name, sync_attr) # Also set on the async instance for consistency setattr(self._async_instance, attr_name, async_instance) except Exception as e: logging.warning( f"Failed to create instance for {attr_name}: {e}" ) # Handle other instance attributes that might not be annotated for name, attr in inspect.getmembers(self._async_instance): if name.startswith("_") or hasattr(self, name): continue # If attribute is an instance of a class, and that class is defined in the original class # we need to check if it needs a sync wrapper if isinstance(attr, object) and not isinstance( attr, (str, int, float, bool, list, dict, tuple) ): from comfy_api.internal.singleton import ProxiedSingleton if isinstance(attr, ProxiedSingleton): # Create a sync version of this nested class try: sync_attr_class = cls.create_sync_class(attr.__class__) # Create instance of the sync wrapper with the async instance sync_attr = object.__new__(sync_attr_class) # type: ignore sync_attr._async_instance = attr setattr(self, name, sync_attr) except Exception: # If we can't create a sync version, keep the original setattr(self, name, attr) sync_class_dict["__init__"] = __init__ # Process methods from the async class for name, method in inspect.getmembers( async_class, predicate=inspect.isfunction ): if name.startswith("_"): continue # Extract the actual return type from a coroutine if inspect.iscoroutinefunction(method): # Create sync version of async method with proper signature @functools.wraps(method) def sync_method(self, *args, _method_name=name, **kwargs): async_method = getattr(self._async_instance, _method_name) return AsyncToSyncConverter.run_async_in_thread( async_method, *args, **kwargs ) # Add to the class dict sync_class_dict[name] = sync_method else: # For regular methods, create a proxy method @functools.wraps(method) def proxy_method(self, *args, _method_name=name, **kwargs): method = getattr(self._async_instance, _method_name) return method(*args, **kwargs) # Add to the class dict sync_class_dict[name] = proxy_method # Handle property access for name, prop in inspect.getmembers( async_class, lambda x: isinstance(x, property) ): def make_property(name, prop_obj): def getter(self): value = getattr(self._async_instance, name) if inspect.iscoroutinefunction(value): def sync_fn(*args, **kwargs): return AsyncToSyncConverter.run_async_in_thread( value, *args, **kwargs ) return sync_fn return value def setter(self, value): setattr(self._async_instance, name, value) return property(getter, setter if prop_obj.fset else None) sync_class_dict[name] = make_property(name, prop) # Create the class sync_class = type(sync_class_name, (object,), sync_class_dict) return sync_class @classmethod def _format_type_annotation( cls, annotation, type_tracker: Optional[TypeTracker] = None ) -> str: """Convert a type annotation to its string representation for stub files.""" if ( annotation is inspect.Parameter.empty or annotation is inspect.Signature.empty ): return "Any" # Handle None type if annotation is type(None): return "None" # Track the type if we have a tracker if type_tracker: type_tracker.track_type(annotation) # Try using typing.get_origin/get_args for Python 3.8+ try: origin = get_origin(annotation) args = get_args(annotation) if origin is not None: # Track the origin type if type_tracker: type_tracker.track_type(origin) # Get the origin name origin_name = getattr(origin, "__name__", str(origin)) if "." in origin_name: origin_name = origin_name.split(".")[-1] # Special handling for types.UnionType (Python 3.10+ pipe operator) if origin_name == "UnionType": origin_name = "Union" # Format arguments recursively if args: formatted_args = [ cls._format_type_annotation(arg, type_tracker) for arg in args ] return f"{origin_name}[{', '.join(formatted_args)}]" else: return origin_name except (AttributeError, TypeError): # Fallback for older Python versions or non-generic types pass # Handle generic types the old way for compatibility if hasattr(annotation, "__origin__") and hasattr(annotation, "__args__"): origin = annotation.__origin__ origin_name = ( origin.__name__ if hasattr(origin, "__name__") else str(origin).split("'")[1] ) # Format each type argument args = [] for arg in annotation.__args__: args.append(cls._format_type_annotation(arg, type_tracker)) return f"{origin_name}[{', '.join(args)}]" # Handle regular types with __name__ if hasattr(annotation, "__name__"): return annotation.__name__ # Handle special module types (like types from typing module) if hasattr(annotation, "__module__") and hasattr(annotation, "__qualname__"): # For types like typing.Literal, typing.TypedDict, etc. return annotation.__qualname__ # Last resort: string conversion with cleanup type_str = str(annotation) # Clean up common patterns more robustly if type_str.startswith(""): type_str = type_str[8:-2] # Remove "" # Remove module prefixes for common modules for prefix in ["typing.", "builtins.", "types."]: if type_str.startswith(prefix): type_str = type_str[len(prefix) :] # Handle special cases if type_str in ("_empty", "inspect._empty"): return "None" # Fix NoneType (this should rarely be needed now) if type_str == "NoneType": return "None" return type_str @classmethod def _extract_coroutine_return_type(cls, annotation): """Extract the actual return type from a Coroutine annotation.""" if hasattr(annotation, "__args__") and len(annotation.__args__) > 2: # Coroutine[Any, Any, ReturnType] -> extract ReturnType return annotation.__args__[2] return annotation @classmethod def _format_parameter_default(cls, default_value) -> str: """Format a parameter's default value for stub files.""" if default_value is inspect.Parameter.empty: return "" elif default_value is None: return " = None" elif isinstance(default_value, bool): return f" = {default_value}" elif default_value == {}: return " = {}" elif default_value == []: return " = []" else: return f" = {default_value}" @classmethod def _format_method_parameters( cls, sig: inspect.Signature, skip_self: bool = True, type_tracker: Optional[TypeTracker] = None, ) -> str: """Format method parameters for stub files.""" params = [] for i, (param_name, param) in enumerate(sig.parameters.items()): if i == 0 and param_name == "self" and skip_self: params.append("self") else: # Get type annotation type_str = cls._format_type_annotation(param.annotation, type_tracker) # Get default value default_str = cls._format_parameter_default(param.default) # Combine parameter parts if param.annotation is inspect.Parameter.empty: params.append(f"{param_name}: Any{default_str}") else: params.append(f"{param_name}: {type_str}{default_str}") return ", ".join(params) @classmethod def _generate_method_signature( cls, method_name: str, method, is_async: bool = False, type_tracker: Optional[TypeTracker] = None, ) -> str: """Generate a complete method signature for stub files.""" sig = inspect.signature(method) # For async methods, extract the actual return type return_annotation = sig.return_annotation if is_async and inspect.iscoroutinefunction(method): return_annotation = cls._extract_coroutine_return_type(return_annotation) # Format parameters params_str = cls._format_method_parameters(sig, type_tracker=type_tracker) # Format return type return_type = cls._format_type_annotation(return_annotation, type_tracker) if return_annotation is inspect.Signature.empty: return_type = "None" return f"def {method_name}({params_str}) -> {return_type}: ..." @classmethod def _generate_imports( cls, async_class: Type, type_tracker: TypeTracker ) -> list[str]: """Generate import statements for the stub file.""" imports = [] # Add standard typing imports imports.append( "from typing import Any, Dict, List, Optional, Tuple, Union, Set, Sequence, cast, NamedTuple" ) # Add imports from the original module if async_class.__module__ != "builtins": module = inspect.getmodule(async_class) additional_types = [] if module: for name, obj in sorted(inspect.getmembers(module)): if isinstance(obj, type): # Check for NamedTuple if issubclass(obj, tuple) and hasattr(obj, "_fields"): additional_types.append(name) # Mark as already imported type_tracker.already_imported.add(name) # Check for Enum elif issubclass(obj, Enum) and name != "Enum": additional_types.append(name) # Mark as already imported type_tracker.already_imported.add(name) if additional_types: type_imports = ", ".join([async_class.__name__] + additional_types) imports.append(f"from {async_class.__module__} import {type_imports}") else: imports.append( f"from {async_class.__module__} import {async_class.__name__}" ) # Add imports for all discovered types # Pass the main module name to avoid duplicate imports imports.extend( type_tracker.get_imports(main_module_name=async_class.__module__) ) # Add base module import if needed if hasattr(inspect.getmodule(async_class), "__name__"): module_name = inspect.getmodule(async_class).__name__ if "." in module_name: base_module = module_name.split(".")[0] # Only add if not already importing from it if not any(imp.startswith(f"from {base_module}") for imp in imports): imports.append(f"import {base_module}") return imports @classmethod def _get_class_attributes(cls, async_class: Type) -> list[tuple[str, Type]]: """Extract class attributes that are classes themselves.""" class_attributes = [] # Look for class attributes that are classes for name, attr in sorted(inspect.getmembers(async_class)): if isinstance(attr, type) and not name.startswith("_"): class_attributes.append((name, attr)) elif ( hasattr(async_class, "__annotations__") and name in async_class.__annotations__ ): annotation = async_class.__annotations__[name] if isinstance(annotation, type): class_attributes.append((name, annotation)) return class_attributes @classmethod def _generate_inner_class_stub( cls, name: str, attr: Type, indent: str = " ", type_tracker: Optional[TypeTracker] = None, ) -> list[str]: """Generate stub for an inner class.""" stub_lines = [] stub_lines.append(f"{indent}class {name}Sync:") # Add docstring if available if hasattr(attr, "__doc__") and attr.__doc__: stub_lines.extend( cls._format_docstring_for_stub(attr.__doc__, f"{indent} ") ) # Add __init__ if it exists if hasattr(attr, "__init__"): try: init_method = getattr(attr, "__init__") init_sig = inspect.signature(init_method) # Format parameters params_str = cls._format_method_parameters( init_sig, type_tracker=type_tracker ) # Add __init__ docstring if available (before the method) if hasattr(init_method, "__doc__") and init_method.__doc__: stub_lines.extend( cls._format_docstring_for_stub( init_method.__doc__, f"{indent} " ) ) stub_lines.append( f"{indent} def __init__({params_str}) -> None: ..." ) except (ValueError, TypeError): stub_lines.append( f"{indent} def __init__(self, *args, **kwargs) -> None: ..." ) # Add methods to the inner class has_methods = False for method_name, method in sorted( inspect.getmembers(attr, predicate=inspect.isfunction) ): if method_name.startswith("_"): continue has_methods = True try: # Add method docstring if available (before the method signature) if method.__doc__: stub_lines.extend( cls._format_docstring_for_stub(method.__doc__, f"{indent} ") ) method_sig = cls._generate_method_signature( method_name, method, is_async=True, type_tracker=type_tracker ) stub_lines.append(f"{indent} {method_sig}") except (ValueError, TypeError): stub_lines.append( f"{indent} def {method_name}(self, *args, **kwargs): ..." ) if not has_methods: stub_lines.append(f"{indent} pass") return stub_lines @classmethod def _format_docstring_for_stub( cls, docstring: str, indent: str = " " ) -> list[str]: """Format a docstring for inclusion in a stub file with proper indentation.""" if not docstring: return [] # First, dedent the docstring to remove any existing indentation dedented = textwrap.dedent(docstring).strip() # Split into lines lines = dedented.split("\n") # Build the properly indented docstring result = [] result.append(f'{indent}"""') for line in lines: if line.strip(): # Non-empty line result.append(f"{indent}{line}") else: # Empty line result.append("") result.append(f'{indent}"""') return result @classmethod def _post_process_stub_content(cls, stub_content: list[str]) -> list[str]: """Post-process stub content to fix any remaining issues.""" processed = [] for line in stub_content: # Skip processing imports if line.startswith(("from ", "import ")): processed.append(line) continue # Fix method signatures missing return types if ( line.strip().startswith("def ") and line.strip().endswith(": ...") and ") -> " not in line ): # Add -> None for methods without return annotation line = line.replace(": ...", " -> None: ...") processed.append(line) return processed @classmethod def generate_stub_file(cls, async_class: Type, sync_class: Type) -> None: """ Generate a .pyi stub file for the sync class to help IDEs with type checking. """ try: # Only generate stub if we can determine module path if async_class.__module__ == "__main__": return module = inspect.getmodule(async_class) if not module: return module_path = module.__file__ if not module_path: return # Create stub file path in a 'generated' subdirectory module_dir = os.path.dirname(module_path) stub_dir = os.path.join(module_dir, "generated") # Ensure the generated directory exists os.makedirs(stub_dir, exist_ok=True) module_name = os.path.basename(module_path) if module_name.endswith(".py"): module_name = module_name[:-3] sync_stub_path = os.path.join(stub_dir, f"{sync_class.__name__}.pyi") # Create a type tracker for this stub generation type_tracker = TypeTracker() stub_content = [] # We'll generate imports after processing all methods to capture all types # Leave a placeholder for imports imports_placeholder_index = len(stub_content) stub_content.append("") # Will be replaced with imports later # Class definition stub_content.append(f"class {sync_class.__name__}:") # Docstring if async_class.__doc__: stub_content.extend( cls._format_docstring_for_stub(async_class.__doc__, " ") ) # Generate __init__ try: init_method = async_class.__init__ init_signature = inspect.signature(init_method) # Format parameters params_str = cls._format_method_parameters( init_signature, type_tracker=type_tracker ) # Add __init__ docstring if available (before the method) if hasattr(init_method, "__doc__") and init_method.__doc__: stub_content.extend( cls._format_docstring_for_stub(init_method.__doc__, " ") ) stub_content.append(f" def __init__({params_str}) -> None: ...") except (ValueError, TypeError): stub_content.append( " def __init__(self, *args, **kwargs) -> None: ..." ) stub_content.append("") # Add newline after __init__ # Get class attributes class_attributes = cls._get_class_attributes(async_class) # Generate inner classes for name, attr in class_attributes: inner_class_stub = cls._generate_inner_class_stub( name, attr, type_tracker=type_tracker ) stub_content.extend(inner_class_stub) stub_content.append("") # Add newline after the inner class # Add methods to the main class processed_methods = set() # Keep track of methods we've processed for name, method in sorted( inspect.getmembers(async_class, predicate=inspect.isfunction) ): if name.startswith("_") or name in processed_methods: continue processed_methods.add(name) try: method_sig = cls._generate_method_signature( name, method, is_async=True, type_tracker=type_tracker ) # Add docstring if available (before the method signature for proper formatting) if method.__doc__: stub_content.extend( cls._format_docstring_for_stub(method.__doc__, " ") ) stub_content.append(f" {method_sig}") stub_content.append("") # Add newline after each method except (ValueError, TypeError): # If we can't get the signature, just add a simple stub stub_content.append(f" def {name}(self, *args, **kwargs): ...") stub_content.append("") # Add newline # Add properties for name, prop in sorted( inspect.getmembers(async_class, lambda x: isinstance(x, property)) ): stub_content.append(" @property") stub_content.append(f" def {name}(self) -> Any: ...") if prop.fset: stub_content.append(f" @{name}.setter") stub_content.append( f" def {name}(self, value: Any) -> None: ..." ) stub_content.append("") # Add newline after each property # Add placeholders for the nested class instances # Check the actual attribute names from class annotations and attributes attribute_mappings = {} # First check annotations for typed attributes (including from parent classes) # Collect all annotations from the class hierarchy all_annotations = {} for base_class in reversed(inspect.getmro(async_class)): if hasattr(base_class, "__annotations__"): all_annotations.update(base_class.__annotations__) for attr_name, attr_type in sorted(all_annotations.items()): for class_name, class_type in class_attributes: # If the class type matches the annotated type if attr_type == class_type or ( hasattr(attr_type, "__name__") and attr_type.__name__ == class_name ): attribute_mappings[class_name] = attr_name # Remove the extra checking - annotations should be sufficient # Add the attribute declarations with proper names for class_name, _ in class_attributes: # Use the attribute name if found in mappings, otherwise use class name attr_name = attribute_mappings.get(class_name, class_name) stub_content.append(f" {attr_name}: {class_name}Sync") stub_content.append("") # Add a final newline # Now generate imports with all discovered types imports = cls._generate_imports(async_class, type_tracker) # Deduplicate imports while preserving order seen = set() unique_imports = [] for imp in imports: if imp not in seen: seen.add(imp) unique_imports.append(imp) else: logging.warning(f"Duplicate import detected: {imp}") # Replace the placeholder with actual imports stub_content[imports_placeholder_index : imports_placeholder_index + 1] = ( unique_imports ) # Post-process stub content stub_content = cls._post_process_stub_content(stub_content) # Write stub file with open(sync_stub_path, "w") as f: f.write("\n".join(stub_content)) logging.info(f"Generated stub file: {sync_stub_path}") except Exception as e: # If stub generation fails, log the error but don't break the main functionality logging.error( f"Error generating stub file for {sync_class.__name__}: {str(e)}" ) import traceback logging.error(traceback.format_exc()) def create_sync_class(async_class: Type, thread_pool_size=10) -> Type: """ Creates a sync version of an async class Args: async_class: The async class to convert thread_pool_size: Size of thread pool to use Returns: A new class with sync versions of all async methods """ return AsyncToSyncConverter.create_sync_class(async_class, thread_pool_size)