Skip to content

Commit

Permalink
refactor(sql): make compilers usable with a base install
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud committed Aug 4, 2024
1 parent e213d02 commit 9e9ec9a
Show file tree
Hide file tree
Showing 44 changed files with 434 additions and 273 deletions.
2 changes: 1 addition & 1 deletion ibis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def connect(*args, **kwargs):
proxy.has_operation = backend.has_operation
proxy.name = name
proxy._from_url = backend._from_url
proxy._to_sqlglot = backend._to_sqlglot

# Add any additional methods that should be exposed at the top level
for attr in getattr(backend, "_top_level_methods", ()):
setattr(proxy, attr, getattr(backend, attr))
Expand Down
149 changes: 12 additions & 137 deletions ibis/backends/bigquery/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from pydata_google_auth import cache

import ibis
import ibis.backends.sql.compilers as sc
import ibis.common.exceptions as com
import ibis.expr.operations as ops
import ibis.expr.schema as sch
Expand All @@ -32,9 +33,7 @@
schema_from_bigquery_table,
)
from ibis.backends.bigquery.datatypes import BigQuerySchema
from ibis.backends.bigquery.udf.core import PythonToJavaScriptTranslator
from ibis.backends.sql import SQLBackend
from ibis.backends.sql.compilers import BigQueryCompiler
from ibis.backends.sql.datatypes import BigQueryType

if TYPE_CHECKING:
Expand Down Expand Up @@ -150,7 +149,7 @@ def _force_quote_table(table: sge.Table) -> sge.Table:

class Backend(SQLBackend, CanCreateDatabase, CanCreateSchema):
name = "bigquery"
compiler = BigQueryCompiler()
compiler = sc.bigquery.compiler
supports_in_memory_tables = True
supports_python_udfs = False

Expand Down Expand Up @@ -652,68 +651,6 @@ def _get_schema_using_query(self, query: str) -> sch.Schema:
)
return BigQuerySchema.to_ibis(job.schema)

def _to_sqlglot(
self,
expr: ir.Expr,
limit: str | None = None,
params: Mapping[ir.Expr, Any] | None = None,
**kwargs,
) -> Any:
"""Compile an Ibis expression.
Parameters
----------
expr
Ibis expression
limit
For expressions yielding result sets; retrieve at most this number
of values/rows. Overrides any limit already set on the expression.
params
Named unbound parameters
kwargs
Keyword arguments passed to the compiler
Returns
-------
Any
The output of compilation. The type of this value depends on the
backend.
"""
self._define_udf_translation_rules(expr)
sql = super()._to_sqlglot(expr, limit=limit, params=params, **kwargs)

table_expr = expr.as_table()
geocols = [
name for name, typ in table_expr.schema().items() if typ.is_geospatial()
]

query = sql.transform(
_qualify_memtable,
dataset=getattr(self._session_dataset, "dataset_id", None),
project=getattr(self._session_dataset, "project", None),
).transform(_remove_null_ordering_from_unsupported_window)

if not geocols:
return query

# if there are any geospatial columns, we have to convert them to WKB,
# so interactive mode knows how to display them
#
# by default bigquery returns data to python as WKT, and there's really
# no point in supporting both if we don't need to.
compiler = self.compiler
quoted = compiler.quoted
f = compiler.f
return sg.select(
sge.Star(
replace=[
f.st_asbinary(sg.column(col, quoted=quoted)).as_(col, quoted=quoted)
for col in geocols
]
)
).from_(query.subquery())

