Add additional db model metadata fields and model downloading function

This commit is contained in:
pythongosssss 2025-06-01 13:34:26 +01:00
parent 1cb3c98947
commit 9da6aca0d0
9 changed files with 566 additions and 50 deletions

View File

@ -1,8 +1,8 @@
"""init
Revision ID: 565b08122d00
Revision ID: e9c714da8d57
Revises:
Create Date: 2025-05-29 19:15:56.230322
Create Date: 2025-05-30 20:14:33.772039
"""
from typing import Sequence, Union
@ -12,7 +12,7 @@ import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = '565b08122d00'
revision: str = 'e9c714da8d57'
down_revision: Union[str, None] = None
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
@ -20,15 +20,23 @@ depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
"""Upgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('model',
sa.Column('type', sa.Text(), nullable=False),
sa.Column('path', sa.Text(), nullable=False),
sa.Column('file_name', sa.Text(), nullable=True),
sa.Column('file_size', sa.Integer(), nullable=True),
sa.Column('hash', sa.Text(), nullable=True),
sa.Column('hash_algorithm', sa.Text(), nullable=True),
sa.Column('source_url', sa.Text(), nullable=True),
sa.Column('date_added', sa.DateTime(), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True),
sa.PrimaryKeyConstraint('type', 'path')
)
# ### end Alembic commands ###
def downgrade() -> None:
"""Downgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table('model')
# ### end Alembic commands ###

View File

@ -1,23 +1,15 @@
import logging
import os
import shutil
from app.logger import log_startup_warning
from utils.install_util import get_missing_requirements_message
from comfy.cli_args import args
_DB_AVAILABLE = False
Session = None
def can_create_session():
return Session is not None
try:
import alembic
import sqlalchemy
except ImportError as e:
logging.error(get_missing_requirements_message())
raise e
from alembic import command
from alembic.config import Config
from alembic.runtime.migration import MigrationContext
@ -25,6 +17,35 @@ from alembic.script import ScriptDirectory
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
_DB_AVAILABLE = True
except ImportError as e:
log_startup_warning(
f"""
------------------------------------------------------------------------
Error importing dependencies: {e}
{get_missing_requirements_message()}
This error is happening because ComfyUI now uses a local sqlite database.
------------------------------------------------------------------------
""".strip()
)
def dependencies_available():
"""
Temporary function to check if the dependencies are available
"""
return _DB_AVAILABLE
def can_create_session():
"""
Temporary function to check if the database is available to create a session
During initial release there may be environmental issues (or missing dependencies) that prevent the database from being created
"""
return dependencies_available() and Session is not None
def get_alembic_config():
root_path = os.path.join(os.path.dirname(__file__), "../..")
@ -49,6 +70,8 @@ def get_db_path():
def init_db():
db_url = args.database_url
logging.debug(f"Database URL: {db_url}")
db_path = get_db_path()
db_exists = os.path.exists(db_path)
config = get_alembic_config()
@ -64,9 +87,8 @@ def init_db():
if current_rev != target_rev:
# Backup the database pre upgrade
db_path = get_db_path()
backup_path = db_path + ".bkp"
if os.path.exists(db_path):
if db_exists:
shutil.copy(db_path, backup_path)
else:
backup_path = None

View File

@ -1,5 +1,6 @@
from sqlalchemy import (
Column,
Integer,
Text,
DateTime,
)
@ -20,7 +21,7 @@ def to_dict(obj):
class Model(Base):
"""
SQLAlchemy model representing a model file in the system.
sqlalchemy model representing a model file in the system.
This class defines the database schema for storing information about model files,
including their type, path, hash, and when they were added to the system.
@ -28,7 +29,11 @@ class Model(Base):
Attributes:
type (Text): The type of the model, this is the name of the folder in the models folder (primary key)
path (Text): The file path of the model relative to the type folder (primary key)
hash (Text): A sha256 hash of the model file
file_name (Text): The name of the model file
file_size (Integer): The size of the model file in bytes
hash (Text): A hash of the model file
hash_algorithm (Text): The algorithm used to generate the hash
source_url (Text): The URL of the model file
date_added (DateTime): Timestamp of when the model was added to the system
"""
@ -36,7 +41,11 @@ class Model(Base):
type = Column(Text, primary_key=True)
path = Column(Text, primary_key=True)
file_name = Column(Text)
file_size = Column(Integer)
hash = Column(Text)
hash_algorithm = Column(Text)
source_url = Column(Text)
date_added = Column(DateTime, server_default=func.now())
def to_dict(self):

