Skip to content

Commit fd3e234

Browse files
authored
One Runtime to Rule them All (#157)
1 parent 9e1fae2 commit fd3e234

File tree

4 files changed

+117
-30
lines changed

4 files changed

+117
-30
lines changed

src/viam/robot/client.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ async def at_address(cls, address: str, options: Options) -> Self:
9090
Returns:
9191
Self: the RobotClient
9292
"""
93+
logging.setLevel(options.log_level)
9394
channel = await dial(address, options.dial_options)
9495
robot = await RobotClient.with_channel(channel, options)
9596
robot._should_close_channel = True

src/viam/rpc/dial.py

Lines changed: 64 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,10 @@
44
import socket
55
import ssl
66
import sys
7+
from typing_extensions import Self
78
import warnings
89
from dataclasses import dataclass
9-
from typing import Callable, Literal, Optional, Tuple, Type
10+
from typing import Callable, ClassVar, Literal, Optional, Tuple, Type
1011

1112
from grpclib.client import Channel, Stream
1213
from grpclib.const import Cardinality
@@ -17,6 +18,7 @@
1718
from viam.errors import InsecureConnectionError, ViamError
1819
from viam.proto.rpc.auth import AuthenticateRequest, AuthServiceStub
1920
from viam.proto.rpc.auth import Credentials as PBCredentials
21+
from viam.utils import PointerCounter
2022

2123
LOGGER = logging.getLogger(__name__)
2224

@@ -151,51 +153,83 @@ async def __aexit__(self, exc_type, exc_value, traceback):
151153
self.close()
152154

153155

154-
async def dial(address: str, options: Optional[DialOptions] = None) -> ViamChannel:
155-
opts = options if options else DialOptions()
156-
if opts.disable_webrtc:
157-
channel = await _dial_direct(address, options)
158-
return ViamChannel(channel, lambda: None)
156+
class _Runtime:
159157

160-
creds = opts.credentials.payload if opts.credentials else ""
161-
insecure = opts.insecure or opts.allow_insecure_with_creds_downgrade or (not creds and opts.allow_insecure_downgrade)
158+
_shared: ClassVar[Self]
162159

163-
libname = pathlib.Path(__file__).parent.absolute() / f"libviam.{'dylib' if sys.platform == 'darwin' else 'so'}"
164-
c_lib = ctypes.CDLL(libname.__str__())
165-
c_lib.init_rust_runtime.argtypes = ()
166-
c_lib.init_rust_runtime.restype = ctypes.c_void_p
160+
_lib: ctypes.CDLL
161+
_ptr: ctypes.c_void_p
162+
_semaphore: PointerCounter = PointerCounter()
167163

168-
c_lib.dial.argtypes = (ctypes.c_char_p, ctypes.c_char_p, ctypes.c_bool, ctypes.c_void_p)
169-
c_lib.dial.restype = ctypes.c_void_p
164+
def __new__(cls):
165+
if not hasattr(cls, "_shared"):
166+
cls._shared = super(_Runtime, cls).__new__(cls)
170167

171-
c_lib.free_rust_runtime.argtypes = (ctypes.c_void_p,)
172-
c_lib.free_rust_runtime.restype = None
168+
libname = pathlib.Path(__file__).parent.absolute() / f"libviam.{'dylib' if sys.platform == 'darwin' else 'so'}"
169+
cls._shared._lib = ctypes.CDLL(libname.__str__())
170+
cls._shared._lib.init_rust_runtime.argtypes = ()
171+
cls._shared._lib.init_rust_runtime.restype = ctypes.c_void_p
173172

174-
c_lib.free_string.argtypes = (ctypes.c_void_p,)
175-
c_lib.free_string.restype = None
173+
cls._shared._lib.dial.argtypes = (ctypes.c_char_p, ctypes.c_char_p, ctypes.c_bool, ctypes.c_void_p)
174+
cls._shared._lib.dial.restype = ctypes.c_void_p
176175

177-
ptr = c_lib.init_rust_runtime()
176+
cls._shared._lib.free_rust_runtime.argtypes = (ctypes.c_void_p,)
177+
cls._shared._lib.free_rust_runtime.restype = None
178178

179-
path_ptr = c_lib.dial(
180-
address.encode("utf-8"),
181-
creds.encode("utf-8") if creds else None,
182-
insecure,
183-
ptr,
184-
)
185-
path = ctypes.cast(path_ptr, ctypes.c_char_p).value
179+
cls._shared._lib.free_string.argtypes = (ctypes.c_void_p,)
180+
cls._shared._lib.free_string.restype = None
181+
182+
cls._shared._ptr = cls._shared._lib.init_rust_runtime()
183+
184+
return cls._shared
185+
186+
def dial(self, address: str, options: DialOptions) -> Tuple[Optional[str], ctypes.c_void_p]:
187+
creds = options.credentials.payload if options.credentials else ""
188+
insecure = options.insecure or options.allow_insecure_with_creds_downgrade or (not creds and options.allow_insecure_downgrade)
189+
190+
path_ptr = self._lib.dial(
191+
address.encode("utf-8"),
192+
creds.encode("utf-8") if creds else None,
193+
insecure,
194+
self._ptr,
195+
)
196+
path = ctypes.cast(path_ptr, ctypes.c_char_p).value
197+
path = path.decode("utf-8") if path else ""
198+
199+
self._semaphore.increment()
200+
201+
return (path, path_ptr)
202+
203+
def release(self):
204+
self._semaphore.decrement()
205+
if self._semaphore.count == 0:
206+
self._lib.free_rust_runtime(self._ptr)
207+
208+
def free_str(self, ptr: ctypes.c_void_p):
209+
self._lib.free_string(ptr)
210+
211+
212+
async def dial(address: str, options: Optional[DialOptions] = None) -> ViamChannel:
213+
opts = options if options else DialOptions()
214+
if opts.disable_webrtc:
215+
channel = await _dial_direct(address, options)
216+
return ViamChannel(channel, lambda: None)
217+
218+
runtime = _Runtime()
186219

220+
path, path_ptr = runtime.dial(address, opts)
187221
if path:
188222
LOGGER.info(f"Connecting to socket: {path}")
189-
chan = Channel(path=path.decode("utf-8"), ssl=None)
223+
chan = Channel(path=path, ssl=None)
190224

191225
def release():
192-
c_lib.free_string(path_ptr)
193-
c_lib.free_rust_runtime(ptr)
226+
runtime.free_str(path_ptr)
227+
runtime.release()
194228

195229
channel = ViamChannel(chan, release)
196230
return channel
197231

198-
c_lib.free_rust_runtime(ptr)
232+
runtime.release()
199233
raise ViamError(f"Unable to establish a connection to {address}")
200234

201235

src/viam/utils.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from asyncio import Event
12
from typing import Any, Dict, List, Mapping, SupportsFloat, Type, TypeVar
23

34
from google.protobuf.json_format import MessageToDict, ParseDict
@@ -157,3 +158,29 @@ def sensor_readings_value_to_native(readings: Mapping[str, Value]) -> Mapping[st
157158
elif kind == "orientation_vector_degrees":
158159
prim_readings[key] = Orientation(o_x=reading["ox"], o_y=reading["oy"], o_z=reading["oz"], theta=reading["theta"])
159160
return prim_readings
161+
162+
163+
class PointerCounter:
164+
def __init__(self) -> None:
165+
self._event = Event()
166+
self._count = 0
167+
self._event.set()
168+
169+
def increment(self) -> int:
170+
self._count += 1
171+
self._event.clear()
172+
return self._count
173+
174+
def decrement(self) -> int:
175+
assert self._count > 0, "Pointer count cannot go below zero"
176+
self._count -= 1
177+
if self._count == 0:
178+
self._event.set()
179+
return self._count
180+
181+
async def wait(self) -> None:
182+
await self._event.wait()
183+
184+
@property
185+
def count(self) -> int:
186+
return self._count

tests/test_utils.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import asyncio
12
import pytest
23
from google.protobuf.json_format import ParseError
34
from google.protobuf.struct_pb2 import ListValue, Struct, Value
@@ -10,6 +11,7 @@
1011
Vector3,
1112
)
1213
from viam.utils import (
14+
PointerCounter,
1315
dict_to_struct,
1416
message_to_struct,
1517
primitive_to_value,
@@ -221,3 +223,26 @@ def test_sensor_readings():
221223
response = sensor_readings_value_to_native(test)
222224

223225
assert response == expected
226+
227+
228+
@pytest.mark.asyncio
229+
async def test_pointer_counter():
230+
counter = PointerCounter()
231+
232+
assert counter.count == 0
233+
234+
counter.increment()
235+
assert counter.count == 1
236+
237+
async def final_test():
238+
assert counter.count > 0
239+
await counter.wait()
240+
assert counter.count == 0
241+
242+
task = asyncio.get_running_loop().create_task(final_test())
243+
244+
await asyncio.sleep(0) # Needed to start the final_test task
245+
246+
counter.decrement()
247+
248+
await task

0 commit comments

Comments
 (0)