diff --git a/packages/api-server/api_server/gateway.py b/packages/api-server/api_server/gateway.py index 26dbf54c5..5751af1df 100644 --- a/packages/api-server/api_server/gateway.py +++ b/packages/api-server/api_server/gateway.py @@ -41,6 +41,7 @@ from .models import ( AlertParameter, AlertRequest, + AlertResponse, BeaconState, BuildingMap, DeliveryAlert, @@ -51,14 +52,7 @@ LiftState, ) from .models.delivery_alerts import action_from_msg, category_from_msg, tier_from_msg -from .repositories import ( - CachedFilesRepository, - LocationAlertFailResponse, - LocationAlertSuccessResponse, - cached_files_repo, - is_final_location_alert_check, - task_id_to_all_locations_success_cache, -) +from .repositories import CachedFilesRepository, cached_files_repo from .rmf_io import alert_events, rmf_events from .ros import ros_node @@ -297,20 +291,6 @@ def convert_fleet_alert(fleet_alert: RmfFleetAlert): def handle_fleet_alert(fleet_alert: AlertRequest): logging.info("Received fleet alert:") logging.info(fleet_alert) - - # Handle request for checking all location completion success for - # this task - is_final_check = is_final_location_alert_check(fleet_alert) - if is_final_check: - successful_so_far = task_id_to_all_locations_success_cache.lookup( - fleet_alert.task_id - ) - if successful_so_far is None or not successful_so_far: - self.respond_to_alert(fleet_alert.id, LocationAlertFailResponse) - else: - self.respond_to_alert(fleet_alert.id, LocationAlertSuccessResponse) - return - alert_events.alert_requests.on_next(fleet_alert) fleet_alert_sub = ros_node().create_subscription( @@ -326,6 +306,31 @@ def handle_fleet_alert(fleet_alert: AlertRequest): ) self._subscriptions.append(fleet_alert_sub) + def convert_fleet_alert_response(fleet_alert_response: RmfFleetAlertResponse): + return AlertResponse( + id=fleet_alert_response.id, + unix_millis_response_time=round(datetime.now().timestamp() * 1000), + response=fleet_alert_response.response, + ) + + def handle_fleet_alert_response(alert_response: AlertResponse): + logging.info("Received alert response:") + logging.info(alert_response) + alert_events.alert_responses.on_next(alert_response) + + fleet_alert_response_sub = ros_node().create_subscription( + RmfFleetAlertResponse, + "fleet_alert_response", + lambda msg: handle_fleet_alert_response(convert_fleet_alert_response(msg)), + rclpy.qos.QoSProfile( + history=rclpy.qos.HistoryPolicy.KEEP_LAST, + depth=10, + reliability=rclpy.qos.ReliabilityPolicy.RELIABLE, + durability=rclpy.qos.DurabilityPolicy.TRANSIENT_LOCAL, + ), + ) + self._subscriptions.append(fleet_alert_response_sub) + def handle_fire_alarm_trigger(fire_alarm_trigger_msg: BoolMsg): if fire_alarm_trigger_msg.data: logging.info("Fire alarm triggered") diff --git a/packages/api-server/api_server/models/alerts.py b/packages/api-server/api_server/models/alerts.py index c5ffd5b2d..b7999b7b7 100644 --- a/packages/api-server/api_server/models/alerts.py +++ b/packages/api-server/api_server/models/alerts.py @@ -20,6 +20,14 @@ class AlertResponse(BaseModel): def from_tortoise(tortoise: ttm.AlertResponse) -> "AlertResponse": return AlertResponse(**tortoise.data) + async def save(self) -> None: + await ttm.AlertResponse.update_or_create( + { + "data": self.json(), + }, + id=self.id, + ) + class AlertRequest(BaseModel): class Tier(str, Enum): diff --git a/packages/api-server/api_server/repositories/__init__.py b/packages/api-server/api_server/repositories/__init__.py index 00583b721..7b0857951 100644 --- a/packages/api-server/api_server/repositories/__init__.py +++ b/packages/api-server/api_server/repositories/__init__.py @@ -1,10 +1,4 @@ -from .alerts import ( - AlertRepository, - LocationAlertFailResponse, - LocationAlertSuccessResponse, - is_final_location_alert_check, - task_id_to_all_locations_success_cache, -) +from .alerts import AlertRepository from .cached_files import CachedFilesRepository, cached_files_repo from .fleets import FleetRepository from .rmf import RmfRepository diff --git a/packages/api-server/api_server/repositories/alerts.py b/packages/api-server/api_server/repositories/alerts.py index be775e050..535423a3b 100644 --- a/packages/api-server/api_server/repositories/alerts.py +++ b/packages/api-server/api_server/repositories/alerts.py @@ -1,5 +1,4 @@ import logging -from collections import deque from datetime import datetime from typing import List, Optional @@ -47,61 +46,6 @@ def get_location_from_location_alert(alert: AlertRequest) -> Optional[str]: return None -def is_final_location_alert_check(alert: AlertRequest) -> bool: - """ - Checks if the alert request requires a check on all location alerts of this - task. - Note: This is an experimental feature and may be subjected to - modifications often. - """ - if ( - alert.task_id is None - or len(alert.alert_parameters) < 1 - or LocationAlertSuccessResponse not in alert.responses_available - or LocationAlertFailResponse not in alert.responses_available - ): - return False - - # Check type - for param in alert.alert_parameters: - if param.name == LocationAlertTypeParameterName: - if param.value == LocationAlertFinalCheckTypeParameterValue: - return True - return False - return False - - -class LRUCache: - def __init__(self, capacity: int): - self._cache = deque(maxlen=capacity) - self._lookup = {} - - def add(self, key, value): - if key in self._lookup: - self._cache.remove(key) - elif len(self._cache) == self._cache.maxlen: - oldest_key = self._cache.popleft() - del self._lookup[oldest_key] - - self._cache.append(key) - self._lookup[key] = value - - def remove(self, key): - if key in self._lookup: - self._cache.remove(key) - del self._lookup[key] - - def lookup(self, key): - if key in self._lookup: - self._cache.remove(key) - self._cache.append(key) - return self._lookup[key] - return None - - -task_id_to_all_locations_success_cache: LRUCache = LRUCache(20) - - class AlertRepository: async def create_new_alert(self, alert: AlertRequest) -> Optional[AlertRequest]: exists = await ttm.AlertRequest.exists(id=alert.id) @@ -180,55 +124,3 @@ async def get_unresponded_alerts(self) -> List[AlertRequest]: alert_response=None, response_expected=True ) return [AlertRequest.from_tortoise(alert) for alert in unresponded_alerts] - - async def create_location_alert_response( - self, - task_id: str, - location: str, - success: bool, - ) -> Optional[AlertResponse]: - """ - Creates an alert response for a location alert of the task. - Note: This is an experimental feature and may be subjected to - modifications often. - """ - alerts = await self.get_alerts_of_task(task_id=task_id, unresponded=True) - if len(alerts) == 0: - logging.error( - f"There are no location alerts awaiting response for task {task_id}" - ) - return None - - for alert in alerts: - location_alert_location = get_location_from_location_alert(alert) - if location_alert_location is None: - continue - - if location_alert_location == location: - response = ( - LocationAlertSuccessResponse - if success - else LocationAlertFailResponse - ) - alert_response_model = await self.create_response(alert.id, response) - if alert_response_model is None: - logging.error( - f"Failed to create response {response} to alert with ID {alert.id}" - ) - return None - - # Cache if all locations of this task has been successful so far - cache = task_id_to_all_locations_success_cache.lookup(task_id) - if cache is None: - task_id_to_all_locations_success_cache.add(task_id, success) - else: - task_id_to_all_locations_success_cache.add( - task_id, cache and success - ) - - return alert_response_model - - logging.error( - f"Task {task_id} is not awaiting completion of location {location}" - ) - return None diff --git a/packages/api-server/api_server/rmf_io/book_keeper.py b/packages/api-server/api_server/rmf_io/book_keeper.py index 89c42ff9d..40765e582 100644 --- a/packages/api-server/api_server/rmf_io/book_keeper.py +++ b/packages/api-server/api_server/rmf_io/book_keeper.py @@ -8,6 +8,7 @@ from api_server.models import ( AlertRequest, + AlertResponse, BeaconState, BuildingMap, DispenserHealth, @@ -21,9 +22,7 @@ LiftState, ) from api_server.models.health import BaseBasicHealth -from api_server.repositories import ( - AlertRepository, # , is_final_location_alert_check, LocationAlertFailResponse, LocationAlertSuccessResponse -) +from api_server.repositories import AlertRepository from .events import AlertEvents, RmfEvents @@ -62,6 +61,7 @@ async def start(self): self._record_ingestor_state() self._record_ingestor_health() self._record_alert_request() + self._record_alert_response() async def stop(self): for sub in self._subscriptions: @@ -205,3 +205,14 @@ async def update(alert_request: AlertRequest): lambda x: self._create_task(update(x)) ) ) + + def _record_alert_response(self): + async def update(alert_response: AlertResponse): + await alert_response.save() + logging.debug(json.dumps(alert_response.dict())) + + self._subscriptions.append( + self.alert_events.alert_responses.subscribe( + lambda x: self._create_task(update(x)) + ) + ) diff --git a/packages/api-server/api_server/routes/tasks/tasks.py b/packages/api-server/api_server/routes/tasks/tasks.py index 46c65aac8..2a8030c91 100644 --- a/packages/api-server/api_server/routes/tasks/tasks.py +++ b/packages/api-server/api_server/routes/tasks/tasks.py @@ -14,12 +14,11 @@ start_time_between_query, ) from api_server.fast_io import FastIORouter, SubscriptionRequest -from api_server.gateway import rmf_gateway from api_server.logging import LoggerAdapter, get_logger from api_server.models.tortoise_models import TaskState as DbTaskState -from api_server.repositories import AlertRepository, RmfRepository, TaskRepository +from api_server.repositories import RmfRepository, TaskRepository from api_server.response import RawJSONResponse -from api_server.rmf_io import alert_events, task_events, tasks_service +from api_server.rmf_io import task_events, tasks_service from api_server.routes.building_map import get_building_map router = FastIORouter(tags=["Tasks"]) @@ -390,28 +389,3 @@ async def post_undo_skip_phase( request: mdl.UndoPhaseSkipRequest = Body(...), ): return RawJSONResponse(await tasks_service().call(request.json(exclude_none=True))) - - -@router.post("/location_complete") -async def location_complete( - task_id: str, - location: str, - success: bool, - alert_repo: AlertRepository = Depends(AlertRepository), - logger: LoggerAdapter = Depends(get_logger), -): - """ - Warning: This endpoint is still actively being worked on and could be - subjected to modifications. - """ - response_model = await alert_repo.create_location_alert_response( - task_id, location, success - ) - if response_model is None: - raise HTTPException( - 422, - f"Failed to create location completion alert response to task [{task_id}] at location [{location}]", - ) - alert_events.alert_responses.on_next(response_model) - rmf_gateway().respond_to_alert(response_model.id, response_model.response) - logger.info(response_model) diff --git a/packages/api-server/api_server/routes/test_alerts.py b/packages/api-server/api_server/routes/test_alerts.py index a505a2f80..ac02277cc 100644 --- a/packages/api-server/api_server/routes/test_alerts.py +++ b/packages/api-server/api_server/routes/test_alerts.py @@ -223,36 +223,3 @@ def test_get_unresponded_alert_ids(self): returned_alert_ids = [a["id"] for a in returned_alerts] self.assertTrue(first_id not in returned_alert_ids) self.assertTrue(second_id in returned_alert_ids) - - def test_task_location_complete(self): - alert_id = str(uuid4()) - task_id = "test_task_id" - location_name = "test_location" - alert = make_alert_request(alert_id=alert_id, responses=["success", "fail"]) - alert.alert_parameters = [ - mdl.AlertParameter(name="type", value="location_result"), - mdl.AlertParameter(name="location_name", value=location_name), - ] - alert.task_id = task_id - resp = self.client.post("/alerts/request", data=alert.json(exclude_none=True)) - self.assertEqual(201, resp.status_code, resp.content) - self.assertEqual(alert, resp.json(), resp.content) - - # complete wrong task ID - params = { - "task_id": "wrong_task_id", - "location": location_name, - "success": True, - } - resp = self.client.post(f"/tasks/location_complete?{urlencode(params)}") - self.assertEqual(422, resp.status_code, resp.content) - - # complete missing location - params = {"task_id": task_id, "location": "wrong_location", "success": True} - resp = self.client.post(f"/tasks/location_complete?{urlencode(params)}") - self.assertEqual(422, resp.status_code, resp.content) - - # complete location - params = {"task_id": task_id, "location": location_name, "success": True} - resp = self.client.post(f"/tasks/location_complete?{urlencode(params)}") - self.assertEqual(200, resp.status_code, resp.content)