From fa9688b1fbc1a6f438281ffe84a68bef15947629 Mon Sep 17 00:00:00 2001 From: bymyself Date: Tue, 20 May 2025 12:15:46 -0700 Subject: [PATCH] [docs] Add OpenAPI specification and test framework --- .gitignore | 1 - openapi.yaml | 904 +++++++++++++++++++++++++++ tests-api/README.md | 74 +++ tests-api/conftest.py | 141 +++++ tests-api/requirements.txt | 6 + tests-api/test_api_by_tag.py | 279 +++++++++ tests-api/test_endpoint_existence.py | 240 +++++++ tests-api/test_schema_validation.py | 440 +++++++++++++ tests-api/test_spec_validation.py | 144 +++++ tests-api/utils/schema_utils.py | 159 +++++ tests-api/utils/validation.py | 180 ++++++ 11 files changed, 2567 insertions(+), 1 deletion(-) create mode 100644 openapi.yaml create mode 100644 tests-api/README.md create mode 100644 tests-api/conftest.py create mode 100644 tests-api/requirements.txt create mode 100644 tests-api/test_api_by_tag.py create mode 100644 tests-api/test_endpoint_existence.py create mode 100644 tests-api/test_schema_validation.py create mode 100644 tests-api/test_spec_validation.py create mode 100644 tests-api/utils/schema_utils.py create mode 100644 tests-api/utils/validation.py diff --git a/.gitignore b/.gitignore index 4e8cea71e..26db138c9 100644 --- a/.gitignore +++ b/.gitignore @@ -21,6 +21,5 @@ venv/ *.log web_custom_versions/ .DS_Store -openapi.yaml filtered-openapi.yaml uv.lock diff --git a/openapi.yaml b/openapi.yaml new file mode 100644 index 000000000..82a95bc8c --- /dev/null +++ b/openapi.yaml @@ -0,0 +1,904 @@ +openapi: 3.0.3 +info: + title: ComfyUI API + description: | + API for ComfyUI - A powerful and modular UI for Stable Diffusion. + + This API allows you to interact with ComfyUI programmatically, including: + - Submitting workflows for execution + - Managing the execution queue + - Retrieving generated images + - Managing models + - Retrieving node information + version: 1.0.0 + license: + name: GNU General Public License v3.0 + url: https://github.com/comfyanonymous/ComfyUI/blob/master/LICENSE + +servers: + - url: / + description: Default ComfyUI server + +tags: + - name: workflow + description: Workflow execution and management + - name: queue + description: Queue management + - name: image + description: Image handling + - name: node + description: Node information + - name: model + description: Model management + - name: system + description: System information + - name: internal + description: Internal API routes + +paths: + /prompt: + get: + tags: + - workflow + summary: Get information about current prompt execution + description: Returns information about the current prompt in the execution queue + operationId: getPromptInfo + responses: + '200': + description: Success + content: + application/json: + schema: + $ref: '#/components/schemas/PromptInfo' + post: + tags: + - workflow + summary: Submit a workflow for execution + description: | + Submit a workflow to be executed by the backend. + The workflow is a JSON object describing the nodes and their connections. + operationId: executePrompt + requestBody: + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/PromptRequest' + responses: + '200': + description: Success - Prompt accepted + content: + application/json: + schema: + $ref: '#/components/schemas/PromptResponse' + '400': + description: Invalid prompt + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + + /queue: + get: + tags: + - queue + summary: Get queue information + description: Returns information about running and pending items in the queue + operationId: getQueueInfo + responses: + '200': + description: Success + content: + application/json: + schema: + $ref: '#/components/schemas/QueueInfo' + post: + tags: + - queue + summary: Manage queue + description: Clear the queue or delete specific items + operationId: manageQueue + requestBody: + required: true + content: + application/json: + schema: + type: object + properties: + clear: + type: boolean + description: If true, clears the entire queue + delete: + type: array + description: Array of prompt IDs to delete from the queue + items: + type: string + format: uuid + responses: + '200': + description: Success + + /interrupt: + post: + tags: + - workflow + summary: Interrupt the current execution + description: Interrupts the currently running workflow execution + operationId: interruptExecution + responses: + '200': + description: Success + + /free: + post: + tags: + - system + summary: Free resources + description: Unload models and/or free memory + operationId: freeResources + requestBody: + required: true + content: + application/json: + schema: + type: object + properties: + unload_models: + type: boolean + description: If true, unloads models from memory + free_memory: + type: boolean + description: If true, frees GPU memory + responses: + '200': + description: Success + + /history: + get: + tags: + - workflow + summary: Get execution history + description: Returns the history of executed workflows + operationId: getHistory + parameters: + - name: max_items + in: query + description: Maximum number of history items to return + required: false + schema: + type: integer + format: int32 + responses: + '200': + description: Success + content: + application/json: + schema: + type: array + items: + $ref: '#/components/schemas/HistoryItem' + post: + tags: + - workflow + summary: Manage history + description: Clear history or delete specific items + operationId: manageHistory + requestBody: + required: true + content: + application/json: + schema: + type: object + properties: + clear: + type: boolean + description: If true, clears the entire history + delete: + type: array + description: Array of prompt IDs to delete from history + items: + type: string + format: uuid + responses: + '200': + description: Success + + /history/{prompt_id}: + get: + tags: + - workflow + summary: Get specific history item + description: Returns a specific history item by ID + operationId: getHistoryItem + parameters: + - name: prompt_id + in: path + description: ID of the prompt to retrieve + required: true + schema: + type: string + format: uuid + responses: + '200': + description: Success + content: + application/json: + schema: + $ref: '#/components/schemas/HistoryItem' + + /object_info: + get: + tags: + - node + summary: Get all node information + description: Returns information about all available nodes + operationId: getNodeInfo + responses: + '200': + description: Success + content: + application/json: + schema: + type: object + additionalProperties: + $ref: '#/components/schemas/NodeInfo' + + /object_info/{node_class}: + get: + tags: + - node + summary: Get specific node information + description: Returns information about a specific node class + operationId: getNodeClassInfo + parameters: + - name: node_class + in: path + description: Name of the node class + required: true + schema: + type: string + responses: + '200': + description: Success + content: + application/json: + schema: + type: object + additionalProperties: + $ref: '#/components/schemas/NodeInfo' + + /upload/image: + post: + tags: + - image + summary: Upload an image + description: Uploads an image to the server + operationId: uploadImage + requestBody: + required: true + content: + multipart/form-data: + schema: + type: object + properties: + image: + type: string + format: binary + description: The image file to upload + overwrite: + type: string + description: Whether to overwrite if file exists (true/false) + type: + type: string + enum: [input, temp, output] + description: Type of directory to store the image in + subfolder: + type: string + description: Subfolder to store the image in + responses: + '200': + description: Success + content: + application/json: + schema: + type: object + properties: + name: + type: string + description: Filename of the uploaded image + subfolder: + type: string + description: Subfolder the image was stored in + type: + type: string + description: Type of directory the image was stored in + '400': + description: Bad request + + /upload/mask: + post: + tags: + - image + summary: Upload a mask for an image + description: Uploads a mask image and applies it to a referenced original image + operationId: uploadMask + requestBody: + required: true + content: + multipart/form-data: + schema: + type: object + properties: + image: + type: string + format: binary + description: The mask image file to upload + original_ref: + type: string + description: JSON string containing reference to the original image + responses: + '200': + description: Success + content: + application/json: + schema: + type: object + properties: + name: + type: string + description: Filename of the uploaded mask + subfolder: + type: string + description: Subfolder the mask was stored in + type: + type: string + description: Type of directory the mask was stored in + '400': + description: Bad request + + /view: + get: + tags: + - image + summary: View an image + description: Retrieves an image from the server + operationId: viewImage + parameters: + - name: filename + in: query + description: Name of the file to retrieve + required: true + schema: + type: string + - name: type + in: query + description: Type of directory to retrieve from + required: false + schema: + type: string + enum: [input, temp, output] + default: output + - name: subfolder + in: query + description: Subfolder to retrieve from + required: false + schema: + type: string + - name: preview + in: query + description: Preview options (format;quality) + required: false + schema: + type: string + - name: channel + in: query + description: Channel to retrieve (rgb, a, rgba) + required: false + schema: + type: string + enum: [rgb, a, rgba] + default: rgba + responses: + '200': + description: Success + content: + image/*: + schema: + type: string + format: binary + '400': + description: Bad request + '404': + description: File not found + + /view_metadata/{folder_name}: + get: + tags: + - model + summary: View model metadata + description: Retrieves metadata from a safetensors file + operationId: viewModelMetadata + parameters: + - name: folder_name + in: path + description: Name of the model folder + required: true + schema: + type: string + - name: filename + in: query + description: Name of the safetensors file + required: true + schema: + type: string + responses: + '200': + description: Success + content: + application/json: + schema: + type: object + '404': + description: File not found + + /models: + get: + tags: + - model + summary: Get model types + description: Returns a list of available model types + operationId: getModelTypes + responses: + '200': + description: Success + content: + application/json: + schema: + type: array + items: + type: string + + /models/{folder}: + get: + tags: + - model + summary: Get models of a specific type + description: Returns a list of available models of a specific type + operationId: getModels + parameters: + - name: folder + in: path + description: Model type folder + required: true + schema: + type: string + responses: + '200': + description: Success + content: + application/json: + schema: + type: array + items: + type: string + '404': + description: Folder not found + + /embeddings: + get: + tags: + - model + summary: Get embeddings + description: Returns a list of available embeddings + operationId: getEmbeddings + responses: + '200': + description: Success + content: + application/json: + schema: + type: array + items: + type: string + + /extensions: + get: + tags: + - system + summary: Get extensions + description: Returns a list of available extensions + operationId: getExtensions + responses: + '200': + description: Success + content: + application/json: + schema: + type: array + items: + type: string + + /system_stats: + get: + tags: + - system + summary: Get system statistics + description: Returns system information including RAM, VRAM, and ComfyUI version + operationId: getSystemStats + responses: + '200': + description: Success + content: + application/json: + schema: + $ref: '#/components/schemas/SystemStats' + + /ws: + get: + tags: + - workflow + summary: WebSocket connection + description: | + Establishes a WebSocket connection for real-time communication. + This endpoint is used for receiving progress updates, status changes, and results from workflow executions. + operationId: webSocketConnect + parameters: + - name: clientId + in: query + description: Optional client ID for reconnection + required: false + schema: + type: string + responses: + '101': + description: Switching Protocols to WebSocket + + /internal/logs: + get: + tags: + - internal + summary: Get logs + description: Returns system logs as a single string + operationId: getLogs + responses: + '200': + description: Success + content: + application/json: + schema: + type: string + + /internal/logs/raw: + get: + tags: + - internal + summary: Get raw logs + description: Returns raw system logs with terminal size information + operationId: getRawLogs + responses: + '200': + description: Success + content: + application/json: + schema: + type: object + properties: + entries: + type: array + items: + type: object + properties: + t: + type: string + description: Timestamp + m: + type: string + description: Message + size: + type: object + properties: + cols: + type: integer + description: Terminal columns + rows: + type: integer + description: Terminal rows + + /internal/logs/subscribe: + patch: + tags: + - internal + summary: Subscribe to logs + description: Subscribe or unsubscribe to log updates + operationId: subscribeToLogs + requestBody: + required: true + content: + application/json: + schema: + type: object + properties: + clientId: + type: string + description: Client ID + enabled: + type: boolean + description: Whether to enable or disable subscription + responses: + '200': + description: Success + + /internal/folder_paths: + get: + tags: + - internal + summary: Get folder paths + description: Returns a map of folder names to their paths + operationId: getFolderPaths + responses: + '200': + description: Success + content: + application/json: + schema: + type: object + additionalProperties: + type: string + + /internal/files/{directory_type}: + get: + tags: + - internal + summary: Get files + description: Returns a list of files in a specific directory type + operationId: getFiles + parameters: + - name: directory_type + in: path + description: Type of directory (output, input, temp) + required: true + schema: + type: string + enum: [output, input, temp] + responses: + '200': + description: Success + content: + application/json: + schema: + type: array + items: + type: string + '400': + description: Invalid directory type + +components: + schemas: + PromptRequest: + type: object + required: + - prompt + properties: + prompt: + type: object + description: The workflow graph to execute + additionalProperties: true + number: + type: number + description: Priority number for the queue (lower numbers have higher priority) + front: + type: boolean + description: If true, adds the prompt to the front of the queue + extra_data: + type: object + description: Extra data to be associated with the prompt + additionalProperties: true + client_id: + type: string + description: Client ID for attribution of the prompt + + PromptResponse: + type: object + properties: + prompt_id: + type: string + format: uuid + description: Unique identifier for the prompt execution + number: + type: number + description: Priority number in the queue + node_errors: + type: object + description: Any errors in the nodes of the prompt + additionalProperties: true + + ErrorResponse: + type: object + properties: + error: + type: object + properties: + type: + type: string + description: Error type + message: + type: string + description: Error message + details: + type: string + description: Detailed error information + extra_info: + type: object + description: Additional error information + additionalProperties: true + node_errors: + type: object + description: Node-specific errors + additionalProperties: true + + PromptInfo: + type: object + properties: + exec_info: + type: object + properties: + queue_remaining: + type: integer + description: Number of items remaining in the queue + + QueueInfo: + type: object + properties: + queue_running: + type: array + items: + type: object + description: Currently running items + additionalProperties: true + queue_pending: + type: array + items: + type: object + description: Pending items in the queue + additionalProperties: true + + HistoryItem: + type: object + properties: + prompt_id: + type: string + format: uuid + description: Unique identifier for the prompt + prompt: + type: object + description: The workflow graph that was executed + additionalProperties: true + extra_data: + type: object + description: Additional data associated with the execution + additionalProperties: true + outputs: + type: object + description: Output data from the execution + additionalProperties: true + + NodeInfo: + type: object + properties: + input: + type: object + description: Input specifications for the node + additionalProperties: true + input_order: + type: object + description: Order of inputs for display + additionalProperties: + type: array + items: + type: string + output: + type: array + items: + type: string + description: Output types of the node + output_is_list: + type: array + items: + type: boolean + description: Whether each output is a list + output_name: + type: array + items: + type: string + description: Names of the outputs + name: + type: string + description: Internal name of the node + display_name: + type: string + description: Display name of the node + description: + type: string + description: Description of the node + python_module: + type: string + description: Python module implementing the node + category: + type: string + description: Category of the node + output_node: + type: boolean + description: Whether this is an output node + output_tooltips: + type: array + items: + type: string + description: Tooltips for outputs + deprecated: + type: boolean + description: Whether the node is deprecated + experimental: + type: boolean + description: Whether the node is experimental + api_node: + type: boolean + description: Whether this is an API node + + SystemStats: + type: object + properties: + system: + type: object + properties: + os: + type: string + description: Operating system + ram_total: + type: number + description: Total system RAM in bytes + ram_free: + type: number + description: Free system RAM in bytes + comfyui_version: + type: string + description: ComfyUI version + python_version: + type: string + description: Python version + pytorch_version: + type: string + description: PyTorch version + embedded_python: + type: boolean + description: Whether using embedded Python + argv: + type: array + items: + type: string + description: Command line arguments + devices: + type: array + items: + type: object + properties: + name: + type: string + description: Device name + type: + type: string + description: Device type + index: + type: integer + description: Device index + vram_total: + type: number + description: Total VRAM in bytes + vram_free: + type: number + description: Free VRAM in bytes + torch_vram_total: + type: number + description: Total VRAM as reported by PyTorch + torch_vram_free: + type: number + description: Free VRAM as reported by PyTorch \ No newline at end of file diff --git a/tests-api/README.md b/tests-api/README.md new file mode 100644 index 000000000..259211b8f --- /dev/null +++ b/tests-api/README.md @@ -0,0 +1,74 @@ +# ComfyUI API Testing + +This directory contains tests for validating the ComfyUI OpenAPI specification against a running instance of ComfyUI. + +## Setup + +1. Install the required dependencies: + +```bash +pip install -r requirements.txt +``` + +2. Make sure you have a running instance of ComfyUI (default: http://127.0.0.1:8188) + +## Running the Tests + +Run all tests with pytest: + +```bash +cd tests-api +pytest +``` + +Run specific test files: + +```bash +pytest test_spec_validation.py +pytest test_endpoint_existence.py +pytest test_schema_validation.py +pytest test_api_by_tag.py +``` + +Run tests with more verbose output: + +```bash +pytest -v +``` + +## Test Categories + +The tests are organized into several categories: + +1. **Spec Validation**: Validates that the OpenAPI specification is valid. +2. **Endpoint Existence**: Tests that the endpoints defined in the spec exist on the server. +3. **Schema Validation**: Tests that the server responses match the schemas defined in the spec. +4. **Tag-Based Tests**: Tests that the API's tag organization is consistent. + +## Using a Different Server + +By default, the tests connect to `http://127.0.0.1:8188`. To test against a different server, set the `COMFYUI_SERVER_URL` environment variable: + +```bash +COMFYUI_SERVER_URL=http://example.com:8188 pytest +``` + +## Test Structure + +- `conftest.py`: Contains pytest fixtures used by the tests. +- `utils/`: Contains utility functions for working with the OpenAPI spec. +- `test_*.py`: The actual test files. +- `resources/`: Contains resources used by the tests (e.g., sample workflows). + +## Extending the Tests + +To add new tests: + +1. For testing new endpoints, add them to the appropriate test file based on their category. +2. For testing more complex functionality, create a new test file following the established patterns. + +## Notes + +- Tests that require a running server will be skipped if the server is not available. +- Some tests may fail if the server doesn't match the specification exactly. +- The tests don't modify any data on the server (they're read-only). \ No newline at end of file diff --git a/tests-api/conftest.py b/tests-api/conftest.py new file mode 100644 index 000000000..fa64bb535 --- /dev/null +++ b/tests-api/conftest.py @@ -0,0 +1,141 @@ +""" +Test fixtures for API testing +""" +import os +import pytest +import yaml +import requests +import logging +from typing import Dict, Any, Generator, Optional +from urllib.parse import urljoin + +# Setup logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +# Default server configuration +DEFAULT_SERVER_URL = "http://127.0.0.1:8188" + + +@pytest.fixture(scope="session") +def api_spec_path() -> str: + """ + Get the path to the OpenAPI specification file + + Returns: + Path to the OpenAPI specification file + """ + return os.path.abspath(os.path.join( + os.path.dirname(__file__), + "..", + "openapi.yaml" + )) + + +@pytest.fixture(scope="session") +def api_spec(api_spec_path: str) -> Dict[str, Any]: + """ + Load the OpenAPI specification + + Args: + api_spec_path: Path to the spec file + + Returns: + Parsed OpenAPI specification + """ + with open(api_spec_path, 'r') as f: + return yaml.safe_load(f) + + +@pytest.fixture(scope="session") +def base_url() -> str: + """ + Get the base URL for the API server + + Returns: + Base URL string + """ + # Allow overriding via environment variable + return os.environ.get("COMFYUI_SERVER_URL", DEFAULT_SERVER_URL) + + +@pytest.fixture(scope="session") +def server_available(base_url: str) -> bool: + """ + Check if the server is available + + Args: + base_url: Base URL for the API + + Returns: + True if the server is available, False otherwise + """ + try: + response = requests.get(base_url, timeout=2) + return response.status_code == 200 + except requests.RequestException: + logger.warning(f"Server at {base_url} is not available") + return False + + +@pytest.fixture +def api_client(base_url: str) -> Generator[Optional[requests.Session], None, None]: + """ + Create a requests session for API testing + + Args: + base_url: Base URL for the API + + Yields: + Requests session configured for the API + """ + session = requests.Session() + + # Helper function to construct URLs + def get_url(path: str) -> str: + return urljoin(base_url, path) + + # Add url helper to the session + session.get_url = get_url # type: ignore + + yield session + + # Cleanup + session.close() + + +@pytest.fixture +def api_get_json(api_client: requests.Session): + """ + Helper fixture for making GET requests and parsing JSON responses + + Args: + api_client: API client session + + Returns: + Function that makes GET requests and returns JSON + """ + def _get_json(path: str, **kwargs): + url = api_client.get_url(path) # type: ignore + response = api_client.get(url, **kwargs) + + if response.status_code == 200: + try: + return response.json() + except ValueError: + return None + return None + + return _get_json + + +@pytest.fixture +def require_server(server_available): + """ + Skip tests if server is not available + + Args: + server_available: Whether the server is available + """ + if not server_available: + pytest.skip("Server is not available") \ No newline at end of file diff --git a/tests-api/requirements.txt b/tests-api/requirements.txt new file mode 100644 index 000000000..6f311f2f5 --- /dev/null +++ b/tests-api/requirements.txt @@ -0,0 +1,6 @@ +pytest>=7.0.0 +pytest-asyncio>=0.21.0 +openapi-spec-validator>=0.5.0 +jsonschema>=4.17.0 +requests>=2.28.0 +pyyaml>=6.0.0 \ No newline at end of file diff --git a/tests-api/test_api_by_tag.py b/tests-api/test_api_by_tag.py new file mode 100644 index 000000000..cc22fc387 --- /dev/null +++ b/tests-api/test_api_by_tag.py @@ -0,0 +1,279 @@ +""" +Tests for API endpoints grouped by tags +""" +import pytest +import logging +import sys +import os +from typing import Dict, Any, List, Set + +# Use a direct import with the full path +current_dir = os.path.dirname(os.path.abspath(__file__)) +sys.path.insert(0, current_dir) + +# Define functions inline to avoid import issues +def get_all_endpoints(spec): + """ + Extract all endpoints from an OpenAPI spec + """ + endpoints = [] + + for path, path_item in spec['paths'].items(): + for method, operation in path_item.items(): + if method.lower() not in ['get', 'post', 'put', 'delete', 'patch']: + continue + + endpoints.append({ + 'path': path, + 'method': method.lower(), + 'tags': operation.get('tags', []), + 'operation_id': operation.get('operationId', ''), + 'summary': operation.get('summary', '') + }) + + return endpoints + +def get_all_tags(spec): + """ + Get all tags used in the API spec + """ + tags = set() + + for path_item in spec['paths'].values(): + for operation in path_item.values(): + if isinstance(operation, dict) and 'tags' in operation: + tags.update(operation['tags']) + + return tags + +def extract_endpoints_by_tag(spec, tag): + """ + Extract all endpoints with a specific tag + """ + endpoints = [] + + for path, path_item in spec['paths'].items(): + for method, operation in path_item.items(): + if method.lower() not in ['get', 'post', 'put', 'delete', 'patch']: + continue + + if tag in operation.get('tags', []): + endpoints.append({ + 'path': path, + 'method': method.lower(), + 'operation_id': operation.get('operationId', ''), + 'summary': operation.get('summary', '') + }) + + return endpoints + +# Setup logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +@pytest.fixture +def api_tags(api_spec: Dict[str, Any]) -> Set[str]: + """ + Get all tags from the API spec + + Args: + api_spec: Loaded OpenAPI spec + + Returns: + Set of tag names + """ + return get_all_tags(api_spec) + + +def test_api_has_tags(api_tags: Set[str]): + """ + Test that the API has defined tags + + Args: + api_tags: Set of tags + """ + assert len(api_tags) > 0, "API spec should have at least one tag" + + # Log the tags + logger.info(f"API spec has the following tags: {sorted(api_tags)}") + + +@pytest.mark.parametrize("tag", [ + "workflow", + "image", + "model", + "node", + "system" +]) +def test_core_tags_exist(api_tags: Set[str], tag: str): + """ + Test that core tags exist in the API spec + + Args: + api_tags: Set of tags + tag: Tag to check + """ + assert tag in api_tags, f"API spec should have '{tag}' tag" + + +def test_workflow_tag_has_endpoints(api_spec: Dict[str, Any]): + """ + Test that the 'workflow' tag has appropriate endpoints + + Args: + api_spec: Loaded OpenAPI spec + """ + endpoints = extract_endpoints_by_tag(api_spec, "workflow") + + assert len(endpoints) > 0, "No endpoints found with 'workflow' tag" + + # Check for key workflow endpoints + endpoint_paths = [e["path"] for e in endpoints] + assert "/prompt" in endpoint_paths, "Workflow tag should include /prompt endpoint" + + # Log the endpoints + logger.info(f"Found {len(endpoints)} endpoints with 'workflow' tag:") + for e in endpoints: + logger.info(f" {e['method'].upper()} {e['path']}") + + +def test_image_tag_has_endpoints(api_spec: Dict[str, Any]): + """ + Test that the 'image' tag has appropriate endpoints + + Args: + api_spec: Loaded OpenAPI spec + """ + endpoints = extract_endpoints_by_tag(api_spec, "image") + + assert len(endpoints) > 0, "No endpoints found with 'image' tag" + + # Check for key image endpoints + endpoint_paths = [e["path"] for e in endpoints] + assert "/upload/image" in endpoint_paths, "Image tag should include /upload/image endpoint" + assert "/view" in endpoint_paths, "Image tag should include /view endpoint" + + # Log the endpoints + logger.info(f"Found {len(endpoints)} endpoints with 'image' tag:") + for e in endpoints: + logger.info(f" {e['method'].upper()} {e['path']}") + + +def test_model_tag_has_endpoints(api_spec: Dict[str, Any]): + """ + Test that the 'model' tag has appropriate endpoints + + Args: + api_spec: Loaded OpenAPI spec + """ + endpoints = extract_endpoints_by_tag(api_spec, "model") + + assert len(endpoints) > 0, "No endpoints found with 'model' tag" + + # Check for key model endpoints + endpoint_paths = [e["path"] for e in endpoints] + assert "/models" in endpoint_paths, "Model tag should include /models endpoint" + + # Log the endpoints + logger.info(f"Found {len(endpoints)} endpoints with 'model' tag:") + for e in endpoints: + logger.info(f" {e['method'].upper()} {e['path']}") + + +def test_node_tag_has_endpoints(api_spec: Dict[str, Any]): + """ + Test that the 'node' tag has appropriate endpoints + + Args: + api_spec: Loaded OpenAPI spec + """ + endpoints = extract_endpoints_by_tag(api_spec, "node") + + assert len(endpoints) > 0, "No endpoints found with 'node' tag" + + # Check for key node endpoints + endpoint_paths = [e["path"] for e in endpoints] + assert "/object_info" in endpoint_paths, "Node tag should include /object_info endpoint" + + # Log the endpoints + logger.info(f"Found {len(endpoints)} endpoints with 'node' tag:") + for e in endpoints: + logger.info(f" {e['method'].upper()} {e['path']}") + + +def test_system_tag_has_endpoints(api_spec: Dict[str, Any]): + """ + Test that the 'system' tag has appropriate endpoints + + Args: + api_spec: Loaded OpenAPI spec + """ + endpoints = extract_endpoints_by_tag(api_spec, "system") + + assert len(endpoints) > 0, "No endpoints found with 'system' tag" + + # Check for key system endpoints + endpoint_paths = [e["path"] for e in endpoints] + assert "/system_stats" in endpoint_paths, "System tag should include /system_stats endpoint" + + # Log the endpoints + logger.info(f"Found {len(endpoints)} endpoints with 'system' tag:") + for e in endpoints: + logger.info(f" {e['method'].upper()} {e['path']}") + + +def test_internal_tag_has_endpoints(api_spec: Dict[str, Any]): + """ + Test that the 'internal' tag has appropriate endpoints + + Args: + api_spec: Loaded OpenAPI spec + """ + endpoints = extract_endpoints_by_tag(api_spec, "internal") + + assert len(endpoints) > 0, "No endpoints found with 'internal' tag" + + # Check for key internal endpoints + endpoint_paths = [e["path"] for e in endpoints] + assert "/internal/logs" in endpoint_paths, "Internal tag should include /internal/logs endpoint" + + # Log the endpoints + logger.info(f"Found {len(endpoints)} endpoints with 'internal' tag:") + for e in endpoints: + logger.info(f" {e['method'].upper()} {e['path']}") + + +def test_operation_ids_match_tag(api_spec: Dict[str, Any]): + """ + Test that operation IDs follow a consistent pattern with their tag + + Args: + api_spec: Loaded OpenAPI spec + """ + failures = [] + + for path, path_item in api_spec['paths'].items(): + for method, operation in path_item.items(): + if method in ['get', 'post', 'put', 'delete', 'patch']: + if 'operationId' in operation and 'tags' in operation and operation['tags']: + op_id = operation['operationId'] + primary_tag = operation['tags'][0].lower() + + # Check if operationId starts with primary tag prefix + # This is a common convention, but might need adjusting + if not (op_id.startswith(primary_tag) or + any(op_id.lower().startswith(f"{tag.lower()}") for tag in operation['tags'])): + failures.append({ + 'path': path, + 'method': method, + 'operationId': op_id, + 'primary_tag': primary_tag + }) + + # Log failures for diagnosis but don't fail the test + # as this is a style/convention check + if failures: + logger.warning(f"Found {len(failures)} operationIds that don't align with their tags:") + for f in failures: + logger.warning(f" {f['method'].upper()} {f['path']} - operationId: {f['operationId']}, primary tag: {f['primary_tag']}") \ No newline at end of file diff --git a/tests-api/test_endpoint_existence.py b/tests-api/test_endpoint_existence.py new file mode 100644 index 000000000..3b5111ab8 --- /dev/null +++ b/tests-api/test_endpoint_existence.py @@ -0,0 +1,240 @@ +""" +Tests for endpoint existence and basic response codes +""" +import pytest +import requests +import logging +import sys +import os +from typing import Dict, Any, List +from urllib.parse import urljoin + +# Use a direct import with the full path +current_dir = os.path.dirname(os.path.abspath(__file__)) +sys.path.insert(0, current_dir) + +# Define get_all_endpoints function inline to avoid import issues +def get_all_endpoints(spec): + """ + Extract all endpoints from an OpenAPI spec + + Args: + spec: Parsed OpenAPI specification + + Returns: + List of dicts with path, method, and tags for each endpoint + """ + endpoints = [] + + for path, path_item in spec['paths'].items(): + for method, operation in path_item.items(): + if method.lower() not in ['get', 'post', 'put', 'delete', 'patch']: + continue + + endpoints.append({ + 'path': path, + 'method': method.lower(), + 'tags': operation.get('tags', []), + 'operation_id': operation.get('operationId', ''), + 'summary': operation.get('summary', '') + }) + + return endpoints + +# Setup logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +@pytest.fixture +def all_endpoints(api_spec: Dict[str, Any]) -> List[Dict[str, Any]]: + """ + Get all endpoints from the API spec + + Args: + api_spec: Loaded OpenAPI spec + + Returns: + List of endpoint information + """ + return get_all_endpoints(api_spec) + + +def test_endpoints_exist(all_endpoints: List[Dict[str, Any]]): + """ + Test that endpoints are defined in the spec + + Args: + all_endpoints: List of endpoint information + """ + # Simple check that we have endpoints defined + assert len(all_endpoints) > 0, "No endpoints defined in the OpenAPI spec" + + # Log the endpoints for informational purposes + logger.info(f"Found {len(all_endpoints)} endpoints in the OpenAPI spec") + for endpoint in all_endpoints: + logger.info(f"{endpoint['method'].upper()} {endpoint['path']} - {endpoint['summary']}") + + +@pytest.mark.parametrize("endpoint_path", [ + "/", # Root path + "/prompt", # Get prompt info + "/queue", # Get queue + "/models", # Get model types + "/object_info", # Get node info + "/system_stats" # Get system stats +]) +def test_basic_get_endpoints(require_server, api_client, endpoint_path: str): + """ + Test that basic GET endpoints exist and respond + + Args: + require_server: Fixture that skips if server is not available + api_client: API client fixture + endpoint_path: Path to test + """ + url = api_client.get_url(endpoint_path) # type: ignore + + try: + response = api_client.get(url) + + # We're just checking that the endpoint exists and returns some kind of response + # Not necessarily a 200 status code + assert response.status_code not in [404, 405], f"Endpoint {endpoint_path} does not exist" + + logger.info(f"Endpoint {endpoint_path} exists with status code {response.status_code}") + + except requests.RequestException as e: + pytest.fail(f"Request to {endpoint_path} failed: {str(e)}") + + +def test_websocket_endpoint_exists(require_server, base_url: str): + """ + Test that the WebSocket endpoint exists + + Args: + require_server: Fixture that skips if server is not available + base_url: Base server URL + """ + ws_url = urljoin(base_url, "/ws") + + # For WebSocket, we can't use a normal GET request + # Instead, we make a HEAD request to check if the endpoint exists + try: + response = requests.head(ws_url) + + # WebSocket endpoints often return a 400 Bad Request for HEAD requests + # but a 404 would indicate the endpoint doesn't exist + assert response.status_code != 404, "WebSocket endpoint /ws does not exist" + + logger.info(f"WebSocket endpoint exists with status code {response.status_code}") + + except requests.RequestException as e: + pytest.fail(f"Request to WebSocket endpoint failed: {str(e)}") + + +def test_api_models_folder_endpoint(require_server, api_client): + """ + Test that the /models/{folder} endpoint exists and responds + + Args: + require_server: Fixture that skips if server is not available + api_client: API client fixture + """ + # First get available model types + models_url = api_client.get_url("/models") # type: ignore + + try: + models_response = api_client.get(models_url) + assert models_response.status_code == 200, "Failed to get model types" + + model_types = models_response.json() + + # Skip if no model types available + if not model_types: + pytest.skip("No model types available to test") + + # Test with the first model type + model_type = model_types[0] + models_folder_url = api_client.get_url(f"/models/{model_type}") # type: ignore + + folder_response = api_client.get(models_folder_url) + + # We're just checking that the endpoint exists + assert folder_response.status_code != 404, f"Endpoint /models/{model_type} does not exist" + + logger.info(f"Endpoint /models/{model_type} exists with status code {folder_response.status_code}") + + except requests.RequestException as e: + pytest.fail(f"Request failed: {str(e)}") + except (ValueError, KeyError, IndexError) as e: + pytest.fail(f"Failed to process response: {str(e)}") + + +def test_api_object_info_node_endpoint(require_server, api_client): + """ + Test that the /object_info/{node_class} endpoint exists and responds + + Args: + require_server: Fixture that skips if server is not available + api_client: API client fixture + """ + # First get available node classes + objects_url = api_client.get_url("/object_info") # type: ignore + + try: + objects_response = api_client.get(objects_url) + assert objects_response.status_code == 200, "Failed to get object info" + + node_classes = objects_response.json() + + # Skip if no node classes available + if not node_classes: + pytest.skip("No node classes available to test") + + # Test with the first node class + node_class = next(iter(node_classes.keys())) + node_url = api_client.get_url(f"/object_info/{node_class}") # type: ignore + + node_response = api_client.get(node_url) + + # We're just checking that the endpoint exists + assert node_response.status_code != 404, f"Endpoint /object_info/{node_class} does not exist" + + logger.info(f"Endpoint /object_info/{node_class} exists with status code {node_response.status_code}") + + except requests.RequestException as e: + pytest.fail(f"Request failed: {str(e)}") + except (ValueError, KeyError, StopIteration) as e: + pytest.fail(f"Failed to process response: {str(e)}") + + +def test_internal_endpoints_exist(require_server, api_client): + """ + Test that internal endpoints exist + + Args: + require_server: Fixture that skips if server is not available + api_client: API client fixture + """ + internal_endpoints = [ + "/internal/logs", + "/internal/logs/raw", + "/internal/folder_paths", + "/internal/files/output" + ] + + for endpoint in internal_endpoints: + url = api_client.get_url(endpoint) # type: ignore + + try: + response = api_client.get(url) + + # We're just checking that the endpoint exists + assert response.status_code != 404, f"Endpoint {endpoint} does not exist" + + logger.info(f"Endpoint {endpoint} exists with status code {response.status_code}") + + except requests.RequestException as e: + logger.warning(f"Request to {endpoint} failed: {str(e)}") + # Don't fail the test as internal endpoints might be restricted \ No newline at end of file diff --git a/tests-api/test_schema_validation.py b/tests-api/test_schema_validation.py new file mode 100644 index 000000000..87a7f27c6 --- /dev/null +++ b/tests-api/test_schema_validation.py @@ -0,0 +1,440 @@ +""" +Tests for validating API responses against OpenAPI schema +""" +import pytest +import requests +import logging +import sys +import os +import json +from typing import Dict, Any, List + +# Use a direct import with the full path +current_dir = os.path.dirname(os.path.abspath(__file__)) +sys.path.insert(0, current_dir) + +# Define validation functions inline to avoid import issues +def get_endpoint_schema( + spec, + path, + method, + status_code = '200' +): + """ + Extract response schema for a specific endpoint from OpenAPI spec + """ + method = method.lower() + + # Handle path not found + if path not in spec['paths']: + return None + + # Handle method not found + if method not in spec['paths'][path]: + return None + + # Handle status code not found + responses = spec['paths'][path][method].get('responses', {}) + if status_code not in responses: + return None + + # Handle no content defined + if 'content' not in responses[status_code]: + return None + + # Get schema from first content type + content_types = responses[status_code]['content'] + first_content_type = next(iter(content_types)) + + if 'schema' not in content_types[first_content_type]: + return None + + return content_types[first_content_type]['schema'] + +def resolve_schema_refs(schema, spec): + """ + Resolve $ref references in a schema + """ + if not isinstance(schema, dict): + return schema + + result = {} + + for key, value in schema.items(): + if key == '$ref' and isinstance(value, str) and value.startswith('#/'): + # Handle reference + ref_path = value[2:].split('/') + ref_value = spec + for path_part in ref_path: + ref_value = ref_value.get(path_part, {}) + + # Recursively resolve any refs in the referenced schema + ref_value = resolve_schema_refs(ref_value, spec) + result.update(ref_value) + elif isinstance(value, dict): + # Recursively resolve refs in nested dictionaries + result[key] = resolve_schema_refs(value, spec) + elif isinstance(value, list): + # Recursively resolve refs in list items + result[key] = [ + resolve_schema_refs(item, spec) if isinstance(item, dict) else item + for item in value + ] + else: + # Pass through other values + result[key] = value + + return result + +def validate_response( + response_data, + spec, + path, + method, + status_code = '200' +): + """ + Validate a response against the OpenAPI schema + """ + schema = get_endpoint_schema(spec, path, method, status_code) + + if schema is None: + return { + 'valid': False, + 'errors': [f"No schema found for {method.upper()} {path} with status {status_code}"] + } + + # Resolve any $ref in the schema + resolved_schema = resolve_schema_refs(schema, spec) + + try: + import jsonschema + jsonschema.validate(instance=response_data, schema=resolved_schema) + return {'valid': True, 'errors': []} + except jsonschema.exceptions.ValidationError as e: + # Extract more detailed error information + path = ".".join(str(p) for p in e.path) if e.path else "root" + instance = e.instance if not isinstance(e.instance, dict) else "..." + schema_path = ".".join(str(p) for p in e.schema_path) if e.schema_path else "unknown" + + detailed_error = ( + f"Validation error at path: {path}\n" + f"Schema path: {schema_path}\n" + f"Error message: {e.message}\n" + f"Failed instance: {instance}\n" + ) + + return {'valid': False, 'errors': [detailed_error]} + +# Setup logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +@pytest.mark.parametrize("endpoint_path,method", [ + ("/system_stats", "get"), + ("/prompt", "get"), + ("/queue", "get"), + ("/models", "get"), + ("/embeddings", "get") +]) +def test_response_schema_validation( + require_server, + api_client, + api_spec: Dict[str, Any], + endpoint_path: str, + method: str +): + """ + Test that API responses match the defined schema + + Args: + require_server: Fixture that skips if server is not available + api_client: API client fixture + api_spec: Loaded OpenAPI spec + endpoint_path: Path to test + method: HTTP method to test + """ + url = api_client.get_url(endpoint_path) # type: ignore + + # Skip if no schema defined + schema = get_endpoint_schema(api_spec, endpoint_path, method) + if not schema: + pytest.skip(f"No schema defined for {method.upper()} {endpoint_path}") + + try: + if method.lower() == "get": + response = api_client.get(url) + else: + pytest.skip(f"Method {method} not implemented for automated testing") + return + + # Skip if response is not 200 + if response.status_code != 200: + pytest.skip(f"Endpoint {endpoint_path} returned status {response.status_code}") + return + + # Skip if response is not JSON + try: + response_data = response.json() + except ValueError: + pytest.skip(f"Endpoint {endpoint_path} did not return valid JSON") + return + + # Validate the response + validation_result = validate_response( + response_data, + api_spec, + endpoint_path, + method + ) + + if validation_result['valid']: + logger.info(f"Response from {method.upper()} {endpoint_path} matches schema") + else: + for error in validation_result['errors']: + logger.error(f"Validation error for {method.upper()} {endpoint_path}: {error}") + + assert validation_result['valid'], f"Response from {method.upper()} {endpoint_path} does not match schema" + + except requests.RequestException as e: + pytest.fail(f"Request to {endpoint_path} failed: {str(e)}") + + +def test_system_stats_response(require_server, api_client, api_spec: Dict[str, Any]): + """ + Test the system_stats endpoint response in detail + + Args: + require_server: Fixture that skips if server is not available + api_client: API client fixture + api_spec: Loaded OpenAPI spec + """ + url = api_client.get_url("/system_stats") # type: ignore + + try: + response = api_client.get(url) + + assert response.status_code == 200, "Failed to get system stats" + + # Parse response + stats = response.json() + + # Validate high-level structure + assert 'system' in stats, "Response missing 'system' field" + assert 'devices' in stats, "Response missing 'devices' field" + + # Validate system fields + system = stats['system'] + assert 'os' in system, "System missing 'os' field" + assert 'ram_total' in system, "System missing 'ram_total' field" + assert 'ram_free' in system, "System missing 'ram_free' field" + assert 'comfyui_version' in system, "System missing 'comfyui_version' field" + + # Validate devices fields + devices = stats['devices'] + assert isinstance(devices, list), "Devices should be a list" + + if devices: + device = devices[0] + assert 'name' in device, "Device missing 'name' field" + assert 'type' in device, "Device missing 'type' field" + assert 'vram_total' in device, "Device missing 'vram_total' field" + assert 'vram_free' in device, "Device missing 'vram_free' field" + + # Perform schema validation + validation_result = validate_response( + stats, + api_spec, + "/system_stats", + "get" + ) + + # Print detailed error if validation fails + if not validation_result['valid']: + for error in validation_result['errors']: + logger.error(f"Validation error for /system_stats: {error}") + + # Print schema details for debugging + schema = get_endpoint_schema(api_spec, "/system_stats", "get") + if schema: + logger.error(f"Schema structure:\n{json.dumps(schema, indent=2)}") + + # Print sample of the response + logger.error(f"Response:\n{json.dumps(stats, indent=2)}") + + assert validation_result['valid'], "System stats response does not match schema" + + except requests.RequestException as e: + pytest.fail(f"Request to /system_stats failed: {str(e)}") + + +def test_models_listing_response(require_server, api_client, api_spec: Dict[str, Any]): + """ + Test the models endpoint response + + Args: + require_server: Fixture that skips if server is not available + api_client: API client fixture + api_spec: Loaded OpenAPI spec + """ + url = api_client.get_url("/models") # type: ignore + + try: + response = api_client.get(url) + + assert response.status_code == 200, "Failed to get models" + + # Parse response + models = response.json() + + # Validate it's a list + assert isinstance(models, list), "Models response should be a list" + + # Each item should be a string + for model in models: + assert isinstance(model, str), "Each model type should be a string" + + # Perform schema validation + validation_result = validate_response( + models, + api_spec, + "/models", + "get" + ) + + # Print detailed error if validation fails + if not validation_result['valid']: + for error in validation_result['errors']: + logger.error(f"Validation error for /models: {error}") + + # Print schema details for debugging + schema = get_endpoint_schema(api_spec, "/models", "get") + if schema: + logger.error(f"Schema structure:\n{json.dumps(schema, indent=2)}") + + # Print response + sample_models = models[:5] if isinstance(models, list) else models + logger.error(f"Models response:\n{json.dumps(sample_models, indent=2)}") + + assert validation_result['valid'], "Models response does not match schema" + + except requests.RequestException as e: + pytest.fail(f"Request to /models failed: {str(e)}") + + +def test_object_info_response(require_server, api_client, api_spec: Dict[str, Any]): + """ + Test the object_info endpoint response + + Args: + require_server: Fixture that skips if server is not available + api_client: API client fixture + api_spec: Loaded OpenAPI spec + """ + url = api_client.get_url("/object_info") # type: ignore + + try: + response = api_client.get(url) + + assert response.status_code == 200, "Failed to get object info" + + # Parse response + objects = response.json() + + # Validate it's an object + assert isinstance(objects, dict), "Object info response should be an object" + + # Check if we have any objects + if objects: + # Get the first object + first_obj_name = next(iter(objects.keys())) + first_obj = objects[first_obj_name] + + # Validate first object has required fields + assert 'input' in first_obj, "Object missing 'input' field" + assert 'output' in first_obj, "Object missing 'output' field" + assert 'name' in first_obj, "Object missing 'name' field" + + # Perform schema validation + validation_result = validate_response( + objects, + api_spec, + "/object_info", + "get" + ) + + # Print detailed error if validation fails + if not validation_result['valid']: + for error in validation_result['errors']: + logger.error(f"Validation error for /object_info: {error}") + + # Print schema details for debugging + schema = get_endpoint_schema(api_spec, "/object_info", "get") + if schema: + logger.error(f"Schema structure:\n{json.dumps(schema, indent=2)}") + + # Also print a small sample of the response + sample = dict(list(objects.items())[:1]) if objects else {} + logger.error(f"Sample response:\n{json.dumps(sample, indent=2)}") + + assert validation_result['valid'], "Object info response does not match schema" + + except requests.RequestException as e: + pytest.fail(f"Request to /object_info failed: {str(e)}") + except (KeyError, StopIteration) as e: + pytest.fail(f"Failed to process response: {str(e)}") + + +def test_queue_response(require_server, api_client, api_spec: Dict[str, Any]): + """ + Test the queue endpoint response + + Args: + require_server: Fixture that skips if server is not available + api_client: API client fixture + api_spec: Loaded OpenAPI spec + """ + url = api_client.get_url("/queue") # type: ignore + + try: + response = api_client.get(url) + + assert response.status_code == 200, "Failed to get queue" + + # Parse response + queue = response.json() + + # Validate structure + assert 'queue_running' in queue, "Queue missing 'queue_running' field" + assert 'queue_pending' in queue, "Queue missing 'queue_pending' field" + + # Each should be a list + assert isinstance(queue['queue_running'], list), "queue_running should be a list" + assert isinstance(queue['queue_pending'], list), "queue_pending should be a list" + + # Perform schema validation + validation_result = validate_response( + queue, + api_spec, + "/queue", + "get" + ) + + # Print detailed error if validation fails + if not validation_result['valid']: + for error in validation_result['errors']: + logger.error(f"Validation error for /queue: {error}") + + # Print schema details for debugging + schema = get_endpoint_schema(api_spec, "/queue", "get") + if schema: + logger.error(f"Schema structure:\n{json.dumps(schema, indent=2)}") + + # Print response + logger.error(f"Queue response:\n{json.dumps(queue, indent=2)}") + + assert validation_result['valid'], "Queue response does not match schema" + + except requests.RequestException as e: + pytest.fail(f"Request to /queue failed: {str(e)}") \ No newline at end of file diff --git a/tests-api/test_spec_validation.py b/tests-api/test_spec_validation.py new file mode 100644 index 000000000..9fc9db6f3 --- /dev/null +++ b/tests-api/test_spec_validation.py @@ -0,0 +1,144 @@ +""" +Tests for validating the OpenAPI specification +""" +import pytest +from openapi_spec_validator import validate_spec +from openapi_spec_validator.exceptions import OpenAPISpecValidatorError +from typing import Dict, Any + + +def test_openapi_spec_is_valid(api_spec: Dict[str, Any]): + """ + Test that the OpenAPI specification is valid + + Args: + api_spec: Loaded OpenAPI spec + """ + try: + validate_spec(api_spec) + except OpenAPISpecValidatorError as e: + pytest.fail(f"OpenAPI spec validation failed: {str(e)}") + + +def test_spec_has_info(api_spec: Dict[str, Any]): + """ + Test that the OpenAPI spec has the required info section + + Args: + api_spec: Loaded OpenAPI spec + """ + assert 'info' in api_spec, "Spec must have info section" + assert 'title' in api_spec['info'], "Info must have title" + assert 'version' in api_spec['info'], "Info must have version" + + +def test_spec_has_paths(api_spec: Dict[str, Any]): + """ + Test that the OpenAPI spec has paths defined + + Args: + api_spec: Loaded OpenAPI spec + """ + assert 'paths' in api_spec, "Spec must have paths section" + assert len(api_spec['paths']) > 0, "Spec must have at least one path" + + +def test_spec_has_components(api_spec: Dict[str, Any]): + """ + Test that the OpenAPI spec has components defined + + Args: + api_spec: Loaded OpenAPI spec + """ + assert 'components' in api_spec, "Spec must have components section" + assert 'schemas' in api_spec['components'], "Components must have schemas" + + +def test_workflow_endpoints_exist(api_spec: Dict[str, Any]): + """ + Test that core workflow endpoints are defined + + Args: + api_spec: Loaded OpenAPI spec + """ + assert '/prompt' in api_spec['paths'], "Spec must define /prompt endpoint" + assert 'post' in api_spec['paths']['/prompt'], "Spec must define POST /prompt" + assert 'get' in api_spec['paths']['/prompt'], "Spec must define GET /prompt" + + +def test_image_endpoints_exist(api_spec: Dict[str, Any]): + """ + Test that core image endpoints are defined + + Args: + api_spec: Loaded OpenAPI spec + """ + assert '/upload/image' in api_spec['paths'], "Spec must define /upload/image endpoint" + assert '/view' in api_spec['paths'], "Spec must define /view endpoint" + + +def test_model_endpoints_exist(api_spec: Dict[str, Any]): + """ + Test that core model endpoints are defined + + Args: + api_spec: Loaded OpenAPI spec + """ + assert '/models' in api_spec['paths'], "Spec must define /models endpoint" + assert '/models/{folder}' in api_spec['paths'], "Spec must define /models/{folder} endpoint" + + +def test_operation_ids_are_unique(api_spec: Dict[str, Any]): + """ + Test that all operationIds are unique + + Args: + api_spec: Loaded OpenAPI spec + """ + operation_ids = [] + + for path, path_item in api_spec['paths'].items(): + for method, operation in path_item.items(): + if method in ['get', 'post', 'put', 'delete', 'patch']: + if 'operationId' in operation: + operation_ids.append(operation['operationId']) + + # Check for duplicates + duplicates = set([op_id for op_id in operation_ids if operation_ids.count(op_id) > 1]) + assert len(duplicates) == 0, f"Found duplicate operationIds: {duplicates}" + + +def test_all_endpoints_have_operation_ids(api_spec: Dict[str, Any]): + """ + Test that all endpoints have operationIds + + Args: + api_spec: Loaded OpenAPI spec + """ + missing = [] + + for path, path_item in api_spec['paths'].items(): + for method, operation in path_item.items(): + if method in ['get', 'post', 'put', 'delete', 'patch']: + if 'operationId' not in operation: + missing.append(f"{method.upper()} {path}") + + assert len(missing) == 0, f"Found endpoints without operationIds: {missing}" + + +def test_all_endpoints_have_tags(api_spec: Dict[str, Any]): + """ + Test that all endpoints have tags + + Args: + api_spec: Loaded OpenAPI spec + """ + missing = [] + + for path, path_item in api_spec['paths'].items(): + for method, operation in path_item.items(): + if method in ['get', 'post', 'put', 'delete', 'patch']: + if 'tags' not in operation or not operation['tags']: + missing.append(f"{method.upper()} {path}") + + assert len(missing) == 0, f"Found endpoints without tags: {missing}" \ No newline at end of file diff --git a/tests-api/utils/schema_utils.py b/tests-api/utils/schema_utils.py new file mode 100644 index 000000000..c354f11b4 --- /dev/null +++ b/tests-api/utils/schema_utils.py @@ -0,0 +1,159 @@ +""" +Utilities for working with OpenAPI schemas +""" +import json +import os +from typing import Any, Dict, List, Optional, Set, Tuple + + +def extract_required_parameters( + spec: Dict[str, Any], + path: str, + method: str +) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]: + """ + Extract required parameters for a specific endpoint + + Args: + spec: Parsed OpenAPI specification + path: API path (e.g., '/prompt') + method: HTTP method (e.g., 'get', 'post') + + Returns: + Tuple of (path_params, query_params) containing required parameters + """ + method = method.lower() + path_params = [] + query_params = [] + + # Handle path not found + if path not in spec['paths']: + return path_params, query_params + + # Handle method not found + if method not in spec['paths'][path]: + return path_params, query_params + + # Get parameters + params = spec['paths'][path][method].get('parameters', []) + + for param in params: + if param.get('required', False): + if param.get('in') == 'path': + path_params.append(param) + elif param.get('in') == 'query': + query_params.append(param) + + return path_params, query_params + + +def get_request_body_schema( + spec: Dict[str, Any], + path: str, + method: str +) -> Optional[Dict[str, Any]]: + """ + Get request body schema for a specific endpoint + + Args: + spec: Parsed OpenAPI specification + path: API path (e.g., '/prompt') + method: HTTP method (e.g., 'get', 'post') + + Returns: + Request body schema or None if not found + """ + method = method.lower() + + # Handle path not found + if path not in spec['paths']: + return None + + # Handle method not found + if method not in spec['paths'][path]: + return None + + # Handle no request body + request_body = spec['paths'][path][method].get('requestBody', {}) + if not request_body or 'content' not in request_body: + return None + + # Get schema from first content type + content_types = request_body['content'] + first_content_type = next(iter(content_types)) + + if 'schema' not in content_types[first_content_type]: + return None + + return content_types[first_content_type]['schema'] + + +def extract_endpoints_by_tag(spec: Dict[str, Any], tag: str) -> List[Dict[str, Any]]: + """ + Extract all endpoints with a specific tag + + Args: + spec: Parsed OpenAPI specification + tag: Tag to filter by + + Returns: + List of endpoint details + """ + endpoints = [] + + for path, path_item in spec['paths'].items(): + for method, operation in path_item.items(): + if method.lower() not in ['get', 'post', 'put', 'delete', 'patch']: + continue + + if tag in operation.get('tags', []): + endpoints.append({ + 'path': path, + 'method': method.lower(), + 'operation_id': operation.get('operationId', ''), + 'summary': operation.get('summary', '') + }) + + return endpoints + + +def get_all_tags(spec: Dict[str, Any]) -> Set[str]: + """ + Get all tags used in the API spec + + Args: + spec: Parsed OpenAPI specification + + Returns: + Set of tag names + """ + tags = set() + + for path_item in spec['paths'].values(): + for operation in path_item.values(): + if isinstance(operation, dict) and 'tags' in operation: + tags.update(operation['tags']) + + return tags + + +def get_schema_examples(spec: Dict[str, Any]) -> Dict[str, Any]: + """ + Extract all examples from component schemas + + Args: + spec: Parsed OpenAPI specification + + Returns: + Dict mapping schema names to examples + """ + examples = {} + + if 'components' not in spec or 'schemas' not in spec['components']: + return examples + + for name, schema in spec['components']['schemas'].items(): + if 'example' in schema: + examples[name] = schema['example'] + + return examples \ No newline at end of file diff --git a/tests-api/utils/validation.py b/tests-api/utils/validation.py new file mode 100644 index 000000000..9e07663ae --- /dev/null +++ b/tests-api/utils/validation.py @@ -0,0 +1,180 @@ +""" +Utilities for API response validation against OpenAPI spec +""" +import json +import os +import yaml +import jsonschema +from typing import Any, Dict, List, Optional, Union + + +def load_openapi_spec(spec_path: str) -> Dict[str, Any]: + """ + Load the OpenAPI specification from a YAML file + + Args: + spec_path: Path to the OpenAPI specification file + + Returns: + Dict containing the parsed OpenAPI spec + """ + with open(spec_path, 'r') as f: + return yaml.safe_load(f) + + +def get_endpoint_schema( + spec: Dict[str, Any], + path: str, + method: str, + status_code: str = '200' +) -> Optional[Dict[str, Any]]: + """ + Extract response schema for a specific endpoint from OpenAPI spec + + Args: + spec: Parsed OpenAPI specification + path: API path (e.g., '/prompt') + method: HTTP method (e.g., 'get', 'post') + status_code: HTTP status code to get schema for + + Returns: + Schema dict or None if not found + """ + method = method.lower() + + # Handle path not found + if path not in spec['paths']: + return None + + # Handle method not found + if method not in spec['paths'][path]: + return None + + # Handle status code not found + responses = spec['paths'][path][method].get('responses', {}) + if status_code not in responses: + return None + + # Handle no content defined + if 'content' not in responses[status_code]: + return None + + # Get schema from first content type + content_types = responses[status_code]['content'] + first_content_type = next(iter(content_types)) + + if 'schema' not in content_types[first_content_type]: + return None + + return content_types[first_content_type]['schema'] + + +def resolve_schema_refs(schema: Dict[str, Any], spec: Dict[str, Any]) -> Dict[str, Any]: + """ + Resolve $ref references in a schema + + Args: + schema: Schema that may contain references + spec: Full OpenAPI spec with component definitions + + Returns: + Schema with references resolved + """ + if not isinstance(schema, dict): + return schema + + result = {} + + for key, value in schema.items(): + if key == '$ref' and isinstance(value, str) and value.startswith('#/'): + # Handle reference + ref_path = value[2:].split('/') + ref_value = spec + for path_part in ref_path: + ref_value = ref_value.get(path_part, {}) + + # Recursively resolve any refs in the referenced schema + ref_value = resolve_schema_refs(ref_value, spec) + result.update(ref_value) + elif isinstance(value, dict): + # Recursively resolve refs in nested dictionaries + result[key] = resolve_schema_refs(value, spec) + elif isinstance(value, list): + # Recursively resolve refs in list items + result[key] = [ + resolve_schema_refs(item, spec) if isinstance(item, dict) else item + for item in value + ] + else: + # Pass through other values + result[key] = value + + return result + + +def validate_response( + response_data: Union[Dict[str, Any], List[Any]], + spec: Dict[str, Any], + path: str, + method: str, + status_code: str = '200' +) -> Dict[str, Any]: + """ + Validate a response against the OpenAPI schema + + Args: + response_data: Response data to validate + spec: Parsed OpenAPI specification + path: API path (e.g., '/prompt') + method: HTTP method (e.g., 'get', 'post') + status_code: HTTP status code to validate against + + Returns: + Dict with validation result containing: + - valid: bool indicating if validation passed + - errors: List of validation errors if any + """ + schema = get_endpoint_schema(spec, path, method, status_code) + + if schema is None: + return { + 'valid': False, + 'errors': [f"No schema found for {method.upper()} {path} with status {status_code}"] + } + + # Resolve any $ref in the schema + resolved_schema = resolve_schema_refs(schema, spec) + + try: + jsonschema.validate(instance=response_data, schema=resolved_schema) + return {'valid': True, 'errors': []} + except jsonschema.exceptions.ValidationError as e: + return {'valid': False, 'errors': [str(e)]} + + +def get_all_endpoints(spec: Dict[str, Any]) -> List[Dict[str, Any]]: + """ + Extract all endpoints from an OpenAPI spec + + Args: + spec: Parsed OpenAPI specification + + Returns: + List of dicts with path, method, and tags for each endpoint + """ + endpoints = [] + + for path, path_item in spec['paths'].items(): + for method, operation in path_item.items(): + if method.lower() not in ['get', 'post', 'put', 'delete', 'patch']: + continue + + endpoints.append({ + 'path': path, + 'method': method.lower(), + 'tags': operation.get('tags', []), + 'operation_id': operation.get('operationId', ''), + 'summary': operation.get('summary', '') + }) + + return endpoints \ No newline at end of file