View File

@ -1,16 +1,23 @@
import hashlib
import os
import logging
import time
import requests
from tqdm import tqdm
from folder_paths import get_relative_path, get_full_path
from app.database.db import create_session, dependencies_available, can_create_session
import blake3
import comfy.utils
if dependencies_available():
from app.database.models import Model
from app.database.db import create_session
from folder_paths import get_relative_path
class ModelProcessor:
def _validate_path(self, model_path):
try:
if not os.path.exists(model_path):
if not self._file_exists(model_path):
logging.error(f"Model file not found: {model_path}")
return None
@ -26,15 +33,26 @@ class ModelProcessor:
logging.error(f"Error validating model path {model_path}: {str(e)}")
return None
def _file_exists(self, path):
"""Check if a file exists."""
return os.path.exists(path)
def _get_file_size(self, path):
"""Get file size."""
return os.path.getsize(path)
def _get_hasher(self):
return blake3.blake3()
def _hash_file(self, model_path):
try:
h = hashlib.sha256()
hasher = self._get_hasher()
with open(model_path, "rb", buffering=0) as f:
b = bytearray(128 * 1024)
mv = memoryview(b)
while n := f.readinto(mv):
h.update(mv[:n])
return h.hexdigest()
hasher.update(mv[:n])
return hasher.hexdigest()
except Exception as e:
logging.error(f"Error hashing file {model_path}: {str(e)}")
return None
@ -47,8 +65,20 @@ class ModelProcessor:
.first()
)
def _ensure_source_url(self, session, model, source_url):
if model.source_url is None:
model.source_url = source_url
session.commit()
def _update_database(
self, session, model_type, model_relative_path, model_hash, model=None
self,
session,
model_type,
model_path,
model_relative_path,
model_hash,
model,
source_url,
):
try:
if not model:
@ -60,10 +90,16 @@ class ModelProcessor:
model = Model(
path=model_relative_path,
type=model_type,
file_name=os.path.basename(model_path),
)
session.add(model)
model.file_size = self._get_file_size(model_path)
model.hash = model_hash
if model_hash:
model.hash_algorithm = "blake3"
model.source_url = source_url
session.commit()
return model
except Exception as e:
@ -71,21 +107,40 @@ class ModelProcessor:
f"Error updating database for {model_relative_path}: {str(e)}"
)
def process_file(self, model_path):
def process_file(self, model_path, source_url=None, model_hash=None):
"""
Process a model file and update the database with metadata.
If the file already exists and matches the database, it will not be processed again.
Returns the model object or if an error occurs, returns None.
"""
try:
if not can_create_session():
return
result = self._validate_path(model_path)
if not result:
return
model_type, model_relative_path = result
with create_session() as session:
session.expire_on_commit = False
existing_model = self._get_existing_model(
session, model_type, model_relative_path
)
if existing_model and existing_model.hash:
# File exists with hash, no need to process
if (
existing_model
and existing_model.hash
and existing_model.file_size == self._get_file_size(model_path)
):
# File exists with hash and same size, no need to process
self._ensure_source_url(session, existing_model, source_url)
return existing_model
if model_hash:
model_hash = model_hash.lower()
logging.info(f"Using provided hash: {model_hash}")
else:
start_time = time.time()
logging.info(f"Hashing model {model_relative_path}")
model_hash = self._hash_file(model_path)
@ -95,12 +150,54 @@ class ModelProcessor:
f"Model hash: {model_hash} (duration: {time.time() - start_time} seconds)"
)
return self._update_database(session, model_type, model_relative_path, model_hash)
return self._update_database(
session,
model_type,
model_path,
model_relative_path,
model_hash,
existing_model,
source_url,
)
except Exception as e:
logging.error(f"Error processing model file {model_path}: {str(e)}")
return None
def retrieve_model_by_hash(self, model_hash, model_type=None, session=None):
"""
Retrieve a model file from the database by hash and optionally by model type.
Returns the model object or None if the model doesnt exist or an error occurs.
"""
try:
if not can_create_session():
return
dispose_session = False
if session is None:
session = create_session()
dispose_session = True
model = session.query(Model).filter(Model.hash == model_hash)
if model_type is not None:
model = model.filter(Model.type == model_type)
return model.first()
except Exception as e:
logging.error(f"Error retrieving model by hash {model_hash}: {str(e)}")
return None
finally:
if dispose_session:
session.close()
def retrieve_hash(self, model_path, model_type=None):
"""
Retrieve the hash of a model file from the database.
Returns the hash or None if the model doesnt exist or an error occurs.
"""
try:
if not can_create_session():
return
if model_type is not None:
result = self._validate_path(model_path)
if not result:
@ -118,5 +215,117 @@ class ModelProcessor:
logging.error(f"Error retrieving hash for {model_path}: {str(e)}")
return None
def _validate_file_extension(self, file_name):
"""Validate that the file extension is supported."""
extension = os.path.splitext(file_name)[1]
if extension not in (".safetensors", ".sft", ".txt", ".csv", ".json", ".yaml"):
raise ValueError(f"Unsupported unsafe file for download: {file_name}")
def _check_existing_file(self, model_type, file_name, expected_hash):
"""Check if file exists and has correct hash."""
destination_path = get_full_path(model_type, file_name, allow_missing=True)
if self._file_exists(destination_path):
model = self.process_file(destination_path)
if model and (expected_hash is None or model.hash == expected_hash):
logging.debug(
f"File {destination_path} already exists in the database and has the correct hash or no hash was provided."
)
return destination_path
else:
raise ValueError(
f"File {destination_path} exists with hash {model.hash if model else 'unknown'} but expected {expected_hash}. Please delete the file and try again."
)
return None
def _check_existing_file_by_hash(self, hash, type, url):
"""Check if a file with the given hash exists in the database and on disk."""
hash = hash.lower()
with create_session() as session:
model = self.retrieve_model_by_hash(hash, type, session)
if model:
existing_path = get_full_path(type, model.path)
if existing_path:
logging.debug(
f"File {model.path} already exists in the database at {existing_path}"
)
self._ensure_source_url(session, model, url)
return existing_path
else:
logging.debug(
f"File {model.path} exists in the database but not on disk"
)
return None
def _download_file(self, url, destination_path, hasher):
"""Download a file and update the hasher with its contents."""
response = requests.get(url, stream=True)
logging.info(f"Downloading {url} to {destination_path}")
with open(destination_path, "wb") as f:
total_size = int(response.headers.get("content-length", 0))
if total_size > 0:
pbar = comfy.utils.ProgressBar(total_size)
else:
pbar = None
with tqdm(total=total_size, unit="B", unit_scale=True) as progress_bar:
for chunk in response.iter_content(chunk_size=128 * 1024):
if chunk:
f.write(chunk)
hasher.update(chunk)
progress_bar.update(len(chunk))
if pbar:
pbar.update(len(chunk))
def _verify_downloaded_hash(self, calculated_hash, expected_hash, destination_path):
"""Verify that the downloaded file has the expected hash."""
if expected_hash is not None and calculated_hash != expected_hash:
self._remove_file(destination_path)
raise ValueError(
f"Downloaded file hash {calculated_hash} does not match expected hash {expected_hash}"
)
def _remove_file(self, file_path):
"""Remove a file from disk."""
os.remove(file_path)
def ensure_downloaded(self, type, url, desired_file_name, hash=None):
"""
Ensure a model file is downloaded and has the correct hash.
Returns the path to the downloaded file.
"""
logging.debug(
f"Ensuring {type} file is downloaded. URL='{url}' Destination='{desired_file_name}' Hash='{hash}'"
)
# Validate file extension
self._validate_file_extension(desired_file_name)
# Check if file exists with correct hash
if hash:
existing_path = self._check_existing_file_by_hash(hash, type, url)
if existing_path:
return existing_path
# Check if file exists locally
destination_path = get_full_path(type, desired_file_name, allow_missing=True)
existing_path = self._check_existing_file(type, desired_file_name, hash)
if existing_path:
return existing_path
# Download the file
hasher = self._get_hasher()
self._download_file(url, destination_path, hasher)
# Verify hash
calculated_hash = hasher.hexdigest()
self._verify_downloaded_hash(calculated_hash, hash, destination_path)
# Update database
self.process_file(destination_path, url, calculated_hash)
# TODO: Notify frontend to reload models
return destination_path
model_processor = ModelProcessor()

