Skip to content

Commit

Permalink
Hammer/save state race condition (#928)
Browse files Browse the repository at this point in the history
* Using a deque to ensure we don't spam alerts when the adapter stalls

Signed-off-by: Aaron Chong <[email protected]>

* Adding a mutex for saving task state to prevent race condition for db saving operation

Signed-off-by: Aaron Chong <[email protected]>

* Adding fixme

Signed-off-by: Aaron Chong <[email protected]>

* Fix deprecated logger.warn

Signed-off-by: Aaron Chong <[email protected]>

* Reverting deque usage

Signed-off-by: Aaron Chong <[email protected]>

* Use asyncio.Lock, added more notes about the issue

Signed-off-by: Aaron Chong <[email protected]>

---------

Signed-off-by: Aaron Chong <[email protected]>
  • Loading branch information
aaronchongth authored May 6, 2024
1 parent 78574d5 commit 1f14af3
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 28 deletions.
64 changes: 40 additions & 24 deletions packages/api-server/api_server/repositories/tasks.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import sys
from datetime import datetime
from typing import Dict, List, Optional, Sequence, Tuple, cast
Expand Down Expand Up @@ -36,6 +37,7 @@ def __init__(
):
self.user = user
self.logger = logger
self.save_task_state_mutex = asyncio.Lock()

def parse_pickup(self, task_request: TaskRequest) -> Optional[str]:
# patrol
Expand Down Expand Up @@ -163,32 +165,46 @@ async def query_task_requests(self, task_ids: List[str]) -> List[DbTaskRequest]:
raise HTTPException(422, str(e)) from e

async def save_task_state(self, task_state: TaskState) -> None:
task_state_dict = {
"data": task_state.json(),
"category": task_state.category.__root__ if task_state.category else None,
"assigned_to": task_state.assigned_to.name
if task_state.assigned_to
else None,
"unix_millis_start_time": task_state.unix_millis_start_time
and datetime.fromtimestamp(task_state.unix_millis_start_time / 1000),
"unix_millis_finish_time": task_state.unix_millis_finish_time
and datetime.fromtimestamp(task_state.unix_millis_finish_time / 1000),
"status": task_state.status if task_state.status else None,
"unix_millis_request_time": task_state.booking.unix_millis_request_time
and datetime.fromtimestamp(
task_state.booking.unix_millis_request_time / 1000
),
"requester": task_state.booking.requester
if task_state.booking.requester
else None,
}
# FIXME: If the task dispatcher is also provided websocket access to
# the API server, when a new task is dispatched via the API server,
# there may be a race condition where both the ROS 2 task response and
# task dispatcher websocket update may attempt to create a new task
# state model with the same task ID. This have unfortunately not been
# reproducible locally, only in the production environment, which uses
# Postgres instead of sqlite. This may be fixed upstream in DB or ORM,
# this mutex can be removed once these libraries have been updated and
# tested to be fixed.
async with self.save_task_state_mutex:
task_state_dict = {
"data": task_state.json(),
"category": task_state.category.__root__
if task_state.category
else None,
"assigned_to": task_state.assigned_to.name
if task_state.assigned_to
else None,
"unix_millis_start_time": task_state.unix_millis_start_time
and datetime.fromtimestamp(task_state.unix_millis_start_time / 1000),
"unix_millis_finish_time": task_state.unix_millis_finish_time
and datetime.fromtimestamp(task_state.unix_millis_finish_time / 1000),
"status": task_state.status if task_state.status else None,
"unix_millis_request_time": task_state.booking.unix_millis_request_time
and datetime.fromtimestamp(
task_state.booking.unix_millis_request_time / 1000
),
"requester": task_state.booking.requester
if task_state.booking.requester
else None,
}

if task_state.unix_millis_warn_time is not None:
task_state_dict["unix_millis_warn_time"] = datetime.fromtimestamp(
task_state.unix_millis_warn_time / 1000
)
if task_state.unix_millis_warn_time is not None:
task_state_dict["unix_millis_warn_time"] = datetime.fromtimestamp(
task_state.unix_millis_warn_time / 1000
)

await ttm.TaskState.update_or_create(task_state_dict, id_=task_state.booking.id)
await ttm.TaskState.update_or_create(
task_state_dict, id_=task_state.booking.id
)

async def query_task_states(
self, query: QuerySet[DbTaskState], pagination: Optional[Pagination] = None
Expand Down
8 changes: 4 additions & 4 deletions packages/api-server/api_server/routes/internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,12 +83,12 @@ async def process_msg(
logger: LoggerAdapter,
) -> None:
if "type" not in msg:
logger.warn(msg)
logger.warn("Ignoring message, 'type' must include in msg field")
logger.warning(msg)
logger.warning("Ignoring message, 'type' must include in msg field")
return
payload_type: str = msg["type"]
if not isinstance(payload_type, str):
logger.warn("error processing message, 'type' must be a string")
logger.warning("error processing message, 'type' must be a string")
return
logger.debug(msg)

Expand Down Expand Up @@ -152,4 +152,4 @@ async def rmf_gateway(
await process_msg(msg, fleet_repo, task_repo, alert_repo, logger)
except (WebSocketDisconnect, ConnectionClosed):
connection_manager.disconnect(websocket)
logger.warn("Client websocket disconnected")
logger.warning("Client websocket disconnected")

0 comments on commit 1f14af3

Please sign in to comment.