Skip to content

Commit

Permalink
Reformat files with black to allow tests to pass
Browse files Browse the repository at this point in the history
Reviewers: haneeshr
  • Loading branch information
mpatou committed Mar 4, 2024
1 parent 5f08916 commit 754d356
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 46 deletions.
6 changes: 2 additions & 4 deletions src/rockset_sqlalchemy/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,11 @@
from .cursor import Cursor
from .exceptions import ProgrammingError


class Connection(object):
def __init__(self, api_server, api_key, virtual_instance=None, debug_sql=False):
self._closed = False
self._client = RocksetClient(
host=api_server,
api_key=api_key
)
self._client = RocksetClient(host=api_server, api_key=api_key)
self.vi = virtual_instance
self.debug_sql = debug_sql
# Used for testing connectivity to Rockset.
Expand Down
33 changes: 17 additions & 16 deletions src/rockset_sqlalchemy/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,18 +43,23 @@ def execute_query(client, query, vi=None, query_params={}):
query=query,
parameters=[
rockset.models.QueryParameter(
name=param, value=str(val), type=Cursor.__convert_to_rockset_type(val)
name=param,
value=str(val),
type=Cursor.__convert_to_rockset_type(val),
)
for param, val in query_params.items()
]
],
)
try:
return client.VirtualInstances.query_virtual_instance(
virtual_instance_id=vi,
sql=request
) if vi else client.Queries.query(sql=request)
return (
client.VirtualInstances.query_virtual_instance(
virtual_instance_id=vi, sql=request
)
if vi
else client.Queries.query(sql=request)
)
except rockset.exceptions.RocksetException as e:
raise Error.map_rockset_exception(e)
raise Error.map_rockset_exception(e)

def execute(self, sql, parameters=None):
self.__check_cursor_opened()
Expand All @@ -68,7 +73,7 @@ def execute(self, sql, parameters=None):
else:
new_params[k] = v
parameters = new_params

if self._connection.debug_sql:
print("+++++++++++++++++++++++++++++")
print(f"Query:\n{sql}")
Expand All @@ -83,10 +88,7 @@ def execute(self, sql, parameters=None):
)

self._response = Cursor.execute_query(
self._connection._client,
sql,
self._connection.vi,
query_params=parameters
self._connection._client, sql, self._connection.vi, query_params=parameters
)
self._response_iter = iter(self._response.results)

Expand All @@ -108,8 +110,8 @@ def fetchone(self):
return None

result = []
for field in self._response_to_column_fields(self._response.column_fields):

for field in self._response_to_column_fields(self._response.column_fields):
name = field["name"]
if name in next_doc:
result.append(next_doc[name])
Expand All @@ -126,7 +128,7 @@ def _response_to_column_fields(self, column_fields):

schema = rockset.Document()
if self._response.results and len(self._response.results) > 0:
# we only look at the first document because
# we only look at the first document because
# is sqlalchemy is typically used for relational
# tables with no sparse fields
schema.update(self._response.results[0])
Expand All @@ -152,7 +154,6 @@ def fetchmany(self, size=None):
break
docs.append(doc)
return docs


@property
def description(self):
Expand Down
28 changes: 12 additions & 16 deletions src/rockset_sqlalchemy/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,27 @@
import rockset
from json import loads


class Error(rockset.exceptions.RocksetException):
@classmethod
def map_rockset_exception(cls, exc):
err_body = loads(exc.body)
args = [
err_body["message"],
exc.status,
err_body["type"]
]
args = [err_body["message"], exc.status, err_body["type"]]
exc_type = type(exc)
if (
exc_type == rockset.exceptions.ApiTypeError or
exc_type == rockset.exceptions.ApiValueError or
exc_type == rockset.exceptions.ApiAttributeError or
exc_type == rockset.exceptions.ApiKeyError or
exc_type == rockset.exceptions.NotFoundException or
exc_type == rockset.exceptions.InputException or
exc_type == rockset.exceptions.InitializationException or
exc_type == rockset.exceptions.BadRequestException

exc_type == rockset.exceptions.ApiTypeError
or exc_type == rockset.exceptions.ApiValueError
or exc_type == rockset.exceptions.ApiAttributeError
or exc_type == rockset.exceptions.ApiKeyError
or exc_type == rockset.exceptions.NotFoundException
or exc_type == rockset.exceptions.InputException
or exc_type == rockset.exceptions.InitializationException
or exc_type == rockset.exceptions.BadRequestException
):
ret = ProgrammingError(*args)
elif (
exc_type == rockset.exceptions.UnauthorizedException or
exc_type == rockset.exceptions.ForbiddenException
exc_type == rockset.exceptions.UnauthorizedException
or exc_type == rockset.exceptions.ForbiddenException
):
ret = OperationalError(*args)
elif exc_type == rockset.exceptions.ServiceException:
Expand Down
26 changes: 16 additions & 10 deletions src/rockset_sqlalchemy/sqlalchemy/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,11 @@ class RocksetDialect(default.DefaultDialect):

@classmethod
def dbapi(cls):
"""Retained for backward compatibility with SQLAlchemy 1.x.
"""
"""Retained for backward compatibility with SQLAlchemy 1.x."""
import rockset_sqlalchemy

return rockset_sqlalchemy

@classmethod
def import_dbapi(cls):
return RocksetDialect.dbapi()
Expand All @@ -57,20 +56,27 @@ def create_connect_args(self, url):
kwargs = {
"api_server": "https://{}".format(url.host),
"api_key": url.password or url.username,
"virtual_instance": url.database
"virtual_instance": url.database,
}
return ([], kwargs)

@reflection.cache
def get_schema_names(self, connection, **kw):
return [w["name"] for w in connection.connect().connection._client.Workspaces.list()["data"]]
return [
w["name"]
for w in connection.connect().connection._client.Workspaces.list()["data"]
]

@reflection.cache
def get_table_names(self, connection, schema=None, **kw):
tables = (connection.connect().connection._client.Collections.list()
if schema is None else
connection.connect().connection._client.Collections.workspace_collections(workspace=schema))['data']

tables = (
connection.connect().connection._client.Collections.list()
if schema is None
else connection.connect().connection._client.Collections.workspace_collections(
workspace=schema
)
)["data"]

return [w["name"] for w in tables]

def _get_table_columns(self, connection, table_name, schema):
Expand Down Expand Up @@ -132,7 +138,7 @@ def get_pk_constraint(self, connection, table_name, schema=None, **kw):
@reflection.cache
def get_indexes(self, connection, table_name, schema=None, **kw):
return []

def has_table(self, connection, table_name, schema=None):
try:
self._get_table_columns(connection, table_name, schema)
Expand Down

0 comments on commit 754d356

Please sign in to comment.