mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-09-14 13:35:05 +00:00
refactor, adding tests
This commit is contained in:
401
tests-unit/app_test/entities_test.py
Normal file
401
tests-unit/app_test/entities_test.py
Normal file
@@ -0,0 +1,401 @@
|
||||
from comfy.cli_args import args
|
||||
|
||||
args.memory_database = True # force in-memory database for testing
|
||||
|
||||
from typing import Callable, Optional
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from unittest.mock import patch
|
||||
from aiohttp import web
|
||||
from app.database.entities import (
|
||||
column,
|
||||
table,
|
||||
Column,
|
||||
GetEntity,
|
||||
GetEntityById,
|
||||
CreateEntity,
|
||||
UpsertEntity,
|
||||
)
|
||||
from app.database.db import db
|
||||
|
||||
pytestmark = pytest.mark.asyncio
|
||||
|
||||
|
||||
def create_table(entity):
|
||||
# reset db
|
||||
db.close()
|
||||
|
||||
cols: list[Column] = entity.__columns__
|
||||
# Create tables as temporary so when we close the db, the tables are dropped for next test
|
||||
sql = f"CREATE TEMPORARY TABLE {entity.__table_name__} ( "
|
||||
for col_name, col in cols.items():
|
||||
type = None
|
||||
if col.type == int:
|
||||
type = "INTEGER"
|
||||
elif col.type == str:
|
||||
type = "TEXT"
|
||||
|
||||
sql += f"{col_name} {type}"
|
||||
if col.required:
|
||||
sql += " NOT NULL"
|
||||
sql += ", "
|
||||
|
||||
sql += f"PRIMARY KEY ({', '.join(entity.__key_columns__)})"
|
||||
sql += ")"
|
||||
db.execute(sql)
|
||||
|
||||
|
||||
async def wrap_db(method: Callable, expected_sql: str, expected_args: list):
|
||||
with patch.object(db, "execute", wraps=db.execute) as mock:
|
||||
response = await method()
|
||||
assert mock.call_args[0][0] == expected_sql
|
||||
assert mock.call_args[0][1:] == expected_args
|
||||
return response
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def getable_entity():
|
||||
@table("getable_entity")
|
||||
class GetableEntity(GetEntity):
|
||||
id: int = column(int, required=True, key=True)
|
||||
test: str = column(str, required=True)
|
||||
nullable: Optional[str] = column(str)
|
||||
|
||||
return GetableEntity
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def getable_by_id_entity():
|
||||
@table("getable_by_id_entity")
|
||||
class GetableByIdEntity(GetEntityById):
|
||||
id: int = column(int, required=True, key=True)
|
||||
test: str = column(str, required=True)
|
||||
|
||||
return GetableByIdEntity
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def getable_by_id_composite_entity():
|
||||
@table("getable_by_id_composite_entity")
|
||||
class GetableByIdCompositeEntity(GetEntityById):
|
||||
id1: str = column(str, required=True, key=True)
|
||||
id2: int = column(int, required=True, key=True)
|
||||
test: str = column(str, required=True)
|
||||
|
||||
return GetableByIdCompositeEntity
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def creatable_entity():
|
||||
@table("creatable_entity")
|
||||
class CreatableEntity(CreateEntity):
|
||||
id: int = column(int, required=True, key=True)
|
||||
test: str = column(str, required=True)
|
||||
reqd: str = column(str, required=True)
|
||||
nullable: Optional[str] = column(str)
|
||||
|
||||
return CreatableEntity
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def upsertable_entity():
|
||||
@table("upsertable_entity")
|
||||
class UpsertableEntity(UpsertEntity):
|
||||
id: int = column(int, required=True, key=True)
|
||||
test: str = column(str, required=True)
|
||||
reqd: str = column(str, required=True)
|
||||
nullable: Optional[str] = column(str)
|
||||
|
||||
return UpsertableEntity
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def entity(request):
|
||||
value = request.getfixturevalue(request.param)
|
||||
create_table(value)
|
||||
return value
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def client(aiohttp_client, app):
|
||||
return await aiohttp_client(app)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app(entity):
|
||||
app = web.Application()
|
||||
routes = web.RouteTableDef()
|
||||
entity.register_route(routes)
|
||||
app.add_routes(routes)
|
||||
return app
|
||||
|
||||
|
||||
@pytest.mark.parametrize("entity", ["getable_entity"], indirect=True)
|
||||
async def test_get_model_empty_response(client):
|
||||
expected_sql = "SELECT * FROM getable_entity"
|
||||
expected_args = ()
|
||||
response = await wrap_db(
|
||||
lambda: client.get("/db/getable_entity"), expected_sql, expected_args
|
||||
)
|
||||
|
||||
assert response.status == 200
|
||||
assert await response.json() == []
|
||||
|
||||
|
||||
@pytest.mark.parametrize("entity", ["getable_entity"], indirect=True)
|
||||
async def test_get_model_with_data(client):
|
||||
# seed db
|
||||
db.execute(
|
||||
"INSERT INTO getable_entity (id, test, nullable) VALUES (1, 'test1', NULL), (2, 'test2', 'test2')"
|
||||
)
|
||||
|
||||
expected_sql = "SELECT * FROM getable_entity"
|
||||
expected_args = ()
|
||||
response = await wrap_db(
|
||||
lambda: client.get("/db/getable_entity"), expected_sql, expected_args
|
||||
)
|
||||
|
||||
assert response.status == 200
|
||||
assert await response.json() == [
|
||||
{"id": 1, "test": "test1", "nullable": None},
|
||||
{"id": 2, "test": "test2", "nullable": "test2"},
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("entity", ["getable_entity"], indirect=True)
|
||||
async def test_get_model_with_top_parameter(client):
|
||||
# seed with 3 rows
|
||||
db.execute(
|
||||
"INSERT INTO getable_entity (id, test, nullable) VALUES (1, 'test1', NULL), (2, 'test2', 'test2'), (3, 'test3', 'test3')"
|
||||
)
|
||||
|
||||
expected_sql = "SELECT * FROM getable_entity LIMIT 2"
|
||||
expected_args = ()
|
||||
response = await wrap_db(
|
||||
lambda: client.get("/db/getable_entity?top=2"),
|
||||
expected_sql,
|
||||
expected_args,
|
||||
)
|
||||
|
||||
assert response.status == 200
|
||||
assert await response.json() == [
|
||||
{"id": 1, "test": "test1", "nullable": None},
|
||||
{"id": 2, "test": "test2", "nullable": "test2"},
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("entity", ["getable_entity"], indirect=True)
|
||||
async def test_get_model_with_invalid_top_parameter(client):
|
||||
response = await client.get("/db/getable_entity?top=hello")
|
||||
assert response.status == 400
|
||||
assert await response.json() == {
|
||||
"message": "Invalid top parameter",
|
||||
"field": "top",
|
||||
"value": "hello",
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("entity", ["getable_by_id_entity"], indirect=True)
|
||||
async def test_get_model_by_id_empty_response(client):
|
||||
# seed db
|
||||
db.execute("INSERT INTO getable_by_id_entity (id, test) VALUES (1, 'test1')")
|
||||
|
||||
expected_sql = "SELECT * FROM getable_by_id_entity WHERE id = ?"
|
||||
expected_args = (1,)
|
||||
response = await wrap_db(
|
||||
lambda: client.get("/db/getable_by_id_entity/1"),
|
||||
expected_sql,
|
||||
expected_args,
|
||||
)
|
||||
|
||||
assert response.status == 200
|
||||
assert await response.json() == [
|
||||
{"id": 1, "test": "test1"},
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("entity", ["getable_by_id_entity"], indirect=True)
|
||||
async def test_get_model_by_id_with_invalid_id(client):
|
||||
response = await client.get("/db/getable_by_id_entity/hello")
|
||||
assert response.status == 400
|
||||
assert await response.json() == {
|
||||
"message": "Invalid value",
|
||||
"field": "id",
|
||||
"value": "hello",
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("entity", ["getable_by_id_composite_entity"], indirect=True)
|
||||
async def test_get_model_by_id_composite(client):
|
||||
# seed db
|
||||
db.execute(
|
||||
"INSERT INTO getable_by_id_composite_entity (id1, id2, test) VALUES ('one', 2, 'test')"
|
||||
)
|
||||
|
||||
expected_sql = (
|
||||
"SELECT * FROM getable_by_id_composite_entity WHERE id1 = ? AND id2 = ?"
|
||||
)
|
||||
expected_args = ("one", 2)
|
||||
response = await wrap_db(
|
||||
lambda: client.get("/db/getable_by_id_composite_entity/one/2"),
|
||||
expected_sql,
|
||||
expected_args,
|
||||
)
|
||||
|
||||
assert response.status == 200
|
||||
assert await response.json() == [
|
||||
{"id1": "one", "id2": 2, "test": "test"},
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("entity", ["getable_by_id_composite_entity"], indirect=True)
|
||||
async def test_get_model_by_id_composite_with_invalid_id(client):
|
||||
response = await client.get("/db/getable_by_id_composite_entity/hello/hello")
|
||||
assert response.status == 400
|
||||
assert await response.json() == {
|
||||
"message": "Invalid value",
|
||||
"field": "id2",
|
||||
"value": "hello",
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("entity", ["creatable_entity"], indirect=True)
|
||||
async def test_create_model(client):
|
||||
expected_sql = (
|
||||
"INSERT INTO creatable_entity (id, test, reqd) VALUES (?, ?, ?) RETURNING *"
|
||||
)
|
||||
expected_args = (1, "test1", "reqd1")
|
||||
response = await wrap_db(
|
||||
lambda: client.post(
|
||||
"/db/creatable_entity", json={"id": 1, "test": "test1", "reqd": "reqd1"}
|
||||
),
|
||||
expected_sql,
|
||||
expected_args,
|
||||
)
|
||||
|
||||
assert response.status == 200
|
||||
assert await response.json() == {
|
||||
"id": 1,
|
||||
"test": "test1",
|
||||
"reqd": "reqd1",
|
||||
"nullable": None,
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("entity", ["creatable_entity"], indirect=True)
|
||||
async def test_create_model_missing_required_field(client):
|
||||
response = await client.post(
|
||||
"/db/creatable_entity", json={"id": 1, "test": "test1"}
|
||||
)
|
||||
|
||||
assert response.status == 400
|
||||
assert await response.json() == {
|
||||
"message": "Missing field",
|
||||
"field": "reqd",
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("entity", ["creatable_entity"], indirect=True)
|
||||
async def test_create_model_missing_key_field(client):
|
||||
response = await client.post(
|
||||
"/db/creatable_entity",
|
||||
json={"test": "test1", "reqd": "reqd1"}, # Missing 'id' which is a key
|
||||
)
|
||||
|
||||
assert response.status == 400
|
||||
assert await response.json() == {
|
||||
"message": "Missing field",
|
||||
"field": "id",
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("entity", ["creatable_entity"], indirect=True)
|
||||
async def test_create_model_invalid_key_data(client):
|
||||
response = await client.post(
|
||||
"/db/creatable_entity",
|
||||
json={
|
||||
"id": "not_an_integer",
|
||||
"test": "test1",
|
||||
"reqd": "reqd1",
|
||||
}, # id should be int
|
||||
)
|
||||
|
||||
assert response.status == 400
|
||||
assert await response.json() == {
|
||||
"message": "Invalid value",
|
||||
"field": "id",
|
||||
"value": "not_an_integer",
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("entity", ["creatable_entity"], indirect=True)
|
||||
async def test_create_model_invalid_field_data(client):
|
||||
response = await client.post(
|
||||
"/db/creatable_entity",
|
||||
json={"id": "aaa", "test": "123", "reqd": "reqd1"}, # id should be int
|
||||
)
|
||||
|
||||
assert response.status == 400
|
||||
assert await response.json() == {
|
||||
"message": "Invalid value",
|
||||
"field": "id",
|
||||
"value": "aaa",
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("entity", ["creatable_entity"], indirect=True)
|
||||
async def test_create_model_invalid_field_type(client):
|
||||
response = await client.post(
|
||||
"/db/creatable_entity",
|
||||
json={
|
||||
"id": 1,
|
||||
"test": ["invalid_array"],
|
||||
"reqd": "reqd1",
|
||||
}, # test should be string
|
||||
)
|
||||
|
||||
assert response.status == 400
|
||||
assert await response.json() == {
|
||||
"message": "Invalid value",
|
||||
"field": "test",
|
||||
"value": ["invalid_array"],
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("entity", ["creatable_entity"], indirect=True)
|
||||
async def test_create_model_invalid_field_name(client):
|
||||
response = await client.post(
|
||||
"/db/creatable_entity",
|
||||
json={"id": 1, "test": "test1", "reqd": "reqd1", "nonexistent_field": "value"},
|
||||
)
|
||||
|
||||
assert response.status == 400
|
||||
assert await response.json() == {
|
||||
"message": "Unknown field",
|
||||
"field": "nonexistent_field",
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("entity", ["upsertable_entity"], indirect=True)
|
||||
async def test_upsert_model(client):
|
||||
expected_sql = (
|
||||
"INSERT INTO upsertable_entity (id, test, reqd) VALUES (?, ?, ?) "
|
||||
"ON CONFLICT (id) DO UPDATE SET test = excluded.test, reqd = excluded.reqd "
|
||||
"RETURNING *"
|
||||
)
|
||||
expected_args = (1, "test1", "reqd1")
|
||||
response = await wrap_db(
|
||||
lambda: client.put(
|
||||
"/db/upsertable_entity", json={"id": 1, "test": "test1", "reqd": "reqd1"}
|
||||
),
|
||||
expected_sql,
|
||||
expected_args,
|
||||
)
|
||||
|
||||
assert response.status == 200
|
||||
assert await response.json() == {
|
||||
"id": 1,
|
||||
"test": "test1",
|
||||
"reqd": "reqd1",
|
||||
"nullable": None,
|
||||
}
|
Reference in New Issue
Block a user