From efdf0346e4fc57f1d57412fcdc3967ec746713df Mon Sep 17 00:00:00 2001 From: rportilla-databricks <38080604+rportilla-databricks@users.noreply.github.com> Date: Fri, 25 Feb 2022 09:17:47 -0500 Subject: [PATCH] add grouped stats (#158) * add grouped stats * fix docs * commiting docs for grouping * adding reduce to replace for loop for metrics calc * update pypi --- python/setup.py | 2 +- python/tempo/tsdf.py | 38 +++++++++++++++++++++++++ python/tests/tsdf_tests.py | 57 +++++++++++++++++++++++++++++++++++++- 3 files changed, 95 insertions(+), 2 deletions(-) diff --git a/python/setup.py b/python/setup.py index 44fc2f27..29d3951f 100644 --- a/python/setup.py +++ b/python/setup.py @@ -6,7 +6,7 @@ setuptools.setup( name='dbl-tempo', - version='0.1.7', + version='0.1.8', author='Ricardo Portilla, Tristan Nixon, Max Thone, Sonali Guleria', author_email='labs@databricks.com', description='Spark Time Series Utility Package', diff --git a/python/tempo/tsdf.py b/python/tempo/tsdf.py index bd3b7348..4dab294a 100644 --- a/python/tempo/tsdf.py +++ b/python/tempo/tsdf.py @@ -593,6 +593,44 @@ def withRangeStats(self, type='range', colsToSummarize=[], rangeBackWindowSecs=1 return TSDF(summary_df, self.ts_col, self.partitionCols) + def withGroupedStats(self, metricCols=[], freq = None): + """ + Create a wider set of stats based on all numeric columns by default + Users can choose which columns they want to summarize also. These stats are: + mean/count/min/max/sum/std deviation + :param metricCols - list of user-supplied columns to compute stats for. All numeric columns are used if no list is provided + :param freq - frequency (provide a string of the form '1 min', '30 seconds' and we interpret the window to use to aggregate + """ + + # identify columns to summarize if not provided + # these should include all numeric columns that + # are not the timestamp column and not any of the partition columns + if not metricCols: + # columns we should never summarize + prohibited_cols = [self.ts_col.lower()] + if self.partitionCols: + prohibited_cols.extend([pc.lower() for pc in self.partitionCols]) + # types that can be summarized + summarizable_types = ['int', 'bigint', 'float', 'double'] + # filter columns to find summarizable columns + metricCols = [datatype[0] for datatype in self.df.dtypes if + ((datatype[1] in summarizable_types) and + (datatype[0].lower() not in prohibited_cols))] + + # build window + parsed_freq = rs.checkAllowableFreq(self, freq) + agg_window = f.window(f.col(self.ts_col), "{} {}".format(parsed_freq[0], rs.freq_dict[parsed_freq[1]])) + + # compute column summaries + selectedCols = [] + reduce(lambda selectedCols, metric: + selectedCols.extend([f.mean(f.col(metric)).alias('mean_' + metric), f.count(f.col(metric)).alias('count_' + metric), f.min(f.col(metric)).alias('min_' + metric), f.max(f.col(metric)).alias('max_' + metric), f.sum(f.col(metric)).alias('sum_' + metric), f.stddev(f.col(metric)).alias('stddev_' + metric)]), metricCols, selectedCols) + + selected_df = self.df.groupBy(self.partitionCols + [agg_window]).agg(*selectedCols) + summary_df = selected_df.select(*selected_df.columns).withColumn(self.ts_col, f.col('window').start).drop('window') + + return TSDF(summary_df, self.ts_col, self.partitionCols) + def write(self, spark, tabName, optimizationCols = None): tio.write(self, spark, tabName, optimizationCols) diff --git a/python/tests/tsdf_tests.py b/python/tests/tsdf_tests.py index 36c510c2..233ef04a 100644 --- a/python/tests/tsdf_tests.py +++ b/python/tests/tsdf_tests.py @@ -501,6 +501,61 @@ def test_range_stats(self): # should be equal to the expected dataframe self.assertDataFramesEqual(featured_df, dfExpected) + def test_group_stats(self): + """Test of range stats for 20 minute rolling window""" + schema = StructType([StructField("symbol", StringType()), + StructField("event_ts", StringType()), + StructField("trade_pr", FloatType())]) + + expectedSchema = StructType([StructField("symbol", StringType()), + StructField("event_ts", StringType()), + StructField("mean_trade_pr", FloatType()), + StructField("count_trade_pr", LongType(), nullable=False), + StructField("min_trade_pr", FloatType()), + StructField("max_trade_pr", FloatType()), + StructField("sum_trade_pr", FloatType()), + StructField("stddev_trade_pr", FloatType())]) + + data = [["S1", "2020-08-01 00:00:10", 349.21], + ["S1", "2020-08-01 00:00:33", 351.32], + ["S1", "2020-09-01 00:02:10", 361.1], + ["S1", "2020-09-01 00:02:49", 362.1]] + + expected_data = [ + ["S1", "2020-08-01 00:00:00", 350.26, 2, 349.21, 351.32, 700.53, 1.49], + ["S1", "2020-09-01 00:02:00", 361.6, 2, 361.1, 362.1, 723.2, 0.71]] + + # construct dataframes + df = self.buildTestDF(schema, data) + dfExpected = self.buildTestDF(expectedSchema, expected_data) + + # convert to TSDF + tsdf_left = TSDF(df, partition_cols=["symbol"]) + + # using lookback of 20 minutes + featured_df = tsdf_left.withGroupedStats(freq = '1 min').df + + # cast to decimal with precision in cents for simplicity + featured_df = featured_df.select(F.col("symbol"), F.col("event_ts"), + F.col("mean_trade_pr").cast("decimal(5, 2)"), + F.col("count_trade_pr"), + F.col("min_trade_pr").cast("decimal(5,2)"), + F.col("max_trade_pr").cast("decimal(5,2)"), + F.col("sum_trade_pr").cast("decimal(5,2)"), + F.col("stddev_trade_pr").cast("decimal(5,2)")) + + # cast to decimal with precision in cents for simplicity + dfExpected = dfExpected.select(F.col("symbol"), F.col("event_ts"), + F.col("mean_trade_pr").cast("decimal(5, 2)"), + F.col("count_trade_pr"), + F.col("min_trade_pr").cast("decimal(5,2)"), + F.col("max_trade_pr").cast("decimal(5,2)"), + F.col("sum_trade_pr").cast("decimal(5,2)"), + F.col("stddev_trade_pr").cast("decimal(5,2)")) + + # should be equal to the expected dataframe + self.assertDataFramesEqual(featured_df, dfExpected) + class UtilsTest(SparkTest): @@ -649,7 +704,7 @@ def test_upsample(self): ["S1", "2020-09-01 00:15:00", 0.0, 362.1, 4.0] ] - expected_bars = [ + expected_bars = [ ['S1', '2020-08-01 00:00:00', 340.21, 9.0, 349.21, 10.0, 340.21, 9.0, 349.21, 10.0], ['S1', '2020-08-01 00:01:00', 350.32, 6.0, 353.32, 8.0, 350.32, 6.0, 353.32, 8.0], ['S1', '2020-09-01 00:01:00', 361.1, 5.0, 361.1, 5.0, 361.1, 5.0, 361.1, 5.0],