diff --git a/src/databricks/sqlalchemy/_parse.py b/src/databricks/sqlalchemy/_parse.py index 4bc7d65e..55f34f95 100644 --- a/src/databricks/sqlalchemy/_parse.py +++ b/src/databricks/sqlalchemy/_parse.py @@ -258,7 +258,7 @@ def get_pk_strings_from_dte_output( return output -# The keys of this dictionary are the values we expect to see in a +# The keys of this dictionary are the values we expect to see in a # TGetColumnsRequest's .TYPE_NAME attribute. # These are enumerated in ttypes.py as class TTypeId. # TODO: confirm that all types in TTypeId are included here. @@ -283,17 +283,40 @@ def get_pk_strings_from_dte_output( } +def parse_numeric_type_precision_and_scale(type_name_str): + """Return an intantiated sqlalchemy Numeric() type that preserves the precision and scale indicated + in the output from TGetColumnsRequest. + + type_name_str + The value of TGetColumnsReq.TYPE_NAME. + + If type_name_str is "DECIMAL(18,5) returns sqlalchemy.types.Numeric(18,5) + """ + + pattern = re.compile(r"DECIMAL\((\d+,\d+)\)") + match = re.search(pattern, type_name_str) + precision_and_scale = match.group(1) + precision, scale = tuple(precision_and_scale.split(",")) + + return sqlalchemy.types.Numeric(int(precision), int(scale)) + + def parse_column_info_from_tgetcolumnsresponse(thrift_resp_row) -> ReflectedColumn: """Returns a dictionary of the ReflectedColumn schema parsed from a single of the result of a TGetColumnsRequest thrift RPC - - Currently doesn't preserve decimal precision """ pat = re.compile(r"^\w+") _raw_col_type = re.search(pat, thrift_resp_row.TYPE_NAME).group(0).lower() _col_type = GET_COLUMNS_TYPE_MAP[_raw_col_type] + if _raw_col_type == "decimal": + final_col_type = parse_numeric_type_precision_and_scale( + thrift_resp_row.TYPE_NAME + ) + else: + final_col_type = _col_type + # See comments about autoincrement in test_suite.py # Since Databricks SQL doesn't currently support inline AUTOINCREMENT declarations # the autoincrement must be manually declared with an Identity() construct in SQLAlchemy @@ -306,7 +329,7 @@ def parse_column_info_from_tgetcolumnsresponse(thrift_resp_row) -> ReflectedColu # key in this dictionary. this_column = { "name": thrift_resp_row.COLUMN_NAME, - "type": _col_type, + "type": final_col_type, "nullable": bool(thrift_resp_row.NULLABLE), "default": thrift_resp_row.COLUMN_DEF, }