From cca500664e01fd99051940effce2796ca61bd0d6 Mon Sep 17 00:00:00 2001 From: Tristan Nixon Date: Wed, 3 Jan 2024 15:51:03 -0800 Subject: [PATCH] big simplification of TS index classes --- python/tempo/tsschema.py | 533 ++++++++++++++++++++------------------- 1 file changed, 272 insertions(+), 261 deletions(-) diff --git a/python/tempo/tsschema.py b/python/tempo/tsschema.py index 8c3f782c..2f2fb505 100644 --- a/python/tempo/tsschema.py +++ b/python/tempo/tsschema.py @@ -43,6 +43,7 @@ def sub_seconds_precision_digits(ts_fmt: str) -> int: else: return len(match.group(1)) + # # Abstract Timeseries Index Classes # @@ -79,13 +80,6 @@ def colname(self) -> str: :return: the column name of the timeseries index """ - @property - @abstractmethod - def ts_col(self) -> str: - """ - :return: the name of the primary timeseries column (may or may not be the same as the name) - """ - @property @abstractmethod def unit(self) -> Optional[TimeUnit]: @@ -127,26 +121,33 @@ def _reverseOrNot( elif isinstance(expr, list): return [col.desc() for col in expr] # reverse all columns in the expression else: - raise TypeError(f"Type for expr argument must be either Column or List[Column], instead received: {type(expr)}") + raise TypeError( + f"Type for expr argument must be either Column or " + f"List[Column], instead received: {type(expr)}" + ) @abstractmethod def orderByExpr(self, reverse: bool = False) -> Union[Column, List[Column]]: """ - Gets an expression that will order the :class:`TSDF` according to the timeseries index. + Gets an expression that will order the :class:`TSDF` + according to the timeseries index. :param reverse: whether the ordering should be reversed (backwards in time) - :return: an expression appropriate for ordering the :class:`TSDF` according to this index + :return: an expression appropriate for ordering the :class:`TSDF` + according to this index """ @abstractmethod def rangeExpr(self, reverse: bool = False) -> Column: """ - Gets an expression appropriate for performing range operations on the :class:`TSDF` records. + Gets an expression appropriate for performing range operations + on the :class:`TSDF` records. :param reverse: whether the ordering should be reversed (backwards in time) - :return: an expression appropriate for operforming range operations on the :class:`TSDF` records + :return: an expression appropriate for performing range operations + on the :class:`TSDF` records """ @@ -162,25 +163,23 @@ def __init__(self, ts_col: StructField) -> None: @property def _indexAttributes(self) -> dict[str, Any]: - return {"name": self.colname, "dataType": self.dataType} + return {"name": self.colname, "dataType": self.dataType, "unit": self.unit} @property def colname(self): return self.__name - @property - def ts_col(self) -> str: - return self.colname - def validate(self, df_schema: StructType) -> None: # the ts column must exist - assert self.colname in df_schema.fieldNames(), \ - f"The TSIndex column {self.colname} does not exist in the given DataFrame" + assert ( + self.colname in df_schema.fieldNames() + ), f"The TSIndex column {self.colname} does not exist in the given DataFrame" schema_ts_col = df_schema[self.colname] # it must have the right type schema_ts_type = schema_ts_col.dataType - assert isinstance(schema_ts_type, type(self.dataType)), \ - f"The TSIndex column is of type {schema_ts_type}, but the expected type is {self.dataType}" + assert isinstance( + schema_ts_type, type(self.dataType) + ), f"The TSIndex column is of type {schema_ts_type}, but the expected type is {self.dataType}" def renamed(self, new_name: str) -> "TSIndex": self.__name = new_name @@ -209,9 +208,10 @@ def fromTSCol(cls, ts_col: StructField) -> "SimpleTSIndex": # Simple TS Index types # + class OrdinalTSIndex(SimpleTSIndex): """ - Timeseries index based on a single column of a numeric or temporal type. + Timeseries index based on a single column of a numeric type. This index is "unitless", meaning that it is not associated with any particular unit of time. It can provide ordering of records, but not range operations. @@ -280,42 +280,71 @@ def rangeExpr(self, reverse: bool = False) -> Column: # -# Multi-Part TS Index types +# Complex (Multi-Field) TS Index types # -class MultiPartTSIndex(TSIndex, ABC): + +class CompositeTSIndex(TSIndex, ABC): """ Abstract base class for Timeseries Index types that reference multiple columns. Such columns are organized as a StructType column with multiple fields. + Some subset of these columns (at least 1) is considered to be a "component field", + the others are called "accessory fields". """ - def __init__(self, ts_struct: StructField) -> None: + def __init__(self, ts_struct: StructField, *component_fields: str) -> None: if not isinstance(ts_struct.dataType, StructType): raise TypeError( f"CompoundTSIndex must be of type StructType, but given " f"ts_struct {ts_struct.name} has type {ts_struct.dataType}" ) + # validate the index fields + assert len(component_fields) > 0, ( + f"A MultiFieldTSIndex must have at least 1 index component field, " + f"but {len(component_fields)} were given" + ) + for ind_f in component_fields: + assert ( + ind_f in ts_struct.dataType.fieldNames() + ), f"Index field {ind_f} does not exist in the given TSIndex schema" + # assign local attributes self.__name: str = ts_struct.name self.schema: StructType = ts_struct.dataType + self.component_fields: List[str] = list(component_fields) @property def _indexAttributes(self) -> dict[str, Any]: - return {"name": self.colname, "schema": self.schema} + return { + "name": self.colname, + "schema": self.schema, + "unit": self.unit, + "component_fields": self.component_fields, + } @property def colname(self) -> str: return self.__name + @property + def fieldNames(self) -> List[str]: + return self.schema.fieldNames() + + @property + def accessory_fields(self) -> List[str]: + return list(set(self.fieldNames) - set(self.component_fields)) + def validate(self, df_schema: StructType) -> None: # validate that the composite field exists - assert self.colname in df_schema.fieldNames(),\ - f"The TSIndex column {self.colname} does not exist in the given DataFrame" + assert ( + self.colname in df_schema.fieldNames() + ), f"The TSIndex column {self.colname} does not exist in the given DataFrame" schema_ts_col = df_schema[self.colname] # it must have the right type schema_ts_type = schema_ts_col.dataType - assert schema_ts_type == self.schema,\ - f"The TSIndex column is of type {schema_ts_type}, "\ + assert schema_ts_type == self.schema, ( + f"The TSIndex column is of type {schema_ts_type}, " f"but the expected type is {self.schema}" + ) def renamed(self, new_name: str) -> "TSIndex": self.__name = new_name @@ -327,143 +356,204 @@ def fieldPath(self, field: str) -> str: :return: A dot-separated path to the given field within the TSIndex column """ - assert field in self.schema.fieldNames(),\ - f"Field {field} does not exist in the TSIndex schema {self.schema}" + assert ( + field in self.fieldNames + ), f"Field {field} does not exist in the TSIndex schema {self.schema}" return f"{self.colname}.{field}" + def orderByExpr(self, reverse: bool = False) -> Union[Column, List[Column]]: + # build an expression for each TS component, in order + exprs = [sfn.col(self.fieldPath(comp)) for comp in self.component_fields] + return self._reverseOrNot(exprs, reverse) + # # Parsed TS Index types # -class ParsedTSIndex(MultiPartTSIndex, ABC): +# class ParsedTSIndex(CompositeTSIndex, ABC): +# """ +# Abstract base class for timeseries indices that are parsed from a string column. +# Retains the original string form as well as the parsed column. +# """ +# +# def __init__( +# self, ts_struct: StructField, parsed_ts_col: str, src_str_col: str +# ) -> None: +# super().__init__(ts_struct, parsed_ts_col) +# # validate the source string column +# src_str_field = self.schema[src_str_col] +# if not isinstance(src_str_field.dataType, StringType): +# raise TypeError( +# f"Source string column must be of StringType, " +# f"but given column {src_str_field.name} " +# f"is of type {src_str_field.dataType}" +# ) +# self._src_str_col = src_str_col +# # validate the parsed column +# assert parsed_ts_col in self.schema.fieldNames(), ( +# f"The parsed timestamp index field {parsed_ts_col} does not exist in the " +# f"MultiPart TSIndex schema {self.schema}" +# ) +# self._parsed_ts_col = parsed_ts_col +# +# @property +# def src_str_col(self): +# return self.fieldPath(self._src_str_col) +# +# @property +# def parsed_ts_col(self): +# return self.fieldPath(self._parsed_ts_col) +# +# @property +# def ts_col(self) -> str: +# return self.parsed_ts_col +# +# @property +# def _indexAttributes(self) -> dict[str, Any]: +# attrs = super()._indexAttributes +# attrs["parsed_ts_col"] = self.parsed_ts_col +# attrs["src_str_col"] = self.src_str_col +# return attrs +# +# def orderByExpr(self, reverse: bool = False) -> Union[Column, List[Column]]: +# expr = sfn.col(self.parsed_ts_col) +# return self._reverseOrNot(expr, reverse) +# +# @classmethod +# def fromParsedTimestamp( +# cls, +# ts_struct: StructField, +# parsed_ts_col: str, +# src_str_col: str, +# double_ts_col: Optional[str] = None, +# num_precision_digits: int = 6, +# ) -> "ParsedTSIndex": +# """ +# Create a ParsedTimestampIndex from a string column containing timestamps or dates +# +# :param ts_struct: The StructField for the TSIndex column +# :param parsed_ts_col: The name of the parsed timestamp column +# :param src_str_col: The name of the source string column +# :param double_ts_col: The name of the double-precision timestamp column +# :param num_precision_digits: The number of digits that make up the precision of +# +# :return: A ParsedTSIndex object +# """ +# +# # if a double timestamp column is given +# # then we are building a SubMicrosecondPrecisionTimestampIndex +# if double_ts_col is not None: +# return SubMicrosecondPrecisionTimestampIndex( +# ts_struct, +# double_ts_col, +# parsed_ts_col, +# src_str_col, +# num_precision_digits, +# ) +# # otherwise, we base it on the standard timestamp type +# # find the schema of the ts_struct column +# ts_schema = ts_struct.dataType +# if not isinstance(ts_schema, StructType): +# raise TypeError( +# f"A ParsedTSIndex must be of type StructType, but given " +# f"ts_struct {ts_struct.name} has type {ts_struct.dataType}" +# ) +# # get the type of the parsed timestamp column +# parsed_ts_type = ts_schema[parsed_ts_col].dataType +# if isinstance(parsed_ts_type, TimestampType): +# return ParsedTimestampIndex(ts_struct, parsed_ts_col, src_str_col) +# elif isinstance(parsed_ts_type, DateType): +# return ParsedDateIndex(ts_struct, parsed_ts_col, src_str_col) +# else: +# raise TypeError( +# f"ParsedTimestampIndex must be of TimestampType or DateType, " +# f"but given ts_col {parsed_ts_col} " +# f"has type {parsed_ts_type}" +# ) + + +class ParsedTimestampIndex(CompositeTSIndex): """ - Abstract base class for timeseries indices that are parsed from a string column. - Retains the original string form as well as the parsed column. + Timeseries index class for timestamps parsed from a string column """ def __init__( self, ts_struct: StructField, parsed_ts_col: str, src_str_col: str ) -> None: - super().__init__(ts_struct) - # validate the source string column + super().__init__(ts_struct, parsed_ts_col) + # validate the parsed column as a timestamp column + parsed_ts_field = self.schema[parsed_ts_col] + if not isinstance(parsed_ts_field.dataType, TimestampType): + raise TypeError( + f"ParsedTimestampIndex requires an index field of TimestampType, " + f"but the given parsed_ts_col {parsed_ts_col} " + f"has type {parsed_ts_field.dataType}" + ) + self.parsed_ts_col = parsed_ts_col + # validate the source column as a string column src_str_field = self.schema[src_str_col] if not isinstance(src_str_field.dataType, StringType): raise TypeError( - f"Source string column must be of StringType, " - f"but given column {src_str_field.name} " - f"is of type {src_str_field.dataType}" + f"ParsedTimestampIndex requires an source field of StringType, " + f"but the given src_str_col {src_str_col} " + f"has type {src_str_field.dataType}" ) - self._src_str_col = src_str_col - # validate the parsed column - assert parsed_ts_col in self.schema.fieldNames(),\ - f"The parsed timestamp index field {parsed_ts_col} does not exist in the " \ - f"MultiPart TSIndex schema {self.schema}" - self._parsed_ts_col = parsed_ts_col - - @property - def src_str_col(self): - return self.fieldPath(self._src_str_col) - - @property - def parsed_ts_col(self): - return self.fieldPath(self._parsed_ts_col) - - @property - def ts_col(self) -> str: - return self.parsed_ts_col + self.src_str_col = src_str_col @property - def _indexAttributes(self) -> dict[str, Any]: - attrs = super()._indexAttributes - attrs["parsed_ts_col"] = self.parsed_ts_col - attrs["src_str_col"] = self.src_str_col - return attrs + def unit(self) -> Optional[TimeUnit]: + return StandardTimeUnits.SECONDS - def orderByExpr(self, reverse: bool = False) -> Union[Column, List[Column]]: - expr = sfn.col(self.parsed_ts_col) + def rangeExpr(self, reverse: bool = False) -> Column: + # cast timestamp to double (fractional seconds since epoch) + expr = sfn.col(self.fieldPath(self.parsed_ts_col)).cast("double") return self._reverseOrNot(expr, reverse) - @classmethod - def fromParsedTimestamp(cls, - ts_struct: StructField, - parsed_ts_col: str, - src_str_col: str, - double_ts_col: Optional[str] = None, - num_precision_digits: int = 6) -> "ParsedTSIndex": - """ - Create a ParsedTimestampIndex from a string column containing timestamps or dates - :param ts_struct: The StructField for the TSIndex column - :param parsed_ts_col: The name of the parsed timestamp column - :param src_str_col: The name of the source string column - :param double_ts_col: The name of the double-precision timestamp column - :param num_precision_digits: The number of digits that make up the precision of - - :return: A ParsedTSIndex object - """ - - # if a double timestamp column is given - # then we are building a SubMicrosecondPrecisionTimestampIndex - if double_ts_col is not None: - return SubMicrosecondPrecisionTimestampIndex(ts_struct, - double_ts_col, - parsed_ts_col, - src_str_col, - num_precision_digits) - # otherwise, we base it on the standard timestamp type - # find the schema of the ts_struct column - ts_schema = ts_struct.dataType - if not isinstance(ts_schema, StructType): - raise TypeError( - f"A ParsedTSIndex must be of type StructType, but given " - f"ts_struct {ts_struct.name} has type {ts_struct.dataType}" - ) - # get the type of the parsed timestamp column - parsed_ts_type = ts_schema[parsed_ts_col].dataType - if isinstance(parsed_ts_type, TimestampType): - return ParsedTimestampIndex(ts_struct, parsed_ts_col, src_str_col) - elif isinstance(parsed_ts_type, DateType): - return ParsedDateIndex(ts_struct, parsed_ts_col, src_str_col) - else: - raise TypeError( - f"ParsedTimestampIndex must be of TimestampType or DateType, " - f"but given ts_col {parsed_ts_col} " - f"has type {parsed_ts_type}" - ) - - - -class ParsedTimestampIndex(ParsedTSIndex): +class ParsedDateIndex(CompositeTSIndex): """ - Timeseries index class for timestamps parsed from a string column + Timeseries index class for dates parsed from a string column """ def __init__( - self, ts_struct: StructField, parsed_ts_col: str, src_str_col: str + self, ts_struct: StructField, parsed_date_col: str, src_str_col: str ) -> None: - super().__init__(ts_struct, parsed_ts_col, src_str_col) - # validate the parsed column as a timestamp column - parsed_ts_field = self.schema[self._parsed_ts_col] - if not isinstance(parsed_ts_field.dataType, TimestampType): + super().__init__(ts_struct, parsed_date_col) + # validate the parsed column as a date column + parsed_date_field = self.schema[parsed_date_col] + if not isinstance(parsed_date_field.dataType, DateType): raise TypeError( - f"ParsedTimestampIndex must be of TimestampType, " - f"but given ts_col {self.parsed_ts_col} " - f"has type {parsed_ts_field.dataType}" + f"ParsedDateIndex requires an index field of DateType, " + f"but the given parsed_ts_col {parsed_date_col} " + f"has type {parsed_date_field.dataType}" ) - - def rangeExpr(self, reverse: bool = False) -> Column: - # cast timestamp to double (fractional seconds since epoch) - expr = sfn.col(self.parsed_ts_col).cast("double") - return self._reverseOrNot(expr, reverse) + self.parsed_date_col = parsed_date_col + # validate the source column as a string column + src_str_field = self.schema[src_str_col] + if not isinstance(src_str_field.dataType, StringType): + raise TypeError( + f"ParsedDateIndex requires an source field of StringType, " + f"but the given src_str_col {src_str_col} " + f"has type {src_str_field.dataType}" + ) + self.src_str_col = src_str_col @property def unit(self) -> Optional[TimeUnit]: - return StandardTimeUnits.SECONDS + return StandardTimeUnits.DAYS + + def rangeExpr(self, reverse: bool = False) -> Column: + # convert date to number of days since the epoch + expr = sfn.datediff( + sfn.col(self.fieldPath(self.parsed_date_col)), + sfn.lit("1970-01-01").cast("date"), + ) + return self._reverseOrNot(expr, reverse) -class SubMicrosecondPrecisionTimestampIndex(ParsedTimestampIndex): +class SubMicrosecondPrecisionTimestampIndex(CompositeTSIndex): """ Timeseries index class for timestamps with sub-microsecond precision parsed from a string column. Internally, the timestamps are stored as @@ -471,12 +561,14 @@ class SubMicrosecondPrecisionTimestampIndex(ParsedTimestampIndex): and a micro-second precision (standard) timestamp field. """ - def __init__(self, - ts_struct: StructField, - double_ts_col: str, - parsed_ts_col: str, - src_str_col: str, - num_precision_digits: int = 9) -> None: + def __init__( + self, + ts_struct: StructField, + double_ts_col: str, + parsed_ts_col: str, + src_str_col: str, + num_precision_digits: int = 9, + ) -> None: """ :param ts_struct: The StructField for the TSIndex column :param double_ts_col: The name of the double-precision timestamp column @@ -487,17 +579,16 @@ def __init__(self, You will receive a warning if this value is 6 or less, as this is the precision of the standard timestamp type. """ - super().__init__(ts_struct, parsed_ts_col, src_str_col) - # set & validate the double timestamp column - self.double_ts_col = double_ts_col + super().__init__(ts_struct, double_ts_col) # validate the double timestamp column - double_ts_field = self.schema[self.double_ts_col] + double_ts_field = self.schema[double_ts_col] if not isinstance(double_ts_field.dataType, DoubleType): raise TypeError( f"The double_ts_col must be of DoubleType, " - f"but the given double_ts_col {self.double_ts_col} " + f"but the given double_ts_col {double_ts_col} " f"has type {double_ts_field.dataType}" ) + self.double_ts_col = double_ts_col # validate the number of precision digits if num_precision_digits <= 6: warnings.warn( @@ -511,121 +602,42 @@ def __init__(self, 10 ** (-num_precision_digits), num_precision_digits, ) - - def orderByExpr(self, reverse: bool = False) -> Union[Column, List[Column]]: - expr = sfn.col(self.double_ts_col) - return self._reverseOrNot(expr, reverse) - - def rangeExpr(self, reverse: bool = False) -> Column: - # just use the order by expression, since this is the same - return self.orderByExpr(reverse) - - @property - def unit(self) -> Optional[TimeUnit]: - return self.__unit - - -class ParsedDateIndex(ParsedTSIndex): - """ - Timeseries index class for dates parsed from a string column - """ - - def __init__( - self, ts_struct: StructField, parsed_ts_col: str, src_str_col: str - ) -> None: - super().__init__(ts_struct, parsed_ts_col, src_str_col) - # validate the parsed column as a date column - parsed_ts_field = self.schema[self._parsed_ts_col] - if not isinstance(parsed_ts_field.dataType, DateType): + # validate the parsed column as a timestamp column + parsed_ts_field = self.schema[parsed_ts_col] + if not isinstance(parsed_ts_field.dataType, TimestampType): raise TypeError( - f"ParsedTimestampIndex must be of DateType, " - f"but given ts_col {self.parsed_ts_col} " + f"parsed_ts_col field must be of TimestampType, " + f"but the given parsed_ts_col {parsed_ts_col} " f"has type {parsed_ts_field.dataType}" ) + self.parsed_ts_col = parsed_ts_col + # validate the source column as a string column + src_str_field = self.schema[src_str_col] + if not isinstance(src_str_field.dataType, StringType): + raise TypeError( + f"src_str_col field must be of StringType, " + f"but the given src_str_col {src_str_col} " + f"has type {src_str_field.dataType}" + ) + self.src_str_col = src_str_col @property def unit(self) -> Optional[TimeUnit]: - return StandardTimeUnits.DAYS + return self.__unit - def rangeExpr(self, reverse: bool = False) -> Column: - # convert date to number of days since the epoch - expr = sfn.datediff( - sfn.col(self.parsed_ts_col), sfn.lit("1970-01-01").cast("date") - ) + def orderByExpr(self, reverse: bool = False) -> Union[Column, List[Column]]: + expr = sfn.col(self.fieldPath(self.double_ts_col)) return self._reverseOrNot(expr, reverse) - -# -# Complex (Multi-Field) TS Index Types -# - - -class CompositeTSIndex(MultiPartTSIndex, ABC): - """ - Abstract base class for complex Timeseries Index classes - that involve two or more columns organized into a StructType column - """ - - def __init__(self, ts_struct: StructField, *ts_fields: str) -> None: - super().__init__(ts_struct) - # handle the timestamp fields - assert len(ts_fields) > 1,\ - f"CompositeTSIndex must have at least two timestamp fields, " \ - f"but only {len(ts_fields)} were given" - self.ts_components = \ - [SimpleTSIndex.fromTSCol(self.schema[field]) for field in ts_fields] - - @property - def _indexAttributes(self) -> dict[str, Any]: - attrs = super()._indexAttributes - attrs["ts_components"] = [str(c) for c in self.ts_components] - return attrs - - def primary_ts_col(self) -> str: - return self.get_ts_component(0) - - @property - def primary_ts_idx(self) -> TSIndex: - return self.ts_components[0] - - @property - def unit(self) -> Optional[TimeUnit]: - return self.primary_ts_idx.unit - - def validate(self, df_schema: StructType) -> None: - super().validate(df_schema) - # validate all the TS components - schema_ts_type = df_schema[self.colname].dataType - assert isinstance(schema_ts_type, StructType),\ - f"CompositeTSIndex must be of StructType, " \ - f"but given ts_col {self.colname} " \ - f"has type {schema_ts_type}" - for comp in self.ts_components: - comp.validate(schema_ts_type) - - def get_ts_component(self, component_index: int) -> str: - """ - Returns the full path to a component field that is a functional part of the timeseries. - - :param component_index: the index giving the ordering of the component field within the timeseries - - :return: a column name that can be used to reference the component field in PySpark expressions - """ - return self.fieldPath(self.ts_components[component_index].colname) - - def orderByExpr(self, reverse: bool = False) -> Column: - # build an expression for each TS component, in order - exprs = [sfn.col(self.fieldPath(comp.colname)) for comp in self.ts_components] - return self._reverseOrNot(exprs, reverse) - - def rangeExpr(self, reverse: bool = False) -> Column: - return self.primary_ts_idx.rangeExpr(reverse) - + def rangeExpr(self, reverse: bool = False) -> Union[Column, List[Column]]: + # just use the order by expression, since this is the same + return self.orderByExpr(reverse) # # Window Builder Interface # + class WindowBuilder(ABC): """ Abstract base class for window builders. @@ -643,10 +655,9 @@ def baseWindow(self, reverse: bool = False) -> WindowSpec: pass @abstractmethod - def rowsBetweenWindow(self, - start: int, - end: int, - reverse: bool = False) -> WindowSpec: + def rowsBetweenWindow( + self, start: int, end: int, reverse: bool = False + ) -> WindowSpec: """ build a row-based window with the given start and end offsets @@ -667,8 +678,7 @@ def allBeforeWindow(self, inclusive: bool = True) -> WindowSpec: :return: a WindowSpec object """ - return self.rowsBetweenWindow(Window.unboundedPreceding, - 0 if inclusive else -1) + return self.rowsBetweenWindow(Window.unboundedPreceding, 0 if inclusive else -1) def allAfterWindow(self, inclusive: bool = True) -> WindowSpec: """ @@ -679,14 +689,12 @@ def allAfterWindow(self, inclusive: bool = True) -> WindowSpec: :return: a WindowSpec object """ - return self.rowsBetweenWindow(0 if inclusive else 1, - Window.unboundedFollowing) + return self.rowsBetweenWindow(0 if inclusive else 1, Window.unboundedFollowing) @abstractmethod - def rangeBetweenWindow(self, - start: int, - end: int, - reverse: bool = False) -> WindowSpec: + def rangeBetweenWindow( + self, start: int, end: int, reverse: bool = False + ) -> WindowSpec: """ build a range-based window with the given start and end offsets @@ -767,8 +775,9 @@ def validate(self, df_schema: StructType) -> None: self.ts_idx.validate(df_schema) # check series IDs for sid in self.series_ids: - assert sid in df_schema.fieldNames(), \ - f"Series ID {sid} does not exist in the given DataFrame" + assert ( + sid in df_schema.fieldNames() + ), f"Series ID {sid} does not exist in the given DataFrame" def find_observational_columns(self, df_schema: StructType) -> list[str]: return list(set(df_schema.fieldNames()) - set(self.structural_columns)) @@ -796,14 +805,16 @@ def baseWindow(self, reverse: bool = False) -> WindowSpec: w = w.partitionBy([sfn.col(sid) for sid in self.series_ids]) return w - def rowsBetweenWindow(self, start: int, end: int, reverse: bool = False) -> WindowSpec: + def rowsBetweenWindow( + self, start: int, end: int, reverse: bool = False + ) -> WindowSpec: return self.baseWindow(reverse=reverse).rowsBetween(start, end) - def rangeBetweenWindow(self, start: int, end: int, reverse: bool = False) -> WindowSpec: + def rangeBetweenWindow( + self, start: int, end: int, reverse: bool = False + ) -> WindowSpec: return ( self.baseWindow(reverse=reverse) .orderBy(self.ts_idx.rangeExpr(reverse=reverse)) .rangeBetween(start, end) ) - -