Skip to content

Commit

Permalink
First round of cleanup
Browse files Browse the repository at this point in the history
Signed-off-by: Aaron Chong <[email protected]>
  • Loading branch information
aaronchongth committed May 31, 2024
1 parent 78f5e0b commit eb83039
Show file tree
Hide file tree
Showing 7 changed files with 52 additions and 201 deletions.
49 changes: 27 additions & 22 deletions packages/api-server/api_server/gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from .models import (
AlertParameter,
AlertRequest,
AlertResponse,
BeaconState,
BuildingMap,
DeliveryAlert,
Expand All @@ -51,14 +52,7 @@
LiftState,
)
from .models.delivery_alerts import action_from_msg, category_from_msg, tier_from_msg
from .repositories import (
CachedFilesRepository,
LocationAlertFailResponse,
LocationAlertSuccessResponse,
cached_files_repo,
is_final_location_alert_check,
task_id_to_all_locations_success_cache,
)
from .repositories import CachedFilesRepository, cached_files_repo
from .rmf_io import alert_events, rmf_events
from .ros import ros_node

Expand Down Expand Up @@ -297,20 +291,6 @@ def convert_fleet_alert(fleet_alert: RmfFleetAlert):
def handle_fleet_alert(fleet_alert: AlertRequest):
logging.info("Received fleet alert:")
logging.info(fleet_alert)

# Handle request for checking all location completion success for
# this task
is_final_check = is_final_location_alert_check(fleet_alert)
if is_final_check:
successful_so_far = task_id_to_all_locations_success_cache.lookup(
fleet_alert.task_id
)
if successful_so_far is None or not successful_so_far:
self.respond_to_alert(fleet_alert.id, LocationAlertFailResponse)
else:
self.respond_to_alert(fleet_alert.id, LocationAlertSuccessResponse)
return

alert_events.alert_requests.on_next(fleet_alert)

fleet_alert_sub = ros_node().create_subscription(
Expand All @@ -326,6 +306,31 @@ def handle_fleet_alert(fleet_alert: AlertRequest):
)
self._subscriptions.append(fleet_alert_sub)

def convert_fleet_alert_response(fleet_alert_response: RmfFleetAlertResponse):
return AlertResponse(
id=fleet_alert_response.id,
unix_millis_response_time=round(datetime.now().timestamp() * 1000),
response=fleet_alert_response.response,
)

def handle_fleet_alert_response(alert_response: AlertResponse):
logging.info("Received alert response:")
logging.info(alert_response)
alert_events.alert_responses.on_next(alert_response)

fleet_alert_response_sub = ros_node().create_subscription(
RmfFleetAlertResponse,
"fleet_alert_response",
lambda msg: handle_fleet_alert_response(convert_fleet_alert_response(msg)),
rclpy.qos.QoSProfile(
history=rclpy.qos.HistoryPolicy.KEEP_LAST,
depth=10,
reliability=rclpy.qos.ReliabilityPolicy.RELIABLE,
durability=rclpy.qos.DurabilityPolicy.TRANSIENT_LOCAL,
),
)
self._subscriptions.append(fleet_alert_response_sub)

def handle_fire_alarm_trigger(fire_alarm_trigger_msg: BoolMsg):
if fire_alarm_trigger_msg.data:
logging.info("Fire alarm triggered")
Expand Down
8 changes: 8 additions & 0 deletions packages/api-server/api_server/models/alerts.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,14 @@ class AlertResponse(BaseModel):
def from_tortoise(tortoise: ttm.AlertResponse) -> "AlertResponse":
return AlertResponse(**tortoise.data)

async def save(self) -> None:
await ttm.AlertResponse.update_or_create(
{
"data": self.json(),
},
id=self.id,
)


