Skip to content

Commit

Permalink
Exporting features to Connection class
Browse files Browse the repository at this point in the history
  • Loading branch information
igorcoding committed Dec 31, 2023
1 parent 2277115 commit 0ac72b4
Show file tree
Hide file tree
Showing 9 changed files with 141 additions and 7 deletions.
8 changes: 6 additions & 2 deletions asynctnt/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,7 +444,7 @@ async def reconnect(self):
await self.disconnect()
await self.connect()

async def __aenter__(self):
async def __aenter__(self) -> "Connection":
"""
Executed on entering the async with section.
Connects to Tarantool instance.
Expand Down Expand Up @@ -606,7 +606,7 @@ def _normalize_api(self):
Api.call = Api.call16
Connection.call = Connection.call16

if self.version < (2, 10): # pragma: nocover
if not self.features.streams: # pragma: nocover

def stream_stub(_):
raise TarantoolError("streams are available only in Tarantool 2.10+")
Expand All @@ -627,6 +627,10 @@ def stream(self) -> Stream:
stream._set_db(db)
return stream

@property
def features(self) -> protocol.IProtoFeatures:
return self._protocol.features


async def connect(**kwargs) -> Connection:
"""
Expand Down
1 change: 1 addition & 0 deletions asynctnt/iproto/protocol.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ cdef class BaseProtocol(CoreProtocol):
bint _schema_fetch_in_progress
object _refetch_schema_future
Db _db
IProtoFeatures _features
req_execute_func execute

object create_future
Expand Down
14 changes: 14 additions & 0 deletions asynctnt/iproto/protocol.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -170,9 +170,23 @@ class Protocol:
def schema_id(self) -> int: ...
@property
def schema(self) -> Schema: ...
@property
def features(self) -> IProtoFeatures: ...
def create_db(self, gen_stream_id: bool = False) -> Db: ...
def get_common_db(self) -> Db: ...
def refetch_schema(self) -> asyncio.Future: ...
def is_connected(self) -> bool: ...
def is_fully_connected(self) -> bool: ...
def get_version(self) -> tuple: ...

class IProtoFeatures:
streams: bool
transactions: bool
error_extension: bool
watchers: bool
pagination: bool
space_and_index_names: bool
watch_once: bool
dml_tuple_extension: bool
call_ret_tuple_extension: bool
call_arg_tuple_extension: bool
10 changes: 8 additions & 2 deletions asynctnt/iproto/protocol.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ cdef class BaseProtocol(CoreProtocol):
self._refetch_schema_future = None
self._db = self._create_db(<bint> False)
self.execute = self._execute_bad
self._features = IProtoFeatures.__new__(IProtoFeatures)

try:
self.create_future = self.loop.create_future
Expand Down Expand Up @@ -253,8 +254,9 @@ cdef class BaseProtocol(CoreProtocol):
return
e = f.exception()
if not e:
logger.debug('Tarantool[%s:%s] identified successfully',
self.host, self.port)
self._features = (<Response> f.result()).result_
logger.debug('Tarantool[%s:%s] iproto features available: %r',
self.host, self.port, self.features)

self.post_con_state = POST_CONNECTION_AUTH
self._post_con_state_machine()
Expand Down Expand Up @@ -515,6 +517,10 @@ cdef class BaseProtocol(CoreProtocol):
def refetch_schema(self):
return self._refetch_schema()

@property
def features(self) -> IProtoFeatures:
return self._features


class Protocol(BaseProtocol, asyncio.Protocol):
pass
Expand Down
14 changes: 14 additions & 0 deletions asynctnt/iproto/response.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ cdef class Response:
bint _push_subscribe
BaseRequest request_
object _exception
object result_

readonly object _q
readonly object _push_event
Expand All @@ -66,3 +67,16 @@ cdef ssize_t response_parse_body(const char *buf, uint32_t buf_len,
Response resp, BaseRequest req,
bint is_chunk) except -1
cdef IProtoError parse_iproto_error(const char ** b, bytes encoding)

cdef class IProtoFeatures:
cdef:
readonly bint streams
readonly bint transactions
readonly bint error_extension
readonly bint watchers
readonly bint pagination
readonly bint space_and_index_names
readonly bint watch_once
readonly bint dml_tuple_extension
readonly bint call_ret_tuple_extension
readonly bint call_arg_tuple_extension
34 changes: 33 additions & 1 deletion asynctnt/iproto/response.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ cdef class IProtoErrorStackFrame:
cdef class IProtoError:
pass

