|
4 | 4 | import socket |
5 | 5 | import ssl |
6 | 6 | import sys |
| 7 | +from typing_extensions import Self |
7 | 8 | import warnings |
8 | 9 | from dataclasses import dataclass |
9 | | -from typing import Callable, Literal, Optional, Tuple, Type |
| 10 | +from typing import Callable, ClassVar, Literal, Optional, Tuple, Type |
10 | 11 |
|
11 | 12 | from grpclib.client import Channel, Stream |
12 | 13 | from grpclib.const import Cardinality |
|
17 | 18 | from viam.errors import InsecureConnectionError, ViamError |
18 | 19 | from viam.proto.rpc.auth import AuthenticateRequest, AuthServiceStub |
19 | 20 | from viam.proto.rpc.auth import Credentials as PBCredentials |
| 21 | +from viam.utils import PointerCounter |
20 | 22 |
|
21 | 23 | LOGGER = logging.getLogger(__name__) |
22 | 24 |
|
@@ -151,51 +153,83 @@ async def __aexit__(self, exc_type, exc_value, traceback): |
151 | 153 | self.close() |
152 | 154 |
|
153 | 155 |
|
154 | | -async def dial(address: str, options: Optional[DialOptions] = None) -> ViamChannel: |
155 | | - opts = options if options else DialOptions() |
156 | | - if opts.disable_webrtc: |
157 | | - channel = await _dial_direct(address, options) |
158 | | - return ViamChannel(channel, lambda: None) |
| 156 | +class _Runtime: |
159 | 157 |
|
160 | | - creds = opts.credentials.payload if opts.credentials else "" |
161 | | - insecure = opts.insecure or opts.allow_insecure_with_creds_downgrade or (not creds and opts.allow_insecure_downgrade) |
| 158 | + _shared: ClassVar[Self] |
162 | 159 |
|
163 | | - libname = pathlib.Path(__file__).parent.absolute() / f"libviam.{'dylib' if sys.platform == 'darwin' else 'so'}" |
164 | | - c_lib = ctypes.CDLL(libname.__str__()) |
165 | | - c_lib.init_rust_runtime.argtypes = () |
166 | | - c_lib.init_rust_runtime.restype = ctypes.c_void_p |
| 160 | + _lib: ctypes.CDLL |
| 161 | + _ptr: ctypes.c_void_p |
| 162 | + _semaphore: PointerCounter = PointerCounter() |
167 | 163 |
|
168 | | - c_lib.dial.argtypes = (ctypes.c_char_p, ctypes.c_char_p, ctypes.c_bool, ctypes.c_void_p) |
169 | | - c_lib.dial.restype = ctypes.c_void_p |
| 164 | + def __new__(cls): |
| 165 | + if not hasattr(cls, "_shared"): |
| 166 | + cls._shared = super(_Runtime, cls).__new__(cls) |
170 | 167 |
|
171 | | - c_lib.free_rust_runtime.argtypes = (ctypes.c_void_p,) |
172 | | - c_lib.free_rust_runtime.restype = None |
| 168 | + libname = pathlib.Path(__file__).parent.absolute() / f"libviam.{'dylib' if sys.platform == 'darwin' else 'so'}" |
| 169 | + cls._shared._lib = ctypes.CDLL(libname.__str__()) |
| 170 | + cls._shared._lib.init_rust_runtime.argtypes = () |
| 171 | + cls._shared._lib.init_rust_runtime.restype = ctypes.c_void_p |
173 | 172 |
|
174 | | - c_lib.free_string.argtypes = (ctypes.c_void_p,) |
175 | | - c_lib.free_string.restype = None |
| 173 | + cls._shared._lib.dial.argtypes = (ctypes.c_char_p, ctypes.c_char_p, ctypes.c_bool, ctypes.c_void_p) |
| 174 | + cls._shared._lib.dial.restype = ctypes.c_void_p |
176 | 175 |
|
177 | | - ptr = c_lib.init_rust_runtime() |
| 176 | + cls._shared._lib.free_rust_runtime.argtypes = (ctypes.c_void_p,) |
| 177 | + cls._shared._lib.free_rust_runtime.restype = None |
178 | 178 |
|
179 | | - path_ptr = c_lib.dial( |
180 | | - address.encode("utf-8"), |
181 | | - creds.encode("utf-8") if creds else None, |
182 | | - insecure, |
183 | | - ptr, |
184 | | - ) |
185 | | - path = ctypes.cast(path_ptr, ctypes.c_char_p).value |
| 179 | + cls._shared._lib.free_string.argtypes = (ctypes.c_void_p,) |
| 180 | + cls._shared._lib.free_string.restype = None |
| 181 | + |
| 182 | + cls._shared._ptr = cls._shared._lib.init_rust_runtime() |
| 183 | + |
| 184 | + return cls._shared |
| 185 | + |
| 186 | + def dial(self, address: str, options: DialOptions) -> Tuple[Optional[str], ctypes.c_void_p]: |
| 187 | + creds = options.credentials.payload if options.credentials else "" |
| 188 | + insecure = options.insecure or options.allow_insecure_with_creds_downgrade or (not creds and options.allow_insecure_downgrade) |
| 189 | + |
| 190 | + path_ptr = self._lib.dial( |
| 191 | + address.encode("utf-8"), |
| 192 | + creds.encode("utf-8") if creds else None, |
| 193 | + insecure, |
| 194 | + self._ptr, |
| 195 | + ) |
| 196 | + path = ctypes.cast(path_ptr, ctypes.c_char_p).value |
| 197 | + path = path.decode("utf-8") if path else "" |
| 198 | + |
| 199 | + self._semaphore.increment() |
| 200 | + |
| 201 | + return (path, path_ptr) |
| 202 | + |
| 203 | + def release(self): |
| 204 | + self._semaphore.decrement() |
| 205 | + if self._semaphore.count == 0: |
| 206 | + self._lib.free_rust_runtime(self._ptr) |
| 207 | + |
| 208 | + def free_str(self, ptr: ctypes.c_void_p): |
| 209 | + self._lib.free_string(ptr) |
| 210 | + |
| 211 | + |
| 212 | +async def dial(address: str, options: Optional[DialOptions] = None) -> ViamChannel: |
| 213 | + opts = options if options else DialOptions() |
| 214 | + if opts.disable_webrtc: |
| 215 | + channel = await _dial_direct(address, options) |
| 216 | + return ViamChannel(channel, lambda: None) |
| 217 | + |
| 218 | + runtime = _Runtime() |
186 | 219 |
|
| 220 | + path, path_ptr = runtime.dial(address, opts) |
187 | 221 | if path: |
188 | 222 | LOGGER.info(f"Connecting to socket: {path}") |
189 | | - chan = Channel(path=path.decode("utf-8"), ssl=None) |
| 223 | + chan = Channel(path=path, ssl=None) |
190 | 224 |
|
191 | 225 | def release(): |
192 | | - c_lib.free_string(path_ptr) |
193 | | - c_lib.free_rust_runtime(ptr) |
| 226 | + runtime.free_str(path_ptr) |
| 227 | + runtime.release() |
194 | 228 |
|
195 | 229 | channel = ViamChannel(chan, release) |
196 | 230 | return channel |
197 | 231 |
|
198 | | - c_lib.free_rust_runtime(ptr) |
| 232 | + runtime.release() |
199 | 233 | raise ViamError(f"Unable to establish a connection to {address}") |
200 | 234 |
|
201 | 235 |
|
|
0 commit comments