diff --git a/model_filemanager/__init__.py b/model_filemanager/__init__.py index 16b2165ee..c35e26374 100644 --- a/model_filemanager/__init__.py +++ b/model_filemanager/__init__.py @@ -1,2 +1,2 @@ # model_manager/__init__.py -from .download_models import download_model, DownloadStatus, DownloadModelResult, DownloadStatusType, create_model_path, check_file_exists, track_download_progress, validate_model_subdirectory +from .download_models import download_model, DownloadModelStatus, DownloadStatusType, create_model_path, check_file_exists, track_download_progress, validate_model_subdirectory diff --git a/model_filemanager/download_models.py b/model_filemanager/download_models.py index d4eb11731..003fcb71f 100644 --- a/model_filemanager/download_models.py +++ b/model_filemanager/download_models.py @@ -16,32 +16,23 @@ class DownloadStatusType(Enum): ERROR = "error" @dataclass -class DownloadStatus(): +class DownloadModelStatus(): status: str progress_percentage: float message: str + already_existed: bool = False - def __init__(self, status: DownloadStatusType, progress_percentage: float, message: str): + def __init__(self, status: DownloadStatusType, progress_percentage: float, message: str, already_existed: bool): self.status = status.value # Store the string value of the Enum self.progress_percentage = progress_percentage self.message = message - -@dataclass -class DownloadModelResult(): - status: str - message: str - already_existed: bool - - def __init__(self, status: DownloadStatusType, message: str, already_existed: bool): - self.status = status.value # Store the string value of the Enum - self.message = message self.already_existed = already_existed async def download_model(model_download_request: Callable[[str], Awaitable[aiohttp.ClientResponse]], model_name: str, model_url: str, model_sub_directory: str, - progress_callback: Callable[[str, DownloadStatus], Awaitable[Any]]) -> DownloadModelResult: + progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]]) -> DownloadModelStatus: """ Download a model file from a given URL into the models directory. @@ -55,15 +46,16 @@ async def download_model(model_download_request: Callable[[str], Awaitable[aioht model_sub_directory (str): The subdirectory within the main models directory where the model should be saved (e.g., 'checkpoints', 'loras', etc.). - progress_callback (Callable[[str, DownloadStatus], Awaitable[Any]]): + progress_callback (Callable[[str, DownloadModelStatus], Awaitable[Any]]): An asynchronous function to call with progress updates. Returns: - DownloadModelResult: The result of the download operation. + DownloadModelStatus: The result of the download operation. """ if not validate_model_subdirectory(model_sub_directory): - return DownloadModelResult( + return DownloadModelStatus( DownloadStatusType.ERROR, + 0, "Invalid model subdirectory", False ) @@ -74,16 +66,16 @@ async def download_model(model_download_request: Callable[[str], Awaitable[aioht return existing_file try: - status = DownloadStatus(DownloadStatusType.PENDING, 0, f"Starting download of {model_name}") + status = DownloadModelStatus(DownloadStatusType.PENDING, 0, f"Starting download of {model_name}", False) await progress_callback(relative_path, status) response = await model_download_request(model_url) if response.status != 200: error_message = f"Failed to download {model_name}. Status code: {response.status}" logging.error(error_message) - status = DownloadStatus(DownloadStatusType.ERROR, 0, error_message) + status = DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False) await progress_callback(relative_path, status) - return DownloadModelResult(DownloadStatusType.ERROR, error_message, False) + return DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False) return await track_download_progress(response, file_path, model_name, progress_callback, relative_path) @@ -99,15 +91,23 @@ def create_model_path(model_name: str, model_directory: str, models_base_dir: st relative_path = '/'.join([model_directory, model_name]) return file_path, relative_path -async def check_file_exists(file_path: str, model_name: str, progress_callback: Callable[[str, DownloadStatus], Awaitable[Any]], relative_path: str) -> Optional[DownloadModelResult]: +async def check_file_exists(file_path: str, + model_name: str, + progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]], + relative_path: str) -> Optional[DownloadModelStatus]: if os.path.exists(file_path): - status = DownloadStatus(DownloadStatusType.COMPLETED, 100, f"{model_name} already exists") + status = DownloadModelStatus(DownloadStatusType.COMPLETED, 100, f"{model_name} already exists", True) await progress_callback(relative_path, status) - return DownloadModelResult(DownloadStatusType.COMPLETED, f"{model_name} already exists", True) + return status return None -async def track_download_progress(response: aiohttp.ClientResponse, file_path: str, model_name: str, progress_callback: Callable[[str, DownloadStatus], Awaitable[Any]], relative_path: str, interval: float = 1.0) -> DownloadModelResult: +async def track_download_progress(response: aiohttp.ClientResponse, + file_path: str, + model_name: str, + progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]], + relative_path: str, + interval: float = 1.0) -> DownloadModelStatus: try: total_size = int(response.headers.get('Content-Length', 0)) downloaded = 0 @@ -116,7 +116,7 @@ async def track_download_progress(response: aiohttp.ClientResponse, file_path: s async def update_progress(): nonlocal last_update_time progress = (downloaded / total_size) * 100 if total_size > 0 else 0 - status = DownloadStatus(DownloadStatusType.IN_PROGRESS, progress, f"Downloading {model_name}") + status = DownloadModelStatus(DownloadStatusType.IN_PROGRESS, progress, f"Downloading {model_name}", False) await progress_callback(relative_path, status) last_update_time = time.time() @@ -136,20 +136,23 @@ async def track_download_progress(response: aiohttp.ClientResponse, file_path: s await update_progress() logging.info(f"Successfully downloaded {model_name}. Total downloaded: {downloaded}") - status = DownloadStatus(DownloadStatusType.COMPLETED, 100, f"Successfully downloaded {model_name}") + status = DownloadModelStatus(DownloadStatusType.COMPLETED, 100, f"Successfully downloaded {model_name}", False) await progress_callback(relative_path, status) - return DownloadModelResult(DownloadStatusType.COMPLETED, f"Successfully downloaded {model_name}", False) + return status except Exception as e: logging.error(f"Error in track_download_progress: {e}") logging.error(traceback.format_exc()) return await handle_download_error(e, model_name, progress_callback, relative_path) -async def handle_download_error(e: Exception, model_name: str, progress_callback: Callable[[str, DownloadStatus], Any], relative_path: str) -> DownloadModelResult: +async def handle_download_error(e: Exception, + model_name: str, + progress_callback: Callable[[str, DownloadModelStatus], Any], + relative_path: str) -> DownloadModelStatus: error_message = f"Error downloading {model_name}: {str(e)}" - status = DownloadStatus(DownloadStatusType.ERROR, 0, error_message) + status = DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False) await progress_callback(relative_path, status) - return DownloadModelResult(DownloadStatusType.ERROR, error_message, False) + return status def validate_model_subdirectory(model_subdirectory: str) -> bool: """ diff --git a/tests-unit/prompt_server_test/download_models_test.py b/tests-unit/prompt_server_test/download_models_test.py index abf8e3f7d..26dd94d4c 100644 --- a/tests-unit/prompt_server_test/download_models_test.py +++ b/tests-unit/prompt_server_test/download_models_test.py @@ -4,7 +4,7 @@ from aiohttp import ClientResponse import itertools import os from unittest.mock import AsyncMock, patch, MagicMock -from model_filemanager import download_model, validate_model_subdirectory, track_download_progress, create_model_path, check_file_exists, DownloadStatus, DownloadModelResult, DownloadStatusType +from model_filemanager import download_model, validate_model_subdirectory, track_download_progress, create_model_path, check_file_exists, DownloadStatusType, DownloadModelStatus class AsyncIteratorMock: """ @@ -73,7 +73,7 @@ async def test_download_model_success(): ) # Assert the result - assert isinstance(result, DownloadModelResult) + assert isinstance(result, DownloadModelStatus) assert result.message == 'Successfully downloaded model.bin' assert result.status == 'completed' assert result.already_existed is False @@ -84,13 +84,13 @@ async def test_download_model_success(): # Check initial call mock_progress_callback.assert_any_call( 'checkpoints/model.bin', - DownloadStatus(DownloadStatusType.PENDING, 0, "Starting download of model.bin") + DownloadModelStatus(DownloadStatusType.PENDING, 0, "Starting download of model.bin", False) ) # Check final call mock_progress_callback.assert_any_call( 'checkpoints/model.bin', - DownloadStatus(DownloadStatusType.COMPLETED, 100, "Successfully downloaded model.bin") + DownloadModelStatus(DownloadStatusType.COMPLETED, 100, "Successfully downloaded model.bin", False) ) # Verify file writing @@ -123,7 +123,7 @@ async def test_download_model_url_request_failure(): ) # Assert the expected behavior - assert isinstance(result, DownloadModelResult) + assert isinstance(result, DownloadModelStatus) assert result.status == 'error' assert result.message == 'Failed to download model.safetensors. Status code: 404' assert result.already_existed is False @@ -131,18 +131,20 @@ async def test_download_model_url_request_failure(): # Check that progress_callback was called with the correct arguments mock_progress_callback.assert_any_call( 'mock_directory/model.safetensors', - DownloadStatus( + DownloadModelStatus( status=DownloadStatusType.PENDING, progress_percentage=0, - message='Starting download of model.safetensors' + message='Starting download of model.safetensors', + already_existed=False ) ) mock_progress_callback.assert_called_with( 'mock_directory/model.safetensors', - DownloadStatus( + DownloadModelStatus( status=DownloadStatusType.ERROR, progress_percentage=0, - message='Failed to download model.safetensors. Status code: 404' + message='Failed to download model.safetensors. Status code: 404', + already_existed=False ) ) @@ -165,7 +167,7 @@ async def test_download_model_invalid_model_subdirectory(): ) # Assert the result - assert isinstance(result, DownloadModelResult) + assert isinstance(result, DownloadModelStatus) assert result.message == 'Invalid model subdirectory' assert result.status == 'error' assert result.already_existed is False @@ -202,7 +204,7 @@ async def test_check_file_exists_when_file_exists(tmp_path): mock_callback.assert_called_once_with( "test/existing_model.bin", - DownloadStatus(DownloadStatusType.COMPLETED, 100, "existing_model.bin already exists") + DownloadModelStatus(DownloadStatusType.COMPLETED, 100, "existing_model.bin already exists", already_existed=True) ) @pytest.mark.asyncio @@ -235,7 +237,7 @@ async def test_track_download_progress_no_content_length(): # Check that progress was reported even without knowing the total size mock_callback.assert_any_call( 'models/model.bin', - DownloadStatus(DownloadStatusType.IN_PROGRESS, 0, "Downloading model.bin") + DownloadModelStatus(DownloadStatusType.IN_PROGRESS, 0, "Downloading model.bin", already_existed=False) ) @pytest.mark.asyncio