Skip to content

Commit 1e31db8

Browse files
authored
RSDK-7192 - Provisioning wrappers (#577)
1 parent 313ca7d commit 1e31db8

File tree

4 files changed

+246
-0
lines changed

4 files changed

+246
-0
lines changed
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
from typing import Mapping, List, Optional
2+
3+
from grpclib.client import Channel
4+
5+
from viam import logging
6+
from viam.proto.provisioning import (
7+
CloudConfig,
8+
GetNetworkListRequest,
9+
GetNetworkListResponse,
10+
GetSmartMachineStatusRequest,
11+
GetSmartMachineStatusResponse,
12+
NetworkInfo,
13+
ProvisioningServiceStub,
14+
SetNetworkCredentialsRequest,
15+
SetSmartMachineCredentialsRequest,
16+
)
17+
18+
LOGGER = logging.getLogger(__name__)
19+
20+
21+
class ProvisioningClient:
22+
"""gRPC client for getting and setting smart machine info.
23+
24+
Constructor is used by `ViamClient` to instantiate relevant service stubs. Calls to
25+
`ProvisioningClient` methods should be made through `ViamClient`.
26+
27+
Establish a connection::
28+
29+
import asyncio
30+
31+
from viam.rpc.dial import DialOptions, Credentials
32+
from viam.app.viam_client import ViamClient
33+
34+
35+
async def connect() -> ViamClient:
36+
# Replace "<API-KEY>" (including brackets) with your API key and "<API-KEY-ID>" with your API key ID
37+
dial_options = DialOptions.with_api_key("<API-KEY>", "<API-KEY-ID>")
38+
return await ViamClient.create_from_dial_options(dial_options)
39+
40+
41+
async def main():
42+
43+
# Make a ViamClient
44+
viam_client = await connect()
45+
# Instantiate a ProvisioningClient to run provisioning client API methods on
46+
provisioning_client = viam_client.provisioning_client
47+
48+
viam_client.close()
49+
50+
if __name__ == '__main__':
51+
asyncio.run(main())
52+
53+
"""
54+
55+
def __init__(self, channel: Channel, metadata: Mapping[str, str]):
56+
"""Create a `ProvisioningClient` that maintains a connection to app.
57+
58+
Args:
59+
channel (grpclib.client.Channel): Connection to app.
60+
metadata (Mapping[str, str]): Required authorization token to send requests to app.
61+
"""
62+
self._metadata = metadata
63+
self._provisioning_client = ProvisioningServiceStub(channel)
64+
self._channel = channel
65+
66+
_provisioning_client: ProvisioningServiceStub
67+
_metadata: Mapping[str, str]
68+
_channel: Channel
69+
70+
async def get_network_list(self) -> List[NetworkInfo]:
71+
"""Returns list of networks that are visible to the Smart Machine."""
72+
request = GetNetworkListRequest()
73+
resp: GetNetworkListResponse = await self._provisioning_client.GetNetworkList(request, metadata=self._metadata)
74+
return list(resp.networks)
75+
76+
async def get_smart_machine_status(self) -> GetSmartMachineStatusResponse:
77+
"""Returns the status of the smart machine."""
78+
request = GetSmartMachineStatusRequest()
79+
return await self._provisioning_client.GetSmartMachineStatus(request, metadata=self._metadata)
80+
81+
async def set_network_credentials(self, network_type: str, ssid: str, psk: str) -> None:
82+
"""Sets the network credentials of the Smart Machine.
83+
84+
Args:
85+
network_type (str): The type of the network.
86+
ssid (str): The SSID of the network.
87+
psk (str): The network's passkey.
88+
"""
89+
90+
request = SetNetworkCredentialsRequest(type=network_type, ssid=ssid, psk=psk)
91+
await self._provisioning_client.SetNetworkCredentials(request, metadata=self._metadata)
92+
93+
async def set_smart_machine_credentials(self, cloud_config: Optional[CloudConfig] = None) -> None:
94+
request = SetSmartMachineCredentialsRequest(cloud=cloud_config)
95+
await self._provisioning_client.SetSmartMachineCredentials(request, metadata=self._metadata)

src/viam/app/viam_client.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from viam.app.billing_client import BillingClient
99
from viam.app.data_client import DataClient
1010
from viam.app.ml_training_client import MLTrainingClient
11+
from viam.app.provisioning_client import ProvisioningClient
1112
from viam.rpc.dial import DialOptions, _dial_app, _get_access_token
1213

1314
LOGGER = logging.getLogger(__name__)
@@ -149,6 +150,27 @@ async def main():
149150

150151
return BillingClient(self._channel, self._metadata)
151152

153+
@property
154+
def provisioning_client(self) -> ProvisioningClient:
155+
"""Instantiate and return a `ProvisioningClient` used to make `provisioning` method calls.
156+
To use the `ProvisioningClient`, you must first instantiate a `ViamClient`.
157+
158+
::
159+
160+
async def connect() -> ViamClient:
161+
# Replace "<API-KEY>" (including brackets) with your API key and "<API-KEY-ID>" with your API key ID
162+
dial_options = DialOptions.with_api_key("<API-KEY>", "<API-KEY-ID>")
163+
return await ViamClient.create_from_dial_options(dial_options)
164+
165+
166+
async def main():
167+
viam_client = await connect()
168+
169+
# Instantiate a ProvisioningClient to run provisioning API methods on
170+
provisioning_client = viam_client.provisioning_client
171+
"""
172+
return ProvisioningClient(self._channel, self._metadata)
173+
152174
def close(self):
153175
"""Close opened channels used for the various service stubs initialized."""
154176
if self._closed:

tests/mocks/services.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,18 @@
279279
FlatTensorDataUInt64,
280280
FlatTensors,
281281
)
282+
from viam.proto.provisioning import (
283+
NetworkInfo,
284+
ProvisioningServiceBase,
285+
GetNetworkListRequest,
286+
GetNetworkListResponse,
287+
GetSmartMachineStatusRequest,
288+
GetSmartMachineStatusResponse,
289+
SetNetworkCredentialsRequest,
290+
SetNetworkCredentialsResponse,
291+
SetSmartMachineCredentialsRequest,
292+
SetSmartMachineCredentialsResponse,
293+
)
282294
from viam.proto.service.motion import (
283295
Constraints,
284296
GetPlanRequest,
@@ -698,6 +710,43 @@ async def do_command(self, command: Mapping[str, ValueTypes], *, timeout: Option
698710
return {"command": command}
699711

700712

713+
class MockProvisioning(ProvisioningServiceBase):
714+
def __init__(
715+
self,
716+
smart_machine_status: GetSmartMachineStatusResponse,
717+
network_info: List[NetworkInfo],
718+
):
719+
self.smart_machine_status = smart_machine_status
720+
self.network_info = network_info
721+
722+
async def GetNetworkList(self, stream: Stream[GetNetworkListRequest, GetNetworkListResponse]) -> None:
723+
request = await stream.recv_message()
724+
assert request is not None
725+
await stream.send_message(GetNetworkListResponse(networks=self.network_info))
726+
727+
async def GetSmartMachineStatus(self, stream: Stream[GetSmartMachineStatusRequest, GetSmartMachineStatusResponse]) -> None:
728+
request = await stream.recv_message()
729+
assert request is not None
730+
await stream.send_message(self.smart_machine_status)
731+
732+
async def SetNetworkCredentials(self, stream: Stream[SetNetworkCredentialsRequest, SetNetworkCredentialsResponse]) -> None:
733+
request = await stream.recv_message()
734+
assert request is not None
735+
self.network_type = request.type
736+
self.ssid = request.ssid
737+
self.psk = request.psk
738+
await stream.send_message(SetNetworkCredentialsResponse())
739+
740+
async def SetSmartMachineCredentials(
741+
self,
742+
stream: Stream[SetSmartMachineCredentialsRequest, SetSmartMachineCredentialsResponse],
743+
) -> None:
744+
request = await stream.recv_message()
745+
assert request is not None
746+
self.cloud_config = request.cloud
747+
await stream.send_message(SetSmartMachineCredentialsResponse())
748+
749+
701750
class MockData(DataServiceBase):
702751
def __init__(
703752
self,

tests/test_provisioning_client.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
import pytest
2+
3+
from grpclib.testing import ChannelFor
4+
5+
from viam.app.provisioning_client import ProvisioningClient
6+
7+
from viam.proto.provisioning import GetSmartMachineStatusResponse, NetworkInfo, ProvisioningInfo, CloudConfig
8+
9+
from .mocks.services import MockProvisioning
10+
11+
ID = "id"
12+
MODEL = "model"
13+
MANUFACTURER = "acme"
14+
PROVISIONING_INFO = ProvisioningInfo(fragment_id=ID, model=MODEL, manufacturer=MANUFACTURER)
15+
HAS_CREDENTIALS = True
16+
IS_ONLINE = True
17+
NETWORK_TYPE = "type"
18+
SSID = "ssid"
19+
ERROR = "error"
20+
ERRORS = [ERROR]
21+
PSK = "psk"
22+
SECRET = "secret"
23+
APP_ADDRESS = "address"
24+
NETWORK_INFO_LATEST = NetworkInfo(
25+
type=NETWORK_TYPE,
26+
ssid=SSID,
27+
security="security",
28+
signal=12,
29+
connected=IS_ONLINE,
30+
last_error=ERROR,
31+
)
32+
NETWORK_INFO = [NETWORK_INFO_LATEST]
33+
SMART_MACHINE_STATUS_RESPONSE = GetSmartMachineStatusResponse(
34+
provisioning_info=PROVISIONING_INFO,
35+
has_smart_machine_credentials=HAS_CREDENTIALS,
36+
is_online=IS_ONLINE,
37+
latest_connection_attempt=NETWORK_INFO_LATEST,
38+
errors=ERRORS
39+
)
40+
CLOUD_CONFIG = CloudConfig(id=ID, secret=SECRET, app_address=APP_ADDRESS)
41+
42+
AUTH_TOKEN = "auth_token"
43+
PROVISIONING_SERVICE_METADATA = {"authorization": f"Bearer {AUTH_TOKEN}"}
44+
45+
46+
@pytest.fixture(scope="function")
47+
def service() -> MockProvisioning:
48+
return MockProvisioning(smart_machine_status=SMART_MACHINE_STATUS_RESPONSE, network_info=NETWORK_INFO)
49+
50+
51+
class TestClient:
52+
@pytest.mark.asyncio
53+
async def test_get_network_list(self, service: MockProvisioning):
54+
async with ChannelFor([service]) as channel:
55+
client = ProvisioningClient(channel, PROVISIONING_SERVICE_METADATA)
56+
network_info = await client.get_network_list()
57+
assert network_info == NETWORK_INFO
58+
59+
@pytest.mark.asyncio
60+
async def test_get_smart_machine_status(self, service: MockProvisioning):
61+
async with ChannelFor([service]) as channel:
62+
client = ProvisioningClient(channel, PROVISIONING_SERVICE_METADATA)
63+
smart_machine_status = await client.get_smart_machine_status()
64+
assert smart_machine_status == SMART_MACHINE_STATUS_RESPONSE
65+
66+
@pytest.mark.asyncio
67+
async def test_set_network_credentials(self, service: MockProvisioning):
68+
async with ChannelFor([service]) as channel:
69+
client = ProvisioningClient(channel, PROVISIONING_SERVICE_METADATA)
70+
await client.set_network_credentials(network_type=NETWORK_TYPE, ssid=SSID, psk=PSK)
71+
assert service.network_type == NETWORK_TYPE
72+
assert service.ssid == SSID
73+
assert service.psk == PSK
74+
75+
@pytest.mark.asyncio
76+
async def test_set_smart_machine_credentials(self, service: MockProvisioning):
77+
async with ChannelFor([service]) as channel:
78+
client = ProvisioningClient(channel, PROVISIONING_SERVICE_METADATA)
79+
await client.set_smart_machine_credentials(cloud_config=CLOUD_CONFIG)
80+
assert service.cloud_config == CLOUD_CONFIG

0 commit comments

Comments
 (0)