Skip to content
This repository has been archived by the owner on Jul 16, 2024. It is now read-only.

Commit

Permalink
Make db a parameter of serialize() (#208)
Browse files Browse the repository at this point in the history
Previously, `db` is an attribute of class `Expr` and its value
cannot be decided when init a new object. As a result, when the
`Expr` is serialized, we need to explicitly bind it to a db to
create functions or types. However, in some cases we forgot to
do that, which lead to errros.

To avoid forgotting this, this patch moves db to a parameter of
`serialize()` and pass db to it each time we call it.

This patch fixes the following bugs in `serialize()`:

1. Database is not found when trying to create a function or data
   type.
2. Star notation `*` is not serialized correctly in column fileds
   when expanding a composite type result.
3. `count(*)` returns wrong result when dataframe contains `None`.

---------

Co-authored-by: Ruxue Zeng <[email protected]>
  • Loading branch information
xuebinsu and ruxuez authored Aug 4, 2023
1 parent d3b317a commit 5086dac
Show file tree
Hide file tree
Showing 14 changed files with 231 additions and 252 deletions.
24 changes: 12 additions & 12 deletions doc/source/notebooks/abalone.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -504,6 +504,7 @@
"import numpy as np\n",
"import pickle\n",
"\n",
"\n",
"@gp.create_column_function\n",
"def linreg_func(length: List[float], shucked_weight: List[float], rings: List[int]) -> LinregType:\n",
" X = np.array([length, shucked_weight]).T\n",
Expand Down Expand Up @@ -560,9 +561,8 @@
"# ) a\n",
"# ) DISTRIBUTED BY (sex);\n",
"\n",
"linreg_fitted = (\n",
" abalone_train.group_by(\"sex\")\n",
" .apply(lambda t: linreg_func(t[\"length\"], t[\"shucked_weight\"], t[\"rings\"]), expand=True)\n",
"linreg_fitted = abalone_train.group_by(\"sex\").apply(\n",
" lambda t: linreg_func(t[\"length\"], t[\"shucked_weight\"], t[\"rings\"]), expand=True\n",
")"
]
},
Expand Down Expand Up @@ -800,7 +800,7 @@
"linreg_test_fit = linreg_fitted.inner_join(\n",
" abalone_test,\n",
" cond=lambda t1, t2: t1[\"sex\"] == t2[\"sex\"],\n",
" self_columns=[\"col_nm\", \"coef\", \"intercept\", \"serialized_linreg_model\", \"created_dt\"]\n",
" self_columns=[\"col_nm\", \"coef\", \"intercept\", \"serialized_linreg_model\", \"created_dt\"],\n",
")"
]
},
Expand Down Expand Up @@ -836,12 +836,11 @@
"\n",
"\n",
"linreg_pred = linreg_test_fit.assign(\n",
" rings_pred=lambda t:\n",
" linreg_pred_func(\n",
" t[\"serialized_linreg_model\"],\n",
" t[\"length\"],\n",
" t[\"shucked_weight\"],\n",
" ),\n",
" rings_pred=lambda t: linreg_pred_func(\n",
" t[\"serialized_linreg_model\"],\n",
" t[\"length\"],\n",
" t[\"shucked_weight\"],\n",
" ),\n",
")[[\"id\", \"sex\", \"rings\", \"rings_pred\"]]"
]
},
Expand Down Expand Up @@ -946,6 +945,7 @@
"# , r2_score float8\n",
"# );\n",
"\n",
"\n",
"@dataclasses.dataclass\n",
"class linreg_eval_type:\n",
" mae: float\n",
Expand Down Expand Up @@ -974,6 +974,7 @@
"source": [
"from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score\n",
"\n",
"\n",
"@gp.create_column_function\n",
"def linreg_eval(y_actual: List[float], y_pred: List[float]) -> linreg_eval_type:\n",
" mae = mean_absolute_error(y_actual, y_pred)\n",
Expand Down Expand Up @@ -1066,8 +1067,7 @@
"# ) a\n",
"\n",
"\n",
"linreg_pred.group_by(\"sex\").apply(\n",
" lambda t: linreg_eval(t[\"rings\"], t[\"rings_pred\"]), expand=True)"
"linreg_pred.group_by(\"sex\").apply(lambda t: linreg_eval(t[\"rings\"], t[\"rings_pred\"]), expand=True)"
]
}
],
Expand Down
12 changes: 3 additions & 9 deletions doc/source/notebooks/basic.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -787,14 +787,8 @@
"t_join = t1.join(\n",
" t2,\n",
" on=\"val\",\n",
" self_columns = {\n",
" \"id\": \"t1_id\",\n",
" \"val\": \"t1_val\"\n",
" },\n",
" other_columns = {\n",
" \"id\": \"t2_id\",\n",
" \"val\": \"t2_val\"\n",
" }\n",
" self_columns={\"id\": \"t1_id\", \"val\": \"t1_val\"},\n",
" other_columns={\"id\": \"t2_id\", \"val\": \"t2_val\"},\n",
")\n",
"t_join"
]
Expand Down Expand Up @@ -1075,7 +1069,7 @@
" numbers.assign(is_even=lambda t: t[\"val\"] % 2 == 0)\n",
" .group_by(\"is_even\")\n",
" .apply(lambda t: F.sum(t[\"val\"]))\n",
")\n"
")"
]
}
],
Expand Down
Loading

0 comments on commit 5086dac

Please sign in to comment.