Skip to content

Commit

Permalink
Update Trino integration (#5)
Browse files Browse the repository at this point in the history
* Bump up version

* update

* update

* update

* fix

* fix

* fix

* update fugue
  • Loading branch information
goodwanghan authored May 30, 2023
1 parent f01c4eb commit f0a7166
Show file tree
Hide file tree
Showing 9 changed files with 174 additions and 34 deletions.
17 changes: 16 additions & 1 deletion fugue_trino/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
from typing import Any, List, Optional

import fugue.api as fa
import ibis
import ibis.expr.datatypes as dt
import pyarrow as pa
from fugue.extensions import namespace_candidate
from fugue_ibis import IbisSchema, IbisTable
from fugue_ibis._utils import ibis_to_pa_type
from fugue_ibis._utils import ibis_to_pa_type, pa_to_ibis_type
from ibis.backends.trino import Backend
from triad import ParamDict, Schema
from triad.utils.pyarrow import TRIAD_DEFAULT_TIMESTAMP
Expand Down Expand Up @@ -46,6 +48,11 @@ def to_schema(schema: IbisSchema) -> Schema:
return Schema(fields)


def to_ibis_schema(schema: Schema) -> ibis.Schema:
fields = [(f.name, _pa_to_ibis_type(f.type)) for f in schema.fields]
return ibis.schema(fields)


def is_select_query(s: str) -> bool:
return (
re.match(r"^\s*select\s", s, re.IGNORECASE) is not None
Expand All @@ -55,3 +62,11 @@ def is_select_query(s: str) -> bool:

def _is_default_timestamp(tp: pa.DataType) -> bool:
return pa.types.is_timestamp(tp) and str(tp.tz).lower() == "utc"


def _pa_to_ibis_type(tp: pa.DataType) -> dt.DataType:
if pa.types.is_timestamp(tp):
if tp.tz is None:
return dt.Timestamp(scale=6)
return dt.Timestamp(scale=6, timezone=str(tp.tz))
return pa_to_ibis_type(tp)
79 changes: 79 additions & 0 deletions fugue_trino/api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
from typing import Any, Optional

import fugue.api as fa
from fugue import AnyDataFrame, AnyExecutionEngine

from .client import TrinoClient
from .dataframe import TrinoDataFrame
from .execution_engine import TrinoExecutionEngine
from fugue_ibis import IbisTable, IbisDataFrame
import sqlglot


def describe(df: Any) -> None:
"""Print the compiled SQL plus the output schema
:param df: the input dataframe or ibis table
:return: the SQL query and the output schema
"""
if isinstance(df, IbisDataFrame):
query = df.to_sql()
elif isinstance(df, IbisTable):
query = str(df.compile())
else:
fa.show(df)
return
query = sqlglot.transpile(query, read="trino", write="trino", pretty=True)[0]
schema = str(fa.get_schema(df))
print(f"{query}\n\nOutput Schema: {schema}")


def load(
query_or_table: str,
engine: AnyExecutionEngine = None,
engine_conf: Any = None,
as_fugue: bool = False,
parallelism: Optional[int] = None,
) -> AnyDataFrame:
"""Load Trino table using the current execution engine.
:param query_or_table: the Trino query or table name
:param engine: execution engine, defaults to None (the current execution engine)
:param engine_conf: engine config, defaults to None
:param as_fugue: whether output a Fugue DataFrame, defaults to False
:param parallelism: the parallelism to load the BigQuery output,
defaults to None (determined by the current engine's parallelism)
:return: the output as a dataframe
.. admonition:: Examples
.. code-block:: python
import fugue.api as fa
import fugue_trino.api as fta
table = "some.trino.table"
# direct load
t1 = fta.load(table) # t1 is an ibis table
t2 = fta.load(table, as_fugue=True) # t2 is a TrinoDataFrame
# load under an engine
with fa.engine_context(spark_session):
t3 = fta.load(table) # t3 is a pyspark DataFrame
# loading parallelism will be at most 4
t4 = fta.load(table, parallelism=4)
"""
with fa.engine_context(
engine, engine_conf=engine_conf, infer_by=["force_trino"]
) as e:
if isinstance(e, TrinoExecutionEngine):
tb = e._client.query_to_ibis(query_or_table)
return e.to_df(tb) if as_fugue else tb
else:
client = TrinoClient.get_or_create(fa.get_current_conf())
df = TrinoDataFrame(client.query_to_ibis(query_or_table))
res = e.to_df(df)
if parallelism is not None and parallelism > 1:
res = fa.repartition(res, partition=parallelism)
return res if as_fugue else fa.get_native_as_df(res)
24 changes: 16 additions & 8 deletions fugue_trino/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import numpy as np
from fugue import AnyDataFrame
from fugue_ibis import IbisDataFrame, IbisTable
from fugue_ibis._utils import to_ibis_schema
from ibis import BaseBackend
from sqlalchemy import exc as sa_exc
from triad import ParamDict, SerializableRLock, assert_or_throw
Expand All @@ -27,7 +26,12 @@
FUGUE_TRINO_CONF_USER,
FUGUE_TRINO_ENV_PASSWORD,
)
from ._utils import get_temp_schema, is_trino_ibis_table, is_select_query
from ._utils import (
get_temp_schema,
is_select_query,
is_trino_ibis_table,
to_ibis_schema,
)
from .collections import TableName

_FUGUE_TRINO_CLIENT_CONTEXT = ContextVar("_FUGUE_TRINO_CLIENT_CONTEXT", default=None)
Expand Down Expand Up @@ -104,6 +108,8 @@ def __init__(
self._trino_con = connect(host=host, port=port, user=user)
self._con_lock = SerializableRLock()
self._schema_backends: Dict[str, BaseBackend] = {}
self.sql(f"CREATE SCHEMA IF NOT EXISTS {catalog}.{self._temp_schema}")
self.sql(f"USE {catalog}.{self._temp_schema}")

@property
def catalog(self) -> str:
Expand Down Expand Up @@ -152,18 +158,20 @@ def df_to_table(
warnings.simplefilter("ignore", category=sa_exc.SAWarning)
if isinstance(fdf, IbisDataFrame) and is_trino_ibis_table(fdf.native):
obj: Any = fdf.native
query = f"CREATE OR REPLACE VIEW {tb} AS " + fdf.to_sql()
self.sql(query)
else:
obj = fdf.as_pandas().replace({np.nan: None})
if len(obj) == 0:
obj = None
else:
obj = ibis.memtable(obj, schema=to_ibis_schema(fdf.schema))
con.create_table(
tb.table,
obj,
schema=to_ibis_schema(fdf.schema),
overwrite=overwrite,
)
con.create_table(
tb.table,
obj,
schema=to_ibis_schema(fdf.schema),
overwrite=overwrite,
)
except HttpError: # pragma: no cover
pass
return tb
Expand Down
2 changes: 1 addition & 1 deletion fugue_trino/execution_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def encode_column_name(self, name: str) -> str:
return '"' + name.replace('"', '\\"') + '"'

def get_temp_table_name(self) -> str:
return str(self.client.to_table_name(super().get_temp_table_name()))
return str(self.client.to_table_name(None))

def to_df(self, df: Any, schema: Any = None) -> IbisDataFrame:
if isinstance(df, TrinoDataFrame):
Expand Down
34 changes: 26 additions & 8 deletions fugue_trino/registry.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from typing import Any, Optional, Tuple

import fugue.plugins as fp
import fugue.api as fa
from fugue import DataFrame, ExecutionEngine, SQLEngine, is_pandas_or
from triad import ParamDict

from ._utils import is_trino_ibis_table, is_trino_repr
from triad import ParamDict, Schema
from fugue_ibis import IbisTable
from ._utils import is_trino_ibis_table, is_trino_repr, to_schema
from .client import TrinoClient
from .dataframe import TrinoDataFrame
from .execution_engine import TrinoExecutionEngine, TrinoSQLEngine
Expand All @@ -15,12 +16,22 @@ def _trino_to_df(query: Tuple[str, str], **kwargs: Any) -> DataFrame:
return TrinoDataFrame(TrinoClient.get_current().query_to_ibis(query[1]))


@fp.parse_creator.candidate(is_trino_repr)
def _parse_trino_creator(query: Tuple[str, str]):
def _creator() -> DataFrame:
return _trino_to_df(query)
@fp.is_df.candidate(lambda df: is_trino_ibis_table(df) or is_trino_repr(df))
def _is_trino_df(df: Any):
return True


return _creator
@fp.get_schema.candidate(lambda df: is_trino_ibis_table(df) or is_trino_repr(df))
def _trino_get_schema(df: Any) -> Schema:
"""Get the schema of certain query or table
:param query_or_table: the table name or query string
:return: the schema of the output
"""
if isinstance(df, IbisTable):
return to_schema(df.schema())
client = TrinoClient.get_or_create(fa.get_current_conf())
return to_schema(client.query_to_ibis(df[1]).schema())


@fp.parse_execution_engine.candidate(
Expand All @@ -34,6 +45,13 @@ def _parse_trino(engine: str, conf: Any, **kwargs) -> ExecutionEngine:
return TrinoExecutionEngine(client, _conf)


@fp.parse_execution_engine.candidate(
lambda engine, conf, **kwargs: isinstance(engine, TrinoClient)
)
def _parse_trino_client(engine: TrinoClient, conf: Any, **kwargs) -> ExecutionEngine:
return TrinoExecutionEngine(engine, conf)


@fp.infer_execution_engine.candidate(
lambda objs: is_pandas_or(objs, TrinoDataFrame)
or any(
Expand Down
6 changes: 3 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,18 +32,18 @@ def get_version() -> str:
install_requires=[],
extras_require={
"bigquery": [
"fugue[ibis]>=0.8.1",
"fugue[ibis]>=0.8.4",
"fs-gcsfs",
"pandas-gbq",
"google-auth",
"db-dtypes",
"ibis-framework[bigquery]",
],
"trino": [
"fugue[ibis]>=0.8.1",
"fugue[ibis]>=0.8.4",
"ibis-framework[trino]",
],
"ray": ["fugue[ray]>=0.8.1"],
"ray": ["fugue[ray]>=0.8.4"],
},
classifiers=[
# "3 - Alpha", "4 - Beta" or "5 - Production/Stable"
Expand Down
11 changes: 1 addition & 10 deletions tests/fugue_trino/_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from typing import Any, Optional

from fugue import DataFrame
from triad import to_uuid

from fugue_trino import TrinoClient, TrinoExecutionEngine
from fugue_trino._constants import FUGUE_TRINO_CONF_TEMP_SCHEMA_DEFAULT_NAME
Expand Down Expand Up @@ -29,12 +28,4 @@ def __init__(self, client: Optional[TrinoClient] = None, conf: Any = None):
self._cache = {}

def to_df(self, df: Any, schema: Any = None) -> DataFrame:
if isinstance(df, list):
key = to_uuid(df, schema)
else:
key = to_uuid(id(df), schema)
if key in self._cache:
return self._cache[key]
res = super().sql_engine.to_df(df, schema)
self._cache[key] = res
return res
return super().sql_engine.to_df(df, schema)
33 changes: 31 additions & 2 deletions tests/fugue_trino/test_api.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,25 @@
import fugue.api as fa
from ._utils import get_testing_client
import pandas as pd
from fugue_ibis import IbisTable

import fugue_trino.api as fta
from fugue_trino import TrinoDataFrame

from ._utils import get_testing_client
import ray
from fugue_ray import RayDataFrame
import ray.data as rd

def test_api():

def test_fugue_api():
def tr(df: pd.DataFrame) -> pd.DataFrame:
return df

with get_testing_client() as client:
df1 = fa.as_fugue_df([["a", 1], ["b", 2]], schema="x:str,b:long")
tb1 = str(client.df_to_table(df1))
assert fa.get_schema(("trino", tb1)) == "x:str,b:long"
assert fa.get_schema(client.query_to_ibis(tb1)) == "x:str,b:long"
df2 = fa.as_fugue_df([["a", True], ["c", False]], schema="x:str,c:bool")
tb2 = str(client.df_to_table(df2))
fa.show(("trino", tb1))
Expand All @@ -20,3 +30,22 @@ def tr(df: pd.DataFrame) -> pd.DataFrame:
res = fa.transform(res, tr, schema="*")
assert [["a", 1, True]] == res.as_array()
assert 2 == fa.count(fa.transform(("trino", tb1), tr, schema="*"))


def test_load():
with get_testing_client() as client:
df1 = fa.as_fugue_df([["a", 1], ["b", 2]], schema="x:str,b:long")
tb1 = str(client.df_to_table(df1))
with fa.engine_context(client):
df = fta.load(tb1)
assert isinstance(df, IbisTable)
df = fta.load(tb1, as_fugue=True)
assert isinstance(df, TrinoDataFrame)

with ray.init():
with fa.engine_context("ray"):
df = fta.load(tb1, parallelism=5)
assert isinstance(df, rd.Dataset)
df = fta.load(tb1, parallelism=5, as_fugue=True)
assert isinstance(df, RayDataFrame)
assert df.num_partitions == 5
2 changes: 1 addition & 1 deletion tests/fugue_trino/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@


@pytest.mark.skipif(sys.version_info < (3, 8), reason="< 3.8")
class BigQueryDataFrameTests(DataFrameTests.Tests):
class TrinoDataFrameTests(DataFrameTests.Tests):
@classmethod
def setUpClass(cls):
cls._client = get_testing_client()
Expand Down

0 comments on commit f0a7166

Please sign in to comment.