Add additional db model metadata fields and model downloading function

This commit is contained in:
pythongosssss
2025-06-01 13:34:26 +01:00
parent 1cb3c98947
commit 9da6aca0d0
9 changed files with 566 additions and 50 deletions

View File

@@ -1,29 +1,50 @@
import logging
import os
import shutil
from app.logger import log_startup_warning
from utils.install_util import get_missing_requirements_message
from comfy.cli_args import args
_DB_AVAILABLE = False
Session = None
def can_create_session():
return Session is not None
try:
import alembic
import sqlalchemy
except ImportError as e:
logging.error(get_missing_requirements_message())
raise e
from alembic import command
from alembic.config import Config
from alembic.runtime.migration import MigrationContext
from alembic.script import ScriptDirectory
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from alembic import command
from alembic.config import Config
from alembic.runtime.migration import MigrationContext
from alembic.script import ScriptDirectory
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
_DB_AVAILABLE = True
except ImportError as e:
log_startup_warning(
f"""
------------------------------------------------------------------------
Error importing dependencies: {e}
{get_missing_requirements_message()}
This error is happening because ComfyUI now uses a local sqlite database.
------------------------------------------------------------------------
""".strip()
)
def dependencies_available():
"""
Temporary function to check if the dependencies are available
"""
return _DB_AVAILABLE
def can_create_session():
"""
Temporary function to check if the database is available to create a session
During initial release there may be environmental issues (or missing dependencies) that prevent the database from being created
"""
return dependencies_available() and Session is not None
def get_alembic_config():
@@ -49,6 +70,8 @@ def get_db_path():
def init_db():
db_url = args.database_url
logging.debug(f"Database URL: {db_url}")
db_path = get_db_path()
db_exists = os.path.exists(db_path)
config = get_alembic_config()
@@ -64,9 +87,8 @@ def init_db():
if current_rev != target_rev:
# Backup the database pre upgrade
db_path = get_db_path()
backup_path = db_path + ".bkp"
if os.path.exists(db_path):
if db_exists:
shutil.copy(db_path, backup_path)
else:
backup_path = None

View File

@@ -1,5 +1,6 @@
from sqlalchemy import (
Column,
Integer,
Text,
DateTime,
)
@@ -20,7 +21,7 @@ def to_dict(obj):
class Model(Base):
"""
SQLAlchemy model representing a model file in the system.
sqlalchemy model representing a model file in the system.
This class defines the database schema for storing information about model files,
including their type, path, hash, and when they were added to the system.
@@ -28,7 +29,11 @@ class Model(Base):
Attributes:
type (Text): The type of the model, this is the name of the folder in the models folder (primary key)
path (Text): The file path of the model relative to the type folder (primary key)
hash (Text): A sha256 hash of the model file
file_name (Text): The name of the model file
file_size (Integer): The size of the model file in bytes
hash (Text): A hash of the model file
hash_algorithm (Text): The algorithm used to generate the hash
source_url (Text): The URL of the model file
date_added (DateTime): Timestamp of when the model was added to the system
"""
@@ -36,7 +41,11 @@ class Model(Base):
type = Column(Text, primary_key=True)
path = Column(Text, primary_key=True)
file_name = Column(Text)
file_size = Column(Integer)
hash = Column(Text)
hash_algorithm = Column(Text)
source_url = Column(Text)
date_added = Column(DateTime, server_default=func.now())
def to_dict(self):

View File