class AlertRequest(BaseModel):
class Tier(str, Enum):
Expand Down
8 changes: 1 addition & 7 deletions packages/api-server/api_server/repositories/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,4 @@
from .alerts import (
AlertRepository,
LocationAlertFailResponse,
LocationAlertSuccessResponse,
is_final_location_alert_check,
task_id_to_all_locations_success_cache,
)
from .alerts import AlertRepository
from .cached_files import CachedFilesRepository, cached_files_repo
from .fleets import FleetRepository
from .rmf import RmfRepository
Expand Down
108 changes: 0 additions & 108 deletions packages/api-server/api_server/repositories/alerts.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import logging
from collections import deque
from datetime import datetime
from typing import List, Optional

Expand Down Expand Up @@ -47,61 +46,6 @@ def get_location_from_location_alert(alert: AlertRequest) -> Optional[str]:
return None


def is_final_location_alert_check(alert: AlertRequest) -> bool:
"""
Checks if the alert request requires a check on all location alerts of this
task.
Note: This is an experimental feature and may be subjected to
modifications often.
"""
if (
alert.task_id is None
or len(alert.alert_parameters) < 1
or LocationAlertSuccessResponse not in alert.responses_available
or LocationAlertFailResponse not in alert.responses_available
):
return False

# Check type
for param in alert.alert_parameters:
if param.name == LocationAlertTypeParameterName:
if param.value == LocationAlertFinalCheckTypeParameterValue:
return True
return False
return False


class LRUCache:
def __init__(self, capacity: int):
self._cache = deque(maxlen=capacity)
self._lookup = {}

def add(self, key, value):
if key in self._lookup:
self._cache.remove(key)
elif len(self._cache) == self._cache.maxlen:
oldest_key = self._cache.popleft()
del self._lookup[oldest_key]

self._cache.append(key)
self._lookup[key] = value

def remove(self, key):
if key in self._lookup:
self._cache.remove(key)
del self._lookup[key]

def lookup(self, key):
if key in self._lookup:
self._cache.remove(key)
self._cache.append(key)
return self._lookup[key]
return None


task_id_to_all_locations_success_cache: LRUCache = LRUCache(20)


class AlertRepository:
async def create_new_alert(self, alert: AlertRequest) -> Optional[AlertRequest]:
exists = await ttm.AlertRequest.exists(id=alert.id)
Expand Down Expand Up @@ -180,55 +124,3 @@ async def get_unresponded_alerts(self) -> List[AlertRequest]:
alert_response=None, response_expected=True
)
return [AlertRequest.from_tortoise(alert) for alert in unresponded_alerts]

async def create_location_alert_response(
self,
task_id: str,
location: str,
success: bool,
) -> Optional[AlertResponse]:
"""
Creates an alert response for a location alert of the task.
Note: This is an experimental feature and may be subjected to
modifications often.
"""
alerts = await self.get_alerts_of_task(task_id=task_id, unresponded=True)
if len(alerts) == 0:
logging.error(
f"There are no location alerts awaiting response for task {task_id}"
)
return None

for alert in alerts:
location_alert_location = get_location_from_location_alert(alert)
if location_alert_location is None:
continue

if location_alert_location == location:
response = (
LocationAlertSuccessResponse
if success
else LocationAlertFailResponse
)
alert_response_model = await self.create_response(alert.id, response)
if alert_response_model is None:
logging.error(
f"Failed to create response {response} to alert with ID {alert.id}"
)
return None

# Cache if all locations of this task has been successful so far
cache = task_id_to_all_locations_success_cache.lookup(task_id)
if cache is None:
task_id_to_all_locations_success_cache.add(task_id, success)
else:
task_id_to_all_locations_success_cache.add(
task_id, cache and success
)

return alert_response_model

logging.error(
f"Task {task_id} is not awaiting completion of location {location}"
)
return None
17 changes: 14 additions & 3 deletions packages/api-server/api_server/rmf_io/book_keeper.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from api_server.models import (
AlertRequest,
AlertResponse,
BeaconState,
BuildingMap,
DispenserHealth,
Expand All @@ -21,9 +22,7 @@
LiftState,
)
from api_server.models.health import BaseBasicHealth
from api_server.repositories import (
AlertRepository, # , is_final_location_alert_check, LocationAlertFailResponse, LocationAlertSuccessResponse
)
from api_server.repositories import AlertRepository

