ComfyUI/tests-unit/app_test/model_processor_test.py

254 lines
9.7 KiB
Python

import pytest
from unittest.mock import patch, MagicMock
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from app.model_processor import ModelProcessor
from app.database.models import Model, Base
import os
# Test data constants
TEST_MODEL_TYPE = "checkpoints"
TEST_URL = "http://example.com/model.safetensors"
TEST_FILE_NAME = "model.safetensors"
TEST_EXPECTED_HASH = "abc123"
TEST_DESTINATION_PATH = "/path/to/model.safetensors"
def create_test_model(session, file_name, model_type, hash_value, file_size=1000, source_url=None):
"""Helper to create a test model in the database."""
model = Model(path=file_name, type=model_type, hash=hash_value, file_size=file_size, source_url=source_url)
session.add(model)
session.commit()
return model
def setup_mock_hash_calculation(model_processor, hash_value):
"""Helper to setup hash calculation mocks."""
mock_hash = MagicMock()
mock_hash.hexdigest.return_value = hash_value
return patch.object(model_processor, "_get_hasher", return_value=mock_hash)
def verify_model_in_db(session, file_name, expected_hash=None, expected_type=None):
"""Helper to verify model exists in database with correct attributes."""
db_model = session.query(Model).filter_by(path=file_name).first()
assert db_model is not None
if expected_hash:
assert db_model.hash == expected_hash
if expected_type:
assert db_model.type == expected_type
return db_model
@pytest.fixture
def db_engine():
# Configure in-memory database
engine = create_engine("sqlite:///:memory:")
Base.metadata.create_all(engine)
yield engine
Base.metadata.drop_all(engine)
@pytest.fixture
def db_session(db_engine):
Session = sessionmaker(bind=db_engine)
session = Session()
yield session
session.close()
@pytest.fixture
def mock_get_relative_path():
with patch("app.model_processor.get_relative_path") as mock:
mock.side_effect = lambda path: (TEST_MODEL_TYPE, os.path.basename(path))
yield mock
@pytest.fixture
def mock_get_full_path():
with patch("app.model_processor.get_full_path") as mock:
mock.return_value = TEST_DESTINATION_PATH
yield mock
@pytest.fixture
def model_processor(db_session, mock_get_relative_path, mock_get_full_path):
with patch("app.model_processor.create_session", return_value=db_session):
with patch("app.model_processor.can_create_session", return_value=True):
processor = ModelProcessor()
# Setup test state
processor.removed_files = []
processor.downloaded_files = []
processor.file_exists = {}
def mock_download_file(url, destination_path, hasher):
processor.downloaded_files.append((url, destination_path))
processor.file_exists[destination_path] = True
# Simulate writing some data to the file
test_data = b"test data"
hasher.update(test_data)
def mock_remove_file(file_path):
processor.removed_files.append(file_path)
if file_path in processor.file_exists:
del processor.file_exists[file_path]
# Setup common patches
file_exists_patch = patch.object(
processor,
"_file_exists",
side_effect=lambda path: processor.file_exists.get(path, False),
)
file_size_patch = patch.object(
processor,
"_get_file_size",
side_effect=lambda path: (
1000 if processor.file_exists.get(path, False) else 0
),
)
download_file_patch = patch.object(
processor, "_download_file", side_effect=mock_download_file
)
remove_file_patch = patch.object(
processor, "_remove_file", side_effect=mock_remove_file
)
with (
file_exists_patch,
file_size_patch,
download_file_patch,
remove_file_patch,
):
yield processor
def test_ensure_downloaded_invalid_extension(model_processor):
# Ensure that an unsupported file extension raises an error to prevent unsafe file downloads
with pytest.raises(ValueError, match="Unsupported unsafe file for download"):
model_processor.ensure_downloaded(TEST_MODEL_TYPE, TEST_URL, "model.exe")
def test_ensure_downloaded_existing_file_with_hash(model_processor, db_session):
# Ensure that a file with the same hash but from a different source is not downloaded again
SOURCE_URL = "https://example.com/other.sft"
create_test_model(db_session, TEST_FILE_NAME, TEST_MODEL_TYPE, TEST_EXPECTED_HASH, source_url=SOURCE_URL)
model_processor.file_exists[TEST_DESTINATION_PATH] = True
result = model_processor.ensure_downloaded(
TEST_MODEL_TYPE, TEST_URL, TEST_FILE_NAME, TEST_EXPECTED_HASH
)
assert result == TEST_DESTINATION_PATH
model = verify_model_in_db(db_session, TEST_FILE_NAME, TEST_EXPECTED_HASH, TEST_MODEL_TYPE)
assert model.source_url == SOURCE_URL # Ensure the source URL is not overwritten
def test_ensure_downloaded_existing_file_hash_mismatch(model_processor, db_session):
# Ensure that a file with a different hash raises an error
create_test_model(db_session, TEST_FILE_NAME, TEST_MODEL_TYPE, "different_hash")
model_processor.file_exists[TEST_DESTINATION_PATH] = True
with pytest.raises(ValueError, match="File .* exists with hash .* but expected .*"):
model_processor.ensure_downloaded(
TEST_MODEL_TYPE, TEST_URL, TEST_FILE_NAME, TEST_EXPECTED_HASH
)
def test_ensure_downloaded_new_file(model_processor, db_session):
# Ensure that a new file is downloaded
model_processor.file_exists[TEST_DESTINATION_PATH] = False
with setup_mock_hash_calculation(model_processor, TEST_EXPECTED_HASH):
result = model_processor.ensure_downloaded(
TEST_MODEL_TYPE, TEST_URL, TEST_FILE_NAME, TEST_EXPECTED_HASH
)
assert result == TEST_DESTINATION_PATH
assert len(model_processor.downloaded_files) == 1
assert model_processor.downloaded_files[0] == (TEST_URL, TEST_DESTINATION_PATH)
assert model_processor.file_exists[TEST_DESTINATION_PATH]
verify_model_in_db(db_session, TEST_FILE_NAME, TEST_EXPECTED_HASH, TEST_MODEL_TYPE)
def test_ensure_downloaded_hash_mismatch(model_processor, db_session):
# Ensure that download that results in a different hash raises an error
model_processor.file_exists[TEST_DESTINATION_PATH] = False
with setup_mock_hash_calculation(model_processor, "different_hash"):
with pytest.raises(
ValueError,
match="Downloaded file hash .* does not match expected hash .*",
):
model_processor.ensure_downloaded(
TEST_MODEL_TYPE,
TEST_URL,
TEST_FILE_NAME,
TEST_EXPECTED_HASH,
)
assert len(model_processor.removed_files) == 1
assert model_processor.removed_files[0] == TEST_DESTINATION_PATH
assert TEST_DESTINATION_PATH not in model_processor.file_exists
assert db_session.query(Model).filter_by(path=TEST_FILE_NAME).first() is None
def test_process_file_without_hash(model_processor, db_session):
# Test processing file without provided hash
model_processor.file_exists[TEST_DESTINATION_PATH] = True
with patch.object(model_processor, "_hash_file", return_value=TEST_EXPECTED_HASH):
result = model_processor.process_file(TEST_DESTINATION_PATH)
assert result is not None
assert result.hash == TEST_EXPECTED_HASH
def test_retrieve_model_by_hash(model_processor, db_session):
# Test retrieving model by hash
create_test_model(db_session, TEST_FILE_NAME, TEST_MODEL_TYPE, TEST_EXPECTED_HASH)
result = model_processor.retrieve_model_by_hash(TEST_EXPECTED_HASH)
assert result is not None
assert result.hash == TEST_EXPECTED_HASH
def test_retrieve_model_by_hash_and_type(model_processor, db_session):
# Test retrieving model by hash and type
create_test_model(db_session, TEST_FILE_NAME, TEST_MODEL_TYPE, TEST_EXPECTED_HASH)
result = model_processor.retrieve_model_by_hash(TEST_EXPECTED_HASH, TEST_MODEL_TYPE)
assert result is not None
assert result.hash == TEST_EXPECTED_HASH
assert result.type == TEST_MODEL_TYPE
def test_retrieve_hash(model_processor, db_session):
# Test retrieving hash for existing model
create_test_model(db_session, TEST_FILE_NAME, TEST_MODEL_TYPE, TEST_EXPECTED_HASH)
with patch.object(
model_processor,
"_validate_path",
return_value=(TEST_MODEL_TYPE, TEST_FILE_NAME),
):
result = model_processor.retrieve_hash(TEST_DESTINATION_PATH, TEST_MODEL_TYPE)
assert result == TEST_EXPECTED_HASH
def test_validate_file_extension_valid_extensions(model_processor):
# Test all valid file extensions
valid_extensions = [".safetensors", ".sft", ".txt", ".csv", ".json", ".yaml"]
for ext in valid_extensions:
model_processor._validate_file_extension(f"test{ext}") # Should not raise
def test_process_file_existing_without_source_url(model_processor, db_session):
# Test processing an existing file that needs its source URL updated
model_processor.file_exists[TEST_DESTINATION_PATH] = True
create_test_model(db_session, TEST_FILE_NAME, TEST_MODEL_TYPE, TEST_EXPECTED_HASH)
result = model_processor.process_file(TEST_DESTINATION_PATH, source_url=TEST_URL)
assert result is not None
assert result.hash == TEST_EXPECTED_HASH
assert result.source_url == TEST_URL
db_model = db_session.query(Model).filter_by(path=TEST_FILE_NAME).first()
assert db_model.source_url == TEST_URL