From 4b08314257b3cdf1514c4d1035591813f7e0e29a Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 1 Feb 2023 22:33:10 -0500 Subject: [PATCH] Add more features to the backend queue code. The queue can now be queried, entries can be deleted and prompts easily queued to the front of the queue. Just need to expose it in the UI next. --- main.py | 91 ++++++++++++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 84 insertions(+), 7 deletions(-) diff --git a/main.py b/main.py index bc0af3dd..e0035fdc 100644 --- a/main.py +++ b/main.py @@ -3,7 +3,7 @@ import sys import copy import json import threading -import queue +import heapq import traceback if '--dont-upcast-attention' in sys.argv: @@ -148,6 +148,7 @@ class PromptExecutor: to_execute += [(0, x)] while len(to_execute) > 0: + #always execute the output that depends on the least amount of unexecuted nodes first to_execute = sorted(list(map(lambda a: (len(recursive_will_execute(prompt, self.outputs, a[-1])), a[-1]), to_execute))) x = to_execute.pop(0)[-1] @@ -266,10 +267,63 @@ def validate_prompt(prompt): def prompt_worker(q): e = PromptExecutor() while True: - item = q.get() + item, item_id = q.get() e.execute(item[-2], item[-1]) - q.task_done() + q.task_done(item_id) +class PromptQueue: + def __init__(self): + self.mutex = threading.RLock() + self.not_empty = threading.Condition(self.mutex) + self.task_counter = 0 + self.queue = [] + self.currently_running = {} + + def put(self, item): + with self.mutex: + heapq.heappush(self.queue, item) + self.not_empty.notify() + + def get(self): + with self.not_empty: + while len(self.queue) == 0: + self.not_empty.wait() + item = heapq.heappop(self.queue) + i = self.task_counter + self.currently_running[i] = copy.deepcopy(item) + self.task_counter += 1 + return (item, i) + + def task_done(self, item_id): + with self.mutex: + self.currently_running.pop(item_id) + + def get_current_queue(self): + with self.mutex: + out = [] + for x in self.currently_running.values(): + out += [x] + return (out, copy.deepcopy(self.queue)) + + def get_tasks_remaining(self): + with self.mutex: + return len(self.queue) + len(self.currently_running) + + def wipe_queue(self): + with self.mutex: + self.queue = [] + + def delete_queue_item(self, function): + with self.mutex: + for x in range(len(self.queue)): + if function(self.queue[x]): + if len(self.queue) == 1: + self.wipe_queue() + else: + self.queue.pop(x) + heapq.heapify(self.queue) + return True + return False from http.server import BaseHTTPRequestHandler, HTTPServer @@ -285,9 +339,16 @@ class PromptServer(BaseHTTPRequestHandler): self._set_headers(ct='application/json') prompt_info = {} exec_info = {} - exec_info['queue_remaining'] = self.server.prompt_queue.unfinished_tasks + exec_info['queue_remaining'] = self.server.prompt_queue.get_tasks_remaining() prompt_info['exec_info'] = exec_info self.wfile.write(json.dumps(prompt_info).encode('utf-8')) + elif self.path == "/queue": + self._set_headers(ct='application/json') + queue_info = {} + current_queue = self.server.prompt_queue.get_current_queue() + queue_info['queue_running'] = current_queue[0] + queue_info['queue_pending'] = current_queue[1] + self.wfile.write(json.dumps(queue_info).encode('utf-8')) elif self.path == "/object_info": self._set_headers(ct='application/json') out = {} @@ -325,12 +386,16 @@ class PromptServer(BaseHTTPRequestHandler): out_string = "" if self.path == "/prompt": print("got prompt") - self.data_string = self.rfile.read(int(self.headers['Content-Length'])) - json_data = json.loads(self.data_string) + data_string = self.rfile.read(int(self.headers['Content-Length'])) + json_data = json.loads(data_string) if "number" in json_data: number = float(json_data['number']) else: number = self.server.number + if "front" in json_data: + if json_data['front']: + number = -number + self.server.number += 1 if "prompt" in json_data: prompt = json_data["prompt"] @@ -344,6 +409,18 @@ class PromptServer(BaseHTTPRequestHandler): resp_code = 400 out_string = valid[1] print("invalid prompt:", valid[1]) + elif self.path == "/queue": + data_string = self.rfile.read(int(self.headers['Content-Length'])) + json_data = json.loads(data_string) + if "clear" in json_data: + if json_data["clear"]: + self.server.prompt_queue.wipe_queue() + if "delete" in json_data: + to_delete = json_data['delete'] + for id_to_delete in to_delete: + delete_func = lambda a: a[1] == int(id_to_delete) + self.server.prompt_queue.delete_queue_item(delete_func) + self._set_headers(code=resp_code) self.end_headers() self.wfile.write(out_string.encode('utf8')) @@ -366,7 +443,7 @@ def run(prompt_queue, address='', port=8188): if __name__ == "__main__": - q = queue.PriorityQueue() + q = PromptQueue() threading.Thread(target=prompt_worker, daemon=True, args=(q,)).start() run(q, address='127.0.0.1', port=8188)