diff --git a/src/rockset_sqlalchemy/connection.py b/src/rockset_sqlalchemy/connection.py index bc54737..b3e698a 100644 --- a/src/rockset_sqlalchemy/connection.py +++ b/src/rockset_sqlalchemy/connection.py @@ -3,13 +3,11 @@ from .cursor import Cursor from .exceptions import ProgrammingError + class Connection(object): def __init__(self, api_server, api_key, virtual_instance=None, debug_sql=False): self._closed = False - self._client = RocksetClient( - host=api_server, - api_key=api_key - ) + self._client = RocksetClient(host=api_server, api_key=api_key) self.vi = virtual_instance self.debug_sql = debug_sql # Used for testing connectivity to Rockset. diff --git a/src/rockset_sqlalchemy/cursor.py b/src/rockset_sqlalchemy/cursor.py index 0ba8c81..9e6ecf1 100644 --- a/src/rockset_sqlalchemy/cursor.py +++ b/src/rockset_sqlalchemy/cursor.py @@ -43,18 +43,23 @@ def execute_query(client, query, vi=None, query_params={}): query=query, parameters=[ rockset.models.QueryParameter( - name=param, value=str(val), type=Cursor.__convert_to_rockset_type(val) + name=param, + value=str(val), + type=Cursor.__convert_to_rockset_type(val), ) for param, val in query_params.items() - ] + ], ) try: - return client.VirtualInstances.query_virtual_instance( - virtual_instance_id=vi, - sql=request - ) if vi else client.Queries.query(sql=request) + return ( + client.VirtualInstances.query_virtual_instance( + virtual_instance_id=vi, sql=request + ) + if vi + else client.Queries.query(sql=request) + ) except rockset.exceptions.RocksetException as e: - raise Error.map_rockset_exception(e) + raise Error.map_rockset_exception(e) def execute(self, sql, parameters=None): self.__check_cursor_opened() @@ -68,7 +73,7 @@ def execute(self, sql, parameters=None): else: new_params[k] = v parameters = new_params - + if self._connection.debug_sql: print("+++++++++++++++++++++++++++++") print(f"Query:\n{sql}") @@ -83,10 +88,7 @@ def execute(self, sql, parameters=None): ) self._response = Cursor.execute_query( - self._connection._client, - sql, - self._connection.vi, - query_params=parameters + self._connection._client, sql, self._connection.vi, query_params=parameters ) self._response_iter = iter(self._response.results) @@ -108,8 +110,9 @@ def fetchone(self): return None result = [] - - for field in self._response_to_column_fields(self._response.column_fields): + + column_fields = getattr(self._response, "column_fields", None) + for field in self._response_to_column_fields(column_fields): name = field["name"] if name in next_doc: result.append(next_doc[name]) @@ -126,7 +129,7 @@ def _response_to_column_fields(self, column_fields): schema = rockset.Document() if self._response.results and len(self._response.results) > 0: - # we only look at the first document because + # we only look at the first document because # is sqlalchemy is typically used for relational # tables with no sparse fields schema.update(self._response.results[0]) @@ -152,7 +155,6 @@ def fetchmany(self, size=None): break docs.append(doc) return docs - @property def description(self): diff --git a/src/rockset_sqlalchemy/exceptions.py b/src/rockset_sqlalchemy/exceptions.py index 0bb40f4..a1de997 100644 --- a/src/rockset_sqlalchemy/exceptions.py +++ b/src/rockset_sqlalchemy/exceptions.py @@ -1,31 +1,27 @@ import rockset from json import loads + class Error(rockset.exceptions.RocksetException): @classmethod def map_rockset_exception(cls, exc): err_body = loads(exc.body) - args = [ - err_body["message"], - exc.status, - err_body["type"] - ] + args = [err_body["message"], exc.status, err_body["type"]] exc_type = type(exc) if ( - exc_type == rockset.exceptions.ApiTypeError or - exc_type == rockset.exceptions.ApiValueError or - exc_type == rockset.exceptions.ApiAttributeError or - exc_type == rockset.exceptions.ApiKeyError or - exc_type == rockset.exceptions.NotFoundException or - exc_type == rockset.exceptions.InputException or - exc_type == rockset.exceptions.InitializationException or - exc_type == rockset.exceptions.BadRequestException - + exc_type == rockset.exceptions.ApiTypeError + or exc_type == rockset.exceptions.ApiValueError + or exc_type == rockset.exceptions.ApiAttributeError + or exc_type == rockset.exceptions.ApiKeyError + or exc_type == rockset.exceptions.NotFoundException + or exc_type == rockset.exceptions.InputException + or exc_type == rockset.exceptions.InitializationException + or exc_type == rockset.exceptions.BadRequestException ): ret = ProgrammingError(*args) elif ( - exc_type == rockset.exceptions.UnauthorizedException or - exc_type == rockset.exceptions.ForbiddenException + exc_type == rockset.exceptions.UnauthorizedException + or exc_type == rockset.exceptions.ForbiddenException ): ret = OperationalError(*args) elif exc_type == rockset.exceptions.ServiceException: diff --git a/src/rockset_sqlalchemy/sqlalchemy/dialect.py b/src/rockset_sqlalchemy/sqlalchemy/dialect.py index db5ccf4..85da2bd 100644 --- a/src/rockset_sqlalchemy/sqlalchemy/dialect.py +++ b/src/rockset_sqlalchemy/sqlalchemy/dialect.py @@ -43,12 +43,11 @@ class RocksetDialect(default.DefaultDialect): @classmethod def dbapi(cls): - """Retained for backward compatibility with SQLAlchemy 1.x. - """ + """Retained for backward compatibility with SQLAlchemy 1.x.""" import rockset_sqlalchemy return rockset_sqlalchemy - + @classmethod def import_dbapi(cls): return RocksetDialect.dbapi() @@ -57,20 +56,27 @@ def create_connect_args(self, url): kwargs = { "api_server": "https://{}".format(url.host), "api_key": url.password or url.username, - "virtual_instance": url.database + "virtual_instance": url.database, } return ([], kwargs) @reflection.cache def get_schema_names(self, connection, **kw): - return [w["name"] for w in connection.connect().connection._client.Workspaces.list()["data"]] + return [ + w["name"] + for w in connection.connect().connection._client.Workspaces.list()["data"] + ] @reflection.cache def get_table_names(self, connection, schema=None, **kw): - tables = (connection.connect().connection._client.Collections.list() - if schema is None else - connection.connect().connection._client.Collections.workspace_collections(workspace=schema))['data'] - + tables = ( + connection.connect().connection._client.Collections.list() + if schema is None + else connection.connect().connection._client.Collections.workspace_collections( + workspace=schema + ) + )["data"] + return [w["name"] for w in tables] def _get_table_columns(self, connection, table_name, schema): @@ -132,7 +138,7 @@ def get_pk_constraint(self, connection, table_name, schema=None, **kw): @reflection.cache def get_indexes(self, connection, table_name, schema=None, **kw): return [] - + def has_table(self, connection, table_name, schema=None): try: self._get_table_columns(connection, table_name, schema)