Skip to content

Commit 00779a4

Browse files
viambotstuqdognjooma
authored
AI update based on proto changes (#1005)
Co-authored-by: Ethan <[email protected]> Co-authored-by: Naveed Jooma <[email protected]>
1 parent 5b512de commit 00779a4

File tree

4 files changed

+201
-8
lines changed

4 files changed

+201
-8
lines changed

src/viam/app/data_client.py

Lines changed: 96 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,14 @@
2727
CaptureInterval,
2828
CaptureMetadata,
2929
ConfigureDatabaseUserRequest,
30+
CreateIndexRequest,
3031
DataRequest,
3132
DataServiceStub,
3233
DeleteBinaryDataByFilterRequest,
3334
DeleteBinaryDataByFilterResponse,
3435
DeleteBinaryDataByIDsRequest,
3536
DeleteBinaryDataByIDsResponse,
37+
DeleteIndexRequest,
3638
DeleteTabularDataRequest,
3739
DeleteTabularDataResponse,
3840
ExportTabularDataRequest,
@@ -42,6 +44,10 @@
4244
GetDatabaseConnectionResponse,
4345
GetLatestTabularDataRequest,
4446
GetLatestTabularDataResponse,
47+
Index,
48+
IndexableCollection,
49+
ListIndexesRequest,
50+
ListIndexesResponse,
4551
Order,
4652
RemoveBinaryDataFromDatasetByIDsRequest,
4753
RemoveBoundingBoxFromImageByIDRequest,
@@ -499,6 +505,7 @@ async def tabular_data_by_mql(
499505
use_recent_data: Optional[bool] = None,
500506
tabular_data_source_type: TabularDataSourceType.ValueType = TabularDataSourceType.TABULAR_DATA_SOURCE_TYPE_STANDARD,
501507
pipeline_id: Optional[str] = None,
508+
query_prefix_name: Optional[str] = None,
502509
) -> List[Dict[str, Union[ValueTypes, datetime]]]:
503510
"""Obtain unified tabular data and metadata, queried with MQL.
504511
@@ -525,6 +532,7 @@ async def tabular_data_by_mql(
525532
Defaults to `TABULAR_DATA_SOURCE_TYPE_STANDARD`.
526533
pipeline_id (str): The ID of the data pipeline to query. Defaults to `None`.
527534
Required if `tabular_data_source_type` is `TABULAR_DATA_SOURCE_TYPE_PIPELINE_SINK`.
535+
query_prefix_name (str): Optional field that can be used to specify a saved query to run.
528536
529537
Returns:
530538
List[Dict[str, Union[ValueTypes, datetime]]]: An array of decoded BSON data objects.
@@ -535,7 +543,12 @@ async def tabular_data_by_mql(
535543
data_source = TabularDataSource(type=tabular_data_source_type, pipeline_id=pipeline_id)
536544
if use_recent_data:
537545
data_source.type = TabularDataSourceType.TABULAR_DATA_SOURCE_TYPE_HOT_STORAGE
538-
request = TabularDataByMQLRequest(organization_id=organization_id, mql_binary=binary, data_source=data_source)
546+
request = TabularDataByMQLRequest(
547+
organization_id=organization_id,
548+
mql_binary=binary,
549+
data_source=data_source,
550+
query_prefix_name=query_prefix_name,
551+
)
539552
response: TabularDataByMQLResponse = await self._data_client.TabularDataByMQL(request, metadata=self._metadata)
540553
return [bson.decode(bson_bytes) for bson_bytes in response.raw_data]
541554

@@ -1131,8 +1144,8 @@ async def add_bounding_box_to_image_by_id(
11311144
11321145
Args:
11331146
binary_id (Union[~viam.proto.app.data.BinaryID, str]): The binary data ID or :class:`BinaryID` of the image to add the bounding
1134-
box to. *DEPRECATED:* :class:`BinaryID` *is deprecated and will be removed in a future release. Instead, pass binary data
1135-
IDs as a list of strings.*
1147+
box to. *DEPRECATED:* :class:`BinaryID` *is deprecated and will be removed in a future release. Instead, pass binary data IDs as a
1148+
list of strings.*
11361149
label (str): A label for the bounding box.
11371150
x_min_normalized (float): Min X value of the bounding box normalized from 0 to 1.
11381151
y_min_normalized (float): Min Y value of the bounding box normalized from 0 to 1.
@@ -2061,6 +2074,86 @@ async def _list_data_pipeline_runs(self, id: str, page_size: int, page_token: st
20612074
response: ListDataPipelineRunsResponse = await self._data_pipelines_client.ListDataPipelineRuns(request, metadata=self._metadata)
20622075
return DataClient.DataPipelineRunsPage.from_proto(response, self, page_size)
20632076

2077+
async def create_index(
2078+
self,
2079+
organization_id: str,
2080+
collection_type: IndexableCollection.ValueType,
2081+
index_spec: Dict[str, Any],
2082+
pipeline_name: Optional[str] = None,
2083+
) -> None:
2084+
"""Starts a custom index build.
2085+
2086+
Args:
2087+
organization_id (str): The ID of the organization that owns the data.
2088+
To find your organization ID, visit the organization settings page.
2089+
collection_type (IndexableCollection.ValueType): The type of collection the index is on.
2090+
index_spec (List[Dict[str, Any]]): The MongoDB index specification defined in JSON format.
2091+
pipeline_name (Optional[str]): The name of the pipeline if the collection type is PIPELINE_SINK.
2092+
2093+
For more information, see `Data Client API <https://docs.viam.com/dev/reference/apis/data-client/#createindex>`_.
2094+
"""
2095+
index_spec_bytes = [bson.encode(index_spec)]
2096+
request = CreateIndexRequest(
2097+
organization_id=organization_id,
2098+
collection_type=collection_type,
2099+
index_spec=index_spec_bytes,
2100+
pipeline_name=pipeline_name,
2101+
)
2102+
await self._data_client.CreateIndex(request, metadata=self._metadata)
2103+
2104+
async def list_indexes(
2105+
self,
2106+
organization_id: str,
2107+
collection_type: IndexableCollection.ValueType,
2108+
pipeline_name: Optional[str] = None,
2109+
) -> Sequence[Index]:
2110+
"""Returns all the indexes for a given collection.
2111+
2112+
Args:
2113+
organization_id (str): The ID of the organization that owns the data.
2114+
To find your organization ID, visit the organization settings page.
2115+
collection_type (IndexableCollection.ValueType): The type of collection the index is on.
2116+
pipeline_name (Optional[str]): The name of the pipeline if the collection type is PIPELINE_SINK.
2117+
2118+
Returns:
2119+
List[Index]: A list of indexes.
2120+
2121+
For more information, see `Data Client API <https://docs.viam.com/dev/reference/apis/data-client/#listindexes>`_.
2122+
"""
2123+
request = ListIndexesRequest(
2124+
organization_id=organization_id,
2125+
collection_type=collection_type,
2126+
pipeline_name=pipeline_name,
2127+
)
2128+
response: ListIndexesResponse = await self._data_client.ListIndexes(request, metadata=self._metadata)
2129+
return response.indexes
2130+
2131+
async def delete_index(
2132+
self,
2133+
organization_id: str,
2134+
collection_type: IndexableCollection.ValueType,
2135+
index_name: str,
2136+
pipeline_name: Optional[str] = None,
2137+
) -> None:
2138+
"""Drops the specified custom index from a collection.
2139+
2140+
Args:
2141+
organization_id (str): The ID of the organization that owns the data.
2142+
To find your organization ID, visit the organization settings page.
2143+
collection_type (IndexableCollection.ValueType): The type of collection the index is on.
2144+
index_name (str): The name of the index to delete.
2145+
pipeline_name (Optional[str]): The name of the pipeline if the collection type is PIPELINE_SINK.
2146+
2147+
For more information, see `Data Client API <https://docs.viam.com/dev/reference/apis/data-client/#deleteindex>`_.
2148+
"""
2149+
request = DeleteIndexRequest(
2150+
organization_id=organization_id,
2151+
collection_type=collection_type,
2152+
index_name=index_name,
2153+
pipeline_name=pipeline_name,
2154+
)
2155+
await self._data_client.DeleteIndex(request, metadata=self._metadata)
2156+
20642157
@staticmethod
20652158
def create_filter(
20662159
component_name: Optional[str] = None,

src/viam/services/worldstatestore/worldstatestore.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,6 @@ async def stream_transform_changes(
8585
print(f"Transform {change.transform.uuid} {change.change_type}")
8686
8787
Returns:
88-
AsyncIterator[StreamTransformChangesResponse]: A stream of transform changes
88+
AsyncGenerator[StreamTransformChangesResponse, None]: A stream of transform changes
8989
"""
9090
...

tests/mocks/services.py

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from datetime import datetime
2-
from typing import Any, AsyncIterator, Dict, List, Mapping, Optional, Sequence, Union
2+
from typing import Any, AsyncGenerator, Dict, List, Mapping, Optional, Sequence, Union
33

44
import bson
55
import numpy as np
@@ -213,10 +213,14 @@
213213
BoundingBoxLabelsByFilterResponse,
214214
ConfigureDatabaseUserRequest,
215215
ConfigureDatabaseUserResponse,
216+
CreateIndexRequest,
217+
CreateIndexResponse,
216218
DeleteBinaryDataByFilterRequest,
217219
DeleteBinaryDataByFilterResponse,
218220
DeleteBinaryDataByIDsRequest,
219221
DeleteBinaryDataByIDsResponse,
222+
DeleteIndexRequest,
223+
DeleteIndexResponse,
220224
DeleteTabularDataRequest,
221225
DeleteTabularDataResponse,
222226
ExportTabularDataRequest,
@@ -225,6 +229,9 @@
225229
GetDatabaseConnectionResponse,
226230
GetLatestTabularDataRequest,
227231
GetLatestTabularDataResponse,
232+
Index,
233+
ListIndexesRequest,
234+
ListIndexesResponse,
228235
RemoveBinaryDataFromDatasetByIDsRequest,
229236
RemoveBinaryDataFromDatasetByIDsResponse,
230237
RemoveBoundingBoxFromImageByIDRequest,
@@ -379,7 +386,7 @@
379386
class MockVision(Vision):
380387
def __init__(
381388
self,
382-
name: str,
389+
name,
383390
detectors: List[str],
384391
detections: List[Detection],
385392
classifiers: List[str],
@@ -845,6 +852,7 @@ def __init__(
845852
bbox_labels_response: List[str],
846853
hostname_response: str,
847854
additional_params: Mapping[str, ValueTypes],
855+
list_indexes_response: List[Index],
848856
):
849857
self.tabular_response = tabular_response
850858
self.tabular_export_response = tabular_export_response
@@ -857,6 +865,7 @@ def __init__(
857865
self.was_tabular_data_requested = False
858866
self.was_binary_data_requested = False
859867
self.additional_params = additional_params
868+
self.list_indexes_response = list_indexes_response
860869

861870
async def TabularDataByFilter(self, stream: Stream[TabularDataByFilterRequest, TabularDataByFilterResponse]) -> None:
862871
request = await stream.recv_message()
@@ -1043,6 +1052,7 @@ async def TabularDataBySQL(self, stream: Stream[TabularDataBySQLRequest, Tabular
10431052
async def TabularDataByMQL(self, stream: Stream[TabularDataByMQLRequest, TabularDataByMQLResponse]) -> None:
10441053
request = await stream.recv_message()
10451054
assert request is not None
1055+
self.query_prefix_name = request.query_prefix_name if request.HasField("query_prefix_name") else None
10461056
await stream.send_message(TabularDataByMQLResponse(raw_data=[bson.encode(dict) for dict in self.tabular_query_response]))
10471057

10481058
async def GetLatestTabularData(self, stream: Stream[GetLatestTabularDataRequest, GetLatestTabularDataResponse]) -> None:
@@ -1068,6 +1078,24 @@ async def ExportTabularData(self, stream: Stream[ExportTabularDataRequest, Expor
10681078
for tabular_data in self.tabular_export_response:
10691079
await stream.send_message(tabular_data)
10701080

1081+
async def CreateIndex(self, stream: Stream[CreateIndexRequest, CreateIndexResponse]) -> None:
1082+
request = await stream.recv_message()
1083+
assert request is not None
1084+
self.create_index_request = request
1085+
await stream.send_message(CreateIndexResponse())
1086+
1087+
async def ListIndexes(self, stream: Stream[ListIndexesRequest, ListIndexesResponse]) -> None:
1088+
request = await stream.recv_message()
1089+
assert request is not None
1090+
self.list_indexes_request = request
1091+
await stream.send_message(ListIndexesResponse(indexes=self.list_indexes_response))
1092+
1093+
async def DeleteIndex(self, stream: Stream[DeleteIndexRequest, DeleteIndexResponse]) -> None:
1094+
request = await stream.recv_message()
1095+
assert request is not None
1096+
self.delete_index_request = request
1097+
await stream.send_message(DeleteIndexResponse())
1098+
10711099

10721100
class MockDataset(DatasetServiceBase):
10731101
def __init__(self, create_response: str, datasets_response: Sequence[Dataset], merged_response: Optional[str] = None):
@@ -1931,7 +1959,7 @@ async def stream_transform_changes(
19311959
*,
19321960
extra: Optional[Mapping[str, Any]] = None,
19331961
timeout: Optional[float] = None,
1934-
) -> AsyncIterator["StreamTransformChangesResponse"]:
1962+
) -> AsyncGenerator[StreamTransformChangesResponse, None]:
19351963
self.extra = extra
19361964
self.timeout = timeout
19371965

tests/test_data_client.py

Lines changed: 73 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from datetime import datetime
22
from typing import List
33

4+
import bson
45
import pytest
56
from google.protobuf.timestamp_pb2 import Timestamp
67
from grpclib.testing import ChannelFor
@@ -16,6 +17,9 @@
1617
CaptureMetadata,
1718
ExportTabularDataResponse,
1819
Filter,
20+
Index,
21+
IndexableCollection,
22+
IndexCreator,
1923
Order,
2024
)
2125
from viam.utils import create_filter, dict_to_struct, struct_to_dict
@@ -142,6 +146,23 @@
142146
TAGS_RESPONSE = ["tag"]
143147
HOSTNAME_RESPONSE = "host"
144148

149+
INDEX_NAME = "my_index"
150+
INDEX_SPEC = {"key": 1}
151+
INDEX_SPEC_BYTES = [bson.encode(INDEX_SPEC)]
152+
COLLECTION_TYPE = IndexableCollection.INDEXABLE_COLLECTION_PIPELINE_SINK
153+
PIPELINE_NAME = "my_pipeline"
154+
INDEX_CREATED_BY = IndexCreator.INDEX_CREATOR_CUSTOMER
155+
LIST_INDEXES_RESPONSE = [
156+
Index(
157+
collection_type=COLLECTION_TYPE,
158+
pipeline_name=PIPELINE_NAME,
159+
index_name=INDEX_NAME,
160+
index_spec=INDEX_SPEC_BYTES,
161+
created_by=INDEX_CREATED_BY,
162+
)
163+
]
164+
QUERY_PREFIX_NAME = "my_saved_query"
165+
145166
AUTH_TOKEN = "auth_token"
146167
DATA_SERVICE_METADATA = {"authorization": f"Bearer {AUTH_TOKEN}"}
147168

@@ -158,6 +179,7 @@ def service() -> MockData:
158179
bbox_labels_response=BBOX_LABELS,
159180
hostname_response=HOSTNAME_RESPONSE,
160181
additional_params=ADDITIONAL_PARAMS,
182+
list_indexes_response=LIST_INDEXES_RESPONSE,
161183
)
162184

163185

@@ -195,9 +217,10 @@ async def test_tabular_data_by_sql(self, service: MockData):
195217
async def test_tabular_data_by_mql(self, service: MockData):
196218
async with ChannelFor([service]) as channel:
197219
client = DataClient(channel, DATA_SERVICE_METADATA)
198-
response = await client.tabular_data_by_mql(ORG_ID, MQL_BINARY)
220+
response = await client.tabular_data_by_mql(ORG_ID, MQL_BINARY, query_prefix_name=QUERY_PREFIX_NAME)
199221
assert isinstance(response[0]["key1"], datetime)
200222
assert response == TABULAR_QUERY_RESPONSE
223+
assert service.query_prefix_name == QUERY_PREFIX_NAME
201224
response = await client.tabular_data_by_mql(ORG_ID, mql_binary=[b"mql_binary"])
202225
assert isinstance(response[0]["key1"], datetime)
203226
assert response == TABULAR_QUERY_RESPONSE
@@ -429,6 +452,55 @@ async def test_remove_binary_data_to_dataset_by_ids(self, service: MockData):
429452
assert service.removed_binary_data_ids == BINARY_DATA_IDS
430453
assert service.dataset_id == DATASET_ID
431454

455+
async def test_create_index(self, service: MockData):
456+
async with ChannelFor([service]) as channel:
457+
client = DataClient(channel, DATA_SERVICE_METADATA)
458+
await client.create_index(
459+
organization_id=ORG_ID,
460+
collection_type=COLLECTION_TYPE,
461+
index_spec=INDEX_SPEC,
462+
pipeline_name=PIPELINE_NAME,
463+
)
464+
assert service.create_index_request.organization_id == ORG_ID
465+
assert service.create_index_request.collection_type == COLLECTION_TYPE
466+
assert service.create_index_request.index_spec == INDEX_SPEC_BYTES
467+
assert service.create_index_request.pipeline_name == PIPELINE_NAME
468+
469+
async def test_list_indexes(self, service: MockData):
470+
async with ChannelFor([service]) as channel:
471+
client = DataClient(channel, DATA_SERVICE_METADATA)
472+
indexes = await client.list_indexes(
473+
organization_id=ORG_ID,
474+
collection_type=COLLECTION_TYPE,
475+
pipeline_name=PIPELINE_NAME,
476+
)
477+
assert service.list_indexes_request.organization_id == ORG_ID
478+
assert service.list_indexes_request.collection_type == COLLECTION_TYPE
479+
assert service.list_indexes_request.pipeline_name == PIPELINE_NAME
480+
481+
assert len(indexes) == len(LIST_INDEXES_RESPONSE)
482+
for i, index in enumerate(indexes):
483+
expected_index = LIST_INDEXES_RESPONSE[i]
484+
assert index.collection_type == expected_index.collection_type
485+
assert index.pipeline_name == expected_index.pipeline_name
486+
assert index.index_name == expected_index.index_name
487+
assert index.index_spec == expected_index.index_spec
488+
assert index.created_by == expected_index.created_by
489+
490+
async def test_delete_index(self, service: MockData):
491+
async with ChannelFor([service]) as channel:
492+
client = DataClient(channel, DATA_SERVICE_METADATA)
493+
await client.delete_index(
494+
organization_id=ORG_ID,
495+
collection_type=COLLECTION_TYPE,
496+
index_name=INDEX_NAME,
497+
pipeline_name=PIPELINE_NAME,
498+
)
499+
assert service.delete_index_request.organization_id == ORG_ID
500+
assert service.delete_index_request.collection_type == COLLECTION_TYPE
501+
assert service.delete_index_request.index_name == INDEX_NAME
502+
assert service.delete_index_request.pipeline_name == PIPELINE_NAME
503+
432504
def assert_filter(self, filter: Filter) -> None:
433505
assert filter.component_name == COMPONENT_NAME
434506
assert filter.component_type == COMPONENT_TYPE

0 commit comments

Comments
 (0)