From 056fb02adcf1b8b969ae285d5fe09ea32e1d1b2d Mon Sep 17 00:00:00 2001 From: ipr0 <8950668+ipr0@users.noreply.github.com> Date: Sun, 12 Apr 2020 18:14:38 -0700 Subject: [PATCH 1/3] Python driver implements DB-API. --- Vagrantfile | 3 +- WORKSPACE | 13 +- driver/python/BUILD | 85 ++++++++++ driver/python/__init__.py | 75 +++++++++ driver/python/connection.py | 192 +++++++++++++++++++++++ driver/python/cursor.py | 216 ++++++++++++++++++++++++++ driver/python/dbtypes.py | 83 ++++++++++ driver/python/exceptions.py | 58 +++++++ driver/python/test/connection_test.py | 68 ++++++++ driver/python/test/cursor_test.py | 50 ++++++ driver/python/test/pylib.py | 92 +++++++++++ driver/python/test/test_data.py | 52 +++++++ server/BUILD | 13 ++ sfdb/BUILD | 13 ++ 14 files changed, 1010 insertions(+), 3 deletions(-) create mode 100644 driver/python/BUILD create mode 100644 driver/python/__init__.py create mode 100644 driver/python/connection.py create mode 100644 driver/python/cursor.py create mode 100644 driver/python/dbtypes.py create mode 100644 driver/python/exceptions.py create mode 100644 driver/python/test/connection_test.py create mode 100644 driver/python/test/cursor_test.py create mode 100644 driver/python/test/pylib.py create mode 100644 driver/python/test/test_data.py diff --git a/Vagrantfile b/Vagrantfile index 7a3cec9..2ac3aa6 100644 --- a/Vagrantfile +++ b/Vagrantfile @@ -21,7 +21,8 @@ Vagrant.configure("2") do |config| config.vm.provision "shell", inline: <<-SHELL apt-get -qqy update - apt-get -qqy install make zip unzip git pkg-config libssl-dev zlib1g-dev + apt-get -qqy install make zip unzip git pkg-config libssl-dev zlib1g-dev + apt-get -qqy install python-dev python3-dev python3.7 echo "[vagrant provisioning] Downloading Bazel..." wget --quiet https://github.com/bazelbuild/bazel/releases/download/0.29.1/bazel-0.29.1-installer-linux-x86_64.sh diff --git a/WORKSPACE b/WORKSPACE index e8beb48..b8f0749 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -21,8 +21,7 @@ http_archive( git_repository( name = "com_github_grpc_grpc", remote = "https://github.com/grpc/grpc.git", - commit = "08fd59f039c7cf62614ab7741b3f34527af103c7", - shallow_since = "1562093080 -0700", + tag = "v1.24.3", ) git_repository( @@ -66,6 +65,16 @@ git_repository( shallow_since = "1560490505 +0000", ) +# Python dependencies +load("@upb//bazel:workspace_deps.bzl", "upb_deps") +upb_deps() + +load("@build_bazel_rules_apple//apple:repositories.bzl", "apple_rules_dependencies") +apple_rules_dependencies() + +load("@build_bazel_apple_support//lib:repositories.bzl", "apple_support_dependencies") +apple_support_dependencies() + # GoLang main Bazel tools http_archive( name = "io_bazel_rules_go", diff --git a/driver/python/BUILD b/driver/python/BUILD new file mode 100644 index 0000000..90c4137 --- /dev/null +++ b/driver/python/BUILD @@ -0,0 +1,85 @@ +# bazel test //driver/python:all_tests --test_output=streamed --runs_per_test=2 + +py_library( + name = "py_db_api", + srcs = [ + "__init__.py", + "connection.py", + "cursor.py", + "exceptions.py", + "dbtypes.py", + ], + srcs_version = "PY3", + visibility = ["//visibility:public"], + imports = ["."], + deps = [ + "//server:grpc_sfdb_service_py_pb2", + "//server:grpc_sfdb_service_py_pb2_grpc", + "//sfdb:api_py_pb2", + "//sfdb:api_pb2_grpc", + ], +) + +py_library( + name = "test/py_test_lib", + srcs = [ + "test/pylib.py", + ], + srcs_version = "PY3", + visibility = ["//visibility:public"], +) + +py_test( + name = "test/connection_test", + srcs = [ + "test/connection_test.py", + "test/pylib.py", + "test/test_data.py", + ], + srcs_version = "PY3", + visibility = ["//visibility:public"], + deps = [ + ":py_db_api", + ], + data = ["//sfdb:sfdb"], +) + +py_test( + name = "test/cursor_test", + srcs = [ + "test/cursor_test.py", + "test/pylib.py", + "test/test_data.py", + ], + srcs_version = "PY3", + visibility = ["//visibility:public"], + deps = [ + ":py_db_api", + ], + data = ["//sfdb:sfdb"], +) + +test_suite( + name = "all_tests", +) + +# python toolchain definition, +# see https://github.com/bazelbuild/bazel/issues/7899 +load("@bazel_tools//tools/python:toolchain.bzl", "py_runtime_pair") +py_runtime( + name = "py3_runtime", + interpreter_path = "/usr/bin/python3.7", + python_version = "PY3", +) + +py_runtime_pair( + name = "py_runtime_pair", + py3_runtime = ":py3_runtime", +) + +toolchain( + name = "py_toolchain", + #target_compatible_with = [...], # optional platform constraints + toolchain = ":py_runtime_pair", + toolchain_type = "@bazel_tools//tools/python:toolchain_type", +) \ No newline at end of file diff --git a/driver/python/__init__.py b/driver/python/__init__.py new file mode 100644 index 0000000..20880b7 --- /dev/null +++ b/driver/python/__init__.py @@ -0,0 +1,75 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Python DB API v2.0 (DB-API): +https://www.python.org/dev/peps/pep-0249 """ + + +from connection import connect +from connection import Connection +from cursor import Cursor +from exceptions import Warning +from exceptions import Error +from exceptions import InterfaceError +from exceptions import DatabaseError +from exceptions import DataError +from exceptions import OperationalError +from exceptions import IntegrityError +from exceptions import InternalError +from exceptions import ProgrammingError +from exceptions import NotSupportedError +from dbtypes import Binary +from dbtypes import Date +from dbtypes import DateFromTicks +from dbtypes import Timestamp +from dbtypes import TimestampFromTicks +from dbtypes import BINARY +from dbtypes import DATETIME +from dbtypes import NUMBER +from dbtypes import ROWID +from dbtypes import STRING + +apilevel = '2.0' + +paramstyle = 'pyformat' + +__all__ = [ + 'apilevel', + 'paramstyle', + 'connect', + 'Connection', + 'Cursor', + 'Warning', + 'Error', + 'InterfaceError', + 'DatabaseError', + 'DataError', + 'OperationalError', + 'IntegrityError', + 'InternalError', + 'ProgrammingError', + 'NotSupportedError', + 'Binary', + 'Date', + 'DateFromTicks', + 'Time', + 'TimeFromTicks', + 'Timestamp', + 'TimestampFromTicks', + 'BINARY', + 'DATETIME', + 'NUMBER', + 'ROWID', + 'STRING', +] diff --git a/driver/python/connection.py b/driver/python/connection.py new file mode 100644 index 0000000..e89fb1d --- /dev/null +++ b/driver/python/connection.py @@ -0,0 +1,192 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Connection for the Google SFDB DB-API.""" + +import grpc +import logging +import sys +import platform +import exceptions +from cursor import Cursor +from logging import getLogger + +from server import grpc_service_pb2_grpc +from server import grpc_service_pb2 + +TIMEOUT = 5 +PYTHON_VERSION = ".".join([str(v) for v in sys.version_info[:3]]) +PLATFORM = platform.platform() +_LOGGER = getLogger(__name__) + + +class Connection(object): + """DB-API connection to SFDB""" + + def __init__(self, host_port, dbname, params, creds): + """Creates and maintains gRPC channel. + + Args: + host_port(tuple(str, str)): hostname:port + dbname(str): database name to connect (not implemented in SFDB) + params(*str): connection parameters + creds(tuple(str, str)): username:password + + Returns: + Connection object + + Raises: + InterfaceError: if there's any gRPC channel related exception + """ + self.addr, self.port = host_port # (ipv4:port|ipv6:port|hostname:port) + self._creds = creds + self._dbname = dbname + self._params = params + self._timeout = TIMEOUT + self._grpc_channel = None + self._stub = None + self._errors = {} + + _LOGGER.debug( + u"SFDB connector for Python" + u"Version: %s, Platform: %s", + PYTHON_VERSION, PLATFORM) + self._connect() + + def _connect(self): + try: + self._grpc_channel = grpc.insecure_channel( + "{}:{}".format(self.addr, self.port)) + self._stub = grpc_service_pb2_grpc.SfdbServiceStub( + self._grpc_channel) + _LOGGER.debug("gRPC channel and stub initialized") + except grpc.RpcError as rpc_error: + err_msg = "{}: {}".format( + rpc_error.code().name, rpc_error.details()) + _LOGGER.error("gRPC channel error %s", err_msg) + raise exceptions.InterfaceError(err_msg) + + def close(self): + """Close the connection now. + + Connection will not be usable and will raise exception.""" + self._stub = None + self._grpc_channel.close() + _LOGGER.info("gRPC connection closed.") + + def commit(self): + """SFDB does not support transactions.""" + pass + + def rollback(self): + """SFDB does not support transactions.""" + pass + + def cursor(self, cursor_class=Cursor): + """Creates cursor object using the connection. + + Each statement should create a new cursor object.""" + _LOGGER.debug("Creating cursor.") + if not self._grpc_channel: + raise exceptions.InterfaceError() + return cursor_class(self) + + def cmd_query(self, req): + """Executes SQL protobuf over gRPC. + + Args: + req: request protobuf + + Returns: + responce protubuf + + Raises: + InterfaceError if there any errors. + """ + try: + resp = self._stub.ExecSql(req) + except grpc.RpcError as rpc_error: + resp = str(rpc_error) + self._errors = {"status_code": rpc_error.code().name, + "details": rpc_error.details()} + _LOGGER.error("Error during request: %s", self._errors) + raise exceptions.InterfaceError(self._errors) + return resp + + def _redirect(self, new_sock): + try: + self.addr, self.port, _ = new_sock.split(":") + except ValueError as e: + raise exceptions.InterfaceError("Invalid redirect address:", e) + self._connect() + + def __repr__(self): + return str(list(map(str, (self.addr, + self.port, + self._dbname, + self._params)))) + + +def connect(conn_str): + """Construct a DB-API connection to SFDB. + + Args: + conn_str(str): string containing connection attributes + + Returns: + Connection object + + Raises: + None + """ + addr, dbname, params, creds = _parse_conn_string(conn_str) + return Connection(addr, dbname=dbname, params=params, creds=creds) + + +def _parse_conn_string(conn_str): + """Parses connection string to Connection. + + Example of connection string: + [username[:password]@]address/dbname[?param1=value1&...¶mN=valueN] + returns: creds dict, address tuple, dbname, params dict + + Args: + conn_str(str): string containing connection attributes + + Returns: + tuple containing parsed connection attributes + """ + conn_str = conn_str.strip() + if not conn_str: + raise exceptions.InterfaceError("Empty connection string.") + if r'@' in conn_str: + creds_str, conn_str = conn_str.split(r'@') + if r':' in creds_str: + username, password = creds_str.split(r':') + else: + username, password = creds_str, None + creds = {username: password} + else: + creds = None + if r'?' in conn_str: + conn_str, params_str = conn_str.split(r'?') + params = {} + for param in params_str.split(r'&'): + k, v = param.split(r'=') + params[k] = v + else: + params = None + conn_str, dbname = conn_str.split(r'/') + sock = tuple(conn_str.split(r':')) + return sock, dbname, params, creds diff --git a/driver/python/cursor.py b/driver/python/cursor.py new file mode 100644 index 0000000..d541613 --- /dev/null +++ b/driver/python/cursor.py @@ -0,0 +1,216 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Cursor for SFDB DB-API. + +Used to execute single query. No transaction support. +""" + +import collections +import json +import exceptions +import logging + +from server import grpc_service_pb2_grpc +from server import grpc_service_pb2_grpc +from sfdb import api_pb2 + +from google.protobuf import descriptor as _descriptor +from google.protobuf import reflection as _reflection +from google.protobuf import message + +_LOGGER = logging.getLogger(__name__) +RETRIES = 3 + + +class Cursor(object): + """DB-API Cursor to SFDB. + + Args: + connection (driver.python.Connection): + A DB-API connection to SFDB. + """ + + def __init__(self, connection): + self._connection = connection + self.description = None + # Per PEP 249: The attribute is -1 in case no .execute*() has been + # performed on the cursor or the rowcount of the last operation + # cannot be determined by the interface. + self.rowcount = -1 + # Per PEP 249: The arraysize attribute defaults to 1, meaning to fetch + # a single row at a time. However, we deviate from that, and set the + # default to None, allowing the backend to automatically determine the + # most appropriate size. + self.arraysize = None + self._query_data = None + self._errors = None + + @property + def connection(self): + return self._connection + + def _set_rowcount(self, query_results): + """Set the rowcount from query results. + + Normally, this sets rowcount to the number of rows returned by the + query, but if it was a DML statement, it sets rowcount to the number + of modified rows. + + Args: + query_results (Dictionary): + Results of a query. + """ + total_rows = 0 + num_dml_affected_rows = query_results.num_dml_affected_rows + + if query_results.total_rows is not None and query_results.total_rows > 0: + total_rows = query_results.total_rows + if num_dml_affected_rows is not None and num_dml_affected_rows > 0: + total_rows = num_dml_affected_rows + self.rowcount = total_rows + + def execute(self, operation, parameters=None): + """Prepare and execute a database operation. + + Args: + operation (str): SFDB query string. + + parameters (Union[Mapping[str, Any], Sequence[Any]]): + (Optional) dictionary or sequence of parameter values. + """ + _LOGGER.debug('Operation:\n%s', operation) + self._query_data = None + req = api_pb2.ExecSqlRequest() + req.sql = operation + for _ in range(RETRIES): + try: + resp = self._connection.cmd_query(req) + except exceptions.InterfaceError as e: + self._errors = e.args[0] + _LOGGER.error('Failed to execute query:\n%s', e) + return False + _LOGGER.debug('Response:\n%s', resp) + if resp.status == api_pb2.ExecSqlResponse.ERROR: + _LOGGER.error( + "Server returned an error executing:\n%s", operation) + self._errors = {"FAILED": operation} + if resp.status == api_pb2.ExecSqlResponse.REDIRECT and resp.redirect: + _LOGGER.debug("Redirecting to: %s", resp.redirect) + self._connection._redirect(resp.redirect) + continue + if resp.rows: + descriptors = self._get_descriptors(resp) + self.rowcount = self._set_data(descriptors, resp.rows) + else: + self.rowcount = 0 + return True + _LOGGER.error("Failed after %d retries.", RETRIES) + return False + + def executemany(self, operation, seq_of_parameters): + """Prepare and execute a database operation multiple times. + + Args: + operation (str): SFDB query string. + + seq_of_parameters (Union[Sequence[Mapping[str, Any], \ + Sequence[Any]]]): + Sequence of many sets of parameter values. + """ + for parameters in seq_of_parameters: + self.execute(operation, parameters) + + def _get_descriptors(self, resp): + """Get descriptors from gRPC response. + + Iterates over descriptors' protobufs in response. Converts descriptor + protobuf to descriptor and save in dict. + """ + msg_descs = {} + prefix = r'/' # row.type_url format is "/ Date: Mon, 13 Apr 2020 16:47:25 -0700 Subject: [PATCH 2/3] Python CLI. --- client/python/BUILD | 32 +++++++++++++++ client/python/__init__.py | 1 + client/python/main.py | 26 ++++++++++++ client/python/optionparser.py | 50 ++++++++++++++++++++++++ client/python/sfdb_cli.py | 63 ++++++++++++++++++++++++++++++ client/python/test/sfdbcli_test.py | 34 ++++++++++++++++ driver/python/BUILD | 7 ++-- 7 files changed, 209 insertions(+), 4 deletions(-) create mode 100644 client/python/BUILD create mode 100644 client/python/__init__.py create mode 100644 client/python/main.py create mode 100644 client/python/optionparser.py create mode 100644 client/python/sfdb_cli.py create mode 100644 client/python/test/sfdbcli_test.py diff --git a/client/python/BUILD b/client/python/BUILD new file mode 100644 index 0000000..c5df03b --- /dev/null +++ b/client/python/BUILD @@ -0,0 +1,32 @@ +py_binary( + name = "main", + srcs = ["main.py",], + python_version = "PY3", + deps = [ + "//driver/python:py_db_api", + ":lib", + ] +) + +py_library( + name = "lib", + srcs = [ + "__init__.py", + "optionparser.py", + "sfdb_cli.py", + ], + visibility = ["//visibility:public"], + imports = ["."], +) + +py_test( + name = "test/sfdbcli_test", + srcs = ["test/sfdbcli_test.py",], + python_version = "PY3", + visibility = ["//visibility:public"], + deps = [ + "//driver/python:py_db_api", + "//driver/python:test/py_test_lib", + ], + data = ["//sfdb:sfdb"], +) \ No newline at end of file diff --git a/client/python/__init__.py b/client/python/__init__.py new file mode 100644 index 0000000..b8023d8 --- /dev/null +++ b/client/python/__init__.py @@ -0,0 +1 @@ +__version__ = '0.0.1' diff --git a/client/python/main.py b/client/python/main.py new file mode 100644 index 0000000..b068e53 --- /dev/null +++ b/client/python/main.py @@ -0,0 +1,26 @@ +import sys +from client.python.optionparser import create_parser +from client.python.sfdb_cli import SfdbCli + + +def run_cli_with(options): + if options.query: + options.interactive_mode = False + sfdbcli = SfdbCli(options) + try: + sfdbcli.connect_to_database() + cursor = sfdbcli.execute_query(str(options.query)) + print(cursor.json) + print(f'Rows affected: {cursor.rowcount}', file=sys.stderr) + finally: + sfdbcli.shutdown() + + +def main(): + sfdbcli_options_parser = create_parser() + sfdbcli_options = sfdbcli_options_parser.parse_args(sys.argv[1:]) + run_cli_with(sfdbcli_options) + + +if __name__ == '__main__': + main() diff --git a/client/python/optionparser.py b/client/python/optionparser.py new file mode 100644 index 0000000..5fe06e9 --- /dev/null +++ b/client/python/optionparser.py @@ -0,0 +1,50 @@ +import argparse +import os + +from client.python import __version__ +from sfdb_cli import LOG_LEVEL_MAP + +SFDB_CLI_SERVER = u'SFDB_CLI_SERVER' + + +def create_parser(): + args_parser = argparse.ArgumentParser( + prog=u'sfdb-cli', + description=u'SFDB CLI. v.{}'.format(__version__) + ) + + args_parser.add_argument( + u'-S', u'--server', + dest='server', + default=os.environ.get(SFDB_CLI_SERVER, None), + metavar='', + help=u'server:port instance to connect e.g. -S \'localhost:27910\'' + ) + + args_parser.add_argument( + u'-Q', u'--query', + dest='query', + default=False, + required=True, + metavar='', + help=u'Executes a query outputting results to STDOUT and exits.' + ) + + args_parser.add_argument( + u'--log_level', + dest='log_level', + default='INFO', + metavar='', + choices=list(LOG_LEVEL_MAP.keys()), + help=u'Log Level.' + ) + + args_parser.add_argument( + u'--log_file', + dest='log_file', + default=None, + metavar='', + help='Path to file to save logs.' + ) + + return args_parser diff --git a/client/python/sfdb_cli.py b/client/python/sfdb_cli.py new file mode 100644 index 0000000..2ef5252 --- /dev/null +++ b/client/python/sfdb_cli.py @@ -0,0 +1,63 @@ +import os +import sys +import logging + +from driver.python.connection import connect + +LOG_LEVEL_MAP = { + 'ERROR': logging.ERROR, + 'WARN': logging.WARN, + 'INFO': logging.INFO, + 'DEBUG': logging.DEBUG, +} + + +class SfdbCli(object): + + default_prompt = r'\d' + + def __init__(self, options): + self.init_logging(options.log_level, options.log_file) + self.logger = logging.getLogger("sfdbcli.SfdbCli") + self.server = options.server + self.query = options.query + + def init_logging(self, log_level, log_file=None): + formatter = logging.Formatter( + '%(asctime)s (%(process)d/%(threadName)s) ' + '%(name)s %(levelname)s - %(message)s') + root_logger = logging.getLogger('') + root_logger.setLevel(LOG_LEVEL_MAP[log_level]) + + console_handler = logging.StreamHandler() + console_handler.setLevel(LOG_LEVEL_MAP[log_level]) + console_handler.setFormatter(formatter) + + root_logger.addHandler(console_handler) + if log_file: + logging_path, log_filename = os.path.split(log_file) + if os.path.isdir(logging_path) and log_filename: + file_handler = logging.FileHandler(log_file, encoding='utf-8') + file_handler.setLevel(LOG_LEVEL_MAP[log_level]) + file_handler.setFormatter(formatter) + + root_logger.addHandler(file_handler) + root_logger.info(f'Initalized logging in: {log_file}') + else: + root_logger.error(f'Invalid logging path: {logging_path}') + root_logger.debug('Initialized sfdbcli logging.') + + def connect_to_database(self): + self.logger.debug("Connecting to database...") + self.conn = connect('/'.join([self.server, 'Test'])) + + def shutdown(self): + self.logger.debug("Shutting down...") + self.conn = None + + def execute_query(self, query_text): + """Process a query string and outputs to STDOUT/file""" + self.logger.debug(f'Executing query: \"{query_text}\"') + cursor = self.conn.cursor() + cursor.execute(query_text) + return cursor diff --git a/client/python/test/sfdbcli_test.py b/client/python/test/sfdbcli_test.py new file mode 100644 index 0000000..3fefdbb --- /dev/null +++ b/client/python/test/sfdbcli_test.py @@ -0,0 +1,34 @@ +import unittest + +from driver.python.connection import connect +from test.pylib import TestSFDB +from driver.python.test.test_data import QUERIES + + +class CLITest(unittest.TestCase): + _db = None + + @classmethod + def setUpClass(cls): + cls._db = TestSFDB() + # spin up additional two instances for proper SFDB cluster init + TestSFDB() + TestSFDB() + for i in TestSFDB.instances: + i.start() + + @classmethod + def tearDownClass(cls): + TestSFDB.shutdown_all() + + def test_query(self): + # test assumes 5 rows exist in database + conn = connect("localhost:27910/test") + cur = conn.cursor() + for q, r in QUERIES.items(): + cur.execute(q) + self.assertEqual(r, cur.rowcount) + + +if __name__ == '__main__': + unittest.main() diff --git a/driver/python/BUILD b/driver/python/BUILD index 90c4137..f2256bb 100644 --- a/driver/python/BUILD +++ b/driver/python/BUILD @@ -24,6 +24,7 @@ py_library( name = "test/py_test_lib", srcs = [ "test/pylib.py", + "test/test_data.py", ], srcs_version = "PY3", visibility = ["//visibility:public"], @@ -33,13 +34,12 @@ py_test( name = "test/connection_test", srcs = [ "test/connection_test.py", - "test/pylib.py", - "test/test_data.py", ], srcs_version = "PY3", visibility = ["//visibility:public"], deps = [ ":py_db_api", + ":test/py_test_lib", ], data = ["//sfdb:sfdb"], ) @@ -48,13 +48,12 @@ py_test( name = "test/cursor_test", srcs = [ "test/cursor_test.py", - "test/pylib.py", - "test/test_data.py", ], srcs_version = "PY3", visibility = ["//visibility:public"], deps = [ ":py_db_api", + ":test/py_test_lib", ], data = ["//sfdb:sfdb"], ) From 74b04d947dd51df3aa9bdede6ae0a5f0e1f1bd1d Mon Sep 17 00:00:00 2001 From: ipr0 <8950668+ipr0@users.noreply.github.com> Date: Tue, 14 Apr 2020 12:55:32 -0700 Subject: [PATCH 3/3] Refactored request handler. --- client/python/main.py | 3 +- driver/python/BUILD | 4 +- driver/python/connection.py | 8 ++-- driver/python/cursor.py | 69 ++++++++++++++++++------------- driver/python/test/cursor_test.py | 2 +- 5 files changed, 48 insertions(+), 38 deletions(-) diff --git a/client/python/main.py b/client/python/main.py index b068e53..d92d0ce 100644 --- a/client/python/main.py +++ b/client/python/main.py @@ -10,7 +10,8 @@ def run_cli_with(options): try: sfdbcli.connect_to_database() cursor = sfdbcli.execute_query(str(options.query)) - print(cursor.json) + if cursor.json: + print(cursor.json) print(f'Rows affected: {cursor.rowcount}', file=sys.stderr) finally: sfdbcli.shutdown() diff --git a/driver/python/BUILD b/driver/python/BUILD index f2256bb..5b12dbe 100644 --- a/driver/python/BUILD +++ b/driver/python/BUILD @@ -1,5 +1,3 @@ -# bazel test //driver/python:all_tests --test_output=streamed --runs_per_test=2 - py_library( name = "py_db_api", srcs = [ @@ -67,7 +65,7 @@ test_suite( load("@bazel_tools//tools/python:toolchain.bzl", "py_runtime_pair") py_runtime( name = "py3_runtime", - interpreter_path = "/usr/bin/python3.7", + interpreter_path = "/usr/bin/python3", python_version = "PY3", ) diff --git a/driver/python/connection.py b/driver/python/connection.py index e89fb1d..9368f16 100644 --- a/driver/python/connection.py +++ b/driver/python/connection.py @@ -58,10 +58,8 @@ def __init__(self, host_port, dbname, params, creds): self._stub = None self._errors = {} - _LOGGER.debug( - u"SFDB connector for Python" - u"Version: %s, Platform: %s", - PYTHON_VERSION, PLATFORM) + _LOGGER.debug(u"SFDB connector. Python Version: %s, Platform: %s", + PYTHON_VERSION, PLATFORM) self._connect() def _connect(self): @@ -112,7 +110,7 @@ def cmd_query(self, req): responce protubuf Raises: - InterfaceError if there any errors. + InterfaceError if there are any errors. """ try: resp = self._stub.ExecSql(req) diff --git a/driver/python/cursor.py b/driver/python/cursor.py index d541613..11addd9 100644 --- a/driver/python/cursor.py +++ b/driver/python/cursor.py @@ -56,6 +56,11 @@ def __init__(self, connection): self.arraysize = None self._query_data = None self._errors = None + self._handler = { + api_pb2.ExecSqlResponse.OK: self._handle_OK, + api_pb2.ExecSqlResponse.ERROR: self._handle_ERROR, + api_pb2.ExecSqlResponse.REDIRECT: self._handle_REDIRECT, + } @property def connection(self): @@ -81,6 +86,36 @@ def _set_rowcount(self, query_results): total_rows = num_dml_affected_rows self.rowcount = total_rows + def _handle(self, request): + for _ in range(RETRIES): + try: + response = self._connection.cmd_query(request) + except exceptions.InterfaceError: + break + result = self._handler[response.status](response) + if result: + return True + return False + + def _handle_OK(self, response): + _LOGGER.debug("Handling OK: %s", response) + self.rowcount = 0 + if response.rows: + descriptors = self._get_descriptors(response) + self.rowcount = self._set_data(descriptors, response.rows) + return True + + def _handle_REDIRECT(self, response): + _LOGGER.debug("Redirecting to: %s", response.redirect) + if response.redirect: + self._connection._redirect(response.redirect) + return False + + def _handle_ERROR(self, response): + _LOGGER.error("Server error") + self._errors = {"FAILED": "In operation."} + return False + def execute(self, operation, parameters=None): """Prepare and execute a database operation. @@ -90,34 +125,11 @@ def execute(self, operation, parameters=None): parameters (Union[Mapping[str, Any], Sequence[Any]]): (Optional) dictionary or sequence of parameter values. """ - _LOGGER.debug('Operation:\n%s', operation) + _LOGGER.debug('Operation: %s', operation) self._query_data = None - req = api_pb2.ExecSqlRequest() - req.sql = operation - for _ in range(RETRIES): - try: - resp = self._connection.cmd_query(req) - except exceptions.InterfaceError as e: - self._errors = e.args[0] - _LOGGER.error('Failed to execute query:\n%s', e) - return False - _LOGGER.debug('Response:\n%s', resp) - if resp.status == api_pb2.ExecSqlResponse.ERROR: - _LOGGER.error( - "Server returned an error executing:\n%s", operation) - self._errors = {"FAILED": operation} - if resp.status == api_pb2.ExecSqlResponse.REDIRECT and resp.redirect: - _LOGGER.debug("Redirecting to: %s", resp.redirect) - self._connection._redirect(resp.redirect) - continue - if resp.rows: - descriptors = self._get_descriptors(resp) - self.rowcount = self._set_data(descriptors, resp.rows) - else: - self.rowcount = 0 - return True - _LOGGER.error("Failed after %d retries.", RETRIES) - return False + request = api_pb2.ExecSqlRequest() + request.sql = operation + return self._handle(request) def executemany(self, operation, seq_of_parameters): """Prepare and execute a database operation multiple times. @@ -187,7 +199,8 @@ def _get_row_dict(self, descriptor, row): @property def json(self): - return json.dumps(self._query_data) + if self._query_data: + return json.dumps(self._query_data) # The following methods not implemented, although expected by DB-API 2.0 def callproc(self, procname): diff --git a/driver/python/test/cursor_test.py b/driver/python/test/cursor_test.py index 74e6f2f..09986c5 100644 --- a/driver/python/test/cursor_test.py +++ b/driver/python/test/cursor_test.py @@ -37,7 +37,7 @@ def test_select(self): def test_no_table(self): self.cur.execute("SELECT id FROM Test;") - self.assertDictEqual({'FAILED': 'SELECT id FROM Test;'}, + self.assertDictEqual({'FAILED': 'In operation.'}, self.cur._errors) def test_json(self):