mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-09-11 03:58:22 +00:00
More progress on brainstorming code for asset management for models
This commit is contained in:
@@ -26,9 +26,10 @@ class ReturnedAssetABC(ABC):
|
|||||||
|
|
||||||
|
|
||||||
class ModelReturnedAsset(ReturnedAssetABC):
|
class ModelReturnedAsset(ReturnedAssetABC):
|
||||||
def __init__(self, model: dict[str, str] | tuple[dict[str, str], dict[str, str]]):
|
def __init__(self, state_dict: dict[str, str], metadata: dict[str, str]=None):
|
||||||
super().__init__("model")
|
super().__init__("model")
|
||||||
self.model = model
|
self.state_dict = state_dict
|
||||||
|
self.metadata = metadata
|
||||||
|
|
||||||
|
|
||||||
class AssetResolverABC(ABC):
|
class AssetResolverABC(ABC):
|
||||||
@@ -38,24 +39,30 @@ class AssetResolverABC(ABC):
|
|||||||
|
|
||||||
|
|
||||||
class LocalAssetResolver(AssetResolverABC):
|
class LocalAssetResolver(AssetResolverABC):
|
||||||
def resolve(self, asset_info: AssetInfo) -> ReturnedAssetABC:
|
def resolve(self, asset_info: AssetInfo, cache_result: bool=True) -> ReturnedAssetABC:
|
||||||
# currently only supports models - make sure models is in the tags
|
# currently only supports models - make sure models is in the tags
|
||||||
if "models" not in asset_info.tags:
|
if "models" not in asset_info.tags:
|
||||||
return None
|
return None
|
||||||
# TODO: if hash exists, call model processor to try to get info about model:
|
# TODO: if hash exists, call model processor to try to get info about model:
|
||||||
if asset_info.hash:
|
if asset_info.hash:
|
||||||
...
|
try:
|
||||||
|
from app.model_processor import model_processor
|
||||||
|
model_db = model_processor.retrieve_model_by_hash(asset_info.hash)
|
||||||
|
full_path = model_db.path
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Could not get model by hash with error: {e}")
|
||||||
|
# the good ol' bread and butter - folder_paths's keys as tags
|
||||||
|
folder_keys = folder_paths.folder_names_and_paths.keys()
|
||||||
|
parent_paths = []
|
||||||
|
for tag in asset_info.tags:
|
||||||
|
if tag in folder_keys:
|
||||||
|
parent_paths.append(tag)
|
||||||
# if subdir metadata and name exists, use that as the model name going forward
|
# if subdir metadata and name exists, use that as the model name going forward
|
||||||
if "subdir" in asset_info.metadata and asset_info.name:
|
if "subdir" in asset_info.metadata and asset_info.name:
|
||||||
relative_path = os.path.join(asset_info.metadata["subdir"], asset_info.name)
|
# if no matching parent paths, then something went wrong and should return None
|
||||||
# the good ol' bread and butter - folder_paths's keys as tags
|
|
||||||
folder_keys = folder_paths.folder_names_and_paths.keys()
|
|
||||||
parent_paths = []
|
|
||||||
for tag in asset_info.tags:
|
|
||||||
if tag in folder_keys:
|
|
||||||
parent_paths.append(tag)
|
|
||||||
if len(parent_paths) == 0:
|
if len(parent_paths) == 0:
|
||||||
return None
|
return None
|
||||||
|
relative_path = os.path.join(asset_info.metadata["subdir"], asset_info.name)
|
||||||
# now we have the parent keys, we can try to get the local path
|
# now we have the parent keys, we can try to get the local path
|
||||||
chosen_parent = None
|
chosen_parent = None
|
||||||
full_path = None
|
full_path = None
|
||||||
@@ -64,27 +71,40 @@ class LocalAssetResolver(AssetResolverABC):
|
|||||||
if full_path:
|
if full_path:
|
||||||
chosen_parent = parent_path
|
chosen_parent = parent_path
|
||||||
break
|
break
|
||||||
logging.info(f"Resolved {asset_info.name} to {full_path} in {chosen_parent}")
|
if full_path is not None:
|
||||||
# we know the path, so load the model and return it
|
logging.info(f"Resolved {asset_info.name} to {full_path} in {chosen_parent}")
|
||||||
model = comfy.utils.load_torch_file(full_path, safe_load=True, device=asset_info.metadata.get("device", None), return_metadata=asset_info.metadata.get("return_metadata", False))
|
# we know the path, so load the model and return it
|
||||||
return ModelReturnedAsset(model)
|
state_dict, metadata = comfy.utils.load_torch_file(full_path, safe_load=True, device=asset_info.metadata.get("device", None), return_metadata=True)
|
||||||
# TODO: if name exists, try to find model by name in all subdirs of parent paths
|
# TODO: handle caching
|
||||||
|
return ModelReturnedAsset(state_dict, metadata)
|
||||||
|
# if just name exists, try to find model by name in all subdirs of parent paths
|
||||||
|
# TODO: this behavior should be configurable by user
|
||||||
if asset_info.name:
|
if asset_info.name:
|
||||||
...
|
for parent_path in parent_paths:
|
||||||
# TODO: if download_url metadata exists, download the model and load it
|
filelist = folder_paths.get_filename_list(parent_path)
|
||||||
|
for file in filelist:
|
||||||
|
if os.path.basename(file) == asset_info.name:
|
||||||
|
full_path = folder_paths.get_full_path(parent_path, file)
|
||||||
|
state_dict, metadata = comfy.utils.load_torch_file(full_path, safe_load=True, device=asset_info.metadata.get("device", None), return_metadata=True)
|
||||||
|
# TODO: handle caching
|
||||||
|
return ModelReturnedAsset(state_dict, metadata)
|
||||||
|
# TODO: if download_url metadata exists, download the model and load it; this should be configurable by user
|
||||||
if asset_info.metadata.get("download_url", None):
|
if asset_info.metadata.get("download_url", None):
|
||||||
...
|
...
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
resolvers: list[AssetResolverABC] = []
|
resolvers: list[AssetResolverABC] = []
|
||||||
|
resolvers.append(LocalAssetResolver())
|
||||||
|
|
||||||
|
|
||||||
def resolve(asset_info: AssetInfo) -> Any:
|
def resolve(asset_info: AssetInfo) -> Any:
|
||||||
global resolvers
|
global resolvers
|
||||||
for resolver in resolvers:
|
for resolver in resolvers:
|
||||||
try:
|
try:
|
||||||
return resolver.resolve(asset_info)
|
to_return = resolver.resolve(asset_info)
|
||||||
|
if to_return is not None:
|
||||||
|
return resolver.resolve(asset_info)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Error resolving asset {asset_info.hash}: {e}")
|
logging.error(f"Error resolving asset {asset_info.name} using {resolver.__class__.__name__}: {e}")
|
||||||
return None
|
return None
|
||||||
|
56
comfy_extras/nodes_assets_test.py
Normal file
56
comfy_extras/nodes_assets_test.py
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
from comfy_api.latest import io, ComfyExtension
|
||||||
|
import comfy.asset_management
|
||||||
|
import comfy.sd
|
||||||
|
import folder_paths
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
class AssetTestNode(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="AssetTestNode",
|
||||||
|
is_experimental=True,
|
||||||
|
inputs=[
|
||||||
|
io.Combo.Input("ckpt_name", folder_paths.get_filename_list("checkpoints")),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Model.Output(),
|
||||||
|
io.Clip.Output(),
|
||||||
|
io.Vae.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, ckpt_name: str):
|
||||||
|
hash = None
|
||||||
|
# lets get the full path just so we can retrieve the hash from db, if exists
|
||||||
|
try:
|
||||||
|
full_path = folder_paths.get_full_path("checkpoints", ckpt_name)
|
||||||
|
if full_path is None:
|
||||||
|
raise Exception(f"Model {ckpt_name} not found")
|
||||||
|
from app.model_processor import model_processor
|
||||||
|
hash = model_processor.retrieve_hash(full_path)
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Could not get model by hash with error: {e}")
|
||||||
|
subdir, name = os.path.split(ckpt_name)
|
||||||
|
asset_info = comfy.asset_management.AssetInfo(hash=hash, name=name, tags=["models", "checkpoints"], metadata={"subdir": subdir})
|
||||||
|
asset = comfy.asset_management.resolve(asset_info)
|
||||||
|
# /\ the stuff above should happen in execution code instead of inside the node
|
||||||
|
# \/ the stuff below should happen in the node - confirm is a model asset, do stuff to it (already loaded? or should be called to 'load'?)
|
||||||
|
if asset is None:
|
||||||
|
raise Exception(f"Model {asset_info.name} not found")
|
||||||
|
assert isinstance(asset, comfy.asset_management.ModelReturnedAsset)
|
||||||
|
out = comfy.sd.load_state_dict_guess_config(asset.state_dict, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"), metadata=asset.metadata)
|
||||||
|
return io.NodeOutput(out[0], out[1], out[2])
|
||||||
|
|
||||||
|
|
||||||
|
class AssetTestExtension(ComfyExtension):
|
||||||
|
@classmethod
|
||||||
|
async def get_node_list(cls):
|
||||||
|
return [AssetTestNode]
|
||||||
|
|
||||||
|
|
||||||
|
def comfy_entrypoint():
|
||||||
|
return AssetTestExtension()
|
Reference in New Issue
Block a user