Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SQLAlchemy 2: Fix failing mypy checks from development #257

Merged
merged 14 commits into from
Oct 23, 2023
42 changes: 31 additions & 11 deletions src/databricks/sqlalchemy/_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@
"""


class DatabricksSqlAlchemyParseException(Exception):
pass


def _match_table_not_found_string(message: str) -> bool:
"""Return True if the message contains a substring indicating that a table was not found"""

Expand All @@ -31,7 +35,7 @@ def _describe_table_extended_result_to_dict_list(
"""Transform the CursorResult of DESCRIBE TABLE EXTENDED into a list of Dictionaries"""

rows_to_return = []
for row in result:
for row in result.all():
this_row = {"col_name": row.col_name, "data_type": row.data_type}
rows_to_return.append(this_row)

Expand Down Expand Up @@ -69,24 +73,33 @@ def extract_three_level_identifier_from_constraint_string(input_str: str) -> dic
"schema": "pysql_dialect_compliance",
"table": "users"
}

Raise a DatabricksSqlAlchemyParseException if a 3L namespace isn't found
"""
pat = re.compile(r"REFERENCES\s+(.*?)\s*\(")
matches = pat.findall(input_str)

if not matches:
return None
raise DatabricksSqlAlchemyParseException(
"3L namespace not found in constraint string"
)

first_match = matches[0]
parts = first_match.split(".")

def strip_backticks(input: str):
return input.replace("`", "")

return {
"catalog": strip_backticks(parts[0]),
"schema": strip_backticks(parts[1]),
"table": strip_backticks(parts[2]),
}
try:
return {
"catalog": strip_backticks(parts[0]),
"schema": strip_backticks(parts[1]),
"table": strip_backticks(parts[2]),
}
except IndexError:
raise DatabricksSqlAlchemyParseException(
"Incomplete 3L namespace found in constraint string: " + ".".join(parts)
)


