Skip to content

Commit

Permalink
Lint and better typing with static methods
Browse files Browse the repository at this point in the history
Signed-off-by: Aaron Chong <[email protected]>
  • Loading branch information
aaronchongth committed May 30, 2024
1 parent 99855f8 commit 78f5e0b
Show file tree
Hide file tree
Showing 7 changed files with 54 additions and 131 deletions.
8 changes: 8 additions & 0 deletions packages/api-server/api_server/models/alerts.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ class AlertResponse(BaseModel):
unix_millis_response_time: int
response: str

@staticmethod
def from_tortoise(tortoise: ttm.AlertResponse) -> "AlertResponse":
return AlertResponse(**tortoise.data)


class AlertRequest(BaseModel):
class Tier(str, Enum):
Expand All @@ -34,6 +38,10 @@ class Tier(str, Enum):
alert_parameters: List[AlertParameter]
task_id: Optional[str]

@staticmethod
def from_tortoise(tortoise: ttm.AlertRequest) -> "AlertRequest":
return AlertRequest(**tortoise.data)

async def save(self) -> None:
await ttm.AlertRequest.update_or_create(
{
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
from tortoise.fields import BooleanField, CharField, JSONField, OneToOneField
from tortoise.fields import (
BooleanField,
CharField,
JSONField,
OneToOneField,
ReverseRelation,
)
from tortoise.models import Model


Expand All @@ -15,3 +21,4 @@ class AlertRequest(Model):
data = JSONField()
response_expected = BooleanField(null=False, index=True)
task_id = CharField(255, null=True, index=True)
alert_response = ReverseRelation["AlertResponse"]
46 changes: 5 additions & 41 deletions packages/api-server/api_server/repositories/alerts.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,6 @@
from api_server.models import AlertRequest, AlertResponse
from api_server.models import tortoise_models as ttm

# from api_server.gateway import rmf_gateway


# TODO: not hardcode all these expected values
LocationAlertSuccessResponse = "success"
LocationAlertFailResponse = "fail"
Expand Down Expand Up @@ -126,7 +123,7 @@ async def get_alert(self, alert_id: str) -> Optional[AlertRequest]:
logging.error(f"Alert with ID {alert_id} does not exists")
return None

alert_model = AlertRequest(**alert.data)
alert_model = AlertRequest.from_tortoise(alert)
return alert_model

async def create_response(
Expand All @@ -137,7 +134,7 @@ async def create_response(
logging.error(f"Alert with ID {alert_id} does not exists")
return None

alert_model = AlertRequest(**alert.data)
alert_model = AlertRequest.from_tortoise(alert)
if response not in alert_model.responses_available:
logging.error(
f"Alert with ID {alert_model.id} does not have allow response of {response}"
Expand All @@ -160,7 +157,7 @@ async def get_alert_response(self, alert_id: str) -> Optional[AlertResponse]:
logging.error(f"Response to alert with ID {alert_id} does not exists")
return None

response_model = AlertResponse(**response.data)
response_model = AlertResponse.from_tortoise(response)
return response_model

async def get_alerts_of_task(
Expand All @@ -175,14 +172,14 @@ async def get_alerts_of_task(
else:
task_id_alerts = await ttm.AlertRequest.filter(task_id=task_id)

alert_models = [AlertRequest(**alert.data) for alert in task_id_alerts]
alert_models = [AlertRequest.from_tortoise(alert) for alert in task_id_alerts]
return alert_models

async def get_unresponded_alerts(self) -> List[AlertRequest]:
unresponded_alerts = await ttm.AlertRequest.filter(
alert_response=None, response_expected=True
)
return [AlertRequest(**alert.data) for alert in unresponded_alerts]
return [AlertRequest.from_tortoise(alert) for alert in unresponded_alerts]

async def create_location_alert_response(
self,
Expand Down Expand Up @@ -235,36 +232,3 @@ async def create_location_alert_response(
f"Task {task_id} is not awaiting completion of location {location}"
)
return None

async def check_all_task_location_alerts_if_succeeded(self, task_id: str) -> bool:
"""
Checks if all location alert reponses for the task were successful.
Note: This is an experimental feature and may be subjected to
modifications often.
"""
task_id_alerts = await ttm.AlertRequest.filter(task_id=task_id)
if len(task_id_alerts) == 0:
logging.info(f"There were no location alerts for task {task_id}")
return False

for alert in task_id_alerts:
alert_model = AlertRequest(**alert.data)
location_alert_location = get_location_from_location_alert(alert_model)
if location_alert_location is None:
continue

if alert.alert_response is None:
logging.info(
f"Alert {alert_model.id} does not have a response, check return False"
)
return False

alert_response_model = AlertResponse(**alert.alert_response.data)
if alert_response_model.response != LocationAlertSuccessResponse:
logging.info(
f"Alert {alert_model.id} has a response {alert_response_model.response}, check return False"
)
return False

logging.info(f"All location alerts for task {task_id} succeeded")
return True
59 changes: 0 additions & 59 deletions packages/api-server/api_server/repositories/tasks.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import sys
from datetime import datetime
from typing import Dict, List, Optional, Sequence, Tuple, cast

Expand All @@ -21,12 +20,9 @@
User,
)
from api_server.models import tortoise_models as ttm
from api_server.models.rmf_api.log_entry import Tier
from api_server.models.rmf_api.task_state import Category, Id, Phase
from api_server.models.tortoise_models import TaskRequest as DbTaskRequest
from api_server.models.tortoise_models import TaskState as DbTaskState
from api_server.query import add_pagination
from api_server.rmf_io import task_events


class TaskRepository:
Expand Down Expand Up @@ -244,61 +240,6 @@ async def _saveTaskLogs(
text=log.text,
)

async def save_log_acknowledged_task_completion(
self,
task_id: str,
acknowledged_by: str,
unix_millis_acknowledged_time: int,
action: str = "Task completion",
) -> None:
async with in_transaction():
task_logs = await self.get_task_log(task_id, (0, sys.maxsize))
task_state = await self.get_task_state(task_id=task_id)
# A try could be used here to avoid using so many "ands"
# but the configured lint suggests comparing that no value is None
if task_logs and task_state and task_logs.phases and task_state.phases:
# The next phase key value matches in both `task_logs` and `task_state`.
# It is the same whether it is obtained from `task_logs` or from `task_state`.
# In this case, it is obtained from `task_logs` and is also used to assign the next
# phase in `task_state`.
next_phase_key = str(int(list(task_logs.phases)[-1]) + 1)
else:
raise ValueError("Phases can't be null")

event = LogEntry(
seq=0,
tier=Tier.warning,
unix_millis_time=unix_millis_acknowledged_time,
text=f"{action} acknowledged by {acknowledged_by}",
)
task_logs.phases = {
**task_logs.phases,
next_phase_key: Phases(log=[], events={"0": [event]}),
}

await self.save_task_log(task_logs)

task_state.phases = {
**task_state.phases,
next_phase_key: Phase(
id=Id(__root__=next_phase_key),
category=Category(__root__="Task completed"),
detail=None,
unix_millis_start_time=None,
unix_millis_finish_time=None,
original_estimate_millis=None,
estimate_millis=None,
final_event_id=None,
events=None,
skip_requests=None,
),
}

await self.save_task_state(task_state)
# Notifies observers of the next task_state value to correctly display the title of the
# logs when acknowledged by a user without reloading the page.
task_events.task_states.on_next(task_state)

async def save_task_log(self, task_log: TaskEventLog) -> None:
async with in_transaction():
db_task_log = (
Expand Down
2 changes: 0 additions & 2 deletions packages/api-server/api_server/routes/alerts.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from datetime import datetime
from typing import List

from fastapi import Depends, HTTPException
Expand All @@ -7,7 +6,6 @@
from api_server.fast_io import FastIORouter, SubscriptionRequest
from api_server.gateway import rmf_gateway
from api_server.models import AlertRequest, AlertResponse
from api_server.models import tortoise_models as ttm
from api_server.repositories import AlertRepository
from api_server.rmf_io import alert_events

Expand Down
15 changes: 9 additions & 6 deletions packages/api-server/api_server/routes/internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,8 @@
from api_server import models as mdl
from api_server.app_config import app_config
from api_server.logging import LoggerAdapter, get_logger
from api_server.repositories import FleetRepository, TaskRepository
from api_server.rmf_io import fleet_events, task_events
from api_server.routes.alerts import create_new_alert
from api_server.repositories import AlertRepository, FleetRepository, TaskRepository
from api_server.rmf_io import alert_events, fleet_events, task_events

router = APIRouter(tags=["_internal"])
user: mdl.User = mdl.User(username="__rmf_internal__", is_admin=True)
Expand Down Expand Up @@ -54,6 +53,7 @@ def disconnect(self, websocket: WebSocket):

async def process_msg(
msg: Dict[str, Any],
alert_repo: AlertRepository,
fleet_repo: FleetRepository,
task_repo: TaskRepository,
logger: LoggerAdapter,
Expand Down Expand Up @@ -86,7 +86,8 @@ async def process_msg(
alert_parameters=[],
task_id=task_state.booking.id,
)
await create_new_alert(alert_request)
created_alert = await alert_repo.create_new_alert(alert_request)
alert_events.alert_requests.on_next(created_alert)
elif task_state.status == mdl.Status.failed:
errorMessage = ""
if (
Expand All @@ -110,7 +111,8 @@ async def process_msg(
alert_parameters=[],
task_id=task_state.booking.id,
)
await create_new_alert(alert_request)
created_alert = await alert_repo.create_new_alert(alert_request)
alert_events.alert_requests.on_next(created_alert)

elif payload_type == "task_log_update":
task_log = mdl.TaskEventLog(**msg["data"])
Expand All @@ -134,12 +136,13 @@ async def rmf_gateway(
logger: LoggerAdapter = Depends(get_logger),
):
await connection_manager.connect(websocket)
alert_repo = AlertRepository()
fleet_repo = FleetRepository(user, logger)
task_repo = TaskRepository(user, logger)
try:
while True:
msg: Dict[str, Any] = await websocket.receive_json()
await process_msg(msg, fleet_repo, task_repo, logger)
await process_msg(msg, alert_repo, fleet_repo, task_repo, logger)
except (WebSocketDisconnect, ConnectionClosed):
connection_manager.disconnect(websocket)
logger.warning("Client websocket disconnected")
46 changes: 24 additions & 22 deletions packages/dashboard/src/components/alert-store.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -177,28 +177,30 @@ const AlertDialog = React.memo((props: AlertDialogProps) => {
margin="dense"
value={alertRequest.subtitle.length > 0 ? alertRequest.subtitle : 'n/a'}
/>
<TextField
label="Message"
id="standard-size-small"
size="small"
variant="filled"
sx={{
'& .MuiFilledInput-root': {
fontSize: isScreenHeightLessThan800 ? '0.8rem' : '1.15',
},
}}
InputLabelProps={{ style: { fontSize: isScreenHeightLessThan800 ? 16 : 20 } }}
InputProps={{ readOnly: true, className: classes.textField }}
fullWidth={true}
multiline
maxRows={4}
margin="dense"
value={
(alertRequest.message.length > 0 ? alertRequest.message : 'n/a') +
'\n' +
(additionalAlertMessage ?? '')
}
/>
{(alertRequest.message.length > 0 || additionalAlertMessage !== null) && (
<TextField
label="Message"
id="standard-size-small"
size="small"
variant="filled"
sx={{
'& .MuiFilledInput-root': {
fontSize: isScreenHeightLessThan800 ? '0.8rem' : '1.15',
},
}}
InputLabelProps={{ style: { fontSize: isScreenHeightLessThan800 ? 16 : 20 } }}
InputProps={{ readOnly: true, className: classes.textField }}
fullWidth={true}
multiline
maxRows={4}
margin="dense"
value={
(alertRequest.message.length > 0 ? alertRequest.message : 'n/a') +
'\n' +
(additionalAlertMessage ?? '')
}
/>
)}
</DialogContent>
<DialogActions>
{alertRequest.responses_available.map((response) => {
Expand Down

0 comments on commit 78f5e0b

Please sign in to comment.