diff --git a/packages/api-server/api_server/repositories/tasks.py b/packages/api-server/api_server/repositories/tasks.py index a930352f2..c994ea0a4 100644 --- a/packages/api-server/api_server/repositories/tasks.py +++ b/packages/api-server/api_server/repositories/tasks.py @@ -1,3 +1,4 @@ +import asyncio import sys from datetime import datetime from typing import Dict, List, Optional, Sequence, Tuple, cast @@ -36,6 +37,7 @@ def __init__( ): self.user = user self.logger = logger + self.save_task_state_mutex = asyncio.Lock() def parse_pickup(self, task_request: TaskRequest) -> Optional[str]: # patrol @@ -163,32 +165,46 @@ async def query_task_requests(self, task_ids: List[str]) -> List[DbTaskRequest]: raise HTTPException(422, str(e)) from e async def save_task_state(self, task_state: TaskState) -> None: - task_state_dict = { - "data": task_state.json(), - "category": task_state.category.__root__ if task_state.category else None, - "assigned_to": task_state.assigned_to.name - if task_state.assigned_to - else None, - "unix_millis_start_time": task_state.unix_millis_start_time - and datetime.fromtimestamp(task_state.unix_millis_start_time / 1000), - "unix_millis_finish_time": task_state.unix_millis_finish_time - and datetime.fromtimestamp(task_state.unix_millis_finish_time / 1000), - "status": task_state.status if task_state.status else None, - "unix_millis_request_time": task_state.booking.unix_millis_request_time - and datetime.fromtimestamp( - task_state.booking.unix_millis_request_time / 1000 - ), - "requester": task_state.booking.requester - if task_state.booking.requester - else None, - } + # FIXME: If the task dispatcher is also provided websocket access to + # the API server, when a new task is dispatched via the API server, + # there may be a race condition where both the ROS 2 task response and + # task dispatcher websocket update may attempt to create a new task + # state model with the same task ID. This have unfortunately not been + # reproducible locally, only in the production environment, which uses + # Postgres instead of sqlite. This may be fixed upstream in DB or ORM, + # this mutex can be removed once these libraries have been updated and + # tested to be fixed. + async with self.save_task_state_mutex: + task_state_dict = { + "data": task_state.json(), + "category": task_state.category.__root__ + if task_state.category + else None, + "assigned_to": task_state.assigned_to.name + if task_state.assigned_to + else None, + "unix_millis_start_time": task_state.unix_millis_start_time + and datetime.fromtimestamp(task_state.unix_millis_start_time / 1000), + "unix_millis_finish_time": task_state.unix_millis_finish_time + and datetime.fromtimestamp(task_state.unix_millis_finish_time / 1000), + "status": task_state.status if task_state.status else None, + "unix_millis_request_time": task_state.booking.unix_millis_request_time + and datetime.fromtimestamp( + task_state.booking.unix_millis_request_time / 1000 + ), + "requester": task_state.booking.requester + if task_state.booking.requester + else None, + } - if task_state.unix_millis_warn_time is not None: - task_state_dict["unix_millis_warn_time"] = datetime.fromtimestamp( - task_state.unix_millis_warn_time / 1000 - ) + if task_state.unix_millis_warn_time is not None: + task_state_dict["unix_millis_warn_time"] = datetime.fromtimestamp( + task_state.unix_millis_warn_time / 1000 + ) - await ttm.TaskState.update_or_create(task_state_dict, id_=task_state.booking.id) + await ttm.TaskState.update_or_create( + task_state_dict, id_=task_state.booking.id + ) async def query_task_states( self, query: QuerySet[DbTaskState], pagination: Optional[Pagination] = None diff --git a/packages/api-server/api_server/routes/internal.py b/packages/api-server/api_server/routes/internal.py index 57fb808a0..0e6b94226 100644 --- a/packages/api-server/api_server/routes/internal.py +++ b/packages/api-server/api_server/routes/internal.py @@ -83,12 +83,12 @@ async def process_msg( logger: LoggerAdapter, ) -> None: if "type" not in msg: - logger.warn(msg) - logger.warn("Ignoring message, 'type' must include in msg field") + logger.warning(msg) + logger.warning("Ignoring message, 'type' must include in msg field") return payload_type: str = msg["type"] if not isinstance(payload_type, str): - logger.warn("error processing message, 'type' must be a string") + logger.warning("error processing message, 'type' must be a string") return logger.debug(msg) @@ -152,4 +152,4 @@ async def rmf_gateway( await process_msg(msg, fleet_repo, task_repo, alert_repo, logger) except (WebSocketDisconnect, ConnectionClosed): connection_manager.disconnect(websocket) - logger.warn("Client websocket disconnected") + logger.warning("Client websocket disconnected")