mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-06-08 07:07:14 +00:00
254 lines
9.7 KiB
Python
254 lines
9.7 KiB
Python
import pytest
|
|
from unittest.mock import patch, MagicMock
|
|
from sqlalchemy import create_engine
|
|
from sqlalchemy.orm import sessionmaker
|
|
from app.model_processor import ModelProcessor
|
|
from app.database.models import Model, Base
|
|
import os
|
|
|
|
# Test data constants
|
|
TEST_MODEL_TYPE = "checkpoints"
|
|
TEST_URL = "http://example.com/model.safetensors"
|
|
TEST_FILE_NAME = "model.safetensors"
|
|
TEST_EXPECTED_HASH = "abc123"
|
|
TEST_DESTINATION_PATH = "/path/to/model.safetensors"
|
|
|
|
|
|
def create_test_model(session, file_name, model_type, hash_value, file_size=1000, source_url=None):
|
|
"""Helper to create a test model in the database."""
|
|
model = Model(path=file_name, type=model_type, hash=hash_value, file_size=file_size, source_url=source_url)
|
|
session.add(model)
|
|
session.commit()
|
|
return model
|
|
|
|
|
|
def setup_mock_hash_calculation(model_processor, hash_value):
|
|
"""Helper to setup hash calculation mocks."""
|
|
mock_hash = MagicMock()
|
|
mock_hash.hexdigest.return_value = hash_value
|
|
return patch.object(model_processor, "_get_hasher", return_value=mock_hash)
|
|
|
|
|
|
def verify_model_in_db(session, file_name, expected_hash=None, expected_type=None):
|
|
"""Helper to verify model exists in database with correct attributes."""
|
|
db_model = session.query(Model).filter_by(path=file_name).first()
|
|
assert db_model is not None
|
|
if expected_hash:
|
|
assert db_model.hash == expected_hash
|
|
if expected_type:
|
|
assert db_model.type == expected_type
|
|
return db_model
|
|
|
|
|
|
@pytest.fixture
|
|
def db_engine():
|
|
# Configure in-memory database
|
|
engine = create_engine("sqlite:///:memory:")
|
|
Base.metadata.create_all(engine)
|
|
yield engine
|
|
Base.metadata.drop_all(engine)
|
|
|
|
|
|
@pytest.fixture
|
|
def db_session(db_engine):
|
|
Session = sessionmaker(bind=db_engine)
|
|
session = Session()
|
|
yield session
|
|
session.close()
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_get_relative_path():
|
|
with patch("app.model_processor.get_relative_path") as mock:
|
|
mock.side_effect = lambda path: (TEST_MODEL_TYPE, os.path.basename(path))
|
|
yield mock
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_get_full_path():
|
|
with patch("app.model_processor.get_full_path") as mock:
|
|
mock.return_value = TEST_DESTINATION_PATH
|
|
yield mock
|
|
|
|
|
|
@pytest.fixture
|
|
def model_processor(db_session, mock_get_relative_path, mock_get_full_path):
|
|
with patch("app.model_processor.create_session", return_value=db_session):
|
|
with patch("app.model_processor.can_create_session", return_value=True):
|
|
processor = ModelProcessor()
|
|
# Setup test state
|
|
processor.removed_files = []
|
|
processor.downloaded_files = []
|
|
processor.file_exists = {}
|
|
|
|
def mock_download_file(url, destination_path, hasher):
|
|
processor.downloaded_files.append((url, destination_path))
|
|
processor.file_exists[destination_path] = True
|
|
# Simulate writing some data to the file
|
|
test_data = b"test data"
|
|
hasher.update(test_data)
|
|
|
|
def mock_remove_file(file_path):
|
|
processor.removed_files.append(file_path)
|
|
if file_path in processor.file_exists:
|
|
del processor.file_exists[file_path]
|
|
|
|
# Setup common patches
|
|
file_exists_patch = patch.object(
|
|
processor,
|
|
"_file_exists",
|
|
side_effect=lambda path: processor.file_exists.get(path, False),
|
|
)
|
|
file_size_patch = patch.object(
|
|
processor,
|
|
"_get_file_size",
|
|
side_effect=lambda path: (
|
|
1000 if processor.file_exists.get(path, False) else 0
|
|
),
|
|
)
|
|
download_file_patch = patch.object(
|
|
processor, "_download_file", side_effect=mock_download_file
|
|
)
|
|
remove_file_patch = patch.object(
|
|
processor, "_remove_file", side_effect=mock_remove_file
|
|
)
|
|
|
|
with (
|
|
file_exists_patch,
|
|
file_size_patch,
|
|
download_file_patch,
|
|
remove_file_patch,
|
|
):
|
|
yield processor
|
|
|
|
|
|
def test_ensure_downloaded_invalid_extension(model_processor):
|
|
# Ensure that an unsupported file extension raises an error to prevent unsafe file downloads
|
|
with pytest.raises(ValueError, match="Unsupported unsafe file for download"):
|
|
model_processor.ensure_downloaded(TEST_MODEL_TYPE, TEST_URL, "model.exe")
|
|
|
|
|
|
def test_ensure_downloaded_existing_file_with_hash(model_processor, db_session):
|
|
# Ensure that a file with the same hash but from a different source is not downloaded again
|
|
SOURCE_URL = "https://example.com/other.sft"
|
|
create_test_model(db_session, TEST_FILE_NAME, TEST_MODEL_TYPE, TEST_EXPECTED_HASH, source_url=SOURCE_URL)
|
|
model_processor.file_exists[TEST_DESTINATION_PATH] = True
|
|
|
|
result = model_processor.ensure_downloaded(
|
|
TEST_MODEL_TYPE, TEST_URL, TEST_FILE_NAME, TEST_EXPECTED_HASH
|
|
)
|
|
|
|
assert result == TEST_DESTINATION_PATH
|
|
model = verify_model_in_db(db_session, TEST_FILE_NAME, TEST_EXPECTED_HASH, TEST_MODEL_TYPE)
|
|
assert model.source_url == SOURCE_URL # Ensure the source URL is not overwritten
|
|
|
|
|
|
def test_ensure_downloaded_existing_file_hash_mismatch(model_processor, db_session):
|
|
# Ensure that a file with a different hash raises an error
|
|
create_test_model(db_session, TEST_FILE_NAME, TEST_MODEL_TYPE, "different_hash")
|
|
model_processor.file_exists[TEST_DESTINATION_PATH] = True
|
|
|
|
with pytest.raises(ValueError, match="File .* exists with hash .* but expected .*"):
|
|
model_processor.ensure_downloaded(
|
|
TEST_MODEL_TYPE, TEST_URL, TEST_FILE_NAME, TEST_EXPECTED_HASH
|
|
)
|
|
|
|
|
|
def test_ensure_downloaded_new_file(model_processor, db_session):
|
|
# Ensure that a new file is downloaded
|
|
model_processor.file_exists[TEST_DESTINATION_PATH] = False
|
|
|
|
with setup_mock_hash_calculation(model_processor, TEST_EXPECTED_HASH):
|
|
result = model_processor.ensure_downloaded(
|
|
TEST_MODEL_TYPE, TEST_URL, TEST_FILE_NAME, TEST_EXPECTED_HASH
|
|
)
|
|
|
|
assert result == TEST_DESTINATION_PATH
|
|
assert len(model_processor.downloaded_files) == 1
|
|
assert model_processor.downloaded_files[0] == (TEST_URL, TEST_DESTINATION_PATH)
|
|
assert model_processor.file_exists[TEST_DESTINATION_PATH]
|
|
verify_model_in_db(db_session, TEST_FILE_NAME, TEST_EXPECTED_HASH, TEST_MODEL_TYPE)
|
|
|
|
|
|
def test_ensure_downloaded_hash_mismatch(model_processor, db_session):
|
|
# Ensure that download that results in a different hash raises an error
|
|
model_processor.file_exists[TEST_DESTINATION_PATH] = False
|
|
|
|
with setup_mock_hash_calculation(model_processor, "different_hash"):
|
|
with pytest.raises(
|
|
ValueError,
|
|
match="Downloaded file hash .* does not match expected hash .*",
|
|
):
|
|
model_processor.ensure_downloaded(
|
|
TEST_MODEL_TYPE,
|
|
TEST_URL,
|
|
TEST_FILE_NAME,
|
|
TEST_EXPECTED_HASH,
|
|
)
|
|
|
|
assert len(model_processor.removed_files) == 1
|
|
assert model_processor.removed_files[0] == TEST_DESTINATION_PATH
|
|
assert TEST_DESTINATION_PATH not in model_processor.file_exists
|
|
assert db_session.query(Model).filter_by(path=TEST_FILE_NAME).first() is None
|
|
|
|
|
|
def test_process_file_without_hash(model_processor, db_session):
|
|
# Test processing file without provided hash
|
|
model_processor.file_exists[TEST_DESTINATION_PATH] = True
|
|
|
|
with patch.object(model_processor, "_hash_file", return_value=TEST_EXPECTED_HASH):
|
|
result = model_processor.process_file(TEST_DESTINATION_PATH)
|
|
assert result is not None
|
|
assert result.hash == TEST_EXPECTED_HASH
|
|
|
|
|
|
def test_retrieve_model_by_hash(model_processor, db_session):
|
|
# Test retrieving model by hash
|
|
create_test_model(db_session, TEST_FILE_NAME, TEST_MODEL_TYPE, TEST_EXPECTED_HASH)
|
|
result = model_processor.retrieve_model_by_hash(TEST_EXPECTED_HASH)
|
|
assert result is not None
|
|
assert result.hash == TEST_EXPECTED_HASH
|
|
|
|
|
|
def test_retrieve_model_by_hash_and_type(model_processor, db_session):
|
|
# Test retrieving model by hash and type
|
|
create_test_model(db_session, TEST_FILE_NAME, TEST_MODEL_TYPE, TEST_EXPECTED_HASH)
|
|
result = model_processor.retrieve_model_by_hash(TEST_EXPECTED_HASH, TEST_MODEL_TYPE)
|
|
assert result is not None
|
|
assert result.hash == TEST_EXPECTED_HASH
|
|
assert result.type == TEST_MODEL_TYPE
|
|
|
|
|
|
def test_retrieve_hash(model_processor, db_session):
|
|
# Test retrieving hash for existing model
|
|
create_test_model(db_session, TEST_FILE_NAME, TEST_MODEL_TYPE, TEST_EXPECTED_HASH)
|
|
with patch.object(
|
|
model_processor,
|
|
"_validate_path",
|
|
return_value=(TEST_MODEL_TYPE, TEST_FILE_NAME),
|
|
):
|
|
result = model_processor.retrieve_hash(TEST_DESTINATION_PATH, TEST_MODEL_TYPE)
|
|
assert result == TEST_EXPECTED_HASH
|
|
|
|
|
|
def test_validate_file_extension_valid_extensions(model_processor):
|
|
# Test all valid file extensions
|
|
valid_extensions = [".safetensors", ".sft", ".txt", ".csv", ".json", ".yaml"]
|
|
for ext in valid_extensions:
|
|
model_processor._validate_file_extension(f"test{ext}") # Should not raise
|
|
|
|
|
|
def test_process_file_existing_without_source_url(model_processor, db_session):
|
|
# Test processing an existing file that needs its source URL updated
|
|
model_processor.file_exists[TEST_DESTINATION_PATH] = True
|
|
|
|
create_test_model(db_session, TEST_FILE_NAME, TEST_MODEL_TYPE, TEST_EXPECTED_HASH)
|
|
result = model_processor.process_file(TEST_DESTINATION_PATH, source_url=TEST_URL)
|
|
|
|
assert result is not None
|
|
assert result.hash == TEST_EXPECTED_HASH
|
|
assert result.source_url == TEST_URL
|
|
|
|
db_model = db_session.query(Model).filter_by(path=TEST_FILE_NAME).first()
|
|
assert db_model.source_url == TEST_URL
|