Skip to content

Commit

Permalink
Added support for generic types in SqlBackend (#272)
Browse files Browse the repository at this point in the history
This PR adds the ability to use rich dataclasses like:

```python
@DataClass
class Foo:
    first: str
    second: bool | None

@DataClass
class Nested:
    foo: Foo
    mapping: dict[str, int]
    array: list[int]
```
  • Loading branch information
nfx authored Sep 11, 2024
1 parent 5dfaaa0 commit fe7bf4d
Show file tree
Hide file tree
Showing 5 changed files with 383 additions and 54 deletions.
99 changes: 46 additions & 53 deletions src/databricks/labs/lsql/backends.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import dataclasses
import datetime
import logging
import os
import re
from abc import ABC, abstractmethod
from collections.abc import Callable, Iterable, Iterator, Sequence
from types import UnionType
from typing import Any, ClassVar, Protocol, TypeVar

from databricks.labs.blueprint.commands import CommandExecutor
Expand All @@ -20,6 +20,7 @@
from databricks.sdk.service.compute import Language

from databricks.labs.lsql.core import Row, StatementExecutionExt
from databricks.labs.lsql.structs import StructInference

logger = logging.getLogger(__name__)

Expand All @@ -42,6 +43,10 @@ class SqlBackend(ABC):
execute SQL statements, fetch results from SQL statements, and save data
to tables."""

# singleton shared across all SQL backends, used to infer schema from dataclasses.
# no state is stored in this class, so it can be shared across all instances.
_STRUCTS = StructInference()

@abstractmethod
def execute(self, sql: str, *, catalog: str | None = None, schema: str | None = None) -> None:
raise NotImplementedError
Expand All @@ -55,44 +60,9 @@ def save_table(self, full_name: str, rows: Sequence[DataclassInstance], klass: D
raise NotImplementedError

def create_table(self, full_name: str, klass: Dataclass):
ddl = f"CREATE TABLE IF NOT EXISTS {full_name} ({self._schema_for(klass)}) USING DELTA"
ddl = f"CREATE TABLE IF NOT EXISTS {full_name} ({self._STRUCTS.as_schema(klass)}) USING DELTA"
self.execute(ddl)

_builtin_type_mapping: ClassVar[dict[type, str]] = {
str: "STRING",
int: "LONG",
bool: "BOOLEAN",
float: "FLOAT",
}

@classmethod
def _schema_for(cls, klass: Dataclass):
fields = []
for f in dataclasses.fields(klass):
field_type = cls._field_type(f)
if isinstance(field_type, UnionType):
field_type = field_type.__args__[0]
if field_type not in cls._builtin_type_mapping:
msg = f"Cannot auto-convert {field_type}"
raise SyntaxError(msg)
not_null = " NOT NULL"
if f.default is None:
not_null = ""
spark_type = cls._builtin_type_mapping[field_type]
fields.append(f"{f.name} {spark_type}{not_null}")
return ", ".join(fields)

@classmethod
def _field_type(cls, field: dataclasses.Field):
# workaround rare (Python?) issue where f.type is the type name instead of the type itself
# this seems to happen when the dataclass is first used from a file importing it
if isinstance(field.type, str):
try:
return __builtins__[field.type]
except TypeError as e:
logger.warning(f"Could not load type {field.type}", exc_info=e)
return field.type

@classmethod
def _filter_none_rows(cls, rows, klass):
if len(rows) == 0:
Expand Down Expand Up @@ -177,23 +147,46 @@ def _row_to_sql(cls, row: DataclassInstance, fields: tuple[dataclasses.Field[Any
data = []
for f in fields:
value = getattr(row, f.name)
field_type = cls._field_type(f)
if isinstance(field_type, UnionType):
field_type = field_type.__args__[0]
if value is None:
data.append("NULL")
elif field_type is bool:
data.append("TRUE" if value else "FALSE")
elif field_type is str:
value = str(value).replace("'", "''")
data.append(f"'{value}'")
elif field_type is int:
data.append(f"{value}")
else:
msg = f"unknown type: {field_type}"
raise ValueError(msg)
data.append(cls._value_to_sql(value))
return ", ".join(data)

@classmethod
def _value_to_sql(cls, value: Any) -> str:
"""Converts a Python value to a SQL string representation."""
if value is None:
return "NULL"
if isinstance(value, bool):
return "TRUE" if value else "FALSE"
if isinstance(value, int):
return f"{value}"
if isinstance(value, float):
return f"{value}"
if isinstance(value, str):
value = str(value).replace("'", "''")
return f"'{value}'"
if isinstance(value, datetime.datetime):
return f"TIMESTAMP '{value.strftime('%Y-%m-%d %H:%M:%S%z')}'"
if isinstance(value, datetime.date):
return f"DATE '{value.year}-{value.month}-{value.day}'"
if isinstance(value, list):
values = ", ".join(cls._value_to_sql(v) for v in value)
return f"ARRAY({values})"
if isinstance(value, dict):
map_values: list[str] = []
for k, v in value.items():
map_values.append(cls._value_to_sql(k))
map_values.append(cls._value_to_sql(v))
return f"MAP({', '.join(map_values)})"
if dataclasses.is_dataclass(value):
struct = []
for f in dataclasses.fields(value):
v = getattr(value, f.name)
sql_value = f"{cls._value_to_sql(v)} AS {f.name}"
struct.append(sql_value)
return f"STRUCT({', '.join(struct)})"
msg = f"unsupported: {value}"
raise ValueError(msg)


class StatementExecutionBackend(ExecutionBackend):
def __init__(self, ws: WorkspaceClient, warehouse_id, *, max_records_per_batch: int = 1000, **kwargs):
Expand Down Expand Up @@ -273,7 +266,7 @@ def save_table(self, full_name: str, rows: Sequence[DataclassInstance], klass: D
self.create_table(full_name, klass)
return
# pyspark deals well with lists of dataclass instances, as long as schema is provided
df = self._spark.createDataFrame(rows, self._schema_for(klass))
df = self._spark.createDataFrame(rows, self._STRUCTS.as_schema(klass))
df.write.saveAsTable(full_name, mode=mode)


Expand Down
192 changes: 192 additions & 0 deletions src/databricks/labs/lsql/structs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
import dataclasses
import datetime
import enum
import types
from dataclasses import dataclass
from typing import ClassVar, Protocol, get_args, get_type_hints


class StructInferError(TypeError):
pass


class SqlType(Protocol):
"""Represents a Spark SQL type."""

def as_sql(self) -> str: ...


@dataclass
class NullableType(SqlType):
"""Represents a nullable type."""

inner_type: SqlType

def as_sql(self) -> str:
return self.inner_type.as_sql()


@dataclass
class ArrayType(SqlType):
"""Represents an array type."""

element_type: SqlType

def as_sql(self) -> str:
return f"ARRAY<{self.element_type.as_sql()}>"


@dataclass
class MapType(SqlType):
"""Represents a map type."""

key_type: SqlType
value_type: SqlType

def as_sql(self) -> str:
return f"MAP<{self.key_type.as_sql()},{self.value_type.as_sql()}>"


@dataclass
class PrimitiveType(SqlType):
"""Represents a primitive type."""

name: str

def as_sql(self) -> str:
return self.name


@dataclass
class StructField:
"""Represents a field in a struct type."""

name: str
type: SqlType

@property
def nullable(self) -> bool:
return isinstance(self.type, NullableType)

def as_sql(self) -> str:
return f"{self.name}:{self.type.as_sql()}"


@dataclass
class StructType(SqlType):
"""Represents a struct type."""

fields: list[StructField]

def as_sql(self) -> str:
"""Returns a DDL representation of the struct type."""
fields = ",".join(f.as_sql() for f in self.fields)
return f"STRUCT<{fields}>"

def as_schema(self) -> str:
"""Returns a schema representation of the struct type."""
fields = []
for field in self.fields:
not_null = "" if field.nullable else " NOT NULL"
fields.append(f"{field.name} {field.type.as_sql()}{not_null}")
return ", ".join(fields)


class StructInference:
"""Infers Spark SQL types from Python types."""

_PRIMITIVES: ClassVar[dict[type, str]] = {
str: "STRING",
int: "LONG",
bool: "BOOLEAN",
float: "FLOAT",
datetime.date: "DATE",
datetime.datetime: "TIMESTAMP",
}

def as_ddl(self, type_ref: type) -> str:
"""Returns a DDL representation of the type."""
v = self._infer(type_ref, [])
return v.as_sql()

def as_schema(self, type_ref: type) -> str:
"""Returns a schema representation of the type."""
v = self._infer(type_ref, [])
if hasattr(v, "as_schema"):
return v.as_schema()
raise StructInferError(f"Cannot generate schema for {type_ref}")

def _infer(self, type_ref: type, path: list[str]) -> SqlType:
"""Infers the SQL type from the Python type. Raises StructInferError if the type is not supported."""
if dataclasses.is_dataclass(type_ref):
return self._infer_struct(type_ref, path)
if isinstance(type_ref, enum.EnumMeta):
return self._infer_primitive(str, path)
if type_ref in self._PRIMITIVES:
return self._infer_primitive(type_ref, path)
if type_ref is list:
raise StructInferError("Cannot determine element type of list. Rewrite as: list[XXX]")
if type_ref is set:
raise StructInferError("Cannot determine element type of set. Rewrite as: set[XXX]")
if type_ref is dict:
raise StructInferError("Cannot determine key and value types of dict. Rewrite as: dict[XXX, YYY]")
return self._infer_generic(type_ref, path)

def _infer_primitive(self, type_ref: type, path: list[str]) -> PrimitiveType:
"""Infers the primitive SQL type from the Python type. Raises StructInferError if the type is not supported."""
if type_ref in self._PRIMITIVES:
return PrimitiveType(self._PRIMITIVES[type_ref])
raise StructInferError(f'{".".join(path)}: unknown: {type_ref}')

def _infer_generic(self, type_ref: type, path: list[str]) -> SqlType:
"""Infers the SQL type from the generic Python type. Uses internal APIs to handle generic types."""
# pylint: disable-next=import-outside-toplevel
from typing import ( # type: ignore[attr-defined]
_GenericAlias,
_UnionGenericAlias,
)

if isinstance(type_ref, (types.UnionType, _UnionGenericAlias)): # type: ignore[attr-defined]
return self._infer_nullable(type_ref, path)
if isinstance(type_ref, (types.GenericAlias, _GenericAlias)): # type: ignore[attr-defined]
if type_ref.__origin__ in (dict, list) or isinstance(type_ref, types.GenericAlias):
return self._infer_container(type_ref, path)
prefix = ".".join(path)
if prefix:
prefix = f"{prefix}: "
raise StructInferError(f"{prefix}unsupported type: {type_ref.__name__}")

def _infer_nullable(self, type_ref: type, path: list[str]) -> SqlType:
"""Infers nullability from Optional[x] or `x | None` types."""
type_args = get_args(type_ref)
if len(type_args) > 2:
raise StructInferError(f'{".".join(path)}: union: too many variants: {type_args}')
first_type = self._infer(type_args[0], [*path, "(first)"])
if type_args[1] is not type(None):
msg = f'{".".join(path)}.(second): not a NoneType: {type_args[1]}'
raise StructInferError(msg)
return NullableType(first_type)

def _infer_container(self, type_ref: type, path: list[str]) -> SqlType:
"""Infers the SQL type from the generic container Python type."""
type_args = get_args(type_ref)
if not type_args:
raise StructInferError(f"Missing type arguments: {type_args} in {type_ref}")
if len(type_args) == 2:
key_type = self._infer(type_args[0], [*path, "key"])
value_type = self._infer(type_args[1], [*path, "value"])
return MapType(key_type, value_type)
# here we make a simple assumption that not two type arguments means a list
element_type = self._infer(type_args[0], path)
return ArrayType(element_type)

def _infer_struct(self, type_ref: type, path: list[str]) -> StructType:
"""Infers the struct type from the Python dataclass type."""
fields = []
for field, hint in get_type_hints(type_ref).items():
origin = getattr(hint, "__origin__", None)
if origin is ClassVar:
continue
field_type = self._infer(hint, [*path, field])
fields.append(StructField(field, field_type))
return StructType(fields)
36 changes: 36 additions & 0 deletions tests/integration/test_structs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import datetime
from dataclasses import dataclass

from databricks.labs.lsql.backends import StatementExecutionBackend


@dataclass
class Foo:
first: str
second: bool | None


@dataclass
class Nested:
foo: Foo
since: datetime.date
created: datetime.datetime
mapping: dict[str, int]
array: list[int]


def test_appends_complex_types(ws, env_or_skip, make_random) -> None:
sql_backend = StatementExecutionBackend(ws, env_or_skip("TEST_DEFAULT_WAREHOUSE_ID"))
today = datetime.date.today()
now = datetime.datetime.now()
full_name = f"hive_metastore.default.t{make_random(4)}"
sql_backend.save_table(
full_name,
[
Nested(Foo("a", True), today, now, {"a": 1, "b": 2}, [1, 2, 3]),
Nested(Foo("b", False), today, now, {"c": 3, "d": 4}, [4, 5, 6]),
],
Nested,
)
rows = list(sql_backend.fetch(f"SELECT * FROM {full_name}"))
assert len(rows) == 2
Loading

0 comments on commit fe7bf4d

Please sign in to comment.