diff --git a/tests/manager/models/test_utils.py b/tests/manager/models/test_utils.py index c50ba1ee48..61116ff8d0 100644 --- a/tests/manager/models/test_utils.py +++ b/tests/manager/models/test_utils.py @@ -1,8 +1,17 @@ +from datetime import datetime + import pytest import sqlalchemy as sa +from dateutil.tz import tzutc +from sqlalchemy.dialects import postgresql as pgsql -from ai.backend.manager.models import KernelRow, SessionRow, kernels -from ai.backend.manager.models.utils import agg_to_array, agg_to_str +from ai.backend.manager.models import KernelRow, SessionRow, kernels, metadata +from ai.backend.manager.models.utils import ( + ExtendedAsyncSAEngine, + agg_to_array, + agg_to_str, + sql_json_merge, +) @pytest.mark.asyncio @@ -133,3 +142,240 @@ async def test_agg_to_array(session_info) -> None: (kernels.c.tag == test_data2) & (kernels.c.session_id == session_id) ) ) + + +@pytest.fixture +async def dummy_kernels(database_engine: ExtendedAsyncSAEngine): + # dummy_kernels, designed solely for testing sql_json_merge, only includes the status_history column, unlike legacy kernels table. + dummy_kernels = sa.Table( + "dummy_kernels", + metadata, + sa.Column("id", sa.Integer(), primary_key=True, default=1), + sa.Column( + "status_history", pgsql.JSONB(), nullable=True, default=sa.null() + ), # JSONB column for testing + extend_existing=True, + ) + + async with database_engine.begin() as db_sess: + await db_sess.run_sync(metadata.create_all) + await db_sess.execute(dummy_kernels.insert()) # insert fixture data for testing + await db_sess.commit() + + yield dummy_kernels + + async with database_engine.begin() as db_sess: + await db_sess.run_sync(dummy_kernels.drop) + await db_sess.commit() + + +@pytest.mark.asyncio +async def test_sql_json_merge__deeper_object( + dummy_kernels: sa.Table, database_engine: ExtendedAsyncSAEngine +): + async with database_engine.begin() as db_sess: + timestamp = datetime.now(tzutc()).isoformat() + expected = { + "kernel": { + "session": { + "PENDING": timestamp, + "PREPARING": timestamp, + }, + }, + } + query = ( + dummy_kernels.update() + .values({ + "status_history": sql_json_merge( + dummy_kernels.c.status_history, + ("kernel", "session"), + { + "PENDING": timestamp, + "PREPARING": timestamp, + }, + ), + }) + .where(dummy_kernels.c.id == 1) + ) + await db_sess.execute(query) + result = (await db_sess.execute(sa.select(dummy_kernels.c.status_history))).scalar() + assert result == expected + + +@pytest.mark.asyncio +async def test_sql_json_merge__append_values( + dummy_kernels: sa.Table, database_engine: ExtendedAsyncSAEngine +): + async with database_engine.begin() as db_sess: + timestamp = datetime.now(tzutc()).isoformat() + expected = { + "kernel": { + "session": { + "PENDING": timestamp, + "PREPARING": timestamp, + "TERMINATED": timestamp, + "TERMINATING": timestamp, + }, + }, + } + query = ( + dummy_kernels.update() + .values({ + "status_history": sql_json_merge( + dummy_kernels.c.status_history, + ("kernel", "session"), + { + "PENDING": timestamp, + "PREPARING": timestamp, + }, + ), + }) + .where(dummy_kernels.c.id == 1) + ) + await db_sess.execute(query) + query = ( + dummy_kernels.update() + .values({ + "status_history": sql_json_merge( + dummy_kernels.c.status_history, + ("kernel", "session"), + { + "TERMINATING": timestamp, + "TERMINATED": timestamp, + }, + ), + }) + .where(dummy_kernels.c.id == 1) + ) + await db_sess.execute(query) + + result = (await db_sess.execute(sa.select(dummy_kernels.c.status_history))).scalar() + assert result == expected + + +@pytest.mark.asyncio +async def test_sql_json_merge__kernel_status_history( + dummy_kernels: sa.Table, database_engine: ExtendedAsyncSAEngine +): + async with database_engine.begin() as db_sess: + timestamp = datetime.now(tzutc()).isoformat() + expected = { + "PENDING": timestamp, + "PREPARING": timestamp, + "TERMINATING": timestamp, + "TERMINATED": timestamp, + } + query = ( + dummy_kernels.update() + .values({ + "status_history": sql_json_merge( + dummy_kernels.c.status_history, + (), + { + "PENDING": timestamp, + "PREPARING": timestamp, + }, + ), + }) + .where(dummy_kernels.c.id == 1) + ) + await db_sess.execute(query) + query = ( + dummy_kernels.update() + .values({ + "status_history": sql_json_merge( + dummy_kernels.c.status_history, + (), + { + "TERMINATING": timestamp, + "TERMINATED": timestamp, + }, + ), + }) + .where(dummy_kernels.c.id == 1) + ) + await db_sess.execute(query) + + result = (await db_sess.execute(sa.select(dummy_kernels.c.status_history))).scalar() + assert result == expected + + +@pytest.mark.asyncio +async def test_sql_json_merge__mixed_formats( + dummy_kernels: sa.Table, database_engine: ExtendedAsyncSAEngine +): + async with database_engine.begin() as db_sess: + timestamp = datetime.now(tzutc()).isoformat() + expected = { + "PENDING": timestamp, + "kernel": { + "PREPARING": timestamp, + }, + } + query = ( + dummy_kernels.update() + .values({ + "status_history": sql_json_merge( + dummy_kernels.c.status_history, + (), + { + "PENDING": timestamp, + }, + ), + }) + .where(dummy_kernels.c.id == 1) + ) + await db_sess.execute(query) + + query = ( + dummy_kernels.update() + .values({ + "status_history": sql_json_merge( + dummy_kernels.c.status_history, + ("kernel",), + { + "PREPARING": timestamp, + }, + ), + }) + .where(dummy_kernels.c.id == 1) + ) + await db_sess.execute(query) + + result = (await db_sess.execute(sa.select(dummy_kernels.c.status_history))).scalar() + assert result == expected + + +@pytest.mark.asyncio +async def test_sql_json_merge__json_serializable_types( + dummy_kernels: sa.Table, database_engine: ExtendedAsyncSAEngine +): + async with database_engine.begin() as db_sess: + expected = { + "boolean": True, + "integer": 10101010, + "float": 1010.1010, + "string": "10101010", + # "bytes": b"10101010", + "list": [ + 10101010, + "10101010", + ], + "dict": { + "10101010": 10101010, + }, + } + query = ( + dummy_kernels.update() + .values({ + "status_history": sql_json_merge( + dummy_kernels.c.status_history, + (), + expected, + ), + }) + .where(dummy_kernels.c.id == 1) + ) + await db_sess.execute(query) + result = (await db_sess.execute(sa.select(dummy_kernels.c.status_history))).scalar() + assert result == expected