from .events import AlertEvents, RmfEvents

Expand Down Expand Up @@ -62,6 +61,7 @@ async def start(self):
self._record_ingestor_state()
self._record_ingestor_health()
self._record_alert_request()
self._record_alert_response()

async def stop(self):
for sub in self._subscriptions:
Expand Down Expand Up @@ -205,3 +205,14 @@ async def update(alert_request: AlertRequest):
lambda x: self._create_task(update(x))
)
)

def _record_alert_response(self):
async def update(alert_response: AlertResponse):
await alert_response.save()
logging.debug(json.dumps(alert_response.dict()))

self._subscriptions.append(
self.alert_events.alert_responses.subscribe(
lambda x: self._create_task(update(x))
)
)
30 changes: 2 additions & 28 deletions packages/api-server/api_server/routes/tasks/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,11 @@
start_time_between_query,
)
from api_server.fast_io import FastIORouter, SubscriptionRequest
from api_server.gateway import rmf_gateway
from api_server.logging import LoggerAdapter, get_logger
from api_server.models.tortoise_models import TaskState as DbTaskState
from api_server.repositories import AlertRepository, RmfRepository, TaskRepository
from api_server.repositories import RmfRepository, TaskRepository
from api_server.response import RawJSONResponse
from api_server.rmf_io import alert_events, task_events, tasks_service
from api_server.rmf_io import task_events, tasks_service
from api_server.routes.building_map import get_building_map

router = FastIORouter(tags=["Tasks"])
Expand Down Expand Up @@ -390,28 +389,3 @@ async def post_undo_skip_phase(
request: mdl.UndoPhaseSkipRequest = Body(...),
):
return RawJSONResponse(await tasks_service().call(request.json(exclude_none=True)))


@router.post("/location_complete")
async def location_complete(
task_id: str,
location: str,
success: bool,
alert_repo: AlertRepository = Depends(AlertRepository),
logger: LoggerAdapter = Depends(get_logger),
):
"""
Warning: This endpoint is still actively being worked on and could be
subjected to modifications.
"""
response_model = await alert_repo.create_location_alert_response(
task_id, location, success
)
if response_model is None:
raise HTTPException(
422,
f"Failed to create location completion alert response to task [{task_id}] at location [{location}]",
)
alert_events.alert_responses.on_next(response_model)
rmf_gateway().respond_to_alert(response_model.id, response_model.response)
logger.info(response_model)
33 changes: 0 additions & 33 deletions packages/api-server/api_server/routes/test_alerts.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,36 +223,3 @@ def test_get_unresponded_alert_ids(self):
returned_alert_ids = [a["id"] for a in returned_alerts]
self.assertTrue(first_id not in returned_alert_ids)
self.assertTrue(second_id in returned_alert_ids)

def test_task_location_complete(self):
alert_id = str(uuid4())
task_id = "test_task_id"
location_name = "test_location"
alert = make_alert_request(alert_id=alert_id, responses=["success", "fail"])
alert.alert_parameters = [
mdl.AlertParameter(name="type", value="location_result"),
mdl.AlertParameter(name="location_name", value=location_name),
]
alert.task_id = task_id
resp = self.client.post("/alerts/request", data=alert.json(exclude_none=True))
self.assertEqual(201, resp.status_code, resp.content)
self.assertEqual(alert, resp.json(), resp.content)

# complete wrong task ID
params = {
"task_id": "wrong_task_id",
"location": location_name,
"success": True,
}
resp = self.client.post(f"/tasks/location_complete?{urlencode(params)}")
self.assertEqual(422, resp.status_code, resp.content)

# complete missing location
params = {"task_id": task_id, "location": "wrong_location", "success": True}
resp = self.client.post(f"/tasks/location_complete?{urlencode(params)}")
self.assertEqual(422, resp.status_code, resp.content)

# complete location
params = {"task_id": task_id, "location": location_name, "success": True}
resp = self.client.post(f"/tasks/location_complete?{urlencode(params)}")
self.assertEqual(200, resp.status_code, resp.content)

0 comments on commit eb83039

Please sign in to comment.