mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-09-11 03:58:22 +00:00
Because `client_id` was missing from ths `progress_state` message, it was being sent to all connected sessions. This technically meant that if someone had a graph with the same nodes, they would see the progress updates for others. Also added a test to prevent reoccurance and moved the tests around to make CI easier to hook up.
234 lines
9.2 KiB
Python
234 lines
9.2 KiB
Python
"""Test that progress updates are properly isolated between WebSocket clients."""
|
|
|
|
import json
|
|
import pytest
|
|
import time
|
|
import threading
|
|
import uuid
|
|
import websocket
|
|
from typing import List, Dict, Any
|
|
from comfy_execution.graph_utils import GraphBuilder
|
|
from tests.execution.test_execution import ComfyClient
|
|
|
|
|
|
class ProgressTracker:
|
|
"""Tracks progress messages received by a WebSocket client."""
|
|
|
|
def __init__(self, client_id: str):
|
|
self.client_id = client_id
|
|
self.progress_messages: List[Dict[str, Any]] = []
|
|
self.lock = threading.Lock()
|
|
|
|
def add_message(self, message: Dict[str, Any]):
|
|
"""Thread-safe addition of progress messages."""
|
|
with self.lock:
|
|
self.progress_messages.append(message)
|
|
|
|
def get_messages_for_prompt(self, prompt_id: str) -> List[Dict[str, Any]]:
|
|
"""Get all progress messages for a specific prompt_id."""
|
|
with self.lock:
|
|
return [
|
|
msg for msg in self.progress_messages
|
|
if msg.get('data', {}).get('prompt_id') == prompt_id
|
|
]
|
|
|
|
def has_cross_contamination(self, own_prompt_id: str) -> bool:
|
|
"""Check if this client received progress for other prompts."""
|
|
with self.lock:
|
|
for msg in self.progress_messages:
|
|
msg_prompt_id = msg.get('data', {}).get('prompt_id')
|
|
if msg_prompt_id and msg_prompt_id != own_prompt_id:
|
|
return True
|
|
return False
|
|
|
|
|
|
class IsolatedClient(ComfyClient):
|
|
"""Extended ComfyClient that tracks all WebSocket messages."""
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.progress_tracker = None
|
|
self.all_messages: List[Dict[str, Any]] = []
|
|
|
|
def connect(self, listen='127.0.0.1', port=8188, client_id=None):
|
|
"""Connect with a specific client_id and set up message tracking."""
|
|
if client_id is None:
|
|
client_id = str(uuid.uuid4())
|
|
super().connect(listen, port, client_id)
|
|
self.progress_tracker = ProgressTracker(client_id)
|
|
|
|
def listen_for_messages(self, duration: float = 5.0):
|
|
"""Listen for WebSocket messages for a specified duration."""
|
|
end_time = time.time() + duration
|
|
self.ws.settimeout(0.5) # Non-blocking with timeout
|
|
|
|
while time.time() < end_time:
|
|
try:
|
|
out = self.ws.recv()
|
|
if isinstance(out, str):
|
|
message = json.loads(out)
|
|
self.all_messages.append(message)
|
|
|
|
# Track progress_state messages
|
|
if message.get('type') == 'progress_state':
|
|
self.progress_tracker.add_message(message)
|
|
except websocket.WebSocketTimeoutException:
|
|
continue
|
|
except Exception:
|
|
# Log error silently in test context
|
|
break
|
|
|
|
|
|
@pytest.mark.execution
|
|
class TestProgressIsolation:
|
|
"""Test suite for verifying progress update isolation between clients."""
|
|
|
|
@pytest.fixture(scope="class", autouse=True)
|
|
def _server(self, args_pytest):
|
|
"""Start the ComfyUI server for testing."""
|
|
import subprocess
|
|
pargs = [
|
|
'python', 'main.py',
|
|
'--output-directory', args_pytest["output_dir"],
|
|
'--listen', args_pytest["listen"],
|
|
'--port', str(args_pytest["port"]),
|
|
'--extra-model-paths-config', 'tests/execution/extra_model_paths.yaml',
|
|
'--cpu',
|
|
]
|
|
p = subprocess.Popen(pargs)
|
|
yield
|
|
p.kill()
|
|
|
|
def start_client_with_retry(self, listen: str, port: int, client_id: str = None):
|
|
"""Start client with connection retries."""
|
|
client = IsolatedClient()
|
|
# Connect to server (with retries)
|
|
n_tries = 5
|
|
for i in range(n_tries):
|
|
time.sleep(4)
|
|
try:
|
|
client.connect(listen, port, client_id)
|
|
return client
|
|
except ConnectionRefusedError as e:
|
|
print(e) # noqa: T201
|
|
print(f"({i+1}/{n_tries}) Retrying...") # noqa: T201
|
|
raise ConnectionRefusedError(f"Failed to connect after {n_tries} attempts")
|
|
|
|
def test_progress_isolation_between_clients(self, args_pytest):
|
|
"""Test that progress updates are isolated between different clients."""
|
|
listen = args_pytest["listen"]
|
|
port = args_pytest["port"]
|
|
|
|
# Create two separate clients with unique IDs
|
|
client_a_id = "client_a_" + str(uuid.uuid4())
|
|
client_b_id = "client_b_" + str(uuid.uuid4())
|
|
|
|
try:
|
|
# Connect both clients with retries
|
|
client_a = self.start_client_with_retry(listen, port, client_a_id)
|
|
client_b = self.start_client_with_retry(listen, port, client_b_id)
|
|
|
|
# Create simple workflows for both clients
|
|
graph_a = GraphBuilder(prefix="client_a")
|
|
image_a = graph_a.node("StubImage", content="BLACK", height=256, width=256, batch_size=1)
|
|
graph_a.node("PreviewImage", images=image_a.out(0))
|
|
|
|
graph_b = GraphBuilder(prefix="client_b")
|
|
image_b = graph_b.node("StubImage", content="WHITE", height=256, width=256, batch_size=1)
|
|
graph_b.node("PreviewImage", images=image_b.out(0))
|
|
|
|
# Submit workflows from both clients
|
|
prompt_a = graph_a.finalize()
|
|
prompt_b = graph_b.finalize()
|
|
|
|
response_a = client_a.queue_prompt(prompt_a)
|
|
prompt_id_a = response_a['prompt_id']
|
|
|
|
response_b = client_b.queue_prompt(prompt_b)
|
|
prompt_id_b = response_b['prompt_id']
|
|
|
|
# Start threads to listen for messages on both clients
|
|
def listen_client_a():
|
|
client_a.listen_for_messages(duration=10.0)
|
|
|
|
def listen_client_b():
|
|
client_b.listen_for_messages(duration=10.0)
|
|
|
|
thread_a = threading.Thread(target=listen_client_a)
|
|
thread_b = threading.Thread(target=listen_client_b)
|
|
|
|
thread_a.start()
|
|
thread_b.start()
|
|
|
|
# Wait for threads to complete
|
|
thread_a.join()
|
|
thread_b.join()
|
|
|
|
# Verify isolation
|
|
# Client A should only receive progress for prompt_id_a
|
|
assert not client_a.progress_tracker.has_cross_contamination(prompt_id_a), \
|
|
f"Client A received progress updates for other clients' workflows. " \
|
|
f"Expected only {prompt_id_a}, but got messages for multiple prompts."
|
|
|
|
# Client B should only receive progress for prompt_id_b
|
|
assert not client_b.progress_tracker.has_cross_contamination(prompt_id_b), \
|
|
f"Client B received progress updates for other clients' workflows. " \
|
|
f"Expected only {prompt_id_b}, but got messages for multiple prompts."
|
|
|
|
# Verify each client received their own progress updates
|
|
client_a_messages = client_a.progress_tracker.get_messages_for_prompt(prompt_id_a)
|
|
client_b_messages = client_b.progress_tracker.get_messages_for_prompt(prompt_id_b)
|
|
|
|
assert len(client_a_messages) > 0, \
|
|
"Client A did not receive any progress updates for its own workflow"
|
|
assert len(client_b_messages) > 0, \
|
|
"Client B did not receive any progress updates for its own workflow"
|
|
|
|
# Ensure no cross-contamination
|
|
client_a_other = client_a.progress_tracker.get_messages_for_prompt(prompt_id_b)
|
|
client_b_other = client_b.progress_tracker.get_messages_for_prompt(prompt_id_a)
|
|
|
|
assert len(client_a_other) == 0, \
|
|
f"Client A incorrectly received {len(client_a_other)} progress updates for Client B's workflow"
|
|
assert len(client_b_other) == 0, \
|
|
f"Client B incorrectly received {len(client_b_other)} progress updates for Client A's workflow"
|
|
|
|
finally:
|
|
# Clean up connections
|
|
if hasattr(client_a, 'ws'):
|
|
client_a.ws.close()
|
|
if hasattr(client_b, 'ws'):
|
|
client_b.ws.close()
|
|
|
|
def test_progress_with_missing_client_id(self, args_pytest):
|
|
"""Test that progress updates handle missing client_id gracefully."""
|
|
listen = args_pytest["listen"]
|
|
port = args_pytest["port"]
|
|
|
|
try:
|
|
# Connect client with retries
|
|
client = self.start_client_with_retry(listen, port)
|
|
|
|
# Create a simple workflow
|
|
graph = GraphBuilder(prefix="test_missing_id")
|
|
image = graph.node("StubImage", content="BLACK", height=128, width=128, batch_size=1)
|
|
graph.node("PreviewImage", images=image.out(0))
|
|
|
|
# Submit workflow
|
|
prompt = graph.finalize()
|
|
response = client.queue_prompt(prompt)
|
|
prompt_id = response['prompt_id']
|
|
|
|
# Listen for messages
|
|
client.listen_for_messages(duration=5.0)
|
|
|
|
# Should still receive progress updates for own workflow
|
|
messages = client.progress_tracker.get_messages_for_prompt(prompt_id)
|
|
assert len(messages) > 0, \
|
|
"Client did not receive progress updates even though it initiated the workflow"
|
|
|
|
finally:
|
|
if hasattr(client, 'ws'):
|
|
client.ws.close()
|
|
|