diff --git a/packages/api-client/lib/openapi/api.ts b/packages/api-client/lib/openapi/api.ts index 6e9353747..54371dacc 100644 --- a/packages/api-client/lib/openapi/api.ts +++ b/packages/api-client/lib/openapi/api.ts @@ -727,6 +727,12 @@ export interface DeliveryAlert { */ message: string; } +/** + * + * @export + * @interface Description + */ +export interface Description {} /** * Detailed information about a task, phase, or event * @export @@ -2403,47 +2409,10 @@ export interface Task { export interface TaskBookingLabel { /** * - * @type {TaskBookingLabelDescription} + * @type {{ [key: string]: Description; }} * @memberof TaskBookingLabel */ - description: TaskBookingLabelDescription; -} -/** - * This description holds several fields that could be useful for frontend dashboards when dispatching a task, to then be identified or rendered accordingly back on the same frontend. - * @export - * @interface TaskBookingLabelDescription - */ -export interface TaskBookingLabelDescription { - /** - * - * @type {string} - * @memberof TaskBookingLabelDescription - */ - task_definition_id: string; - /** - * - * @type {number} - * @memberof TaskBookingLabelDescription - */ - unix_millis_warn_time?: number; - /** - * - * @type {string} - * @memberof TaskBookingLabelDescription - */ - pickup?: string; - /** - * - * @type {string} - * @memberof TaskBookingLabelDescription - */ - destination?: string; - /** - * - * @type {string} - * @memberof TaskBookingLabelDescription - */ - cart_id?: string; + description: { [key: string]: Description }; } /** * Response to a request to cancel a task @@ -9529,15 +9498,16 @@ export const TasksApiAxiosParamCreator = function (configuration?: Configuration }; }, /** - * + * Note that sorting by `pickup` and `destination` is mutually exclusive and sorting by either of them will filter only tasks which has those labels. * @summary Query Task States * @param {string} [taskId] comma separated list of task ids * @param {string} [category] comma separated list of task categories * @param {string} [requester] comma separated list of requester names - * @param {string} [pickup] comma separated list of pickup names - * @param {string} [destination] comma separated list of destination names + * @param {string} [pickup] comma separated list of pickup names. [deprecated] use `label` instead + * @param {string} [destination] comma separated list of destination names, [deprecated] use `label` instead * @param {string} [assignedTo] comma separated list of assigned robot names * @param {string} [status] comma separated list of statuses + * @param {string} [label] comma separated list of labels, each item must be in the form <key>=<value>, multiple items will filter tasks with all the labels * @param {string} [requestTimeBetween] The period of request time to fetch, in unix millis. This must be a comma separated string, \'X,Y\' to fetch between X millis and Y millis inclusive. Example: \"1000,2000\" - Fetches logs between unix millis 1000 and 2000. * @param {string} [startTimeBetween] The period of starting time to fetch, in unix millis. This must be a comma separated string, \'X,Y\' to fetch between X millis and Y millis inclusive. Example: \"1000,2000\" - Fetches logs between unix millis 1000 and 2000. * @param {string} [finishTimeBetween] The period of finishing time to fetch, in unix millis. This must be a comma separated string, \'X,Y\' to fetch between X millis and Y millis inclusive. Example: \"1000,2000\" - Fetches logs between unix millis 1000 and 2000. \"-60000\" - Fetches logs in the last minute. @@ -9555,6 +9525,7 @@ export const TasksApiAxiosParamCreator = function (configuration?: Configuration destination?: string, assignedTo?: string, status?: string, + label?: string, requestTimeBetween?: string, startTimeBetween?: string, finishTimeBetween?: string, @@ -9603,6 +9574,10 @@ export const TasksApiAxiosParamCreator = function (configuration?: Configuration localVarQueryParameter['status'] = status; } + if (label !== undefined) { + localVarQueryParameter['label'] = label; + } + if (requestTimeBetween !== undefined) { localVarQueryParameter['request_time_between'] = requestTimeBetween; } @@ -10163,15 +10138,16 @@ export const TasksApiFp = function (configuration?: Configuration) { return createRequestFunction(localVarAxiosArgs, globalAxios, BASE_PATH, configuration); }, /** - * + * Note that sorting by `pickup` and `destination` is mutually exclusive and sorting by either of them will filter only tasks which has those labels. * @summary Query Task States * @param {string} [taskId] comma separated list of task ids * @param {string} [category] comma separated list of task categories * @param {string} [requester] comma separated list of requester names - * @param {string} [pickup] comma separated list of pickup names - * @param {string} [destination] comma separated list of destination names + * @param {string} [pickup] comma separated list of pickup names. [deprecated] use `label` instead + * @param {string} [destination] comma separated list of destination names, [deprecated] use `label` instead * @param {string} [assignedTo] comma separated list of assigned robot names * @param {string} [status] comma separated list of statuses + * @param {string} [label] comma separated list of labels, each item must be in the form <key>=<value>, multiple items will filter tasks with all the labels * @param {string} [requestTimeBetween] The period of request time to fetch, in unix millis. This must be a comma separated string, \'X,Y\' to fetch between X millis and Y millis inclusive. Example: \"1000,2000\" - Fetches logs between unix millis 1000 and 2000. * @param {string} [startTimeBetween] The period of starting time to fetch, in unix millis. This must be a comma separated string, \'X,Y\' to fetch between X millis and Y millis inclusive. Example: \"1000,2000\" - Fetches logs between unix millis 1000 and 2000. * @param {string} [finishTimeBetween] The period of finishing time to fetch, in unix millis. This must be a comma separated string, \'X,Y\' to fetch between X millis and Y millis inclusive. Example: \"1000,2000\" - Fetches logs between unix millis 1000 and 2000. \"-60000\" - Fetches logs in the last minute. @@ -10189,6 +10165,7 @@ export const TasksApiFp = function (configuration?: Configuration) { destination?: string, assignedTo?: string, status?: string, + label?: string, requestTimeBetween?: string, startTimeBetween?: string, finishTimeBetween?: string, @@ -10205,6 +10182,7 @@ export const TasksApiFp = function (configuration?: Configuration) { destination, assignedTo, status, + label, requestTimeBetween, startTimeBetween, finishTimeBetween, @@ -10625,15 +10603,16 @@ export const TasksApiFactory = function ( .then((request) => request(axios, basePath)); }, /** - * + * Note that sorting by `pickup` and `destination` is mutually exclusive and sorting by either of them will filter only tasks which has those labels. * @summary Query Task States * @param {string} [taskId] comma separated list of task ids * @param {string} [category] comma separated list of task categories * @param {string} [requester] comma separated list of requester names - * @param {string} [pickup] comma separated list of pickup names - * @param {string} [destination] comma separated list of destination names + * @param {string} [pickup] comma separated list of pickup names. [deprecated] use `label` instead + * @param {string} [destination] comma separated list of destination names, [deprecated] use `label` instead * @param {string} [assignedTo] comma separated list of assigned robot names * @param {string} [status] comma separated list of statuses + * @param {string} [label] comma separated list of labels, each item must be in the form <key>=<value>, multiple items will filter tasks with all the labels * @param {string} [requestTimeBetween] The period of request time to fetch, in unix millis. This must be a comma separated string, \'X,Y\' to fetch between X millis and Y millis inclusive. Example: \"1000,2000\" - Fetches logs between unix millis 1000 and 2000. * @param {string} [startTimeBetween] The period of starting time to fetch, in unix millis. This must be a comma separated string, \'X,Y\' to fetch between X millis and Y millis inclusive. Example: \"1000,2000\" - Fetches logs between unix millis 1000 and 2000. * @param {string} [finishTimeBetween] The period of finishing time to fetch, in unix millis. This must be a comma separated string, \'X,Y\' to fetch between X millis and Y millis inclusive. Example: \"1000,2000\" - Fetches logs between unix millis 1000 and 2000. \"-60000\" - Fetches logs in the last minute. @@ -10651,6 +10630,7 @@ export const TasksApiFactory = function ( destination?: string, assignedTo?: string, status?: string, + label?: string, requestTimeBetween?: string, startTimeBetween?: string, finishTimeBetween?: string, @@ -10668,6 +10648,7 @@ export const TasksApiFactory = function ( destination, assignedTo, status, + label, requestTimeBetween, startTimeBetween, finishTimeBetween, @@ -11114,15 +11095,16 @@ export class TasksApi extends BaseAPI { } /** - * + * Note that sorting by `pickup` and `destination` is mutually exclusive and sorting by either of them will filter only tasks which has those labels. * @summary Query Task States * @param {string} [taskId] comma separated list of task ids * @param {string} [category] comma separated list of task categories * @param {string} [requester] comma separated list of requester names - * @param {string} [pickup] comma separated list of pickup names - * @param {string} [destination] comma separated list of destination names + * @param {string} [pickup] comma separated list of pickup names. [deprecated] use `label` instead + * @param {string} [destination] comma separated list of destination names, [deprecated] use `label` instead * @param {string} [assignedTo] comma separated list of assigned robot names * @param {string} [status] comma separated list of statuses + * @param {string} [label] comma separated list of labels, each item must be in the form <key>=<value>, multiple items will filter tasks with all the labels * @param {string} [requestTimeBetween] The period of request time to fetch, in unix millis. This must be a comma separated string, \'X,Y\' to fetch between X millis and Y millis inclusive. Example: \"1000,2000\" - Fetches logs between unix millis 1000 and 2000. * @param {string} [startTimeBetween] The period of starting time to fetch, in unix millis. This must be a comma separated string, \'X,Y\' to fetch between X millis and Y millis inclusive. Example: \"1000,2000\" - Fetches logs between unix millis 1000 and 2000. * @param {string} [finishTimeBetween] The period of finishing time to fetch, in unix millis. This must be a comma separated string, \'X,Y\' to fetch between X millis and Y millis inclusive. Example: \"1000,2000\" - Fetches logs between unix millis 1000 and 2000. \"-60000\" - Fetches logs in the last minute. @@ -11141,6 +11123,7 @@ export class TasksApi extends BaseAPI { destination?: string, assignedTo?: string, status?: string, + label?: string, requestTimeBetween?: string, startTimeBetween?: string, finishTimeBetween?: string, @@ -11158,6 +11141,7 @@ export class TasksApi extends BaseAPI { destination, assignedTo, status, + label, requestTimeBetween, startTimeBetween, finishTimeBetween, diff --git a/packages/api-client/lib/version.ts b/packages/api-client/lib/version.ts index d44ff17a7..8f08dd1f7 100644 --- a/packages/api-client/lib/version.ts +++ b/packages/api-client/lib/version.ts @@ -3,6 +3,6 @@ import { version as rmfModelVer } from 'rmf-models'; export const version = { rmfModels: rmfModelVer, - rmfServer: '98741b14ceca74208ca98e4bb0c3ca9e41ca1e3c', + rmfServer: 'd536f9525f277088015d827b6b7198035d1a856b', openapiGenerator: '6.2.1', }; diff --git a/packages/api-client/schema/index.ts b/packages/api-client/schema/index.ts index 37ea1f1b5..9d0215111 100644 --- a/packages/api-client/schema/index.ts +++ b/packages/api-client/schema/index.ts @@ -709,6 +709,8 @@ export default { get: { tags: ['Tasks'], summary: 'Query Task States', + description: + 'Note that sorting by `pickup` and `destination` is mutually exclusive and sorting\nby either of them will filter only tasks which has those labels.', operationId: 'query_task_states_tasks_get', parameters: [ { @@ -745,23 +747,27 @@ export default { in: 'query', }, { - description: 'comma separated list of pickup names', + description: 'comma separated list of pickup names. [deprecated] use `label` instead', required: false, + deprecated: true, schema: { title: 'Pickup', type: 'string', - description: 'comma separated list of pickup names', + description: 'comma separated list of pickup names. [deprecated] use `label` instead', }, name: 'pickup', in: 'query', }, { - description: 'comma separated list of destination names', + description: + 'comma separated list of destination names, [deprecated] use `label` instead', required: false, + deprecated: true, schema: { title: 'Destination', type: 'string', - description: 'comma separated list of destination names', + description: + 'comma separated list of destination names, [deprecated] use `label` instead', }, name: 'destination', in: 'query', @@ -788,6 +794,19 @@ export default { name: 'status', in: 'query', }, + { + description: + 'comma separated list of labels, each item must be in the form =, multiple items will filter tasks with all the labels', + required: false, + schema: { + title: 'Label', + type: 'string', + description: + 'comma separated list of labels, each item must be in the form =, multiple items will filter tasks with all the labels', + }, + name: 'label', + in: 'query', + }, { description: '\n The period of request time to fetch, in unix millis.\n\n This must be a comma separated string, \'X,Y\' to fetch between X millis and Y millis inclusive.\n\n Example:\n "1000,2000" - Fetches logs between unix millis 1000 and 2000.\n ', @@ -3675,23 +3694,17 @@ export default { title: 'TaskBookingLabel', required: ['description'], type: 'object', - properties: { description: { $ref: '#/components/schemas/TaskBookingLabelDescription' } }, - description: - 'This label is to be populated by any frontend during a task dispatch, by\nbeing added to TaskRequest.labels, which in turn populates\nTaskState.booking.labels, and can be used to display relevant information\nneeded for any frontends.', - }, - TaskBookingLabelDescription: { - title: 'TaskBookingLabelDescription', - required: ['task_definition_id'], - type: 'object', properties: { - task_definition_id: { title: 'Task Definition Id', type: 'string' }, - unix_millis_warn_time: { title: 'Unix Millis Warn Time', type: 'integer' }, - pickup: { title: 'Pickup', type: 'string' }, - destination: { title: 'Destination', type: 'string' }, - cart_id: { title: 'Cart Id', type: 'string' }, + description: { + title: 'Description', + type: 'object', + additionalProperties: { + anyOf: [{ type: 'string' }, { type: 'integer' }, { type: 'number' }], + }, + }, }, description: - 'This description holds several fields that could be useful for frontend\ndashboards when dispatching a task, to then be identified or rendered\naccordingly back on the same frontend.', + 'This label is to be populated by any frontend during a task dispatch, by\nbeing added to TaskRequest.labels, which in turn populates\nTaskState.booking.labels, and can be used to display relevant information\nneeded for any frontends.', }, TaskCancelResponse: { title: 'TaskCancelResponse', diff --git a/packages/api-server/api_server/models/__init__.py b/packages/api-server/api_server/models/__init__.py index 238fdf04b..414a446d0 100644 --- a/packages/api-server/api_server/models/__init__.py +++ b/packages/api-server/api_server/models/__init__.py @@ -6,6 +6,7 @@ from .doors import * from .health import * from .ingestors import * +from .labels import * from .lifts import * from .pagination import * from .rmf_api.activity_discovery_request import ActivityDiscoveryRequest diff --git a/packages/api-server/api_server/models/labels.py b/packages/api-server/api_server/models/labels.py new file mode 100644 index 000000000..209a8ff59 --- /dev/null +++ b/packages/api-server/api_server/models/labels.py @@ -0,0 +1,25 @@ +from typing import Sequence + +from pydantic import BaseModel + + +class Labels(BaseModel): + """ + Labels for a resource. + """ + + __root__: dict[str, str] + + @staticmethod + def _parse_label(s: str) -> tuple[str, str]: + sep = s.find("=") + if sep == -1: + return s, "" + return s[:sep], s[sep + 1 :] + + @staticmethod + def from_strings(labels: Sequence[str]) -> "Labels": + return Labels(__root__=dict(Labels._parse_label(s) for s in labels)) + + def to_strings(self) -> list[str]: + return [f"{k}={v}" for k, v in self.__root__.items()] diff --git a/packages/api-server/api_server/models/task_booking_label.py b/packages/api-server/api_server/models/task_booking_label.py index 1e59d77ab..3e5602680 100644 --- a/packages/api-server/api_server/models/task_booking_label.py +++ b/packages/api-server/api_server/models/task_booking_label.py @@ -3,30 +3,6 @@ import pydantic from pydantic import BaseModel -# NOTE: This label model needs to exactly match the fields that are defined and -# populated by the dashboard. Any changes to either side will require syncing. - - -class TaskBookingLabelDescription(BaseModel): - """ - This description holds several fields that could be useful for frontend - dashboards when dispatching a task, to then be identified or rendered - accordingly back on the same frontend. - """ - - task_definition_id: str - unix_millis_warn_time: Optional[int] - pickup: Optional[str] - destination: Optional[str] - cart_id: Optional[str] - - @staticmethod - def from_json_string(json_str: str) -> Optional["TaskBookingLabelDescription"]: - try: - return TaskBookingLabelDescription.parse_raw(json_str) - except pydantic.error_wrappers.ValidationError: - return None - class TaskBookingLabel(BaseModel): """ @@ -36,7 +12,7 @@ class TaskBookingLabel(BaseModel): needed for any frontends. """ - description: TaskBookingLabelDescription + description: dict[str, str | int | float] @staticmethod def from_json_string(json_str: str) -> Optional["TaskBookingLabel"]: diff --git a/packages/api-server/api_server/query.py b/packages/api-server/api_server/query.py index b04c3d565..35dcb62c1 100644 --- a/packages/api-server/api_server/query.py +++ b/packages/api-server/api_server/query.py @@ -1,5 +1,7 @@ from typing import Dict, Optional +import tortoise.functions as tfuncs +from tortoise.expressions import Q from tortoise.queryset import MODEL, QuerySet from api_server.models.pagination import Pagination @@ -9,24 +11,46 @@ def add_pagination( query: QuerySet[MODEL], pagination: Pagination, field_mappings: Optional[Dict[str, str]] = None, + group_by: str | None = None, ) -> QuerySet[MODEL]: """ - Adds pagination and ordering to a query. + Adds pagination and ordering to a query. If the order field starts with `label=`, it is + assumed to be a label and label sorting will used. In this case, the model must have + a reverse relation named "labels" and the `group_by` param is required. :param field_mapping: A dict mapping the order fields to the fields used to build the query. e.g. a url of `?order_by=order_field` and a field mapping of `{"order_field": "db_field"}` will order the query result according to `db_field`. + :param group_by: Required when sorting by labels, must be the foreign key column of the label table. """ field_mappings = field_mappings or {} + annotations = {} query = query.limit(pagination.limit).offset(pagination.offset) if pagination.order_by is not None: order_fields = [] order_values = pagination.order_by.split(",") for v in order_values: + # perform the mapping after stripping the order prefix + order_prefix = "" + order_field = v if v[0] in ["-", "+"]: - stripped = v[1:] - order_fields.append(v[0] + field_mappings.get(stripped, stripped)) - else: - order_fields.append(field_mappings.get(v, v)) + order_prefix = v[0] + order_field = v[1:] + order_field = field_mappings.get(order_field, order_field) + + # add annotations required for sorting by labels + if order_field.startswith("label="): + f = order_field[6:] + annotations[f"label_sort_{f}"] = tfuncs.Max( + "labels__label_value_str", + _filter=Q(labels__label_name=f), + ) + order_field = f"label_sort_{f}" + + order_fields.append(order_prefix + order_field) + + query = query.annotate(**annotations) + if group_by is not None: + query = query.group_by(group_by) query = query.order_by(*order_fields) return query diff --git a/packages/api-server/api_server/repositories/tasks.py b/packages/api-server/api_server/repositories/tasks.py index c857abfea..5fb1ee963 100644 --- a/packages/api-server/api_server/repositories/tasks.py +++ b/packages/api-server/api_server/repositories/tasks.py @@ -117,31 +117,40 @@ async def save_task_state(self, task_state: TaskState) -> None: # Here we generate the labels required for server-side sorting and # filtering. - if booking_label.description.pickup is not None: - await ttm.TaskLabel.create( - state=state, - label_name="pickup", - label_value_str=booking_label.description.pickup, - ) - if booking_label.description.destination is not None: - await ttm.TaskLabel.create( - state=state, - label_name="destination", - label_value_str=booking_label.description.destination, - ) - if booking_label.description.unix_millis_warn_time is not None: - await ttm.TaskLabel.create( - state=state, - label_name="unix_millis_warn_time", - label_value_num=booking_label.description.unix_millis_warn_time, - ) + async with in_transaction(): + for k, v in booking_label.description.items(): + if isinstance(v, str): + await ttm.TaskLabel.create( + state=state, label_name=k, label_value_str=v + ) + elif isinstance(v, int): + await ttm.TaskLabel.create( + state=state, + label_name=k, + label_value_num=v, + label_value_float=v, # also store float to make querying easier + ) + elif isinstance(v, float): + exact_val = int(v) if v.is_integer else None + await ttm.TaskLabel.create( + state=state, + label_name=k, + label_value_float=v, + label_value_num=exact_val, # also store int to make querying easier + ) async def query_task_states( self, query: QuerySet[DbTaskState], pagination: Optional[Pagination] = None ) -> List[TaskState]: try: if pagination: - query = add_pagination(query, pagination) + query = add_pagination( + query, + pagination, + # TODO(koonpeng): remove this mapping after `pickup` and `destination` query is removed. + {"pickup": "label=pickup", "destination": "label=destination"}, + group_by="labels__state_id", + ) # TODO: enforce with authz results = await query.values_list("data", flat=True) return [TaskState(**r) for r in results] diff --git a/packages/api-server/api_server/routes/tasks/tasks.py b/packages/api-server/api_server/routes/tasks/tasks.py index 2a8030c91..11a43326f 100644 --- a/packages/api-server/api_server/routes/tasks/tasks.py +++ b/packages/api-server/api_server/routes/tasks/tasks.py @@ -1,8 +1,10 @@ from datetime import datetime from typing import List, Optional, Tuple, cast +import tortoise.functions as tfuncs from fastapi import Body, Depends, HTTPException, Path, Query from rx import operators as rxops +from tortoise.expressions import Q from api_server import models as mdl from api_server.dependencies import ( @@ -99,10 +101,14 @@ async def query_task_states( None, description="comma separated list of requester names" ), pickup: Optional[str] = Query( - None, description="comma separated list of pickup names" + None, + description="comma separated list of pickup names. [deprecated] use `label` instead", + deprecated=True, ), destination: Optional[str] = Query( - None, description="comma separated list of destination names" + None, + description="comma separated list of destination names, [deprecated] use `label` instead", + deprecated=True, ), assigned_to: Optional[str] = Query( None, description="comma separated list of assigned robot names" @@ -114,8 +120,17 @@ async def query_task_states( finish_time_between_query ), status: Optional[str] = Query(None, description="comma separated list of statuses"), + label: str + | None = Query( + None, + description="comma separated list of labels, each item must be in the form =, multiple items will filter tasks with all the labels", + ), pagination: mdl.Pagination = Depends(pagination_query), ): + """ + Note that sorting by `pickup` and `destination` is mutually exclusive and sorting + by either of them will filter only tasks which has those labels. + """ filters = {} if task_id is not None: filters["id___in"] = task_id.split(",") @@ -141,33 +156,47 @@ async def query_task_states( if status_string not in valid_values: continue filters["status__in"].append(mdl.Status(status_string)) + query = DbTaskState.filter(**filters) - # NOTE: in order to perform filtering based on the values in labels, a - # filter on the label_name will need to be applied as well as a filter on - # the label_value. + label_filters = {} if pickup is not None: - filters["labels__label_name"] = "pickup" - filters["labels__label_value_str__in"] = pickup.split(",") + label_filters["label_filter_pickup"] = tfuncs.Count( + "id_", + _filter=Q( + labels__label_name="pickup", + labels__label_value_str__in=pickup.split(","), + ), + ) if destination is not None: - filters["labels__label_name"] = "destination" - filters["labels__label_value_str__in"] = destination.split(",") - - # NOTE: In order to perform sorting based on the values in labels, a filter - # on the label_name has to be performed first. A side-effect of this would - # be that states that do not contain this field will not be returned. - if pagination.order_by is not None: - labels_fields = ["pickup", "destination"] - new_order = pagination.order_by - for field in labels_fields: - if field in pagination.order_by: - filters["labels__label_name"] = field - new_order = pagination.order_by.replace( - field, "labels__label_value_str" + label_filters["label_filter_destination"] = tfuncs.Count( + "id_", + _filter=Q( + labels__label_name="destination", + labels__label_value_str__in=destination.split(","), + ), + ) + if label is not None: + labels = mdl.Labels.from_strings(label.split(",")) + label_filters.update( + { + f"label_filter_{k}": tfuncs.Count( + "id_", _filter=Q(labels__label_name=k, labels__label_value_str=v) ) - break - pagination.order_by = new_order + for k, v in labels.__root__.items() + } + ) + + if len(label_filters) > 0: + filter_gt = {f"{f}__gt": 0 for f in label_filters} + query = ( + query.annotate(**label_filters) + .group_by( + "labels__state_id" + ) # need to group by a related field to make tortoise-orm generate joins + .filter(**filter_gt) + ) - return await task_repo.query_task_states(DbTaskState.filter(**filters), pagination) + return await task_repo.query_task_states(query, pagination) @router.get("/{task_id}/state", response_model=mdl.TaskState) diff --git a/packages/api-server/api_server/routes/tasks/test_tasks.py b/packages/api-server/api_server/routes/tasks/test_tasks.py index ffa78e9db..5c3c7f673 100644 --- a/packages/api-server/api_server/routes/tasks/test_tasks.py +++ b/packages/api-server/api_server/routes/tasks/test_tasks.py @@ -1,6 +1,8 @@ from unittest.mock import patch from uuid import uuid4 +import pydantic + from api_server import models as mdl from api_server.rmf_io import tasks_service from api_server.test import ( @@ -15,8 +17,32 @@ class TestTasksRoute(AppFixture): @classmethod def setUpClass(cls): super().setUpClass() - task_ids = [uuid4()] - cls.task_states = [make_task_state(task_id=f"test_{x}") for x in task_ids] + booking_labels = make_task_booking_label() + booking_labels.description["test_single"] = "" + booking_labels.description["test_single_2"] = "" + booking_labels.description["test_kv"] = "value" + booking_labels.description["test_label_sort"] = "zzz" + booking_labels.description["test_label_sort_2"] = "aaa" + booking_labels.description["test_label_sort_3"] = "bbb" + booking_labels_2 = make_task_booking_label() + booking_labels_2.description["test_label_sort"] = "aaa" + booking_labels_2.description["test_label_sort_3"] = "bbb" + + task_ids = [uuid4(), uuid4()] + cls.task_states = [ + make_task_state( + task_id=f"test_{task_ids[0]}", + booking_labels=[ + "dummy_label_1", + "dummy_label_2", + booking_labels.json(), + ], + ), + make_task_state( + task_id=f"test_{task_ids[1]}", + booking_labels=[booking_labels_2.json()], + ), + ] cls.task_logs = [make_task_log(task_id=f"test_{x}") for x in task_ids] with cls.client.websocket_connect("/_internal") as ws: @@ -44,6 +70,107 @@ def test_query_task_states(self): self.assertEqual(1, len(results)) self.assertEqual(self.task_states[0].booking.id, results[0]["booking"]["id"]) + test_cases = [ + ({"pickup": "Kitchen"}, [self.task_states[0].booking.id]), + ({"destination": "room_203"}, [self.task_states[0].booking.id]), + ( + {"pickup": "Kitchen", "destination": "room_203"}, + [self.task_states[0].booking.id], + ), + ( + {"pickup": "Kitchen", "destination": "room_202"}, + [], + ), + ] + for tc in test_cases: + q = "&".join(f"{k}={v}" for k, v in tc[0].items()) + resp = self.client.get( + f"/tasks?task_id={self.task_states[0].booking.id}&{q}" + ) + self.assertEqual(200, resp.status_code, tc) + results = resp.json() + self.assertEqual(len(tc[1]), len(results), tc) + for a, b in zip(tc[1], results): + self.assertEqual(a, b["booking"]["id"], tc) + + def test_query_task_states_filter_by_label(self): + resp = self.client.get("/tasks?label=not_existing") + self.assertEqual(200, resp.status_code) + results = pydantic.parse_raw_as(list[mdl.TaskState], resp.content) + self.assertEqual(0, len(results)) + + resp = self.client.get("/tasks?label=test_single") + self.assertEqual(200, resp.status_code) + results = pydantic.parse_raw_as(list[mdl.TaskState], resp.content) + self.assertEqual(1, len(results)) + self.assertEqual(self.task_states[0].booking.id, results[0].booking.id) + + resp = self.client.get("/tasks?label=test_single=wrong_value") + self.assertEqual(200, resp.status_code) + results = pydantic.parse_raw_as(list[mdl.TaskState], resp.content) + self.assertEqual(0, len(results)) + + resp = self.client.get("/tasks?label=test_single_2=") + self.assertEqual(200, resp.status_code) + results = pydantic.parse_raw_as(list[mdl.TaskState], resp.content) + self.assertEqual(1, len(results)) + self.assertEqual(self.task_states[0].booking.id, results[0].booking.id) + + resp = self.client.get("/tasks?label=test_kv=value") + self.assertEqual(200, resp.status_code) + results = pydantic.parse_raw_as(list[mdl.TaskState], resp.content) + self.assertEqual(1, len(results)) + self.assertEqual(self.task_states[0].booking.id, results[0].booking.id) + + resp = self.client.get("/tasks?label=test_kv=wrong_value") + self.assertEqual(200, resp.status_code) + results = pydantic.parse_raw_as(list[mdl.TaskState], resp.content) + self.assertEqual(0, len(results)) + + resp = self.client.get("/tasks?label=test_single,test_kv=value") + self.assertEqual(200, resp.status_code) + results = pydantic.parse_raw_as(list[mdl.TaskState], resp.content) + self.assertEqual(1, len(results)) + self.assertEqual(self.task_states[0].booking.id, results[0].booking.id) + + resp = self.client.get("/tasks?label=test_single,test_kv=wrong_value") + self.assertEqual(200, resp.status_code) + results = pydantic.parse_raw_as(list[mdl.TaskState], resp.content) + self.assertEqual(0, len(results)) + + def test_query_task_states_sort_by_label(self): + resp = self.client.get("/tasks?order_by=-label=test_label_sort") + self.assertEqual(200, resp.status_code) + results = pydantic.parse_raw_as(list[mdl.TaskState], resp.content) + self.assertEqual(2, len(results)) + for a, b in zip(self.task_states, results): + self.assertEqual(a, b) + + resp = self.client.get("/tasks?order_by=label=test_label_sort") + self.assertEqual(200, resp.status_code) + results = pydantic.parse_raw_as(list[mdl.TaskState], resp.content) + self.assertEqual(2, len(results)) + for a, b in zip(self.task_states[::-1], results): + self.assertEqual(a, b) + + # test sorting by multiple labels + resp = self.client.get( + "/tasks?order_by=label=test_label_sort,label=test_label_sort_3" + ) + self.assertEqual(200, resp.status_code) + results = pydantic.parse_raw_as(list[mdl.TaskState], resp.content) + self.assertEqual(2, len(results)) + for a, b in zip(self.task_states[::-1], results): + self.assertEqual(a, b) + + # test that tasks without the label are not filtered out + # we don't test the result order because different db has different behavior + # of sorting NULL. + resp = self.client.get("/tasks?order_by=label=test_label_sort_not_existing") + self.assertEqual(200, resp.status_code) + results = pydantic.parse_raw_as(list[mdl.TaskState], resp.content) + self.assertEqual(2, len(results)) + def test_sub_task_state(self): task_id = self.task_states[0].booking.id gen = self.subscribe_sio(f"/tasks/{task_id}/state") @@ -59,10 +186,13 @@ def test_sub_task_state(self): def test_get_task_booking_label(self): resp = self.client.get(f"/tasks/{self.task_states[0].booking.id}/booking_label") self.assertEqual(200, resp.status_code) - self.assertEqual( - make_task_booking_label(), - mdl.TaskBookingLabel(**resp.json()), - ) + labels = mdl.TaskBookingLabel.parse_raw(resp.content) + expected = make_task_booking_label() + for k, v in expected.description.items(): + self.assertEqual( + v, + labels.description[k], + ) def test_get_task_log(self): resp = self.client.get( diff --git a/packages/api-server/api_server/test/__init__.py b/packages/api-server/api_server/test/__init__.py index cf62c9c9f..0cfa7abe6 100644 --- a/packages/api-server/api_server/test/__init__.py +++ b/packages/api-server/api_server/test/__init__.py @@ -2,7 +2,6 @@ from api_server.models import User from .mocks import * -from .test_client import client from .test_data import * from .test_fixtures import * from .test_utils import * diff --git a/packages/api-server/api_server/test/test_client.py b/packages/api-server/api_server/test/test_client.py index 20975db06..0539a7235 100644 --- a/packages/api-server/api_server/test/test_client.py +++ b/packages/api-server/api_server/test/test_client.py @@ -33,38 +33,15 @@ def _generate_token(username: str): class TestClient(BaseTestClient): - _admin_token: Optional[str] = None - def __init__(self): super().__init__(app) + self.current_user: str + self.set_user("admin") @classmethod def token(cls, username: str) -> str: - if username == "admin": - if cls._admin_token is None: - cls._admin_token = _generate_token("admin") - return cls._admin_token - return _generate_token(username) - def set_user(self, user): + def set_user(self, user: str): + self.current_user = user self.headers["Authorization"] = f"bearer {self.token(user)}" - - -_client: Optional[TestClient] = None - - -def client(user="admin") -> TestClient: - global _client - if _client is None: - _client = TestClient() - _client.__enter__() - _client.headers["Content-Type"] = "application/json" - _client.set_user(user) - return _client - - -def shutdown(): - global _client - if _client is not None: - _client.__exit__() diff --git a/packages/api-server/api_server/test/test_data.py b/packages/api-server/api_server/test/test_data.py index cf0f693d2..3a1970c5c 100644 --- a/packages/api-server/api_server/test/test_data.py +++ b/packages/api-server/api_server/test/test_data.py @@ -24,7 +24,6 @@ LiftState, RobotState, TaskBookingLabel, - TaskBookingLabelDescription, TaskEventLog, TaskFavorite, TaskState, @@ -134,17 +133,20 @@ def make_fleet_log() -> FleetLog: def make_task_booking_label() -> TaskBookingLabel: return TaskBookingLabel( - description=TaskBookingLabelDescription( - task_definition_id="multi-delivery", - unix_millis_warn_time=1636388400000, - pickup="Kitchen", - destination="room_203", - cart_id="soda", - ) + description={ + "task_definition_id": "multi-delivery", + "unix_millis_warn_time": 1636388400000, + "pickup": "Kitchen", + "destination": "room_203", + "cart_id": "soda", + } ) -def make_task_state(task_id: str = "test_task") -> TaskState: +def make_task_state( + task_id: str = "test_task", + booking_labels: list[str] | None = None, +) -> TaskState: # from https://raw.githubusercontent.com/open-rmf/rmf_api_msgs/960b286d9849fc716a3043b8e1f5fb341bdf5778/rmf_api_msgs/samples/task_state/multi_dropoff_delivery.json sample_task = json.loads( """ @@ -444,11 +446,12 @@ def make_task_state(task_id: str = "test_task") -> TaskState: ) sample_task["booking"]["id"] = task_id - booking_labels = [ - "dummy_label_1", - "dummy_label_2", - make_task_booking_label().json(), - ] + if booking_labels is None: + booking_labels = [ + "dummy_label_1", + "dummy_label_2", + make_task_booking_label().json(), + ] sample_task["booking"]["labels"] = booking_labels return TaskState(**sample_task) diff --git a/packages/api-server/api_server/test/test_fixtures.py b/packages/api-server/api_server/test/test_fixtures.py index 32a2e6d2f..261bdb923 100644 --- a/packages/api-server/api_server/test/test_fixtures.py +++ b/packages/api-server/api_server/test/test_fixtures.py @@ -9,9 +9,10 @@ from uuid import uuid4 from api_server.app import app, on_sio_connect +from api_server.models import User from .mocks import patch_sio -from .test_client import client +from .test_client import TestClient T = TypeVar("T") @@ -79,8 +80,11 @@ async def async_try_until( class AppFixture(unittest.TestCase): @classmethod def setUpClass(cls): - cls.client = client() - cls.client.set_user("admin") + cls.admin_user = User(username="admin", is_admin=True) + cls.client = TestClient() + cls.client.headers["Content-Type"] = "application/json" + cls.client.__enter__() + cls.addClassCleanup(cls.client.__exit__) def subscribe_sio(self, room: str, *, user="admin"): """ diff --git a/packages/api-server/api_server/test_sio_auth.py b/packages/api-server/api_server/test_sio_auth.py index c48baa611..0ddbda633 100644 --- a/packages/api-server/api_server/test_sio_auth.py +++ b/packages/api-server/api_server/test_sio_auth.py @@ -4,7 +4,6 @@ from api_server.app import app, on_sio_connect -from .test import client from .test.test_fixtures import AppFixture @@ -33,4 +32,4 @@ def test_fail_with_invalid_token(self): self.assertFalse(self.try_connect("invalid")) def test_success_with_valid_token(self): - self.assertTrue(self.try_connect(client().token("admin"))) + self.assertTrue(self.try_connect(self.client.token("admin"))) diff --git a/packages/api-server/scripts/test.py b/packages/api-server/scripts/test.py index 4acdf5e90..e164368e6 100644 --- a/packages/api-server/scripts/test.py +++ b/packages/api-server/scripts/test.py @@ -6,9 +6,5 @@ import unittest -from api_server.test import test_client - -test_client.client() result = unittest.main(module=None, exit=False) -test_client.shutdown() exit(1 if not result.result.wasSuccessful() else 0) diff --git a/packages/dashboard/src/components/tasks/tasks-app.tsx b/packages/dashboard/src/components/tasks/tasks-app.tsx index 2716e99cc..3980d3e37 100644 --- a/packages/dashboard/src/components/tasks/tasks-app.tsx +++ b/packages/dashboard/src/components/tasks/tasks-app.tsx @@ -195,6 +195,7 @@ export const TasksApp = React.memo( filterColumn && filterColumn === 'destination' ? filterValue : undefined, filterColumn && filterColumn === 'assigned_to' ? filterValue : undefined, filterColumn && filterColumn === 'status' ? filterValue : undefined, + undefined, filterColumn && filterColumn === 'unix_millis_request_time' ? filterValue : undefined, filterColumn && filterColumn === 'unix_millis_start_time' ? filterValue : undefined, filterColumn && filterColumn === 'unix_millis_finish_time' ? filterValue : undefined, @@ -263,6 +264,7 @@ export const TasksApp = React.memo( undefined, undefined, undefined, + undefined, `${currentMillis - oneMonthMillis},${currentMillis}`, undefined, QueryLimit,