diff --git a/comfy_api/v3/io.py b/comfy_api/v3/io.py index d6b72e33..c49ea434 100644 --- a/comfy_api/v3/io.py +++ b/comfy_api/v3/io.py @@ -501,6 +501,7 @@ class classproperty(object): class ComfyNodeV3(ABC): """Common base class for all V3 nodes.""" + RELATIVE_PYTHON_MODULE = None ############################################# # V1 Backwards Compatibility code #-------------------------------------------- diff --git a/comfy_extras/nodes_v3_test.py b/comfy_extras/nodes_v3_test.py index 799ab4ae..7021139f 100644 --- a/comfy_extras/nodes_v3_test.py +++ b/comfy_extras/nodes_v3_test.py @@ -46,15 +46,15 @@ class V3TestNode(ComfyNodeV3): -NODES: list[ComfyNodeV3] = [ +NODES_LIST: list[ComfyNodeV3] = [ V3TestNode, ] -NODE_CLASS_MAPPINGS = {} -NODE_DISPLAY_NAME_MAPPINGS = {} -for node in NODES: - schema = node.GET_SCHEMA() - NODE_CLASS_MAPPINGS[schema.node_id] = node - if schema.display_name: - NODE_DISPLAY_NAME_MAPPINGS[schema.node_id] = schema.display_name +# NODE_CLASS_MAPPINGS = {} +# NODE_DISPLAY_NAME_MAPPINGS = {} +# for node in NODES_LIST: +# schema = node.GET_SCHEMA() +# NODE_CLASS_MAPPINGS[schema.node_id] = node +# if schema.display_name: +# NODE_DISPLAY_NAME_MAPPINGS[schema.node_id] = schema.display_name diff --git a/nodes.py b/nodes.py index f932b150..05c5cffc 100644 --- a/nodes.py +++ b/nodes.py @@ -26,6 +26,7 @@ import comfy.sd import comfy.utils import comfy.controlnet from comfy.comfy_types import IO, ComfyNodeABC, InputTypeDict, FileLocator +from comfy_api.v3.io import ComfyNodeV3 import comfy.clip_vision @@ -2128,7 +2129,17 @@ def load_custom_node(module_path: str, ignore=set(), module_parent="custom_nodes if os.path.isdir(web_dir): EXTENSION_WEB_DIRS[module_name] = web_dir - if hasattr(module, "NODE_CLASS_MAPPINGS") and getattr(module, "NODE_CLASS_MAPPINGS") is not None: + if getattr(module, "NODES_LIST", None) is not None: + for node_cls in module.NODES: + node_cls: ComfyNodeV3 + schema = node_cls.GET_SCHEMA() + if schema.node_id not in ignore: + NODE_CLASS_MAPPINGS[schema.node_id] = node_cls + node_cls.RELATIVE_PYTHON_MODULE = "{}.{}".format(module_parent, get_module_name(module_path)) + if schema.display_name is not None: + NODE_DISPLAY_NAME_MAPPINGS[schema.node_id] = schema.display_name + return True + elif hasattr(module, "NODE_CLASS_MAPPINGS") and getattr(module, "NODE_CLASS_MAPPINGS") is not None: for name, node_cls in module.NODE_CLASS_MAPPINGS.items(): if name not in ignore: NODE_CLASS_MAPPINGS[name] = node_cls @@ -2137,7 +2148,7 @@ def load_custom_node(module_path: str, ignore=set(), module_parent="custom_nodes NODE_DISPLAY_NAME_MAPPINGS.update(module.NODE_DISPLAY_NAME_MAPPINGS) return True else: - logging.warning(f"Skip {module_path} module for custom nodes due to the lack of NODE_CLASS_MAPPINGS.") + logging.warning(f"Skip {module_path} module for custom nodes due to the lack of NODE_CLASS_MAPPINGS or NODES_LIST (need one).") return False except Exception as e: logging.warning(traceback.format_exc())