View File

@ -20,7 +20,6 @@
import torch
import math
import struct
from app.model_processor import model_processor
import comfy.checkpoint_pickle
import safetensors.torch
import numpy as np
@ -50,13 +49,16 @@ if hasattr(torch.serialization, "add_safe_globals"): # TODO: this was added in
else:
logging.info("Warning, you are using an old pytorch version and some ckpt/pt files might be loaded unsafely. Upgrading to 2.4 or above is recommended.")
def is_html_file(file_path):
with open(file_path, "rb") as f:
content = f.read(100)
return b"<!DOCTYPE html>" in content or b"<html" in content
def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
if device is None:
device = torch.device("cpu")
metadata = None
model_processor.process_file(ckpt)
if ckpt.lower().endswith(".safetensors") or ckpt.lower().endswith(".sft"):
try:
with safetensors.safe_open(ckpt, framework="pt", device=device.type) as f:
@ -66,6 +68,8 @@ def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
if return_metadata:
metadata = f.metadata()
except Exception as e:
if is_html_file(ckpt):
raise ValueError("{}\n\nFile path: {}\n\nThe requested file is an HTML document not a safetensors file. Please re-download the file, not the web page.".format(e, ckpt))
if len(e.args) > 0:
message = e.args[0]
if "HeaderTooLarge" in message:
@ -92,6 +96,13 @@ def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
sd = pl_sd
else:
sd = pl_sd
try:
from app.model_processor import model_processor
model_processor.process_file(ckpt)
except Exception as e:
logging.error(f"Error processing file {ckpt}: {e}")
return (sd, metadata) if return_metadata else sd
def save_torch_file(sd, ckpt, metadata=None):

