Fix linting issues in API tests

This commit is contained in:
bymyself 2025-05-20 12:26:56 -07:00
parent fa9688b1fb
commit e8a92e4c9b
7 changed files with 288 additions and 292 deletions

View File

@ -21,13 +21,13 @@ DEFAULT_SERVER_URL = "http://127.0.0.1:8188"
def api_spec_path() -> str: def api_spec_path() -> str:
""" """
Get the path to the OpenAPI specification file Get the path to the OpenAPI specification file
Returns: Returns:
Path to the OpenAPI specification file Path to the OpenAPI specification file
""" """
return os.path.abspath(os.path.join( return os.path.abspath(os.path.join(
os.path.dirname(__file__), os.path.dirname(__file__),
"..", "..",
"openapi.yaml" "openapi.yaml"
)) ))
@ -36,10 +36,10 @@ def api_spec_path() -> str:
def api_spec(api_spec_path: str) -> Dict[str, Any]: def api_spec(api_spec_path: str) -> Dict[str, Any]:
""" """
Load the OpenAPI specification Load the OpenAPI specification
Args: Args:
api_spec_path: Path to the spec file api_spec_path: Path to the spec file
Returns: Returns:
Parsed OpenAPI specification Parsed OpenAPI specification
""" """
@ -51,7 +51,7 @@ def api_spec(api_spec_path: str) -> Dict[str, Any]:
def base_url() -> str: def base_url() -> str:
""" """
Get the base URL for the API server Get the base URL for the API server
Returns: Returns:
Base URL string Base URL string
""" """
@ -63,10 +63,10 @@ def base_url() -> str:
def server_available(base_url: str) -> bool: def server_available(base_url: str) -> bool:
""" """
Check if the server is available Check if the server is available
Args: Args:
base_url: Base URL for the API base_url: Base URL for the API
Returns: Returns:
True if the server is available, False otherwise True if the server is available, False otherwise
""" """
@ -82,24 +82,24 @@ def server_available(base_url: str) -> bool:
def api_client(base_url: str) -> Generator[Optional[requests.Session], None, None]: def api_client(base_url: str) -> Generator[Optional[requests.Session], None, None]:
""" """
Create a requests session for API testing Create a requests session for API testing
Args: Args:
base_url: Base URL for the API base_url: Base URL for the API
Yields: Yields:
Requests session configured for the API Requests session configured for the API
""" """
session = requests.Session() session = requests.Session()
# Helper function to construct URLs # Helper function to construct URLs
def get_url(path: str) -> str: def get_url(path: str) -> str:
return urljoin(base_url, path) return urljoin(base_url, path)
# Add url helper to the session # Add url helper to the session
session.get_url = get_url # type: ignore session.get_url = get_url # type: ignore
yield session yield session
# Cleanup # Cleanup
session.close() session.close()
@ -108,24 +108,24 @@ def api_client(base_url: str) -> Generator[Optional[requests.Session], None, Non
def api_get_json(api_client: requests.Session): def api_get_json(api_client: requests.Session):
""" """
Helper fixture for making GET requests and parsing JSON responses Helper fixture for making GET requests and parsing JSON responses
Args: Args:
api_client: API client session api_client: API client session
Returns: Returns:
Function that makes GET requests and returns JSON Function that makes GET requests and returns JSON
""" """
def _get_json(path: str, **kwargs): def _get_json(path: str, **kwargs):
url = api_client.get_url(path) # type: ignore url = api_client.get_url(path) # type: ignore
response = api_client.get(url, **kwargs) response = api_client.get(url, **kwargs)
if response.status_code == 200: if response.status_code == 200:
try: try:
return response.json() return response.json()
except ValueError: except ValueError:
return None return None
return None return None
return _get_json return _get_json
@ -133,9 +133,9 @@ def api_get_json(api_client: requests.Session):
def require_server(server_available): def require_server(server_available):
""" """
Skip tests if server is not available Skip tests if server is not available
Args: Args:
server_available: Whether the server is available server_available: Whether the server is available
""" """
if not server_available: if not server_available:
pytest.skip("Server is not available") pytest.skip("Server is not available")

View File

@ -5,7 +5,7 @@ import pytest
import logging import logging
import sys import sys
import os import os
from typing import Dict, Any, List, Set from typing import Dict, Any, Set
# Use a direct import with the full path # Use a direct import with the full path
current_dir = os.path.dirname(os.path.abspath(__file__)) current_dir = os.path.dirname(os.path.abspath(__file__))
@ -17,12 +17,12 @@ def get_all_endpoints(spec):
Extract all endpoints from an OpenAPI spec Extract all endpoints from an OpenAPI spec
""" """
endpoints = [] endpoints = []
for path, path_item in spec['paths'].items(): for path, path_item in spec['paths'].items():
for method, operation in path_item.items(): for method, operation in path_item.items():
if method.lower() not in ['get', 'post', 'put', 'delete', 'patch']: if method.lower() not in ['get', 'post', 'put', 'delete', 'patch']:
continue continue
endpoints.append({ endpoints.append({
'path': path, 'path': path,
'method': method.lower(), 'method': method.lower(),
@ -30,7 +30,7 @@ def get_all_endpoints(spec):
'operation_id': operation.get('operationId', ''), 'operation_id': operation.get('operationId', ''),
'summary': operation.get('summary', '') 'summary': operation.get('summary', '')
}) })
return endpoints return endpoints
def get_all_tags(spec): def get_all_tags(spec):
@ -38,12 +38,12 @@ def get_all_tags(spec):
Get all tags used in the API spec Get all tags used in the API spec
""" """
tags = set() tags = set()
for path_item in spec['paths'].values(): for path_item in spec['paths'].values():
for operation in path_item.values(): for operation in path_item.values():
if isinstance(operation, dict) and 'tags' in operation: if isinstance(operation, dict) and 'tags' in operation:
tags.update(operation['tags']) tags.update(operation['tags'])
return tags return tags
def extract_endpoints_by_tag(spec, tag): def extract_endpoints_by_tag(spec, tag):
@ -51,12 +51,12 @@ def extract_endpoints_by_tag(spec, tag):
Extract all endpoints with a specific tag Extract all endpoints with a specific tag
""" """
endpoints = [] endpoints = []
for path, path_item in spec['paths'].items(): for path, path_item in spec['paths'].items():
for method, operation in path_item.items(): for method, operation in path_item.items():
if method.lower() not in ['get', 'post', 'put', 'delete', 'patch']: if method.lower() not in ['get', 'post', 'put', 'delete', 'patch']:
continue continue
if tag in operation.get('tags', []): if tag in operation.get('tags', []):
endpoints.append({ endpoints.append({
'path': path, 'path': path,
@ -64,7 +64,7 @@ def extract_endpoints_by_tag(spec, tag):
'operation_id': operation.get('operationId', ''), 'operation_id': operation.get('operationId', ''),
'summary': operation.get('summary', '') 'summary': operation.get('summary', '')
}) })
return endpoints return endpoints
# Setup logging # Setup logging
@ -76,10 +76,10 @@ logger = logging.getLogger(__name__)
def api_tags(api_spec: Dict[str, Any]) -> Set[str]: def api_tags(api_spec: Dict[str, Any]) -> Set[str]:
""" """
Get all tags from the API spec Get all tags from the API spec
Args: Args:
api_spec: Loaded OpenAPI spec api_spec: Loaded OpenAPI spec
Returns: Returns:
Set of tag names Set of tag names
""" """
@ -89,12 +89,12 @@ def api_tags(api_spec: Dict[str, Any]) -> Set[str]:
def test_api_has_tags(api_tags: Set[str]): def test_api_has_tags(api_tags: Set[str]):
""" """
Test that the API has defined tags Test that the API has defined tags
Args: Args:
api_tags: Set of tags api_tags: Set of tags
""" """
assert len(api_tags) > 0, "API spec should have at least one tag" assert len(api_tags) > 0, "API spec should have at least one tag"
# Log the tags # Log the tags
logger.info(f"API spec has the following tags: {sorted(api_tags)}") logger.info(f"API spec has the following tags: {sorted(api_tags)}")
@ -109,7 +109,7 @@ def test_api_has_tags(api_tags: Set[str]):
def test_core_tags_exist(api_tags: Set[str], tag: str): def test_core_tags_exist(api_tags: Set[str], tag: str):
""" """
Test that core tags exist in the API spec Test that core tags exist in the API spec
Args: Args:
api_tags: Set of tags api_tags: Set of tags
tag: Tag to check tag: Tag to check
@ -120,18 +120,18 @@ def test_core_tags_exist(api_tags: Set[str], tag: str):
def test_workflow_tag_has_endpoints(api_spec: Dict[str, Any]): def test_workflow_tag_has_endpoints(api_spec: Dict[str, Any]):
""" """
Test that the 'workflow' tag has appropriate endpoints Test that the 'workflow' tag has appropriate endpoints
Args: Args:
api_spec: Loaded OpenAPI spec api_spec: Loaded OpenAPI spec
""" """
endpoints = extract_endpoints_by_tag(api_spec, "workflow") endpoints = extract_endpoints_by_tag(api_spec, "workflow")
assert len(endpoints) > 0, "No endpoints found with 'workflow' tag" assert len(endpoints) > 0, "No endpoints found with 'workflow' tag"
# Check for key workflow endpoints # Check for key workflow endpoints
endpoint_paths = [e["path"] for e in endpoints] endpoint_paths = [e["path"] for e in endpoints]
assert "/prompt" in endpoint_paths, "Workflow tag should include /prompt endpoint" assert "/prompt" in endpoint_paths, "Workflow tag should include /prompt endpoint"
# Log the endpoints # Log the endpoints
logger.info(f"Found {len(endpoints)} endpoints with 'workflow' tag:") logger.info(f"Found {len(endpoints)} endpoints with 'workflow' tag:")
for e in endpoints: for e in endpoints:
@ -141,19 +141,19 @@ def test_workflow_tag_has_endpoints(api_spec: Dict[str, Any]):
def test_image_tag_has_endpoints(api_spec: Dict[str, Any]): def test_image_tag_has_endpoints(api_spec: Dict[str, Any]):
""" """
Test that the 'image' tag has appropriate endpoints Test that the 'image' tag has appropriate endpoints
Args: Args:
api_spec: Loaded OpenAPI spec api_spec: Loaded OpenAPI spec
""" """
endpoints = extract_endpoints_by_tag(api_spec, "image") endpoints = extract_endpoints_by_tag(api_spec, "image")
assert len(endpoints) > 0, "No endpoints found with 'image' tag" assert len(endpoints) > 0, "No endpoints found with 'image' tag"
# Check for key image endpoints # Check for key image endpoints
endpoint_paths = [e["path"] for e in endpoints] endpoint_paths = [e["path"] for e in endpoints]
assert "/upload/image" in endpoint_paths, "Image tag should include /upload/image endpoint" assert "/upload/image" in endpoint_paths, "Image tag should include /upload/image endpoint"
assert "/view" in endpoint_paths, "Image tag should include /view endpoint" assert "/view" in endpoint_paths, "Image tag should include /view endpoint"
# Log the endpoints # Log the endpoints
logger.info(f"Found {len(endpoints)} endpoints with 'image' tag:") logger.info(f"Found {len(endpoints)} endpoints with 'image' tag:")
for e in endpoints: for e in endpoints:
@ -163,18 +163,18 @@ def test_image_tag_has_endpoints(api_spec: Dict[str, Any]):
def test_model_tag_has_endpoints(api_spec: Dict[str, Any]): def test_model_tag_has_endpoints(api_spec: Dict[str, Any]):
""" """
Test that the 'model' tag has appropriate endpoints Test that the 'model' tag has appropriate endpoints
Args: Args:
api_spec: Loaded OpenAPI spec api_spec: Loaded OpenAPI spec
""" """
endpoints = extract_endpoints_by_tag(api_spec, "model") endpoints = extract_endpoints_by_tag(api_spec, "model")
assert len(endpoints) > 0, "No endpoints found with 'model' tag" assert len(endpoints) > 0, "No endpoints found with 'model' tag"
# Check for key model endpoints # Check for key model endpoints
endpoint_paths = [e["path"] for e in endpoints] endpoint_paths = [e["path"] for e in endpoints]
assert "/models" in endpoint_paths, "Model tag should include /models endpoint" assert "/models" in endpoint_paths, "Model tag should include /models endpoint"
# Log the endpoints # Log the endpoints
logger.info(f"Found {len(endpoints)} endpoints with 'model' tag:") logger.info(f"Found {len(endpoints)} endpoints with 'model' tag:")
for e in endpoints: for e in endpoints:
@ -184,18 +184,18 @@ def test_model_tag_has_endpoints(api_spec: Dict[str, Any]):
def test_node_tag_has_endpoints(api_spec: Dict[str, Any]): def test_node_tag_has_endpoints(api_spec: Dict[str, Any]):
""" """
Test that the 'node' tag has appropriate endpoints Test that the 'node' tag has appropriate endpoints
Args: Args:
api_spec: Loaded OpenAPI spec api_spec: Loaded OpenAPI spec
""" """
endpoints = extract_endpoints_by_tag(api_spec, "node") endpoints = extract_endpoints_by_tag(api_spec, "node")
assert len(endpoints) > 0, "No endpoints found with 'node' tag" assert len(endpoints) > 0, "No endpoints found with 'node' tag"
# Check for key node endpoints # Check for key node endpoints
endpoint_paths = [e["path"] for e in endpoints] endpoint_paths = [e["path"] for e in endpoints]
assert "/object_info" in endpoint_paths, "Node tag should include /object_info endpoint" assert "/object_info" in endpoint_paths, "Node tag should include /object_info endpoint"
# Log the endpoints # Log the endpoints
logger.info(f"Found {len(endpoints)} endpoints with 'node' tag:") logger.info(f"Found {len(endpoints)} endpoints with 'node' tag:")
for e in endpoints: for e in endpoints:
@ -205,18 +205,18 @@ def test_node_tag_has_endpoints(api_spec: Dict[str, Any]):
def test_system_tag_has_endpoints(api_spec: Dict[str, Any]): def test_system_tag_has_endpoints(api_spec: Dict[str, Any]):
""" """
Test that the 'system' tag has appropriate endpoints Test that the 'system' tag has appropriate endpoints
Args: Args:
api_spec: Loaded OpenAPI spec api_spec: Loaded OpenAPI spec
""" """
endpoints = extract_endpoints_by_tag(api_spec, "system") endpoints = extract_endpoints_by_tag(api_spec, "system")
assert len(endpoints) > 0, "No endpoints found with 'system' tag" assert len(endpoints) > 0, "No endpoints found with 'system' tag"
# Check for key system endpoints # Check for key system endpoints
endpoint_paths = [e["path"] for e in endpoints] endpoint_paths = [e["path"] for e in endpoints]
assert "/system_stats" in endpoint_paths, "System tag should include /system_stats endpoint" assert "/system_stats" in endpoint_paths, "System tag should include /system_stats endpoint"
# Log the endpoints # Log the endpoints
logger.info(f"Found {len(endpoints)} endpoints with 'system' tag:") logger.info(f"Found {len(endpoints)} endpoints with 'system' tag:")
for e in endpoints: for e in endpoints:
@ -226,18 +226,18 @@ def test_system_tag_has_endpoints(api_spec: Dict[str, Any]):
def test_internal_tag_has_endpoints(api_spec: Dict[str, Any]): def test_internal_tag_has_endpoints(api_spec: Dict[str, Any]):
""" """
Test that the 'internal' tag has appropriate endpoints Test that the 'internal' tag has appropriate endpoints
Args: Args:
api_spec: Loaded OpenAPI spec api_spec: Loaded OpenAPI spec
""" """
endpoints = extract_endpoints_by_tag(api_spec, "internal") endpoints = extract_endpoints_by_tag(api_spec, "internal")
assert len(endpoints) > 0, "No endpoints found with 'internal' tag" assert len(endpoints) > 0, "No endpoints found with 'internal' tag"
# Check for key internal endpoints # Check for key internal endpoints
endpoint_paths = [e["path"] for e in endpoints] endpoint_paths = [e["path"] for e in endpoints]
assert "/internal/logs" in endpoint_paths, "Internal tag should include /internal/logs endpoint" assert "/internal/logs" in endpoint_paths, "Internal tag should include /internal/logs endpoint"
# Log the endpoints # Log the endpoints
logger.info(f"Found {len(endpoints)} endpoints with 'internal' tag:") logger.info(f"Found {len(endpoints)} endpoints with 'internal' tag:")
for e in endpoints: for e in endpoints:
@ -247,22 +247,22 @@ def test_internal_tag_has_endpoints(api_spec: Dict[str, Any]):
def test_operation_ids_match_tag(api_spec: Dict[str, Any]): def test_operation_ids_match_tag(api_spec: Dict[str, Any]):
""" """
Test that operation IDs follow a consistent pattern with their tag Test that operation IDs follow a consistent pattern with their tag
Args: Args:
api_spec: Loaded OpenAPI spec api_spec: Loaded OpenAPI spec
""" """
failures = [] failures = []
for path, path_item in api_spec['paths'].items(): for path, path_item in api_spec['paths'].items():
for method, operation in path_item.items(): for method, operation in path_item.items():
if method in ['get', 'post', 'put', 'delete', 'patch']: if method in ['get', 'post', 'put', 'delete', 'patch']:
if 'operationId' in operation and 'tags' in operation and operation['tags']: if 'operationId' in operation and 'tags' in operation and operation['tags']:
op_id = operation['operationId'] op_id = operation['operationId']
primary_tag = operation['tags'][0].lower() primary_tag = operation['tags'][0].lower()
# Check if operationId starts with primary tag prefix # Check if operationId starts with primary tag prefix
# This is a common convention, but might need adjusting # This is a common convention, but might need adjusting
if not (op_id.startswith(primary_tag) or if not (op_id.startswith(primary_tag) or
any(op_id.lower().startswith(f"{tag.lower()}") for tag in operation['tags'])): any(op_id.lower().startswith(f"{tag.lower()}") for tag in operation['tags'])):
failures.append({ failures.append({
'path': path, 'path': path,
@ -270,10 +270,10 @@ def test_operation_ids_match_tag(api_spec: Dict[str, Any]):
'operationId': op_id, 'operationId': op_id,
'primary_tag': primary_tag 'primary_tag': primary_tag
}) })
# Log failures for diagnosis but don't fail the test # Log failures for diagnosis but don't fail the test
# as this is a style/convention check # as this is a style/convention check
if failures: if failures:
logger.warning(f"Found {len(failures)} operationIds that don't align with their tags:") logger.warning(f"Found {len(failures)} operationIds that don't align with their tags:")
for f in failures: for f in failures:
logger.warning(f" {f['method'].upper()} {f['path']} - operationId: {f['operationId']}, primary tag: {f['primary_tag']}") logger.warning(f" {f['method'].upper()} {f['path']} - operationId: {f['operationId']}, primary tag: {f['primary_tag']}")

View File

@ -17,20 +17,20 @@ sys.path.insert(0, current_dir)
def get_all_endpoints(spec): def get_all_endpoints(spec):
""" """
Extract all endpoints from an OpenAPI spec Extract all endpoints from an OpenAPI spec
Args: Args:
spec: Parsed OpenAPI specification spec: Parsed OpenAPI specification
Returns: Returns:
List of dicts with path, method, and tags for each endpoint List of dicts with path, method, and tags for each endpoint
""" """
endpoints = [] endpoints = []
for path, path_item in spec['paths'].items(): for path, path_item in spec['paths'].items():
for method, operation in path_item.items(): for method, operation in path_item.items():
if method.lower() not in ['get', 'post', 'put', 'delete', 'patch']: if method.lower() not in ['get', 'post', 'put', 'delete', 'patch']:
continue continue
endpoints.append({ endpoints.append({
'path': path, 'path': path,
'method': method.lower(), 'method': method.lower(),
@ -38,7 +38,7 @@ def get_all_endpoints(spec):
'operation_id': operation.get('operationId', ''), 'operation_id': operation.get('operationId', ''),
'summary': operation.get('summary', '') 'summary': operation.get('summary', '')
}) })
return endpoints return endpoints
# Setup logging # Setup logging
@ -50,10 +50,10 @@ logger = logging.getLogger(__name__)
def all_endpoints(api_spec: Dict[str, Any]) -> List[Dict[str, Any]]: def all_endpoints(api_spec: Dict[str, Any]) -> List[Dict[str, Any]]:
""" """
Get all endpoints from the API spec Get all endpoints from the API spec
Args: Args:
api_spec: Loaded OpenAPI spec api_spec: Loaded OpenAPI spec
Returns: Returns:
List of endpoint information List of endpoint information
""" """
@ -63,13 +63,13 @@ def all_endpoints(api_spec: Dict[str, Any]) -> List[Dict[str, Any]]:
def test_endpoints_exist(all_endpoints: List[Dict[str, Any]]): def test_endpoints_exist(all_endpoints: List[Dict[str, Any]]):
""" """
Test that endpoints are defined in the spec Test that endpoints are defined in the spec
Args: Args:
all_endpoints: List of endpoint information all_endpoints: List of endpoint information
""" """
# Simple check that we have endpoints defined # Simple check that we have endpoints defined
assert len(all_endpoints) > 0, "No endpoints defined in the OpenAPI spec" assert len(all_endpoints) > 0, "No endpoints defined in the OpenAPI spec"
# Log the endpoints for informational purposes # Log the endpoints for informational purposes
logger.info(f"Found {len(all_endpoints)} endpoints in the OpenAPI spec") logger.info(f"Found {len(all_endpoints)} endpoints in the OpenAPI spec")
for endpoint in all_endpoints: for endpoint in all_endpoints:
@ -87,23 +87,23 @@ def test_endpoints_exist(all_endpoints: List[Dict[str, Any]]):
def test_basic_get_endpoints(require_server, api_client, endpoint_path: str): def test_basic_get_endpoints(require_server, api_client, endpoint_path: str):
""" """
Test that basic GET endpoints exist and respond Test that basic GET endpoints exist and respond
Args: Args:
require_server: Fixture that skips if server is not available require_server: Fixture that skips if server is not available
api_client: API client fixture api_client: API client fixture
endpoint_path: Path to test endpoint_path: Path to test
""" """
url = api_client.get_url(endpoint_path) # type: ignore url = api_client.get_url(endpoint_path) # type: ignore
try: try:
response = api_client.get(url) response = api_client.get(url)
# We're just checking that the endpoint exists and returns some kind of response # We're just checking that the endpoint exists and returns some kind of response
# Not necessarily a 200 status code # Not necessarily a 200 status code
assert response.status_code not in [404, 405], f"Endpoint {endpoint_path} does not exist" assert response.status_code not in [404, 405], f"Endpoint {endpoint_path} does not exist"
logger.info(f"Endpoint {endpoint_path} exists with status code {response.status_code}") logger.info(f"Endpoint {endpoint_path} exists with status code {response.status_code}")
except requests.RequestException as e: except requests.RequestException as e:
pytest.fail(f"Request to {endpoint_path} failed: {str(e)}") pytest.fail(f"Request to {endpoint_path} failed: {str(e)}")
@ -111,24 +111,24 @@ def test_basic_get_endpoints(require_server, api_client, endpoint_path: str):
def test_websocket_endpoint_exists(require_server, base_url: str): def test_websocket_endpoint_exists(require_server, base_url: str):
""" """
Test that the WebSocket endpoint exists Test that the WebSocket endpoint exists
Args: Args:
require_server: Fixture that skips if server is not available require_server: Fixture that skips if server is not available
base_url: Base server URL base_url: Base server URL
""" """
ws_url = urljoin(base_url, "/ws") ws_url = urljoin(base_url, "/ws")
# For WebSocket, we can't use a normal GET request # For WebSocket, we can't use a normal GET request
# Instead, we make a HEAD request to check if the endpoint exists # Instead, we make a HEAD request to check if the endpoint exists
try: try:
response = requests.head(ws_url) response = requests.head(ws_url)
# WebSocket endpoints often return a 400 Bad Request for HEAD requests # WebSocket endpoints often return a 400 Bad Request for HEAD requests
# but a 404 would indicate the endpoint doesn't exist # but a 404 would indicate the endpoint doesn't exist
assert response.status_code != 404, "WebSocket endpoint /ws does not exist" assert response.status_code != 404, "WebSocket endpoint /ws does not exist"
logger.info(f"WebSocket endpoint exists with status code {response.status_code}") logger.info(f"WebSocket endpoint exists with status code {response.status_code}")
except requests.RequestException as e: except requests.RequestException as e:
pytest.fail(f"Request to WebSocket endpoint failed: {str(e)}") pytest.fail(f"Request to WebSocket endpoint failed: {str(e)}")
@ -136,35 +136,35 @@ def test_websocket_endpoint_exists(require_server, base_url: str):
def test_api_models_folder_endpoint(require_server, api_client): def test_api_models_folder_endpoint(require_server, api_client):
""" """
Test that the /models/{folder} endpoint exists and responds Test that the /models/{folder} endpoint exists and responds
Args: Args:
require_server: Fixture that skips if server is not available require_server: Fixture that skips if server is not available
api_client: API client fixture api_client: API client fixture
""" """
# First get available model types # First get available model types
models_url = api_client.get_url("/models") # type: ignore models_url = api_client.get_url("/models") # type: ignore
try: try:
models_response = api_client.get(models_url) models_response = api_client.get(models_url)
assert models_response.status_code == 200, "Failed to get model types" assert models_response.status_code == 200, "Failed to get model types"
model_types = models_response.json() model_types = models_response.json()
# Skip if no model types available # Skip if no model types available
if not model_types: if not model_types:
pytest.skip("No model types available to test") pytest.skip("No model types available to test")
# Test with the first model type # Test with the first model type
model_type = model_types[0] model_type = model_types[0]
models_folder_url = api_client.get_url(f"/models/{model_type}") # type: ignore models_folder_url = api_client.get_url(f"/models/{model_type}") # type: ignore
folder_response = api_client.get(models_folder_url) folder_response = api_client.get(models_folder_url)
# We're just checking that the endpoint exists # We're just checking that the endpoint exists
assert folder_response.status_code != 404, f"Endpoint /models/{model_type} does not exist" assert folder_response.status_code != 404, f"Endpoint /models/{model_type} does not exist"
logger.info(f"Endpoint /models/{model_type} exists with status code {folder_response.status_code}") logger.info(f"Endpoint /models/{model_type} exists with status code {folder_response.status_code}")
except requests.RequestException as e: except requests.RequestException as e:
pytest.fail(f"Request failed: {str(e)}") pytest.fail(f"Request failed: {str(e)}")
except (ValueError, KeyError, IndexError) as e: except (ValueError, KeyError, IndexError) as e:
@ -174,35 +174,35 @@ def test_api_models_folder_endpoint(require_server, api_client):
def test_api_object_info_node_endpoint(require_server, api_client): def test_api_object_info_node_endpoint(require_server, api_client):
""" """
Test that the /object_info/{node_class} endpoint exists and responds Test that the /object_info/{node_class} endpoint exists and responds
Args: Args:
require_server: Fixture that skips if server is not available require_server: Fixture that skips if server is not available
api_client: API client fixture api_client: API client fixture
""" """
# First get available node classes # First get available node classes
objects_url = api_client.get_url("/object_info") # type: ignore objects_url = api_client.get_url("/object_info") # type: ignore
try: try:
objects_response = api_client.get(objects_url) objects_response = api_client.get(objects_url)
assert objects_response.status_code == 200, "Failed to get object info" assert objects_response.status_code == 200, "Failed to get object info"
node_classes = objects_response.json() node_classes = objects_response.json()
# Skip if no node classes available # Skip if no node classes available
if not node_classes: if not node_classes:
pytest.skip("No node classes available to test") pytest.skip("No node classes available to test")
# Test with the first node class # Test with the first node class
node_class = next(iter(node_classes.keys())) node_class = next(iter(node_classes.keys()))
node_url = api_client.get_url(f"/object_info/{node_class}") # type: ignore node_url = api_client.get_url(f"/object_info/{node_class}") # type: ignore
node_response = api_client.get(node_url) node_response = api_client.get(node_url)
# We're just checking that the endpoint exists # We're just checking that the endpoint exists
assert node_response.status_code != 404, f"Endpoint /object_info/{node_class} does not exist" assert node_response.status_code != 404, f"Endpoint /object_info/{node_class} does not exist"
logger.info(f"Endpoint /object_info/{node_class} exists with status code {node_response.status_code}") logger.info(f"Endpoint /object_info/{node_class} exists with status code {node_response.status_code}")
except requests.RequestException as e: except requests.RequestException as e:
pytest.fail(f"Request failed: {str(e)}") pytest.fail(f"Request failed: {str(e)}")
except (ValueError, KeyError, StopIteration) as e: except (ValueError, KeyError, StopIteration) as e:
@ -212,7 +212,7 @@ def test_api_object_info_node_endpoint(require_server, api_client):
def test_internal_endpoints_exist(require_server, api_client): def test_internal_endpoints_exist(require_server, api_client):
""" """
Test that internal endpoints exist Test that internal endpoints exist
Args: Args:
require_server: Fixture that skips if server is not available require_server: Fixture that skips if server is not available
api_client: API client fixture api_client: API client fixture
@ -223,18 +223,18 @@ def test_internal_endpoints_exist(require_server, api_client):
"/internal/folder_paths", "/internal/folder_paths",
"/internal/files/output" "/internal/files/output"
] ]
for endpoint in internal_endpoints: for endpoint in internal_endpoints:
url = api_client.get_url(endpoint) # type: ignore url = api_client.get_url(endpoint) # type: ignore
try: try:
response = api_client.get(url) response = api_client.get(url)
# We're just checking that the endpoint exists # We're just checking that the endpoint exists
assert response.status_code != 404, f"Endpoint {endpoint} does not exist" assert response.status_code != 404, f"Endpoint {endpoint} does not exist"
logger.info(f"Endpoint {endpoint} exists with status code {response.status_code}") logger.info(f"Endpoint {endpoint} exists with status code {response.status_code}")
except requests.RequestException as e: except requests.RequestException as e:
logger.warning(f"Request to {endpoint} failed: {str(e)}") logger.warning(f"Request to {endpoint} failed: {str(e)}")
# Don't fail the test as internal endpoints might be restricted # Don't fail the test as internal endpoints might be restricted

