mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-09-10 11:35:40 +00:00
fix: add cache headers for images (#9560)
This commit is contained in:
1
middleware/__init__.py
Normal file
1
middleware/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Server middleware modules"""
|
52
middleware/cache_middleware.py
Normal file
52
middleware/cache_middleware.py
Normal file
@@ -0,0 +1,52 @@
|
||||
"""Cache control middleware for ComfyUI server"""
|
||||
|
||||
from aiohttp import web
|
||||
from typing import Callable, Awaitable
|
||||
|
||||
# Time in seconds
|
||||
ONE_HOUR: int = 3600
|
||||
ONE_DAY: int = 86400
|
||||
IMG_EXTENSIONS = (
|
||||
".jpg",
|
||||
".jpeg",
|
||||
".png",
|
||||
".ppm",
|
||||
".bmp",
|
||||
".pgm",
|
||||
".tif",
|
||||
".tiff",
|
||||
".webp",
|
||||
)
|
||||
|
||||
|
||||
@web.middleware
|
||||
async def cache_control(
|
||||
request: web.Request, handler: Callable[[web.Request], Awaitable[web.Response]]
|
||||
) -> web.Response:
|
||||
"""Cache control middleware that sets appropriate cache headers based on file type and response status"""
|
||||
response: web.Response = await handler(request)
|
||||
|
||||
if (
|
||||
request.path.endswith(".js")
|
||||
or request.path.endswith(".css")
|
||||
or request.path.endswith("index.json")
|
||||
):
|
||||
response.headers.setdefault("Cache-Control", "no-cache")
|
||||
return response
|
||||
|
||||
# Early return for non-image files - no cache headers needed
|
||||
if not request.path.lower().endswith(IMG_EXTENSIONS):
|
||||
return response
|
||||
|
||||
# Handle image files
|
||||
if response.status == 404:
|
||||
response.headers.setdefault("Cache-Control", f"public, max-age={ONE_HOUR}")
|
||||
elif response.status in (200, 201, 202, 203, 204, 205, 206, 301, 308):
|
||||
# Success responses and permanent redirects - cache for 1 day
|
||||
response.headers.setdefault("Cache-Control", f"public, max-age={ONE_DAY}")
|
||||
elif response.status in (302, 303, 307):
|
||||
# Temporary redirects - no cache
|
||||
response.headers.setdefault("Cache-Control", "no-cache")
|
||||
# Note: 304 Not Modified falls through - no cache headers set
|
||||
|
||||
return response
|
11
server.py
11
server.py
@@ -39,20 +39,15 @@ from typing import Optional, Union
|
||||
from api_server.routes.internal.internal_routes import InternalRoutes
|
||||
from protocol import BinaryEventTypes
|
||||
|
||||
# Import cache control middleware
|
||||
from middleware.cache_middleware import cache_control
|
||||
|
||||
async def send_socket_catch_exception(function, message):
|
||||
try:
|
||||
await function(message)
|
||||
except (aiohttp.ClientError, aiohttp.ClientPayloadError, ConnectionResetError, BrokenPipeError, ConnectionError) as err:
|
||||
logging.warning("send error: {}".format(err))
|
||||
|
||||
@web.middleware
|
||||
async def cache_control(request: web.Request, handler):
|
||||
response: web.Response = await handler(request)
|
||||
if request.path.endswith('.js') or request.path.endswith('.css') or request.path.endswith('index.json'):
|
||||
response.headers.setdefault('Cache-Control', 'no-cache')
|
||||
return response
|
||||
|
||||
|
||||
@web.middleware
|
||||
async def compress_body(request: web.Request, handler):
|
||||
accept_encoding = request.headers.get("Accept-Encoding", "")
|
||||
|
255
tests-unit/server_test/test_cache_control.py
Normal file
255
tests-unit/server_test/test_cache_control.py
Normal file
@@ -0,0 +1,255 @@
|
||||
"""Tests for server cache control middleware"""
|
||||
|
||||
import pytest
|
||||
from aiohttp import web
|
||||
from aiohttp.test_utils import make_mocked_request
|
||||
from typing import Dict, Any
|
||||
|
||||
from middleware.cache_middleware import cache_control, ONE_HOUR, ONE_DAY, IMG_EXTENSIONS
|
||||
|
||||
pytestmark = pytest.mark.asyncio # Apply asyncio mark to all tests
|
||||
|
||||
# Test configuration data
|
||||
CACHE_SCENARIOS = [
|
||||
# Image file scenarios
|
||||
{
|
||||
"name": "image_200_status",
|
||||
"path": "/test.jpg",
|
||||
"status": 200,
|
||||
"expected_cache": f"public, max-age={ONE_DAY}",
|
||||
"should_have_header": True,
|
||||
},
|
||||
{
|
||||
"name": "image_404_status",
|
||||
"path": "/missing.jpg",
|
||||
"status": 404,
|
||||
"expected_cache": f"public, max-age={ONE_HOUR}",
|
||||
"should_have_header": True,
|
||||
},
|
||||
# JavaScript/CSS scenarios
|
||||
{
|
||||
"name": "js_no_cache",
|
||||
"path": "/script.js",
|
||||
"status": 200,
|
||||
"expected_cache": "no-cache",
|
||||
"should_have_header": True,
|
||||
},
|
||||
{
|
||||
"name": "css_no_cache",
|
||||
"path": "/styles.css",
|
||||
"status": 200,
|
||||
"expected_cache": "no-cache",
|
||||
"should_have_header": True,
|
||||
},
|
||||
{
|
||||
"name": "index_json_no_cache",
|
||||
"path": "/api/index.json",
|
||||
"status": 200,
|
||||
"expected_cache": "no-cache",
|
||||
"should_have_header": True,
|
||||
},
|
||||
# Non-matching files
|
||||
{
|
||||
"name": "html_no_header",
|
||||
"path": "/index.html",
|
||||
"status": 200,
|
||||
"expected_cache": None,
|
||||
"should_have_header": False,
|
||||
},
|
||||
{
|
||||
"name": "txt_no_header",
|
||||
"path": "/data.txt",
|
||||
"status": 200,
|
||||
"expected_cache": None,
|
||||
"should_have_header": False,
|
||||
},
|
||||
{
|
||||
"name": "api_endpoint_no_header",
|
||||
"path": "/api/endpoint",
|
||||
"status": 200,
|
||||
"expected_cache": None,
|
||||
"should_have_header": False,
|
||||
},
|
||||
{
|
||||
"name": "pdf_no_header",
|
||||
"path": "/file.pdf",
|
||||
"status": 200,
|
||||
"expected_cache": None,
|
||||
"should_have_header": False,
|
||||
},
|
||||
]
|
||||
|
||||
# Status code scenarios for images
|
||||
IMAGE_STATUS_SCENARIOS = [
|
||||
# Success statuses get long cache
|
||||
{"status": 200, "expected": f"public, max-age={ONE_DAY}"},
|
||||
{"status": 201, "expected": f"public, max-age={ONE_DAY}"},
|
||||
{"status": 202, "expected": f"public, max-age={ONE_DAY}"},
|
||||
{"status": 204, "expected": f"public, max-age={ONE_DAY}"},
|
||||
{"status": 206, "expected": f"public, max-age={ONE_DAY}"},
|
||||
# Permanent redirects get long cache
|
||||
{"status": 301, "expected": f"public, max-age={ONE_DAY}"},
|
||||
{"status": 308, "expected": f"public, max-age={ONE_DAY}"},
|
||||
# Temporary redirects get no cache
|
||||
{"status": 302, "expected": "no-cache"},
|
||||
{"status": 303, "expected": "no-cache"},
|
||||
{"status": 307, "expected": "no-cache"},
|
||||
# 404 gets short cache
|
||||
{"status": 404, "expected": f"public, max-age={ONE_HOUR}"},
|
||||
]
|
||||
|
||||
# Case sensitivity test paths
|
||||
CASE_SENSITIVITY_PATHS = ["/image.JPG", "/photo.PNG", "/pic.JpEg"]
|
||||
|
||||
# Edge case test paths
|
||||
EDGE_CASE_PATHS = [
|
||||
{
|
||||
"name": "query_strings_ignored",
|
||||
"path": "/image.jpg?v=123&size=large",
|
||||
"expected": f"public, max-age={ONE_DAY}",
|
||||
},
|
||||
{
|
||||
"name": "multiple_dots_in_path",
|
||||
"path": "/image.min.jpg",
|
||||
"expected": f"public, max-age={ONE_DAY}",
|
||||
},
|
||||
{
|
||||
"name": "nested_paths_with_images",
|
||||
"path": "/static/images/photo.jpg",
|
||||
"expected": f"public, max-age={ONE_DAY}",
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
class TestCacheControl:
|
||||
"""Test cache control middleware functionality"""
|
||||
|
||||
@pytest.fixture
|
||||
def status_handler_factory(self):
|
||||
"""Create a factory for handlers that return specific status codes"""
|
||||
|
||||
def factory(status: int, headers: Dict[str, str] = None):
|
||||
async def handler(request):
|
||||
return web.Response(status=status, headers=headers or {})
|
||||
|
||||
return handler
|
||||
|
||||
return factory
|
||||
|
||||
@pytest.fixture
|
||||
def mock_handler(self, status_handler_factory):
|
||||
"""Create a mock handler that returns a response with 200 status"""
|
||||
return status_handler_factory(200)
|
||||
|
||||
@pytest.fixture
|
||||
def handler_with_existing_cache(self, status_handler_factory):
|
||||
"""Create a handler that returns response with existing Cache-Control header"""
|
||||
return status_handler_factory(200, {"Cache-Control": "max-age=3600"})
|
||||
|
||||
async def assert_cache_header(
|
||||
self,
|
||||
response: web.Response,
|
||||
expected_cache: str = None,
|
||||
should_have_header: bool = True,
|
||||
):
|
||||
"""Helper to assert cache control headers"""
|
||||
if should_have_header:
|
||||
assert "Cache-Control" in response.headers
|
||||
if expected_cache:
|
||||
assert response.headers["Cache-Control"] == expected_cache
|
||||
else:
|
||||
assert "Cache-Control" not in response.headers
|
||||
|
||||
# Parameterized tests
|
||||
@pytest.mark.parametrize("scenario", CACHE_SCENARIOS, ids=lambda x: x["name"])
|
||||
async def test_cache_control_scenarios(
|
||||
self, scenario: Dict[str, Any], status_handler_factory
|
||||
):
|
||||
"""Test various cache control scenarios"""
|
||||
handler = status_handler_factory(scenario["status"])
|
||||
request = make_mocked_request("GET", scenario["path"])
|
||||
response = await cache_control(request, handler)
|
||||
|
||||
assert response.status == scenario["status"]
|
||||
await self.assert_cache_header(
|
||||
response, scenario["expected_cache"], scenario["should_have_header"]
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize("ext", IMG_EXTENSIONS)
|
||||
async def test_all_image_extensions(self, ext: str, mock_handler):
|
||||
"""Test all defined image extensions are handled correctly"""
|
||||
request = make_mocked_request("GET", f"/image{ext}")
|
||||
response = await cache_control(request, mock_handler)
|
||||
|
||||
assert response.status == 200
|
||||
assert "Cache-Control" in response.headers
|
||||
assert response.headers["Cache-Control"] == f"public, max-age={ONE_DAY}"
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"status_scenario", IMAGE_STATUS_SCENARIOS, ids=lambda x: f"status_{x['status']}"
|
||||
)
|
||||
async def test_image_status_codes(
|
||||
self, status_scenario: Dict[str, Any], status_handler_factory
|
||||
):
|
||||
"""Test different status codes for image requests"""
|
||||
handler = status_handler_factory(status_scenario["status"])
|
||||
request = make_mocked_request("GET", "/image.jpg")
|
||||
response = await cache_control(request, handler)
|
||||
|
||||
assert response.status == status_scenario["status"]
|
||||
assert "Cache-Control" in response.headers
|
||||
assert response.headers["Cache-Control"] == status_scenario["expected"]
|
||||
|
||||
@pytest.mark.parametrize("path", CASE_SENSITIVITY_PATHS)
|
||||
async def test_case_insensitive_image_extension(self, path: str, mock_handler):
|
||||
"""Test that image extensions are matched case-insensitively"""
|
||||
request = make_mocked_request("GET", path)
|
||||
response = await cache_control(request, mock_handler)
|
||||
|
||||
assert "Cache-Control" in response.headers
|
||||
assert response.headers["Cache-Control"] == f"public, max-age={ONE_DAY}"
|
||||
|
||||
@pytest.mark.parametrize("edge_case", EDGE_CASE_PATHS, ids=lambda x: x["name"])
|
||||
async def test_edge_cases(self, edge_case: Dict[str, str], mock_handler):
|
||||
"""Test edge cases like query strings, nested paths, etc."""
|
||||
request = make_mocked_request("GET", edge_case["path"])
|
||||
response = await cache_control(request, mock_handler)
|
||||
|
||||
assert "Cache-Control" in response.headers
|
||||
assert response.headers["Cache-Control"] == edge_case["expected"]
|
||||
|
||||
# Header preservation tests (special cases not covered by parameterization)
|
||||
async def test_js_preserves_existing_headers(self, handler_with_existing_cache):
|
||||
"""Test that .js files preserve existing Cache-Control headers"""
|
||||
request = make_mocked_request("GET", "/script.js")
|
||||
response = await cache_control(request, handler_with_existing_cache)
|
||||
|
||||
# setdefault should preserve existing header
|
||||
assert response.headers["Cache-Control"] == "max-age=3600"
|
||||
|
||||
async def test_css_preserves_existing_headers(self, handler_with_existing_cache):
|
||||
"""Test that .css files preserve existing Cache-Control headers"""
|
||||
request = make_mocked_request("GET", "/styles.css")
|
||||
response = await cache_control(request, handler_with_existing_cache)
|
||||
|
||||
# setdefault should preserve existing header
|
||||
assert response.headers["Cache-Control"] == "max-age=3600"
|
||||
|
||||
async def test_image_preserves_existing_headers(self, status_handler_factory):
|
||||
"""Test that image cache headers preserve existing Cache-Control"""
|
||||
handler = status_handler_factory(200, {"Cache-Control": "private, no-cache"})
|
||||
request = make_mocked_request("GET", "/image.jpg")
|
||||
response = await cache_control(request, handler)
|
||||
|
||||
# setdefault should preserve existing header
|
||||
assert response.headers["Cache-Control"] == "private, no-cache"
|
||||
|
||||
async def test_304_not_modified_inherits_cache(self, status_handler_factory):
|
||||
"""Test that 304 Not Modified doesn't set cache headers for images"""
|
||||
handler = status_handler_factory(304, {"Cache-Control": "max-age=7200"})
|
||||
request = make_mocked_request("GET", "/not-modified.jpg")
|
||||
response = await cache_control(request, handler)
|
||||
|
||||
assert response.status == 304
|
||||
# Should preserve existing cache header, not override
|
||||
assert response.headers["Cache-Control"] == "max-age=7200"
|
Reference in New Issue
Block a user