mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-09-15 14:09:28 +00:00
Add tests for delete & update
This commit is contained in:
@@ -8,6 +8,7 @@ import pytest_asyncio
|
||||
from unittest.mock import patch
|
||||
from aiohttp import web
|
||||
from app.database.entities import (
|
||||
DeleteEntity,
|
||||
column,
|
||||
table,
|
||||
Column,
|
||||
@@ -15,6 +16,7 @@ from app.database.entities import (
|
||||
GetEntityById,
|
||||
CreateEntity,
|
||||
UpsertEntity,
|
||||
UpdateEntity,
|
||||
)
|
||||
from app.database.db import db
|
||||
|
||||
@@ -25,9 +27,9 @@ def create_table(entity):
|
||||
# reset db
|
||||
db.close()
|
||||
|
||||
cols: list[Column] = entity.__columns__
|
||||
cols: list[Column] = entity._columns
|
||||
# Create tables as temporary so when we close the db, the tables are dropped for next test
|
||||
sql = f"CREATE TEMPORARY TABLE {entity.__table_name__} ( "
|
||||
sql = f"CREATE TEMPORARY TABLE {entity._table_name} ( "
|
||||
for col_name, col in cols.items():
|
||||
type = None
|
||||
if col.type == int:
|
||||
@@ -40,7 +42,7 @@ def create_table(entity):
|
||||
sql += " NOT NULL"
|
||||
sql += ", "
|
||||
|
||||
sql += f"PRIMARY KEY ({', '.join(entity.__key_columns__)})"
|
||||
sql += f"PRIMARY KEY ({', '.join(entity._key_columns)})"
|
||||
sql += ")"
|
||||
db.execute(sql)
|
||||
|
||||
@@ -48,6 +50,7 @@ def create_table(entity):
|
||||
async def wrap_db(method: Callable, expected_sql: str, expected_args: list):
|
||||
with patch.object(db, "execute", wraps=db.execute) as mock:
|
||||
response = await method()
|
||||
assert mock.call_count == 1
|
||||
assert mock.call_args[0][0] == expected_sql
|
||||
assert mock.call_args[0][1:] == expected_args
|
||||
return response
|
||||
@@ -109,6 +112,35 @@ def upsertable_entity():
|
||||
return UpsertableEntity
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def updateable_entity():
|
||||
@table("updateable_entity")
|
||||
class UpdateableEntity(UpdateEntity):
|
||||
id: int = column(int, required=True, key=True)
|
||||
reqd: str = column(str, required=True)
|
||||
|
||||
return UpdateableEntity
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def deletable_entity():
|
||||
@table("deletable_entity")
|
||||
class DeletableEntity(DeleteEntity):
|
||||
id: int = column(int, required=True, key=True)
|
||||
|
||||
return DeletableEntity
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def deletable_composite_entity():
|
||||
@table("deletable_composite_entity")
|
||||
class DeletableCompositeEntity(DeleteEntity):
|
||||
id1: str = column(str, required=True, key=True)
|
||||
id2: int = column(int, required=True, key=True)
|
||||
|
||||
return DeletableCompositeEntity
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def entity(request):
|
||||
value = request.getfixturevalue(request.param)
|
||||
@@ -399,3 +431,83 @@ async def test_upsert_model(client):
|
||||
"reqd": "reqd1",
|
||||
"nullable": None,
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("entity", ["updateable_entity"], indirect=True)
|
||||
async def test_update_model(client):
|
||||
# seed db
|
||||
db.execute("INSERT INTO updateable_entity (id, reqd) VALUES (1, 'test1')")
|
||||
|
||||
expected_sql = "UPDATE updateable_entity SET reqd = ? WHERE id = ? RETURNING *"
|
||||
expected_args = ("updated_test", 1)
|
||||
response = await wrap_db(
|
||||
lambda: client.patch("/db/updateable_entity/1", json={"reqd": "updated_test"}),
|
||||
expected_sql,
|
||||
expected_args,
|
||||
)
|
||||
|
||||
assert response.status == 200
|
||||
assert await response.json() == {
|
||||
"id": 1,
|
||||
"reqd": "updated_test",
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("entity", ["updateable_entity"], indirect=True)
|
||||
async def test_update_model_reject_null_required_field(client):
|
||||
response = await client.patch("/db/updateable_entity/1", json={"reqd": None})
|
||||
|
||||
assert response.status == 400
|
||||
assert await response.json() == {
|
||||
"message": "Required field",
|
||||
"field": "reqd",
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("entity", ["updateable_entity"], indirect=True)
|
||||
async def test_update_model_reject_invalid_field(client):
|
||||
response = await client.patch("/db/updateable_entity/1", json={"hello": "world"})
|
||||
|
||||
assert response.status == 400
|
||||
assert await response.json() == {
|
||||
"message": "Unknown field",
|
||||
"field": "hello",
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("entity", ["updateable_entity"], indirect=True)
|
||||
async def test_update_model_reject_missing_record(client):
|
||||
response = await client.patch(
|
||||
"/db/updateable_entity/1", json={"reqd": "updated_test"}
|
||||
)
|
||||
|
||||
assert response.status == 404
|
||||
assert await response.json() == {
|
||||
"message": "Failed to update entity",
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("entity", ["deletable_entity"], indirect=True)
|
||||
async def test_delete_model(client):
|
||||
expected_sql = "DELETE FROM deletable_entity WHERE id = ?"
|
||||
expected_args = (1,)
|
||||
response = await wrap_db(
|
||||
lambda: client.delete("/db/deletable_entity/1"),
|
||||
expected_sql,
|
||||
expected_args,
|
||||
)
|
||||
|
||||
assert response.status == 204
|
||||
|
||||
|
||||
@pytest.mark.parametrize("entity", ["deletable_composite_entity"], indirect=True)
|
||||
async def test_delete_model_composite_key(client):
|
||||
expected_sql = "DELETE FROM deletable_composite_entity WHERE id1 = ? AND id2 = ?"
|
||||
expected_args = ("one", 2)
|
||||
response = await wrap_db(
|
||||
lambda: client.delete("/db/deletable_composite_entity/one/2"),
|
||||
expected_sql,
|
||||
expected_args,
|
||||
)
|
||||
|
||||
assert response.status == 204
|
||||
|
Reference in New Issue
Block a user