diff --git a/comfy_api/v3/io.py b/comfy_api/v3/io.py index 904c87b20..68cecd0d9 100644 --- a/comfy_api/v3/io.py +++ b/comfy_api/v3/io.py @@ -4,8 +4,9 @@ from enum import Enum from abc import ABC, abstractmethod from dataclasses import dataclass, asdict from comfy.comfy_types.node_typing import IO - +# used for type hinting import torch +from comfy.model_patcher import ModelPatcher class FolderType(str, Enum): @@ -213,10 +214,13 @@ class ComfyTypeIO(ComfyType): ... -class NodeState: +class NodeState(ABC): def __init__(self, node_id: str): self.node_id = node_id + @abstractmethod + def pop(self, key: str): + pass class NodeStateLocal(NodeState): def __init__(self, node_id: str): @@ -227,14 +231,29 @@ class NodeStateLocal(NodeState): local_state = type(self).__getattribute__(self, "local_state") if key in local_state: return local_state[key] - raise AttributeError(f"'{type(self).__name__}' object has no attribute '{key}'") + return None + # raise AttributeError(f"'{type(self).__name__}' object has no attribute '{key}'") def __setattr__(self, key: str, value: Any): if key in ['node_id', 'local_state']: super().__setattr__(key, value) else: self.local_state[key] = value + + def __setitem__(self, key: str, value: Any): + self.local_state[key] = value + + def __getitem__(self, key: str): + return self.local_state[key] + + def __contains__(self, key: str): + return key in self.local_state + + def __delitem__(self, key: str): + del self.local_state[key] + def pop(self, key: str): + return self.local_state.pop(key) @comfytype(io_type=IO.BOOLEAN) class Boolean: @@ -441,7 +460,7 @@ class Vae(ComfyTypeIO): @comfytype(io_type=IO.MODEL) class Model(ComfyTypeIO): - Type = Any + Type = ModelPatcher @comfytype(io_type=IO.CLIP_VISION) class ClipVision(ComfyTypeIO): @@ -677,63 +696,6 @@ class SchemaV3: is_api_node: bool=False """Flags a node as an API node. See: https://docs.comfy.org/tutorials/api-nodes/overview.""" -# class SchemaV3Class: -# def __init__(self, -# node_id: str, -# node_name: str, -# category: str, -# inputs: list[InputV3], -# outputs: list[OutputV3]=None, -# hidden: list[Hidden]=None, -# description: str="", -# is_input_list: bool = False, -# is_output_node: bool=False, -# is_deprecated: bool=False, -# is_experimental: bool=False, -# is_api_node: bool=False, -# ): -# self.node_id = node_id -# """ID of node - should be globally unique. If this is a custom node, add a prefix or postfix to avoid name clashes.""" -# self.node_name = node_name -# """Display name of node.""" -# self.category = category -# """The category of the node, as per the "Add Node" menu.""" -# self.inputs = inputs -# self.outputs = outputs -# self.hidden = hidden -# self.description = description -# """Node description, shown as a tooltip when hovering over the node.""" -# self.is_input_list = is_input_list -# """A flag indicating if this node implements the additional code necessary to deal with OUTPUT_IS_LIST nodes. - -# All inputs of ``type`` will become ``list[type]``, regardless of how many items are passed in. This also affects ``check_lazy_status``. - -# From the docs: - -# A node can also override the default input behaviour and receive the whole list in a single call. This is done by setting a class attribute `INPUT_IS_LIST` to ``True``. - -# Comfy Docs: https://docs.comfy.org/custom-nodes/backend/lists#list-processing -# """ -# self.is_output_node = is_output_node -# """Flags this node as an output node, causing any inputs it requires to be executed. - -# If a node is not connected to any output nodes, that node will not be executed. Usage:: - -# OUTPUT_NODE = True - -# From the docs: - -# By default, a node is not considered an output. Set ``OUTPUT_NODE = True`` to specify that it is. - -# Comfy Docs: https://docs.comfy.org/custom-nodes/backend/server_overview#output-node -# """ -# self.is_deprecated = is_deprecated -# """Flags a node as deprecated, indicating to users that they should find alternatives to this node.""" -# self.is_experimental = is_experimental -# """Flags a node as experimental, informing users that it may change or not work as expected.""" -# self.is_api_node = is_api_node -# """Flags a node as an API node. See: https://docs.comfy.org/tutorials/api-nodes/overview.""" - class Serializer: def __init_subclass__(cls, io_type: IO | str, **kwargs): diff --git a/comfy_extras/nodes_v3_test.py b/comfy_extras/nodes_v3_test.py index fe0eeb8f0..47ab473a9 100644 --- a/comfy_extras/nodes_v3_test.py +++ b/comfy_extras/nodes_v3_test.py @@ -1,7 +1,10 @@ import torch from comfy_api.v3 import io import logging - +import folder_paths +import comfy.utils +import comfy.sd +from typing import Any @io.comfytype(io_type="XYZ") class XYZ: @@ -29,8 +32,8 @@ class V3TestNode(io.ComfyNodeV3): return io.SchemaV3( node_id="V3_01_TestNode1", display_name="V3 Test Node", - description="This is a funky V3 node test.", category="v3 nodes", + description="This is a funky V3 node test.", inputs=[ io.Image.Input("image", display_name="new_image"), XYZ.Input("xyz", optional=True), @@ -75,7 +78,10 @@ class V3TestNode(io.ComfyNodeV3): zzz = cls.hidden.prompt cls.state.my_str = "LOLJK" expected_int = 123 - cls.state.my_int = expected_int + if "thing" not in cls.state: + cls.state["thing"] = "hahaha" + yyy = cls.state["thing"] + del cls.state["thing"] if cls.state.my_int is None: cls.state.my_int = expected_int else: @@ -90,6 +96,71 @@ class V3TestNode(io.ComfyNodeV3): return io.NodeOutput(some_int, image) +class V3LoraLoader(io.ComfyNodeV3): + class State(io.NodeState): + loaded_lora: tuple[str, Any] | None = None + state: State + + @classmethod + def DEFINE_SCHEMA(cls): + return io.SchemaV3( + node_id="V3_LoraLoader", + display_name="V3 LoRA Loader", + category="v3 nodes", + description="LoRAs are used to modify diffusion and CLIP models, altering the way in which latents are denoised such as applying styles. Multiple LoRA nodes can be linked together.", + inputs=[ + io.Model.Input("model", tooltip="The diffusion model the LoRA will be applied to."), + io.Clip.Input("clip", tooltip="The CLIP model the LoRA will be applied to."), + io.Combo.Input( + "lora_name", + options=folder_paths.get_filename_list("loras"), + tooltip="The name of the LoRA." + ), + io.Float.Input( + "strength_model", + default=1.0, + min=-100.0, + max=100.0, + step=0.01, + tooltip="How strongly to modify the diffusion model. This value can be negative." + ), + io.Float.Input( + "strength_clip", + default=1.0, + min=-100.0, + max=100.0, + step=0.01, + tooltip="How strongly to modify the CLIP model. This value can be negative." + ), + ], + outputs=[ + io.Model.Output("model_out"), + io.Clip.Output("clip_out"), + ], + ) + + @classmethod + def execute(cls, model: io.Model.Type, clip: io.Clip.Type, lora_name: str, strength_model: float, strength_clip: float, **kwargs): + if strength_model == 0 and strength_clip == 0: + return io.NodeOutput(model, clip) + + lora_path = folder_paths.get_full_path_or_raise("loras", lora_name) + lora = None + if cls.state.loaded_lora is not None: + if cls.state.loaded_lora[0] == lora_path: + lora = cls.state.loaded_lora[1] + else: + cls.state.loaded_lora = None + + if lora is None: + lora = comfy.utils.load_torch_file(lora_path, safe_load=True) + cls.state.loaded_lora = (lora_path, lora) + + model_lora, clip_lora = comfy.sd.load_lora_for_models(model, clip, lora, strength_model, strength_clip) + return io.NodeOutput(model_lora, clip_lora) + + NODES_LIST: list[io.ComfyNodeV3] = [ V3TestNode, + V3LoraLoader, ]