@@ -1,16 +1,23 @@
import hashlib
import os
import logging
import time
from app.database.models import Model
from app.database.db import create_session
from folder_paths import get_relative_path
import requests
from tqdm import tqdm
from folder_paths import get_relative_path, get_full_path
from app.database.db import create_session, dependencies_available, can_create_session
import blake3
import comfy.utils
if dependencies_available():
from app.database.models import Model
class ModelProcessor:
def _validate_path(self, model_path):
try:
if not os.path.exists(model_path):
if not self._file_exists(model_path):
logging.error(f"Model file not found: {model_path}")
return None
@@ -26,15 +33,26 @@ class ModelProcessor:
logging.error(f"Error validating model path {model_path}: {str(e)}")
return None
def _file_exists(self, path):
"""Check if a file exists."""
return os.path.exists(path)
def _get_file_size(self, path):
"""Get file size."""
return os.path.getsize(path)
def _get_hasher(self):
return blake3.blake3()
def _hash_file(self, model_path):
try:
h = hashlib.sha256()
hasher = self._get_hasher()
with open(model_path, "rb", buffering=0) as f:
b = bytearray(128 * 1024)
mv = memoryview(b)
while n := f.readinto(mv):
h.update(mv[:n])
return h.hexdigest()
hasher.update(mv[:n])
return hasher.hexdigest()
except Exception as e:
logging.error(f"Error hashing file {model_path}: {str(e)}")
return None
@@ -46,9 +64,21 @@ class ModelProcessor:
.filter(Model.path == model_relative_path)
.first()
)
def _ensure_source_url(self, session, model, source_url):
if model.source_url is None:
model.source_url = source_url
session.commit()
def _update_database(
self, session, model_type, model_relative_path, model_hash, model=None
self,
session,
model_type,
model_path,
model_relative_path,
model_hash,
model,
source_url,
):
try:
if not model:
@@ -60,10 +90,16 @@ class ModelProcessor:
model = Model(
path=model_relative_path,
type=model_type,
file_name=os.path.basename(model_path),
)
session.add(model)
model.file_size = self._get_file_size(model_path)
model.hash = model_hash
if model_hash:
model.hash_algorithm = "blake3"
model.source_url = source_url
session.commit()
return model
except Exception as e:
@@ -71,36 +107,97 @@ class ModelProcessor:
f"Error updating database for {model_relative_path}: {str(e)}"
)
def process_file(self, model_path):
def process_file(self, model_path, source_url=None, model_hash=None):
"""
Process a model file and update the database with metadata.
If the file already exists and matches the database, it will not be processed again.
Returns the model object or if an error occurs, returns None.
"""
try:
if not can_create_session():
return
result = self._validate_path(model_path)
if not result:
return
model_type, model_relative_path = result
with create_session() as session:
session.expire_on_commit = False
existing_model = self._get_existing_model(
session, model_type, model_relative_path
)
if existing_model and existing_model.hash:
# File exists with hash, no need to process
if (
existing_model
and existing_model.hash
and existing_model.file_size == self._get_file_size(model_path)
):
# File exists with hash and same size, no need to process
self._ensure_source_url(session, existing_model, source_url)
return existing_model
start_time = time.time()
logging.info(f"Hashing model {model_relative_path}")
model_hash = self._hash_file(model_path)
if not model_hash:
return
logging.info(
f"Model hash: {model_hash} (duration: {time.time() - start_time} seconds)"
)
if model_hash:
model_hash = model_hash.lower()
logging.info(f"Using provided hash: {model_hash}")
else:
start_time = time.time()
logging.info(f"Hashing model {model_relative_path}")
model_hash = self._hash_file(model_path)
if not model_hash:
return
logging.info(
f"Model hash: {model_hash} (duration: {time.time() - start_time} seconds)"
)
return self._update_database(session, model_type, model_relative_path, model_hash)
return self._update_database(
session,
model_type,
model_path,
model_relative_path,
model_hash,
existing_model,
source_url,
)
except Exception as e:
logging.error(f"Error processing model file {model_path}: {str(e)}")
return None
def retrieve_model_by_hash(self, model_hash, model_type=None, session=None):
"""
Retrieve a model file from the database by hash and optionally by model type.
Returns the model object or None if the model doesnt exist or an error occurs.
"""
try:
if not can_create_session():
return
dispose_session = False
if session is None:
session = create_session()
dispose_session = True
model = session.query(Model).filter(Model.hash == model_hash)
if model_type is not None:
model = model.filter(Model.type == model_type)
return model.first()
except Exception as e:
logging.error(f"Error retrieving model by hash {model_hash}: {str(e)}")
return None
finally:
if dispose_session:
session.close()
def retrieve_hash(self, model_path, model_type=None):
"""
Retrieve the hash of a model file from the database.
Returns the hash or None if the model doesnt exist or an error occurs.
"""
try:
if not can_create_session():
return
if model_type is not None:
result = self._validate_path(model_path)
if not result:
@@ -118,5 +215,117 @@ class ModelProcessor:
logging.error(f"Error retrieving hash for {model_path}: {str(e)}")
return None
def _validate_file_extension(self, file_name):
"""Validate that the file extension is supported."""
extension = os.path.splitext(file_name)[1]
if extension not in (".safetensors", ".sft", ".txt", ".csv", ".json", ".yaml"):
raise ValueError(f"Unsupported unsafe file for download: {file_name}")
def _check_existing_file(self, model_type, file_name, expected_hash):
"""Check if file exists and has correct hash."""
destination_path = get_full_path(model_type, file_name, allow_missing=True)
if self._file_exists(destination_path):
model = self.process_file(destination_path)
if model and (expected_hash is None or model.hash == expected_hash):
logging.debug(
f"File {destination_path} already exists in the database and has the correct hash or no hash was provided."
)
return destination_path
else:
raise ValueError(
f"File {destination_path} exists with hash {model.hash if model else 'unknown'} but expected {expected_hash}. Please delete the file and try again."
)
return None
def _check_existing_file_by_hash(self, hash, type, url):
"""Check if a file with the given hash exists in the database and on disk."""
hash = hash.lower()
with create_session() as session:
model = self.retrieve_model_by_hash(hash, type, session)
if model:
existing_path = get_full_path(type, model.path)
if existing_path:
logging.debug(
f"File {model.path} already exists in the database at {existing_path}"
)
self._ensure_source_url(session, model, url)
return existing_path
else:
logging.debug(
f"File {model.path} exists in the database but not on disk"
)
return None
def _download_file(self, url, destination_path, hasher):
"""Download a file and update the hasher with its contents."""
response = requests.get(url, stream=True)
logging.info(f"Downloading {url} to {destination_path}")
with open(destination_path, "wb") as f:
total_size = int(response.headers.get("content-length", 0))
if total_size > 0:
pbar = comfy.utils.ProgressBar(total_size)
else:
pbar = None
with tqdm(total=total_size, unit="B", unit_scale=True) as progress_bar:
for chunk in response.iter_content(chunk_size=128 * 1024):
if chunk:
f.write(chunk)
hasher.update(chunk)
progress_bar.update(len(chunk))
if pbar:
pbar.update(len(chunk))
def _verify_downloaded_hash(self, calculated_hash, expected_hash, destination_path):
"""Verify that the downloaded file has the expected hash."""
if expected_hash is not None and calculated_hash != expected_hash:
self._remove_file(destination_path)
raise ValueError(
f"Downloaded file hash {calculated_hash} does not match expected hash {expected_hash}"
)
def _remove_file(self, file_path):
"""Remove a file from disk."""
os.remove(file_path)
def ensure_downloaded(self, type, url, desired_file_name, hash=None):
"""
Ensure a model file is downloaded and has the correct hash.
Returns the path to the downloaded file.
"""
logging.debug(
f"Ensuring {type} file is downloaded. URL='{url}' Destination='{desired_file_name}' Hash='{hash}'"
)
# Validate file extension
self._validate_file_extension(desired_file_name)
# Check if file exists with correct hash
if hash:
existing_path = self._check_existing_file_by_hash(hash, type, url)
if existing_path:
return existing_path
# Check if file exists locally
destination_path = get_full_path(type, desired_file_name, allow_missing=True)
existing_path = self._check_existing_file(type, desired_file_name, hash)
if existing_path:
return existing_path
# Download the file
hasher = self._get_hasher()
self._download_file(url, destination_path, hasher)
# Verify hash
calculated_hash = hasher.hexdigest()
self._verify_downloaded_hash(calculated_hash, hash, destination_path)
# Update database
self.process_file(destination_path, url, calculated_hash)
# TODO: Notify frontend to reload models
return destination_path
model_processor = ModelProcessor()