Consolidate DownloadStatus and DownloadModelResult

This commit is contained in:
Robin Huang
2024-08-07 16:44:56 -07:00
parent a6d8a93fa1
commit c1d78d6890
3 changed files with 47 additions and 42 deletions

View File

@@ -16,32 +16,23 @@ class DownloadStatusType(Enum):
ERROR = "error"
@dataclass
class DownloadStatus():
class DownloadModelStatus():
status: str
progress_percentage: float
message: str
already_existed: bool = False
def __init__(self, status: DownloadStatusType, progress_percentage: float, message: str):
def __init__(self, status: DownloadStatusType, progress_percentage: float, message: str, already_existed: bool):
self.status = status.value # Store the string value of the Enum
self.progress_percentage = progress_percentage
self.message = message
@dataclass
class DownloadModelResult():
status: str
message: str
already_existed: bool
def __init__(self, status: DownloadStatusType, message: str, already_existed: bool):
self.status = status.value # Store the string value of the Enum
self.message = message
self.already_existed = already_existed
async def download_model(model_download_request: Callable[[str], Awaitable[aiohttp.ClientResponse]],
model_name: str,
model_url: str,
model_sub_directory: str,
progress_callback: Callable[[str, DownloadStatus], Awaitable[Any]]) -> DownloadModelResult:
progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]]) -> DownloadModelStatus:
"""
Download a model file from a given URL into the models directory.
@@ -55,15 +46,16 @@ async def download_model(model_download_request: Callable[[str], Awaitable[aioht
model_sub_directory (str):
The subdirectory within the main models directory where the model
should be saved (e.g., 'checkpoints', 'loras', etc.).
progress_callback (Callable[[str, DownloadStatus], Awaitable[Any]]):
progress_callback (Callable[[str, DownloadModelStatus], Awaitable[Any]]):
An asynchronous function to call with progress updates.
Returns:
DownloadModelResult: The result of the download operation.
DownloadModelStatus: The result of the download operation.
"""
if not validate_model_subdirectory(model_sub_directory):
return DownloadModelResult(
return DownloadModelStatus(
DownloadStatusType.ERROR,
0,
"Invalid model subdirectory",
False
)
@@ -74,16 +66,16 @@ async def download_model(model_download_request: Callable[[str], Awaitable[aioht
return existing_file
try:
status = DownloadStatus(DownloadStatusType.PENDING, 0, f"Starting download of {model_name}")
status = DownloadModelStatus(DownloadStatusType.PENDING, 0, f"Starting download of {model_name}", False)
await progress_callback(relative_path, status)
response = await model_download_request(model_url)
if response.status != 200:
error_message = f"Failed to download {model_name}. Status code: {response.status}"
logging.error(error_message)
status = DownloadStatus(DownloadStatusType.ERROR, 0, error_message)
status = DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False)
await progress_callback(relative_path, status)
return DownloadModelResult(DownloadStatusType.ERROR, error_message, False)
return DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False)
return await track_download_progress(response, file_path, model_name, progress_callback, relative_path)
@@ -99,15 +91,23 @@ def create_model_path(model_name: str, model_directory: str, models_base_dir: st
relative_path = '/'.join([model_directory, model_name])
return file_path, relative_path
async def check_file_exists(file_path: str, model_name: str, progress_callback: Callable[[str, DownloadStatus], Awaitable[Any]], relative_path: str) -> Optional[DownloadModelResult]:
async def check_file_exists(file_path: str,
model_name: str,
progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]],
relative_path: str) -> Optional[DownloadModelStatus]:
if os.path.exists(file_path):
status = DownloadStatus(DownloadStatusType.COMPLETED, 100, f"{model_name} already exists")
status = DownloadModelStatus(DownloadStatusType.COMPLETED, 100, f"{model_name} already exists", True)
await progress_callback(relative_path, status)
return DownloadModelResult(DownloadStatusType.COMPLETED, f"{model_name} already exists", True)
return status
return None
async def track_download_progress(response: aiohttp.ClientResponse, file_path: str, model_name: str, progress_callback: Callable[[str, DownloadStatus], Awaitable[Any]], relative_path: str, interval: float = 1.0) -> DownloadModelResult:
async def track_download_progress(response: aiohttp.ClientResponse,
file_path: str,
model_name: str,
progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]],
relative_path: str,
interval: float = 1.0) -> DownloadModelStatus:
try:
total_size = int(response.headers.get('Content-Length', 0))
downloaded = 0
@@ -116,7 +116,7 @@ async def track_download_progress(response: aiohttp.ClientResponse, file_path: s
async def update_progress():
nonlocal last_update_time
progress = (downloaded / total_size) * 100 if total_size > 0 else 0
status = DownloadStatus(DownloadStatusType.IN_PROGRESS, progress, f"Downloading {model_name}")
status = DownloadModelStatus(DownloadStatusType.IN_PROGRESS, progress, f"Downloading {model_name}", False)
await progress_callback(relative_path, status)
last_update_time = time.time()
@@ -136,20 +136,23 @@ async def track_download_progress(response: aiohttp.ClientResponse, file_path: s
await update_progress()
logging.info(f"Successfully downloaded {model_name}. Total downloaded: {downloaded}")
status = DownloadStatus(DownloadStatusType.COMPLETED, 100, f"Successfully downloaded {model_name}")
status = DownloadModelStatus(DownloadStatusType.COMPLETED, 100, f"Successfully downloaded {model_name}", False)
await progress_callback(relative_path, status)
return DownloadModelResult(DownloadStatusType.COMPLETED, f"Successfully downloaded {model_name}", False)
return status
except Exception as e:
logging.error(f"Error in track_download_progress: {e}")
logging.error(traceback.format_exc())
return await handle_download_error(e, model_name, progress_callback, relative_path)
async def handle_download_error(e: Exception, model_name: str, progress_callback: Callable[[str, DownloadStatus], Any], relative_path: str) -> DownloadModelResult:
async def handle_download_error(e: Exception,
model_name: str,
progress_callback: Callable[[str, DownloadModelStatus], Any],
relative_path: str) -> DownloadModelStatus:
error_message = f"Error downloading {model_name}: {str(e)}"
status = DownloadStatus(DownloadStatusType.ERROR, 0, error_message)
status = DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False)
await progress_callback(relative_path, status)
return DownloadModelResult(DownloadStatusType.ERROR, error_message, False)
return status
def validate_model_subdirectory(model_subdirectory: str) -> bool:
"""