From 7d005cdb998aca6b319166b7a8a75f583d8d1659 Mon Sep 17 00:00:00 2001 From: Tristan Nixon Date: Thu, 25 Jul 2024 15:40:16 -0700 Subject: [PATCH 01/12] skeleton of TimeSeriesCrossValidator --- python/tempo/ml.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) create mode 100644 python/tempo/ml.py diff --git a/python/tempo/ml.py b/python/tempo/ml.py new file mode 100644 index 00000000..8164d4d0 --- /dev/null +++ b/python/tempo/ml.py @@ -0,0 +1,13 @@ +from typing import List, Tuple + +from pyspark.ml.tuning import CrossValidator +from pyspark.sql import DataFrame + + +class TimeSeriesCrossValidator(CrossValidator): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def _kFold(self, dataset: DataFrame) -> List[Tuple[DataFrame, DataFrame]]: + pass + From 8388a1045f5c1b5efeb0d62b80e9f24b6aab1ad4 Mon Sep 17 00:00:00 2001 From: Tristan Nixon Date: Thu, 25 Jul 2024 20:09:00 -0700 Subject: [PATCH 02/12] basic implementation --- python/tempo/ml.py | 59 ++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 57 insertions(+), 2 deletions(-) diff --git a/python/tempo/ml.py b/python/tempo/ml.py index 8164d4d0..5674b9c7 100644 --- a/python/tempo/ml.py +++ b/python/tempo/ml.py @@ -1,13 +1,68 @@ from typing import List, Tuple +from functools import reduce -from pyspark.ml.tuning import CrossValidator from pyspark.sql import DataFrame +from pyspark.sql.window import Window, WindowSpec +from pyspark.sql import functions as sfn + +from pyspark.ml.param import Param, Params +from pyspark.ml.tuning import CrossValidator +TMP_SPLIT_COL = "__tmp_split_col" + class TimeSeriesCrossValidator(CrossValidator): + # some additional parameters + timeSeriesCol: Param[str] = Param( + Params._dummy(), + "timeSeriesCol", + "The name of the time series column" + ) + seriesIdCols: Param[List[str]] = Param( + Params._dummy(), + "seriesIdCols", + "The name of the series id columns" + ) + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) + def getTimeSeriesCol(self) -> str: + return self.getOrDefault(self.timeSeriesCol) + + def getSeriesIdCols(self) -> List[str]: + return self.getOrDefault(self.seriesIdCols) + + def setTimeSeriesCol(self, value: str) -> "TimeSeriesCrossValidator": + return self._set(timeSeriesCol=value) + + def setSeriesIdCols(self, value: List[str]) -> "TimeSeriesCrossValidator": + return self._set(seriesIdCols=value) + + def _get_split_win(self) -> WindowSpec: + win = Window.orderBy(self.getTimeSeriesCol()) + series_id_cols = self.getSeriesIdCols() + if series_id_cols and len(series_id_cols) > 0: + win = win.partitionBy(*series_id_cols) + return win + def _kFold(self, dataset: DataFrame) -> List[Tuple[DataFrame, DataFrame]]: - pass + nFolds = self.getOrDefault(self.numFolds) + nSplits = nFolds+1 + + # split the data into nSplits subsets by timeseries order + split_df = dataset.withColumn(TMP_SPLIT_COL, + sfn.ntile(nSplits).over(self._get_split_win())) + all_splits = [split_df.filter(sfn.col(TMP_SPLIT_COL) == i).drop(TMP_SPLIT_COL) + for i in range(1, nSplits+1)] + assert len(all_splits) == nSplits + + # compose the k folds by including all previous splits in the training set, + # and the next split in the test set + kFolds = [(reduce(lambda a, b: a.union(b), all_splits[:i+1]), all_splits[i+1]) + for i in range(nFolds)] + assert len(kFolds) == nFolds + for tv in kFolds: + assert len(tv) == 2 + return kFolds From 11a00c8968cbf1228c19555643d98a40c3587e06 Mon Sep 17 00:00:00 2001 From: Tristan Nixon Date: Tue, 29 Oct 2024 15:06:39 -0700 Subject: [PATCH 03/12] added gap param to tscv class --- python/tempo/ml.py | 53 ++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 46 insertions(+), 7 deletions(-) diff --git a/python/tempo/ml.py b/python/tempo/ml.py index 5674b9c7..12035003 100644 --- a/python/tempo/ml.py +++ b/python/tempo/ml.py @@ -5,27 +5,43 @@ from pyspark.sql.window import Window, WindowSpec from pyspark.sql import functions as sfn -from pyspark.ml.param import Param, Params +from pyspark.ml.param import Param, Params, TypeConverters from pyspark.ml.tuning import CrossValidator TMP_SPLIT_COL = "__tmp_split_col" +TMP_GAP_COL = "__tmp_gap_row" class TimeSeriesCrossValidator(CrossValidator): # some additional parameters timeSeriesCol: Param[str] = Param( Params._dummy(), "timeSeriesCol", - "The name of the time series column" + "The name of the time series column", + typeConverter=TypeConverters.toString ) seriesIdCols: Param[List[str]] = Param( Params._dummy(), "seriesIdCols", - "The name of the series id columns" + "The name of the series id columns", + typeConverter=TypeConverters.toListString + ) + gap: Param[int] = Param( + Params._dummy(), + "gap", + "The gap between training and test set", + typeConverter=TypeConverters.toInt ) - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) + def __init__(self, + *, + timeSeriesCol: str = "event_ts", + seriesIdCols: List[str] = [], + gap: int = 0) -> None: + super(TimeSeriesCrossValidator, self).__init__() + self._setDefault(timeSeriesCol="event_ts", seriesIdCols=[], gap=0) + kwargs = self._input_kwargs + self._set(**kwargs) def getTimeSeriesCol(self) -> str: return self.getOrDefault(self.timeSeriesCol) @@ -33,14 +49,23 @@ def getTimeSeriesCol(self) -> str: def getSeriesIdCols(self) -> List[str]: return self.getOrDefault(self.seriesIdCols) + def getGap(self) -> int: + return self.getOrDefault(self.gap) + def setTimeSeriesCol(self, value: str) -> "TimeSeriesCrossValidator": return self._set(timeSeriesCol=value) def setSeriesIdCols(self, value: List[str]) -> "TimeSeriesCrossValidator": return self._set(seriesIdCols=value) - def _get_split_win(self) -> WindowSpec: - win = Window.orderBy(self.getTimeSeriesCol()) + def setGap(self, value: int) -> "TimeSeriesCrossValidator": + return self._set(gap=value) + + def _get_split_win(self, desc: bool = False) -> WindowSpec: + ts_col_expr = sfn.col(self.getTimeSeriesCol()) + if desc: + ts_col_expr = ts_col_expr.desc() + win = Window.orderBy(ts_col_expr) series_id_cols = self.getSeriesIdCols() if series_id_cols and len(series_id_cols) > 0: win = win.partitionBy(*series_id_cols) @@ -65,4 +90,18 @@ def _kFold(self, dataset: DataFrame) -> List[Tuple[DataFrame, DataFrame]]: for tv in kFolds: assert len(tv) == 2 + # trim out a gap from the training datasets, if specified + gap = self.getOrDefault(self.gap) + if gap > 0: + order_cols = self.getSeriesIdCols() + [self.getTimeSeriesCol()] + # trim each training dataset by the specified gap + kFolds = [((train_df.withColumn(TMP_GAP_COL, + sfn.row_number().over(self._get_split_win(desc=True))) + .where(sfn.col(TMP_GAP_COL) > gap) + .drop(TMP_GAP_COL) + .orderBy(*order_cols)), + test_df) + for (train_df, test_df) in kFolds] + + # return the k folds (training, test) datasets return kFolds From c83a97edcf57c7f97b9a79c2cc9191d465099f33 Mon Sep 17 00:00:00 2001 From: Tristan Nixon Date: Wed, 30 Oct 2024 12:37:37 -0700 Subject: [PATCH 04/12] updating init method --- python/tempo/ml.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/python/tempo/ml.py b/python/tempo/ml.py index 12035003..1c9f14ac 100644 --- a/python/tempo/ml.py +++ b/python/tempo/ml.py @@ -34,14 +34,13 @@ class TimeSeriesCrossValidator(CrossValidator): ) def __init__(self, - *, timeSeriesCol: str = "event_ts", seriesIdCols: List[str] = [], - gap: int = 0) -> None: - super(TimeSeriesCrossValidator, self).__init__() + gap: int = 0, + **other_kwargs) -> None: + super(TimeSeriesCrossValidator, self).__init__(**other_kwargs) self._setDefault(timeSeriesCol="event_ts", seriesIdCols=[], gap=0) - kwargs = self._input_kwargs - self._set(**kwargs) + self._set(timeSeriesCol=timeSeriesCol, seriesIdCols=seriesIdCols, gap=gap) def getTimeSeriesCol(self) -> str: return self.getOrDefault(self.timeSeriesCol) From e17ffced20204be1b6e438cee41e60fcea3a0983 Mon Sep 17 00:00:00 2001 From: tnixon Date: Thu, 31 Oct 2024 21:38:57 +0000 Subject: [PATCH 05/12] example notebook showing how to make use of timeseries split cross-validator --- examples/TimeSeries Split Cross-Validation.py | 154 ++++++++++++++++++ 1 file changed, 154 insertions(+) create mode 100644 examples/TimeSeries Split Cross-Validation.py diff --git a/examples/TimeSeries Split Cross-Validation.py b/examples/TimeSeries Split Cross-Validation.py new file mode 100644 index 00000000..9fc01711 --- /dev/null +++ b/examples/TimeSeries Split Cross-Validation.py @@ -0,0 +1,154 @@ +# Databricks notebook source +# MAGIC %md +# MAGIC # Set up dataset + +# COMMAND ---------- + +# MAGIC %sh +# MAGIC +# MAGIC wget -P /tmp/finserv/ https://pages.databricks.com/rs/094-YMS-629/images/ASOF_Trades.csv; + +# COMMAND ---------- + +# MAGIC %sh +# MAGIC +# MAGIC ls /tmp/finserv/ + +# COMMAND ---------- + +# MAGIC %sh head -n 30 /tmp/finserv/ASOF_Trades.csv + +# COMMAND ---------- + +data_dir = "/tmp/finserv" +trade_schema = """ + symbol string, + event_ts timestamp, + mod_dt date, + trade_pr double +""" + +trades_df = (spark.read.format("csv") + .schema(trade_schema) + .option("header", "true") + .option("delimiter", ",") + .load(f"{data_dir}/ASOF_Trades.csv")) + +# COMMAND ---------- + +# MAGIC %md +# MAGIC # Prepare Data +# MAGIC +# MAGIC Aggregate trades into 15-minute aggregates + +# COMMAND ---------- + +import pyspark.sql.functions as sfn + +bars_df = (trades_df + .where(sfn.col("symbol").isNotNull()) + .groupBy(sfn.col("symbol"), + sfn.window("event_ts", "15 minutes")) + .agg( + sfn.first_value("trade_pr").alias("open"), + sfn.min("trade_pr").alias("low"), + sfn.max("trade_pr").alias("high"), + sfn.last_value("trade_pr").alias("close"), + sfn.count("trade_pr").alias("num_trades")) + .select("symbol", + sfn.col("window.start").alias("event_ts"), + "open", "high", "low", "close", "num_trades") + .orderBy("symbol", "event_ts")) + +# COMMAND ---------- + +display(bars_df) + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ### Rolling 1.5 hour window +# MAGIC +# MAGIC Build a feature-vector by rolling the 6 previous rows (1.5 hours) into a feature vector to predict current prices. + +# COMMAND ---------- + +from pyspark.sql import Window + +six_step_win = Window.partitionBy("symbol").orderBy("event_ts").rowsBetween(-6, -1) + +six_step_rolling = (bars_df + .withColumn("prev_open", sfn.collect_list("open").over(six_step_win)) + .withColumn("prev_high", sfn.collect_list("high").over(six_step_win)) + .withColumn("prev_low", sfn.collect_list("low").over(six_step_win)) + .withColumn("prev_close", sfn.collect_list("close").over(six_step_win)) + .withColumn("prev_n", sfn.collect_list("num_trades").over(six_step_win)) + .where(sfn.array_size("prev_n") >= sfn.lit(6))) + +# COMMAND ---------- + +display(six_step_rolling) + +# COMMAND ---------- + +import pyspark.ml.functions as mlfn + +features_df = (six_step_rolling + .withColumn("features", + mlfn.array_to_vector(sfn.concat("prev_open", "prev_high", + "prev_low", "prev_close", + "prev_n"))) + .drop("prev_open", "prev_high", "prev_low", "prev_close", "prev_n")) + +# COMMAND ---------- + +display(features_df) + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ### Cross-validate a model +# MAGIC But we need to split based on time... + +# COMMAND ---------- + +from pyspark.ml.tuning import ParamGridBuilder +from pyspark.ml.regression import GBTRegressor +from pyspark.ml.evaluation import RegressionEvaluator + +from tempo.ml import TimeSeriesCrossValidator + +# parameters +params = ParamGridBuilder().build() + +# set up model +target_col = "close" +gbt = GBTRegressor(labelCol=target_col, featuresCol="features") + +# set up evaluator +eval = RegressionEvaluator(labelCol=target_col, predictionCol="prediction", metricName="rmse") + +# set up cross-validator +param_grid = ParamGridBuilder().build() +tscv = TimeSeriesCrossValidator(estimator=gbt, + evaluator=eval, + estimatorParamMaps=param_grid, + collectSubModels=True, + timeSeriesCol="event_ts", + seriesIdCols=["symbol"]) + +# COMMAND ---------- + +import mlflow +import mlflow.spark + +mlflow.spark.autolog() + +with mlflow.start_run() as run: + cvModel = tscv.fit(features_df) + best_gbt_mdl = cvModel.bestModel + mlflow.spark.log_model(best_gbt_mdl, "cv_model") + +# COMMAND ---------- + + From 11f51cdf140edbaa3d550dd6ea2f04cae231aaeb Mon Sep 17 00:00:00 2001 From: Tristan Nixon Date: Thu, 31 Oct 2024 14:53:45 -0700 Subject: [PATCH 06/12] applying black formatting --- python/tempo/ml.py | 56 +++++++++++++++++++++++++++++----------------- 1 file changed, 36 insertions(+), 20 deletions(-) diff --git a/python/tempo/ml.py b/python/tempo/ml.py index 1c9f14ac..3db858af 100644 --- a/python/tempo/ml.py +++ b/python/tempo/ml.py @@ -12,32 +12,35 @@ TMP_SPLIT_COL = "__tmp_split_col" TMP_GAP_COL = "__tmp_gap_row" + class TimeSeriesCrossValidator(CrossValidator): # some additional parameters timeSeriesCol: Param[str] = Param( Params._dummy(), "timeSeriesCol", "The name of the time series column", - typeConverter=TypeConverters.toString + typeConverter=TypeConverters.toString, ) seriesIdCols: Param[List[str]] = Param( Params._dummy(), "seriesIdCols", "The name of the series id columns", - typeConverter=TypeConverters.toListString + typeConverter=TypeConverters.toListString, ) gap: Param[int] = Param( Params._dummy(), "gap", "The gap between training and test set", - typeConverter=TypeConverters.toInt + typeConverter=TypeConverters.toInt, ) - def __init__(self, - timeSeriesCol: str = "event_ts", - seriesIdCols: List[str] = [], - gap: int = 0, - **other_kwargs) -> None: + def __init__( + self, + timeSeriesCol: str = "event_ts", + seriesIdCols: List[str] = [], + gap: int = 0, + **other_kwargs + ) -> None: super(TimeSeriesCrossValidator, self).__init__(**other_kwargs) self._setDefault(timeSeriesCol="event_ts", seriesIdCols=[], gap=0) self._set(timeSeriesCol=timeSeriesCol, seriesIdCols=seriesIdCols, gap=gap) @@ -72,19 +75,24 @@ def _get_split_win(self, desc: bool = False) -> WindowSpec: def _kFold(self, dataset: DataFrame) -> List[Tuple[DataFrame, DataFrame]]: nFolds = self.getOrDefault(self.numFolds) - nSplits = nFolds+1 + nSplits = nFolds + 1 # split the data into nSplits subsets by timeseries order - split_df = dataset.withColumn(TMP_SPLIT_COL, - sfn.ntile(nSplits).over(self._get_split_win())) - all_splits = [split_df.filter(sfn.col(TMP_SPLIT_COL) == i).drop(TMP_SPLIT_COL) - for i in range(1, nSplits+1)] + split_df = dataset.withColumn( + TMP_SPLIT_COL, sfn.ntile(nSplits).over(self._get_split_win()) + ) + all_splits = [ + split_df.filter(sfn.col(TMP_SPLIT_COL) == i).drop(TMP_SPLIT_COL) + for i in range(1, nSplits + 1) + ] assert len(all_splits) == nSplits # compose the k folds by including all previous splits in the training set, # and the next split in the test set - kFolds = [(reduce(lambda a, b: a.union(b), all_splits[:i+1]), all_splits[i+1]) - for i in range(nFolds)] + kFolds = [ + (reduce(lambda a, b: a.union(b), all_splits[: i + 1]), all_splits[i + 1]) + for i in range(nFolds) + ] assert len(kFolds) == nFolds for tv in kFolds: assert len(tv) == 2 @@ -94,13 +102,21 @@ def _kFold(self, dataset: DataFrame) -> List[Tuple[DataFrame, DataFrame]]: if gap > 0: order_cols = self.getSeriesIdCols() + [self.getTimeSeriesCol()] # trim each training dataset by the specified gap - kFolds = [((train_df.withColumn(TMP_GAP_COL, - sfn.row_number().over(self._get_split_win(desc=True))) + kFolds = [ + ( + ( + train_df.withColumn( + TMP_GAP_COL, + sfn.row_number().over(self._get_split_win(desc=True)), + ) .where(sfn.col(TMP_GAP_COL) > gap) .drop(TMP_GAP_COL) - .orderBy(*order_cols)), - test_df) - for (train_df, test_df) in kFolds] + .orderBy(*order_cols) + ), + test_df, + ) + for (train_df, test_df) in kFolds + ] # return the k folds (training, test) datasets return kFolds From 1f168592d3cf1babcf2dc4dadc8a649733859b85 Mon Sep 17 00:00:00 2001 From: Tristan Nixon Date: Thu, 31 Oct 2024 15:14:18 -0700 Subject: [PATCH 07/12] fixing type-check issue --- python/tempo/ml.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/tempo/ml.py b/python/tempo/ml.py index 3db858af..67c73bf2 100644 --- a/python/tempo/ml.py +++ b/python/tempo/ml.py @@ -1,4 +1,4 @@ -from typing import List, Tuple +from typing import Any, List, Tuple from functools import reduce from pyspark.sql import DataFrame @@ -39,7 +39,7 @@ def __init__( timeSeriesCol: str = "event_ts", seriesIdCols: List[str] = [], gap: int = 0, - **other_kwargs + **other_kwargs: Any ) -> None: super(TimeSeriesCrossValidator, self).__init__(**other_kwargs) self._setDefault(timeSeriesCol="event_ts", seriesIdCols=[], gap=0) From 0782126b374e6d8c1c3c54e0374539448553a802 Mon Sep 17 00:00:00 2001 From: Tristan Nixon Date: Thu, 31 Oct 2024 15:32:49 -0700 Subject: [PATCH 08/12] mypy should run in non-interactive mode --- .github/workflows/test.yml | 4 ++-- python/tox.ini | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 588b142d..5a33c876 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -23,10 +23,10 @@ jobs: run: | python -m pip install --upgrade pip python -m pip install tox tox-gh-actions - - name: Execute tox envs + - name: Lint check working-directory: ./python run: tox -e lint -- --check --diff - - name: Execute tox envs + - name: Type check working-directory: ./python run: tox -e type-check diff --git a/python/tox.ini b/python/tox.ini index 1ba8893f..de529c4c 100644 --- a/python/tox.ini +++ b/python/tox.ini @@ -50,7 +50,7 @@ deps = numpy types-openpyxl commands = - mypy --install-types {toxinidir}/tempo + mypy --install-types --non-interactive {toxinidir}/tempo [testenv:build-dist] description = build distribution From 621987911ad793a9707d98772ebc743128a6d6d5 Mon Sep 17 00:00:00 2001 From: Tristan Nixon Date: Mon, 4 Nov 2024 21:00:29 -0800 Subject: [PATCH 09/12] basic test code for cross-validator: constructor and params --- python/tests/ml_tests.py | 136 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 136 insertions(+) create mode 100644 python/tests/ml_tests.py diff --git a/python/tests/ml_tests.py b/python/tests/ml_tests.py new file mode 100644 index 00000000..f638eac0 --- /dev/null +++ b/python/tests/ml_tests.py @@ -0,0 +1,136 @@ +import unittest + +from pyspark.ml.tuning import CrossValidator, ParamGridBuilder +from pyspark.ml.regression import GBTRegressor +from pyspark.ml.evaluation import RegressionEvaluator + +from tempo.ml import TimeSeriesCrossValidator + +from tests.base import SparkTest + +class TimeSeriesCrossValidatorTests(SparkTest): + def test_empty_constructor(self): + # construct with default parameters + tscv = TimeSeriesCrossValidator() + # check the object + self.assertIsNotNone(tscv) + self.assertIsInstance(tscv, TimeSeriesCrossValidator) + self.assertIsInstance(tscv, CrossValidator) + # check the parameters + # CrossValidator parameters + self.assertEqual(tscv.getNumFolds(), 3) + self.assertEqual(tscv.getFoldCol(), "") + self.assertEqual(tscv.getParallelism(), 1) + self.assertEqual(tscv.getCollectSubModels(), False) + # TimeSeriesCrossValidator parameters + self.assertEqual(tscv.getTimeSeriesCol(), "event_ts") + self.assertEqual(tscv.getSeriesIdCols(), []) + self.assertEqual(tscv.getGap(), 0) + + def test_estim_eval_constructor(self): + # set up estimator and evaluator + estimator = GBTRegressor(labelCol="close", featuresCol="features") + evaluator = RegressionEvaluator(labelCol="close", + predictionCol="prediction", + metricName="rmse") + parm_grid = ParamGridBuilder().build() + # construct with default parameters + tscv = TimeSeriesCrossValidator(estimator=estimator, + evaluator=evaluator, + estimatorParamMaps=parm_grid) + # test the parameters + self.assertEqual(tscv.getEstimator(), estimator) + self.assertEqual(tscv.getEvaluator(), evaluator) + self.assertEqual(tscv.getEstimatorParamMaps(), parm_grid) + + def test_num_folds_param(self): + # construct with default parameters + tscv = TimeSeriesCrossValidator() + # set the number of folds + tscv.setNumFolds(5) + # check the number of folds + self.assertEqual(tscv.getNumFolds(), 5) + + def test_fold_col_param(self): + # construct with default parameters + tscv = TimeSeriesCrossValidator() + # set the fold column + tscv.setFoldCol("fold") + # check the fold column + self.assertEqual(tscv.getFoldCol(), "fold") + + def test_parallelism_param(self): + # construct with default parameters + tscv = TimeSeriesCrossValidator() + # set the parallelism + tscv.setParallelism(4) + # check the parallelism + self.assertEqual(tscv.getParallelism(), 4) + + def test_collect_sub_models_param(self): + # construct with default parameters + tscv = TimeSeriesCrossValidator() + # set the collect sub models + tscv.setCollectSubModels(True) + # check the collect sub models + self.assertEqual(tscv.getCollectSubModels(), True) + + def test_estimator_param(self): + # set up estimator and evaluator + estimator = GBTRegressor(labelCol="close", featuresCol="features") + # construct with default parameters + tscv = TimeSeriesCrossValidator() + # set the estimator + tscv.setEstimator(estimator) + # check the estimator + self.assertEqual(tscv.getEstimator(), estimator) + + def test_evaluator_param(self): + # set up estimator and evaluator + evaluator = RegressionEvaluator(labelCol="close", + predictionCol="prediction", + metricName="rmse") + # construct with default parameters + tscv = TimeSeriesCrossValidator() + # set the evaluator + tscv.setEvaluator(evaluator) + # check the evaluator + self.assertEqual(tscv.getEvaluator(), evaluator) + + def test_estimator_param_maps_param(self): + # set up estimator and evaluator + parm_grid = ParamGridBuilder().build() + # construct with default parameters + tscv = TimeSeriesCrossValidator() + # set the estimator parameter maps + tscv.setEstimatorParamMaps(parm_grid) + # check the estimator parameter maps + self.assertEqual(tscv.getEstimatorParamMaps(), parm_grid) + + def test_time_series_col_param(self): + # construct with default parameters + tscv = TimeSeriesCrossValidator() + # set the time series column + tscv.setTimeSeriesCol("ts") + # check the time series column + self.assertEqual(tscv.getTimeSeriesCol(), "ts") + + def test_series_id_cols_param(self): + # construct with default parameters + tscv = TimeSeriesCrossValidator() + # set the series id columns + tscv.setSeriesIdCols(["id1", "id2"]) + # check the series id columns + self.assertEqual(tscv.getSeriesIdCols(), ["id1", "id2"]) + + def test_gap_param(self): + # construct with default parameters + tscv = TimeSeriesCrossValidator() + # set the gap + tscv.setGap(2) + # check the gap + self.assertEqual(tscv.getGap(), 2) + +# MAIN +if __name__ == "__main__": + unittest.main() From 2f00ced5a4848543738b292d1ff5b25c8c83c225 Mon Sep 17 00:00:00 2001 From: Tristan Nixon Date: Mon, 4 Nov 2024 21:40:14 -0800 Subject: [PATCH 10/12] cache contents of test data file --- python/tests/base.py | 42 +++++++++++++++++++++++++++--------------- 1 file changed, 27 insertions(+), 15 deletions(-) diff --git a/python/tests/base.py b/python/tests/base.py index 8538a1ce..3d08b8f9 100644 --- a/python/tests/base.py +++ b/python/tests/base.py @@ -2,6 +2,7 @@ import unittest import warnings from typing import Union, Optional +from functools import cached_property import jsonref import pyspark.sql.functions as sfn @@ -170,7 +171,10 @@ class SparkTest(unittest.TestCase): # Spark Session object spark = None - test_data = None + + # test data + test_data_file = None + test_case_data = None @classmethod def setUpClass(cls) -> None: @@ -208,10 +212,10 @@ def tearDownClass(cls) -> None: cls.spark.stop() def setUp(self) -> None: - self.test_data = self.__loadTestData(self.id()) + self.test_case_data = self.__loadTestData(self.id()) def tearDown(self) -> None: - del self.test_data + del self.test_case_data # # Utility Functions @@ -219,7 +223,7 @@ def tearDown(self) -> None: def get_data_as_idf(self, name: str, convert_ts_col=True): df = self.get_data_as_sdf(name, convert_ts_col) - td = self.test_data[name] + td = self.test_case_data[name] idf = IntervalsDF( df, start_ts=td["start_ts"], @@ -258,20 +262,28 @@ def __loadTestData(self, test_case_path: str) -> dict: """ file_name, class_name, func_name = test_case_path.split(".")[-3:] - # find our test data file - test_data_file = self.__getTestDataFilePath(file_name) - if not os.path.isfile(test_data_file): - warnings.warn(f"Could not load test data file {test_data_file}") - return {} + # load the test data file if it hasn't been loaded yet + if self.test_data_file is None: + # find our test data file + test_data_filename = self.__getTestDataFilePath(file_name) + if not os.path.isfile(test_data_filename): + warnings.warn(f"Could not load test data file {test_data_filename}") + self.test_data_file = {} + + # proces the data file + with open(test_data_filename, "r") as f: + self.test_data_file = jsonref.load(f) + + # return the data if it exists + if class_name in self.test_data_file: + if func_name in self.test_data_file[class_name]: + return self.test_data_file[class_name][func_name] - # proces the data file - with open(test_data_file, "r") as f: - data_metadata_from_json = jsonref.load(f) - # return the data - return data_metadata_from_json[class_name][func_name] + # return empty dictionary if no data found + return {} def get_test_df_builder(self, name: str) -> TestDataFrameBuilder: - return TestDataFrameBuilder(self.spark, self.test_data[name]) + return TestDataFrameBuilder(self.spark, self.test_case_data[name]) # # Assertion Functions From 4f85463a78bcbc32c953eae0e75a77d04a84d34e Mon Sep 17 00:00:00 2001 From: Tristan Nixon Date: Mon, 4 Nov 2024 21:58:19 -0800 Subject: [PATCH 11/12] option to load test data from csv file --- python/tests/base.py | 28 +++++++++++++++++++++------- 1 file changed, 21 insertions(+), 7 deletions(-) diff --git a/python/tests/base.py b/python/tests/base.py index 3d08b8f9..cc1487c8 100644 --- a/python/tests/base.py +++ b/python/tests/base.py @@ -5,11 +5,15 @@ from functools import cached_property import jsonref +import pandas as pd + import pyspark.sql.functions as sfn -from chispa import assert_df_equality -from delta.pip_utils import configure_spark_with_delta_pip from pyspark.sql import SparkSession from pyspark.sql.dataframe import DataFrame + +from chispa import assert_df_equality +from delta.pip_utils import configure_spark_with_delta_pip + from tempo.intervals import IntervalsDF from tempo.tsdf import TSDF @@ -43,11 +47,20 @@ def df_schema(self) -> str: """ return self.df["schema"] - def df_data(self) -> list: + def df_data(self) -> Union[list, pd.DataFrame]: """ :return: the data component of the test data """ - return self.df["data"] + data = self.df["data"] + # return data literals (list of rows) + if isinstance(data, list): + return data + # load data from a csv file + elif isinstance(data, str): + csv_path = SparkTest.getTestDataFilePath(data, extension='') + return pd.read_csv(csv_path) + else: + raise ValueError(f"Invalid data type {type(data)}") # TSDF metadata @@ -234,7 +247,8 @@ def get_data_as_idf(self, name: str, convert_ts_col=True): TEST_DATA_FOLDER = "unit_test_data" - def __getTestDataFilePath(self, test_file_name: str) -> str: + @classmethod + def getTestDataFilePath(cls, test_file_name: str, extension: str = '.json') -> str: # what folder are we running from? cwd = os.path.basename(os.getcwd()) @@ -251,7 +265,7 @@ def __getTestDataFilePath(self, test_file_name: str) -> str: ) # return appropriate path - return f"{dir_path}/{self.TEST_DATA_FOLDER}/{test_file_name}.json" + return f"{dir_path}/{cls.TEST_DATA_FOLDER}/{test_file_name}{extension}" def __loadTestData(self, test_case_path: str) -> dict: """ @@ -265,7 +279,7 @@ def __loadTestData(self, test_case_path: str) -> dict: # load the test data file if it hasn't been loaded yet if self.test_data_file is None: # find our test data file - test_data_filename = self.__getTestDataFilePath(file_name) + test_data_filename = self.getTestDataFilePath(file_name) if not os.path.isfile(test_data_filename): warnings.warn(f"Could not load test data file {test_data_filename}") self.test_data_file = {} From 30680d62fcecfe961993a4c1b710ac6029596b5a Mon Sep 17 00:00:00 2001 From: Tristan Nixon Date: Mon, 4 Nov 2024 22:03:58 -0800 Subject: [PATCH 12/12] test the k-folds operation of our TimeSeriesCrossValidator --- python/tests/ml_tests.py | 20 +++++ python/tests/unit_test_data/ml_tests.json | 13 +++ python/tests/unit_test_data/trades.csv | 101 ++++++++++++++++++++++ 3 files changed, 134 insertions(+) create mode 100644 python/tests/unit_test_data/ml_tests.json create mode 100644 python/tests/unit_test_data/trades.csv diff --git a/python/tests/ml_tests.py b/python/tests/ml_tests.py index f638eac0..b435bf98 100644 --- a/python/tests/ml_tests.py +++ b/python/tests/ml_tests.py @@ -3,6 +3,7 @@ from pyspark.ml.tuning import CrossValidator, ParamGridBuilder from pyspark.ml.regression import GBTRegressor from pyspark.ml.evaluation import RegressionEvaluator +from pyspark.sql import DataFrame from tempo.ml import TimeSeriesCrossValidator @@ -131,6 +132,25 @@ def test_gap_param(self): # check the gap self.assertEqual(tscv.getGap(), 2) + def test_kfolds(self): + # load test data + trades_df = self.get_test_df_builder("trades").as_sdf() + # construct with default parameters + tscv = TimeSeriesCrossValidator(timeSeriesCol='event_ts', + seriesIdCols=['symbol'], + gap=0) + # test the k-folds + k_folds = tscv._kFold(trades_df) + # check the number of folds + self.assertEqual(len(k_folds), tscv.getNumFolds()) + # check each fold + for fold in k_folds: + self.assertIsInstance(fold, tuple) + self.assertEqual(len(fold), 2) + self.assertIsInstance(fold[0], DataFrame) + self.assertIsInstance(fold[1], DataFrame) + + # MAIN if __name__ == "__main__": unittest.main() diff --git a/python/tests/unit_test_data/ml_tests.json b/python/tests/unit_test_data/ml_tests.json new file mode 100644 index 00000000..7ff634b3 --- /dev/null +++ b/python/tests/unit_test_data/ml_tests.json @@ -0,0 +1,13 @@ +{ + "TimeSeriesCrossValidatorTests": { + "test_kfolds": { + "trades": { + "df": { + "schema": "symbol string, event_ts string, trade_pr float", + "ts_convert": ["event_ts"], + "data": "trades.csv" + } + } + } + } +} \ No newline at end of file diff --git a/python/tests/unit_test_data/trades.csv b/python/tests/unit_test_data/trades.csv new file mode 100644 index 00000000..a7e3600e --- /dev/null +++ b/python/tests/unit_test_data/trades.csv @@ -0,0 +1,101 @@ +symbol,event_ts,trade_pr +IBM,2017-08-31 00:57:25,347.9766055434685 +IBM,2017-08-31 05:02:55,347.603478891568 +IBM,2017-08-31 05:26:44,348.2851225377187 +IBM,2017-08-31 05:38:08,347.8817054037267 +IBM,2017-08-31 05:53:32,348.3718457507241 +IBM,2017-08-31 06:56:22,349.40868952165323 +IBM,2017-08-31 08:11:03,350.4640358206109 +IBM,2017-08-31 10:49:09,347.716019602253 +IBM,2017-08-31 11:01:38,347.2030920487126 +IBM,2017-08-31 11:11:25,347.92907707949666 +IBM,2017-08-31 11:49:55,346.1066922566784 +IBM,2017-08-31 12:10:41,346.1236987198399 +IBM,2017-08-31 13:02:47,349.20960037131124 +IBM,2017-08-31 13:07:58,347.09158893690676 +IBM,2017-08-31 14:15:49,347.45775383566854 +IBM,2017-08-31 15:50:02,347.1668702661576 +IBM,2017-08-31 17:27:50,348.56522298908044 +IBM,2017-08-31 18:07:56,349.26325456538416 +IBM,2017-08-31 19:09:47,349.34601689149946 +IBM,2017-08-31 19:55:55,348.09936204319274 +IBM,2017-08-31 20:17:15,347.1308847917395 +IBM,2017-08-31 20:51:37,348.83766041227994 +IBM,2017-08-31 21:37:17,348.4003780895007 +K,2017-08-31 00:06:27,347.27138459233106 +K,2017-08-31 00:18:46,347.9898553182071 +K,2017-08-31 00:31:12,346.85852918073624 +K,2017-08-31 00:51:16,346.91520445001134 +K,2017-08-31 01:08:30,347.8078868655896 +K,2017-08-31 01:34:54,347.2374835843108 +K,2017-08-31 02:47:49,349.00659452619976 +K,2017-08-31 02:49:22,347.4814105439092 +K,2017-08-31 02:56:43,350.3539039043633 +K,2017-08-31 03:01:33,349.5941805224711 +K,2017-08-31 03:50:20,348.6119516556592 +K,2017-08-31 03:52:18,348.18731148311406 +K,2017-08-31 04:36:19,345.95795045531105 +K,2017-08-31 05:27:12,346.6341114389929 +K,2017-08-31 06:29:58,347.4121586706382 +K,2017-08-31 06:32:30,346.7582132240916 +K,2017-08-31 06:37:31,348.919146315238 +K,2017-08-31 06:56:24,349.45235333868743 +K,2017-08-31 08:38:22,347.6687817715506 +K,2017-08-31 08:52:59,349.11648025163987 +K,2017-08-31 09:22:55,347.16036576622395 +K,2017-08-31 10:00:54,348.4869310969907 +K,2017-08-31 10:52:36,348.44707325529976 +K,2017-08-31 12:47:15,349.2617047407556 +K,2017-08-31 13:17:24,349.16422862658777 +K,2017-08-31 13:17:36,347.2034739832661 +K,2017-08-31 13:42:17,350.3594725526159 +K,2017-08-31 14:53:24,345.9384837375688 +K,2017-08-31 15:14:08,346.3947630851533 +K,2017-08-31 16:41:45,348.99202720361484 +K,2017-08-31 18:41:52,348.7838699834772 +K,2017-08-31 19:05:41,347.95173326760005 +K,2017-08-31 19:25:27,348.16797905143034 +K,2017-08-31 19:33:37,350.6567627351192 +K,2017-08-31 20:21:47,347.9468144834939 +K,2017-08-31 21:20:48,349.0419269428769 +K,2017-08-31 21:36:07,347.38074751913484 +K,2017-08-31 21:46:14,348.02539935462477 +K,2017-08-31 21:58:11,346.98271245245644 +K,2017-08-31 23:16:57,349.77827310811676 +K,2017-08-31 23:29:40,348.9429200005411 +KFS,2017-08-31 01:57:44,347.77347472191366 +KFS,2017-08-31 01:58:19,347.3575869386784 +KFS,2017-08-31 03:42:15,349.12235630639043 +KFS,2017-08-31 10:26:57,347.9734526183446 +KFS,2017-08-31 11:12:29,345.7111774398965 +KFS,2017-08-31 12:30:27,347.9446791058658 +KFS,2017-08-31 12:56:56,348.40914502757425 +KFS,2017-08-31 20:18:30,348.5555420623246 +KFS,2017-08-31 21:34:01,346.7731734554559 +KFS,2017-08-31 22:32:59,348.6877379266723 +KFS,2017-08-31 23:09:35,349.41137210604654 +KFS,2017-08-31 23:11:17,349.0671659876273 +KFS,2017-08-31 23:31:03,350.44123904624985 +TBB,2017-08-31 00:54:21,347.0268200605267 +TBB,2017-08-31 01:27:59,347.81625383701953 +TBB,2017-08-31 01:29:59,346.7819013463641 +TBB,2017-08-31 01:42:25,347.20721120029015 +TBB,2017-08-31 02:28:56,347.4150394760788 +TBB,2017-08-31 03:16:10,348.7008001367906 +TBB,2017-08-31 04:38:04,348.0449984445236 +TBB,2017-08-31 05:48:03,348.6731290332764 +TBB,2017-08-31 08:32:07,350.7247367234809 +TBB,2017-08-31 08:42:47,346.5096608964251 +TBB,2017-08-31 10:58:42,348.4464129070117 +TBB,2017-08-31 11:37:39,347.9739503215442 +TBB,2017-08-31 12:25:31,349.7654451975011 +TBB,2017-08-31 13:00:17,347.77438852748907 +TBB,2017-08-31 14:46:22,348.6523007656035 +TBB,2017-08-31 16:11:57,348.1998564265572 +TBB,2017-08-31 16:54:51,347.86227977925466 +TBB,2017-08-31 17:44:52,346.8702925232193 +TBB,2017-08-31 18:26:52,347.85539454921854 +TBB,2017-08-31 18:50:39,349.22132130112925 +TBB,2017-08-31 19:03:36,346.8821233653525 +TBB,2017-08-31 20:34:19,348.2391472198875 +TBB,2017-08-31 20:36:40,347.0180283437618