mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-09-14 05:25:23 +00:00
Add model downloading endpoint.
This commit is contained in:
68
tests-unit/prompt_server_test/download_models_test.py
Normal file
68
tests-unit/prompt_server_test/download_models_test.py
Normal file
@@ -0,0 +1,68 @@
|
||||
import pytest
|
||||
import aiohttp
|
||||
import uuid
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from model_filemanager import download_model, DownloadStatus, DownloadStatusType
|
||||
|
||||
|
||||
async def async_iterator(chunks):
|
||||
for chunk in chunks:
|
||||
yield chunk
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_download_model_success():
|
||||
# Create a temporary directory for testing
|
||||
model_directory = str(uuid.uuid4())
|
||||
|
||||
# Create a mock session
|
||||
session = AsyncMock(spec=aiohttp.ClientSession)
|
||||
|
||||
# Mock the response
|
||||
mock_response = MagicMock(spec=aiohttp.ClientResponse)
|
||||
mock_response.status = 200
|
||||
mock_response.headers = {'Content-Length': '100'}
|
||||
mock_response.content.iter_chunked.return_value = async_iterator([b'chunk1', b'chunk2'])
|
||||
|
||||
session.get.return_value.__aenter__.return_value = mock_response
|
||||
|
||||
# Create a mock progress callback
|
||||
progress_callback = AsyncMock()
|
||||
|
||||
# Call the function
|
||||
result = await download_model(session, 'model.safetensors', 'http://example.com/model.safetensors', model_directory, progress_callback)
|
||||
|
||||
# Assert the expected behavior
|
||||
assert result['status'] == DownloadStatusType.COMPLETED
|
||||
assert result['message'] == 'Successfully downloaded model.safetensors'
|
||||
assert result['already_existed'] is False
|
||||
relative_path = '/'.join([model_directory, 'model.safetensors'])
|
||||
progress_callback.assert_awaited_with(relative_path, DownloadStatus(status=DownloadStatusType.COMPLETED, progress_percentage=100, message='Successfully downloaded model.safetensors'))
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_download_model_failure():
|
||||
# Create a temporary directory for testing
|
||||
model_directory = str(uuid.uuid4())
|
||||
|
||||
# Create a mock session
|
||||
session = AsyncMock(spec=aiohttp.ClientSession)
|
||||
|
||||
# Mock the response with an error status code
|
||||
mock_response = MagicMock(spec=aiohttp.ClientResponse)
|
||||
mock_response.status = 500
|
||||
session.get.return_value.__aenter__.return_value = mock_response
|
||||
|
||||
# Create a mock progress callback
|
||||
progress_callback = AsyncMock()
|
||||
|
||||
# Call the function
|
||||
result = await download_model(session, 'model.safetensors', 'http://example.com/model.safetensors', model_directory, progress_callback)
|
||||
print(result)
|
||||
|
||||
# Assert the expected behavior
|
||||
assert result['status'] == DownloadStatusType.ERROR
|
||||
assert result['message'].strip() == 'Failed to download model.safetensors. Status code: 500'
|
||||
assert result['already_existed'] is False
|
||||
|
||||
relative_path = '/'.join([model_directory, 'model.safetensors'])
|
||||
progress_callback.assert_awaited_with(relative_path, DownloadStatus(status=DownloadStatusType.ERROR, progress_percentage=0, message='Failed to download model.safetensors. Status code: 500'))
|
Reference in New Issue
Block a user