From 7d1e2487742127afe58db6a33322496ede2d57d8 Mon Sep 17 00:00:00 2001 From: Aarav Borthakur Date: Tue, 8 Aug 2023 15:44:50 -0700 Subject: [PATCH] Fix dialect bugs --- src/rockset_sqlalchemy/sqlalchemy/dialect.py | 33 +++++++++++++------- 1 file changed, 21 insertions(+), 12 deletions(-) diff --git a/src/rockset_sqlalchemy/sqlalchemy/dialect.py b/src/rockset_sqlalchemy/sqlalchemy/dialect.py index db5ccf4..e6be99d 100644 --- a/src/rockset_sqlalchemy/sqlalchemy/dialect.py +++ b/src/rockset_sqlalchemy/sqlalchemy/dialect.py @@ -1,4 +1,6 @@ -from sqlalchemy import exc, types, util +from rockset.exceptions import NotFoundException + +from sqlalchemy import exc from sqlalchemy.engine import default, reflection from sqlalchemy.sql import compiler @@ -63,25 +65,25 @@ def create_connect_args(self, url): @reflection.cache def get_schema_names(self, connection, **kw): - return [w["name"] for w in connection.connect().connection._client.Workspaces.list()["data"]] + return [w["name"] for w in connection._dbapi_connection.connection._client.Workspaces.list()["data"]] @reflection.cache def get_table_names(self, connection, schema=None, **kw): - tables = (connection.connect().connection._client.Collections.list() + tables = (connection._dbapi_connection.connection._client.Collections.list() if schema is None else - connection.connect().connection._client.Collections.workspace_collections(workspace=schema))['data'] + connection._dbapi_connection.connection._client.Collections.workspace_collections(workspace=schema))['data'] return [w["name"] for w in tables] def _get_table_columns(self, connection, table_name, schema): - schema = self.identifier_preparer.quote_identifier(schema) + schema = self.identifier_preparer.quote_identifier(schema or "commons") table_name = self.identifier_preparer.quote_identifier(table_name) # Get a single row and determine the schema from that. # This assumes the whole collection has a fixed schema of course. q = f"SELECT * FROM {schema}.{table_name} LIMIT 1" try: - cursor = connection.connect().connection.cursor() + cursor = connection._dbapi_connection.connection.cursor() cursor.execute(q) fields = cursor.description if not fields: @@ -105,6 +107,13 @@ def _get_table_columns(self, connection, table_name, schema): } ) except Exception as e: + try: + connection._dbapi_connection.connection._client.Collections.get( + collection=table_name, + workspace=schema + ) + except NotFoundException: + raise exc.NoSuchTableError(e) # TODO: more graceful handling of exceptions. raise e return columns @@ -116,24 +125,24 @@ def get_columns(self, connection, table_name, schema=None, **kw): return self._get_table_columns(connection, table_name, schema) @reflection.cache - def get_view_names(self, connection, schema=None, **kw): + def get_view_names(self, connection, schema="commons", **kw): # TODO: implement this. return [] @reflection.cache - def get_foreign_keys(self, connection, table_name, schema=None, **kw): + def get_foreign_keys(self, connection, table_name, schema="commons", **kw): # Rockset does not have foreign keys. return [] @reflection.cache - def get_pk_constraint(self, connection, table_name, schema=None, **kw): + def get_pk_constraint(self, connection, table_name, schema="commons", **kw): return {"constrained_columns": ["_id"], "name": "_id_pk"} @reflection.cache - def get_indexes(self, connection, table_name, schema=None, **kw): + def get_indexes(self, connection, table_name, schema="commons", **kw): return [] - - def has_table(self, connection, table_name, schema=None): + + def has_table(self, connection, table_name, schema="commons"): try: self._get_table_columns(connection, table_name, schema) return True