diff --git a/doc/source/notebooks/abalone.ipynb b/doc/source/notebooks/abalone.ipynb index 78e87f4e..aaf2070b 100644 --- a/doc/source/notebooks/abalone.ipynb +++ b/doc/source/notebooks/abalone.ipynb @@ -504,6 +504,7 @@ "import numpy as np\n", "import pickle\n", "\n", + "\n", "@gp.create_column_function\n", "def linreg_func(length: List[float], shucked_weight: List[float], rings: List[int]) -> LinregType:\n", " X = np.array([length, shucked_weight]).T\n", @@ -560,9 +561,8 @@ "# ) a\n", "# ) DISTRIBUTED BY (sex);\n", "\n", - "linreg_fitted = (\n", - " abalone_train.group_by(\"sex\")\n", - " .apply(lambda t: linreg_func(t[\"length\"], t[\"shucked_weight\"], t[\"rings\"]), expand=True)\n", + "linreg_fitted = abalone_train.group_by(\"sex\").apply(\n", + " lambda t: linreg_func(t[\"length\"], t[\"shucked_weight\"], t[\"rings\"]), expand=True\n", ")" ] }, @@ -800,7 +800,7 @@ "linreg_test_fit = linreg_fitted.inner_join(\n", " abalone_test,\n", " cond=lambda t1, t2: t1[\"sex\"] == t2[\"sex\"],\n", - " self_columns=[\"col_nm\", \"coef\", \"intercept\", \"serialized_linreg_model\", \"created_dt\"]\n", + " self_columns=[\"col_nm\", \"coef\", \"intercept\", \"serialized_linreg_model\", \"created_dt\"],\n", ")" ] }, @@ -836,12 +836,11 @@ "\n", "\n", "linreg_pred = linreg_test_fit.assign(\n", - " rings_pred=lambda t:\n", - " linreg_pred_func(\n", - " t[\"serialized_linreg_model\"],\n", - " t[\"length\"],\n", - " t[\"shucked_weight\"],\n", - " ),\n", + " rings_pred=lambda t: linreg_pred_func(\n", + " t[\"serialized_linreg_model\"],\n", + " t[\"length\"],\n", + " t[\"shucked_weight\"],\n", + " ),\n", ")[[\"id\", \"sex\", \"rings\", \"rings_pred\"]]" ] }, @@ -946,6 +945,7 @@ "# , r2_score float8\n", "# );\n", "\n", + "\n", "@dataclasses.dataclass\n", "class linreg_eval_type:\n", " mae: float\n", @@ -974,6 +974,7 @@ "source": [ "from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score\n", "\n", + "\n", "@gp.create_column_function\n", "def linreg_eval(y_actual: List[float], y_pred: List[float]) -> linreg_eval_type:\n", " mae = mean_absolute_error(y_actual, y_pred)\n", @@ -1066,8 +1067,7 @@ "# ) a\n", "\n", "\n", - "linreg_pred.group_by(\"sex\").apply(\n", - " lambda t: linreg_eval(t[\"rings\"], t[\"rings_pred\"]), expand=True)" + "linreg_pred.group_by(\"sex\").apply(lambda t: linreg_eval(t[\"rings\"], t[\"rings_pred\"]), expand=True)" ] } ], diff --git a/doc/source/notebooks/basic.ipynb b/doc/source/notebooks/basic.ipynb index 839f71f9..c75257ae 100644 --- a/doc/source/notebooks/basic.ipynb +++ b/doc/source/notebooks/basic.ipynb @@ -787,14 +787,8 @@ "t_join = t1.join(\n", " t2,\n", " on=\"val\",\n", - " self_columns = {\n", - " \"id\": \"t1_id\",\n", - " \"val\": \"t1_val\"\n", - " },\n", - " other_columns = {\n", - " \"id\": \"t2_id\",\n", - " \"val\": \"t2_val\"\n", - " }\n", + " self_columns={\"id\": \"t1_id\", \"val\": \"t1_val\"},\n", + " other_columns={\"id\": \"t2_id\", \"val\": \"t2_val\"},\n", ")\n", "t_join" ] @@ -1075,7 +1069,7 @@ " numbers.assign(is_even=lambda t: t[\"val\"] % 2 == 0)\n", " .group_by(\"is_even\")\n", " .apply(lambda t: F.sum(t[\"val\"]))\n", - ")\n" + ")" ] } ], diff --git a/doc/source/notebooks/pandas.ipynb b/doc/source/notebooks/pandas.ipynb index 64ee3c3b..61caa6a5 100644 --- a/doc/source/notebooks/pandas.ipynb +++ b/doc/source/notebooks/pandas.ipynb @@ -31,7 +31,7 @@ "import pandas as pd\n", "import greenplumpython as gp\n", "\n", - "gp\n" + "gp" ] }, { @@ -80,7 +80,7 @@ ], "source": [ "students = [(\"alice\", 18), (\"bob\", 19), (\"carol\", 19)]\n", - "students\n" + "students" ] }, { @@ -154,7 +154,7 @@ ], "source": [ "pd_df = pd.DataFrame.from_records(students, columns=[\"name\", \"age\"])\n", - "pd_df\n" + "pd_df" ] }, { @@ -210,7 +210,7 @@ "source": [ "db = gp.database(\"postgresql://localhost/gpadmin\")\n", "gp_df = gp.DataFrame.from_rows(students, column_names=[\"name\", \"age\"], db=db)\n", - "gp_df\n" + "gp_df" ] }, { @@ -271,7 +271,7 @@ } ], "source": [ - "gp_df.save_as(\"student\", column_names=[\"name\", \"age\"], temp=True)\n" + "gp_df.save_as(\"student\", column_names=[\"name\", \"age\"], temp=True)" ] }, { @@ -287,7 +287,7 @@ "metadata": {}, "outputs": [], "source": [ - "pd_df.to_csv(\"/tmp/student.csv\")\n" + "pd_df.to_csv(\"/tmp/student.csv\")" ] }, { @@ -352,7 +352,7 @@ ], "source": [ "student = db.create_dataframe(table_name=\"student\")\n", - "student\n" + "student" ] }, { @@ -429,7 +429,7 @@ } ], "source": [ - "pd.read_csv(\"/tmp/student.csv\")\n" + "pd.read_csv(\"/tmp/student.csv\")" ] }, { @@ -460,7 +460,7 @@ ], "source": [ "for row in gp_df:\n", - " print(row[\"name\"], row[\"age\"])\n" + " print(row[\"name\"], row[\"age\"])" ] }, { @@ -487,7 +487,7 @@ ], "source": [ "for row in pd_df.iterrows():\n", - " print(row[1][\"name\"], row[1][\"age\"])\n" + " print(row[1][\"name\"], row[1][\"age\"])" ] }, { @@ -584,7 +584,7 @@ } ], "source": [ - "pd_df[[\"name\", \"age\"]]\n" + "pd_df[[\"name\", \"age\"]]" ] }, { @@ -638,7 +638,7 @@ } ], "source": [ - "student[[\"name\", \"age\"]]\n" + "student[[\"name\", \"age\"]]" ] }, { @@ -679,7 +679,7 @@ } ], "source": [ - "pd_df[\"name\"]\n" + "pd_df[\"name\"]" ] }, { @@ -706,7 +706,7 @@ } ], "source": [ - "gp_df[\"name\"]\n" + "gp_df[\"name\"]" ] }, { @@ -788,7 +788,7 @@ } ], "source": [ - "pd_df[lambda df: df[\"name\"] == \"alice\"]\n" + "pd_df[lambda df: df[\"name\"] == \"alice\"]" ] }, { @@ -832,7 +832,7 @@ } ], "source": [ - "student[lambda t: t[\"name\"] == \"alice\"]\n" + "student[lambda t: t[\"name\"] == \"alice\"]" ] }, { @@ -898,7 +898,7 @@ } ], "source": [ - "student[:2]\n" + "student[:2]" ] }, { @@ -965,7 +965,7 @@ } ], "source": [ - "pd_df[:2]\n" + "pd_df[:2]" ] }, { @@ -1062,7 +1062,7 @@ } ], "source": [ - "pd_df.sort_values([\"age\", \"name\"], ascending=[False, False])\n" + "pd_df.sort_values([\"age\", \"name\"], ascending=[False, False])" ] }, { @@ -1116,7 +1116,7 @@ } ], "source": [ - "student.order_by(\"age\", ascending=False).order_by(\"name\", ascending=False)[:]\n" + "student.order_by(\"age\", ascending=False).order_by(\"name\", ascending=False)[:]" ] }, { @@ -1219,7 +1219,7 @@ "import datetime\n", "\n", "this_year = datetime.date.today().year\n", - "pd_df.assign(year_of_birth=lambda df: -df[\"age\"] + this_year)\n" + "pd_df.assign(year_of_birth=lambda df: -df[\"age\"] + this_year)" ] }, { @@ -1277,7 +1277,7 @@ } ], "source": [ - "student.assign(year_of_birth=lambda t: -t[\"age\"] + this_year)\n" + "student.assign(year_of_birth=lambda t: -t[\"age\"] + this_year)" ] }, { @@ -1297,9 +1297,10 @@ "source": [ "from hashlib import sha256\n", "\n", + "\n", "@gp.create_function\n", "def hash_name(name: str) -> str:\n", - " return sha256(name.encode(\"utf-8\")).hexdigest()\n" + " return sha256(name.encode(\"utf-8\")).hexdigest()" ] }, { @@ -1359,7 +1360,7 @@ } ], "source": [ - "student.assign(name_=lambda t: hash_name(t[\"name\"]))\n" + "student.assign(name_=lambda t: hash_name(t[\"name\"]))" ] }, { @@ -1430,7 +1431,7 @@ " return Student(name=sha256(name.encode(\"utf-8\")).hexdigest(), age=age)\n", "\n", "\n", - "student.apply(lambda t: gp.create_function(hide_name)(t[\"name\"], t[\"age\"]), expand=True)\n" + "student.apply(lambda t: gp.create_function(hide_name)(t[\"name\"], t[\"age\"]), expand=True)" ] }, { @@ -1503,11 +1504,7 @@ } ], "source": [ - "pd_df.apply(\n", - " lambda df: asdict(hide_name(df[\"name\"], df[\"age\"])),\n", - " axis=1, \n", - " result_type=\"expand\"\n", - ")\n" + "pd_df.apply(lambda df: asdict(hide_name(df[\"name\"], df[\"age\"])), axis=1, result_type=\"expand\")" ] }, { @@ -1555,7 +1552,7 @@ "source": [ "import numpy as np\n", "\n", - "pd_df.groupby(\"age\").apply(lambda df: np.count_nonzero(df[\"name\"]))\n" + "pd_df.groupby(\"age\").apply(lambda df: np.count_nonzero(df[\"name\"]))" ] }, { @@ -1606,7 +1603,7 @@ "source": [ "count = gp.aggregate_function(\"count\")\n", "\n", - "student.group_by(\"age\").apply(lambda t: count(t[\"name\"]))\n" + "student.group_by(\"age\").apply(lambda t: count(t[\"name\"]))" ] }, { @@ -1681,7 +1678,7 @@ } ], "source": [ - "pd_df.drop_duplicates(\"age\")\n" + "pd_df.drop_duplicates(\"age\")" ] }, { @@ -1730,7 +1727,7 @@ } ], "source": [ - "student.distinct_on(\"age\")\n" + "student.distinct_on(\"age\")" ] }, { @@ -1774,7 +1771,7 @@ } ], "source": [ - "student.apply(lambda t: count.distinct(t[\"age\"]))\n" + "student.apply(lambda t: count.distinct(t[\"age\"]))" ] }, { @@ -1873,7 +1870,7 @@ } ], "source": [ - "pd_df.merge(pd_df, on=\"age\", suffixes=(\"\", \"_2\"))\n" + "pd_df.merge(pd_df, on=\"age\", suffixes=(\"\", \"_2\"))" ] }, { @@ -1943,7 +1940,7 @@ } ], "source": [ - "student.join(student, on=\"age\", other_columns={\"name\": \"name_2\"})\n" + "student.join(student, on=\"age\", other_columns={\"name\": \"name_2\"})" ] }, { @@ -2029,7 +2026,7 @@ "num_1 = pd.DataFrame({\"val\": [1, 3, 5, 7, 9]})\n", "num_2 = pd.DataFrame({\"val\": [2, 4, 6, 8, 10]})\n", "\n", - "num_1[num_2[\"val\"] % 2 == 0] # Even numbers?\n" + "num_1[num_2[\"val\"] % 2 == 0] # Even numbers?" ] }, { diff --git a/greenplumpython/builtins/functions.py b/greenplumpython/builtins/functions.py index 8f2481ab..469a8e05 100644 --- a/greenplumpython/builtins/functions.py +++ b/greenplumpython/builtins/functions.py @@ -29,7 +29,8 @@ def count( """ if arg is None: - return FunctionExpr(aggregate_function(name="count"), ("*",)) + no_arg: tuple[()] = tuple() + return FunctionExpr(aggregate_function(name="count"), no_arg) return FunctionExpr(aggregate_function(name="count"), (arg,)) @@ -162,4 +163,6 @@ def generate_series(start: Any, stop: Any, step: Optional[Any] = None) -> Functi (10 rows) """ - return FunctionExpr(function(name="generate_series"), (start, stop, step)) + return FunctionExpr( + function(name="generate_series"), (start, stop, step) if step is not None else (start, stop) + ) diff --git a/greenplumpython/col.py b/greenplumpython/col.py index 8e3659db..a7c0220f 100644 --- a/greenplumpython/col.py +++ b/greenplumpython/col.py @@ -3,7 +3,7 @@ from greenplumpython.db import Database from greenplumpython.expr import Expr -from greenplumpython.type import Type +from greenplumpython.type import DataType if TYPE_CHECKING: from greenplumpython.dataframe import DataFrame @@ -28,8 +28,12 @@ def __init__( self._column = column super().__init__(column._dataframe) - def _serialize(self) -> str: - return f'({self._column._serialize()})."{self._field_name}"' + def _serialize(self, db: Optional[Database] = None) -> str: + return ( + f'({self._column._serialize(db=db)})."{self._field_name}"' + if self._field_name != "*" + else f"({self._column._serialize(db=db)}).*" + ) class Column(Expr): @@ -44,9 +48,9 @@ def __init__(self, name: str, dataframe: "DataFrame") -> None: """:meta private:""" super().__init__(dataframe=dataframe) self._name = name - self._type: Optional[Type] = None # TODO: Add type inference + self._type: Optional[DataType] = None # TODO: Add type inference - def _serialize(self) -> str: + def _serialize(self, db: Optional[Database] = None) -> str: assert self._dataframe is not None # Quote both dataframe name and column name to avoid SQL injection. return ( diff --git a/greenplumpython/dataframe.py b/greenplumpython/dataframe.py index 28aed86e..26042522 100644 --- a/greenplumpython/dataframe.py +++ b/greenplumpython/dataframe.py @@ -54,7 +54,7 @@ from greenplumpython.col import Column, Expr from greenplumpython.db import Database -from greenplumpython.expr import _serialize +from greenplumpython.expr import _serialize_to_expr from greenplumpython.group import DataFrameGroupingSet from greenplumpython.order import DataFrameOrdering from greenplumpython.row import Row @@ -99,7 +99,7 @@ def _(self, predicate: Callable[["DataFrame"], Expr]): @_getitem.register(list) def _(self, column_names: List[str]) -> "DataFrame": - targets_str = [_serialize(self[col]) for col in column_names] + targets_str = [_serialize_to_expr(self[col], db=self._db) for col in column_names] return DataFrame( f""" SELECT {','.join(targets_str)} @@ -333,11 +333,14 @@ def where(self, predicate: Callable[["DataFrame"], "Expr"]) -> "DataFrame": """ v = predicate(self) + assert isinstance(v, Expr), "Predicate must be an expression." assert v._dataframe == self, "Predicate must based on current dataframe" parents = [self] if v._other_dataframe is not None and self._name != v._other_dataframe._name: parents.append(v._other_dataframe) - return DataFrame(f"SELECT * FROM {self._name} WHERE {v._serialize()}", parents=parents) + return DataFrame( + f"SELECT * FROM {self._name} WHERE {v._serialize(db=self._db)}", parents=parents + ) def apply( self, @@ -433,8 +436,8 @@ def apply( # explicitly. return ( func(self) - ._bind(dataframe=self, db=self._db) - .apply(expand=expand, column_name=column_name) + ._bind(dataframe=self) + .apply(expand=expand, column_name=column_name, db=self._db) ) def assign(self, **new_columns: Callable[["DataFrame"], Any]) -> "DataFrame": @@ -488,8 +491,7 @@ def assign(self, **new_columns: Callable[["DataFrame"], Any]) -> "DataFrame": if v._other_dataframe is not None and v._other_dataframe._name != self._name: if v._other_dataframe._name not in other_parents: other_parents[v._other_dataframe._name] = v._other_dataframe - v = v._bind(db=self._db) - targets.append(f"{_serialize(v)} AS {k}") + targets.append(f"{_serialize_to_expr(v, db=self._db)} AS {k}") return DataFrame( f"SELECT *, {','.join(targets)} FROM {self._name}", parents=[self] + list(other_parents.values()), @@ -621,25 +623,27 @@ def join( ], "Unsupported join type" assert cond is None or on is None, 'Cannot specify "cond" and "using" together' - def _bind( - t: DataFrame, db: Database, columns: Union[Dict[str, Optional[str]], Set[str]] - ) -> List[str]: + def _bind(t: DataFrame, columns: Union[Dict[str, Optional[str]], Set[str]]) -> List[str]: target_list: List[str] = [] for k in columns: - col: Column = t[k]._bind(db=db) + col: Column = t[k] v = columns[k] if isinstance(columns, dict) else None - target_list.append(col._serialize() + (f' AS "{v}"' if v is not None else "")) + target_list.append( + col._serialize(db=t._db) + (f' AS "{v}"' if v is not None else "") + ) return target_list other_temp = other if self._name != other._name else DataFrame(query="") other_clause = ( other._name if self._name != other._name else other._name + " AS " + other_temp._name ) - target_list = _bind(self, db=self._db, columns=self_columns) + _bind( - other_temp, db=other._db, columns=other_columns - ) + target_list = _bind(self, columns=self_columns) + _bind(other_temp, columns=other_columns) # ON clause in SQL uses argument `cond`. - sql_on_clause = f"ON {cond(self, other_temp)._serialize()}" if cond is not None else "" + if cond is not None: + assert isinstance(cond(self, other_temp), Expr), "Join Predicate must be an expression." + sql_on_clause = ( + f"ON {cond(self, other_temp)._serialize(db=self._db)}" if cond is not None else "" + ) join_column_names = ( (f'"{on}"' if isinstance(on, str) else ",".join([f'"{name}"' for name in on])) if on is not None @@ -1057,7 +1061,7 @@ def distinct_on(self, *column_names: str) -> "DataFrame": will randomly pick one of them for the name column. Use "[['age']]" to make sure the result is stable. """ - cols = [Column(name, self)._serialize() for name in column_names] + cols: list[Column] = [self[name]._serialize(db=self._db) for name in column_names] return DataFrame( f"SELECT DISTINCT ON ({','.join(cols)}) * FROM {self._name}", parents=[self], @@ -1136,7 +1140,10 @@ def from_rows( column_names = first_row.keys() assert column_names is not None, "Column names of the DataFrame is unknown." rows_string = ",".join( - [f"({','.join(_serialize(datum) for datum in row)})" for row in row_tuples] + [ + f"({','.join(_serialize_to_expr(datum, db=db) for datum in row)})" + for row in row_tuples + ] ) column_names = [f'"{name}"' for name in column_names] columns_string = f"({','.join(column_names)})" @@ -1174,6 +1181,6 @@ def from_columns(cls, columns: Dict[str, Iterable[Any]], db: Database) -> "DataF (3 rows) """ columns_string = ",".join( - [f'unnest({_serialize(list(v))}) AS "{k}"' for k, v in columns.items()] + [f'unnest({_serialize_to_expr(list(v), db=db)}) AS "{k}"' for k, v in columns.items()] ) return DataFrame(f"SELECT {columns_string}", db=db) diff --git a/greenplumpython/db.py b/greenplumpython/db.py index 6badf2ce..cc38c2a3 100644 --- a/greenplumpython/db.py +++ b/greenplumpython/db.py @@ -178,7 +178,7 @@ def apply( ----- (1 row) """ - return func()._bind(db=self).apply(expand=expand, column_name=column_name) + return func().apply(expand=expand, column_name=column_name, db=self) def assign(self, **new_columns: Callable[[], Any]) -> "DataFrame": """ @@ -206,18 +206,15 @@ def assign(self, **new_columns: Callable[[], Any]) -> "DataFrame": (1 row) """ from greenplumpython.dataframe import DataFrame - from greenplumpython.expr import Expr, _serialize + from greenplumpython.expr import Expr, _serialize_to_expr from greenplumpython.func import FunctionExpr targets: List[str] = [] for k, f in new_columns.items(): v: Any = f() if isinstance(v, Expr): - v._bind(db=self) assert v._dataframe is None, "New column should not depend on any dataframe." - if isinstance(v, FunctionExpr): - v = v._bind(db=self) - targets.append(f"{_serialize(v)} AS {k}") + targets.append(f"{_serialize_to_expr(v, db=self)} AS {k}") return DataFrame(f"SELECT {','.join(targets)}", db=self) diff --git a/greenplumpython/expr.py b/greenplumpython/expr.py index dc5222e6..051a7b73 100644 --- a/greenplumpython/expr.py +++ b/greenplumpython/expr.py @@ -35,7 +35,7 @@ def _bind( def __hash__(self) -> int: # noqa: D105 - return hash(self._serialize()) + return hash(self._serialize(db=None)) def __and__(self, other: Any) -> "BinaryExpr": """ @@ -495,9 +495,9 @@ def like(self, pattern: str) -> "BinaryExpr": def __str__(self) -> str: """Return string statement of Expr.""" - return self._serialize() + return self._serialize(db=None) - def _serialize(self) -> str: + def _serialize(self, db: Optional[Database] = None) -> str: raise NotImplementedError() # NOTE: We cannot use __contains__() because the return value will always @@ -546,20 +546,20 @@ def in_(self, container: Union["Expr", List[Any]]) -> "InExpr": from psycopg2.extensions import adapt # type: ignore -def _serialize(value: Any) -> str: +def _serialize_to_expr(obj: Any, db: Optional[Database] = None) -> str: # noqa: D400 """ :meta private: - Converts a value to UTF-8 encoded str to be used in a SQL statement + Converts any Python object to a SQL expression. Note: It is OK to consider UTF-8 only since all `strs` are encoded in UTF-8 in Python 3 and Python 2 is EOL officially. """ - if isinstance(value, Expr): - return value._serialize() - return adapt(value).getquoted().decode("utf-8") # type: ignore + if isinstance(obj, Expr): + return obj._serialize(db=db) + return adapt(obj).getquoted().decode("utf-8") # type: ignore class BinaryExpr(Expr): @@ -629,11 +629,9 @@ def __init__( """ self._init(operator, left, right) - def _serialize(self) -> str: - from greenplumpython.expr import _serialize - - left_str = _serialize(self._left) - right_str = _serialize(self._right) + def _serialize(self, db: Optional[Database] = None) -> str: + left_str = _serialize_to_expr(self._left, db=db) + right_str = _serialize_to_expr(self._right, db=db) return f"({left_str} {self._operator} {right_str})" @@ -653,8 +651,8 @@ def __init__( self.operator = operator self.right = right - def _serialize(self) -> str: - right_str = str(self.right) + def _serialize(self, db: Optional[Database] = None) -> str: + right_str = _serialize_to_expr(self.right, db=db) return f"{self.operator} ({right_str})" @@ -674,7 +672,7 @@ def __init__( self._item = item self._container = container - def _serialize(self) -> str: + def _serialize(self, db: Optional[Database] = None) -> str: if isinstance(self._container, Expr): assert self._other_dataframe is not None, "DataFrame of container is unknown." # Using either IN or = any() will violate @@ -684,10 +682,10 @@ def _serialize(self) -> str: if isinstance(self._container, Expr) and self._other_dataframe is not None: return ( f"(EXISTS (SELECT FROM {self._other_dataframe._name}" - f" WHERE ({self._container._serialize()} = {self._item._serialize()})))" + f" WHERE ({self._container._serialize(db=db)} = {self._item._serialize(db=db)})))" ) return ( - f'(EXISTS (SELECT FROM unnest({_serialize(self._container)}) AS "{container_name}"' - f' WHERE ("{container_name}" = {self._item._serialize()})))' + f'(EXISTS (SELECT FROM unnest({_serialize_to_expr(self._container, db=db)}) AS "{container_name}"' + f' WHERE ("{container_name}" = {self._item._serialize(db=db)})))' ) diff --git a/greenplumpython/func.py b/greenplumpython/func.py index acf9e453..a497c9c4 100644 --- a/greenplumpython/func.py +++ b/greenplumpython/func.py @@ -17,9 +17,9 @@ from greenplumpython.col import Column from greenplumpython.dataframe import DataFrame from greenplumpython.db import Database -from greenplumpython.expr import Expr, _serialize +from greenplumpython.expr import Expr, _serialize_to_expr from greenplumpython.group import DataFrameGroupingSet -from greenplumpython.type import to_pg_type +from greenplumpython.type import _serialize_to_type class FunctionExpr(Expr): @@ -57,46 +57,35 @@ def _bind( self, group_by: Optional[DataFrameGroupingSet] = None, dataframe: Optional[DataFrame] = None, - db: Optional[Database] = None, ): # noqa D400 """:meta private:""" - f = FunctionExpr( + return FunctionExpr( self._func, self._args, group_by=group_by, dataframe=dataframe, distinct=self._distinct, ) - f._db = ( - db - if db is not None - else dataframe._db - if dataframe is not None - else group_by._dataframe._db - if group_by is not None - else self._db - ) - assert f._db is not None - return f - def _serialize(self) -> str: + def _serialize(self, db: Optional[Database] = None) -> str: # noqa D400 """:meta private:""" - assert self._db is not None, "Database is required to create function." - self._function._create_in_db(self._db) + if db is not None: + self._function._create_in_db(db) distinct = "DISTINCT" if self._distinct else "" - for arg in self._args: - if isinstance(arg, Expr): - arg._db = self._db args_string = ( - ",".join([_serialize(arg) for arg in self._args if arg is not None]) + ",".join([_serialize_to_expr(arg, db=db) for arg in self._args]) if any(self._args) else "" + if not isinstance(self._func, AggregateFunction) + else "*" ) return f"{self._function._qualified_name_str}({distinct} {args_string})" - def apply(self, expand: bool = False, column_name: Optional[str] = None) -> DataFrame: + def apply( + self, expand: bool = False, column_name: Optional[str] = None, db: Optional[Database] = None + ) -> DataFrame: # noqa D400 """ :meta private: @@ -107,7 +96,7 @@ def apply(self, expand: bool = False, column_name: Optional[str] = None) -> Data assert not ( expand and column_name is not None ), "Cannot assign single column name when expanding multi-valued results." - self._function._create_in_db(self._db) + self._function._create_in_db(db=db) from_clause = f"FROM {self._dataframe._name}" if self._dataframe is not None else "" group_by_clause = self._group_by._clause() if self._group_by is not None else "" if expand and column_name is None: @@ -119,20 +108,20 @@ def apply(self, expand: bool = False, column_name: Optional[str] = None) -> Data # unique. This can be mitigated after we implement dataframe column # inference by raising an error when the function gets called. grouping_cols = ( - [Column(name, self._dataframe)._serialize() for name in grouping_col_names] + [Column(name, self._dataframe)._serialize(db=None) for name in grouping_col_names] if grouping_col_names is not None and len(grouping_col_names) != 0 else None ) - orig_func_dataframe = DataFrame( + unexpanded_dataframe = DataFrame( " ".join( [ - f"SELECT {str(self)} {'AS ' + column_name if column_name is not None else ''}", + f"SELECT {_serialize_to_expr(self, db=db)} {'AS ' + column_name if column_name is not None else ''}", ("," + ",".join(grouping_cols)) if (grouping_cols is not None) else "", from_clause, group_by_clause, ] ), - db=self._db, + db=db, parents=parents, ) # We use 2 `DataFrame`s here because on GPDB 6X and PostgreSQL <= 9.6, a @@ -149,24 +138,25 @@ def apply(self, expand: bool = False, column_name: Optional[str] = None) -> Data # SELECT (result).* FROM func_call; # ``` rebased_grouping_cols = ( - [Column(name, orig_func_dataframe)._serialize() for name in grouping_col_names] + [_serialize_to_expr(unexpanded_dataframe[name], db=db) for name in grouping_col_names] if grouping_col_names is not None else None ) - results = ( - "*" + result_cols = ( + _serialize_to_expr(unexpanded_dataframe["*"], db=db) if not expand - else f"({column_name}).*" + else _serialize_to_expr(unexpanded_dataframe[column_name]["*"], db=db) + # `len(rebased_grouping_cols) == 0` means `GROUP BY GROUPING SETS (())` if rebased_grouping_cols is None or len(rebased_grouping_cols) == 0 - else f"({orig_func_dataframe._name}).*" + else f"({unexpanded_dataframe._name}).*" if not expand - else f"{','.join(rebased_grouping_cols)}, ({column_name}).*" + else f"{','.join(rebased_grouping_cols)}, {_serialize_to_expr(unexpanded_dataframe[column_name]['*'], db=db)}" ) return DataFrame( - f"SELECT {str(results)} FROM {orig_func_dataframe._name}", - db=self._db, - parents=[orig_func_dataframe], + f"SELECT {result_cols} FROM {unexpanded_dataframe._name}", + db=db, + parents=[unexpanded_dataframe], ) @property @@ -182,10 +172,11 @@ class ArrayFunctionExpr(FunctionExpr): It will array aggregate all the columns given by the user. """ - def _serialize(self) -> str: + def _serialize(self, db: Optional[Database] = None) -> str: # noqa D400 """:meta private:""" - self._function._create_in_db(self._db) + if db is not None: + self._function._create_in_db(db) args_string_list = [] args_string = "" grouping_col_names = self._group_by._flatten() if self._group_by is not None else None @@ -200,11 +191,11 @@ def _serialize(self) -> str: continue if isinstance(self._args[i], Expr): if grouping_cols is None or self._args[i] not in grouping_cols: - s = f"array_agg({str(self._args[i])})" # type: ignore + s = f"array_agg({_serialize_to_expr(self._args[i], db=db)})" # type: ignore else: - s = str(self._args[i]) # type: ignore + s = _serialize_to_expr(self._args[i], db=db) # type: ignore else: - s = _serialize(self._args[i]) + s = _serialize_to_expr(self._args[i], db=db) args_string_list.append(s) args_string = ",".join(args_string_list) return f"{self._function._qualified_name_str}({args_string})" @@ -213,18 +204,15 @@ def _bind( self, group_by: Optional[DataFrameGroupingSet] = None, dataframe: Optional[DataFrame] = None, - db: Optional[Database] = None, ): # noqa D400 """:meta private:""" - array_f = ArrayFunctionExpr( + return ArrayFunctionExpr( self._func, self._args, group_by=group_by if group_by else self._group_by, dataframe=dataframe, ) - array_f._db = db if db is not None else self._db - return array_f # The parent class for all database functions. @@ -319,14 +307,14 @@ def _create_in_db(self, db: Database) -> None: func_sig = inspect.signature(self._wrapped_func) func_args = ",".join( [ - f"{param.name} {to_pg_type(param.annotation, db=db)}" + f"{param.name} {_serialize_to_type(param.annotation, db=db)}" for param in func_sig.parameters.values() ] ) func_arg_names = ",".join( [f"{param.name}={param.name}" for param in func_sig.parameters.values()] ) - return_type = to_pg_type(func_sig.return_annotation, db=db, for_return=True) + return_type = _serialize_to_type(func_sig.return_annotation, db=db, for_return=True) func_pickled: bytes = dill.dumps(self._wrapped_func) _, func_name = self._qualified_name # Modify the AST of the wrapped function to minify dependency: (1-3) @@ -478,14 +466,17 @@ def _create_in_db(self, db: Database) -> None: param_list = iter(sig.parameters.values()) state_param = next(param_list) args_string = ",".join( - [f"{param.name} {to_pg_type(param.annotation, db=db)}" for param in param_list] + [ + f"{param.name} {_serialize_to_type(param.annotation, db=db)}" + for param in param_list + ] ) # -- Creation of UDA in Greenplum db._execute( ( f"CREATE AGGREGATE {self._qualified_name_str} ({args_string}) (\n" f" SFUNC = {self.transition_function._qualified_name_str},\n" - f" STYPE = {to_pg_type(state_param.annotation, db=db)}\n" + f" STYPE = {_serialize_to_type(state_param.annotation, db=db)}\n" f");\n" ), has_results=False, diff --git a/greenplumpython/group.py b/greenplumpython/group.py index fe47cacc..3192f2d7 100644 --- a/greenplumpython/group.py +++ b/greenplumpython/group.py @@ -1,17 +1,8 @@ """Definitions for the result of grouping :class:`~dataframe.DataFrame`.""" -from typing import ( - TYPE_CHECKING, - Any, - Callable, - Iterable, - List, - MutableSet, - Optional, - Set, -) - -from greenplumpython.expr import Expr, _serialize +from typing import TYPE_CHECKING, Any, Callable, List, Optional + +from greenplumpython.expr import Expr, _serialize_to_expr if TYPE_CHECKING: from greenplumpython.dataframe import DataFrame @@ -114,10 +105,12 @@ def apply( ----------------------- (2 rows) """ - return ( - func(self._dataframe) - ._bind(group_by=self, db=self._dataframe._db) - .apply(expand=expand, column_name=column_name) + from greenplumpython.func import FunctionExpr + + v: FunctionExpr = func(self._dataframe) + assert isinstance(v, FunctionExpr), "Can only apply functions." + return v._bind(group_by=self).apply( + expand=expand, column_name=column_name, db=self._dataframe._db ) def assign(self, **new_columns: Callable[["DataFrame"], Any]) -> "DataFrame": @@ -164,8 +157,7 @@ def assign(self, **new_columns: Callable[["DataFrame"], Any]) -> "DataFrame": assert ( v._dataframe is None or v._dataframe == self._dataframe ), "Newly included columns must be based on the current dataframe" - v = v._bind(db=self._dataframe._db) - targets.append(f"{_serialize(v)} AS {k}") + targets.append(f"{_serialize_to_expr(v, db=self._dataframe._db)} AS {k}") return DataFrame( f"SELECT {','.join(targets)} FROM {self._dataframe._name} {self._clause()}", parents=[self._dataframe], diff --git a/greenplumpython/order.py b/greenplumpython/order.py index 1dfc35e0..a5180214 100644 --- a/greenplumpython/order.py +++ b/greenplumpython/order.py @@ -126,7 +126,7 @@ def _clause(self) -> str: [ " ".join( [ - Column(self._column_name_list[i], self._dataframe)._serialize(), + Column(self._column_name_list[i], self._dataframe)._serialize(db=None), "" if self._ascending_list[i] is None else "ASC" diff --git a/greenplumpython/type.py b/greenplumpython/type.py index d3ea1d32..63d063d8 100644 --- a/greenplumpython/type.py +++ b/greenplumpython/type.py @@ -1,9 +1,19 @@ # noqa: D100 -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, get_type_hints +from typing import ( + TYPE_CHECKING, + Any, + Dict, + List, + Optional, + Set, + Tuple, + Union, + get_type_hints, +) from uuid import uuid4 from greenplumpython.db import Database -from greenplumpython.expr import Expr, _serialize +from greenplumpython.expr import Expr, _serialize_to_expr if TYPE_CHECKING: from greenplumpython.dataframe import DataFrame @@ -43,46 +53,12 @@ def __init__(self, obj: object, qualified_type_name: str) -> None: self._obj = obj self._qualified_type_name = qualified_type_name - def _serialize(self) -> str: - obj_str = _serialize(self._obj) + def _serialize(self, db: Optional[Database] = None) -> str: + obj_str = _serialize_to_expr(self._obj, db=db) return f"({obj_str}::{self._qualified_type_name})" - def _bind( - self, - dataframe: Optional["DataFrame"] = None, - db: Optional[Database] = None, - column_name: str = None, - ) -> "Expr": - # noqa D102 - self._db = db - if isinstance(self._obj, Expr): - self._obj = self._obj._bind( - dataframe=dataframe, - db=db, - ) - return self - - def apply( - self, expand: bool = False, column_name: Optional[str] = None, row_id: Optional[str] = None - ) -> "DataFrame": - # noqa D102 - from greenplumpython.dataframe import DataFrame - - if expand and column_name is None: - column_name = "func_" + uuid4().hex - return DataFrame( - f""" - SELECT {(row_id + ',') if row_id is not None else ''} - {self._serialize()} - {'AS ' + column_name if column_name is not None else ''} - {('FROM ' + self._obj._dataframe._name) if isinstance(self._obj, Expr) and self._obj._dataframe is not None else ""} - """, - db=self._db, - parents=[self._obj._dataframe], - ) - -class Type: +class DataType: """ Represents a type of values in a :class:`~dataframe.DataFrame`. @@ -95,7 +71,7 @@ class Type: case, a type annotation object is provided such as the defined\ :code:`class`. - A :class:`~type.Type` object is callable. when called, it casts the object in + A :class:`~type.DataType` object is callable. when called, it casts the object in the argument to the mapped type in database. """ @@ -140,7 +116,7 @@ def _create_in_db(self, db: Database): if len(members) == 0: raise Exception(f"Failed to get annotations for type {self._annotation}") att_type_str = ",\n".join( - [f"{name} {to_pg_type(type_t, db)}" for name, type_t in members.items()] + [f"{name} {_serialize_to_type(type_t, db)}" for name, type_t in members.items()] ) db._execute( f'CREATE TYPE "{schema}"."{self._name}" AS (\n' f"{att_type_str}\n" f");", @@ -165,26 +141,26 @@ def __call__(self, obj: Any) -> TypeCast: @property def _qualified_name(self) -> Tuple[Optional[str], str]: """ - Return the schema name and name of :class:`~type.Type`. + Return the schema name and name of :class:`~type.DataType`. Returns: - Tuple[str, str]: schema name and :class:`~type.Type`'s name. + Tuple[str, str]: schema name and :class:`~type.DataType`'s name. """ return self._schema, self._name # -- Map between Python and Greenplum primitive types -_defined_types: Dict[Optional[type], Type] = { - None: Type(name="void"), - int: Type(name="int4"), - float: Type(name="float8"), - bool: Type(name="bool"), - str: Type(name="text"), - bytes: Type(name="bytea"), +_defined_types: Dict[Optional[type], DataType] = { + None: DataType(name="void"), + int: DataType(name="int4"), + float: DataType(name="float8"), + bool: DataType(name="bool"), + str: DataType(name="text"), + bytes: DataType(name="bytea"), } -def type_(name: str, schema: Optional[str] = None, modifier: Optional[int] = None) -> Type: +def type_(name: str, schema: Optional[str] = None, modifier: Optional[int] = None) -> DataType: """ Get access to a type predefined in database. @@ -196,12 +172,12 @@ def type_(name: str, schema: Optional[str] = None, modifier: Optional[int] = Non Returns: The predefined type as a :class:`~type.Type` object. """ - return Type(name, schema=schema, modifier=modifier) + return DataType(name, schema=schema, modifier=modifier) -def to_pg_type( - annotation: Optional[type], - db: Optional[Database] = None, +def _serialize_to_type( + annotation: Union[DataType, type], + db: Database, for_return: bool = False, ) -> str: # noqa: D400 @@ -219,20 +195,20 @@ def to_pg_type( Returns: str: name of type in SQL """ - if annotation is not None and hasattr(annotation, "__origin__"): + if hasattr(annotation, "__origin__"): # The `or` here is to make the function work on Python 3.6. # Python 3.6 is the default Python version on CentOS 7 and Ubuntu 18.04 if annotation.__origin__ == list or annotation.__origin__ == List: args: Tuple[type, ...] = annotation.__args__ if for_return: - return f"SETOF {to_pg_type(args[0], db)}" # type: ignore + return f"SETOF {_serialize_to_type(args[0], db)}" # type: ignore if args[0] in _defined_types: - return f"{to_pg_type(args[0], db)}[]" # type: ignore + return f"{_serialize_to_type(args[0], db)}[]" # type: ignore raise NotImplementedError() else: assert db is not None, "Database is required to create type" if annotation not in _defined_types: type_name = "type_" + uuid4().hex - _defined_types[annotation] = Type(name=type_name, annotation=annotation) + _defined_types[annotation] = DataType(name=type_name, annotation=annotation) _defined_types[annotation]._create_in_db(db) return _defined_types[annotation]._qualified_name_str diff --git a/tests/test_func.py b/tests/test_func.py index d8d5fff7..2d27a903 100644 --- a/tests/test_func.py +++ b/tests/test_func.py @@ -4,7 +4,7 @@ import pytest import greenplumpython as gp -from greenplumpython.builtins.functions import generate_series +from greenplumpython.builtins.functions import count, generate_series from greenplumpython.func import AggregateFunction, NormalFunction from tests import db @@ -800,15 +800,36 @@ def test_func_non_default_schema(db: gp.Database): assert sorted([row["abs"] for row in results2]) == list(range(1, 11)) -def test_func_nested_create(db: gp.Database): - @gp.create_function - def add_one(x: int) -> int: - return x + 1 +@gp.create_function +def add_one(x: int) -> int: + return x + 1 - @gp.create_function - def add_two(x: int) -> int: - return x + 2 +@gp.create_function +def add_two(x: int) -> int: + return x + 2 + + +def test_func_nested_create(db: gp.Database): result = db.apply(lambda: add_two(add_one(1)), column_name="val") for row in result: assert row["val"] == 1 + 1 + 2 + + +def test_count_none(db: gp.Database): + for row in db.create_dataframe(columns={"none": [1, None]}).apply( + lambda _: count(), column_name="count" + ): + assert row["count"] == 2 + + +def test_func_in_binary_expr(db: gp.Database): + result = db.assign(val=lambda: add_two(1) + add_one(1)) + for row in result: + assert row["val"] == (1 + 2) + (1 + 1) + + +def test_func_in_where(db: gp.Database): + df = db.create_dataframe(columns={"a": [1]}) + result = df.where(lambda t: add_two(t["a"]) < 5) + assert len(list(result)) == 1 diff --git a/tests/test_type.py b/tests/test_type.py index a85e0aaf..c7fc2ed6 100644 --- a/tests/test_type.py +++ b/tests/test_type.py @@ -3,7 +3,7 @@ import pytest import greenplumpython as gp -from greenplumpython.type import to_pg_type +from greenplumpython.type import _serialize_to_type from tests import db @@ -42,9 +42,8 @@ def test_type_cast_func_result(db: gp.Database): def func(a: int, b: int) -> int: return a + b - results_app = df.apply( - lambda t: float8(func(t["a"], t["b"])), - column_name="float8", + results_app = df.assign( + float8=lambda t: float8(func(t["a"], t["b"])), ) assert sorted([row["float8"] for row in results_app]) == list(range(0, 20, 2)) @@ -77,7 +76,7 @@ class Person: _first_name: str _last_name: str - type_name = to_pg_type(Person, db=db) + type_name = _serialize_to_type(Person, db=db) assert isinstance(type_name, str) @@ -89,5 +88,5 @@ def __init__(self, _first_name: str, _last_name: str) -> None: self._last_name = _last_name with pytest.raises(Exception) as exc_info: - to_pg_type(Person, db=db) + _serialize_to_type(Person, db=db) assert "Failed to get annotations" in str(exc_info.value)