diff --git a/python_otbr_api/__init__.py b/python_otbr_api/__init__.py index b3bbf44..a4c4e1e 100644 --- a/python_otbr_api/__init__.py +++ b/python_otbr_api/__init__.py @@ -4,11 +4,17 @@ import aiohttp +from .models import OperationalDataSet + class OTBRError(Exception): """Raised on error.""" +class ThreadNetworkActiveError(OTBRError): + """Raised on attempts to modify the active dataset when thread network is active.""" + + class OTBR: # pylint: disable=too-few-public-methods """Class to interact with the Open Thread Border Router REST API.""" @@ -20,6 +26,18 @@ def __init__( self._url = url self._timeout = timeout + async def set_enabled(self, enabled: bool) -> None: + """Enable or disable the router.""" + + response = await self._session.post( + f"{self._url}/node/state", + json="enabled" if enabled else "disabled", + timeout=aiohttp.ClientTimeout(total=10), + ) + + if response.status != HTTPStatus.OK: + raise OTBRError(f"unexpected http status {response.status}") + async def get_active_dataset_tlvs(self) -> bytes | None: """Get current active operational dataset in TLVS format, or None. @@ -29,7 +47,6 @@ async def get_active_dataset_tlvs(self) -> bytes | None: response = await self._session.get( f"{self._url}/node/dataset/active", headers={"Accept": "text/plain"}, - raise_for_status=True, timeout=aiohttp.ClientTimeout(total=self._timeout), ) @@ -43,3 +60,40 @@ async def get_active_dataset_tlvs(self) -> bytes | None: return bytes.fromhex(await response.text("ASCII")) except ValueError as exc: raise OTBRError("unexpected API response") from exc + + async def create_active_dataset(self, dataset: OperationalDataSet) -> None: + """Create active operational dataset. + + The passed in OperationalDataSet does not need to be fully populated, any fields + not set will be automatically set by the open thread border router. + Raises if the http status is 400 or higher or if the response is invalid. + """ + + response = await self._session.post( + f"{self._url}/node/dataset/active", + json=dataset.as_json(), + timeout=aiohttp.ClientTimeout(total=self._timeout), + ) + + if response.status == HTTPStatus.CONFLICT: + raise ThreadNetworkActiveError + if response.status != HTTPStatus.ACCEPTED: + raise OTBRError(f"unexpected http status {response.status}") + + async def set_active_dataset_tlvs(self, dataset: bytes) -> None: + """Set current active operational dataset. + + Raises if the http status is 400 or higher or if the response is invalid. + """ + + response = await self._session.put( + f"{self._url}/node/dataset/active", + data=dataset.hex(), + headers={"Content-Type": "text/plain"}, + timeout=aiohttp.ClientTimeout(total=10), + ) + + if response.status == HTTPStatus.CONFLICT: + raise ThreadNetworkActiveError + if response.status != HTTPStatus.ACCEPTED: + raise OTBRError(f"unexpected http status {response.status}") diff --git a/python_otbr_api/models.py b/python_otbr_api/models.py new file mode 100644 index 0000000..8ab7c43 --- /dev/null +++ b/python_otbr_api/models.py @@ -0,0 +1,211 @@ +"""Data models.""" +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any + +import voluptuous as vol # type:ignore[import] + + +@dataclass +class Timestamp: + """Timestamp.""" + + SCHEMA = vol.Schema( + { + vol.Optional("Authoritative"): bool, + vol.Optional("Seconds"): int, + vol.Optional("Ticks"): int, + } + ) + + authoritative: bool | None = None + seconds: int | None = None + ticks: int | None = None + + def as_json(self) -> dict: + """Serialize to JSON.""" + result: dict[str, Any] = {} + if self.authoritative is not None: + result["Authoritative"] = self.authoritative + if self.seconds is not None: + result["Seconds"] = self.seconds + if self.ticks is not None: + result["Ticks"] = self.ticks + return result + + @classmethod + def from_json(cls, json_data: Any) -> Timestamp: + """Deserialize from JSON.""" + cls.SCHEMA(json_data) + return cls( + json_data.get("Authoritative"), + json_data.get("Seconds"), + json_data.get("Ticks"), + ) + + +@dataclass +class SecurityPolicy: # pylint: disable=too-many-instance-attributes + """Security policy.""" + + SCHEMA = vol.Schema( + { + vol.Optional("AutonomousEnrollment"): bool, + vol.Optional("CommercialCommissioning"): bool, + vol.Optional("ExternalCommissioning"): bool, + vol.Optional("NativeCommissioning"): bool, + vol.Optional("NetworkKeyProvisioning"): bool, + vol.Optional("NonCcmRouters"): bool, + vol.Optional("ObtainNetworkKey"): bool, + vol.Optional("RotationTime"): int, + vol.Optional("Routers"): bool, + vol.Optional("TobleLink"): bool, + } + ) + + autonomous_enrollment: bool | None = None + commercial_commissioning: bool | None = None + external_commissioning: bool | None = None + native_commissioning: bool | None = None + network_key_provisioning: bool | None = None + non_ccm_routers: bool | None = None + obtain_network_key: bool | None = None + rotation_time: int | None = None + routers: bool | None = None + to_ble_link: bool | None = None + + def as_json(self) -> dict: + """Serialize to JSON.""" + result: dict[str, Any] = {} + if self.autonomous_enrollment is not None: + result["AutonomousEnrollment"] = self.autonomous_enrollment + if self.commercial_commissioning is not None: + result["CommercialCommissioning"] = self.commercial_commissioning + if self.external_commissioning is not None: + result["ExternalCommissioning"] = self.external_commissioning + if self.native_commissioning is not None: + result["NativeCommissioning"] = self.native_commissioning + if self.network_key_provisioning is not None: + result["NetworkKeyProvisioning"] = self.network_key_provisioning + if self.non_ccm_routers is not None: + result["NonCcmRouters"] = self.non_ccm_routers + if self.obtain_network_key is not None: + result["ObtainNetworkKey"] = self.obtain_network_key + if self.rotation_time is not None: + result["RotationTime"] = self.rotation_time + if self.routers is not None: + result["Routers"] = self.routers + if self.to_ble_link is not None: + result["TobleLink"] = self.to_ble_link + return result + + @classmethod + def from_json(cls, json_data: Any) -> SecurityPolicy: + """Deserialize from JSON.""" + cls.SCHEMA(json_data) + return cls( + json_data.get("AutonomousEnrollment"), + json_data.get("CommercialCommissioning"), + json_data.get("ExternalCommissioning"), + json_data.get("NativeCommissioning"), + json_data.get("NetworkKeyProvisioning"), + json_data.get("NonCcmRouters"), + json_data.get("ObtainNetworkKey"), + json_data.get("RotationTime"), + json_data.get("Routers"), + json_data.get("TobleLink"), + ) + + +@dataclass +class OperationalDataSet: # pylint: disable=too-many-instance-attributes + """Operational dataset.""" + + SCHEMA = vol.Schema( + { + vol.Optional("ActiveTimestamp"): dict, + vol.Optional("ChannelMask"): int, + vol.Optional("Channel"): int, + vol.Optional("Delay"): int, + vol.Optional("ExtPanId"): str, + vol.Optional("MeshLocalPrefix"): str, + vol.Optional("NetworkKey"): str, + vol.Optional("NetworkName"): str, + vol.Optional("PanId"): int, + vol.Optional("PendingTimestamp"): dict, + vol.Optional("PSKc"): str, + vol.Optional("SecurityPolicy"): dict, + } + ) + + active_timestamp: Timestamp | None = None + channel_mask: int | None = None + channel: int | None = None + delay: int | None = None + extended_pan_id: str | None = None + mesh_local_prefix: str | None = None + network_key: str | None = None + network_name: str | None = None + pan_id: int | None = None + pending_timestamp: Timestamp | None = None + psk_c: str | None = None + security_policy: SecurityPolicy | None = None + + def as_json(self) -> dict: + """Serialize to JSON.""" + result: dict[str, Any] = {} + if self.active_timestamp is not None: + result["ActiveTimestamp"] = self.active_timestamp.as_json() + if self.channel_mask is not None: + result["ChannelMask"] = self.channel_mask + if self.channel is not None: + result["Channel"] = self.channel + if self.delay is not None: + result["Delay"] = self.delay + if self.extended_pan_id is not None: + result["ExtPanId"] = self.extended_pan_id + if self.mesh_local_prefix is not None: + result["MeshLocalPrefix"] = self.mesh_local_prefix + if self.network_key is not None: + result["NetworkKey"] = self.network_key + if self.network_name is not None: + result["NetworkName"] = self.network_name + if self.pan_id is not None: + result["PanId"] = self.pan_id + if self.pending_timestamp is not None: + result["PendingTimestamp"] = self.pending_timestamp.as_json() + if self.psk_c is not None: + result["PSKc"] = self.psk_c + if self.security_policy is not None: + result["SecurityPolicy"] = self.security_policy.as_json() + return result + + @classmethod + def from_json(cls, json_data: Any) -> OperationalDataSet: + """Deserialize from JSON.""" + cls.SCHEMA(json_data) + active_timestamp = None + pending_timestamp = None + security_policy = None + if "ActiveTimestamp" in json_data: + active_timestamp = Timestamp.from_json(json_data["ActiveTimestamp"]) + if "PendingTimestamp" in json_data: + pending_timestamp = Timestamp.from_json(json_data["PendingTimestamp"]) + if "SecurityPolicy" in json_data: + security_policy = SecurityPolicy.from_json(json_data["SecurityPolicy"]) + + return OperationalDataSet( + active_timestamp, + json_data.get("ChannelMask"), + json_data.get("Channel"), + json_data.get("Delay"), + json_data.get("ExtPanId"), + json_data.get("MeshLocalPrefix"), + json_data.get("NetworkKey"), + json_data.get("NetworkName"), + json_data.get("PanId"), + pending_timestamp, + json_data.get("PSKc"), + security_policy, + ) diff --git a/requirements.txt b/requirements.txt index ee4ba4f..4f60397 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1 +1,2 @@ aiohttp +voluptuous