Skip to content

Commit

Permalink
Use ORM object to query and update session row
Browse files Browse the repository at this point in the history
  • Loading branch information
fregataa committed Sep 27, 2024
1 parent f7e2172 commit 59be0d4
Showing 1 changed file with 17 additions and 12 deletions.
29 changes: 17 additions & 12 deletions src/ai/backend/manager/models/gql_models/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import sqlalchemy as sa
from dateutil.parser import parse as dtparse
from graphene.types.datetime import DateTime as GQLDateTime
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload

from ai.backend.common.types import ClusterMode, SessionId, SessionResult
Expand Down Expand Up @@ -52,6 +53,7 @@
get_permission_ctx,
)
from ..user import UserRole
from ..utils import execute_with_txn_retry
from .kernel import KernelConnection, KernelNode

if TYPE_CHECKING:
Expand Down Expand Up @@ -553,18 +555,21 @@ async def mutate_and_get_payload(
f"The priority value {input["priority"]!r} is out of range: "
f"[{SESSION_PRIORITY_MIN}, {SESSION_PRIORITY_MAX}]."
)
async with graph_ctx.db.begin_session() as db_sess:
data: dict[str, Any] = {}
set_if_set(input, data, "name")
set_if_set(input, data, "priority")
query = (
sa.update(SessionRow)
.where(SessionRow.id == session_id)
.values(data)
.returning(SessionRow)
)
result = await db_sess.execute(query)
session_row = result.fetchone()
data: dict[str, Any] = {}
set_if_set(input, data, "name")
set_if_set(input, data, "priority")

async def _update(db_sess: AsyncSession) -> SessionRow:
query = sa.select(SessionRow).where(SessionRow.id == session_id)
session_row = cast(SessionRow, await db_sess.scalar(query))
if "name" in input:
session_row.name = input["name"]
if "priority" in input:
session_row.priority = input["priority"]
return session_row

async with graph_ctx.db.connect() as db_conn:
session_row = await execute_with_txn_retry(_update, graph_ctx.db.begin_session, db_conn)
return cls(
ComputeSessionNode.from_row(graph_ctx, session_row),
input.get("client_mutation_id"),
Expand Down

0 comments on commit 59be0d4

Please sign in to comment.