View File

@ -275,7 +275,7 @@ def filter_files_extensions(files: Collection[str], extensions: Collection[str])
def get_full_path(folder_name: str, filename: str) -> str | None:
def get_full_path(folder_name: str, filename: str, allow_missing: bool = False) -> str | None:
global folder_names_and_paths
folder_name = map_legacy(folder_name)
if folder_name not in folder_names_and_paths:
@ -288,6 +288,8 @@ def get_full_path(folder_name: str, filename: str) -> str | None:
return full_path
elif os.path.islink(full_path):
logging.warning("WARNING path {} exists but doesn't link anywhere, skipping.".format(full_path))
elif allow_missing:
return full_path
return None

View File

@ -238,10 +238,11 @@ def cleanup_temp():
def setup_database():
try:
from app.database.db import init_db
from app.database.db import init_db, dependencies_available
if dependencies_available():
init_db()
except Exception as e:
logging.error(f"Failed to initialize database. Please report this error as in future the database will be required: {e}")
logging.error(f"Failed to initialize database. Please ensure you have installed the latest requirements. If the error persists, please report this as in future the database will be required: {e}")
def start_comfyui(asyncio_loop=None):
"""

View File

@ -20,6 +20,7 @@ tqdm
psutil
alembic
SQLAlchemy
blake3
#non essential dependencies:
kornia>=0.7.1

View File

@ -0,0 +1,253 @@
import pytest
from unittest.mock import patch, MagicMock
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from app.model_processor import ModelProcessor
from app.database.models import Model, Base
import os
# Test data constants
TEST_MODEL_TYPE = "checkpoints"
TEST_URL = "http://example.com/model.safetensors"
TEST_FILE_NAME = "model.safetensors"
TEST_EXPECTED_HASH = "abc123"
TEST_DESTINATION_PATH = "/path/to/model.safetensors"
def create_test_model(session, file_name, model_type, hash_value, file_size=1000, source_url=None):
"""Helper to create a test model in the database."""
model = Model(path=file_name, type=model_type, hash=hash_value, file_size=file_size, source_url=source_url)
session.add(model)
session.commit()
return model
def setup_mock_hash_calculation(model_processor, hash_value):
"""Helper to setup hash calculation mocks."""
mock_hash = MagicMock()
mock_hash.hexdigest.return_value = hash_value
return patch.object(model_processor, "_get_hasher", return_value=mock_hash)
def verify_model_in_db(session, file_name, expected_hash=None, expected_type=None):
"""Helper to verify model exists in database with correct attributes."""
db_model = session.query(Model).filter_by(path=file_name).first()
assert db_model is not None
if expected_hash:
assert db_model.hash == expected_hash
if expected_type:
assert db_model.type == expected_type
return db_model
@pytest.fixture
def db_engine():
# Configure in-memory database
engine = create_engine("sqlite:///:memory:")
Base.metadata.create_all(engine)
yield engine
Base.metadata.drop_all(engine)
@pytest.fixture
def db_session(db_engine):
Session = sessionmaker(bind=db_engine)
session = Session()
yield session
session.close()
@pytest.fixture
def mock_get_relative_path():
with patch("app.model_processor.get_relative_path") as mock:
mock.side_effect = lambda path: (TEST_MODEL_TYPE, os.path.basename(path))
yield mock
@pytest.fixture
def mock_get_full_path():
with patch("app.model_processor.get_full_path") as mock:
mock.return_value = TEST_DESTINATION_PATH
yield mock
@pytest.fixture
def model_processor(db_session, mock_get_relative_path, mock_get_full_path):
with patch("app.model_processor.create_session", return_value=db_session):
with patch("app.model_processor.can_create_session", return_value=True):
processor = ModelProcessor()
# Setup test state
processor.removed_files = []
processor.downloaded_files = []
processor.file_exists = {}
def mock_download_file(url, destination_path, hasher):
processor.downloaded_files.append((url, destination_path))
processor.file_exists[destination_path] = True
# Simulate writing some data to the file
test_data = b"test data"
hasher.update(test_data)
def mock_remove_file(file_path):
processor.removed_files.append(file_path)
if file_path in processor.file_exists:
del processor.file_exists[file_path]
# Setup common patches
file_exists_patch = patch.object(
processor,
"_file_exists",
side_effect=lambda path: processor.file_exists.get(path, False),
)
file_size_patch = patch.object(
processor,
"_get_file_size",
side_effect=lambda path: (
1000 if processor.file_exists.get(path, False) else 0
),
)
download_file_patch = patch.object(
processor, "_download_file", side_effect=mock_download_file
)
remove_file_patch = patch.object(
processor, "_remove_file", side_effect=mock_remove_file
)
with (
file_exists_patch,
file_size_patch,
download_file_patch,
remove_file_patch,
):
yield processor
def test_ensure_downloaded_invalid_extension(model_processor):
# Ensure that an unsupported file extension raises an error to prevent unsafe file downloads
with pytest.raises(ValueError, match="Unsupported unsafe file for download"):
model_processor.ensure_downloaded(TEST_MODEL_TYPE, TEST_URL, "model.exe")
def test_ensure_downloaded_existing_file_with_hash(model_processor, db_session):
# Ensure that a file with the same hash but from a different source is not downloaded again
SOURCE_URL = "https://example.com/other.sft"
create_test_model(db_session, TEST_FILE_NAME, TEST_MODEL_TYPE, TEST_EXPECTED_HASH, source_url=SOURCE_URL)
model_processor.file_exists[TEST_DESTINATION_PATH] = True
result = model_processor.ensure_downloaded(
TEST_MODEL_TYPE, TEST_URL, TEST_FILE_NAME, TEST_EXPECTED_HASH
)
assert result == TEST_DESTINATION_PATH
model = verify_model_in_db(db_session, TEST_FILE_NAME, TEST_EXPECTED_HASH, TEST_MODEL_TYPE)
assert model.source_url == SOURCE_URL # Ensure the source URL is not overwritten
def test_ensure_downloaded_existing_file_hash_mismatch(model_processor, db_session):
# Ensure that a file with a different hash raises an error
create_test_model(db_session, TEST_FILE_NAME, TEST_MODEL_TYPE, "different_hash")
model_processor.file_exists[TEST_DESTINATION_PATH] = True
with pytest.raises(ValueError, match="File .* exists with hash .* but expected .*"):
model_processor.ensure_downloaded(
TEST_MODEL_TYPE, TEST_URL, TEST_FILE_NAME, TEST_EXPECTED_HASH
)
def test_ensure_downloaded_new_file(model_processor, db_session):
# Ensure that a new file is downloaded
model_processor.file_exists[TEST_DESTINATION_PATH] = False
with setup_mock_hash_calculation(model_processor, TEST_EXPECTED_HASH):
result = model_processor.ensure_downloaded(
TEST_MODEL_TYPE, TEST_URL, TEST_FILE_NAME, TEST_EXPECTED_HASH
)
assert result == TEST_DESTINATION_PATH
assert len(model_processor.downloaded_files) == 1
assert model_processor.downloaded_files[0] == (TEST_URL, TEST_DESTINATION_PATH)
assert model_processor.file_exists[TEST_DESTINATION_PATH]
verify_model_in_db(db_session, TEST_FILE_NAME, TEST_EXPECTED_HASH, TEST_MODEL_TYPE)
def test_ensure_downloaded_hash_mismatch(model_processor, db_session):
# Ensure that download that results in a different hash raises an error
model_processor.file_exists[TEST_DESTINATION_PATH] = False
with setup_mock_hash_calculation(model_processor, "different_hash"):
with pytest.raises(
ValueError,
match="Downloaded file hash .* does not match expected hash .*",
):
model_processor.ensure_downloaded(
TEST_MODEL_TYPE,
TEST_URL,
TEST_FILE_NAME,
TEST_EXPECTED_HASH,
)
assert len(model_processor.removed_files) == 1
assert model_processor.removed_files[0] == TEST_DESTINATION_PATH
assert TEST_DESTINATION_PATH not in model_processor.file_exists
assert db_session.query(Model).filter_by(path=TEST_FILE_NAME).first() is None
def test_process_file_without_hash(model_processor, db_session):
# Test processing file without provided hash
model_processor.file_exists[TEST_DESTINATION_PATH] = True
with patch.object(model_processor, "_hash_file", return_value=TEST_EXPECTED_HASH):
result = model_processor.process_file(TEST_DESTINATION_PATH)
assert result is not None
assert result.hash == TEST_EXPECTED_HASH
def test_retrieve_model_by_hash(model_processor, db_session):
# Test retrieving model by hash
create_test_model(db_session, TEST_FILE_NAME, TEST_MODEL_TYPE, TEST_EXPECTED_HASH)
result = model_processor.retrieve_model_by_hash(TEST_EXPECTED_HASH)
assert result is not None
assert result.hash == TEST_EXPECTED_HASH
def test_retrieve_model_by_hash_and_type(model_processor, db_session):
# Test retrieving model by hash and type
create_test_model(db_session, TEST_FILE_NAME, TEST_MODEL_TYPE, TEST_EXPECTED_HASH)
result = model_processor.retrieve_model_by_hash(TEST_EXPECTED_HASH, TEST_MODEL_TYPE)
assert result is not None
assert result.hash == TEST_EXPECTED_HASH
assert result.type == TEST_MODEL_TYPE
def test_retrieve_hash(model_processor, db_session):
# Test retrieving hash for existing model
create_test_model(db_session, TEST_FILE_NAME, TEST_MODEL_TYPE, TEST_EXPECTED_HASH)
with patch.object(
model_processor,
"_validate_path",
return_value=(TEST_MODEL_TYPE, TEST_FILE_NAME),
):
result = model_processor.retrieve_hash(TEST_DESTINATION_PATH, TEST_MODEL_TYPE)
assert result == TEST_EXPECTED_HASH
def test_validate_file_extension_valid_extensions(model_processor):
# Test all valid file extensions
valid_extensions = [".safetensors", ".sft", ".txt", ".csv", ".json", ".yaml"]
for ext in valid_extensions:
model_processor._validate_file_extension(f"test{ext}") # Should not raise
def test_process_file_existing_without_source_url(model_processor, db_session):
# Test processing an existing file that needs its source URL updated
model_processor.file_exists[TEST_DESTINATION_PATH] = True
create_test_model(db_session, TEST_FILE_NAME, TEST_MODEL_TYPE, TEST_EXPECTED_HASH)
result = model_processor.process_file(TEST_DESTINATION_PATH, source_url=TEST_URL)
assert result is not None
assert result.hash == TEST_EXPECTED_HASH
assert result.source_url == TEST_URL
db_model = db_session.query(Model).filter_by(path=TEST_FILE_NAME).first()
assert db_model.source_url == TEST_URL