-
Notifications
You must be signed in to change notification settings - Fork 53
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* skeleton of TimeSeriesCrossValidator * basic implementation * added gap param to tscv class * updating init method * example notebook showing how to make use of timeseries split cross-validator * applying black formatting * fixing type-check issue * mypy should run in non-interactive mode * basic test code for cross-validator: constructor and params * cache contents of test data file * option to load test data from csv file * test the k-folds operation of our TimeSeriesCrossValidator --------- Co-authored-by: tnixon <tnixon>
- Loading branch information
Showing
8 changed files
with
596 additions
and
24 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,154 @@ | ||
# Databricks notebook source | ||
# MAGIC %md | ||
# MAGIC # Set up dataset | ||
|
||
# COMMAND ---------- | ||
|
||
# MAGIC %sh | ||
# MAGIC | ||
# MAGIC wget -P /tmp/finserv/ https://pages.databricks.com/rs/094-YMS-629/images/ASOF_Trades.csv; | ||
|
||
# COMMAND ---------- | ||
|
||
# MAGIC %sh | ||
# MAGIC | ||
# MAGIC ls /tmp/finserv/ | ||
|
||
# COMMAND ---------- | ||
|
||
# MAGIC %sh head -n 30 /tmp/finserv/ASOF_Trades.csv | ||
|
||
# COMMAND ---------- | ||
|
||
data_dir = "/tmp/finserv" | ||
trade_schema = """ | ||
symbol string, | ||
event_ts timestamp, | ||
mod_dt date, | ||
trade_pr double | ||
""" | ||
|
||
trades_df = (spark.read.format("csv") | ||
.schema(trade_schema) | ||
.option("header", "true") | ||
.option("delimiter", ",") | ||
.load(f"{data_dir}/ASOF_Trades.csv")) | ||
|
||
# COMMAND ---------- | ||
|
||
# MAGIC %md | ||
# MAGIC # Prepare Data | ||
# MAGIC | ||
# MAGIC Aggregate trades into 15-minute aggregates | ||
|
||
# COMMAND ---------- | ||
|
||
import pyspark.sql.functions as sfn | ||
|
||
bars_df = (trades_df | ||
.where(sfn.col("symbol").isNotNull()) | ||
.groupBy(sfn.col("symbol"), | ||
sfn.window("event_ts", "15 minutes")) | ||
.agg( | ||
sfn.first_value("trade_pr").alias("open"), | ||
sfn.min("trade_pr").alias("low"), | ||
sfn.max("trade_pr").alias("high"), | ||
sfn.last_value("trade_pr").alias("close"), | ||
sfn.count("trade_pr").alias("num_trades")) | ||
.select("symbol", | ||
sfn.col("window.start").alias("event_ts"), | ||
"open", "high", "low", "close", "num_trades") | ||
.orderBy("symbol", "event_ts")) | ||
|
||
# COMMAND ---------- | ||
|
||
display(bars_df) | ||
|
||
# COMMAND ---------- | ||
|
||
# MAGIC %md | ||
# MAGIC ### Rolling 1.5 hour window | ||
# MAGIC | ||
# MAGIC Build a feature-vector by rolling the 6 previous rows (1.5 hours) into a feature vector to predict current prices. | ||
|
||
# COMMAND ---------- | ||
|
||
from pyspark.sql import Window | ||
|
||
six_step_win = Window.partitionBy("symbol").orderBy("event_ts").rowsBetween(-6, -1) | ||
|
||
six_step_rolling = (bars_df | ||
.withColumn("prev_open", sfn.collect_list("open").over(six_step_win)) | ||
.withColumn("prev_high", sfn.collect_list("high").over(six_step_win)) | ||
.withColumn("prev_low", sfn.collect_list("low").over(six_step_win)) | ||
.withColumn("prev_close", sfn.collect_list("close").over(six_step_win)) | ||
.withColumn("prev_n", sfn.collect_list("num_trades").over(six_step_win)) | ||
.where(sfn.array_size("prev_n") >= sfn.lit(6))) | ||
|
||
# COMMAND ---------- | ||
|
||
display(six_step_rolling) | ||
|
||
# COMMAND ---------- | ||
|
||
import pyspark.ml.functions as mlfn | ||
|
||
features_df = (six_step_rolling | ||
.withColumn("features", | ||
mlfn.array_to_vector(sfn.concat("prev_open", "prev_high", | ||
"prev_low", "prev_close", | ||
"prev_n"))) | ||
.drop("prev_open", "prev_high", "prev_low", "prev_close", "prev_n")) | ||
|
||
# COMMAND ---------- | ||
|
||
display(features_df) | ||
|
||
# COMMAND ---------- | ||
|
||
# MAGIC %md | ||
# MAGIC ### Cross-validate a model | ||
# MAGIC But we need to split based on time... | ||
|
||
# COMMAND ---------- | ||
|
||
from pyspark.ml.tuning import ParamGridBuilder | ||
from pyspark.ml.regression import GBTRegressor | ||
from pyspark.ml.evaluation import RegressionEvaluator | ||
|
||
from tempo.ml import TimeSeriesCrossValidator | ||
|
||
# parameters | ||
params = ParamGridBuilder().build() | ||
|
||
# set up model | ||
target_col = "close" | ||
gbt = GBTRegressor(labelCol=target_col, featuresCol="features") | ||
|
||
# set up evaluator | ||
eval = RegressionEvaluator(labelCol=target_col, predictionCol="prediction", metricName="rmse") | ||
|
||
# set up cross-validator | ||
param_grid = ParamGridBuilder().build() | ||
tscv = TimeSeriesCrossValidator(estimator=gbt, | ||
evaluator=eval, | ||
estimatorParamMaps=param_grid, | ||
collectSubModels=True, | ||
timeSeriesCol="event_ts", | ||
seriesIdCols=["symbol"]) | ||
|
||
# COMMAND ---------- | ||
|
||
import mlflow | ||
import mlflow.spark | ||
|
||
mlflow.spark.autolog() | ||
|
||
with mlflow.start_run() as run: | ||
cvModel = tscv.fit(features_df) | ||
best_gbt_mdl = cvModel.bestModel | ||
mlflow.spark.log_model(best_gbt_mdl, "cv_model") | ||
|
||
# COMMAND ---------- | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
from typing import Any, List, Tuple | ||
from functools import reduce | ||
|
||
from pyspark.sql import DataFrame | ||
from pyspark.sql.window import Window, WindowSpec | ||
from pyspark.sql import functions as sfn | ||
|
||
from pyspark.ml.param import Param, Params, TypeConverters | ||
from pyspark.ml.tuning import CrossValidator | ||
|
||
|
||
TMP_SPLIT_COL = "__tmp_split_col" | ||
TMP_GAP_COL = "__tmp_gap_row" | ||
|
||
|
||
class TimeSeriesCrossValidator(CrossValidator): | ||
# some additional parameters | ||
timeSeriesCol: Param[str] = Param( | ||
Params._dummy(), | ||
"timeSeriesCol", | ||
"The name of the time series column", | ||
typeConverter=TypeConverters.toString, | ||
) | ||
seriesIdCols: Param[List[str]] = Param( | ||
Params._dummy(), | ||
"seriesIdCols", | ||
"The name of the series id columns", | ||
typeConverter=TypeConverters.toListString, | ||
) | ||
gap: Param[int] = Param( | ||
Params._dummy(), | ||
"gap", | ||
"The gap between training and test set", | ||
typeConverter=TypeConverters.toInt, | ||
) | ||
|
||
def __init__( | ||
self, | ||
timeSeriesCol: str = "event_ts", | ||
seriesIdCols: List[str] = [], | ||
gap: int = 0, | ||
**other_kwargs: Any | ||
) -> None: | ||
super(TimeSeriesCrossValidator, self).__init__(**other_kwargs) | ||
self._setDefault(timeSeriesCol="event_ts", seriesIdCols=[], gap=0) | ||
self._set(timeSeriesCol=timeSeriesCol, seriesIdCols=seriesIdCols, gap=gap) | ||
|
||
def getTimeSeriesCol(self) -> str: | ||
return self.getOrDefault(self.timeSeriesCol) | ||
|
||
def getSeriesIdCols(self) -> List[str]: | ||
return self.getOrDefault(self.seriesIdCols) | ||
|
||
def getGap(self) -> int: | ||
return self.getOrDefault(self.gap) | ||
|
||
def setTimeSeriesCol(self, value: str) -> "TimeSeriesCrossValidator": | ||
return self._set(timeSeriesCol=value) | ||
|
||
def setSeriesIdCols(self, value: List[str]) -> "TimeSeriesCrossValidator": | ||
return self._set(seriesIdCols=value) | ||
|
||
def setGap(self, value: int) -> "TimeSeriesCrossValidator": | ||
return self._set(gap=value) | ||
|
||
def _get_split_win(self, desc: bool = False) -> WindowSpec: | ||
ts_col_expr = sfn.col(self.getTimeSeriesCol()) | ||
if desc: | ||
ts_col_expr = ts_col_expr.desc() | ||
win = Window.orderBy(ts_col_expr) | ||
series_id_cols = self.getSeriesIdCols() | ||
if series_id_cols and len(series_id_cols) > 0: | ||
win = win.partitionBy(*series_id_cols) | ||
return win | ||
|
||
def _kFold(self, dataset: DataFrame) -> List[Tuple[DataFrame, DataFrame]]: | ||
nFolds = self.getOrDefault(self.numFolds) | ||
nSplits = nFolds + 1 | ||
|
||
# split the data into nSplits subsets by timeseries order | ||
split_df = dataset.withColumn( | ||
TMP_SPLIT_COL, sfn.ntile(nSplits).over(self._get_split_win()) | ||
) | ||
all_splits = [ | ||
split_df.filter(sfn.col(TMP_SPLIT_COL) == i).drop(TMP_SPLIT_COL) | ||
for i in range(1, nSplits + 1) | ||
] | ||
assert len(all_splits) == nSplits | ||
|
||
# compose the k folds by including all previous splits in the training set, | ||
# and the next split in the test set | ||
kFolds = [ | ||
(reduce(lambda a, b: a.union(b), all_splits[: i + 1]), all_splits[i + 1]) | ||
for i in range(nFolds) | ||
] | ||
assert len(kFolds) == nFolds | ||
for tv in kFolds: | ||
assert len(tv) == 2 | ||
|
||
# trim out a gap from the training datasets, if specified | ||
gap = self.getOrDefault(self.gap) | ||
if gap > 0: | ||
order_cols = self.getSeriesIdCols() + [self.getTimeSeriesCol()] | ||
# trim each training dataset by the specified gap | ||
kFolds = [ | ||
( | ||
( | ||
train_df.withColumn( | ||
TMP_GAP_COL, | ||
sfn.row_number().over(self._get_split_win(desc=True)), | ||
) | ||
.where(sfn.col(TMP_GAP_COL) > gap) | ||
.drop(TMP_GAP_COL) | ||
.orderBy(*order_cols) | ||
), | ||
test_df, | ||
) | ||
for (train_df, test_df) in kFolds | ||
] | ||
|
||
# return the k folds (training, test) datasets | ||
return kFolds |
Oops, something went wrong.