From 1aa089e0b6c21185a930377f0255f19e2e5202ba Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Tue, 12 Aug 2025 17:59:16 -0700 Subject: [PATCH] More progress on brainstorming code for asset management for models --- comfy/asset_management.py | 60 ++++++++++++++++++++----------- comfy_extras/nodes_assets_test.py | 56 +++++++++++++++++++++++++++++ nodes.py | 1 + 3 files changed, 97 insertions(+), 20 deletions(-) create mode 100644 comfy_extras/nodes_assets_test.py diff --git a/comfy/asset_management.py b/comfy/asset_management.py index 6bdb18f45..e47996320 100644 --- a/comfy/asset_management.py +++ b/comfy/asset_management.py @@ -26,9 +26,10 @@ class ReturnedAssetABC(ABC): class ModelReturnedAsset(ReturnedAssetABC): - def __init__(self, model: dict[str, str] | tuple[dict[str, str], dict[str, str]]): + def __init__(self, state_dict: dict[str, str], metadata: dict[str, str]=None): super().__init__("model") - self.model = model + self.state_dict = state_dict + self.metadata = metadata class AssetResolverABC(ABC): @@ -38,24 +39,30 @@ class AssetResolverABC(ABC): class LocalAssetResolver(AssetResolverABC): - def resolve(self, asset_info: AssetInfo) -> ReturnedAssetABC: + def resolve(self, asset_info: AssetInfo, cache_result: bool=True) -> ReturnedAssetABC: # currently only supports models - make sure models is in the tags if "models" not in asset_info.tags: return None # TODO: if hash exists, call model processor to try to get info about model: if asset_info.hash: - ... + try: + from app.model_processor import model_processor + model_db = model_processor.retrieve_model_by_hash(asset_info.hash) + full_path = model_db.path + except Exception as e: + logging.error(f"Could not get model by hash with error: {e}") + # the good ol' bread and butter - folder_paths's keys as tags + folder_keys = folder_paths.folder_names_and_paths.keys() + parent_paths = [] + for tag in asset_info.tags: + if tag in folder_keys: + parent_paths.append(tag) # if subdir metadata and name exists, use that as the model name going forward if "subdir" in asset_info.metadata and asset_info.name: - relative_path = os.path.join(asset_info.metadata["subdir"], asset_info.name) - # the good ol' bread and butter - folder_paths's keys as tags - folder_keys = folder_paths.folder_names_and_paths.keys() - parent_paths = [] - for tag in asset_info.tags: - if tag in folder_keys: - parent_paths.append(tag) + # if no matching parent paths, then something went wrong and should return None if len(parent_paths) == 0: return None + relative_path = os.path.join(asset_info.metadata["subdir"], asset_info.name) # now we have the parent keys, we can try to get the local path chosen_parent = None full_path = None @@ -64,27 +71,40 @@ class LocalAssetResolver(AssetResolverABC): if full_path: chosen_parent = parent_path break - logging.info(f"Resolved {asset_info.name} to {full_path} in {chosen_parent}") - # we know the path, so load the model and return it - model = comfy.utils.load_torch_file(full_path, safe_load=True, device=asset_info.metadata.get("device", None), return_metadata=asset_info.metadata.get("return_metadata", False)) - return ModelReturnedAsset(model) - # TODO: if name exists, try to find model by name in all subdirs of parent paths + if full_path is not None: + logging.info(f"Resolved {asset_info.name} to {full_path} in {chosen_parent}") + # we know the path, so load the model and return it + state_dict, metadata = comfy.utils.load_torch_file(full_path, safe_load=True, device=asset_info.metadata.get("device", None), return_metadata=True) + # TODO: handle caching + return ModelReturnedAsset(state_dict, metadata) + # if just name exists, try to find model by name in all subdirs of parent paths + # TODO: this behavior should be configurable by user if asset_info.name: - ... - # TODO: if download_url metadata exists, download the model and load it + for parent_path in parent_paths: + filelist = folder_paths.get_filename_list(parent_path) + for file in filelist: + if os.path.basename(file) == asset_info.name: + full_path = folder_paths.get_full_path(parent_path, file) + state_dict, metadata = comfy.utils.load_torch_file(full_path, safe_load=True, device=asset_info.metadata.get("device", None), return_metadata=True) + # TODO: handle caching + return ModelReturnedAsset(state_dict, metadata) + # TODO: if download_url metadata exists, download the model and load it; this should be configurable by user if asset_info.metadata.get("download_url", None): ... return None resolvers: list[AssetResolverABC] = [] +resolvers.append(LocalAssetResolver()) def resolve(asset_info: AssetInfo) -> Any: global resolvers for resolver in resolvers: try: - return resolver.resolve(asset_info) + to_return = resolver.resolve(asset_info) + if to_return is not None: + return resolver.resolve(asset_info) except Exception as e: - logging.error(f"Error resolving asset {asset_info.hash}: {e}") + logging.error(f"Error resolving asset {asset_info.name} using {resolver.__class__.__name__}: {e}") return None diff --git a/comfy_extras/nodes_assets_test.py b/comfy_extras/nodes_assets_test.py new file mode 100644 index 000000000..5172cd628 --- /dev/null +++ b/comfy_extras/nodes_assets_test.py @@ -0,0 +1,56 @@ +from comfy_api.latest import io, ComfyExtension +import comfy.asset_management +import comfy.sd +import folder_paths +import logging +import os + + +class AssetTestNode(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="AssetTestNode", + is_experimental=True, + inputs=[ + io.Combo.Input("ckpt_name", folder_paths.get_filename_list("checkpoints")), + ], + outputs=[ + io.Model.Output(), + io.Clip.Output(), + io.Vae.Output(), + ], + ) + + @classmethod + def execute(cls, ckpt_name: str): + hash = None + # lets get the full path just so we can retrieve the hash from db, if exists + try: + full_path = folder_paths.get_full_path("checkpoints", ckpt_name) + if full_path is None: + raise Exception(f"Model {ckpt_name} not found") + from app.model_processor import model_processor + hash = model_processor.retrieve_hash(full_path) + except Exception as e: + logging.error(f"Could not get model by hash with error: {e}") + subdir, name = os.path.split(ckpt_name) + asset_info = comfy.asset_management.AssetInfo(hash=hash, name=name, tags=["models", "checkpoints"], metadata={"subdir": subdir}) + asset = comfy.asset_management.resolve(asset_info) + # /\ the stuff above should happen in execution code instead of inside the node + # \/ the stuff below should happen in the node - confirm is a model asset, do stuff to it (already loaded? or should be called to 'load'?) + if asset is None: + raise Exception(f"Model {asset_info.name} not found") + assert isinstance(asset, comfy.asset_management.ModelReturnedAsset) + out = comfy.sd.load_state_dict_guess_config(asset.state_dict, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"), metadata=asset.metadata) + return io.NodeOutput(out[0], out[1], out[2]) + + +class AssetTestExtension(ComfyExtension): + @classmethod + async def get_node_list(cls): + return [AssetTestNode] + + +def comfy_entrypoint(): + return AssetTestExtension() diff --git a/nodes.py b/nodes.py index 9448f9c1b..0f1f4e937 100644 --- a/nodes.py +++ b/nodes.py @@ -2320,6 +2320,7 @@ async def init_builtin_extra_nodes(): "nodes_camera_trajectory.py", "nodes_edit_model.py", "nodes_tcfg.py", + "nodes_assets_test.py", ] import_failed = []