Validate that model subdirectory cannot contain relative paths.

This commit is contained in:
Robin Huang
2024-08-07 12:45:07 -07:00
parent 9632dded9e
commit 59933489bf
4 changed files with 81 additions and 16 deletions

View File

@@ -3,7 +3,8 @@ import os
import traceback
import logging
from folder_paths import models_dir
from typing import Callable, Any, Optional, Awaitable
import re
from typing import Callable, Any, Optional, Awaitable, Tuple
from enum import Enum
import time
from dataclasses import dataclass
@@ -36,13 +37,38 @@ class DownloadModelResult():
self.message = message
self.already_existed = already_existed
async def download_model(make_request: Callable[[str], Awaitable[aiohttp.ClientResponse]],
async def download_model(model_download_request: Callable[[str], Awaitable[aiohttp.ClientResponse]],
model_name: str,
model_url: str,
model_directory: str,
model_sub_directory: str,
progress_callback: Callable[[str, DownloadStatus], Awaitable[Any]]) -> DownloadModelResult:
"""
Download a model file from a given URL into the models directory.
file_path, relative_path = create_model_path(model_name, model_directory, models_dir)
Args:
model_download_request (Callable[[str], Awaitable[aiohttp.ClientResponse]]):
A function that makes an HTTP request. This makes it easier to mock in unit tests.
model_name (str):
The name of the model file to be downloaded. This will be the filename on disk.
model_url (str):
The URL from which to download the model.
model_sub_directory (str):
The subdirectory within the main models directory where the model
should be saved (e.g., 'checkpoints', 'loras', etc.).
progress_callback (Callable[[str, DownloadStatus], Awaitable[Any]]):
An asynchronous function to call with progress updates.
Returns:
DownloadModelResult: The result of the download operation.
"""
if not validate_model_subdirectory(model_sub_directory):
return DownloadModelResult(
DownloadStatusType.ERROR,
"Invalid model subdirectory",
False
)
file_path, relative_path = create_model_path(model_name, model_sub_directory, models_dir)
existing_file = await check_file_exists(file_path, model_name, progress_callback, relative_path)
if existing_file:
return existing_file
@@ -51,9 +77,10 @@ async def download_model(make_request: Callable[[str], Awaitable[aiohttp.ClientR
status = DownloadStatus(DownloadStatusType.PENDING, 0, f"Starting download of {model_name}")
await progress_callback(relative_path, status)
response = await make_request(model_url)
response = await model_download_request(model_url)
if response.status != 200:
error_message = f"Failed to download {model_name}. Status code: {response.status}"
logging.error(error_message)
status = DownloadStatus(DownloadStatusType.ERROR, 0, error_message)
await progress_callback(relative_path, status)
return DownloadModelResult(DownloadStatusType.ERROR, error_message, False)
@@ -61,15 +88,11 @@ async def download_model(make_request: Callable[[str], Awaitable[aiohttp.ClientR
return await track_download_progress(response, file_path, model_name, progress_callback, relative_path)
except Exception as e:
logging.error(f"Error in track_download_progress: {e}")
logging.error(f"Error in downloading model: {e}")
return await handle_download_error(e, model_name, progress_callback, relative_path)
async def make_http_request(session: aiohttp.ClientSession, url: str) -> aiohttp.ClientResponse:
return await session.get(url)
def create_model_path(model_name: str, model_directory: str, models_base_dir: str) -> tuple[str, str]:
def create_model_path(model_name: str, model_directory: str, models_base_dir: str) -> Tuple[str, str]:
full_model_dir = os.path.join(models_base_dir, model_directory)
os.makedirs(full_model_dir, exist_ok=True)
file_path = os.path.join(full_model_dir, model_name)
@@ -126,4 +149,25 @@ async def handle_download_error(e: Exception, model_name: str, progress_callback
error_message = f"Error downloading {model_name}: {str(e)}"
status = DownloadStatus(DownloadStatusType.ERROR, 0, error_message)
await progress_callback(relative_path, status)
return DownloadModelResult(DownloadStatusType.ERROR, error_message, False)
return DownloadModelResult(DownloadStatusType.ERROR, error_message, False)
def validate_model_subdirectory(model_subdirectory: str) -> bool:
"""
Validate that the model subdirectory is safe.
Args:
model_subdirectory (str): The subdirectory for the specific model type.
Returns:
bool: True if the subdirectory is safe, False otherwise.
"""
if len(model_subdirectory) > 50:
return False
if '..' in model_subdirectory or '/' in model_subdirectory:
return False
if not re.match(r'^[a-zA-Z0-9_-]+$', model_subdirectory):
return False
return True