From fe7bf4dd04849e21b854eaf99cd31657e1abcf27 Mon Sep 17 00:00:00 2001 From: Serge Smertin <259697+nfx@users.noreply.github.com> Date: Wed, 11 Sep 2024 18:27:49 +0200 Subject: [PATCH] Added support for generic types in `SqlBackend` (#272) This PR adds the ability to use rich dataclasses like: ```python @dataclass class Foo: first: str second: bool | None @dataclass class Nested: foo: Foo mapping: dict[str, int] array: list[int] ``` --- src/databricks/labs/lsql/backends.py | 99 +++++++------- src/databricks/labs/lsql/structs.py | 192 +++++++++++++++++++++++++++ tests/integration/test_structs.py | 36 +++++ tests/unit/test_backends.py | 39 +++++- tests/unit/test_structs.py | 71 ++++++++++ 5 files changed, 383 insertions(+), 54 deletions(-) create mode 100644 src/databricks/labs/lsql/structs.py create mode 100644 tests/integration/test_structs.py create mode 100644 tests/unit/test_structs.py diff --git a/src/databricks/labs/lsql/backends.py b/src/databricks/labs/lsql/backends.py index 6be86f7a..f8b9230b 100644 --- a/src/databricks/labs/lsql/backends.py +++ b/src/databricks/labs/lsql/backends.py @@ -1,10 +1,10 @@ import dataclasses +import datetime import logging import os import re from abc import ABC, abstractmethod from collections.abc import Callable, Iterable, Iterator, Sequence -from types import UnionType from typing import Any, ClassVar, Protocol, TypeVar from databricks.labs.blueprint.commands import CommandExecutor @@ -20,6 +20,7 @@ from databricks.sdk.service.compute import Language from databricks.labs.lsql.core import Row, StatementExecutionExt +from databricks.labs.lsql.structs import StructInference logger = logging.getLogger(__name__) @@ -42,6 +43,10 @@ class SqlBackend(ABC): execute SQL statements, fetch results from SQL statements, and save data to tables.""" + # singleton shared across all SQL backends, used to infer schema from dataclasses. + # no state is stored in this class, so it can be shared across all instances. + _STRUCTS = StructInference() + @abstractmethod def execute(self, sql: str, *, catalog: str | None = None, schema: str | None = None) -> None: raise NotImplementedError @@ -55,44 +60,9 @@ def save_table(self, full_name: str, rows: Sequence[DataclassInstance], klass: D raise NotImplementedError def create_table(self, full_name: str, klass: Dataclass): - ddl = f"CREATE TABLE IF NOT EXISTS {full_name} ({self._schema_for(klass)}) USING DELTA" + ddl = f"CREATE TABLE IF NOT EXISTS {full_name} ({self._STRUCTS.as_schema(klass)}) USING DELTA" self.execute(ddl) - _builtin_type_mapping: ClassVar[dict[type, str]] = { - str: "STRING", - int: "LONG", - bool: "BOOLEAN", - float: "FLOAT", - } - - @classmethod - def _schema_for(cls, klass: Dataclass): - fields = [] - for f in dataclasses.fields(klass): - field_type = cls._field_type(f) - if isinstance(field_type, UnionType): - field_type = field_type.__args__[0] - if field_type not in cls._builtin_type_mapping: - msg = f"Cannot auto-convert {field_type}" - raise SyntaxError(msg) - not_null = " NOT NULL" - if f.default is None: - not_null = "" - spark_type = cls._builtin_type_mapping[field_type] - fields.append(f"{f.name} {spark_type}{not_null}") - return ", ".join(fields) - - @classmethod - def _field_type(cls, field: dataclasses.Field): - # workaround rare (Python?) issue where f.type is the type name instead of the type itself - # this seems to happen when the dataclass is first used from a file importing it - if isinstance(field.type, str): - try: - return __builtins__[field.type] - except TypeError as e: - logger.warning(f"Could not load type {field.type}", exc_info=e) - return field.type - @classmethod def _filter_none_rows(cls, rows, klass): if len(rows) == 0: @@ -177,23 +147,46 @@ def _row_to_sql(cls, row: DataclassInstance, fields: tuple[dataclasses.Field[Any data = [] for f in fields: value = getattr(row, f.name) - field_type = cls._field_type(f) - if isinstance(field_type, UnionType): - field_type = field_type.__args__[0] - if value is None: - data.append("NULL") - elif field_type is bool: - data.append("TRUE" if value else "FALSE") - elif field_type is str: - value = str(value).replace("'", "''") - data.append(f"'{value}'") - elif field_type is int: - data.append(f"{value}") - else: - msg = f"unknown type: {field_type}" - raise ValueError(msg) + data.append(cls._value_to_sql(value)) return ", ".join(data) + @classmethod + def _value_to_sql(cls, value: Any) -> str: + """Converts a Python value to a SQL string representation.""" + if value is None: + return "NULL" + if isinstance(value, bool): + return "TRUE" if value else "FALSE" + if isinstance(value, int): + return f"{value}" + if isinstance(value, float): + return f"{value}" + if isinstance(value, str): + value = str(value).replace("'", "''") + return f"'{value}'" + if isinstance(value, datetime.datetime): + return f"TIMESTAMP '{value.strftime('%Y-%m-%d %H:%M:%S%z')}'" + if isinstance(value, datetime.date): + return f"DATE '{value.year}-{value.month}-{value.day}'" + if isinstance(value, list): + values = ", ".join(cls._value_to_sql(v) for v in value) + return f"ARRAY({values})" + if isinstance(value, dict): + map_values: list[str] = [] + for k, v in value.items(): + map_values.append(cls._value_to_sql(k)) + map_values.append(cls._value_to_sql(v)) + return f"MAP({', '.join(map_values)})" + if dataclasses.is_dataclass(value): + struct = [] + for f in dataclasses.fields(value): + v = getattr(value, f.name) + sql_value = f"{cls._value_to_sql(v)} AS {f.name}" + struct.append(sql_value) + return f"STRUCT({', '.join(struct)})" + msg = f"unsupported: {value}" + raise ValueError(msg) + class StatementExecutionBackend(ExecutionBackend): def __init__(self, ws: WorkspaceClient, warehouse_id, *, max_records_per_batch: int = 1000, **kwargs): @@ -273,7 +266,7 @@ def save_table(self, full_name: str, rows: Sequence[DataclassInstance], klass: D self.create_table(full_name, klass) return # pyspark deals well with lists of dataclass instances, as long as schema is provided - df = self._spark.createDataFrame(rows, self._schema_for(klass)) + df = self._spark.createDataFrame(rows, self._STRUCTS.as_schema(klass)) df.write.saveAsTable(full_name, mode=mode) diff --git a/src/databricks/labs/lsql/structs.py b/src/databricks/labs/lsql/structs.py new file mode 100644 index 00000000..b7f129b1 --- /dev/null +++ b/src/databricks/labs/lsql/structs.py @@ -0,0 +1,192 @@ +import dataclasses +import datetime +import enum +import types +from dataclasses import dataclass +from typing import ClassVar, Protocol, get_args, get_type_hints + + +class StructInferError(TypeError): + pass + + +class SqlType(Protocol): + """Represents a Spark SQL type.""" + + def as_sql(self) -> str: ... + + +@dataclass +class NullableType(SqlType): + """Represents a nullable type.""" + + inner_type: SqlType + + def as_sql(self) -> str: + return self.inner_type.as_sql() + + +@dataclass +class ArrayType(SqlType): + """Represents an array type.""" + + element_type: SqlType + + def as_sql(self) -> str: + return f"ARRAY<{self.element_type.as_sql()}>" + + +@dataclass +class MapType(SqlType): + """Represents a map type.""" + + key_type: SqlType + value_type: SqlType + + def as_sql(self) -> str: + return f"MAP<{self.key_type.as_sql()},{self.value_type.as_sql()}>" + + +@dataclass +class PrimitiveType(SqlType): + """Represents a primitive type.""" + + name: str + + def as_sql(self) -> str: + return self.name + + +@dataclass +class StructField: + """Represents a field in a struct type.""" + + name: str + type: SqlType + + @property + def nullable(self) -> bool: + return isinstance(self.type, NullableType) + + def as_sql(self) -> str: + return f"{self.name}:{self.type.as_sql()}" + + +@dataclass +class StructType(SqlType): + """Represents a struct type.""" + + fields: list[StructField] + + def as_sql(self) -> str: + """Returns a DDL representation of the struct type.""" + fields = ",".join(f.as_sql() for f in self.fields) + return f"STRUCT<{fields}>" + + def as_schema(self) -> str: + """Returns a schema representation of the struct type.""" + fields = [] + for field in self.fields: + not_null = "" if field.nullable else " NOT NULL" + fields.append(f"{field.name} {field.type.as_sql()}{not_null}") + return ", ".join(fields) + + +class StructInference: + """Infers Spark SQL types from Python types.""" + + _PRIMITIVES: ClassVar[dict[type, str]] = { + str: "STRING", + int: "LONG", + bool: "BOOLEAN", + float: "FLOAT", + datetime.date: "DATE", + datetime.datetime: "TIMESTAMP", + } + + def as_ddl(self, type_ref: type) -> str: + """Returns a DDL representation of the type.""" + v = self._infer(type_ref, []) + return v.as_sql() + + def as_schema(self, type_ref: type) -> str: + """Returns a schema representation of the type.""" + v = self._infer(type_ref, []) + if hasattr(v, "as_schema"): + return v.as_schema() + raise StructInferError(f"Cannot generate schema for {type_ref}") + + def _infer(self, type_ref: type, path: list[str]) -> SqlType: + """Infers the SQL type from the Python type. Raises StructInferError if the type is not supported.""" + if dataclasses.is_dataclass(type_ref): + return self._infer_struct(type_ref, path) + if isinstance(type_ref, enum.EnumMeta): + return self._infer_primitive(str, path) + if type_ref in self._PRIMITIVES: + return self._infer_primitive(type_ref, path) + if type_ref is list: + raise StructInferError("Cannot determine element type of list. Rewrite as: list[XXX]") + if type_ref is set: + raise StructInferError("Cannot determine element type of set. Rewrite as: set[XXX]") + if type_ref is dict: + raise StructInferError("Cannot determine key and value types of dict. Rewrite as: dict[XXX, YYY]") + return self._infer_generic(type_ref, path) + + def _infer_primitive(self, type_ref: type, path: list[str]) -> PrimitiveType: + """Infers the primitive SQL type from the Python type. Raises StructInferError if the type is not supported.""" + if type_ref in self._PRIMITIVES: + return PrimitiveType(self._PRIMITIVES[type_ref]) + raise StructInferError(f'{".".join(path)}: unknown: {type_ref}') + + def _infer_generic(self, type_ref: type, path: list[str]) -> SqlType: + """Infers the SQL type from the generic Python type. Uses internal APIs to handle generic types.""" + # pylint: disable-next=import-outside-toplevel + from typing import ( # type: ignore[attr-defined] + _GenericAlias, + _UnionGenericAlias, + ) + + if isinstance(type_ref, (types.UnionType, _UnionGenericAlias)): # type: ignore[attr-defined] + return self._infer_nullable(type_ref, path) + if isinstance(type_ref, (types.GenericAlias, _GenericAlias)): # type: ignore[attr-defined] + if type_ref.__origin__ in (dict, list) or isinstance(type_ref, types.GenericAlias): + return self._infer_container(type_ref, path) + prefix = ".".join(path) + if prefix: + prefix = f"{prefix}: " + raise StructInferError(f"{prefix}unsupported type: {type_ref.__name__}") + + def _infer_nullable(self, type_ref: type, path: list[str]) -> SqlType: + """Infers nullability from Optional[x] or `x | None` types.""" + type_args = get_args(type_ref) + if len(type_args) > 2: + raise StructInferError(f'{".".join(path)}: union: too many variants: {type_args}') + first_type = self._infer(type_args[0], [*path, "(first)"]) + if type_args[1] is not type(None): + msg = f'{".".join(path)}.(second): not a NoneType: {type_args[1]}' + raise StructInferError(msg) + return NullableType(first_type) + + def _infer_container(self, type_ref: type, path: list[str]) -> SqlType: + """Infers the SQL type from the generic container Python type.""" + type_args = get_args(type_ref) + if not type_args: + raise StructInferError(f"Missing type arguments: {type_args} in {type_ref}") + if len(type_args) == 2: + key_type = self._infer(type_args[0], [*path, "key"]) + value_type = self._infer(type_args[1], [*path, "value"]) + return MapType(key_type, value_type) + # here we make a simple assumption that not two type arguments means a list + element_type = self._infer(type_args[0], path) + return ArrayType(element_type) + + def _infer_struct(self, type_ref: type, path: list[str]) -> StructType: + """Infers the struct type from the Python dataclass type.""" + fields = [] + for field, hint in get_type_hints(type_ref).items(): + origin = getattr(hint, "__origin__", None) + if origin is ClassVar: + continue + field_type = self._infer(hint, [*path, field]) + fields.append(StructField(field, field_type)) + return StructType(fields) diff --git a/tests/integration/test_structs.py b/tests/integration/test_structs.py new file mode 100644 index 00000000..076e3225 --- /dev/null +++ b/tests/integration/test_structs.py @@ -0,0 +1,36 @@ +import datetime +from dataclasses import dataclass + +from databricks.labs.lsql.backends import StatementExecutionBackend + + +@dataclass +class Foo: + first: str + second: bool | None + + +@dataclass +class Nested: + foo: Foo + since: datetime.date + created: datetime.datetime + mapping: dict[str, int] + array: list[int] + + +def test_appends_complex_types(ws, env_or_skip, make_random) -> None: + sql_backend = StatementExecutionBackend(ws, env_or_skip("TEST_DEFAULT_WAREHOUSE_ID")) + today = datetime.date.today() + now = datetime.datetime.now() + full_name = f"hive_metastore.default.t{make_random(4)}" + sql_backend.save_table( + full_name, + [ + Nested(Foo("a", True), today, now, {"a": 1, "b": 2}, [1, 2, 3]), + Nested(Foo("b", False), today, now, {"c": 3, "d": 4}, [4, 5, 6]), + ], + Nested, + ) + rows = list(sql_backend.fetch(f"SELECT * FROM {full_name}")) + assert len(rows) == 2 diff --git a/tests/unit/test_backends.py b/tests/unit/test_backends.py index e521f6f6..2d95160b 100644 --- a/tests/unit/test_backends.py +++ b/tests/unit/test_backends.py @@ -1,3 +1,4 @@ +import datetime import os import sys from dataclasses import dataclass @@ -219,7 +220,7 @@ def test_statement_execution_backend_save_table_two_records(): ) -def test_statement_execution_backend_save_table_in_batches_of_two(mocker): +def test_statement_execution_backend_save_table_in_batches_of_two(): ws = create_autospec(WorkspaceClient) ws.statement_execution.execute_statement.return_value = StatementResponse( @@ -430,3 +431,39 @@ def test_mock_backend_overwrite(): Row(first="aaa", second=True), Row(first="bbb", second=False), ] + + +@dataclass +class Nested: + foo: Foo + since: datetime.date + created: datetime.datetime + mapping: dict[str, int] + array: list[int] + some: float | None = None + + +def test_supports_complex_types(): + ws = create_autospec(WorkspaceClient) + + ws.statement_execution.execute_statement.return_value = StatementResponse( + status=StatementStatus(state=StatementState.SUCCEEDED) + ) + + seb = StatementExecutionBackend(ws, "abc", max_records_per_batch=2) + + today = datetime.date(2024, 9, 11) + now = datetime.datetime(2024, 9, 11, 12, 13, 14, tzinfo=datetime.timezone.utc) + seb.save_table( + "x", + [ + Nested(Foo("a", True), today, now, {"a": 1, "b": 2}, [1, 2, 3], 0.342532), + ], + Nested, + ) + + queries = [_.kwargs["statement"] for _ in ws.statement_execution.method_calls] + assert [ + "CREATE TABLE IF NOT EXISTS x (foo STRUCT NOT NULL, since DATE NOT NULL, created TIMESTAMP NOT NULL, mapping MAP NOT NULL, array ARRAY NOT NULL, some FLOAT) USING DELTA", + "INSERT INTO x (foo, since, created, mapping, array, some) VALUES (STRUCT('a' AS first, TRUE AS second), DATE '2024-9-11', TIMESTAMP '2024-09-11 12:13:14+0000', MAP('a', 1, 'b', 2), ARRAY(1, 2, 3), 0.342532)", + ] == queries diff --git a/tests/unit/test_structs.py b/tests/unit/test_structs.py new file mode 100644 index 00000000..10c90570 --- /dev/null +++ b/tests/unit/test_structs.py @@ -0,0 +1,71 @@ +import datetime +from dataclasses import dataclass +from typing import Optional + +import pytest + +from databricks.labs.lsql.structs import StructInference, StructInferError + + +@dataclass +class Foo: + first: str + second: bool | None + + +@dataclass +class Nested: + foo: Foo + mapping: dict[str, int] + array: list[int] + + +class NotDataclass: + x: int + + +@pytest.mark.parametrize( + "type_ref, ddl", + [ + (int, "LONG"), + (int | None, "LONG"), + (Optional[int], "LONG"), + (float, "FLOAT"), + (str, "STRING"), + (bool, "BOOLEAN"), + (datetime.date, "DATE"), + (datetime.datetime, "TIMESTAMP"), + (list[str], "ARRAY"), + (set[str], "ARRAY"), + (dict[str, int], "MAP"), + (dict[int, list[str]], "MAP>"), + (Foo, "STRUCT"), + (Nested, "STRUCT,mapping:MAP,array:ARRAY>"), + ], +) +def test_struct_inference(type_ref, ddl) -> None: + inference = StructInference() + assert inference.as_ddl(type_ref) == ddl + + +@pytest.mark.parametrize("type_ref", [type(None), list, set, tuple, dict, object, NotDataclass]) +def test_struct_inference_raises_on_unknown_type(type_ref) -> None: + inference = StructInference() + with pytest.raises(StructInferError): + inference.as_ddl(type_ref) + + +@pytest.mark.parametrize( + "type_ref,ddl", + [ + (Foo, "first STRING NOT NULL, second BOOLEAN"), + ( + Nested, + "foo STRUCT NOT NULL, " + "mapping MAP NOT NULL, array ARRAY NOT NULL", + ), + ], +) +def test_as_schema(type_ref, ddl) -> None: + inference = StructInference() + assert inference.as_schema(type_ref) == ddl