diff --git a/src/databricks/sqlalchemy/_parse.py b/src/databricks/sqlalchemy/_parse.py index 55f34f95..3d96e160 100644 --- a/src/databricks/sqlalchemy/_parse.py +++ b/src/databricks/sqlalchemy/_parse.py @@ -12,6 +12,10 @@ """ +class DatabricksSqlAlchemyParseException(Exception): + pass + + def _match_table_not_found_string(message: str) -> bool: """Return True if the message contains a substring indicating that a table was not found""" @@ -31,7 +35,7 @@ def _describe_table_extended_result_to_dict_list( """Transform the CursorResult of DESCRIBE TABLE EXTENDED into a list of Dictionaries""" rows_to_return = [] - for row in result: + for row in result.all(): this_row = {"col_name": row.col_name, "data_type": row.data_type} rows_to_return.append(this_row) @@ -69,12 +73,16 @@ def extract_three_level_identifier_from_constraint_string(input_str: str) -> dic "schema": "pysql_dialect_compliance", "table": "users" } + + Raise a DatabricksSqlAlchemyParseException if a 3L namespace isn't found """ pat = re.compile(r"REFERENCES\s+(.*?)\s*\(") matches = pat.findall(input_str) if not matches: - return None + raise DatabricksSqlAlchemyParseException( + "3L namespace not found in constraint string" + ) first_match = matches[0] parts = first_match.split(".") @@ -82,11 +90,16 @@ def extract_three_level_identifier_from_constraint_string(input_str: str) -> dic def strip_backticks(input: str): return input.replace("`", "") - return { - "catalog": strip_backticks(parts[0]), - "schema": strip_backticks(parts[1]), - "table": strip_backticks(parts[2]), - } + try: + return { + "catalog": strip_backticks(parts[0]), + "schema": strip_backticks(parts[1]), + "table": strip_backticks(parts[2]), + } + except IndexError: + raise DatabricksSqlAlchemyParseException( + "Incomplete 3L namespace found in constraint string: " + ".".join(parts) + ) def _parse_fk_from_constraint_string(constraint_str: str) -> dict: @@ -170,10 +183,12 @@ def build_fk_dict( else: schema_override_dict = {} + # mypy doesn't like this method of conditionally adding a key to a dictionary + # while keeping everything immutable complete_foreign_key_dict = { "name": fk_name, **base_fk_dict, - **schema_override_dict, + **schema_override_dict, # type: ignore } return complete_foreign_key_dict @@ -234,7 +249,7 @@ def match_dte_rows_by_value(dte_output: List[Dict[str, str]], match: str) -> Lis return output_rows -def get_fk_strings_from_dte_output(dte_output: List[List]) -> List[dict]: +def get_fk_strings_from_dte_output(dte_output: List[Dict[str, str]]) -> List[dict]: """If the DESCRIBE TABLE EXTENDED output contains foreign key constraints, return a list of dictionaries, one dictionary per defined constraint """ @@ -307,7 +322,11 @@ def parse_column_info_from_tgetcolumnsresponse(thrift_resp_row) -> ReflectedColu """ pat = re.compile(r"^\w+") - _raw_col_type = re.search(pat, thrift_resp_row.TYPE_NAME).group(0).lower() + + # This method assumes a valid TYPE_NAME field in the response. + # TODO: add error handling in case TGetColumnsResponse format changes + + _raw_col_type = re.search(pat, thrift_resp_row.TYPE_NAME).group(0).lower() # type: ignore _col_type = GET_COLUMNS_TYPE_MAP[_raw_col_type] if _raw_col_type == "decimal": @@ -334,4 +353,5 @@ def parse_column_info_from_tgetcolumnsresponse(thrift_resp_row) -> ReflectedColu "default": thrift_resp_row.COLUMN_DEF, } - return this_column + # TODO: figure out how to return sqlalchemy.interfaces in a way that mypy respects + return this_column # type: ignore diff --git a/src/databricks/sqlalchemy/base.py b/src/databricks/sqlalchemy/base.py index 937d0a9e..3684789e 100644 --- a/src/databricks/sqlalchemy/base.py +++ b/src/databricks/sqlalchemy/base.py @@ -1,5 +1,5 @@ import re -from typing import Any, List, Optional, Dict, Collection, Iterable, Tuple +from typing import Any, List, Optional, Dict, Union, Collection, Iterable, Tuple import databricks.sqlalchemy._ddl as dialect_ddl_impl import databricks.sqlalchemy._types as dialect_type_impl @@ -73,10 +73,12 @@ class DatabricksDialect(default.DefaultDialect): # SQLAlchemy requires that a table with no primary key # constraint return a dictionary that looks like this. - EMPTY_PK = {"constrained_columns": [], "name": None} + EMPTY_PK: Dict[str, Any] = {"constrained_columns": [], "name": None} # SQLAlchemy requires that a table with no foreign keys # defined return an empty list. Same for indexes. + EMPTY_FK: List + EMPTY_INDEX: List EMPTY_FK = EMPTY_INDEX = [] @classmethod @@ -139,7 +141,7 @@ def _describe_table_extended( catalog_name: Optional[str] = None, schema_name: Optional[str] = None, expect_result=True, - ) -> List[Dict[str, str]]: + ) -> Union[List[Dict[str, str]], None]: """Run DESCRIBE TABLE EXTENDED on a table and return a list of dictionaries of the result. This method is the fastest way to check for the presence of a table in a schema. @@ -158,7 +160,7 @@ def _describe_table_extended( stmt = DDL(f"DESCRIBE TABLE EXTENDED {_target}") try: - result = connection.execute(stmt).all() + result = connection.execute(stmt) except DatabaseError as e: if _match_table_not_found_string(str(e)): raise sqlalchemy.exc.NoSuchTableError( @@ -197,9 +199,11 @@ def get_pk_constraint( schema_name=schema, ) - raw_pk_constraints: List = get_pk_strings_from_dte_output(result) + # Type ignore is because mypy knows that self._describe_table_extended *can* + # return None (even though it never will since expect_result defaults to True) + raw_pk_constraints: List = get_pk_strings_from_dte_output(result) # type: ignore if not any(raw_pk_constraints): - return self.EMPTY_PK + return self.EMPTY_PK # type: ignore if len(raw_pk_constraints) > 1: logger.warning( @@ -212,11 +216,12 @@ def get_pk_constraint( pk_name = first_pk_constraint.get("col_name") pk_constraint_string = first_pk_constraint.get("data_type") - return build_pk_dict(pk_name, pk_constraint_string) + # TODO: figure out how to return sqlalchemy.interfaces in a way that mypy respects + return build_pk_dict(pk_name, pk_constraint_string) # type: ignore def get_foreign_keys( self, connection, table_name, schema=None, **kw - ) -> ReflectedForeignKeyConstraint: + ) -> List[ReflectedForeignKeyConstraint]: """Return information about foreign_keys in `table_name`.""" result = self._describe_table_extended( @@ -225,7 +230,9 @@ def get_foreign_keys( schema_name=schema, ) - raw_fk_constraints: List = get_fk_strings_from_dte_output(result) + # Type ignore is because mypy knows that self._describe_table_extended *can* + # return None (even though it never will since expect_result defaults to True) + raw_fk_constraints: List = get_fk_strings_from_dte_output(result) # type: ignore if not any(raw_fk_constraints): return self.EMPTY_FK @@ -239,7 +246,8 @@ def get_foreign_keys( ) fk_constraints.append(this_constraint_dict) - return fk_constraints + # TODO: figure out how to return sqlalchemy.interfaces in a way that mypy respects + return fk_constraints # type: ignore def get_indexes(self, connection, table_name, schema=None, **kw): """SQLAlchemy requires this method. Databricks doesn't support indexes.""" diff --git a/src/databricks/sqlalchemy/test/_future.py b/src/databricks/sqlalchemy/test/_future.py index 0da38634..519a4e09 100644 --- a/src/databricks/sqlalchemy/test/_future.py +++ b/src/databricks/sqlalchemy/test/_future.py @@ -1,3 +1,5 @@ +# type: ignore + from enum import Enum import pytest diff --git a/src/databricks/sqlalchemy/test/_regression.py b/src/databricks/sqlalchemy/test/_regression.py index aeeb5c3f..6342d2d5 100644 --- a/src/databricks/sqlalchemy/test/_regression.py +++ b/src/databricks/sqlalchemy/test/_regression.py @@ -1,3 +1,5 @@ +# type: ignore + import pytest from sqlalchemy.testing.suite import ( ArgSignatureTest, diff --git a/src/databricks/sqlalchemy/test/_unsupported.py b/src/databricks/sqlalchemy/test/_unsupported.py index 63932fe2..899e73e4 100644 --- a/src/databricks/sqlalchemy/test/_unsupported.py +++ b/src/databricks/sqlalchemy/test/_unsupported.py @@ -1,3 +1,5 @@ +# type: ignore + from enum import Enum import pytest diff --git a/src/databricks/sqlalchemy/test_local/test_parsing.py b/src/databricks/sqlalchemy/test_local/test_parsing.py index ab82613e..f17814f9 100644 --- a/src/databricks/sqlalchemy/test_local/test_parsing.py +++ b/src/databricks/sqlalchemy/test_local/test_parsing.py @@ -6,6 +6,7 @@ build_fk_dict, build_pk_dict, match_dte_rows_by_value, + DatabricksSqlAlchemyParseException, ) @@ -55,6 +56,13 @@ def test_extract_3l_namespace_from_constraint_string(): ), "Failed to extract 3L namespace from constraint string" +def test_extract_3l_namespace_from_bad_constraint_string(): + input = "FOREIGN KEY (`parent_user_id`) REFERENCES `pysql_dialect_compliance`.`users` (`user_id`)" + + with pytest.raises(DatabricksSqlAlchemyParseException): + extract_three_level_identifier_from_constraint_string(input) + + @pytest.mark.parametrize("schema", [None, "some_schema"]) def test_build_fk_dict(schema): fk_constraint_string = "FOREIGN KEY (`parent_user_id`) REFERENCES `main`.`some_schema`.`users` (`user_id`)"