From 0ac72b4feb41f82639ed323ded7b158bd7f31e56 Mon Sep 17 00:00:00 2001 From: igorcoding Date: Sat, 30 Dec 2023 21:37:38 +0300 Subject: [PATCH] Exporting features to Connection class --- asynctnt/connection.py | 8 ++++-- asynctnt/iproto/protocol.pxd | 1 + asynctnt/iproto/protocol.pyi | 14 ++++++++++ asynctnt/iproto/protocol.pyx | 10 +++++-- asynctnt/iproto/response.pxd | 14 ++++++++++ asynctnt/iproto/response.pyx | 34 +++++++++++++++++++++++- setup.py | 2 +- temp/demo.py | 14 ++++++++++ tests/test_connect.py | 51 +++++++++++++++++++++++++++++++++++- 9 files changed, 141 insertions(+), 7 deletions(-) create mode 100644 temp/demo.py diff --git a/asynctnt/connection.py b/asynctnt/connection.py index 12e4571..7791be9 100644 --- a/asynctnt/connection.py +++ b/asynctnt/connection.py @@ -444,7 +444,7 @@ async def reconnect(self): await self.disconnect() await self.connect() - async def __aenter__(self): + async def __aenter__(self) -> "Connection": """ Executed on entering the async with section. Connects to Tarantool instance. @@ -606,7 +606,7 @@ def _normalize_api(self): Api.call = Api.call16 Connection.call = Connection.call16 - if self.version < (2, 10): # pragma: nocover + if not self.features.streams: # pragma: nocover def stream_stub(_): raise TarantoolError("streams are available only in Tarantool 2.10+") @@ -627,6 +627,10 @@ def stream(self) -> Stream: stream._set_db(db) return stream + @property + def features(self) -> protocol.IProtoFeatures: + return self._protocol.features + async def connect(**kwargs) -> Connection: """ diff --git a/asynctnt/iproto/protocol.pxd b/asynctnt/iproto/protocol.pxd index ce42e94..8dc1726 100644 --- a/asynctnt/iproto/protocol.pxd +++ b/asynctnt/iproto/protocol.pxd @@ -69,6 +69,7 @@ cdef class BaseProtocol(CoreProtocol): bint _schema_fetch_in_progress object _refetch_schema_future Db _db + IProtoFeatures _features req_execute_func execute object create_future diff --git a/asynctnt/iproto/protocol.pyi b/asynctnt/iproto/protocol.pyi index 10ee16a..f6e3ae1 100644 --- a/asynctnt/iproto/protocol.pyi +++ b/asynctnt/iproto/protocol.pyi @@ -170,9 +170,23 @@ class Protocol: def schema_id(self) -> int: ... @property def schema(self) -> Schema: ... + @property + def features(self) -> IProtoFeatures: ... def create_db(self, gen_stream_id: bool = False) -> Db: ... def get_common_db(self) -> Db: ... def refetch_schema(self) -> asyncio.Future: ... def is_connected(self) -> bool: ... def is_fully_connected(self) -> bool: ... def get_version(self) -> tuple: ... + +class IProtoFeatures: + streams: bool + transactions: bool + error_extension: bool + watchers: bool + pagination: bool + space_and_index_names: bool + watch_once: bool + dml_tuple_extension: bool + call_ret_tuple_extension: bool + call_arg_tuple_extension: bool diff --git a/asynctnt/iproto/protocol.pyx b/asynctnt/iproto/protocol.pyx index 8f479dd..c2a1da1 100644 --- a/asynctnt/iproto/protocol.pyx +++ b/asynctnt/iproto/protocol.pyx @@ -99,6 +99,7 @@ cdef class BaseProtocol(CoreProtocol): self._refetch_schema_future = None self._db = self._create_db( False) self.execute = self._execute_bad + self._features = IProtoFeatures.__new__(IProtoFeatures) try: self.create_future = self.loop.create_future @@ -253,8 +254,9 @@ cdef class BaseProtocol(CoreProtocol): return e = f.exception() if not e: - logger.debug('Tarantool[%s:%s] identified successfully', - self.host, self.port) + self._features = ( f.result()).result_ + logger.debug('Tarantool[%s:%s] iproto features available: %r', + self.host, self.port, self.features) self.post_con_state = POST_CONNECTION_AUTH self._post_con_state_machine() @@ -515,6 +517,10 @@ cdef class BaseProtocol(CoreProtocol): def refetch_schema(self): return self._refetch_schema() + @property + def features(self) -> IProtoFeatures: + return self._features + class Protocol(BaseProtocol, asyncio.Protocol): pass diff --git a/asynctnt/iproto/response.pxd b/asynctnt/iproto/response.pxd index 8bb59eb..130ffc9 100644 --- a/asynctnt/iproto/response.pxd +++ b/asynctnt/iproto/response.pxd @@ -41,6 +41,7 @@ cdef class Response: bint _push_subscribe BaseRequest request_ object _exception + object result_ readonly object _q readonly object _push_event @@ -66,3 +67,16 @@ cdef ssize_t response_parse_body(const char *buf, uint32_t buf_len, Response resp, BaseRequest req, bint is_chunk) except -1 cdef IProtoError parse_iproto_error(const char ** b, bytes encoding) + +cdef class IProtoFeatures: + cdef: + readonly bint streams + readonly bint transactions + readonly bint error_extension + readonly bint watchers + readonly bint pagination + readonly bint space_and_index_names + readonly bint watch_once + readonly bint dml_tuple_extension + readonly bint call_ret_tuple_extension + readonly bint call_arg_tuple_extension diff --git a/asynctnt/iproto/response.pyx b/asynctnt/iproto/response.pyx index 57ac06b..1b56cc9 100644 --- a/asynctnt/iproto/response.pyx +++ b/asynctnt/iproto/response.pyx @@ -25,6 +25,10 @@ cdef class IProtoErrorStackFrame: cdef class IProtoError: pass +@cython.final +cdef class IProtoFeatures: + pass + @cython.final @cython.freelist(REQUEST_FREELIST) cdef class Response: @@ -41,6 +45,7 @@ cdef class Response: self.errmsg = None self.error = None self._rowcount = 0 + self.result_ = None self.body = None self.encoding = None self.metadata = None @@ -546,6 +551,7 @@ cdef ssize_t response_parse_body(const char *buf, uint32_t buf_len, const char *s list data Field field + IProtoFeatures features b = buf # mp_fprint(stdio.stdout, b) @@ -635,7 +641,33 @@ cdef ssize_t response_parse_body(const char *buf, uint32_t buf_len, logger.debug("IProto version: %s", _decode_obj(&b, resp.encoding)) elif key == tarantool.IPROTO_FEATURES: - logger.debug("IProto features available: %s", _decode_obj(&b, resp.encoding)) + features = IProtoFeatures.__new__(IProtoFeatures) + + for item in _decode_obj(&b, resp.encoding): + if item == 0: + features.streams = 1 + elif item == 1: + features.transactions = 1 + elif item == 2: + features.error_extension = 1 + elif item == 3: + features.watchers = 1 + elif item == 4: + features.pagination = 1 + elif item == 5: + features.space_and_index_names = 1 + elif item == 6: + features.watch_once = 1 + elif item == 7: + features.dml_tuple_extension = 1 + elif item == 8: + features.call_ret_tuple_extension = 1 + elif item == 9: + features.call_arg_tuple_extension = 1 + else: + logger.debug("unknown iproto feature available: %d", item) + + resp.result_ = features elif key == tarantool.IPROTO_AUTH_TYPE: logger.debug("IProto auth type: %s", _decode_obj(&b, resp.encoding)) diff --git a/setup.py b/setup.py index 8337d1c..1189531 100644 --- a/setup.py +++ b/setup.py @@ -42,7 +42,7 @@ def initialize_options(self): self.debug = True self.gdb_debug = True else: - self.cython_always = False + self.cython_always = True self.cython_annotate = None self.cython_directives = None self.gdb_debug = False diff --git a/temp/demo.py b/temp/demo.py new file mode 100644 index 0000000..f06aea4 --- /dev/null +++ b/temp/demo.py @@ -0,0 +1,14 @@ +import asyncio +import logging + +import asynctnt + + +async def main(): + logging.basicConfig(level=logging.DEBUG) + async with asynctnt.Connection(host="127.0.0.1", port=3305) as conn: + print(conn.features) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/tests/test_connect.py b/tests/test_connect.py index 6bf62b4..80c9a4a 100644 --- a/tests/test_connect.py +++ b/tests/test_connect.py @@ -10,7 +10,7 @@ ) from asynctnt.instance import TarantoolSyncInstance from tests import BaseTarantoolTestCase -from tests._testbase import check_version +from tests._testbase import check_version, ensure_version class ConnectTestCase(BaseTarantoolTestCase): @@ -802,3 +802,52 @@ async def state_checker(): await conn.call("box.info") finally: await conn.disconnect() + + async def test__features(self): + async with asynctnt.Connection(host=self.tnt.host, port=self.tnt.port) as conn: + if not check_version( + self, + conn.version, + min=(2, 10), + max=(3, 0), + min_included=True, + max_included=False, + ): + return + + self.assertIsNotNone(conn.features) + self.assertTrue(conn.features.streams) + self.assertTrue(conn.features.watchers) + self.assertTrue(conn.features.error_extension) + self.assertTrue(conn.features.transactions) + self.assertTrue(conn.features.pagination) + + self.assertFalse(conn.features.space_and_index_names) + self.assertFalse(conn.features.watch_once) + self.assertFalse(conn.features.dml_tuple_extension) + self.assertFalse(conn.features.call_ret_tuple_extension) + self.assertFalse(conn.features.call_arg_tuple_extension) + + async def test__features_3_0(self): + async with asynctnt.Connection(host=self.tnt.host, port=self.tnt.port) as conn: + if not check_version( + self, + conn.version, + min=(3, 0), + min_included=True, + max_included=False, + ): + return + + self.assertIsNotNone(conn.features) + self.assertTrue(conn.features.streams) + self.assertTrue(conn.features.watchers) + self.assertTrue(conn.features.error_extension) + self.assertTrue(conn.features.transactions) + self.assertTrue(conn.features.pagination) + + self.assertTrue(conn.features.space_and_index_names) + self.assertTrue(conn.features.watch_once) + self.assertTrue(conn.features.dml_tuple_extension) + self.assertTrue(conn.features.call_ret_tuple_extension) + self.assertTrue(conn.features.call_arg_tuple_extension)