def raw_sql(self, query: str, params=None, page_size: int | None = None):
query_parameters = [
bigquery_param(
Expand Down Expand Up @@ -750,16 +687,16 @@ def compile(
self, expr: ir.Expr, limit: str | None = None, params=None, **kwargs: Any
):
"""Compile an Ibis expression to a SQL string."""
query = self._to_sqlglot(expr, limit=limit, params=params, **kwargs)
udf_sources = []
for udf_node in expr.op().find(ops.ScalarUDF):
compile_func = getattr(
self, f"_compile_{udf_node.__input_type__.name.lower()}_udf"
)
if sql := compile_func(udf_node):
udf_sources.append(sql.sql(self.name, pretty=True))

sql = ";\n".join([*udf_sources, query.sql(dialect=self.name, pretty=True)])
query = self.compiler.to_sqlglot(
expr,
limit=limit,
params=params,
session_dataset_id=getattr(self._session_dataset, "dataset", None),
session_project_id=getattr(self._session_dataset, "project", None),
**kwargs,
)
queries = util.promote_list(query)
sql = ";\n".join(query.sql(self.dialect) for query in queries)
self._log(sql)
return sql

Expand Down Expand Up @@ -1202,68 +1139,6 @@ def _clean_up_cached_table(self, name):
force=True,
)

def _get_udf_source(self, udf_node: ops.ScalarUDF):
name = type(udf_node).__name__
type_mapper = self.compiler.udf_type_mapper

body = PythonToJavaScriptTranslator(udf_node.__func__).compile()
config = udf_node.__config__
libraries = config.get("libraries", [])

signature = [
sge.ColumnDef(
this=sg.to_identifier(name, quoted=self.compiler.quoted),
kind=type_mapper.from_ibis(param.annotation.pattern.dtype),
)
for name, param in udf_node.__signature__.parameters.items()
]

lines = ['"""']

if config.get("strict", True):
lines.append('"use strict";')

lines += [
body,
"",
f"return {udf_node.__func_name__}({', '.join(udf_node.argnames)});",
'"""',
]

func = sge.Create(
kind="FUNCTION",
this=sge.UserDefinedFunction(
this=sg.to_identifier(name), expressions=signature, wrapped=True
),
# not exactly what I had in mind, but it works
#
# quoting is too simplistic to handle multiline strings
expression=sge.Var(this="\n".join(lines)),
exists=False,
properties=sge.Properties(
expressions=[
sge.TemporaryProperty(),
sge.ReturnsProperty(this=type_mapper.from_ibis(udf_node.dtype)),
sge.StabilityProperty(
this="IMMUTABLE" if config.get("determinism") else "VOLATILE"
),
sge.LanguageProperty(this=sg.to_identifier("js")),
]
+ [
sge.Property(
this=sg.to_identifier("library"),
value=self.compiler.f.array(*libraries),
)
]
* bool(libraries)
),
)

return func

def _compile_python_udf(self, udf_node: ops.ScalarUDF) -> None:
return self._get_udf_source(udf_node)

def _register_udfs(self, expr: ir.Expr) -> None:
"""No op because UDFs made with CREATE TEMPORARY FUNCTION must be followed by a query."""

Expand Down
8 changes: 4 additions & 4 deletions ibis/backends/clickhouse/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from clickhouse_connect.driver.external import ExternalData

import ibis
import ibis.backends.sql.compilers as sc
import ibis.common.exceptions as com
import ibis.config
import ibis.expr.operations as ops
Expand All @@ -26,7 +27,6 @@
from ibis.backends import BaseBackend, CanCreateDatabase
from ibis.backends.clickhouse.converter import ClickHousePandasData
from ibis.backends.sql import SQLBackend
from ibis.backends.sql.compilers import ClickHouseCompiler
from ibis.backends.sql.compilers.base import C

if TYPE_CHECKING:
Expand All @@ -44,7 +44,7 @@ def _to_memtable(v):

class Backend(SQLBackend, CanCreateDatabase):
name = "clickhouse"
compiler = ClickHouseCompiler()
compiler = sc.clickhouse.compiler

# ClickHouse itself does, but the client driver does not
supports_temporary_tables = False
Expand Down Expand Up @@ -732,7 +732,7 @@ def create_table(
expression = None

if obj is not None:
expression = self._to_sqlglot(obj)
expression = self.compiler.to_sqlglot(obj)
external_tables.update(self._collect_in_memory_tables(obj))

code = sge.Create(
Expand All @@ -759,7 +759,7 @@ def create_view(
database: str | None = None,
overwrite: bool = False,
) -> ir.Table:
expression = self._to_sqlglot(obj)
expression = self.compiler.to_sqlglot(obj)
src = sge.Create(
this=sg.table(name, db=database),
kind="VIEW",
Expand Down
9 changes: 5 additions & 4 deletions ibis/backends/datafusion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import sqlglot.expressions as sge

import ibis
import ibis.backends.sql.compilers as sc
import ibis.common.exceptions as com
import ibis.expr.datatypes as dt
import ibis.expr.operations as ops
Expand All @@ -23,7 +24,6 @@
from ibis import util
from ibis.backends import CanCreateCatalog, CanCreateDatabase, CanCreateSchema, NoUrl
from ibis.backends.sql import SQLBackend
from ibis.backends.sql.compilers import DataFusionCompiler
from ibis.backends.sql.compilers.base import C
from ibis.common.dispatch import lazy_singledispatch
from ibis.expr.operations.udf import InputType
Expand Down Expand Up @@ -68,7 +68,7 @@ class Backend(SQLBackend, CanCreateCatalog, CanCreateDatabase, CanCreateSchema,
name = "datafusion"
supports_in_memory_tables = True
supports_arrays = True
compiler = DataFusionCompiler()
compiler = sc.datafusion.compiler

@property
def version(self):
Expand Down Expand Up @@ -629,16 +629,17 @@ def create_table(
# If it's a memtable, it will get registered in the pre-execute hooks
self._run_pre_execute_hooks(table)

compiler = self.compiler
relname = "_"
query = sg.select(
*(
self.compiler.cast(
compiler.cast(
sg.column(col, table=relname, quoted=quoted), dtype
).as_(col, quoted=quoted)
for col, dtype in table.schema().items()
)
).from_(
self._to_sqlglot(table).subquery(
compiler.to_sqlglot(table).subquery(
sg.to_identifier(relname, quoted=quoted)
)
)
Expand Down
4 changes: 2 additions & 2 deletions ibis/backends/druid/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@
import pydruid.db
import sqlglot as sg

import ibis.backends.sql.compilers as sc
import ibis.expr.datatypes as dt
import ibis.expr.schema as sch
from ibis import util
from ibis.backends.sql import SQLBackend
from ibis.backends.sql.compilers import DruidCompiler
from ibis.backends.sql.compilers.base import STAR
from ibis.backends.sql.datatypes import DruidType

Expand All @@ -31,7 +31,7 @@

class Backend(SQLBackend):
name = "druid"
compiler = DruidCompiler()
compiler = sc.druid.compiler
supports_create_or_replace = False
supports_in_memory_tables = True

Expand Down
34 changes: 3 additions & 31 deletions ibis/backends/duckdb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import sqlglot.expressions as sge

import ibis
import ibis.backends.sql.compilers as sc
import ibis.common.exceptions as exc
import ibis.expr.operations as ops
import ibis.expr.schema as sch
Expand All @@ -26,7 +27,6 @@
from ibis.backends import CanCreateDatabase, CanCreateSchema, UrlFromPath
from ibis.backends.duckdb.converter import DuckDBPandasData
from ibis.backends.sql import SQLBackend
from ibis.backends.sql.compilers import DuckDBCompiler
from ibis.backends.sql.compilers.base import STAR, C
from ibis.common.dispatch import lazy_singledispatch
from ibis.expr.operations.udf import InputType
Expand Down Expand Up @@ -68,7 +68,7 @@ def __repr__(self):

class Backend(SQLBackend, CanCreateDatabase, CanCreateSchema, UrlFromPath):
name = "duckdb"
compiler = DuckDBCompiler()
compiler = sc.duckdb.compiler

def _define_udf_translation_rules(self, expr):
"""No-op: UDF translation rules are defined in the compiler."""
Expand All @@ -95,34 +95,6 @@ def raw_sql(self, query: str | sg.Expression, **kwargs: Any) -> Any:
query = query.sql(dialect=self.name)
return self.con.execute(query, **kwargs)

def _to_sqlglot(
self, expr: ir.Expr, limit: str | None = None, params=None, **_: Any
):
sql = super()._to_sqlglot(expr, limit=limit, params=params)

table_expr = expr.as_table()
geocols = [
name for name, typ in table_expr.schema().items() if typ.is_geospatial()
]

if not geocols:
return sql
else:
self._load_extensions(["spatial"])

compiler = self.compiler
quoted = compiler.quoted
return sg.select(
sge.Star(
replace=[
compiler.f.st_aswkb(sg.column(col, quoted=quoted)).as_(
col, quoted=quoted
)
for col in geocols
]
)
).from_(sql.subquery())

def create_table(
self,
name: str,
Expand Down Expand Up @@ -195,7 +167,7 @@ def create_table(

self._run_pre_execute_hooks(table)

query = self._to_sqlglot(table)
query = self.compiler.to_sqlglot(table)
else:
query = None

Expand Down
Loading

0 comments on commit 9e9ec9a

Please sign in to comment.