diff --git a/alembic_db/versions/565b08122d00_init.py b/alembic_db/versions/e9c714da8d57_init.py similarity index 59% rename from alembic_db/versions/565b08122d00_init.py rename to alembic_db/versions/e9c714da8d57_init.py index 9a8a51fb..1a296104 100644 --- a/alembic_db/versions/565b08122d00_init.py +++ b/alembic_db/versions/e9c714da8d57_init.py @@ -1,8 +1,8 @@ """init -Revision ID: 565b08122d00 +Revision ID: e9c714da8d57 Revises: -Create Date: 2025-05-29 19:15:56.230322 +Create Date: 2025-05-30 20:14:33.772039 """ from typing import Sequence, Union @@ -12,7 +12,7 @@ import sqlalchemy as sa # revision identifiers, used by Alembic. -revision: str = '565b08122d00' +revision: str = 'e9c714da8d57' down_revision: Union[str, None] = None branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None @@ -20,15 +20,23 @@ depends_on: Union[str, Sequence[str], None] = None def upgrade() -> None: """Upgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### op.create_table('model', sa.Column('type', sa.Text(), nullable=False), sa.Column('path', sa.Text(), nullable=False), + sa.Column('file_name', sa.Text(), nullable=True), + sa.Column('file_size', sa.Integer(), nullable=True), sa.Column('hash', sa.Text(), nullable=True), + sa.Column('hash_algorithm', sa.Text(), nullable=True), + sa.Column('source_url', sa.Text(), nullable=True), sa.Column('date_added', sa.DateTime(), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True), sa.PrimaryKeyConstraint('type', 'path') ) + # ### end Alembic commands ### def downgrade() -> None: """Downgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### op.drop_table('model') + # ### end Alembic commands ### diff --git a/app/database/db.py b/app/database/db.py index d17fa4f1..45bcfcba 100644 --- a/app/database/db.py +++ b/app/database/db.py @@ -1,29 +1,50 @@ import logging import os import shutil +from app.logger import log_startup_warning from utils.install_util import get_missing_requirements_message from comfy.cli_args import args +_DB_AVAILABLE = False Session = None -def can_create_session(): - return Session is not None - - try: - import alembic - import sqlalchemy -except ImportError as e: - logging.error(get_missing_requirements_message()) - raise e + from alembic import command + from alembic.config import Config + from alembic.runtime.migration import MigrationContext + from alembic.script import ScriptDirectory + from sqlalchemy import create_engine + from sqlalchemy.orm import sessionmaker -from alembic import command -from alembic.config import Config -from alembic.runtime.migration import MigrationContext -from alembic.script import ScriptDirectory -from sqlalchemy import create_engine -from sqlalchemy.orm import sessionmaker + _DB_AVAILABLE = True +except ImportError as e: + log_startup_warning( + f""" +------------------------------------------------------------------------ +Error importing dependencies: {e} + +{get_missing_requirements_message()} + +This error is happening because ComfyUI now uses a local sqlite database. +------------------------------------------------------------------------ +""".strip() + ) + + +def dependencies_available(): + """ + Temporary function to check if the dependencies are available + """ + return _DB_AVAILABLE + + +def can_create_session(): + """ + Temporary function to check if the database is available to create a session + During initial release there may be environmental issues (or missing dependencies) that prevent the database from being created + """ + return dependencies_available() and Session is not None def get_alembic_config(): @@ -49,6 +70,8 @@ def get_db_path(): def init_db(): db_url = args.database_url logging.debug(f"Database URL: {db_url}") + db_path = get_db_path() + db_exists = os.path.exists(db_path) config = get_alembic_config() @@ -64,9 +87,8 @@ def init_db(): if current_rev != target_rev: # Backup the database pre upgrade - db_path = get_db_path() backup_path = db_path + ".bkp" - if os.path.exists(db_path): + if db_exists: shutil.copy(db_path, backup_path) else: backup_path = None diff --git a/app/database/models.py b/app/database/models.py index d2c1e042..b0225c41 100644 --- a/app/database/models.py +++ b/app/database/models.py @@ -1,5 +1,6 @@ from sqlalchemy import ( Column, + Integer, Text, DateTime, ) @@ -20,7 +21,7 @@ def to_dict(obj): class Model(Base): """ - SQLAlchemy model representing a model file in the system. + sqlalchemy model representing a model file in the system. This class defines the database schema for storing information about model files, including their type, path, hash, and when they were added to the system. @@ -28,7 +29,11 @@ class Model(Base): Attributes: type (Text): The type of the model, this is the name of the folder in the models folder (primary key) path (Text): The file path of the model relative to the type folder (primary key) - hash (Text): A sha256 hash of the model file + file_name (Text): The name of the model file + file_size (Integer): The size of the model file in bytes + hash (Text): A hash of the model file + hash_algorithm (Text): The algorithm used to generate the hash + source_url (Text): The URL of the model file date_added (DateTime): Timestamp of when the model was added to the system """ @@ -36,7 +41,11 @@ class Model(Base): type = Column(Text, primary_key=True) path = Column(Text, primary_key=True) + file_name = Column(Text) + file_size = Column(Integer) hash = Column(Text) + hash_algorithm = Column(Text) + source_url = Column(Text) date_added = Column(DateTime, server_default=func.now()) def to_dict(self): diff --git a/app/model_processor.py b/app/model_processor.py index 98094026..6cf8fd6f 100644 --- a/app/model_processor.py +++ b/app/model_processor.py @@ -1,16 +1,23 @@ -import hashlib import os import logging import time -from app.database.models import Model -from app.database.db import create_session -from folder_paths import get_relative_path + +import requests +from tqdm import tqdm +from folder_paths import get_relative_path, get_full_path +from app.database.db import create_session, dependencies_available, can_create_session +import blake3 +import comfy.utils + + +if dependencies_available(): + from app.database.models import Model class ModelProcessor: def _validate_path(self, model_path): try: - if not os.path.exists(model_path): + if not self._file_exists(model_path): logging.error(f"Model file not found: {model_path}") return None @@ -26,15 +33,26 @@ class ModelProcessor: logging.error(f"Error validating model path {model_path}: {str(e)}") return None + def _file_exists(self, path): + """Check if a file exists.""" + return os.path.exists(path) + + def _get_file_size(self, path): + """Get file size.""" + return os.path.getsize(path) + + def _get_hasher(self): + return blake3.blake3() + def _hash_file(self, model_path): try: - h = hashlib.sha256() + hasher = self._get_hasher() with open(model_path, "rb", buffering=0) as f: b = bytearray(128 * 1024) mv = memoryview(b) while n := f.readinto(mv): - h.update(mv[:n]) - return h.hexdigest() + hasher.update(mv[:n]) + return hasher.hexdigest() except Exception as e: logging.error(f"Error hashing file {model_path}: {str(e)}") return None @@ -46,9 +64,21 @@ class ModelProcessor: .filter(Model.path == model_relative_path) .first() ) + + def _ensure_source_url(self, session, model, source_url): + if model.source_url is None: + model.source_url = source_url + session.commit() def _update_database( - self, session, model_type, model_relative_path, model_hash, model=None + self, + session, + model_type, + model_path, + model_relative_path, + model_hash, + model, + source_url, ): try: if not model: @@ -60,10 +90,16 @@ class ModelProcessor: model = Model( path=model_relative_path, type=model_type, + file_name=os.path.basename(model_path), ) session.add(model) + model.file_size = self._get_file_size(model_path) model.hash = model_hash + if model_hash: + model.hash_algorithm = "blake3" + model.source_url = source_url + session.commit() return model except Exception as e: @@ -71,36 +107,97 @@ class ModelProcessor: f"Error updating database for {model_relative_path}: {str(e)}" ) - def process_file(self, model_path): + def process_file(self, model_path, source_url=None, model_hash=None): + """ + Process a model file and update the database with metadata. + If the file already exists and matches the database, it will not be processed again. + Returns the model object or if an error occurs, returns None. + """ try: + if not can_create_session(): + return + result = self._validate_path(model_path) if not result: return model_type, model_relative_path = result with create_session() as session: + session.expire_on_commit = False + existing_model = self._get_existing_model( session, model_type, model_relative_path ) - if existing_model and existing_model.hash: - # File exists with hash, no need to process + if ( + existing_model + and existing_model.hash + and existing_model.file_size == self._get_file_size(model_path) + ): + # File exists with hash and same size, no need to process + self._ensure_source_url(session, existing_model, source_url) return existing_model - start_time = time.time() - logging.info(f"Hashing model {model_relative_path}") - model_hash = self._hash_file(model_path) - if not model_hash: - return - logging.info( - f"Model hash: {model_hash} (duration: {time.time() - start_time} seconds)" - ) + if model_hash: + model_hash = model_hash.lower() + logging.info(f"Using provided hash: {model_hash}") + else: + start_time = time.time() + logging.info(f"Hashing model {model_relative_path}") + model_hash = self._hash_file(model_path) + if not model_hash: + return + logging.info( + f"Model hash: {model_hash} (duration: {time.time() - start_time} seconds)" + ) - return self._update_database(session, model_type, model_relative_path, model_hash) + return self._update_database( + session, + model_type, + model_path, + model_relative_path, + model_hash, + existing_model, + source_url, + ) except Exception as e: logging.error(f"Error processing model file {model_path}: {str(e)}") + return None + + def retrieve_model_by_hash(self, model_hash, model_type=None, session=None): + """ + Retrieve a model file from the database by hash and optionally by model type. + Returns the model object or None if the model doesnt exist or an error occurs. + """ + try: + if not can_create_session(): + return + + dispose_session = False + + if session is None: + session = create_session() + dispose_session = True + + model = session.query(Model).filter(Model.hash == model_hash) + if model_type is not None: + model = model.filter(Model.type == model_type) + return model.first() + except Exception as e: + logging.error(f"Error retrieving model by hash {model_hash}: {str(e)}") + return None + finally: + if dispose_session: + session.close() def retrieve_hash(self, model_path, model_type=None): + """ + Retrieve the hash of a model file from the database. + Returns the hash or None if the model doesnt exist or an error occurs. + """ try: + if not can_create_session(): + return + if model_type is not None: result = self._validate_path(model_path) if not result: @@ -118,5 +215,117 @@ class ModelProcessor: logging.error(f"Error retrieving hash for {model_path}: {str(e)}") return None + def _validate_file_extension(self, file_name): + """Validate that the file extension is supported.""" + extension = os.path.splitext(file_name)[1] + if extension not in (".safetensors", ".sft", ".txt", ".csv", ".json", ".yaml"): + raise ValueError(f"Unsupported unsafe file for download: {file_name}") + + def _check_existing_file(self, model_type, file_name, expected_hash): + """Check if file exists and has correct hash.""" + destination_path = get_full_path(model_type, file_name, allow_missing=True) + if self._file_exists(destination_path): + model = self.process_file(destination_path) + if model and (expected_hash is None or model.hash == expected_hash): + logging.debug( + f"File {destination_path} already exists in the database and has the correct hash or no hash was provided." + ) + return destination_path + else: + raise ValueError( + f"File {destination_path} exists with hash {model.hash if model else 'unknown'} but expected {expected_hash}. Please delete the file and try again." + ) + return None + + def _check_existing_file_by_hash(self, hash, type, url): + """Check if a file with the given hash exists in the database and on disk.""" + hash = hash.lower() + with create_session() as session: + model = self.retrieve_model_by_hash(hash, type, session) + if model: + existing_path = get_full_path(type, model.path) + if existing_path: + logging.debug( + f"File {model.path} already exists in the database at {existing_path}" + ) + self._ensure_source_url(session, model, url) + return existing_path + else: + logging.debug( + f"File {model.path} exists in the database but not on disk" + ) + return None + + def _download_file(self, url, destination_path, hasher): + """Download a file and update the hasher with its contents.""" + response = requests.get(url, stream=True) + logging.info(f"Downloading {url} to {destination_path}") + + with open(destination_path, "wb") as f: + total_size = int(response.headers.get("content-length", 0)) + if total_size > 0: + pbar = comfy.utils.ProgressBar(total_size) + else: + pbar = None + with tqdm(total=total_size, unit="B", unit_scale=True) as progress_bar: + for chunk in response.iter_content(chunk_size=128 * 1024): + if chunk: + f.write(chunk) + hasher.update(chunk) + progress_bar.update(len(chunk)) + if pbar: + pbar.update(len(chunk)) + + def _verify_downloaded_hash(self, calculated_hash, expected_hash, destination_path): + """Verify that the downloaded file has the expected hash.""" + if expected_hash is not None and calculated_hash != expected_hash: + self._remove_file(destination_path) + raise ValueError( + f"Downloaded file hash {calculated_hash} does not match expected hash {expected_hash}" + ) + + def _remove_file(self, file_path): + """Remove a file from disk.""" + os.remove(file_path) + + def ensure_downloaded(self, type, url, desired_file_name, hash=None): + """ + Ensure a model file is downloaded and has the correct hash. + Returns the path to the downloaded file. + """ + logging.debug( + f"Ensuring {type} file is downloaded. URL='{url}' Destination='{desired_file_name}' Hash='{hash}'" + ) + + # Validate file extension + self._validate_file_extension(desired_file_name) + + # Check if file exists with correct hash + if hash: + existing_path = self._check_existing_file_by_hash(hash, type, url) + if existing_path: + return existing_path + + # Check if file exists locally + destination_path = get_full_path(type, desired_file_name, allow_missing=True) + existing_path = self._check_existing_file(type, desired_file_name, hash) + if existing_path: + return existing_path + + # Download the file + hasher = self._get_hasher() + self._download_file(url, destination_path, hasher) + + # Verify hash + calculated_hash = hasher.hexdigest() + self._verify_downloaded_hash(calculated_hash, hash, destination_path) + + # Update database + self.process_file(destination_path, url, calculated_hash) + + # TODO: Notify frontend to reload models + + return destination_path + model_processor = ModelProcessor() diff --git a/comfy/utils.py b/comfy/utils.py index 547ce9fc..7768f363 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -20,7 +20,6 @@ import torch import math import struct -from app.model_processor import model_processor import comfy.checkpoint_pickle import safetensors.torch import numpy as np @@ -50,13 +49,16 @@ if hasattr(torch.serialization, "add_safe_globals"): # TODO: this was added in else: logging.info("Warning, you are using an old pytorch version and some ckpt/pt files might be loaded unsafely. Upgrading to 2.4 or above is recommended.") +def is_html_file(file_path): + with open(file_path, "rb") as f: + content = f.read(100) + return b"" in content or b" 0: message = e.args[0] if "HeaderTooLarge" in message: @@ -92,6 +96,13 @@ def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False): sd = pl_sd else: sd = pl_sd + + try: + from app.model_processor import model_processor + model_processor.process_file(ckpt) + except Exception as e: + logging.error(f"Error processing file {ckpt}: {e}") + return (sd, metadata) if return_metadata else sd def save_torch_file(sd, ckpt, metadata=None): diff --git a/folder_paths.py b/folder_paths.py index 452409bf..5b5554a3 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -275,7 +275,7 @@ def filter_files_extensions(files: Collection[str], extensions: Collection[str]) -def get_full_path(folder_name: str, filename: str) -> str | None: +def get_full_path(folder_name: str, filename: str, allow_missing: bool = False) -> str | None: global folder_names_and_paths folder_name = map_legacy(folder_name) if folder_name not in folder_names_and_paths: @@ -288,6 +288,8 @@ def get_full_path(folder_name: str, filename: str) -> str | None: return full_path elif os.path.islink(full_path): logging.warning("WARNING path {} exists but doesn't link anywhere, skipping.".format(full_path)) + elif allow_missing: + return full_path return None diff --git a/main.py b/main.py index d6f8193c..17581a42 100644 --- a/main.py +++ b/main.py @@ -238,10 +238,11 @@ def cleanup_temp(): def setup_database(): try: - from app.database.db import init_db - init_db() + from app.database.db import init_db, dependencies_available + if dependencies_available(): + init_db() except Exception as e: - logging.error(f"Failed to initialize database. Please report this error as in future the database will be required: {e}") + logging.error(f"Failed to initialize database. Please ensure you have installed the latest requirements. If the error persists, please report this as in future the database will be required: {e}") def start_comfyui(asyncio_loop=None): """ diff --git a/requirements.txt b/requirements.txt index ea51f24a..1ae6de3e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -20,6 +20,7 @@ tqdm psutil alembic SQLAlchemy +blake3 #non essential dependencies: kornia>=0.7.1 diff --git a/tests-unit/app_test/model_processor_test.py b/tests-unit/app_test/model_processor_test.py new file mode 100644 index 00000000..d1e43d37 --- /dev/null +++ b/tests-unit/app_test/model_processor_test.py @@ -0,0 +1,253 @@ +import pytest +from unittest.mock import patch, MagicMock +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker +from app.model_processor import ModelProcessor +from app.database.models import Model, Base +import os + +# Test data constants +TEST_MODEL_TYPE = "checkpoints" +TEST_URL = "http://example.com/model.safetensors" +TEST_FILE_NAME = "model.safetensors" +TEST_EXPECTED_HASH = "abc123" +TEST_DESTINATION_PATH = "/path/to/model.safetensors" + + +def create_test_model(session, file_name, model_type, hash_value, file_size=1000, source_url=None): + """Helper to create a test model in the database.""" + model = Model(path=file_name, type=model_type, hash=hash_value, file_size=file_size, source_url=source_url) + session.add(model) + session.commit() + return model + + +def setup_mock_hash_calculation(model_processor, hash_value): + """Helper to setup hash calculation mocks.""" + mock_hash = MagicMock() + mock_hash.hexdigest.return_value = hash_value + return patch.object(model_processor, "_get_hasher", return_value=mock_hash) + + +def verify_model_in_db(session, file_name, expected_hash=None, expected_type=None): + """Helper to verify model exists in database with correct attributes.""" + db_model = session.query(Model).filter_by(path=file_name).first() + assert db_model is not None + if expected_hash: + assert db_model.hash == expected_hash + if expected_type: + assert db_model.type == expected_type + return db_model + + +@pytest.fixture +def db_engine(): + # Configure in-memory database + engine = create_engine("sqlite:///:memory:") + Base.metadata.create_all(engine) + yield engine + Base.metadata.drop_all(engine) + + +@pytest.fixture +def db_session(db_engine): + Session = sessionmaker(bind=db_engine) + session = Session() + yield session + session.close() + + +@pytest.fixture +def mock_get_relative_path(): + with patch("app.model_processor.get_relative_path") as mock: + mock.side_effect = lambda path: (TEST_MODEL_TYPE, os.path.basename(path)) + yield mock + + +@pytest.fixture +def mock_get_full_path(): + with patch("app.model_processor.get_full_path") as mock: + mock.return_value = TEST_DESTINATION_PATH + yield mock + + +@pytest.fixture +def model_processor(db_session, mock_get_relative_path, mock_get_full_path): + with patch("app.model_processor.create_session", return_value=db_session): + with patch("app.model_processor.can_create_session", return_value=True): + processor = ModelProcessor() + # Setup test state + processor.removed_files = [] + processor.downloaded_files = [] + processor.file_exists = {} + + def mock_download_file(url, destination_path, hasher): + processor.downloaded_files.append((url, destination_path)) + processor.file_exists[destination_path] = True + # Simulate writing some data to the file + test_data = b"test data" + hasher.update(test_data) + + def mock_remove_file(file_path): + processor.removed_files.append(file_path) + if file_path in processor.file_exists: + del processor.file_exists[file_path] + + # Setup common patches + file_exists_patch = patch.object( + processor, + "_file_exists", + side_effect=lambda path: processor.file_exists.get(path, False), + ) + file_size_patch = patch.object( + processor, + "_get_file_size", + side_effect=lambda path: ( + 1000 if processor.file_exists.get(path, False) else 0 + ), + ) + download_file_patch = patch.object( + processor, "_download_file", side_effect=mock_download_file + ) + remove_file_patch = patch.object( + processor, "_remove_file", side_effect=mock_remove_file + ) + + with ( + file_exists_patch, + file_size_patch, + download_file_patch, + remove_file_patch, + ): + yield processor + + +def test_ensure_downloaded_invalid_extension(model_processor): + # Ensure that an unsupported file extension raises an error to prevent unsafe file downloads + with pytest.raises(ValueError, match="Unsupported unsafe file for download"): + model_processor.ensure_downloaded(TEST_MODEL_TYPE, TEST_URL, "model.exe") + + +def test_ensure_downloaded_existing_file_with_hash(model_processor, db_session): + # Ensure that a file with the same hash but from a different source is not downloaded again + SOURCE_URL = "https://example.com/other.sft" + create_test_model(db_session, TEST_FILE_NAME, TEST_MODEL_TYPE, TEST_EXPECTED_HASH, source_url=SOURCE_URL) + model_processor.file_exists[TEST_DESTINATION_PATH] = True + + result = model_processor.ensure_downloaded( + TEST_MODEL_TYPE, TEST_URL, TEST_FILE_NAME, TEST_EXPECTED_HASH + ) + + assert result == TEST_DESTINATION_PATH + model = verify_model_in_db(db_session, TEST_FILE_NAME, TEST_EXPECTED_HASH, TEST_MODEL_TYPE) + assert model.source_url == SOURCE_URL # Ensure the source URL is not overwritten + + +def test_ensure_downloaded_existing_file_hash_mismatch(model_processor, db_session): + # Ensure that a file with a different hash raises an error + create_test_model(db_session, TEST_FILE_NAME, TEST_MODEL_TYPE, "different_hash") + model_processor.file_exists[TEST_DESTINATION_PATH] = True + + with pytest.raises(ValueError, match="File .* exists with hash .* but expected .*"): + model_processor.ensure_downloaded( + TEST_MODEL_TYPE, TEST_URL, TEST_FILE_NAME, TEST_EXPECTED_HASH + ) + + +def test_ensure_downloaded_new_file(model_processor, db_session): + # Ensure that a new file is downloaded + model_processor.file_exists[TEST_DESTINATION_PATH] = False + + with setup_mock_hash_calculation(model_processor, TEST_EXPECTED_HASH): + result = model_processor.ensure_downloaded( + TEST_MODEL_TYPE, TEST_URL, TEST_FILE_NAME, TEST_EXPECTED_HASH + ) + + assert result == TEST_DESTINATION_PATH + assert len(model_processor.downloaded_files) == 1 + assert model_processor.downloaded_files[0] == (TEST_URL, TEST_DESTINATION_PATH) + assert model_processor.file_exists[TEST_DESTINATION_PATH] + verify_model_in_db(db_session, TEST_FILE_NAME, TEST_EXPECTED_HASH, TEST_MODEL_TYPE) + + +def test_ensure_downloaded_hash_mismatch(model_processor, db_session): + # Ensure that download that results in a different hash raises an error + model_processor.file_exists[TEST_DESTINATION_PATH] = False + + with setup_mock_hash_calculation(model_processor, "different_hash"): + with pytest.raises( + ValueError, + match="Downloaded file hash .* does not match expected hash .*", + ): + model_processor.ensure_downloaded( + TEST_MODEL_TYPE, + TEST_URL, + TEST_FILE_NAME, + TEST_EXPECTED_HASH, + ) + + assert len(model_processor.removed_files) == 1 + assert model_processor.removed_files[0] == TEST_DESTINATION_PATH + assert TEST_DESTINATION_PATH not in model_processor.file_exists + assert db_session.query(Model).filter_by(path=TEST_FILE_NAME).first() is None + + +def test_process_file_without_hash(model_processor, db_session): + # Test processing file without provided hash + model_processor.file_exists[TEST_DESTINATION_PATH] = True + + with patch.object(model_processor, "_hash_file", return_value=TEST_EXPECTED_HASH): + result = model_processor.process_file(TEST_DESTINATION_PATH) + assert result is not None + assert result.hash == TEST_EXPECTED_HASH + + +def test_retrieve_model_by_hash(model_processor, db_session): + # Test retrieving model by hash + create_test_model(db_session, TEST_FILE_NAME, TEST_MODEL_TYPE, TEST_EXPECTED_HASH) + result = model_processor.retrieve_model_by_hash(TEST_EXPECTED_HASH) + assert result is not None + assert result.hash == TEST_EXPECTED_HASH + + +def test_retrieve_model_by_hash_and_type(model_processor, db_session): + # Test retrieving model by hash and type + create_test_model(db_session, TEST_FILE_NAME, TEST_MODEL_TYPE, TEST_EXPECTED_HASH) + result = model_processor.retrieve_model_by_hash(TEST_EXPECTED_HASH, TEST_MODEL_TYPE) + assert result is not None + assert result.hash == TEST_EXPECTED_HASH + assert result.type == TEST_MODEL_TYPE + + +def test_retrieve_hash(model_processor, db_session): + # Test retrieving hash for existing model + create_test_model(db_session, TEST_FILE_NAME, TEST_MODEL_TYPE, TEST_EXPECTED_HASH) + with patch.object( + model_processor, + "_validate_path", + return_value=(TEST_MODEL_TYPE, TEST_FILE_NAME), + ): + result = model_processor.retrieve_hash(TEST_DESTINATION_PATH, TEST_MODEL_TYPE) + assert result == TEST_EXPECTED_HASH + + +def test_validate_file_extension_valid_extensions(model_processor): + # Test all valid file extensions + valid_extensions = [".safetensors", ".sft", ".txt", ".csv", ".json", ".yaml"] + for ext in valid_extensions: + model_processor._validate_file_extension(f"test{ext}") # Should not raise + + +def test_process_file_existing_without_source_url(model_processor, db_session): + # Test processing an existing file that needs its source URL updated + model_processor.file_exists[TEST_DESTINATION_PATH] = True + + create_test_model(db_session, TEST_FILE_NAME, TEST_MODEL_TYPE, TEST_EXPECTED_HASH) + result = model_processor.process_file(TEST_DESTINATION_PATH, source_url=TEST_URL) + + assert result is not None + assert result.hash == TEST_EXPECTED_HASH + assert result.source_url == TEST_URL + + db_model = db_session.query(Model).filter_by(path=TEST_FILE_NAME).first() + assert db_model.source_url == TEST_URL