@cython.final
cdef class IProtoFeatures:
pass

@cython.final
@cython.freelist(REQUEST_FREELIST)
cdef class Response:
Expand All @@ -41,6 +45,7 @@ cdef class Response:
self.errmsg = None
self.error = None
self._rowcount = 0
self.result_ = None
self.body = None
self.encoding = None
self.metadata = None
Expand Down Expand Up @@ -546,6 +551,7 @@ cdef ssize_t response_parse_body(const char *buf, uint32_t buf_len,
const char *s
list data
Field field
IProtoFeatures features

b = <const char *> buf
# mp_fprint(stdio.stdout, b)
Expand Down Expand Up @@ -635,7 +641,33 @@ cdef ssize_t response_parse_body(const char *buf, uint32_t buf_len,
logger.debug("IProto version: %s", _decode_obj(&b, resp.encoding))

elif key == tarantool.IPROTO_FEATURES:
logger.debug("IProto features available: %s", _decode_obj(&b, resp.encoding))
features = <IProtoFeatures> IProtoFeatures.__new__(IProtoFeatures)

for item in _decode_obj(&b, resp.encoding):
if item == 0:
features.streams = 1
elif item == 1:
features.transactions = 1
elif item == 2:
features.error_extension = 1
elif item == 3:
features.watchers = 1
elif item == 4:
features.pagination = 1
elif item == 5:
features.space_and_index_names = 1
elif item == 6:
features.watch_once = 1
elif item == 7:
features.dml_tuple_extension = 1
elif item == 8:
features.call_ret_tuple_extension = 1
elif item == 9:
features.call_arg_tuple_extension = 1
else:
logger.debug("unknown iproto feature available: %d", item)

resp.result_ = features

elif key == tarantool.IPROTO_AUTH_TYPE:
logger.debug("IProto auth type: %s", _decode_obj(&b, resp.encoding))
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def initialize_options(self):
self.debug = True
self.gdb_debug = True
else:
self.cython_always = False
self.cython_always = True
self.cython_annotate = None
self.cython_directives = None
self.gdb_debug = False
Expand Down
14 changes: 14 additions & 0 deletions temp/demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import asyncio
import logging

import asynctnt


async def main():
logging.basicConfig(level=logging.DEBUG)
async with asynctnt.Connection(host="127.0.0.1", port=3305) as conn:
print(conn.features)


if __name__ == "__main__":
asyncio.run(main())
51 changes: 50 additions & 1 deletion tests/test_connect.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
)
from asynctnt.instance import TarantoolSyncInstance
from tests import BaseTarantoolTestCase
from tests._testbase import check_version
from tests._testbase import check_version, ensure_version


class ConnectTestCase(BaseTarantoolTestCase):
Expand Down Expand Up @@ -802,3 +802,52 @@ async def state_checker():
await conn.call("box.info")
finally:
await conn.disconnect()

async def test__features(self):
async with asynctnt.Connection(host=self.tnt.host, port=self.tnt.port) as conn:
if not check_version(
self,
conn.version,
min=(2, 10),
max=(3, 0),
min_included=True,
max_included=False,
):
return

self.assertIsNotNone(conn.features)
self.assertTrue(conn.features.streams)
self.assertTrue(conn.features.watchers)
self.assertTrue(conn.features.error_extension)
self.assertTrue(conn.features.transactions)
self.assertTrue(conn.features.pagination)

self.assertFalse(conn.features.space_and_index_names)
self.assertFalse(conn.features.watch_once)
self.assertFalse(conn.features.dml_tuple_extension)
self.assertFalse(conn.features.call_ret_tuple_extension)
self.assertFalse(conn.features.call_arg_tuple_extension)

async def test__features_3_0(self):
async with asynctnt.Connection(host=self.tnt.host, port=self.tnt.port) as conn:
if not check_version(
self,
conn.version,
min=(3, 0),
min_included=True,
max_included=False,
):
return

self.assertIsNotNone(conn.features)
self.assertTrue(conn.features.streams)
self.assertTrue(conn.features.watchers)
self.assertTrue(conn.features.error_extension)
self.assertTrue(conn.features.transactions)
self.assertTrue(conn.features.pagination)

self.assertTrue(conn.features.space_and_index_names)
self.assertTrue(conn.features.watch_once)
self.assertTrue(conn.features.dml_tuple_extension)
self.assertTrue(conn.features.call_ret_tuple_extension)
self.assertTrue(conn.features.call_arg_tuple_extension)

0 comments on commit 0ac72b4

Please sign in to comment.