From f0a7166509d8d59a89021343945590520a6a0f3b Mon Sep 17 00:00:00 2001 From: Han Wang Date: Mon, 29 May 2023 22:44:40 -0700 Subject: [PATCH] Update Trino integration (#5) * Bump up version * update * update * update * fix * fix * fix * update fugue --- fugue_trino/_utils.py | 17 ++++++- fugue_trino/api.py | 79 +++++++++++++++++++++++++++++ fugue_trino/client.py | 24 ++++++--- fugue_trino/execution_engine.py | 2 +- fugue_trino/registry.py | 34 ++++++++++--- setup.py | 6 +-- tests/fugue_trino/_utils.py | 11 +--- tests/fugue_trino/test_api.py | 33 +++++++++++- tests/fugue_trino/test_dataframe.py | 2 +- 9 files changed, 174 insertions(+), 34 deletions(-) create mode 100644 fugue_trino/api.py diff --git a/fugue_trino/_utils.py b/fugue_trino/_utils.py index 75f6df7..09c6e0c 100644 --- a/fugue_trino/_utils.py +++ b/fugue_trino/_utils.py @@ -2,10 +2,12 @@ from typing import Any, List, Optional import fugue.api as fa +import ibis +import ibis.expr.datatypes as dt import pyarrow as pa from fugue.extensions import namespace_candidate from fugue_ibis import IbisSchema, IbisTable -from fugue_ibis._utils import ibis_to_pa_type +from fugue_ibis._utils import ibis_to_pa_type, pa_to_ibis_type from ibis.backends.trino import Backend from triad import ParamDict, Schema from triad.utils.pyarrow import TRIAD_DEFAULT_TIMESTAMP @@ -46,6 +48,11 @@ def to_schema(schema: IbisSchema) -> Schema: return Schema(fields) +def to_ibis_schema(schema: Schema) -> ibis.Schema: + fields = [(f.name, _pa_to_ibis_type(f.type)) for f in schema.fields] + return ibis.schema(fields) + + def is_select_query(s: str) -> bool: return ( re.match(r"^\s*select\s", s, re.IGNORECASE) is not None @@ -55,3 +62,11 @@ def is_select_query(s: str) -> bool: def _is_default_timestamp(tp: pa.DataType) -> bool: return pa.types.is_timestamp(tp) and str(tp.tz).lower() == "utc" + + +def _pa_to_ibis_type(tp: pa.DataType) -> dt.DataType: + if pa.types.is_timestamp(tp): + if tp.tz is None: + return dt.Timestamp(scale=6) + return dt.Timestamp(scale=6, timezone=str(tp.tz)) + return pa_to_ibis_type(tp) diff --git a/fugue_trino/api.py b/fugue_trino/api.py new file mode 100644 index 0000000..b5abb5c --- /dev/null +++ b/fugue_trino/api.py @@ -0,0 +1,79 @@ +from typing import Any, Optional + +import fugue.api as fa +from fugue import AnyDataFrame, AnyExecutionEngine + +from .client import TrinoClient +from .dataframe import TrinoDataFrame +from .execution_engine import TrinoExecutionEngine +from fugue_ibis import IbisTable, IbisDataFrame +import sqlglot + + +def describe(df: Any) -> None: + """Print the compiled SQL plus the output schema + + :param df: the input dataframe or ibis table + :return: the SQL query and the output schema + """ + if isinstance(df, IbisDataFrame): + query = df.to_sql() + elif isinstance(df, IbisTable): + query = str(df.compile()) + else: + fa.show(df) + return + query = sqlglot.transpile(query, read="trino", write="trino", pretty=True)[0] + schema = str(fa.get_schema(df)) + print(f"{query}\n\nOutput Schema: {schema}") + + +def load( + query_or_table: str, + engine: AnyExecutionEngine = None, + engine_conf: Any = None, + as_fugue: bool = False, + parallelism: Optional[int] = None, +) -> AnyDataFrame: + """Load Trino table using the current execution engine. + + :param query_or_table: the Trino query or table name + :param engine: execution engine, defaults to None (the current execution engine) + :param engine_conf: engine config, defaults to None + :param as_fugue: whether output a Fugue DataFrame, defaults to False + :param parallelism: the parallelism to load the BigQuery output, + defaults to None (determined by the current engine's parallelism) + :return: the output as a dataframe + + .. admonition:: Examples + + .. code-block:: python + + import fugue.api as fa + import fugue_trino.api as fta + + table = "some.trino.table" + + # direct load + t1 = fta.load(table) # t1 is an ibis table + t2 = fta.load(table, as_fugue=True) # t2 is a TrinoDataFrame + + # load under an engine + with fa.engine_context(spark_session): + t3 = fta.load(table) # t3 is a pyspark DataFrame + # loading parallelism will be at most 4 + t4 = fta.load(table, parallelism=4) + """ + with fa.engine_context( + engine, engine_conf=engine_conf, infer_by=["force_trino"] + ) as e: + if isinstance(e, TrinoExecutionEngine): + tb = e._client.query_to_ibis(query_or_table) + return e.to_df(tb) if as_fugue else tb + else: + client = TrinoClient.get_or_create(fa.get_current_conf()) + df = TrinoDataFrame(client.query_to_ibis(query_or_table)) + res = e.to_df(df) + if parallelism is not None and parallelism > 1: + res = fa.repartition(res, partition=parallelism) + return res if as_fugue else fa.get_native_as_df(res) diff --git a/fugue_trino/client.py b/fugue_trino/client.py index 7f9331f..0e6d11b 100644 --- a/fugue_trino/client.py +++ b/fugue_trino/client.py @@ -9,7 +9,6 @@ import numpy as np from fugue import AnyDataFrame from fugue_ibis import IbisDataFrame, IbisTable -from fugue_ibis._utils import to_ibis_schema from ibis import BaseBackend from sqlalchemy import exc as sa_exc from triad import ParamDict, SerializableRLock, assert_or_throw @@ -27,7 +26,12 @@ FUGUE_TRINO_CONF_USER, FUGUE_TRINO_ENV_PASSWORD, ) -from ._utils import get_temp_schema, is_trino_ibis_table, is_select_query +from ._utils import ( + get_temp_schema, + is_select_query, + is_trino_ibis_table, + to_ibis_schema, +) from .collections import TableName _FUGUE_TRINO_CLIENT_CONTEXT = ContextVar("_FUGUE_TRINO_CLIENT_CONTEXT", default=None) @@ -104,6 +108,8 @@ def __init__( self._trino_con = connect(host=host, port=port, user=user) self._con_lock = SerializableRLock() self._schema_backends: Dict[str, BaseBackend] = {} + self.sql(f"CREATE SCHEMA IF NOT EXISTS {catalog}.{self._temp_schema}") + self.sql(f"USE {catalog}.{self._temp_schema}") @property def catalog(self) -> str: @@ -152,18 +158,20 @@ def df_to_table( warnings.simplefilter("ignore", category=sa_exc.SAWarning) if isinstance(fdf, IbisDataFrame) and is_trino_ibis_table(fdf.native): obj: Any = fdf.native + query = f"CREATE OR REPLACE VIEW {tb} AS " + fdf.to_sql() + self.sql(query) else: obj = fdf.as_pandas().replace({np.nan: None}) if len(obj) == 0: obj = None else: obj = ibis.memtable(obj, schema=to_ibis_schema(fdf.schema)) - con.create_table( - tb.table, - obj, - schema=to_ibis_schema(fdf.schema), - overwrite=overwrite, - ) + con.create_table( + tb.table, + obj, + schema=to_ibis_schema(fdf.schema), + overwrite=overwrite, + ) except HttpError: # pragma: no cover pass return tb diff --git a/fugue_trino/execution_engine.py b/fugue_trino/execution_engine.py index dcf4d22..e42b815 100644 --- a/fugue_trino/execution_engine.py +++ b/fugue_trino/execution_engine.py @@ -53,7 +53,7 @@ def encode_column_name(self, name: str) -> str: return '"' + name.replace('"', '\\"') + '"' def get_temp_table_name(self) -> str: - return str(self.client.to_table_name(super().get_temp_table_name())) + return str(self.client.to_table_name(None)) def to_df(self, df: Any, schema: Any = None) -> IbisDataFrame: if isinstance(df, TrinoDataFrame): diff --git a/fugue_trino/registry.py b/fugue_trino/registry.py index a1fc307..d5baaab 100644 --- a/fugue_trino/registry.py +++ b/fugue_trino/registry.py @@ -1,10 +1,11 @@ from typing import Any, Optional, Tuple import fugue.plugins as fp +import fugue.api as fa from fugue import DataFrame, ExecutionEngine, SQLEngine, is_pandas_or -from triad import ParamDict - -from ._utils import is_trino_ibis_table, is_trino_repr +from triad import ParamDict, Schema +from fugue_ibis import IbisTable +from ._utils import is_trino_ibis_table, is_trino_repr, to_schema from .client import TrinoClient from .dataframe import TrinoDataFrame from .execution_engine import TrinoExecutionEngine, TrinoSQLEngine @@ -15,12 +16,22 @@ def _trino_to_df(query: Tuple[str, str], **kwargs: Any) -> DataFrame: return TrinoDataFrame(TrinoClient.get_current().query_to_ibis(query[1])) -@fp.parse_creator.candidate(is_trino_repr) -def _parse_trino_creator(query: Tuple[str, str]): - def _creator() -> DataFrame: - return _trino_to_df(query) +@fp.is_df.candidate(lambda df: is_trino_ibis_table(df) or is_trino_repr(df)) +def _is_trino_df(df: Any): + return True + - return _creator +@fp.get_schema.candidate(lambda df: is_trino_ibis_table(df) or is_trino_repr(df)) +def _trino_get_schema(df: Any) -> Schema: + """Get the schema of certain query or table + + :param query_or_table: the table name or query string + :return: the schema of the output + """ + if isinstance(df, IbisTable): + return to_schema(df.schema()) + client = TrinoClient.get_or_create(fa.get_current_conf()) + return to_schema(client.query_to_ibis(df[1]).schema()) @fp.parse_execution_engine.candidate( @@ -34,6 +45,13 @@ def _parse_trino(engine: str, conf: Any, **kwargs) -> ExecutionEngine: return TrinoExecutionEngine(client, _conf) +@fp.parse_execution_engine.candidate( + lambda engine, conf, **kwargs: isinstance(engine, TrinoClient) +) +def _parse_trino_client(engine: TrinoClient, conf: Any, **kwargs) -> ExecutionEngine: + return TrinoExecutionEngine(engine, conf) + + @fp.infer_execution_engine.candidate( lambda objs: is_pandas_or(objs, TrinoDataFrame) or any( diff --git a/setup.py b/setup.py index 95f9b35..57bf21b 100644 --- a/setup.py +++ b/setup.py @@ -32,7 +32,7 @@ def get_version() -> str: install_requires=[], extras_require={ "bigquery": [ - "fugue[ibis]>=0.8.1", + "fugue[ibis]>=0.8.4", "fs-gcsfs", "pandas-gbq", "google-auth", @@ -40,10 +40,10 @@ def get_version() -> str: "ibis-framework[bigquery]", ], "trino": [ - "fugue[ibis]>=0.8.1", + "fugue[ibis]>=0.8.4", "ibis-framework[trino]", ], - "ray": ["fugue[ray]>=0.8.1"], + "ray": ["fugue[ray]>=0.8.4"], }, classifiers=[ # "3 - Alpha", "4 - Beta" or "5 - Production/Stable" diff --git a/tests/fugue_trino/_utils.py b/tests/fugue_trino/_utils.py index 43fd0f0..01466e0 100644 --- a/tests/fugue_trino/_utils.py +++ b/tests/fugue_trino/_utils.py @@ -1,7 +1,6 @@ from typing import Any, Optional from fugue import DataFrame -from triad import to_uuid from fugue_trino import TrinoClient, TrinoExecutionEngine from fugue_trino._constants import FUGUE_TRINO_CONF_TEMP_SCHEMA_DEFAULT_NAME @@ -29,12 +28,4 @@ def __init__(self, client: Optional[TrinoClient] = None, conf: Any = None): self._cache = {} def to_df(self, df: Any, schema: Any = None) -> DataFrame: - if isinstance(df, list): - key = to_uuid(df, schema) - else: - key = to_uuid(id(df), schema) - if key in self._cache: - return self._cache[key] - res = super().sql_engine.to_df(df, schema) - self._cache[key] = res - return res + return super().sql_engine.to_df(df, schema) diff --git a/tests/fugue_trino/test_api.py b/tests/fugue_trino/test_api.py index 979a04a..6b319a9 100644 --- a/tests/fugue_trino/test_api.py +++ b/tests/fugue_trino/test_api.py @@ -1,15 +1,25 @@ import fugue.api as fa -from ._utils import get_testing_client import pandas as pd +from fugue_ibis import IbisTable + +import fugue_trino.api as fta +from fugue_trino import TrinoDataFrame +from ._utils import get_testing_client +import ray +from fugue_ray import RayDataFrame +import ray.data as rd -def test_api(): + +def test_fugue_api(): def tr(df: pd.DataFrame) -> pd.DataFrame: return df with get_testing_client() as client: df1 = fa.as_fugue_df([["a", 1], ["b", 2]], schema="x:str,b:long") tb1 = str(client.df_to_table(df1)) + assert fa.get_schema(("trino", tb1)) == "x:str,b:long" + assert fa.get_schema(client.query_to_ibis(tb1)) == "x:str,b:long" df2 = fa.as_fugue_df([["a", True], ["c", False]], schema="x:str,c:bool") tb2 = str(client.df_to_table(df2)) fa.show(("trino", tb1)) @@ -20,3 +30,22 @@ def tr(df: pd.DataFrame) -> pd.DataFrame: res = fa.transform(res, tr, schema="*") assert [["a", 1, True]] == res.as_array() assert 2 == fa.count(fa.transform(("trino", tb1), tr, schema="*")) + + +def test_load(): + with get_testing_client() as client: + df1 = fa.as_fugue_df([["a", 1], ["b", 2]], schema="x:str,b:long") + tb1 = str(client.df_to_table(df1)) + with fa.engine_context(client): + df = fta.load(tb1) + assert isinstance(df, IbisTable) + df = fta.load(tb1, as_fugue=True) + assert isinstance(df, TrinoDataFrame) + + with ray.init(): + with fa.engine_context("ray"): + df = fta.load(tb1, parallelism=5) + assert isinstance(df, rd.Dataset) + df = fta.load(tb1, parallelism=5, as_fugue=True) + assert isinstance(df, RayDataFrame) + assert df.num_partitions == 5 diff --git a/tests/fugue_trino/test_dataframe.py b/tests/fugue_trino/test_dataframe.py index 059c377..902607f 100644 --- a/tests/fugue_trino/test_dataframe.py +++ b/tests/fugue_trino/test_dataframe.py @@ -14,7 +14,7 @@ @pytest.mark.skipif(sys.version_info < (3, 8), reason="< 3.8") -class BigQueryDataFrameTests(DataFrameTests.Tests): +class TrinoDataFrameTests(DataFrameTests.Tests): @classmethod def setUpClass(cls): cls._client = get_testing_client()