View File

@ -7,7 +7,7 @@ import logging
import sys import sys
import os import os
import json import json
from typing import Dict, Any, List from typing import Dict, Any
# Use a direct import with the full path # Use a direct import with the full path
current_dir = os.path.dirname(os.path.abspath(__file__)) current_dir = os.path.dirname(os.path.abspath(__file__))
@ -15,40 +15,40 @@ sys.path.insert(0, current_dir)
# Define validation functions inline to avoid import issues # Define validation functions inline to avoid import issues
def get_endpoint_schema( def get_endpoint_schema(
spec, spec,
path, path,
method, method,
status_code = '200' status_code = '200'
): ):
""" """
Extract response schema for a specific endpoint from OpenAPI spec Extract response schema for a specific endpoint from OpenAPI spec
""" """
method = method.lower() method = method.lower()
# Handle path not found # Handle path not found
if path not in spec['paths']: if path not in spec['paths']:
return None return None
# Handle method not found # Handle method not found
if method not in spec['paths'][path]: if method not in spec['paths'][path]:
return None return None
# Handle status code not found # Handle status code not found
responses = spec['paths'][path][method].get('responses', {}) responses = spec['paths'][path][method].get('responses', {})
if status_code not in responses: if status_code not in responses:
return None return None
# Handle no content defined # Handle no content defined
if 'content' not in responses[status_code]: if 'content' not in responses[status_code]:
return None return None
# Get schema from first content type # Get schema from first content type
content_types = responses[status_code]['content'] content_types = responses[status_code]['content']
first_content_type = next(iter(content_types)) first_content_type = next(iter(content_types))
if 'schema' not in content_types[first_content_type]: if 'schema' not in content_types[first_content_type]:
return None return None
return content_types[first_content_type]['schema'] return content_types[first_content_type]['schema']
def resolve_schema_refs(schema, spec): def resolve_schema_refs(schema, spec):
@ -57,9 +57,9 @@ def resolve_schema_refs(schema, spec):
""" """
if not isinstance(schema, dict): if not isinstance(schema, dict):
return schema return schema
result = {} result = {}
for key, value in schema.items(): for key, value in schema.items():
if key == '$ref' and isinstance(value, str) and value.startswith('#/'): if key == '$ref' and isinstance(value, str) and value.startswith('#/'):
# Handle reference # Handle reference
@ -67,7 +67,7 @@ def resolve_schema_refs(schema, spec):
ref_value = spec ref_value = spec
for path_part in ref_path: for path_part in ref_path:
ref_value = ref_value.get(path_part, {}) ref_value = ref_value.get(path_part, {})
# Recursively resolve any refs in the referenced schema # Recursively resolve any refs in the referenced schema
ref_value = resolve_schema_refs(ref_value, spec) ref_value = resolve_schema_refs(ref_value, spec)
result.update(ref_value) result.update(ref_value)
@ -83,7 +83,7 @@ def resolve_schema_refs(schema, spec):
else: else:
# Pass through other values # Pass through other values
result[key] = value result[key] = value
return result return result
def validate_response( def validate_response(
@ -97,16 +97,16 @@ def validate_response(
Validate a response against the OpenAPI schema Validate a response against the OpenAPI schema
""" """
schema = get_endpoint_schema(spec, path, method, status_code) schema = get_endpoint_schema(spec, path, method, status_code)
if schema is None: if schema is None:
return { return {
'valid': False, 'valid': False,
'errors': [f"No schema found for {method.upper()} {path} with status {status_code}"] 'errors': [f"No schema found for {method.upper()} {path} with status {status_code}"]
} }
# Resolve any $ref in the schema # Resolve any $ref in the schema
resolved_schema = resolve_schema_refs(schema, spec) resolved_schema = resolve_schema_refs(schema, spec)
try: try:
import jsonschema import jsonschema
jsonschema.validate(instance=response_data, schema=resolved_schema) jsonschema.validate(instance=response_data, schema=resolved_schema)
@ -116,14 +116,14 @@ def validate_response(
path = ".".join(str(p) for p in e.path) if e.path else "root" path = ".".join(str(p) for p in e.path) if e.path else "root"
instance = e.instance if not isinstance(e.instance, dict) else "..." instance = e.instance if not isinstance(e.instance, dict) else "..."
schema_path = ".".join(str(p) for p in e.schema_path) if e.schema_path else "unknown" schema_path = ".".join(str(p) for p in e.schema_path) if e.schema_path else "unknown"
detailed_error = ( detailed_error = (
f"Validation error at path: {path}\n" f"Validation error at path: {path}\n"
f"Schema path: {schema_path}\n" f"Schema path: {schema_path}\n"
f"Error message: {e.message}\n" f"Error message: {e.message}\n"
f"Failed instance: {instance}\n" f"Failed instance: {instance}\n"
) )
return {'valid': False, 'errors': [detailed_error]} return {'valid': False, 'errors': [detailed_error]}
# Setup logging # Setup logging
@ -139,15 +139,15 @@ logger = logging.getLogger(__name__)
("/embeddings", "get") ("/embeddings", "get")
]) ])
def test_response_schema_validation( def test_response_schema_validation(
require_server, require_server,
api_client, api_client,
api_spec: Dict[str, Any], api_spec: Dict[str, Any],
endpoint_path: str, endpoint_path: str,
method: str method: str
): ):
""" """
Test that API responses match the defined schema Test that API responses match the defined schema
Args: Args:
require_server: Fixture that skips if server is not available require_server: Fixture that skips if server is not available
api_client: API client fixture api_client: API client fixture
@ -156,47 +156,47 @@ def test_response_schema_validation(
method: HTTP method to test method: HTTP method to test
""" """
url = api_client.get_url(endpoint_path) # type: ignore url = api_client.get_url(endpoint_path) # type: ignore
# Skip if no schema defined # Skip if no schema defined
schema = get_endpoint_schema(api_spec, endpoint_path, method) schema = get_endpoint_schema(api_spec, endpoint_path, method)
if not schema: if not schema:
pytest.skip(f"No schema defined for {method.upper()} {endpoint_path}") pytest.skip(f"No schema defined for {method.upper()} {endpoint_path}")
try: try:
if method.lower() == "get": if method.lower() == "get":
response = api_client.get(url) response = api_client.get(url)
else: else:
pytest.skip(f"Method {method} not implemented for automated testing") pytest.skip(f"Method {method} not implemented for automated testing")
return return
# Skip if response is not 200 # Skip if response is not 200
if response.status_code != 200: if response.status_code != 200:
pytest.skip(f"Endpoint {endpoint_path} returned status {response.status_code}") pytest.skip(f"Endpoint {endpoint_path} returned status {response.status_code}")
return return
# Skip if response is not JSON # Skip if response is not JSON
try: try:
response_data = response.json() response_data = response.json()
except ValueError: except ValueError:
pytest.skip(f"Endpoint {endpoint_path} did not return valid JSON") pytest.skip(f"Endpoint {endpoint_path} did not return valid JSON")
return return
# Validate the response # Validate the response
validation_result = validate_response( validation_result = validate_response(
response_data, response_data,
api_spec, api_spec,
endpoint_path, endpoint_path,
method method
) )
if validation_result['valid']: if validation_result['valid']:
logger.info(f"Response from {method.upper()} {endpoint_path} matches schema") logger.info(f"Response from {method.upper()} {endpoint_path} matches schema")
else: else:
for error in validation_result['errors']: for error in validation_result['errors']:
logger.error(f"Validation error for {method.upper()} {endpoint_path}: {error}") logger.error(f"Validation error for {method.upper()} {endpoint_path}: {error}")
assert validation_result['valid'], f"Response from {method.upper()} {endpoint_path} does not match schema" assert validation_result['valid'], f"Response from {method.upper()} {endpoint_path} does not match schema"
except requests.RequestException as e: except requests.RequestException as e:
pytest.fail(f"Request to {endpoint_path} failed: {str(e)}") pytest.fail(f"Request to {endpoint_path} failed: {str(e)}")
@ -204,67 +204,67 @@ def test_response_schema_validation(
def test_system_stats_response(require_server, api_client, api_spec: Dict[str, Any]): def test_system_stats_response(require_server, api_client, api_spec: Dict[str, Any]):
""" """
Test the system_stats endpoint response in detail Test the system_stats endpoint response in detail
Args: Args:
require_server: Fixture that skips if server is not available require_server: Fixture that skips if server is not available
api_client: API client fixture api_client: API client fixture
api_spec: Loaded OpenAPI spec api_spec: Loaded OpenAPI spec
""" """
url = api_client.get_url("/system_stats") # type: ignore url = api_client.get_url("/system_stats") # type: ignore
try: try:
response = api_client.get(url) response = api_client.get(url)
assert response.status_code == 200, "Failed to get system stats" assert response.status_code == 200, "Failed to get system stats"
# Parse response # Parse response
stats = response.json() stats = response.json()
# Validate high-level structure # Validate high-level structure
assert 'system' in stats, "Response missing 'system' field" assert 'system' in stats, "Response missing 'system' field"
assert 'devices' in stats, "Response missing 'devices' field" assert 'devices' in stats, "Response missing 'devices' field"
# Validate system fields # Validate system fields
system = stats['system'] system = stats['system']
assert 'os' in system, "System missing 'os' field" assert 'os' in system, "System missing 'os' field"
assert 'ram_total' in system, "System missing 'ram_total' field" assert 'ram_total' in system, "System missing 'ram_total' field"
assert 'ram_free' in system, "System missing 'ram_free' field" assert 'ram_free' in system, "System missing 'ram_free' field"
assert 'comfyui_version' in system, "System missing 'comfyui_version' field" assert 'comfyui_version' in system, "System missing 'comfyui_version' field"
# Validate devices fields # Validate devices fields
devices = stats['devices'] devices = stats['devices']
assert isinstance(devices, list), "Devices should be a list" assert isinstance(devices, list), "Devices should be a list"
if devices: if devices:
device = devices[0] device = devices[0]
assert 'name' in device, "Device missing 'name' field" assert 'name' in device, "Device missing 'name' field"
assert 'type' in device, "Device missing 'type' field" assert 'type' in device, "Device missing 'type' field"
assert 'vram_total' in device, "Device missing 'vram_total' field" assert 'vram_total' in device, "Device missing 'vram_total' field"
assert 'vram_free' in device, "Device missing 'vram_free' field" assert 'vram_free' in device, "Device missing 'vram_free' field"
# Perform schema validation # Perform schema validation
validation_result = validate_response( validation_result = validate_response(
stats, stats,
api_spec, api_spec,
"/system_stats", "/system_stats",
"get" "get"
) )
# Print detailed error if validation fails # Print detailed error if validation fails
if not validation_result['valid']: if not validation_result['valid']:
for error in validation_result['errors']: for error in validation_result['errors']:
logger.error(f"Validation error for /system_stats: {error}") logger.error(f"Validation error for /system_stats: {error}")
# Print schema details for debugging # Print schema details for debugging
schema = get_endpoint_schema(api_spec, "/system_stats", "get") schema = get_endpoint_schema(api_spec, "/system_stats", "get")
if schema: if schema:
logger.error(f"Schema structure:\n{json.dumps(schema, indent=2)}") logger.error(f"Schema structure:\n{json.dumps(schema, indent=2)}")
# Print sample of the response # Print sample of the response
logger.error(f"Response:\n{json.dumps(stats, indent=2)}") logger.error(f"Response:\n{json.dumps(stats, indent=2)}")
assert validation_result['valid'], "System stats response does not match schema" assert validation_result['valid'], "System stats response does not match schema"
except requests.RequestException as e: except requests.RequestException as e:
pytest.fail(f"Request to /system_stats failed: {str(e)}") pytest.fail(f"Request to /system_stats failed: {str(e)}")
@ -272,53 +272,53 @@ def test_system_stats_response(require_server, api_client, api_spec: Dict[str, A
def test_models_listing_response(require_server, api_client, api_spec: Dict[str, Any]): def test_models_listing_response(require_server, api_client, api_spec: Dict[str, Any]):
""" """
Test the models endpoint response Test the models endpoint response
Args: Args:
require_server: Fixture that skips if server is not available require_server: Fixture that skips if server is not available
api_client: API client fixture api_client: API client fixture
api_spec: Loaded OpenAPI spec api_spec: Loaded OpenAPI spec
""" """
url = api_client.get_url("/models") # type: ignore url = api_client.get_url("/models") # type: ignore
try: try:
response = api_client.get(url) response = api_client.get(url)
assert response.status_code == 200, "Failed to get models" assert response.status_code == 200, "Failed to get models"
# Parse response # Parse response
models = response.json() models = response.json()
# Validate it's a list # Validate it's a list
assert isinstance(models, list), "Models response should be a list" assert isinstance(models, list), "Models response should be a list"
# Each item should be a string # Each item should be a string
for model in models: for model in models:
assert isinstance(model, str), "Each model type should be a string" assert isinstance(model, str), "Each model type should be a string"
# Perform schema validation # Perform schema validation
validation_result = validate_response( validation_result = validate_response(
models, models,
api_spec, api_spec,
"/models", "/models",
"get" "get"
) )
# Print detailed error if validation fails # Print detailed error if validation fails
if not validation_result['valid']: if not validation_result['valid']:
for error in validation_result['errors']: for error in validation_result['errors']:
logger.error(f"Validation error for /models: {error}") logger.error(f"Validation error for /models: {error}")
# Print schema details for debugging # Print schema details for debugging
schema = get_endpoint_schema(api_spec, "/models", "get") schema = get_endpoint_schema(api_spec, "/models", "get")
if schema: if schema:
logger.error(f"Schema structure:\n{json.dumps(schema, indent=2)}") logger.error(f"Schema structure:\n{json.dumps(schema, indent=2)}")
# Print response # Print response
sample_models = models[:5] if isinstance(models, list) else models sample_models = models[:5] if isinstance(models, list) else models
logger.error(f"Models response:\n{json.dumps(sample_models, indent=2)}") logger.error(f"Models response:\n{json.dumps(sample_models, indent=2)}")
assert validation_result['valid'], "Models response does not match schema" assert validation_result['valid'], "Models response does not match schema"
except requests.RequestException as e: except requests.RequestException as e:
pytest.fail(f"Request to /models failed: {str(e)}") pytest.fail(f"Request to /models failed: {str(e)}")
@ -326,60 +326,60 @@ def test_models_listing_response(require_server, api_client, api_spec: Dict[str,
def test_object_info_response(require_server, api_client, api_spec: Dict[str, Any]): def test_object_info_response(require_server, api_client, api_spec: Dict[str, Any]):
""" """
Test the object_info endpoint response Test the object_info endpoint response
Args: Args:
require_server: Fixture that skips if server is not available require_server: Fixture that skips if server is not available
api_client: API client fixture api_client: API client fixture
api_spec: Loaded OpenAPI spec api_spec: Loaded OpenAPI spec
""" """
url = api_client.get_url("/object_info") # type: ignore url = api_client.get_url("/object_info") # type: ignore
try: try:
response = api_client.get(url) response = api_client.get(url)
assert response.status_code == 200, "Failed to get object info" assert response.status_code == 200, "Failed to get object info"
# Parse response # Parse response
objects = response.json() objects = response.json()
# Validate it's an object # Validate it's an object
assert isinstance(objects, dict), "Object info response should be an object" assert isinstance(objects, dict), "Object info response should be an object"
# Check if we have any objects # Check if we have any objects
if objects: if objects:
# Get the first object # Get the first object
first_obj_name = next(iter(objects.keys())) first_obj_name = next(iter(objects.keys()))
first_obj = objects[first_obj_name] first_obj = objects[first_obj_name]
# Validate first object has required fields # Validate first object has required fields
assert 'input' in first_obj, "Object missing 'input' field" assert 'input' in first_obj, "Object missing 'input' field"
assert 'output' in first_obj, "Object missing 'output' field" assert 'output' in first_obj, "Object missing 'output' field"
assert 'name' in first_obj, "Object missing 'name' field" assert 'name' in first_obj, "Object missing 'name' field"
# Perform schema validation # Perform schema validation
validation_result = validate_response( validation_result = validate_response(
objects, objects,
api_spec, api_spec,
"/object_info", "/object_info",
"get" "get"
) )
# Print detailed error if validation fails # Print detailed error if validation fails
if not validation_result['valid']: if not validation_result['valid']:
for error in validation_result['errors']: for error in validation_result['errors']:
logger.error(f"Validation error for /object_info: {error}") logger.error(f"Validation error for /object_info: {error}")
# Print schema details for debugging # Print schema details for debugging
schema = get_endpoint_schema(api_spec, "/object_info", "get") schema = get_endpoint_schema(api_spec, "/object_info", "get")
if schema: if schema:
logger.error(f"Schema structure:\n{json.dumps(schema, indent=2)}") logger.error(f"Schema structure:\n{json.dumps(schema, indent=2)}")
# Also print a small sample of the response # Also print a small sample of the response
sample = dict(list(objects.items())[:1]) if objects else {} sample = dict(list(objects.items())[:1]) if objects else {}
logger.error(f"Sample response:\n{json.dumps(sample, indent=2)}") logger.error(f"Sample response:\n{json.dumps(sample, indent=2)}")
assert validation_result['valid'], "Object info response does not match schema" assert validation_result['valid'], "Object info response does not match schema"
except requests.RequestException as e: except requests.RequestException as e:
pytest.fail(f"Request to /object_info failed: {str(e)}") pytest.fail(f"Request to /object_info failed: {str(e)}")
except (KeyError, StopIteration) as e: except (KeyError, StopIteration) as e:
@ -389,52 +389,52 @@ def test_object_info_response(require_server, api_client, api_spec: Dict[str, An
def test_queue_response(require_server, api_client, api_spec: Dict[str, Any]): def test_queue_response(require_server, api_client, api_spec: Dict[str, Any]):
""" """
Test the queue endpoint response Test the queue endpoint response
Args: Args:
require_server: Fixture that skips if server is not available require_server: Fixture that skips if server is not available
api_client: API client fixture api_client: API client fixture
api_spec: Loaded OpenAPI spec api_spec: Loaded OpenAPI spec
""" """
url = api_client.get_url("/queue") # type: ignore url = api_client.get_url("/queue") # type: ignore
try: try:
response = api_client.get(url) response = api_client.get(url)
assert response.status_code == 200, "Failed to get queue" assert response.status_code == 200, "Failed to get queue"
# Parse response # Parse response
queue = response.json() queue = response.json()
# Validate structure # Validate structure
assert 'queue_running' in queue, "Queue missing 'queue_running' field" assert 'queue_running' in queue, "Queue missing 'queue_running' field"
assert 'queue_pending' in queue, "Queue missing 'queue_pending' field" assert 'queue_pending' in queue, "Queue missing 'queue_pending' field"
# Each should be a list # Each should be a list
assert isinstance(queue['queue_running'], list), "queue_running should be a list" assert isinstance(queue['queue_running'], list), "queue_running should be a list"
assert isinstance(queue['queue_pending'], list), "queue_pending should be a list" assert isinstance(queue['queue_pending'], list), "queue_pending should be a list"
# Perform schema validation # Perform schema validation
validation_result = validate_response( validation_result = validate_response(
queue, queue,
api_spec, api_spec,
"/queue", "/queue",
"get" "get"
) )
# Print detailed error if validation fails # Print detailed error if validation fails
if not validation_result['valid']: if not validation_result['valid']:
for error in validation_result['errors']: for error in validation_result['errors']:
logger.error(f"Validation error for /queue: {error}") logger.error(f"Validation error for /queue: {error}")
# Print schema details for debugging # Print schema details for debugging
schema = get_endpoint_schema(api_spec, "/queue", "get") schema = get_endpoint_schema(api_spec, "/queue", "get")
if schema: if schema:
logger.error(f"Schema structure:\n{json.dumps(schema, indent=2)}") logger.error(f"Schema structure:\n{json.dumps(schema, indent=2)}")
# Print response # Print response
logger.error(f"Queue response:\n{json.dumps(queue, indent=2)}") logger.error(f"Queue response:\n{json.dumps(queue, indent=2)}")
assert validation_result['valid'], "Queue response does not match schema" assert validation_result['valid'], "Queue response does not match schema"
except requests.RequestException as e: except requests.RequestException as e:
pytest.fail(f"Request to /queue failed: {str(e)}") pytest.fail(f"Request to /queue failed: {str(e)}")

View File

@ -10,7 +10,7 @@ from typing import Dict, Any
def test_openapi_spec_is_valid(api_spec: Dict[str, Any]): def test_openapi_spec_is_valid(api_spec: Dict[str, Any]):
""" """
Test that the OpenAPI specification is valid Test that the OpenAPI specification is valid
Args: Args:
api_spec: Loaded OpenAPI spec api_spec: Loaded OpenAPI spec
""" """
@ -23,7 +23,7 @@ def test_openapi_spec_is_valid(api_spec: Dict[str, Any]):
def test_spec_has_info(api_spec: Dict[str, Any]): def test_spec_has_info(api_spec: Dict[str, Any]):
""" """
Test that the OpenAPI spec has the required info section Test that the OpenAPI spec has the required info section
Args: Args:
api_spec: Loaded OpenAPI spec api_spec: Loaded OpenAPI spec
""" """
@ -35,7 +35,7 @@ def test_spec_has_info(api_spec: Dict[str, Any]):
def test_spec_has_paths(api_spec: Dict[str, Any]): def test_spec_has_paths(api_spec: Dict[str, Any]):
""" """
Test that the OpenAPI spec has paths defined Test that the OpenAPI spec has paths defined
Args: Args:
api_spec: Loaded OpenAPI spec api_spec: Loaded OpenAPI spec
""" """
@ -46,7 +46,7 @@ def test_spec_has_paths(api_spec: Dict[str, Any]):
def test_spec_has_components(api_spec: Dict[str, Any]): def test_spec_has_components(api_spec: Dict[str, Any]):
""" """
Test that the OpenAPI spec has components defined Test that the OpenAPI spec has components defined
Args: Args:
api_spec: Loaded OpenAPI spec api_spec: Loaded OpenAPI spec
""" """
@ -57,7 +57,7 @@ def test_spec_has_components(api_spec: Dict[str, Any]):
def test_workflow_endpoints_exist(api_spec: Dict[str, Any]): def test_workflow_endpoints_exist(api_spec: Dict[str, Any]):
""" """
Test that core workflow endpoints are defined Test that core workflow endpoints are defined
Args: Args:
api_spec: Loaded OpenAPI spec api_spec: Loaded OpenAPI spec
""" """
@ -69,7 +69,7 @@ def test_workflow_endpoints_exist(api_spec: Dict[str, Any]):
def test_image_endpoints_exist(api_spec: Dict[str, Any]): def test_image_endpoints_exist(api_spec: Dict[str, Any]):
""" """
Test that core image endpoints are defined Test that core image endpoints are defined
Args: Args:
api_spec: Loaded OpenAPI spec api_spec: Loaded OpenAPI spec
""" """
@ -80,7 +80,7 @@ def test_image_endpoints_exist(api_spec: Dict[str, Any]):
def test_model_endpoints_exist(api_spec: Dict[str, Any]): def test_model_endpoints_exist(api_spec: Dict[str, Any]):
""" """
Test that core model endpoints are defined Test that core model endpoints are defined
Args: Args:
api_spec: Loaded OpenAPI spec api_spec: Loaded OpenAPI spec
""" """
@ -91,18 +91,18 @@ def test_model_endpoints_exist(api_spec: Dict[str, Any]):
def test_operation_ids_are_unique(api_spec: Dict[str, Any]): def test_operation_ids_are_unique(api_spec: Dict[str, Any]):
""" """
Test that all operationIds are unique Test that all operationIds are unique
Args: Args:
api_spec: Loaded OpenAPI spec api_spec: Loaded OpenAPI spec
""" """
operation_ids = [] operation_ids = []
for path, path_item in api_spec['paths'].items(): for path, path_item in api_spec['paths'].items():
for method, operation in path_item.items(): for method, operation in path_item.items():
if method in ['get', 'post', 'put', 'delete', 'patch']: if method in ['get', 'post', 'put', 'delete', 'patch']:
if 'operationId' in operation: if 'operationId' in operation:
operation_ids.append(operation['operationId']) operation_ids.append(operation['operationId'])
# Check for duplicates # Check for duplicates
duplicates = set([op_id for op_id in operation_ids if operation_ids.count(op_id) > 1]) duplicates = set([op_id for op_id in operation_ids if operation_ids.count(op_id) > 1])
assert len(duplicates) == 0, f"Found duplicate operationIds: {duplicates}" assert len(duplicates) == 0, f"Found duplicate operationIds: {duplicates}"
@ -111,34 +111,34 @@ def test_operation_ids_are_unique(api_spec: Dict[str, Any]):
def test_all_endpoints_have_operation_ids(api_spec: Dict[str, Any]): def test_all_endpoints_have_operation_ids(api_spec: Dict[str, Any]):
""" """
Test that all endpoints have operationIds Test that all endpoints have operationIds
Args: Args:
api_spec: Loaded OpenAPI spec api_spec: Loaded OpenAPI spec
""" """
missing = [] missing = []
for path, path_item in api_spec['paths'].items(): for path, path_item in api_spec['paths'].items():
for method, operation in path_item.items(): for method, operation in path_item.items():
if method in ['get', 'post', 'put', 'delete', 'patch']: if method in ['get', 'post', 'put', 'delete', 'patch']:
if 'operationId' not in operation: if 'operationId' not in operation:
missing.append(f"{method.upper()} {path}") missing.append(f"{method.upper()} {path}")
assert len(missing) == 0, f"Found endpoints without operationIds: {missing}" assert len(missing) == 0, f"Found endpoints without operationIds: {missing}"
def test_all_endpoints_have_tags(api_spec: Dict[str, Any]): def test_all_endpoints_have_tags(api_spec: Dict[str, Any]):
""" """
Test that all endpoints have tags Test that all endpoints have tags
Args: Args:
api_spec: Loaded OpenAPI spec api_spec: Loaded OpenAPI spec
""" """
missing = [] missing = []
for path, path_item in api_spec['paths'].items(): for path, path_item in api_spec['paths'].items():
for method, operation in path_item.items(): for method, operation in path_item.items():
if method in ['get', 'post', 'put', 'delete', 'patch']: if method in ['get', 'post', 'put', 'delete', 'patch']:
if 'tags' not in operation or not operation['tags']: if 'tags' not in operation or not operation['tags']:
missing.append(f"{method.upper()} {path}") missing.append(f"{method.upper()} {path}")
assert len(missing) == 0, f"Found endpoints without tags: {missing}" assert len(missing) == 0, f"Found endpoints without tags: {missing}"

View File

@ -1,111 +1,109 @@
""" """
Utilities for working with OpenAPI schemas Utilities for working with OpenAPI schemas
""" """
import json
import os
from typing import Any, Dict, List, Optional, Set, Tuple from typing import Any, Dict, List, Optional, Set, Tuple
def extract_required_parameters( def extract_required_parameters(
spec: Dict[str, Any], spec: Dict[str, Any],
path: str, path: str,
method: str method: str
) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]: ) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
""" """
Extract required parameters for a specific endpoint Extract required parameters for a specific endpoint
Args: Args:
spec: Parsed OpenAPI specification spec: Parsed OpenAPI specification
path: API path (e.g., '/prompt') path: API path (e.g., '/prompt')
method: HTTP method (e.g., 'get', 'post') method: HTTP method (e.g., 'get', 'post')
Returns: Returns:
Tuple of (path_params, query_params) containing required parameters Tuple of (path_params, query_params) containing required parameters
""" """
method = method.lower() method = method.lower()
path_params = [] path_params = []
query_params = [] query_params = []
# Handle path not found # Handle path not found
if path not in spec['paths']: if path not in spec['paths']:
return path_params, query_params return path_params, query_params
# Handle method not found # Handle method not found
if method not in spec['paths'][path]: if method not in spec['paths'][path]:
return path_params, query_params return path_params, query_params
# Get parameters # Get parameters
params = spec['paths'][path][method].get('parameters', []) params = spec['paths'][path][method].get('parameters', [])
for param in params: for param in params:
if param.get('required', False): if param.get('required', False):
if param.get('in') == 'path': if param.get('in') == 'path':
path_params.append(param) path_params.append(param)
elif param.get('in') == 'query': elif param.get('in') == 'query':
query_params.append(param) query_params.append(param)
return path_params, query_params return path_params, query_params
def get_request_body_schema( def get_request_body_schema(
spec: Dict[str, Any], spec: Dict[str, Any],
path: str, path: str,
method: str method: str
) -> Optional[Dict[str, Any]]: ) -> Optional[Dict[str, Any]]:
""" """
Get request body schema for a specific endpoint Get request body schema for a specific endpoint
Args: Args:
spec: Parsed OpenAPI specification spec: Parsed OpenAPI specification
path: API path (e.g., '/prompt') path: API path (e.g., '/prompt')
method: HTTP method (e.g., 'get', 'post') method: HTTP method (e.g., 'get', 'post')
Returns: Returns:
Request body schema or None if not found Request body schema or None if not found
""" """
method = method.lower() method = method.lower()
# Handle path not found # Handle path not found
if path not in spec['paths']: if path not in spec['paths']:
return None return None
# Handle method not found # Handle method not found
if method not in spec['paths'][path]: if method not in spec['paths'][path]:
return None return None
# Handle no request body # Handle no request body
request_body = spec['paths'][path][method].get('requestBody', {}) request_body = spec['paths'][path][method].get('requestBody', {})
if not request_body or 'content' not in request_body: if not request_body or 'content' not in request_body:
return None return None
# Get schema from first content type # Get schema from first content type
content_types = request_body['content'] content_types = request_body['content']
first_content_type = next(iter(content_types)) first_content_type = next(iter(content_types))
if 'schema' not in content_types[first_content_type]: if 'schema' not in content_types[first_content_type]:
return None return None
return content_types[first_content_type]['schema'] return content_types[first_content_type]['schema']
def extract_endpoints_by_tag(spec: Dict[str, Any], tag: str) -> List[Dict[str, Any]]: def extract_endpoints_by_tag(spec: Dict[str, Any], tag: str) -> List[Dict[str, Any]]:
""" """
Extract all endpoints with a specific tag Extract all endpoints with a specific tag
Args: Args:
spec: Parsed OpenAPI specification spec: Parsed OpenAPI specification
tag: Tag to filter by tag: Tag to filter by
Returns: Returns:
List of endpoint details List of endpoint details
""" """
endpoints = [] endpoints = []
for path, path_item in spec['paths'].items(): for path, path_item in spec['paths'].items():
for method, operation in path_item.items(): for method, operation in path_item.items():
if method.lower() not in ['get', 'post', 'put', 'delete', 'patch']: if method.lower() not in ['get', 'post', 'put', 'delete', 'patch']:
continue continue
if tag in operation.get('tags', []): if tag in operation.get('tags', []):
endpoints.append({ endpoints.append({
'path': path, 'path': path,
@ -113,47 +111,47 @@ def extract_endpoints_by_tag(spec: Dict[str, Any], tag: str) -> List[Dict[str, A
'operation_id': operation.get('operationId', ''), 'operation_id': operation.get('operationId', ''),
'summary': operation.get('summary', '') 'summary': operation.get('summary', '')
}) })
return endpoints return endpoints
def get_all_tags(spec: Dict[str, Any]) -> Set[str]: def get_all_tags(spec: Dict[str, Any]) -> Set[str]:
""" """
Get all tags used in the API spec Get all tags used in the API spec
Args: Args:
spec: Parsed OpenAPI specification spec: Parsed OpenAPI specification
Returns: Returns:
Set of tag names Set of tag names
""" """
tags = set() tags = set()
for path_item in spec['paths'].values(): for path_item in spec['paths'].values():
for operation in path_item.values(): for operation in path_item.values():
if isinstance(operation, dict) and 'tags' in operation: if isinstance(operation, dict) and 'tags' in operation:
tags.update(operation['tags']) tags.update(operation['tags'])
return tags return tags
def get_schema_examples(spec: Dict[str, Any]) -> Dict[str, Any]: def get_schema_examples(spec: Dict[str, Any]) -> Dict[str, Any]:
""" """
Extract all examples from component schemas Extract all examples from component schemas
Args: Args:
spec: Parsed OpenAPI specification spec: Parsed OpenAPI specification
Returns: Returns:
Dict mapping schema names to examples Dict mapping schema names to examples
""" """
examples = {} examples = {}
if 'components' not in spec or 'schemas' not in spec['components']: if 'components' not in spec or 'schemas' not in spec['components']:
return examples return examples
for name, schema in spec['components']['schemas'].items(): for name, schema in spec['components']['schemas'].items():
if 'example' in schema: if 'example' in schema:
examples[name] = schema['example'] examples[name] = schema['example']
return examples return examples

View File

@ -1,8 +1,6 @@
""" """
Utilities for API response validation against OpenAPI spec Utilities for API response validation against OpenAPI spec
""" """
import json
import os
import yaml import yaml
import jsonschema import jsonschema
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional, Union
@ -11,10 +9,10 @@ from typing import Any, Dict, List, Optional, Union
def load_openapi_spec(spec_path: str) -> Dict[str, Any]: def load_openapi_spec(spec_path: str) -> Dict[str, Any]:
""" """
Load the OpenAPI specification from a YAML file Load the OpenAPI specification from a YAML file
Args: Args:
spec_path: Path to the OpenAPI specification file spec_path: Path to the OpenAPI specification file
Returns: Returns:
Dict containing the parsed OpenAPI spec Dict containing the parsed OpenAPI spec
""" """
@ -23,68 +21,68 @@ def load_openapi_spec(spec_path: str) -> Dict[str, Any]:
def get_endpoint_schema( def get_endpoint_schema(
spec: Dict[str, Any], spec: Dict[str, Any],
path: str, path: str,
method: str, method: str,
status_code: str = '200' status_code: str = '200'
) -> Optional[Dict[str, Any]]: ) -> Optional[Dict[str, Any]]:
""" """
Extract response schema for a specific endpoint from OpenAPI spec Extract response schema for a specific endpoint from OpenAPI spec
Args: Args:
spec: Parsed OpenAPI specification spec: Parsed OpenAPI specification
path: API path (e.g., '/prompt') path: API path (e.g., '/prompt')
method: HTTP method (e.g., 'get', 'post') method: HTTP method (e.g., 'get', 'post')
status_code: HTTP status code to get schema for status_code: HTTP status code to get schema for
Returns: Returns:
Schema dict or None if not found Schema dict or None if not found
""" """
method = method.lower() method = method.lower()
# Handle path not found # Handle path not found
if path not in spec['paths']: if path not in spec['paths']:
return None return None
# Handle method not found # Handle method not found
if method not in spec['paths'][path]: if method not in spec['paths'][path]:
return None return None
# Handle status code not found # Handle status code not found
responses = spec['paths'][path][method].get('responses', {}) responses = spec['paths'][path][method].get('responses', {})
if status_code not in responses: if status_code not in responses:
return None return None
# Handle no content defined # Handle no content defined
if 'content' not in responses[status_code]: if 'content' not in responses[status_code]:
return None return None
# Get schema from first content type # Get schema from first content type
content_types = responses[status_code]['content'] content_types = responses[status_code]['content']
first_content_type = next(iter(content_types)) first_content_type = next(iter(content_types))
if 'schema' not in content_types[first_content_type]: if 'schema' not in content_types[first_content_type]:
return None return None
return content_types[first_content_type]['schema'] return content_types[first_content_type]['schema']
def resolve_schema_refs(schema: Dict[str, Any], spec: Dict[str, Any]) -> Dict[str, Any]: def resolve_schema_refs(schema: Dict[str, Any], spec: Dict[str, Any]) -> Dict[str, Any]:
""" """
Resolve $ref references in a schema Resolve $ref references in a schema
Args: Args:
schema: Schema that may contain references schema: Schema that may contain references
spec: Full OpenAPI spec with component definitions spec: Full OpenAPI spec with component definitions
Returns: Returns:
Schema with references resolved Schema with references resolved
""" """
if not isinstance(schema, dict): if not isinstance(schema, dict):
return schema return schema
result = {} result = {}
for key, value in schema.items(): for key, value in schema.items():
if key == '$ref' and isinstance(value, str) and value.startswith('#/'): if key == '$ref' and isinstance(value, str) and value.startswith('#/'):
# Handle reference # Handle reference
@ -92,7 +90,7 @@ def resolve_schema_refs(schema: Dict[str, Any], spec: Dict[str, Any]) -> Dict[st
ref_value = spec ref_value = spec
for path_part in ref_path: for path_part in ref_path:
ref_value = ref_value.get(path_part, {}) ref_value = ref_value.get(path_part, {})
# Recursively resolve any refs in the referenced schema # Recursively resolve any refs in the referenced schema
ref_value = resolve_schema_refs(ref_value, spec) ref_value = resolve_schema_refs(ref_value, spec)
result.update(ref_value) result.update(ref_value)
@ -108,7 +106,7 @@ def resolve_schema_refs(schema: Dict[str, Any], spec: Dict[str, Any]) -> Dict[st
else: else:
# Pass through other values # Pass through other values
result[key] = value result[key] = value
return result return result
@ -121,30 +119,30 @@ def validate_response(
) -> Dict[str, Any]: ) -> Dict[str, Any]:
""" """
Validate a response against the OpenAPI schema Validate a response against the OpenAPI schema
Args: Args:
response_data: Response data to validate response_data: Response data to validate
spec: Parsed OpenAPI specification spec: Parsed OpenAPI specification
path: API path (e.g., '/prompt') path: API path (e.g., '/prompt')
method: HTTP method (e.g., 'get', 'post') method: HTTP method (e.g., 'get', 'post')
status_code: HTTP status code to validate against status_code: HTTP status code to validate against
Returns: Returns:
Dict with validation result containing: Dict with validation result containing:
- valid: bool indicating if validation passed - valid: bool indicating if validation passed
- errors: List of validation errors if any - errors: List of validation errors if any
""" """
schema = get_endpoint_schema(spec, path, method, status_code) schema = get_endpoint_schema(spec, path, method, status_code)
if schema is None: if schema is None:
return { return {
'valid': False, 'valid': False,
'errors': [f"No schema found for {method.upper()} {path} with status {status_code}"] 'errors': [f"No schema found for {method.upper()} {path} with status {status_code}"]
} }
# Resolve any $ref in the schema # Resolve any $ref in the schema
resolved_schema = resolve_schema_refs(schema, spec) resolved_schema = resolve_schema_refs(schema, spec)
try: try:
jsonschema.validate(instance=response_data, schema=resolved_schema) jsonschema.validate(instance=response_data, schema=resolved_schema)
return {'valid': True, 'errors': []} return {'valid': True, 'errors': []}
@ -155,20 +153,20 @@ def validate_response(
def get_all_endpoints(spec: Dict[str, Any]) -> List[Dict[str, Any]]: def get_all_endpoints(spec: Dict[str, Any]) -> List[Dict[str, Any]]:
""" """
Extract all endpoints from an OpenAPI spec Extract all endpoints from an OpenAPI spec
Args: Args:
spec: Parsed OpenAPI specification spec: Parsed OpenAPI specification
Returns: Returns:
List of dicts with path, method, and tags for each endpoint List of dicts with path, method, and tags for each endpoint
""" """
endpoints = [] endpoints = []
for path, path_item in spec['paths'].items(): for path, path_item in spec['paths'].items():
for method, operation in path_item.items(): for method, operation in path_item.items():
if method.lower() not in ['get', 'post', 'put', 'delete', 'patch']: if method.lower() not in ['get', 'post', 'put', 'delete', 'patch']:
continue continue
endpoints.append({ endpoints.append({
'path': path, 'path': path,
'method': method.lower(), 'method': method.lower(),
@ -176,5 +174,5 @@ def get_all_endpoints(spec: Dict[str, Any]) -> List[Dict[str, Any]]:
'operation_id': operation.get('operationId', ''), 'operation_id': operation.get('operationId', ''),
'summary': operation.get('summary', '') 'summary': operation.get('summary', '')
}) })
return endpoints return endpoints