diff --git a/comfy_api/v3/io.py b/comfy_api/v3/io.py index 957bae802..bbdd12300 100644 --- a/comfy_api/v3/io.py +++ b/comfy_api/v3/io.py @@ -220,15 +220,40 @@ class NodeState(ABC): def __init__(self, node_id: str): self.node_id = node_id + @abstractmethod + def get_value(self, key: str): + pass + + @abstractmethod + def set_value(self, key: str, value: Any): + pass + @abstractmethod def pop(self, key: str): pass + @abstractmethod + def __contains__(self, key: str): + pass + + class NodeStateLocal(NodeState): def __init__(self, node_id: str): super().__init__(node_id) self.local_state = {} + def get_value(self, key: str): + return self.local_state.get(key) + + def set_value(self, key: str, value: Any): + self.local_state[key] = value + + def pop(self, key: str): + return self.local_state.pop(key, None) + + def __contains__(self, key: str): + return key in self.local_state + def __getattr__(self, key: str): local_state = type(self).__getattribute__(self, "local_state") if key in local_state: @@ -248,15 +273,9 @@ class NodeStateLocal(NodeState): def __getitem__(self, key: str): return self.local_state[key] - def __contains__(self, key: str): - return key in self.local_state - def __delitem__(self, key: str): del self.local_state[key] - def pop(self, key: str): - return self.local_state.pop(key) - @comfytype(io_type=IO.BOOLEAN) class Boolean: Type = bool @@ -319,7 +338,6 @@ class Float: default: float=None, min: float=None, max: float=None, step: float=None, round: float=None, display_mode: NumberDisplay=None, socketless: bool=None, types: list[type[ComfyType] | ComfyType]=None): super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, self.io_type, types) - self.default = default self.min = min self.max = max self.step = step diff --git a/comfy_extras/nodes_v3_test.py b/comfy_extras/nodes_v3_test.py index d6f9fcc76..c65926ac2 100644 --- a/comfy_extras/nodes_v3_test.py +++ b/comfy_extras/nodes_v3_test.py @@ -80,11 +80,15 @@ class V3TestNode(io.ComfyNodeV3): cls.state["thing"] = "hahaha" yyy = cls.state["thing"] del cls.state["thing"] + if cls.state.get_value("int2") is None: + cls.state.set_value("int2", 123) + zzz = cls.state.get_value("int2") + cls.state.pop("int2") if cls.state.my_int is None: cls.state.my_int = expected_int else: if cls.state.my_int != expected_int: - raise Exception(f"Explicit state object did not maintain expected value: {cls.state.my_int} != {expected_int}") + raise Exception(f"Explicit state object did not maintain expected value (__getattr__/__setattr__): {cls.state.my_int} != {expected_int}") #some_int if hasattr(cls, "hahajkunless"): raise Exception("The 'cls' variable leaked instance state between runs!") @@ -158,7 +162,7 @@ class V3LoraLoader(io.ComfyNodeV3): return io.NodeOutput(model_lora, clip_lora) -NODES_LIST: list[io.ComfyNodeV3] = [ +NODES_LIST: list[type[io.ComfyNodeV3]] = [ V3TestNode, V3LoraLoader, ]