def _parse_fk_from_constraint_string(constraint_str: str) -> dict:
Expand Down Expand Up @@ -170,10 +183,12 @@ def build_fk_dict(
else:
schema_override_dict = {}

# mypy doesn't like this method of conditionally adding a key to a dictionary
# while keeping everything immutable
complete_foreign_key_dict = {
"name": fk_name,
**base_fk_dict,
**schema_override_dict,
**schema_override_dict, # type: ignore
}

return complete_foreign_key_dict
Expand Down Expand Up @@ -234,7 +249,7 @@ def match_dte_rows_by_value(dte_output: List[Dict[str, str]], match: str) -> Lis
return output_rows


def get_fk_strings_from_dte_output(dte_output: List[List]) -> List[dict]:
def get_fk_strings_from_dte_output(dte_output: List[Dict[str, str]]) -> List[dict]:
"""If the DESCRIBE TABLE EXTENDED output contains foreign key constraints, return a list of dictionaries,
one dictionary per defined constraint
"""
Expand Down Expand Up @@ -307,7 +322,11 @@ def parse_column_info_from_tgetcolumnsresponse(thrift_resp_row) -> ReflectedColu
"""

pat = re.compile(r"^\w+")
_raw_col_type = re.search(pat, thrift_resp_row.TYPE_NAME).group(0).lower()

# This method assumes a valid TYPE_NAME field in the response.
# TODO: add error handling in case TGetColumnsResponse format changes

_raw_col_type = re.search(pat, thrift_resp_row.TYPE_NAME).group(0).lower() # type: ignore
_col_type = GET_COLUMNS_TYPE_MAP[_raw_col_type]

if _raw_col_type == "decimal":
Expand All @@ -334,4 +353,5 @@ def parse_column_info_from_tgetcolumnsresponse(thrift_resp_row) -> ReflectedColu
"default": thrift_resp_row.COLUMN_DEF,
}

return this_column
# TODO: figure out how to return sqlalchemy.interfaces in a way that mypy respects
return this_column # type: ignore
28 changes: 18 additions & 10 deletions src/databricks/sqlalchemy/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import re
from typing import Any, List, Optional, Dict, Collection, Iterable, Tuple
from typing import Any, List, Optional, Dict, Union, Collection, Iterable, Tuple

import databricks.sqlalchemy._ddl as dialect_ddl_impl
import databricks.sqlalchemy._types as dialect_type_impl
Expand Down Expand Up @@ -73,10 +73,12 @@ class DatabricksDialect(default.DefaultDialect):

# SQLAlchemy requires that a table with no primary key
# constraint return a dictionary that looks like this.
EMPTY_PK = {"constrained_columns": [], "name": None}
EMPTY_PK: Dict[str, Any] = {"constrained_columns": [], "name": None}

# SQLAlchemy requires that a table with no foreign keys
# defined return an empty list. Same for indexes.
EMPTY_FK: List
EMPTY_INDEX: List
EMPTY_FK = EMPTY_INDEX = []

@classmethod
Expand Down Expand Up @@ -139,7 +141,7 @@ def _describe_table_extended(
catalog_name: Optional[str] = None,
schema_name: Optional[str] = None,
expect_result=True,
) -> List[Dict[str, str]]:
) -> Union[List[Dict[str, str]], None]:
"""Run DESCRIBE TABLE EXTENDED on a table and return a list of dictionaries of the result.

This method is the fastest way to check for the presence of a table in a schema.
Expand All @@ -158,7 +160,7 @@ def _describe_table_extended(
stmt = DDL(f"DESCRIBE TABLE EXTENDED {_target}")

try:
result = connection.execute(stmt).all()
result = connection.execute(stmt)
except DatabaseError as e:
if _match_table_not_found_string(str(e)):
raise sqlalchemy.exc.NoSuchTableError(
Expand Down Expand Up @@ -197,9 +199,11 @@ def get_pk_constraint(
schema_name=schema,
)

raw_pk_constraints: List = get_pk_strings_from_dte_output(result)
# Type ignore is because mypy knows that self._describe_table_extended *can*
# return None (even though it never will since expect_result defaults to True)
raw_pk_constraints: List = get_pk_strings_from_dte_output(result) # type: ignore
if not any(raw_pk_constraints):
return self.EMPTY_PK
return self.EMPTY_PK # type: ignore

if len(raw_pk_constraints) > 1:
logger.warning(
Expand All @@ -212,11 +216,12 @@ def get_pk_constraint(
pk_name = first_pk_constraint.get("col_name")
pk_constraint_string = first_pk_constraint.get("data_type")

return build_pk_dict(pk_name, pk_constraint_string)
# TODO: figure out how to return sqlalchemy.interfaces in a way that mypy respects
return build_pk_dict(pk_name, pk_constraint_string) # type: ignore

def get_foreign_keys(
self, connection, table_name, schema=None, **kw
) -> ReflectedForeignKeyConstraint:
) -> List[ReflectedForeignKeyConstraint]:
"""Return information about foreign_keys in `table_name`."""

result = self._describe_table_extended(
Expand All @@ -225,7 +230,9 @@ def get_foreign_keys(
schema_name=schema,
)

raw_fk_constraints: List = get_fk_strings_from_dte_output(result)
# Type ignore is because mypy knows that self._describe_table_extended *can*
# return None (even though it never will since expect_result defaults to True)
raw_fk_constraints: List = get_fk_strings_from_dte_output(result) # type: ignore

if not any(raw_fk_constraints):
return self.EMPTY_FK
Expand All @@ -239,7 +246,8 @@ def get_foreign_keys(
)
fk_constraints.append(this_constraint_dict)

return fk_constraints
# TODO: figure out how to return sqlalchemy.interfaces in a way that mypy respects
return fk_constraints # type: ignore

def get_indexes(self, connection, table_name, schema=None, **kw):
"""SQLAlchemy requires this method. Databricks doesn't support indexes."""
Expand Down
2 changes: 2 additions & 0 deletions src/databricks/sqlalchemy/test/_future.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# type: ignore

from enum import Enum

import pytest
Expand Down
2 changes: 2 additions & 0 deletions src/databricks/sqlalchemy/test/_regression.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# type: ignore

import pytest
from sqlalchemy.testing.suite import (
ArgSignatureTest,
Expand Down
2 changes: 2 additions & 0 deletions src/databricks/sqlalchemy/test/_unsupported.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# type: ignore

from enum import Enum

import pytest
Expand Down
8 changes: 8 additions & 0 deletions src/databricks/sqlalchemy/test_local/test_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
build_fk_dict,
build_pk_dict,
match_dte_rows_by_value,
DatabricksSqlAlchemyParseException,
)


Expand Down Expand Up @@ -55,6 +56,13 @@ def test_extract_3l_namespace_from_constraint_string():
), "Failed to extract 3L namespace from constraint string"


def test_extract_3l_namespace_from_bad_constraint_string():
input = "FOREIGN KEY (`parent_user_id`) REFERENCES `pysql_dialect_compliance`.`users` (`user_id`)"

with pytest.raises(DatabricksSqlAlchemyParseException):
extract_three_level_identifier_from_constraint_string(input)


@pytest.mark.parametrize("schema", [None, "some_schema"])
def test_build_fk_dict(schema):
fk_constraint_string = "FOREIGN KEY (`parent_user_id`) REFERENCES `main`.`some_schema`.`users` (`user_id`)"
Expand Down