diff --git a/python/tempo/interpol.py b/python/tempo/interpol.py index c4c9d311..e98204bb 100644 --- a/python/tempo/interpol.py +++ b/python/tempo/interpol.py @@ -161,8 +161,8 @@ def __generate_time_series_fill(self, tsdf: t_tsdf.TSDF) -> t_tsdf.TSDF: "previous_timestamp", sfn.col(tsdf.ts_col), ).withColumn( - "next_timestamp", - sfn.lead(sfn.col(tsdf.ts_col)).over(tsdf.baseWindow())) + "next_timestamp", sfn.lead(sfn.col(tsdf.ts_col)).over(tsdf.baseWindow()) + ) def __generate_column_time_fill( self, @@ -184,12 +184,14 @@ def __generate_column_time_fill( return tsdf.withColumn( f"previous_timestamp_{target_col}", - sfn.last(sfn.col(f"{tsdf.ts_col}_{target_col}"), - ignorenulls=True).over(fwd_win), + sfn.last(sfn.col(f"{tsdf.ts_col}_{target_col}"), ignorenulls=True).over( + fwd_win + ), ).withColumn( f"next_timestamp_{target_col}", - sfn.last(sfn.col(f"{tsdf.ts_col}_{target_col}"), - ignorenulls=True).over(bkwd_win) + sfn.last(sfn.col(f"{tsdf.ts_col}_{target_col}"), ignorenulls=True).over( + bkwd_win + ), ) def __generate_target_fill( @@ -213,15 +215,14 @@ def __generate_target_fill( return ( tsdf.withColumn( f"previous_{target_col}", - sfn.last(sfn.col(target_col), ignorenulls=True).over(fwd_win) + sfn.last(sfn.col(target_col), ignorenulls=True).over(fwd_win), ) # Handle if subsequent value is null .withColumn( f"next_null_{target_col}", - sfn.last(sfn.col(target_col), ignorenulls=True).over(bkwd_win) + sfn.last(sfn.col(target_col), ignorenulls=True).over(bkwd_win), ).withColumn( - f"next_{target_col}", - sfn.lead(sfn.col(target_col)).over(fwd_win) + f"next_{target_col}", sfn.lead(sfn.col(target_col)).over(fwd_win) ) ) @@ -277,9 +278,7 @@ def interpolate( if self.is_resampled is False: # Resample and Normalize Input - sampled_input = tsdf.resample(freq=freq, - func=func, - metricCols=target_cols) + sampled_input = tsdf.resample(freq=freq, func=func, metricCols=target_cols) # Fill timeseries for nearest values time_series_filled = self.__generate_time_series_fill(sampled_input) @@ -290,11 +289,11 @@ def interpolate( for column in target_cols: add_column_time = add_column_time.withColumn( f"{tsdf.ts_col}_{column}", - sfn.when(sfn.col(column).isNull(), None).otherwise(sfn.col(tsdf.ts_col)), - ) - add_column_time = self.__generate_column_time_fill( - add_column_time, column + sfn.when(sfn.col(column).isNull(), None).otherwise( + sfn.col(tsdf.ts_col) + ), ) + add_column_time = self.__generate_column_time_fill(add_column_time, column) # Handle edge case if last value (latest) is null edge_filled = add_column_time.withColumn( @@ -325,9 +324,9 @@ def interpolate( flagged_series = ( exploded_series.withColumn( "is_ts_interpolated", - sfn.when(sfn.col(f"new_{tsdf.ts_col}") != sfn.col(tsdf.ts_col), True).otherwise( - False - ), + sfn.when( + sfn.col(f"new_{tsdf.ts_col}") != sfn.col(tsdf.ts_col), True + ).otherwise(False), ) .withColumn(tsdf.ts_col, sfn.col(f"new_{tsdf.ts_col}")) .drop(sfn.col(f"new_{tsdf.ts_col}")) diff --git a/python/tempo/intervals.py b/python/tempo/intervals.py index bd029858..178a90e8 100644 --- a/python/tempo/intervals.py +++ b/python/tempo/intervals.py @@ -192,12 +192,14 @@ def fromStackedMetrics( return cls(df, start_ts, end_ts, series) @classmethod - def fromNestedBoundariesDF(cls, - df: DataFrame, - nested_boundaries_col: str, - series_ids: Optional[list[str]] = None, - start_element_name: str = "start", - end_element_name: str = "end") -> IntervalsDF: + def fromNestedBoundariesDF( + cls, + df: DataFrame, + nested_boundaries_col: str, + series_ids: Optional[list[str]] = None, + start_element_name: str = "start", + end_element_name: str = "end", + ) -> IntervalsDF: """ :param df: @@ -211,14 +213,15 @@ def fromNestedBoundariesDF(cls, # unpack the start & end elements start_path = f"{nested_boundaries_col}.{start_element_name}" end_path = f"{nested_boundaries_col}.{end_element_name}" - unpacked_boundaries_df = (df.withColumn(start_element_name, sfn.col(start_path)) - .withColumn(end_element_name, sfn.col(end_path)) - .drop(nested_boundaries_col)) + unpacked_boundaries_df = ( + df.withColumn(start_element_name, sfn.col(start_path)) + .withColumn(end_element_name, sfn.col(end_path)) + .drop(nested_boundaries_col) + ) # return the results as an IntervalsDF - return IntervalsDF(unpacked_boundaries_df, - start_element_name, - end_element_name, - series_ids) + return IntervalsDF( + unpacked_boundaries_df, start_element_name, end_element_name, series_ids + ) def __get_adjacent_rows(self, df: DataFrame) -> DataFrame: """ diff --git a/python/tempo/resample.py b/python/tempo/resample.py index 6f35ded1..adb55277 100644 --- a/python/tempo/resample.py +++ b/python/tempo/resample.py @@ -157,9 +157,7 @@ def aggregate( exprs = {x: "avg" for x in metricCols} res = df.groupBy(groupingCols).agg(exprs) agg_metric_cls = list( - set(res.columns).difference( - set(tsdf.series_ids + [tsdf.ts_col, "agg_key"]) - ) + set(res.columns).difference(set(tsdf.series_ids + [tsdf.ts_col, "agg_key"])) ) new_cols = [ sfn.col(c).alias( @@ -172,9 +170,7 @@ def aggregate( exprs = {x: "min" for x in metricCols} res = df.groupBy(groupingCols).agg(exprs) agg_metric_cls = list( - set(res.columns).difference( - set(tsdf.series_ids + [tsdf.ts_col, "agg_key"]) - ) + set(res.columns).difference(set(tsdf.series_ids + [tsdf.ts_col, "agg_key"])) ) new_cols = [ sfn.col(c).alias( @@ -187,9 +183,7 @@ def aggregate( exprs = {x: "max" for x in metricCols} res = df.groupBy(groupingCols).agg(exprs) agg_metric_cls = list( - set(res.columns).difference( - set(tsdf.series_ids + [tsdf.ts_col, "agg_key"]) - ) + set(res.columns).difference(set(tsdf.series_ids + [tsdf.ts_col, "agg_key"])) ) new_cols = [ sfn.col(c).alias( @@ -245,9 +239,9 @@ def aggregate( metrics.append(col[0]) if fill: - res = imputes.join( - res, tsdf.series_ids + [tsdf.ts_col], "leftouter" - ).na.fill(0, metrics) + res = imputes.join(res, tsdf.series_ids + [tsdf.ts_col], "leftouter").na.fill( + 0, metrics + ) return res diff --git a/python/tempo/stats.py b/python/tempo/stats.py index a6ad15f2..bc615c14 100644 --- a/python/tempo/stats.py +++ b/python/tempo/stats.py @@ -13,10 +13,10 @@ def vwap( - tsdf: TSDF, - frequency: str = "m", - volume_col: str = "volume", - price_col: str = "price", + tsdf: TSDF, + frequency: str = "m", + volume_col: str = "volume", + price_col: str = "price", ) -> TSDF: # set pre_vwap as self or enrich with the frequency pre_vwap = tsdf.df @@ -72,9 +72,7 @@ def EMA(tsdf: TSDF, colName: str, window: int = 30, exp_factor: float = 0.2) -> for i in range(window): lagColName = "_".join(["lag", colName, str(i)]) weight = exp_factor * (1 - exp_factor) ** i - df = df.withColumn( - lagColName, weight * sfn.lag(sfn.col(colName), i).over(w) - ) + df = df.withColumn(lagColName, weight * sfn.lag(sfn.col(colName), i).over(w)) df = df.withColumn( emaColName, sfn.col(emaColName) @@ -88,11 +86,11 @@ def EMA(tsdf: TSDF, colName: str, window: int = 30, exp_factor: float = 0.2) -> def withLookbackFeatures( - tsdf: TSDF, - feature_cols: List[str], - lookback_window_size: int, - exact_size: bool = True, - feature_col_name: str = "features", + tsdf: TSDF, + feature_cols: List[str], + lookback_window_size: int, + exact_size: bool = True, + feature_col_name: str = "features", ) -> TSDF: """ Creates a 2-D feature tensor suitable for training an ML model to predict current values from the history of @@ -130,10 +128,10 @@ def withLookbackFeatures( def withRangeStats( - tsdf: TSDF, - type: str = "range", - cols_to_summarize: Optional[List[Column]] = None, - range_back_window_secs: int = 1000, + tsdf: TSDF, + type: str = "range", + cols_to_summarize: Optional[List[Column]] = None, + range_back_window_secs: int = 1000, ) -> TSDF: """ Create a wider set of stats based on all numeric columns by default @@ -172,8 +170,8 @@ def withRangeStats( selected_cols.append(sfn.stddev(metric).over(w).alias("stddev_" + metric)) derived_cols.append( ( - (sfn.col(metric) - sfn.col("mean_" + metric)) - / sfn.col("stddev_" + metric) + (sfn.col(metric) - sfn.col("mean_" + metric)) + / sfn.col("stddev_" + metric) ).alias("zscore_" + metric) ) selected_df = tsdf.df.select(*selected_cols) @@ -185,9 +183,9 @@ def withRangeStats( def withGroupedStats( - tsdf: TSDF, - metric_cols: Optional[List[str]] = None, - freq: Optional[str] = None, + tsdf: TSDF, + metric_cols: Optional[List[str]] = None, + freq: Optional[str] = None, ) -> TSDF: """ Create a wider set of stats based on all numeric columns by default @@ -214,8 +212,8 @@ def withGroupedStats( datatype[0] for datatype in tsdf.df.dtypes if ( - (datatype[1] in summarizable_types) - and (datatype[0].lower() not in prohibited_cols) + (datatype[1] in summarizable_types) + and (datatype[0].lower() not in prohibited_cols) ) ] @@ -255,10 +253,10 @@ def withGroupedStats( def calc_bars( - tsdf: TSDF, - freq: str, - metric_cols: Optional[List[str]] = None, - fill: Optional[bool] = None, + tsdf: TSDF, + freq: str, + metric_cols: Optional[List[str]] = None, + fill: Optional[bool] = None, ) -> TSDF: resample_open = tsdf.resample( freq=freq, func="floor", metricCols=metric_cols, prefix="open", fill=fill @@ -279,9 +277,11 @@ def calc_bars( .join(resample_low.df, join_cols) .join(resample_close.df, join_cols) ) - non_part_cols = set(bars.columns) - set(resample_open.series_ids) - {resample_open.ts_col} + non_part_cols = ( + set(bars.columns) - set(resample_open.series_ids) - {resample_open.ts_col} + ) sel_and_sort = ( - resample_open.series_ids + [resample_open.ts_col] + sorted(non_part_cols) + resample_open.series_ids + [resample_open.ts_col] + sorted(non_part_cols) ) bars = bars.select(sel_and_sort) @@ -289,9 +289,7 @@ def calc_bars( def fourier_transform( - tsdf: TSDF, - timestep: Union[int, float, complex], - value_col: str + tsdf: TSDF, timestep: Union[int, float, complex], value_col: str ) -> TSDF: """ Function to fourier transform the time series to its frequency domain representation. @@ -304,7 +302,7 @@ def fourier_transform( """ def tempo_fourier_util( - pdf: pd.DataFrame, + pdf: pd.DataFrame, ) -> pd.DataFrame: """ This method is a vanilla python logic implementing fourier transform on a numpy array using the scipy module. diff --git a/python/tempo/tsdf.py b/python/tempo/tsdf.py index 2f4e1c87..68e4a3ab 100644 --- a/python/tempo/tsdf.py +++ b/python/tempo/tsdf.py @@ -35,11 +35,13 @@ class TSDF(WindowBuilder): This object is the main wrapper over a Spark data frame which allows a user to parallelize time series computations on a Spark data frame by various dimensions. The two dimensions required are partition_cols (list of columns by which to summarize) and ts_col (timestamp column, which can be epoch or TimestampType). """ - def __init__(self, - df: DataFrame, - ts_schema: Optional[TSSchema] = None, - ts_col: Optional[str] = None, - series_ids: Optional[Collection[str]] = None) -> None: + def __init__( + self, + df: DataFrame, + ts_schema: Optional[TSSchema] = None, + ts_col: Optional[str] = None, + series_ids: Optional[Collection[str]] = None, + ) -> None: self.df = df # construct schema if we don't already have one if ts_schema: @@ -79,7 +81,9 @@ def __withStandardizedColOrder(self) -> TSDF: :return: a :class:`TSDF` with the columns reordered into "standard order" (as described above) """ std_ordered_cols = ( - list(self.series_ids) + [self.ts_index.colname] + list(self.observational_cols) + list(self.series_ids) + + [self.ts_index.colname] + + list(self.observational_cols) ) return self.__withTransformedDF(self.df.select(std_ordered_cols)) @@ -97,8 +101,9 @@ def __makeStructFromCols( :return: the transformed :class:`DataFrame` """ - return (df.withColumn(struct_col_name, sfn.struct(*cols_to_move)) - .drop(*cols_to_move)) + return df.withColumn(struct_col_name, sfn.struct(*cols_to_move)).drop( + *cols_to_move + ) # default column name for constructed timeseries index struct columns __DEFAULT_TS_IDX_COL = "ts_idx" @@ -362,7 +367,9 @@ def __getTimePartitions(self, tsPartitionVal: int, fraction: float = 0.1) -> "TS df = partition_df.union(remainder_df).drop( "partition_remainder", "ts_col_double" ) - return TSDF(df, ts_col=self.ts_col, series_ids=self.series_ids + ["ts_partition"]) + return TSDF( + df, ts_col=self.ts_col, series_ids=self.series_ids + ["ts_partition"] + ) # # Slicing & Selection @@ -828,12 +835,8 @@ def asofJoin( # validate timestamp datatypes match self.__validateTsColMatch(right_tsdf) - orig_left_col_diff = list( - set(left_df.columns) - set(self.series_ids) - ) - orig_right_col_diff = list( - set(right_df.columns) - set(self.series_ids) - ) + orig_left_col_diff = list(set(left_df.columns) - set(self.series_ids)) + orig_right_col_diff = list(set(right_df.columns) - set(self.series_ids)) left_tsdf = ( (self.__addPrefixToColumns([self.ts_col] + orig_left_col_diff, left_prefix)) @@ -930,16 +933,14 @@ def asofJoin( def baseWindow(self, reverse: bool = False) -> WindowSpec: return self.ts_schema.baseWindow(reverse=reverse) - def rowsBetweenWindow(self, - start: int, - end: int, - reverse: bool = False) -> WindowSpec: + def rowsBetweenWindow( + self, start: int, end: int, reverse: bool = False + ) -> WindowSpec: return self.ts_schema.rowsBetweenWindow(start, end, reverse=reverse) - def rangeBetweenWindow(self, - start: int, - end: int, - reverse: bool = False) -> WindowSpec: + def rangeBetweenWindow( + self, start: int, end: int, reverse: bool = False + ) -> WindowSpec: return self.ts_schema.rangeBetweenWindow(start, end, reverse=reverse) # @@ -1013,7 +1014,6 @@ def drop(self, *cols: str) -> TSDF: ... def drop(self, *cols: ColumnOrName) -> TSDF: - """ Returns a new :class:`TSDF` that drops the specified column. @@ -1025,9 +1025,9 @@ def drop(self, *cols: ColumnOrName) -> TSDF: dropped_df = self.df.drop(*cols) return self.__withTransformedDF(dropped_df) - def mapInPandas(self, - func: PandasMapIterFunction, - schema: Union[StructType, str]) -> TSDF: + def mapInPandas( + self, func: PandasMapIterFunction, schema: Union[StructType, str] + ) -> TSDF: """ :param func: @@ -1053,9 +1053,9 @@ def unionByName(self, other: TSDF, allowMissingColumns: bool = False) -> TSDF: # Rolling (Windowed) Transformations # - def rollingAgg(self, - window: WindowSpec, - *exprs: Union[Column, Dict[str, str]]) -> TSDF: + def rollingAgg( + self, window: WindowSpec, *exprs: Union[Column, Dict[str, str]] + ) -> TSDF: """ :param window: @@ -1068,25 +1068,30 @@ def rollingAgg(self, for input_col in exprs.keys(): expr_str = exprs[input_col] new_col_name = f"{expr_str}({input_col})" - roll_agg_tsdf = roll_agg_tsdf.withColumn(new_col_name, - sfn.expr(expr_str).over(window)) + roll_agg_tsdf = roll_agg_tsdf.withColumn( + new_col_name, sfn.expr(expr_str).over(window) + ) else: # Columns - assert all(isinstance(c, Column) for c in exprs), \ - "all exprs should be Column" + assert all( + isinstance(c, Column) for c in exprs + ), "all exprs should be Column" for expr in exprs: new_col_name = f"{expr}" - roll_agg_tsdf = roll_agg_tsdf.withColumn(new_col_name, - expr.over(window)) + roll_agg_tsdf = roll_agg_tsdf.withColumn( + new_col_name, expr.over(window) + ) return roll_agg_tsdf - def rollingApply(self, - outputCol: str, - window: WindowSpec, - func: PandasGroupedMapFunction, - schema: Union[StructType, str], - *inputCols: Union[str, Column]) -> TSDF: + def rollingApply( + self, + outputCol: str, + window: WindowSpec, + func: PandasGroupedMapFunction, + schema: Union[StructType, str], + *inputCols: Union[str, Column], + ) -> TSDF: """ :param outputCol: @@ -1172,9 +1177,9 @@ def aggBySeries(self, *exprs: Union[Column, Dict[str, str]]) -> DataFrame: """ return self.groupBySeries().agg(exprs) - def applyToSeries(self, - func: PandasGroupedMapFunction, - schema: Union[StructType, str]) -> DataFrame: + def applyToSeries( + self, func: PandasGroupedMapFunction, schema: Union[StructType, str] + ) -> DataFrame: """ Maps each series using a pandas udf and returns the result as a `DataFrame`. @@ -1206,11 +1211,13 @@ def applyToSeries(self, # Cyclical Aggregtion - def groupByCycles(self, - length: str, - period: Optional[str] = None, - offset: Optional[str] = None, - bySeries: bool = True) -> GroupedData: + def groupByCycles( + self, + length: str, + period: Optional[str] = None, + offset: Optional[str] = None, + bySeries: bool = True, + ) -> GroupedData: """ :param length: @@ -1224,20 +1231,26 @@ def groupByCycles(self, grouping_cols = [sfn.col(series_col) for series_col in self.series_ids] else: grouping_cols = [] - grouping_cols.append(sfn.window(timeColumn=self.ts_col, - windowDuration=length, - slideDuration=period, - startTime=offset)) + grouping_cols.append( + sfn.window( + timeColumn=self.ts_col, + windowDuration=length, + slideDuration=period, + startTime=offset, + ) + ) # return the DataFrame grouped accordingly return self.df.groupBy(grouping_cols) - def aggByCycles(self, - length: str, - *exprs: Union[Column, Dict[str, str]], - period: Optional[str] = None, - offset: Optional[str] = None, - bySeries: bool = True) -> IntervalsDF: + def aggByCycles( + self, + length: str, + *exprs: Union[Column, Dict[str, str]], + period: Optional[str] = None, + offset: Optional[str] = None, + bySeries: bool = True, + ) -> IntervalsDF: """ :param length: @@ -1252,19 +1265,21 @@ def aggByCycles(self, # if we have aggregated over series, we return a TSDF without series if bySeries: - return IntervalsDF.fromNestedBoundariesDF(agged_df, - "window", - self.series_ids) + return IntervalsDF.fromNestedBoundariesDF( + agged_df, "window", self.series_ids + ) else: return IntervalsDF.fromNestedBoundariesDF(agged_df, "window") - def applyToCycles(self, - length: str, - func: PandasGroupedMapFunction, - schema: Union[StructType, str], - period: Optional[str] = None, - offset: Optional[str] = None, - bySeries: bool = True) -> IntervalsDF: + def applyToCycles( + self, + length: str, + func: PandasGroupedMapFunction, + schema: Union[StructType, str], + period: Optional[str] = None, + offset: Optional[str] = None, + bySeries: bool = True, + ) -> IntervalsDF: """ :param length: @@ -1276,14 +1291,15 @@ def applyToCycles(self, :return: """ # apply function to get DataFrame of results - applied_df = self.groupByCycles(length, period, offset, bySeries)\ - .applyInPandas(func, schema) + applied_df = self.groupByCycles(length, period, offset, bySeries).applyInPandas( + func, schema + ) # if we have applied over series, we return a TSDF without series if bySeries: - return IntervalsDF.fromNestedBoundariesDF(applied_df, - "window", - self.series_ids) + return IntervalsDF.fromNestedBoundariesDF( + applied_df, "window", self.series_ids + ) else: return IntervalsDF.fromNestedBoundariesDF(applied_df, "window") @@ -1528,7 +1544,7 @@ def __init__( freq: str, func: Union[Callable | str], ts_col: str = "event_ts", - series_ids: Optional[List[str]] = None + series_ids: Optional[List[str]] = None, ): super(_ResampledTSDF, self).__init__(df, ts_col, series_ids) self.__freq = freq diff --git a/python/tempo/tsschema.py b/python/tempo/tsschema.py index 757383c3..d04ca0d0 100644 --- a/python/tempo/tsschema.py +++ b/python/tempo/tsschema.py @@ -4,14 +4,22 @@ import pyspark.sql.functions as sfn from pyspark.sql import Column, Window, WindowSpec -from pyspark.sql.types import BooleanType, DateType, NumericType, StringType, \ - StructField, StructType, TimestampType +from pyspark.sql.types import ( + BooleanType, + DateType, + NumericType, + StringType, + StructField, + StructType, + TimestampType, +) # # Time Units # + class TimeUnits(Enum): YEARS = auto() MONTHS = auto() @@ -100,7 +108,9 @@ def _reverseOrNot( elif isinstance(expr, list): return [col.desc() for col in expr] # reverse all columns in the expression else: - raise TypeError(f"Type for expr argument must be either Column or List[Column], instead received: {type(expr)}") + raise TypeError( + f"Type for expr argument must be either Column or List[Column], instead received: {type(expr)}" + ) @abstractmethod def orderByExpr(self, reverse: bool = False) -> Union[Column, List[Column]]: @@ -152,13 +162,15 @@ def ts_col(self) -> str: def validate(self, df_schema: StructType) -> None: # the ts column must exist - assert self.colname in df_schema.fieldNames(), \ - f"The TSIndex column {self.colname} does not exist in the given DataFrame" + assert ( + self.colname in df_schema.fieldNames() + ), f"The TSIndex column {self.colname} does not exist in the given DataFrame" schema_ts_col = df_schema[self.colname] # it must have the right type schema_ts_type = schema_ts_col.dataType - assert isinstance(schema_ts_type, type(self.dataType)), \ - f"The TSIndex column is of type {schema_ts_type}, but the expected type is {self.dataType}" + assert isinstance( + schema_ts_type, type(self.dataType) + ), f"The TSIndex column is of type {schema_ts_type}, but the expected type is {self.dataType}" def renamed(self, new_name: str) -> "TSIndex": self.__name = new_name @@ -267,8 +279,12 @@ def __init__(self, ts_idx: StructField, *ts_fields: str) -> None: self.struct: StructType = ts_idx.dataType # handle the timestamp fields if ts_fields is None or len(ts_fields) < 1: - raise ValueError("A CompoundTSIndex must have at least one ts_field specified!") - self.ts_components = [SimpleTSIndex.fromTSCol(self.struct[field]) for field in ts_fields] + raise ValueError( + "A CompoundTSIndex must have at least one ts_field specified!" + ) + self.ts_components = [ + SimpleTSIndex.fromTSCol(self.struct[field]) for field in ts_fields + ] self.primary_ts_idx = self.ts_components[0] @property @@ -276,7 +292,7 @@ def _indexAttributes(self) -> dict[str, Any]: return { "name": self.colname, "struct": self.struct, - "ts_components": self.ts_components + "ts_components": self.ts_components, } @property @@ -297,13 +313,15 @@ def unit(self) -> Optional[TimeUnits]: def validate(self, df_schema: StructType) -> None: # validate that the composite field exists - assert self.colname in df_schema.fieldNames(), \ - f"The TSIndex column {self.colname} does not exist in the given DataFrame" + assert ( + self.colname in df_schema.fieldNames() + ), f"The TSIndex column {self.colname} does not exist in the given DataFrame" schema_ts_col = df_schema[self.colname] # it must have the right type schema_ts_type = schema_ts_col.dataType - assert isinstance(schema_ts_type, StructType), \ - f"The TSIndex column is of type {schema_ts_type}, but the expected type is {StructType}" + assert isinstance( + schema_ts_type, StructType + ), f"The TSIndex column is of type {schema_ts_type}, but the expected type is {StructType}" # validate all the TS components for comp in self.ts_components: comp.validate(schema_ts_type) @@ -347,9 +365,7 @@ class ParsedTSIndex(CompositeTSIndex, ABC): Retains the original string form as well as the parsed column. """ - def __init__( - self, ts_idx: StructField, src_str_col: str, parsed_col: str - ) -> None: + def __init__(self, ts_idx: StructField, src_str_col: str, parsed_col: str) -> None: super().__init__(ts_idx, primary_ts_col=parsed_col) src_str_field = self.struct[src_str_col] if not isinstance(src_str_field.dataType, StringType): @@ -371,13 +387,17 @@ def src_str_col(self): def validate(self, df_schema: StructType) -> None: super().validate(df_schema) # make sure the parsed field exists - composite_idx_type: StructType = cast(StructType, df_schema[self.colname].dataType) - assert self.__src_str_col in composite_idx_type, \ - f"The src_str_col column {self.src_str_col} does not exist in the composite field {composite_idx_type}" + composite_idx_type: StructType = cast( + StructType, df_schema[self.colname].dataType + ) + assert ( + self.__src_str_col in composite_idx_type + ), f"The src_str_col column {self.src_str_col} does not exist in the composite field {composite_idx_type}" # make sure it's StringType src_str_field_type = composite_idx_type[self.__src_str_col].dataType - assert isinstance(src_str_field_type, StringType), \ - f"The src_str_col column {self.src_str_col} should be of StringType, but found {src_str_field_type} instead" + assert isinstance( + src_str_field_type, StringType + ), f"The src_str_col column {self.src_str_col} should be of StringType, but found {src_str_field_type} instead" class ParsedTimestampIndex(ParsedTSIndex): @@ -385,9 +405,7 @@ class ParsedTimestampIndex(ParsedTSIndex): Timeseries index class for timestamps parsed from a string column """ - def __init__( - self, ts_idx: StructField, src_str_col: str, parsed_col: str - ) -> None: + def __init__(self, ts_idx: StructField, src_str_col: str, parsed_col: str) -> None: super().__init__(ts_idx, src_str_col, parsed_col) if not isinstance(self.primary_ts_idx.dataType, TimestampType): raise TypeError( @@ -405,9 +423,7 @@ class ParsedDateIndex(ParsedTSIndex): Timeseries index class for dates parsed from a string column """ - def __init__( - self, ts_idx: StructField, src_str_col: str, parsed_col: str - ) -> None: + def __init__(self, ts_idx: StructField, src_str_col: str, parsed_col: str) -> None: super().__init__(ts_idx, src_str_col, parsed_col) if not isinstance(self.primary_ts_idx.dataType, DateType): raise TypeError( @@ -426,6 +442,7 @@ def rangeExpr(self, reverse: bool = False) -> Column: # Window Builder Interface # + class WindowBuilder(ABC): """ Abstract base class for window builders. @@ -443,10 +460,9 @@ def baseWindow(self, reverse: bool = False) -> WindowSpec: pass @abstractmethod - def rowsBetweenWindow(self, - start: int, - end: int, - reverse: bool = False) -> WindowSpec: + def rowsBetweenWindow( + self, start: int, end: int, reverse: bool = False + ) -> WindowSpec: """ build a row-based window with the given start and end offsets @@ -467,8 +483,7 @@ def allBeforeWindow(self, inclusive: bool = True) -> WindowSpec: :return: a WindowSpec object """ - return self.rowsBetweenWindow(Window.unboundedPreceding, - 0 if inclusive else -1) + return self.rowsBetweenWindow(Window.unboundedPreceding, 0 if inclusive else -1) def allAfterWindow(self, inclusive: bool = True) -> WindowSpec: """ @@ -479,14 +494,12 @@ def allAfterWindow(self, inclusive: bool = True) -> WindowSpec: :return: a WindowSpec object """ - return self.rowsBetweenWindow(0 if inclusive else 1, - Window.unboundedFollowing) + return self.rowsBetweenWindow(0 if inclusive else 1, Window.unboundedFollowing) @abstractmethod - def rangeBetweenWindow(self, - start: int, - end: int, - reverse: bool = False) -> WindowSpec: + def rangeBetweenWindow( + self, start: int, end: int, reverse: bool = False + ) -> WindowSpec: """ build a range-based window with the given start and end offsets @@ -567,8 +580,9 @@ def validate(self, df_schema: StructType) -> None: self.ts_idx.validate(df_schema) # check series IDs for sid in self.series_ids: - assert sid in df_schema.fieldNames(), \ - f"Series ID {sid} does not exist in the given DataFrame" + assert ( + sid in df_schema.fieldNames() + ), f"Series ID {sid} does not exist in the given DataFrame" def find_observational_columns(self, df_schema: StructType) -> list[str]: return list(set(df_schema.fieldNames()) - set(self.structural_columns)) @@ -596,10 +610,14 @@ def baseWindow(self, reverse: bool = False) -> WindowSpec: w = w.partitionBy([sfn.col(sid) for sid in self.series_ids]) return w - def rowsBetweenWindow(self, start: int, end: int, reverse: bool = False) -> WindowSpec: + def rowsBetweenWindow( + self, start: int, end: int, reverse: bool = False + ) -> WindowSpec: return self.baseWindow(reverse=reverse).rowsBetween(start, end) - def rangeBetweenWindow(self, start: int, end: int, reverse: bool = False) -> WindowSpec: + def rangeBetweenWindow( + self, start: int, end: int, reverse: bool = False + ) -> WindowSpec: return ( self.baseWindow(reverse=reverse) .orderBy(self.ts_idx.rangeExpr(reverse=reverse))