Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix cursor.py failing in SELECT * queries #16

Merged
merged 2 commits into from
Mar 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions src/rockset_sqlalchemy/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,11 @@
from .cursor import Cursor
from .exceptions import ProgrammingError


class Connection(object):
def __init__(self, api_server, api_key, virtual_instance=None, debug_sql=False):
self._closed = False
self._client = RocksetClient(
host=api_server,
api_key=api_key
)
self._client = RocksetClient(host=api_server, api_key=api_key)
self.vi = virtual_instance
self.debug_sql = debug_sql
# Used for testing connectivity to Rockset.
Expand Down
34 changes: 18 additions & 16 deletions src/rockset_sqlalchemy/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,18 +43,23 @@ def execute_query(client, query, vi=None, query_params={}):
query=query,
parameters=[
rockset.models.QueryParameter(
name=param, value=str(val), type=Cursor.__convert_to_rockset_type(val)
name=param,
value=str(val),
type=Cursor.__convert_to_rockset_type(val),
)
for param, val in query_params.items()
]
],
)
try:
return client.VirtualInstances.query_virtual_instance(
virtual_instance_id=vi,
sql=request
) if vi else client.Queries.query(sql=request)
return (
client.VirtualInstances.query_virtual_instance(
virtual_instance_id=vi, sql=request
)
if vi
else client.Queries.query(sql=request)
)
except rockset.exceptions.RocksetException as e:
raise Error.map_rockset_exception(e)
raise Error.map_rockset_exception(e)

def execute(self, sql, parameters=None):
self.__check_cursor_opened()
Expand All @@ -68,7 +73,7 @@ def execute(self, sql, parameters=None):
else:
new_params[k] = v
parameters = new_params

if self._connection.debug_sql:
print("+++++++++++++++++++++++++++++")
print(f"Query:\n{sql}")
Expand All @@ -83,10 +88,7 @@ def execute(self, sql, parameters=None):
)

self._response = Cursor.execute_query(
self._connection._client,
sql,
self._connection.vi,
query_params=parameters
self._connection._client, sql, self._connection.vi, query_params=parameters
)
self._response_iter = iter(self._response.results)

Expand All @@ -108,8 +110,9 @@ def fetchone(self):
return None

result = []

for field in self._response_to_column_fields(self._response.column_fields):

column_fields = getattr(self._response, "column_fields", None)
for field in self._response_to_column_fields(column_fields):
name = field["name"]
if name in next_doc:
result.append(next_doc[name])
Expand All @@ -126,7 +129,7 @@ def _response_to_column_fields(self, column_fields):

schema = rockset.Document()
if self._response.results and len(self._response.results) > 0:
# we only look at the first document because
# we only look at the first document because
# is sqlalchemy is typically used for relational
# tables with no sparse fields
schema.update(self._response.results[0])
Expand All @@ -152,7 +155,6 @@ def fetchmany(self, size=None):
break
docs.append(doc)
return docs


@property
def description(self):
Expand Down
28 changes: 12 additions & 16 deletions src/rockset_sqlalchemy/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,27 @@
import rockset
from json import loads


class Error(rockset.exceptions.RocksetException):
@classmethod
def map_rockset_exception(cls, exc):
err_body = loads(exc.body)
args = [
err_body["message"],
exc.status,
err_body["type"]
]
args = [err_body["message"], exc.status, err_body["type"]]
exc_type = type(exc)
if (
exc_type == rockset.exceptions.ApiTypeError or
exc_type == rockset.exceptions.ApiValueError or
exc_type == rockset.exceptions.ApiAttributeError or
exc_type == rockset.exceptions.ApiKeyError or
exc_type == rockset.exceptions.NotFoundException or
exc_type == rockset.exceptions.InputException or
exc_type == rockset.exceptions.InitializationException or
exc_type == rockset.exceptions.BadRequestException

exc_type == rockset.exceptions.ApiTypeError
or exc_type == rockset.exceptions.ApiValueError
or exc_type == rockset.exceptions.ApiAttributeError
or exc_type == rockset.exceptions.ApiKeyError
or exc_type == rockset.exceptions.NotFoundException
or exc_type == rockset.exceptions.InputException
or exc_type == rockset.exceptions.InitializationException
or exc_type == rockset.exceptions.BadRequestException
):
ret = ProgrammingError(*args)
elif (
exc_type == rockset.exceptions.UnauthorizedException or
exc_type == rockset.exceptions.ForbiddenException
exc_type == rockset.exceptions.UnauthorizedException
or exc_type == rockset.exceptions.ForbiddenException
):
ret = OperationalError(*args)
elif exc_type == rockset.exceptions.ServiceException:
Expand Down
26 changes: 16 additions & 10 deletions src/rockset_sqlalchemy/sqlalchemy/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,11 @@ class RocksetDialect(default.DefaultDialect):

@classmethod
def dbapi(cls):
"""Retained for backward compatibility with SQLAlchemy 1.x.
"""
"""Retained for backward compatibility with SQLAlchemy 1.x."""
import rockset_sqlalchemy

return rockset_sqlalchemy

@classmethod
def import_dbapi(cls):
return RocksetDialect.dbapi()
Expand All @@ -57,20 +56,27 @@ def create_connect_args(self, url):
kwargs = {
"api_server": "https://{}".format(url.host),
"api_key": url.password or url.username,
"virtual_instance": url.database
"virtual_instance": url.database,
}
return ([], kwargs)

@reflection.cache
def get_schema_names(self, connection, **kw):
return [w["name"] for w in connection.connect().connection._client.Workspaces.list()["data"]]
return [
w["name"]
for w in connection.connect().connection._client.Workspaces.list()["data"]
]

@reflection.cache
def get_table_names(self, connection, schema=None, **kw):
tables = (connection.connect().connection._client.Collections.list()
if schema is None else
connection.connect().connection._client.Collections.workspace_collections(workspace=schema))['data']

tables = (
connection.connect().connection._client.Collections.list()
if schema is None
else connection.connect().connection._client.Collections.workspace_collections(
workspace=schema
)
)["data"]

return [w["name"] for w in tables]

def _get_table_columns(self, connection, table_name, schema):
Expand Down Expand Up @@ -132,7 +138,7 @@ def get_pk_constraint(self, connection, table_name, schema=None, **kw):
@reflection.cache
def get_indexes(self, connection, table_name, schema=None, **kw):
return []

def has_table(self, connection, table_name, schema=None):
try:
self._get_table_columns(connection, table_name, schema)
Expand Down