From 68cfe8b5e8849b0f75129932e7971a04b1fd0985 Mon Sep 17 00:00:00 2001 From: Tristan Nixon Date: Tue, 20 Jun 2023 16:56:19 -0700 Subject: [PATCH 01/13] checkpoint save - refactoring TSIndex hierarchy --- python/tempo/tsdf.py | 4 +- python/tempo/tsschema.py | 297 ++++++++++++++++++++++++--------------- 2 files changed, 184 insertions(+), 117 deletions(-) diff --git a/python/tempo/tsdf.py b/python/tempo/tsdf.py index 4a04ebc1..8eb9adb9 100644 --- a/python/tempo/tsdf.py +++ b/python/tempo/tsdf.py @@ -883,7 +883,7 @@ def asofJoin( if tsPartitionVal is None: seq_col = None if isinstance(combined_df.ts_index, CompositeTSIndex): - seq_col = cast(CompositeTSIndex, combined_df.ts_index).ts_component(1) + seq_col = cast(CompositeTSIndex, combined_df.ts_index).get_ts_component(1) asofDF = combined_df.__getLastRightRow( left_tsdf.ts_col, right_columns, @@ -898,7 +898,7 @@ def asofJoin( ) seq_col = None if isinstance(tsPartitionDF.ts_index, CompositeTSIndex): - seq_col = cast(CompositeTSIndex, tsPartitionDF.ts_index).ts_component(1) + seq_col = cast(CompositeTSIndex, tsPartitionDF.ts_index).get_ts_component(1) asofDF = tsPartitionDF.__getLastRightRow( left_tsdf.ts_col, right_columns, diff --git a/python/tempo/tsschema.py b/python/tempo/tsschema.py index b0abb3e3..f1d5f259 100644 --- a/python/tempo/tsschema.py +++ b/python/tempo/tsschema.py @@ -23,7 +23,7 @@ class TimeUnits(Enum): # -# Timeseries Index Classes +# Abstract Timeseries Index Classes # @@ -72,6 +72,13 @@ def unit(self) -> Optional[TimeUnits]: :return: the unit of this index, that is, the unit that a range value of 1 represents (Days, seconds, etc.) """ + @property + def has_unit(self) -> bool: + """ + :return: whether this index has a unit + """ + return self.unit is not None + @abstractmethod def validate(self, df_schema: StructType) -> None: """ @@ -122,20 +129,15 @@ def rangeExpr(self, reverse: bool = False) -> Column: """ -# -# Simple TS Index types -# - - class SimpleTSIndex(TSIndex, ABC): """ Abstract base class for simple Timeseries Index types that only reference a single column for maintaining the temporal structure """ - def __init__(self, ts_idx: StructField) -> None: - self.__name = ts_idx.name - self.dataType = ts_idx.dataType + def __init__(self, ts_col: StructField) -> None: + self.__name = ts_col.name + self.dataType = ts_col.dataType @property def _indexAttributes(self) -> dict[str, Any]: @@ -171,7 +173,7 @@ def orderByExpr(self, reverse: bool = False) -> Column: def fromTSCol(cls, ts_col: StructField) -> "SimpleTSIndex": # pick our implementation based on the column type if isinstance(ts_col.dataType, NumericType): - return NumericIndex(ts_col) + return OrdinalTSIndex(ts_col) elif isinstance(ts_col.dataType, TimestampType): return SimpleTimestampIndex(ts_col) elif isinstance(ts_col.dataType, DateType): @@ -182,24 +184,34 @@ def fromTSCol(cls, ts_col: StructField) -> "SimpleTSIndex": ) -class NumericIndex(SimpleTSIndex): + + +# +# Simple TS Index types +# + +class OrdinalTSIndex(SimpleTSIndex): """ Timeseries index based on a single column of a numeric or temporal type. + This index is "unitless", meaning that it is not associated with any + particular unit of time. It can provide ordering of records, but not + range operations. """ - def __init__(self, ts_idx: StructField) -> None: - if not isinstance(ts_idx.dataType, NumericType): + def __init__(self, ts_col: StructField) -> None: + if not isinstance(ts_col.dataType, NumericType): raise TypeError( - f"NumericIndex must be of a numeric type, but ts_col {ts_idx.name} has type {ts_idx.dataType}" + f"OrdinalTSIndex must be of a numeric type, but ts_col {ts_col.name} " + f"has type {ts_col.dataType}" ) - super().__init__(ts_idx) + super().__init__(ts_col) @property def unit(self) -> Optional[TimeUnits]: return None def rangeExpr(self, reverse: bool = False) -> Column: - return self.orderByExpr(reverse) + raise TypeError("Cannot perform range operations on an OrdinalTSIndex") class SimpleTimestampIndex(SimpleTSIndex): @@ -207,12 +219,13 @@ class SimpleTimestampIndex(SimpleTSIndex): Timeseries index based on a single Timestamp column """ - def __init__(self, ts_idx: StructField) -> None: - if not isinstance(ts_idx.dataType, TimestampType): + def __init__(self, ts_col: StructField) -> None: + if not isinstance(ts_col.dataType, TimestampType): raise TypeError( - f"SimpleTimestampIndex must be of TimestampType, but given ts_col {ts_idx.name} has type {ts_idx.dataType}" + f"SimpleTimestampIndex must be of TimestampType, " + f"but given ts_col {ts_col.name} has type {ts_col.dataType}" ) - super().__init__(ts_idx) + super().__init__(ts_col) @property def unit(self) -> Optional[TimeUnits]: @@ -229,12 +242,13 @@ class SimpleDateIndex(SimpleTSIndex): Timeseries index based on a single Date column """ - def __init__(self, ts_idx: StructField) -> None: - if not isinstance(ts_idx.dataType, DateType): + def __init__(self, ts_col: StructField) -> None: + if not isinstance(ts_col.dataType, DateType): raise TypeError( - f"DateIndex must be of DateType, but given ts_col {ts_idx.name} has type {ts_idx.dataType}" + f"DateIndex must be of DateType, " + f"but given ts_col {ts_col.name} has type {ts_col.dataType}" ) - super().__init__(ts_idx) + super().__init__(ts_col) @property def unit(self) -> Optional[TimeUnits]: @@ -247,137 +261,110 @@ def rangeExpr(self, reverse: bool = False) -> Column: # -# Complex (Multi-Field) TS Index Types +# Multi-Part TS Index types # - -class CompositeTSIndex(TSIndex): +class MultiPartTSIndex(TSIndex, ABC): """ - Abstract base class for complex Timeseries Index classes - that involve two or more columns organized into a StructType column + Abstract base class for Timeseries Index types that reference multiple columns. + Such columns are organized as a StructType column with multiple fields. """ - def __init__(self, ts_idx: StructField, *ts_fields: str) -> None: - if not isinstance(ts_idx.dataType, StructType): + def __init__(self, ts_struct: StructField) -> None: + if not isinstance(ts_struct.dataType, StructType): raise TypeError( - f"CompoundTSIndex must be of type StructType, but given compound_ts_idx {ts_idx.name} has type {ts_idx.dataType}" + f"CompoundTSIndex must be of type StructType, but given " + f"ts_struct {ts_struct.name} has type {ts_struct.dataType}" ) - self.__name: str = ts_idx.name - self.struct: StructType = ts_idx.dataType - # handle the timestamp fields - if ts_fields is None or len(ts_fields) < 1: - raise ValueError("A CompoundTSIndex must have at least one ts_field specified!") - self.ts_components = [SimpleTSIndex.fromTSCol(self.struct[field]) for field in ts_fields] - self.primary_ts_idx = self.ts_components[0] - + self.__name: str = ts_struct.name + self.schema: StructType = ts_struct.dataType @property def _indexAttributes(self) -> dict[str, Any]: - return { - "name": self.colname, - "struct": self.struct, - "ts_components": self.ts_components - } + return {"name": self.colname, "schema": self.schema} @property def colname(self) -> str: return self.__name - @property - def ts_col(self) -> str: - return self.primary_ts_col - - @property - def primary_ts_col(self) -> str: - return self.ts_component(0) - - @property - def unit(self) -> Optional[TimeUnits]: - return self.primary_ts_idx.unit - def validate(self, df_schema: StructType) -> None: # validate that the composite field exists - assert(self.colname in df_schema.fieldNames(), - f"The TSIndex column {self.colname} does not exist in the given DataFrame") + assert self.colname in df_schema.fieldNames(),\ + f"The TSIndex column {self.colname} does not exist in the given DataFrame" schema_ts_col = df_schema[self.colname] # it must have the right type schema_ts_type = schema_ts_col.dataType - assert( isinstance(schema_ts_type, StructType), - f"The TSIndex column is of type {schema_ts_type}, but the expected type is {StructType}" ) - # validate all the TS components - for comp in self.ts_components: - comp.validate(schema_ts_type) + assert schema_ts_type == self.schema,\ + f"The TSIndex column is of type {schema_ts_type}, "\ + f"but the expected type is {self.schema}" def renamed(self, new_name: str) -> "TSIndex": self.__name = new_name return self - def component(self, component_name: str) -> str: - """ - Returns the full path to a component field that is within the composite index - - :param component_name: the name of the component element within the composite index - - :return: a column name that can be used to reference the component field in PySpark expressions + def fieldPath(self, field: str) -> str: """ - return f"{self.colname}.{self.struct[component_name].name}" + :param field: The name of a field within the TSIndex column - def ts_component(self, component_index: int) -> str: + :return: A dot-separated path to the given field within the TSIndex column """ - Returns the full path to a component field that is a functional part of the timeseries. - - :param component_index: the index giving the ordering of the component field within the timeseries + assert field in self.schema.fieldNames(),\ + f"Field {field} does not exist in the TSIndex schema {self.schema}" + return f"{self.colname}.{field}" - :return: a column name that can be used to reference the component field in PySpark expressions - """ - return self.component(self.ts_components[component_index].colname) - def orderByExpr(self, reverse: bool = False) -> Column: - # build an expression for each TS component, in order - exprs = [sfn.col(self.component(comp.colname)) for comp in self.ts_components] - return self._reverseOrNot(exprs, reverse) - - def rangeExpr(self, reverse: bool = False) -> Column: - return self.primary_ts_idx.rangeExpr(reverse) +# +# Parsed TS Index types +# -class ParsedTSIndex(CompositeTSIndex, ABC): +class ParsedTSIndex(MultiPartTSIndex, ABC): """ Abstract base class for timeseries indices that are parsed from a string column. Retains the original string form as well as the parsed column. """ def __init__( - self, ts_idx: StructField, src_str_col: str, parsed_col: str + self, ts_struct: StructField, parsed_ts_col: str, src_str_col: str ) -> None: - super().__init__(ts_idx, primary_ts_col=parsed_col) - src_str_field = self.struct[src_str_col] + super().__init__(ts_struct) + # validate the source string column + src_str_field = self.schema[src_str_col] if not isinstance(src_str_field.dataType, StringType): raise TypeError( - f"Source string column must be of StringType, but given column {src_str_field.name} is of type {src_str_field.dataType}" + f"Source string column must be of StringType, " + f"but given column {src_str_field.name} " + f"is of type {src_str_field.dataType}" ) self.__src_str_col = src_str_col + # validate the parsed column + assert parsed_ts_col in self.schema.fieldNames(),\ + f"The parsed timestamp index field {parsed_ts_col} does not exist in the " \ + f"MultiPart TSIndex schema {self.schema}" + self.__parsed_ts_col = parsed_ts_col + + @property + def src_str_col(self): + return self.fieldPath(self.__src_str_col) + + @property + def parsed_ts_col(self): + return self.fieldPath(self.__parsed_ts_col) + + @property + def ts_col(self) -> str: + return self.parsed_ts_col @property def _indexAttributes(self) -> dict[str, Any]: attrs = super()._indexAttributes + attrs["parsed_ts_col"] = self.parsed_ts_col attrs["src_str_col"] = self.src_str_col return attrs - @property - def src_str_col(self): - return self.component(self.__src_str_col) - - def validate(self, df_schema: StructType) -> None: - super().validate(df_schema) - # make sure the parsed field exists - composite_idx_type: StructType = cast(StructType, df_schema[self.colname].dataType) - assert( self.__src_str_col in composite_idx_type, - f"The src_str_col column {self.src_str_col} does not exist in the composite field {composite_idx_type}") - # make sure it's StringType - src_str_field_type = composite_idx_type[self.__src_str_col].dataType - assert( isinstance(src_str_field_type, StringType), - f"The src_str_col column {self.src_str_col} should be of StringType, but found {src_str_field_type} instead" ) + def orderByExpr(self, reverse: bool = False) -> Union[Column, List[Column]]: + expr = sfn.col(self.parsed_ts_col) + return self._reverseOrNot(expr, reverse) class ParsedTimestampIndex(ParsedTSIndex): @@ -386,19 +373,26 @@ class ParsedTimestampIndex(ParsedTSIndex): """ def __init__( - self, ts_idx: StructField, src_str_col: str, parsed_col: str + self, ts_struct: StructField, parsed_ts_col: str, src_str_col: str ) -> None: - super().__init__(ts_idx, src_str_col, parsed_col) - if not isinstance(self.primary_ts_idx.dataType, TimestampType): + super().__init__(ts_struct, parsed_ts_col, src_str_col) + parsed_ts_field = self.schema[self.__parsed_ts_col] + if not isinstance(parsed_ts_field.dataType, TimestampType): raise TypeError( - f"ParsedTimestampIndex must be of TimestampType, but given ts_col {self.primary_ts_idx.colname} has type {self.primary_ts_idx.dataType}" + f"ParsedTimestampIndex must be of TimestampType, " + f"but given ts_col {self.parsed_ts_col} " + f"has type {parsed_ts_field.dataType}" ) def rangeExpr(self, reverse: bool = False) -> Column: # cast timestamp to double (fractional seconds since epoch) - expr = sfn.col(self.primary_ts_col).cast("double") + expr = sfn.col(self.parsed_ts_col).cast("double") return self._reverseOrNot(expr, reverse) + @property + def unit(self) -> Optional[TimeUnits]: + return TimeUnits.SECONDS + class ParsedDateIndex(ParsedTSIndex): """ @@ -406,22 +400,95 @@ class ParsedDateIndex(ParsedTSIndex): """ def __init__( - self, ts_idx: StructField, src_str_col: str, parsed_col: str + self, ts_struct: StructField, parsed_ts_col: str, src_str_col: str ) -> None: - super().__init__(ts_idx, src_str_col, parsed_col) - if not isinstance(self.primary_ts_idx.dataType, DateType): + super().__init__(ts_struct, parsed_ts_col, src_str_col) + parsed_ts_field = self.schema[self.__parsed_ts_col] + if not isinstance(parsed_ts_field.dataType, DateType): raise TypeError( - f"ParsedDateIndex must be of DateType, but given ts_col {self.primary_ts_idx.colname} has type {self.primary_ts_idx.dataType}" + f"ParsedTimestampIndex must be of DateType, " + f"but given ts_col {self.parsed_ts_col} " + f"has type {parsed_ts_field.dataType}" ) + @property + def unit(self) -> Optional[TimeUnits]: + return TimeUnits.DAYS + def rangeExpr(self, reverse: bool = False) -> Column: # convert date to number of days since the epoch expr = sfn.datediff( - sfn.col(self.primary_ts_col), sfn.lit("1970-01-01").cast("date") + sfn.col(self.parsed_ts_col), sfn.lit("1970-01-01").cast("date") ) return self._reverseOrNot(expr, reverse) +# +# Complex (Multi-Field) TS Index Types +# + + +class CompositeTSIndex(MultiPartTSIndex, ABC): + """ + Abstract base class for complex Timeseries Index classes + that involve two or more columns organized into a StructType column + """ + + def __init__(self, ts_struct: StructField, *ts_fields: str) -> None: + super().__init__(ts_struct) + # handle the timestamp fields + assert len(ts_fields) > 1,\ + f"CompositeTSIndex must have at least two timestamp fields, " \ + f"but only {len(ts_fields)} were given" + self.ts_components = [SimpleTSIndex.fromTSCol(self.schema[field]) for field in ts_fields] + + @property + def _indexAttributes(self) -> dict[str, Any]: + attrs = super()._indexAttributes + attrs["ts_components"] = [str(c) for c in self.ts_components] + return attrs + + def primary_ts_col(self) -> str: + return self.get_ts_component(0) + + @property + def primary_ts_idx(self) -> TSIndex: + return self.ts_components[0] + + @property + def unit(self) -> Optional[TimeUnits]: + return self.primary_ts_idx.unit + + def validate(self, df_schema: StructType) -> None: + super().validate(df_schema) + # validate all the TS components + schema_ts_type = df_schema[self.colname].dataType + assert isinstance(schema_ts_type, StructType),\ + f"CompositeTSIndex must be of StructType, " \ + f"but given ts_col {self.colname} " \ + f"has type {schema_ts_type}" + for comp in self.ts_components: + comp.validate(schema_ts_type) + + def get_ts_component(self, component_index: int) -> str: + """ + Returns the full path to a component field that is a functional part of the timeseries. + + :param component_index: the index giving the ordering of the component field within the timeseries + + :return: a column name that can be used to reference the component field in PySpark expressions + """ + return self.fieldPath(self.ts_components[component_index].colname) + + def orderByExpr(self, reverse: bool = False) -> Column: + # build an expression for each TS component, in order + exprs = [sfn.col(self.fieldPath(comp.colname)) for comp in self.ts_components] + return self._reverseOrNot(exprs, reverse) + + def rangeExpr(self, reverse: bool = False) -> Column: + return self.primary_ts_idx.rangeExpr(reverse) + + # # Window Builder Interface # From 7743b54ee2487dd16194211d3d2a38d0601016fb Mon Sep 17 00:00:00 2001 From: Tristan Nixon Date: Wed, 21 Jun 2023 15:13:10 -0700 Subject: [PATCH 02/13] checkpoint save - refactoring TSIndex hierarchy --- python/tempo/tsschema.py | 36 ++++++++++++++++++++++++++++++++++-- 1 file changed, 34 insertions(+), 2 deletions(-) diff --git a/python/tempo/tsschema.py b/python/tempo/tsschema.py index f1d5f259..d16ca25e 100644 --- a/python/tempo/tsschema.py +++ b/python/tempo/tsschema.py @@ -1,6 +1,7 @@ from enum import Enum, auto from abc import ABC, abstractmethod from typing import cast, Any, Union, Optional, Collection, List +import re import pyspark.sql.functions as sfn from pyspark.sql import Column, WindowSpec, Window @@ -22,6 +23,24 @@ class TimeUnits(Enum): NANOSECONDS = auto() +# +# Timestamp parsing helpers +# + +DEFAULT_TIMESTAMP_FORMAT = "yyyy-MM-dd HH:mm:ss" +__time_pattern_components = "hHkKmsS" + +def is_time_format(format: str) -> bool: + """ + Checcks whether the given format string contains time elements, + or if it is just a date format + + :param format: the format string to check + + :return: whether the given format string contains time elements + """ + return any(c in format for c in __time_pattern_components) + # # Abstract Timeseries Index Classes # @@ -184,8 +203,6 @@ def fromTSCol(cls, ts_col: StructField) -> "SimpleTSIndex": ) - - # # Simple TS Index types # @@ -366,6 +383,21 @@ def orderByExpr(self, reverse: bool = False) -> Union[Column, List[Column]]: expr = sfn.col(self.parsed_ts_col) return self._reverseOrNot(expr, reverse) + @classmethod + def fromParsedTimestamp(cls, + str_ts_col: str, + ts_fmt: str = DEFAULT_TIMESTAMP_FORMAT) -> "ParsedTimestampIndex": + """ + Create a ParsedTimestampIndex from a string column containing timestamps or dates + + :param str_ts_col: The name of the string column containing timestamps or dates + :param ts_fmt: The format of the timestamps or dates in the string column + + :return: A ParsedTimestampIndex + """ + # TODO fill this in + pass + class ParsedTimestampIndex(ParsedTSIndex): """ From de8a6bd4c8ab6a1573dd0f6de43ef95242e1dd2b Mon Sep 17 00:00:00 2001 From: Tristan Nixon Date: Thu, 13 Jul 2023 15:35:16 -0700 Subject: [PATCH 03/13] timeslicing functions should have type Any for user-provided ts params --- python/tempo/tsdf.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/python/tempo/tsdf.py b/python/tempo/tsdf.py index 8eb9adb9..fe32e3e4 100644 --- a/python/tempo/tsdf.py +++ b/python/tempo/tsdf.py @@ -111,9 +111,9 @@ def fromSubsequenceCol( ) -> "TSDF": # construct a struct with the ts_col and subsequence_col struct_col_name = cls.__DEFAULT_TS_IDX_COL - with_subseq_struct_df = cls.__makeStructFromCols( - df, struct_col_name, [ts_col, subsequence_col] - ) + with_subseq_struct_df = cls.__makeStructFromCols(df, + struct_col_name, + [ts_col, subsequence_col]) # construct an appropriate TSIndex subseq_struct = with_subseq_struct_df.schema[struct_col_name] subseq_idx = CompositeTSIndex(subseq_struct, ts_col, subsequence_col) @@ -408,7 +408,7 @@ def where(self, condition: Union[Column, str]) -> "TSDF": where_df = self.df.where(condition) return self.__withTransformedDF(where_df) - def __slice(self, op: str, target_ts: Union[str, int]) -> "TSDF": + def __slice(self, op: str, target_ts: Any) -> "TSDF": """ Private method to slice TSDF by time @@ -424,7 +424,7 @@ def __slice(self, op: str, target_ts: Union[str, int]) -> "TSDF": sliced_df = self.df.where(slice_expr) return self.__withTransformedDF(sliced_df) - def at(self, ts: Union[str, int]) -> "TSDF": + def at(self, ts: Any) -> "TSDF": """ Select only records at a given time @@ -434,7 +434,7 @@ def at(self, ts: Union[str, int]) -> "TSDF": """ return self.__slice("==", ts) - def before(self, ts: Union[str, int]) -> "TSDF": + def before(self, ts: Any) -> "TSDF": """ Select only records before a given time @@ -444,7 +444,7 @@ def before(self, ts: Union[str, int]) -> "TSDF": """ return self.__slice("<", ts) - def atOrBefore(self, ts: Union[str, int]) -> "TSDF": + def atOrBefore(self, ts: Any) -> "TSDF": """ Select only records at or before a given time @@ -454,7 +454,7 @@ def atOrBefore(self, ts: Union[str, int]) -> "TSDF": """ return self.__slice("<=", ts) - def after(self, ts: Union[str, int]) -> "TSDF": + def after(self, ts: Any) -> "TSDF": """ Select only records after a given time @@ -464,7 +464,7 @@ def after(self, ts: Union[str, int]) -> "TSDF": """ return self.__slice(">", ts) - def atOrAfter(self, ts: Union[str, int]) -> "TSDF": + def atOrAfter(self, ts: Any) -> "TSDF": """ Select only records at or after a given time @@ -475,7 +475,7 @@ def atOrAfter(self, ts: Union[str, int]) -> "TSDF": return self.__slice(">=", ts) def between( - self, start_ts: Union[str, int], end_ts: Union[str, int], inclusive: bool = True + self, start_ts: Any, end_ts: Any, inclusive: bool = True ) -> "TSDF": """ Select only records in a given range @@ -530,7 +530,7 @@ def latest(self, n: int = 1) -> "TSDF": next_window = self.baseWindow(reverse=True) return self.__top_rows_per_series(next_window, n) - def priorTo(self, ts: Union[str, int], n: int = 1) -> "TSDF": + def priorTo(self, ts: Any, n: int = 1) -> "TSDF": """ Select the n most recent records prior to a given time You can think of this like an 'asOf' select - it selects the records as of a particular time @@ -542,7 +542,7 @@ def priorTo(self, ts: Union[str, int], n: int = 1) -> "TSDF": """ return self.atOrBefore(ts).latest(n) - def subsequentTo(self, ts: Union[str, int], n: int = 1) -> "TSDF": + def subsequentTo(self, ts: Any, n: int = 1) -> "TSDF": """ Select the n records subsequent to a give time From 5696ca92a74b980135d0e963b2ab0ace06135a5e Mon Sep 17 00:00:00 2001 From: Tristan Nixon Date: Thu, 20 Jul 2023 14:30:30 -0700 Subject: [PATCH 04/13] New timeunit type to track different units of time --- python/tempo/timeunit.py | 46 ++++++++++++++ python/tempo/tsschema.py | 130 +++++++++++++++++++++++++++------------ 2 files changed, 137 insertions(+), 39 deletions(-) create mode 100644 python/tempo/timeunit.py diff --git a/python/tempo/timeunit.py b/python/tempo/timeunit.py new file mode 100644 index 00000000..c7282a17 --- /dev/null +++ b/python/tempo/timeunit.py @@ -0,0 +1,46 @@ +from typing import NamedTuple +from functools import total_ordering + + +@total_ordering +class TimeUnit(NamedTuple): + name: str + approx_seconds: float + sub_second_precision: int = 0 + """ + Represents a unit of time, with a name, + an approximate number of seconds, + and a sub-second precision. + """ + + def __eq__(self, other): + return self.approx_seconds == other.approx_seconds + + def __lt__(self, other): + return self.approx_seconds < other.approx_seconds + + +TimeUnitsType = NamedTuple("TimeUnitsType", + [("YEARS", TimeUnit), + ("MONTHS", TimeUnit), + ("WEEKS", TimeUnit), + ("DAYS", TimeUnit), + ("HOURS", TimeUnit), + ("MINUTES", TimeUnit), + ("SECONDS", TimeUnit), + ("MILLISECONDS", TimeUnit), + ("MICROSECONDS", TimeUnit), + ("NANOSECONDS", TimeUnit)]) + +StandardTimeUnits = TimeUnitsType( + TimeUnit("year", 365 * 24 * 60 * 60), + TimeUnit("month", 30 * 24 * 60 * 60), + TimeUnit("week", 7 * 24 * 60 * 60), + TimeUnit("day", 24 * 60 * 60), + TimeUnit("hour", 60 * 60), + TimeUnit("minute", 60), + TimeUnit("second", 1), + TimeUnit("millisecond", 1e-03, 3), + TimeUnit("microsecond", 1e-06, 6), + TimeUnit("nanosecond", 1e-09, 9) +) diff --git a/python/tempo/tsschema.py b/python/tempo/tsschema.py index d16ca25e..cd0db5e1 100644 --- a/python/tempo/tsschema.py +++ b/python/tempo/tsschema.py @@ -1,27 +1,13 @@ -from enum import Enum, auto +import warnings from abc import ABC, abstractmethod -from typing import cast, Any, Union, Optional, Collection, List -import re +from typing import Any, Collection, List, Optional, Union import pyspark.sql.functions as sfn -from pyspark.sql import Column, WindowSpec, Window +from pyspark.sql import Column, Window, WindowSpec from pyspark.sql.types import * from pyspark.sql.types import NumericType -# -# Time Units -# - -class TimeUnits(Enum): - YEARS = auto() - MONTHS = auto() - DAYS = auto() - HOURS = auto() - MINUTES = auto() - SECONDS = auto() - MICROSECONDS = auto() - NANOSECONDS = auto() - +from tempo.timeunit import TimeUnit, StandardTimeUnits # # Timestamp parsing helpers @@ -30,16 +16,17 @@ class TimeUnits(Enum): DEFAULT_TIMESTAMP_FORMAT = "yyyy-MM-dd HH:mm:ss" __time_pattern_components = "hHkKmsS" -def is_time_format(format: str) -> bool: + +def is_time_format(ts_fmt: str) -> bool: """ Checcks whether the given format string contains time elements, or if it is just a date format - :param format: the format string to check + :param ts_fmt: the format string to check :return: whether the given format string contains time elements """ - return any(c in format for c in __time_pattern_components) + return any(c in ts_fmt for c in __time_pattern_components) # # Abstract Timeseries Index Classes @@ -86,7 +73,7 @@ def ts_col(self) -> str: @property @abstractmethod - def unit(self) -> Optional[TimeUnits]: + def unit(self) -> Optional[TimeUnit]: """ :return: the unit of this index, that is, the unit that a range value of 1 represents (Days, seconds, etc.) """ @@ -172,13 +159,13 @@ def ts_col(self) -> str: def validate(self, df_schema: StructType) -> None: # the ts column must exist - assert(self.colname in df_schema.fieldNames(), - f"The TSIndex column {self.colname} does not exist in the given DataFrame") + assert self.colname in df_schema.fieldNames(), \ + f"The TSIndex column {self.colname} does not exist in the given DataFrame" schema_ts_col = df_schema[self.colname] # it must have the right type schema_ts_type = schema_ts_col.dataType - assert( isinstance(schema_ts_type, type(self.dataType)), - f"The TSIndex column is of type {schema_ts_type}, but the expected type is {self.dataType}" ) + assert isinstance(schema_ts_type, type(self.dataType)), \ + f"The TSIndex column is of type {schema_ts_type}, but the expected type is {self.dataType}" def renamed(self, new_name: str) -> "TSIndex": self.__name = new_name @@ -224,7 +211,7 @@ def __init__(self, ts_col: StructField) -> None: super().__init__(ts_col) @property - def unit(self) -> Optional[TimeUnits]: + def unit(self) -> Optional[TimeUnit]: return None def rangeExpr(self, reverse: bool = False) -> Column: @@ -245,8 +232,8 @@ def __init__(self, ts_col: StructField) -> None: super().__init__(ts_col) @property - def unit(self) -> Optional[TimeUnits]: - return TimeUnits.SECONDS + def unit(self) -> Optional[TimeUnit]: + return StandardTimeUnits.SECONDS def rangeExpr(self, reverse: bool = False) -> Column: # cast timestamp to double (fractional seconds since epoch) @@ -268,8 +255,8 @@ def __init__(self, ts_col: StructField) -> None: super().__init__(ts_col) @property - def unit(self) -> Optional[TimeUnits]: - return TimeUnits.DAYS + def unit(self) -> Optional[TimeUnit]: + return StandardTimeUnits.DAYS def rangeExpr(self, reverse: bool = False) -> Column: # convert date to number of days since the epoch @@ -408,6 +395,7 @@ def __init__( self, ts_struct: StructField, parsed_ts_col: str, src_str_col: str ) -> None: super().__init__(ts_struct, parsed_ts_col, src_str_col) + # validate the parsed column as a timestamp column parsed_ts_field = self.schema[self.__parsed_ts_col] if not isinstance(parsed_ts_field.dataType, TimestampType): raise TypeError( @@ -422,8 +410,70 @@ def rangeExpr(self, reverse: bool = False) -> Column: return self._reverseOrNot(expr, reverse) @property - def unit(self) -> Optional[TimeUnits]: - return TimeUnits.SECONDS + def unit(self) -> Optional[TimeUnit]: + return StandardTimeUnits.SECONDS + + +class SubMicrosecondPrecisionTimestampIndex(ParsedTimestampIndex): + """ + Timeseries index class for timestamps with sub-microsecond precision + parsed from a string column. Internally, the timestamps are stored as + doubles (fractional seconds since epoch), as well as the original string + and a micro-second precision (standard) timestamp field. + """ + + def __init__(self, + ts_struct: StructField, + double_ts_col: str, + parsed_ts_col: str, + src_str_col: str, + num_precision_digits: int = 9) -> None: + """ + :param ts_struct: The StructField for the TSIndex column + :param double_ts_col: The name of the double-precision timestamp column + :param parsed_ts_col: The name of the parsed timestamp column + :param src_str_col: The name of the source string column + :param num_precision_digits: The number of digits that make up the precision of + the timestamp. Ie. 9 for nanoseconds (default), 12 for picoseconds, etc. + You will receive a warning if this value is 6 or less, as this is the precision + of the standard timestamp type. + """ + super().__init__(ts_struct, parsed_ts_col, src_str_col) + # set & validate the double timestamp column + self.double_ts_col = double_ts_col + # validate the double timestamp column + double_ts_field = self.schema[self.double_ts_col] + if not isinstance(double_ts_field.dataType, DoubleType): + raise TypeError( + f"The double_ts_col must be of DoubleType, " + f"but the given double_ts_col {self.double_ts_col} " + f"has type {double_ts_field.dataType}" + ) + # validate the number of precision digits + if num_precision_digits <= 6: + warnings.warn( + f"SubMicrosecondPrecisionTimestampIndex has a num_precision_digits " + f"of {num_precision_digits} which is within the range of the " + f"standard timestamp precision of 6 digits (microseconds). " + f"Consider using a ParsedTimestampIndex instead." + ) + self.__unit = TimeUnit( + f"custom_subsecond_unit (precision: {num_precision_digits})", + 10 ** (-num_precision_digits), + num_precision_digits, + ) + + def orderByExpr(self, reverse: bool = False) -> Union[Column, List[Column]]: + expr = sfn.col(self.double_ts_col) + return self._reverseOrNot(expr, reverse) + + def rangeExpr(self, reverse: bool = False) -> Column: + # just use the order by expression, since this is the same + return self.orderByExpr(reverse) + + @property + def unit(self) -> Optional[TimeUnit]: + return self.__unit class ParsedDateIndex(ParsedTSIndex): @@ -435,6 +485,7 @@ def __init__( self, ts_struct: StructField, parsed_ts_col: str, src_str_col: str ) -> None: super().__init__(ts_struct, parsed_ts_col, src_str_col) + # validate the parsed column as a date column parsed_ts_field = self.schema[self.__parsed_ts_col] if not isinstance(parsed_ts_field.dataType, DateType): raise TypeError( @@ -444,8 +495,8 @@ def __init__( ) @property - def unit(self) -> Optional[TimeUnits]: - return TimeUnits.DAYS + def unit(self) -> Optional[TimeUnit]: + return StandardTimeUnits.DAYS def rangeExpr(self, reverse: bool = False) -> Column: # convert date to number of days since the epoch @@ -472,7 +523,8 @@ def __init__(self, ts_struct: StructField, *ts_fields: str) -> None: assert len(ts_fields) > 1,\ f"CompositeTSIndex must have at least two timestamp fields, " \ f"but only {len(ts_fields)} were given" - self.ts_components = [SimpleTSIndex.fromTSCol(self.schema[field]) for field in ts_fields] + self.ts_components = \ + [SimpleTSIndex.fromTSCol(self.schema[field]) for field in ts_fields] @property def _indexAttributes(self) -> dict[str, Any]: @@ -488,7 +540,7 @@ def primary_ts_idx(self) -> TSIndex: return self.ts_components[0] @property - def unit(self) -> Optional[TimeUnits]: + def unit(self) -> Optional[TimeUnit]: return self.primary_ts_idx.unit def validate(self, df_schema: StructType) -> None: @@ -666,8 +718,8 @@ def validate(self, df_schema: StructType) -> None: self.ts_idx.validate(df_schema) # check series IDs for sid in self.series_ids: - assert( sid in df_schema.fieldNames(), - f"Series ID {sid} does not exist in the given DataFrame" ) + assert sid in df_schema.fieldNames(), \ + f"Series ID {sid} does not exist in the given DataFrame" def find_observational_columns(self, df_schema: StructType) -> list[str]: return list(set(df_schema.fieldNames()) - set(self.structural_columns)) From 881095bc4754589aaf0db03e47a4954a79a946f9 Mon Sep 17 00:00:00 2001 From: Tristan Nixon Date: Fri, 21 Jul 2023 20:22:12 -0700 Subject: [PATCH 05/13] Updates for parsed timestamps --- python/tempo/tsschema.py | 77 ++++++++++++++++++++++++++++++++-------- 1 file changed, 63 insertions(+), 14 deletions(-) diff --git a/python/tempo/tsschema.py b/python/tempo/tsschema.py index cd0db5e1..8c3f782c 100644 --- a/python/tempo/tsschema.py +++ b/python/tempo/tsschema.py @@ -1,4 +1,5 @@ import warnings +import re from abc import ABC, abstractmethod from typing import Any, Collection, List, Optional, Union @@ -28,6 +29,20 @@ def is_time_format(ts_fmt: str) -> bool: """ return any(c in ts_fmt for c in __time_pattern_components) + +def sub_seconds_precision_digits(ts_fmt: str) -> int: + """ + Returns the number of digits of precision for a timestamp format string + """ + # pattern for matching the sub-second precision digits + sub_seconds_ptrn = r"\.(\S+)" + # find the sub-second precision digits + match = re.search(sub_seconds_ptrn, ts_fmt) + if match is None: + return 0 + else: + return len(match.group(1)) + # # Abstract Timeseries Index Classes # @@ -340,20 +355,20 @@ def __init__( f"but given column {src_str_field.name} " f"is of type {src_str_field.dataType}" ) - self.__src_str_col = src_str_col + self._src_str_col = src_str_col # validate the parsed column assert parsed_ts_col in self.schema.fieldNames(),\ f"The parsed timestamp index field {parsed_ts_col} does not exist in the " \ f"MultiPart TSIndex schema {self.schema}" - self.__parsed_ts_col = parsed_ts_col + self._parsed_ts_col = parsed_ts_col @property def src_str_col(self): - return self.fieldPath(self.__src_str_col) + return self.fieldPath(self._src_str_col) @property def parsed_ts_col(self): - return self.fieldPath(self.__parsed_ts_col) + return self.fieldPath(self._parsed_ts_col) @property def ts_col(self) -> str: @@ -372,18 +387,52 @@ def orderByExpr(self, reverse: bool = False) -> Union[Column, List[Column]]: @classmethod def fromParsedTimestamp(cls, - str_ts_col: str, - ts_fmt: str = DEFAULT_TIMESTAMP_FORMAT) -> "ParsedTimestampIndex": + ts_struct: StructField, + parsed_ts_col: str, + src_str_col: str, + double_ts_col: Optional[str] = None, + num_precision_digits: int = 6) -> "ParsedTSIndex": """ Create a ParsedTimestampIndex from a string column containing timestamps or dates - :param str_ts_col: The name of the string column containing timestamps or dates - :param ts_fmt: The format of the timestamps or dates in the string column + :param ts_struct: The StructField for the TSIndex column + :param parsed_ts_col: The name of the parsed timestamp column + :param src_str_col: The name of the source string column + :param double_ts_col: The name of the double-precision timestamp column + :param num_precision_digits: The number of digits that make up the precision of + + :return: A ParsedTSIndex object + """ + + # if a double timestamp column is given + # then we are building a SubMicrosecondPrecisionTimestampIndex + if double_ts_col is not None: + return SubMicrosecondPrecisionTimestampIndex(ts_struct, + double_ts_col, + parsed_ts_col, + src_str_col, + num_precision_digits) + # otherwise, we base it on the standard timestamp type + # find the schema of the ts_struct column + ts_schema = ts_struct.dataType + if not isinstance(ts_schema, StructType): + raise TypeError( + f"A ParsedTSIndex must be of type StructType, but given " + f"ts_struct {ts_struct.name} has type {ts_struct.dataType}" + ) + # get the type of the parsed timestamp column + parsed_ts_type = ts_schema[parsed_ts_col].dataType + if isinstance(parsed_ts_type, TimestampType): + return ParsedTimestampIndex(ts_struct, parsed_ts_col, src_str_col) + elif isinstance(parsed_ts_type, DateType): + return ParsedDateIndex(ts_struct, parsed_ts_col, src_str_col) + else: + raise TypeError( + f"ParsedTimestampIndex must be of TimestampType or DateType, " + f"but given ts_col {parsed_ts_col} " + f"has type {parsed_ts_type}" + ) - :return: A ParsedTimestampIndex - """ - # TODO fill this in - pass class ParsedTimestampIndex(ParsedTSIndex): @@ -396,7 +445,7 @@ def __init__( ) -> None: super().__init__(ts_struct, parsed_ts_col, src_str_col) # validate the parsed column as a timestamp column - parsed_ts_field = self.schema[self.__parsed_ts_col] + parsed_ts_field = self.schema[self._parsed_ts_col] if not isinstance(parsed_ts_field.dataType, TimestampType): raise TypeError( f"ParsedTimestampIndex must be of TimestampType, " @@ -486,7 +535,7 @@ def __init__( ) -> None: super().__init__(ts_struct, parsed_ts_col, src_str_col) # validate the parsed column as a date column - parsed_ts_field = self.schema[self.__parsed_ts_col] + parsed_ts_field = self.schema[self._parsed_ts_col] if not isinstance(parsed_ts_field.dataType, DateType): raise TypeError( f"ParsedTimestampIndex must be of DateType, " From a4ae322a157c0f98dba7a90b5a1fa1016214126a Mon Sep 17 00:00:00 2001 From: Tristan Nixon Date: Wed, 3 Jan 2024 15:50:41 -0800 Subject: [PATCH 06/13] normalizing whitespace --- python/tempo/timeunit.py | 44 ++++++++++++++++++++-------------------- 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/python/tempo/timeunit.py b/python/tempo/timeunit.py index c7282a17..f472a07c 100644 --- a/python/tempo/timeunit.py +++ b/python/tempo/timeunit.py @@ -4,20 +4,20 @@ @total_ordering class TimeUnit(NamedTuple): - name: str - approx_seconds: float - sub_second_precision: int = 0 - """ - Represents a unit of time, with a name, - an approximate number of seconds, - and a sub-second precision. - """ + name: str + approx_seconds: float + sub_second_precision: int = 0 + """ + Represents a unit of time, with a name, + an approximate number of seconds, + and a sub-second precision. + """ - def __eq__(self, other): - return self.approx_seconds == other.approx_seconds + def __eq__(self, other): + return self.approx_seconds == other.approx_seconds - def __lt__(self, other): - return self.approx_seconds < other.approx_seconds + def __lt__(self, other): + return self.approx_seconds < other.approx_seconds TimeUnitsType = NamedTuple("TimeUnitsType", @@ -33,14 +33,14 @@ def __lt__(self, other): ("NANOSECONDS", TimeUnit)]) StandardTimeUnits = TimeUnitsType( - TimeUnit("year", 365 * 24 * 60 * 60), - TimeUnit("month", 30 * 24 * 60 * 60), - TimeUnit("week", 7 * 24 * 60 * 60), - TimeUnit("day", 24 * 60 * 60), - TimeUnit("hour", 60 * 60), - TimeUnit("minute", 60), - TimeUnit("second", 1), - TimeUnit("millisecond", 1e-03, 3), - TimeUnit("microsecond", 1e-06, 6), - TimeUnit("nanosecond", 1e-09, 9) + TimeUnit("year", 365 * 24 * 60 * 60), + TimeUnit("month", 30 * 24 * 60 * 60), + TimeUnit("week", 7 * 24 * 60 * 60), + TimeUnit("day", 24 * 60 * 60), + TimeUnit("hour", 60 * 60), + TimeUnit("minute", 60), + TimeUnit("second", 1), + TimeUnit("millisecond", 1e-03, 3), + TimeUnit("microsecond", 1e-06, 6), + TimeUnit("nanosecond", 1e-09, 9) ) From cca500664e01fd99051940effce2796ca61bd0d6 Mon Sep 17 00:00:00 2001 From: Tristan Nixon Date: Wed, 3 Jan 2024 15:51:03 -0800 Subject: [PATCH 07/13] big simplification of TS index classes --- python/tempo/tsschema.py | 533 ++++++++++++++++++++------------------- 1 file changed, 272 insertions(+), 261 deletions(-) diff --git a/python/tempo/tsschema.py b/python/tempo/tsschema.py index 8c3f782c..2f2fb505 100644 --- a/python/tempo/tsschema.py +++ b/python/tempo/tsschema.py @@ -43,6 +43,7 @@ def sub_seconds_precision_digits(ts_fmt: str) -> int: else: return len(match.group(1)) + # # Abstract Timeseries Index Classes # @@ -79,13 +80,6 @@ def colname(self) -> str: :return: the column name of the timeseries index """ - @property - @abstractmethod - def ts_col(self) -> str: - """ - :return: the name of the primary timeseries column (may or may not be the same as the name) - """ - @property @abstractmethod def unit(self) -> Optional[TimeUnit]: @@ -127,26 +121,33 @@ def _reverseOrNot( elif isinstance(expr, list): return [col.desc() for col in expr] # reverse all columns in the expression else: - raise TypeError(f"Type for expr argument must be either Column or List[Column], instead received: {type(expr)}") + raise TypeError( + f"Type for expr argument must be either Column or " + f"List[Column], instead received: {type(expr)}" + ) @abstractmethod def orderByExpr(self, reverse: bool = False) -> Union[Column, List[Column]]: """ - Gets an expression that will order the :class:`TSDF` according to the timeseries index. + Gets an expression that will order the :class:`TSDF` + according to the timeseries index. :param reverse: whether the ordering should be reversed (backwards in time) - :return: an expression appropriate for ordering the :class:`TSDF` according to this index + :return: an expression appropriate for ordering the :class:`TSDF` + according to this index """ @abstractmethod def rangeExpr(self, reverse: bool = False) -> Column: """ - Gets an expression appropriate for performing range operations on the :class:`TSDF` records. + Gets an expression appropriate for performing range operations + on the :class:`TSDF` records. :param reverse: whether the ordering should be reversed (backwards in time) - :return: an expression appropriate for operforming range operations on the :class:`TSDF` records + :return: an expression appropriate for performing range operations + on the :class:`TSDF` records """ @@ -162,25 +163,23 @@ def __init__(self, ts_col: StructField) -> None: @property def _indexAttributes(self) -> dict[str, Any]: - return {"name": self.colname, "dataType": self.dataType} + return {"name": self.colname, "dataType": self.dataType, "unit": self.unit} @property def colname(self): return self.__name - @property - def ts_col(self) -> str: - return self.colname - def validate(self, df_schema: StructType) -> None: # the ts column must exist - assert self.colname in df_schema.fieldNames(), \ - f"The TSIndex column {self.colname} does not exist in the given DataFrame" + assert ( + self.colname in df_schema.fieldNames() + ), f"The TSIndex column {self.colname} does not exist in the given DataFrame" schema_ts_col = df_schema[self.colname] # it must have the right type schema_ts_type = schema_ts_col.dataType - assert isinstance(schema_ts_type, type(self.dataType)), \ - f"The TSIndex column is of type {schema_ts_type}, but the expected type is {self.dataType}" + assert isinstance( + schema_ts_type, type(self.dataType) + ), f"The TSIndex column is of type {schema_ts_type}, but the expected type is {self.dataType}" def renamed(self, new_name: str) -> "TSIndex": self.__name = new_name @@ -209,9 +208,10 @@ def fromTSCol(cls, ts_col: StructField) -> "SimpleTSIndex": # Simple TS Index types # + class OrdinalTSIndex(SimpleTSIndex): """ - Timeseries index based on a single column of a numeric or temporal type. + Timeseries index based on a single column of a numeric type. This index is "unitless", meaning that it is not associated with any particular unit of time. It can provide ordering of records, but not range operations. @@ -280,42 +280,71 @@ def rangeExpr(self, reverse: bool = False) -> Column: # -# Multi-Part TS Index types +# Complex (Multi-Field) TS Index types # -class MultiPartTSIndex(TSIndex, ABC): + +class CompositeTSIndex(TSIndex, ABC): """ Abstract base class for Timeseries Index types that reference multiple columns. Such columns are organized as a StructType column with multiple fields. + Some subset of these columns (at least 1) is considered to be a "component field", + the others are called "accessory fields". """ - def __init__(self, ts_struct: StructField) -> None: + def __init__(self, ts_struct: StructField, *component_fields: str) -> None: if not isinstance(ts_struct.dataType, StructType): raise TypeError( f"CompoundTSIndex must be of type StructType, but given " f"ts_struct {ts_struct.name} has type {ts_struct.dataType}" ) + # validate the index fields + assert len(component_fields) > 0, ( + f"A MultiFieldTSIndex must have at least 1 index component field, " + f"but {len(component_fields)} were given" + ) + for ind_f in component_fields: + assert ( + ind_f in ts_struct.dataType.fieldNames() + ), f"Index field {ind_f} does not exist in the given TSIndex schema" + # assign local attributes self.__name: str = ts_struct.name self.schema: StructType = ts_struct.dataType + self.component_fields: List[str] = list(component_fields) @property def _indexAttributes(self) -> dict[str, Any]: - return {"name": self.colname, "schema": self.schema} + return { + "name": self.colname, + "schema": self.schema, + "unit": self.unit, + "component_fields": self.component_fields, + } @property def colname(self) -> str: return self.__name + @property + def fieldNames(self) -> List[str]: + return self.schema.fieldNames() + + @property + def accessory_fields(self) -> List[str]: + return list(set(self.fieldNames) - set(self.component_fields)) + def validate(self, df_schema: StructType) -> None: # validate that the composite field exists - assert self.colname in df_schema.fieldNames(),\ - f"The TSIndex column {self.colname} does not exist in the given DataFrame" + assert ( + self.colname in df_schema.fieldNames() + ), f"The TSIndex column {self.colname} does not exist in the given DataFrame" schema_ts_col = df_schema[self.colname] # it must have the right type schema_ts_type = schema_ts_col.dataType - assert schema_ts_type == self.schema,\ - f"The TSIndex column is of type {schema_ts_type}, "\ + assert schema_ts_type == self.schema, ( + f"The TSIndex column is of type {schema_ts_type}, " f"but the expected type is {self.schema}" + ) def renamed(self, new_name: str) -> "TSIndex": self.__name = new_name @@ -327,143 +356,204 @@ def fieldPath(self, field: str) -> str: :return: A dot-separated path to the given field within the TSIndex column """ - assert field in self.schema.fieldNames(),\ - f"Field {field} does not exist in the TSIndex schema {self.schema}" + assert ( + field in self.fieldNames + ), f"Field {field} does not exist in the TSIndex schema {self.schema}" return f"{self.colname}.{field}" + def orderByExpr(self, reverse: bool = False) -> Union[Column, List[Column]]: + # build an expression for each TS component, in order + exprs = [sfn.col(self.fieldPath(comp)) for comp in self.component_fields] + return self._reverseOrNot(exprs, reverse) + # # Parsed TS Index types # -class ParsedTSIndex(MultiPartTSIndex, ABC): +# class ParsedTSIndex(CompositeTSIndex, ABC): +# """ +# Abstract base class for timeseries indices that are parsed from a string column. +# Retains the original string form as well as the parsed column. +# """ +# +# def __init__( +# self, ts_struct: StructField, parsed_ts_col: str, src_str_col: str +# ) -> None: +# super().__init__(ts_struct, parsed_ts_col) +# # validate the source string column +# src_str_field = self.schema[src_str_col] +# if not isinstance(src_str_field.dataType, StringType): +# raise TypeError( +# f"Source string column must be of StringType, " +# f"but given column {src_str_field.name} " +# f"is of type {src_str_field.dataType}" +# ) +# self._src_str_col = src_str_col +# # validate the parsed column +# assert parsed_ts_col in self.schema.fieldNames(), ( +# f"The parsed timestamp index field {parsed_ts_col} does not exist in the " +# f"MultiPart TSIndex schema {self.schema}" +# ) +# self._parsed_ts_col = parsed_ts_col +# +# @property +# def src_str_col(self): +# return self.fieldPath(self._src_str_col) +# +# @property +# def parsed_ts_col(self): +# return self.fieldPath(self._parsed_ts_col) +# +# @property +# def ts_col(self) -> str: +# return self.parsed_ts_col +# +# @property +# def _indexAttributes(self) -> dict[str, Any]: +# attrs = super()._indexAttributes +# attrs["parsed_ts_col"] = self.parsed_ts_col +# attrs["src_str_col"] = self.src_str_col +# return attrs +# +# def orderByExpr(self, reverse: bool = False) -> Union[Column, List[Column]]: +# expr = sfn.col(self.parsed_ts_col) +# return self._reverseOrNot(expr, reverse) +# +# @classmethod +# def fromParsedTimestamp( +# cls, +# ts_struct: StructField, +# parsed_ts_col: str, +# src_str_col: str, +# double_ts_col: Optional[str] = None, +# num_precision_digits: int = 6, +# ) -> "ParsedTSIndex": +# """ +# Create a ParsedTimestampIndex from a string column containing timestamps or dates +# +# :param ts_struct: The StructField for the TSIndex column +# :param parsed_ts_col: The name of the parsed timestamp column +# :param src_str_col: The name of the source string column +# :param double_ts_col: The name of the double-precision timestamp column +# :param num_precision_digits: The number of digits that make up the precision of +# +# :return: A ParsedTSIndex object +# """ +# +# # if a double timestamp column is given +# # then we are building a SubMicrosecondPrecisionTimestampIndex +# if double_ts_col is not None: +# return SubMicrosecondPrecisionTimestampIndex( +# ts_struct, +# double_ts_col, +# parsed_ts_col, +# src_str_col, +# num_precision_digits, +# ) +# # otherwise, we base it on the standard timestamp type +# # find the schema of the ts_struct column +# ts_schema = ts_struct.dataType +# if not isinstance(ts_schema, StructType): +# raise TypeError( +# f"A ParsedTSIndex must be of type StructType, but given " +# f"ts_struct {ts_struct.name} has type {ts_struct.dataType}" +# ) +# # get the type of the parsed timestamp column +# parsed_ts_type = ts_schema[parsed_ts_col].dataType +# if isinstance(parsed_ts_type, TimestampType): +# return ParsedTimestampIndex(ts_struct, parsed_ts_col, src_str_col) +# elif isinstance(parsed_ts_type, DateType): +# return ParsedDateIndex(ts_struct, parsed_ts_col, src_str_col) +# else: +# raise TypeError( +# f"ParsedTimestampIndex must be of TimestampType or DateType, " +# f"but given ts_col {parsed_ts_col} " +# f"has type {parsed_ts_type}" +# ) + + +class ParsedTimestampIndex(CompositeTSIndex): """ - Abstract base class for timeseries indices that are parsed from a string column. - Retains the original string form as well as the parsed column. + Timeseries index class for timestamps parsed from a string column """ def __init__( self, ts_struct: StructField, parsed_ts_col: str, src_str_col: str ) -> None: - super().__init__(ts_struct) - # validate the source string column + super().__init__(ts_struct, parsed_ts_col) + # validate the parsed column as a timestamp column + parsed_ts_field = self.schema[parsed_ts_col] + if not isinstance(parsed_ts_field.dataType, TimestampType): + raise TypeError( + f"ParsedTimestampIndex requires an index field of TimestampType, " + f"but the given parsed_ts_col {parsed_ts_col} " + f"has type {parsed_ts_field.dataType}" + ) + self.parsed_ts_col = parsed_ts_col + # validate the source column as a string column src_str_field = self.schema[src_str_col] if not isinstance(src_str_field.dataType, StringType): raise TypeError( - f"Source string column must be of StringType, " - f"but given column {src_str_field.name} " - f"is of type {src_str_field.dataType}" + f"ParsedTimestampIndex requires an source field of StringType, " + f"but the given src_str_col {src_str_col} " + f"has type {src_str_field.dataType}" ) - self._src_str_col = src_str_col - # validate the parsed column - assert parsed_ts_col in self.schema.fieldNames(),\ - f"The parsed timestamp index field {parsed_ts_col} does not exist in the " \ - f"MultiPart TSIndex schema {self.schema}" - self._parsed_ts_col = parsed_ts_col - - @property - def src_str_col(self): - return self.fieldPath(self._src_str_col) - - @property - def parsed_ts_col(self): - return self.fieldPath(self._parsed_ts_col) - - @property - def ts_col(self) -> str: - return self.parsed_ts_col + self.src_str_col = src_str_col @property - def _indexAttributes(self) -> dict[str, Any]: - attrs = super()._indexAttributes - attrs["parsed_ts_col"] = self.parsed_ts_col - attrs["src_str_col"] = self.src_str_col - return attrs + def unit(self) -> Optional[TimeUnit]: + return StandardTimeUnits.SECONDS - def orderByExpr(self, reverse: bool = False) -> Union[Column, List[Column]]: - expr = sfn.col(self.parsed_ts_col) + def rangeExpr(self, reverse: bool = False) -> Column: + # cast timestamp to double (fractional seconds since epoch) + expr = sfn.col(self.fieldPath(self.parsed_ts_col)).cast("double") return self._reverseOrNot(expr, reverse) - @classmethod - def fromParsedTimestamp(cls, - ts_struct: StructField, - parsed_ts_col: str, - src_str_col: str, - double_ts_col: Optional[str] = None, - num_precision_digits: int = 6) -> "ParsedTSIndex": - """ - Create a ParsedTimestampIndex from a string column containing timestamps or dates - :param ts_struct: The StructField for the TSIndex column - :param parsed_ts_col: The name of the parsed timestamp column - :param src_str_col: The name of the source string column - :param double_ts_col: The name of the double-precision timestamp column - :param num_precision_digits: The number of digits that make up the precision of - - :return: A ParsedTSIndex object - """ - - # if a double timestamp column is given - # then we are building a SubMicrosecondPrecisionTimestampIndex - if double_ts_col is not None: - return SubMicrosecondPrecisionTimestampIndex(ts_struct, - double_ts_col, - parsed_ts_col, - src_str_col, - num_precision_digits) - # otherwise, we base it on the standard timestamp type - # find the schema of the ts_struct column - ts_schema = ts_struct.dataType - if not isinstance(ts_schema, StructType): - raise TypeError( - f"A ParsedTSIndex must be of type StructType, but given " - f"ts_struct {ts_struct.name} has type {ts_struct.dataType}" - ) - # get the type of the parsed timestamp column - parsed_ts_type = ts_schema[parsed_ts_col].dataType - if isinstance(parsed_ts_type, TimestampType): - return ParsedTimestampIndex(ts_struct, parsed_ts_col, src_str_col) - elif isinstance(parsed_ts_type, DateType): - return ParsedDateIndex(ts_struct, parsed_ts_col, src_str_col) - else: - raise TypeError( - f"ParsedTimestampIndex must be of TimestampType or DateType, " - f"but given ts_col {parsed_ts_col} " - f"has type {parsed_ts_type}" - ) - - - -class ParsedTimestampIndex(ParsedTSIndex): +class ParsedDateIndex(CompositeTSIndex): """ - Timeseries index class for timestamps parsed from a string column + Timeseries index class for dates parsed from a string column """ def __init__( - self, ts_struct: StructField, parsed_ts_col: str, src_str_col: str + self, ts_struct: StructField, parsed_date_col: str, src_str_col: str ) -> None: - super().__init__(ts_struct, parsed_ts_col, src_str_col) - # validate the parsed column as a timestamp column - parsed_ts_field = self.schema[self._parsed_ts_col] - if not isinstance(parsed_ts_field.dataType, TimestampType): + super().__init__(ts_struct, parsed_date_col) + # validate the parsed column as a date column + parsed_date_field = self.schema[parsed_date_col] + if not isinstance(parsed_date_field.dataType, DateType): raise TypeError( - f"ParsedTimestampIndex must be of TimestampType, " - f"but given ts_col {self.parsed_ts_col} " - f"has type {parsed_ts_field.dataType}" + f"ParsedDateIndex requires an index field of DateType, " + f"but the given parsed_ts_col {parsed_date_col} " + f"has type {parsed_date_field.dataType}" ) - - def rangeExpr(self, reverse: bool = False) -> Column: - # cast timestamp to double (fractional seconds since epoch) - expr = sfn.col(self.parsed_ts_col).cast("double") - return self._reverseOrNot(expr, reverse) + self.parsed_date_col = parsed_date_col + # validate the source column as a string column + src_str_field = self.schema[src_str_col] + if not isinstance(src_str_field.dataType, StringType): + raise TypeError( + f"ParsedDateIndex requires an source field of StringType, " + f"but the given src_str_col {src_str_col} " + f"has type {src_str_field.dataType}" + ) + self.src_str_col = src_str_col @property def unit(self) -> Optional[TimeUnit]: - return StandardTimeUnits.SECONDS + return StandardTimeUnits.DAYS + + def rangeExpr(self, reverse: bool = False) -> Column: + # convert date to number of days since the epoch + expr = sfn.datediff( + sfn.col(self.fieldPath(self.parsed_date_col)), + sfn.lit("1970-01-01").cast("date"), + ) + return self._reverseOrNot(expr, reverse) -class SubMicrosecondPrecisionTimestampIndex(ParsedTimestampIndex): +class SubMicrosecondPrecisionTimestampIndex(CompositeTSIndex): """ Timeseries index class for timestamps with sub-microsecond precision parsed from a string column. Internally, the timestamps are stored as @@ -471,12 +561,14 @@ class SubMicrosecondPrecisionTimestampIndex(ParsedTimestampIndex): and a micro-second precision (standard) timestamp field. """ - def __init__(self, - ts_struct: StructField, - double_ts_col: str, - parsed_ts_col: str, - src_str_col: str, - num_precision_digits: int = 9) -> None: + def __init__( + self, + ts_struct: StructField, + double_ts_col: str, + parsed_ts_col: str, + src_str_col: str, + num_precision_digits: int = 9, + ) -> None: """ :param ts_struct: The StructField for the TSIndex column :param double_ts_col: The name of the double-precision timestamp column @@ -487,17 +579,16 @@ def __init__(self, You will receive a warning if this value is 6 or less, as this is the precision of the standard timestamp type. """ - super().__init__(ts_struct, parsed_ts_col, src_str_col) - # set & validate the double timestamp column - self.double_ts_col = double_ts_col + super().__init__(ts_struct, double_ts_col) # validate the double timestamp column - double_ts_field = self.schema[self.double_ts_col] + double_ts_field = self.schema[double_ts_col] if not isinstance(double_ts_field.dataType, DoubleType): raise TypeError( f"The double_ts_col must be of DoubleType, " - f"but the given double_ts_col {self.double_ts_col} " + f"but the given double_ts_col {double_ts_col} " f"has type {double_ts_field.dataType}" ) + self.double_ts_col = double_ts_col # validate the number of precision digits if num_precision_digits <= 6: warnings.warn( @@ -511,121 +602,42 @@ def __init__(self, 10 ** (-num_precision_digits), num_precision_digits, ) - - def orderByExpr(self, reverse: bool = False) -> Union[Column, List[Column]]: - expr = sfn.col(self.double_ts_col) - return self._reverseOrNot(expr, reverse) - - def rangeExpr(self, reverse: bool = False) -> Column: - # just use the order by expression, since this is the same - return self.orderByExpr(reverse) - - @property - def unit(self) -> Optional[TimeUnit]: - return self.__unit - - -class ParsedDateIndex(ParsedTSIndex): - """ - Timeseries index class for dates parsed from a string column - """ - - def __init__( - self, ts_struct: StructField, parsed_ts_col: str, src_str_col: str - ) -> None: - super().__init__(ts_struct, parsed_ts_col, src_str_col) - # validate the parsed column as a date column - parsed_ts_field = self.schema[self._parsed_ts_col] - if not isinstance(parsed_ts_field.dataType, DateType): + # validate the parsed column as a timestamp column + parsed_ts_field = self.schema[parsed_ts_col] + if not isinstance(parsed_ts_field.dataType, TimestampType): raise TypeError( - f"ParsedTimestampIndex must be of DateType, " - f"but given ts_col {self.parsed_ts_col} " + f"parsed_ts_col field must be of TimestampType, " + f"but the given parsed_ts_col {parsed_ts_col} " f"has type {parsed_ts_field.dataType}" ) + self.parsed_ts_col = parsed_ts_col + # validate the source column as a string column + src_str_field = self.schema[src_str_col] + if not isinstance(src_str_field.dataType, StringType): + raise TypeError( + f"src_str_col field must be of StringType, " + f"but the given src_str_col {src_str_col} " + f"has type {src_str_field.dataType}" + ) + self.src_str_col = src_str_col @property def unit(self) -> Optional[TimeUnit]: - return StandardTimeUnits.DAYS + return self.__unit - def rangeExpr(self, reverse: bool = False) -> Column: - # convert date to number of days since the epoch - expr = sfn.datediff( - sfn.col(self.parsed_ts_col), sfn.lit("1970-01-01").cast("date") - ) + def orderByExpr(self, reverse: bool = False) -> Union[Column, List[Column]]: + expr = sfn.col(self.fieldPath(self.double_ts_col)) return self._reverseOrNot(expr, reverse) - -# -# Complex (Multi-Field) TS Index Types -# - - -class CompositeTSIndex(MultiPartTSIndex, ABC): - """ - Abstract base class for complex Timeseries Index classes - that involve two or more columns organized into a StructType column - """ - - def __init__(self, ts_struct: StructField, *ts_fields: str) -> None: - super().__init__(ts_struct) - # handle the timestamp fields - assert len(ts_fields) > 1,\ - f"CompositeTSIndex must have at least two timestamp fields, " \ - f"but only {len(ts_fields)} were given" - self.ts_components = \ - [SimpleTSIndex.fromTSCol(self.schema[field]) for field in ts_fields] - - @property - def _indexAttributes(self) -> dict[str, Any]: - attrs = super()._indexAttributes - attrs["ts_components"] = [str(c) for c in self.ts_components] - return attrs - - def primary_ts_col(self) -> str: - return self.get_ts_component(0) - - @property - def primary_ts_idx(self) -> TSIndex: - return self.ts_components[0] - - @property - def unit(self) -> Optional[TimeUnit]: - return self.primary_ts_idx.unit - - def validate(self, df_schema: StructType) -> None: - super().validate(df_schema) - # validate all the TS components - schema_ts_type = df_schema[self.colname].dataType - assert isinstance(schema_ts_type, StructType),\ - f"CompositeTSIndex must be of StructType, " \ - f"but given ts_col {self.colname} " \ - f"has type {schema_ts_type}" - for comp in self.ts_components: - comp.validate(schema_ts_type) - - def get_ts_component(self, component_index: int) -> str: - """ - Returns the full path to a component field that is a functional part of the timeseries. - - :param component_index: the index giving the ordering of the component field within the timeseries - - :return: a column name that can be used to reference the component field in PySpark expressions - """ - return self.fieldPath(self.ts_components[component_index].colname) - - def orderByExpr(self, reverse: bool = False) -> Column: - # build an expression for each TS component, in order - exprs = [sfn.col(self.fieldPath(comp.colname)) for comp in self.ts_components] - return self._reverseOrNot(exprs, reverse) - - def rangeExpr(self, reverse: bool = False) -> Column: - return self.primary_ts_idx.rangeExpr(reverse) - + def rangeExpr(self, reverse: bool = False) -> Union[Column, List[Column]]: + # just use the order by expression, since this is the same + return self.orderByExpr(reverse) # # Window Builder Interface # + class WindowBuilder(ABC): """ Abstract base class for window builders. @@ -643,10 +655,9 @@ def baseWindow(self, reverse: bool = False) -> WindowSpec: pass @abstractmethod - def rowsBetweenWindow(self, - start: int, - end: int, - reverse: bool = False) -> WindowSpec: + def rowsBetweenWindow( + self, start: int, end: int, reverse: bool = False + ) -> WindowSpec: """ build a row-based window with the given start and end offsets @@ -667,8 +678,7 @@ def allBeforeWindow(self, inclusive: bool = True) -> WindowSpec: :return: a WindowSpec object """ - return self.rowsBetweenWindow(Window.unboundedPreceding, - 0 if inclusive else -1) + return self.rowsBetweenWindow(Window.unboundedPreceding, 0 if inclusive else -1) def allAfterWindow(self, inclusive: bool = True) -> WindowSpec: """ @@ -679,14 +689,12 @@ def allAfterWindow(self, inclusive: bool = True) -> WindowSpec: :return: a WindowSpec object """ - return self.rowsBetweenWindow(0 if inclusive else 1, - Window.unboundedFollowing) + return self.rowsBetweenWindow(0 if inclusive else 1, Window.unboundedFollowing) @abstractmethod - def rangeBetweenWindow(self, - start: int, - end: int, - reverse: bool = False) -> WindowSpec: + def rangeBetweenWindow( + self, start: int, end: int, reverse: bool = False + ) -> WindowSpec: """ build a range-based window with the given start and end offsets @@ -767,8 +775,9 @@ def validate(self, df_schema: StructType) -> None: self.ts_idx.validate(df_schema) # check series IDs for sid in self.series_ids: - assert sid in df_schema.fieldNames(), \ - f"Series ID {sid} does not exist in the given DataFrame" + assert ( + sid in df_schema.fieldNames() + ), f"Series ID {sid} does not exist in the given DataFrame" def find_observational_columns(self, df_schema: StructType) -> list[str]: return list(set(df_schema.fieldNames()) - set(self.structural_columns)) @@ -796,14 +805,16 @@ def baseWindow(self, reverse: bool = False) -> WindowSpec: w = w.partitionBy([sfn.col(sid) for sid in self.series_ids]) return w - def rowsBetweenWindow(self, start: int, end: int, reverse: bool = False) -> WindowSpec: + def rowsBetweenWindow( + self, start: int, end: int, reverse: bool = False + ) -> WindowSpec: return self.baseWindow(reverse=reverse).rowsBetween(start, end) - def rangeBetweenWindow(self, start: int, end: int, reverse: bool = False) -> WindowSpec: + def rangeBetweenWindow( + self, start: int, end: int, reverse: bool = False + ) -> WindowSpec: return ( self.baseWindow(reverse=reverse) .orderBy(self.ts_idx.rangeExpr(reverse=reverse)) .rangeBetween(start, end) ) - - From d488d7c0b1416ad4683c6651d54809bfb81b9681 Mon Sep 17 00:00:00 2001 From: Tristan Nixon Date: Sun, 7 Jan 2024 00:38:28 -0800 Subject: [PATCH 08/13] checkpoint save - testing basic index functionality --- python/requirements.txt | 1 + python/tempo/tsdf.py | 72 +-- python/tempo/tsschema.py | 577 +++++++++++------- python/tests/tsschema_tests.py | 235 +++++++ python/tests/unit_test_data/tsdf_tests.json | 52 +- .../tests/unit_test_data/tsschema_tests.json | 26 + 6 files changed, 653 insertions(+), 310 deletions(-) create mode 100644 python/tests/tsschema_tests.py create mode 100644 python/tests/unit_test_data/tsschema_tests.json diff --git a/python/requirements.txt b/python/requirements.txt index 1a6844a9..0b801db8 100644 --- a/python/requirements.txt +++ b/python/requirements.txt @@ -17,3 +17,4 @@ sphinx-design==0.2.0 sphinx-panels==0.6.0 jsonref==1.1.0 python-dateutil==2.8.2 +parameterized==0.8.1 \ No newline at end of file diff --git a/python/tempo/tsdf.py b/python/tempo/tsdf.py index fe32e3e4..345f7c35 100644 --- a/python/tempo/tsdf.py +++ b/python/tempo/tsdf.py @@ -23,7 +23,8 @@ import tempo.resample as t_resample import tempo.utils as t_utils from tempo.intervals import IntervalsDF -from tempo.tsschema import CompositeTSIndex, TSIndex, TSSchema, WindowBuilder +from tempo.tsschema import DEFAULT_TIMESTAMP_FORMAT, is_time_format, \ + CompositeTSIndex, TSIndex, TSSchema, WindowBuilder logger = logging.getLogger(__name__) @@ -120,25 +121,39 @@ def fromSubsequenceCol( # construct & return the TSDF with appropriate schema return TSDF(with_subseq_struct_df, ts_schema=TSSchema(subseq_idx, series_ids)) + # default column name for parsed timeseries column + __DEFAULT_PARSED_TS_COL = "parsed_ts" + @classmethod - def fromTimestampString( + def fromStringTimestamp( cls, df: DataFrame, ts_col: str, series_ids: Collection[str] = None, - ts_fmt: str = "YYYY-MM-DDThh:mm:ss[.SSSSSS]", + ts_fmt: str = DEFAULT_TIMESTAMP_FORMAT, ) -> "TSDF": - pass - - @classmethod - def fromDateString( - cls, - df: DataFrame, - ts_col: str, - series_ids: Collection[str], - date_fmt: str = "YYYY-MM-DD", - ) -> "TSDF ": - pass + # parse the ts_col based on the pattern + if is_time_format(ts_fmt): + # if the ts_fmt is a time format, we can use to_timestamp + ts_expr = sfn.to_timestamp(sfn.col(ts_col), ts_fmt) + else: + # otherwise, we'll use to_date + ts_expr = sfn.to_date(sfn.col(ts_col), ts_fmt) + # parse the ts_col give the expression + parsed_ts_col = cls.__DEFAULT_PARSED_TS_COL + parsed_df = df.withColumn(cls.__DEFAULT_PARSED_TS_COL, ts_expr) + # move the ts cols into a struct + struct_col_name = cls.__DEFAULT_TS_IDX_COL + with_parsed_struct_df = cls.__makeStructFromCols(parsed_df, + struct_col_name, + [ts_col, parsed_ts_col]) + # construct an appropriate TSIndex + parsed_struct = with_parsed_struct_df.schema[struct_col_name] + parsed_ts_idx = ParsedTSIndex.fromParsedTimestamp(parsed_struct, + parsed_ts_col, + ts_col) + # construct & return the TSDF with appropriate schema + return TSDF(with_parsed_struct_df, ts_schema=TSSchema(parsed_ts_idx, series_ids)) @property def ts_index(self) -> "TSIndex": @@ -146,7 +161,8 @@ def ts_index(self) -> "TSIndex": @property def ts_col(self) -> str: - return self.ts_index.ts_col + # TODO - this should be replaced TSIndex expressions + pass @property def columns(self) -> List[str]: @@ -408,22 +424,6 @@ def where(self, condition: Union[Column, str]) -> "TSDF": where_df = self.df.where(condition) return self.__withTransformedDF(where_df) - def __slice(self, op: str, target_ts: Any) -> "TSDF": - """ - Private method to slice TSDF by time - - :param op: string symbol of the operation to perform - :type op: str - :param target_ts: timestamp on which to filter - - :return: a TSDF object containing only those records within the time slice specified - """ - # quote our timestamp if its a string - target_expr = f"'{target_ts}'" if isinstance(target_ts, str) else target_ts - slice_expr = sfn.expr(f"{self.ts_col} {op} {target_expr}") - sliced_df = self.df.where(slice_expr) - return self.__withTransformedDF(sliced_df) - def at(self, ts: Any) -> "TSDF": """ Select only records at a given time @@ -432,7 +432,7 @@ def at(self, ts: Any) -> "TSDF": :return: a :class:`~tsdf.TSDF` object containing just the records at the given time """ - return self.__slice("==", ts) + return self.where(self.ts_index == ts) def before(self, ts: Any) -> "TSDF": """ @@ -442,7 +442,7 @@ def before(self, ts: Any) -> "TSDF": :return: a :class:`~tsdf.TSDF` object containing just the records before the given time """ - return self.__slice("<", ts) + return self.where(self.ts_index < ts) def atOrBefore(self, ts: Any) -> "TSDF": """ @@ -452,7 +452,7 @@ def atOrBefore(self, ts: Any) -> "TSDF": :return: a :class:`~tsdf.TSDF` object containing just the records at or before the given time """ - return self.__slice("<=", ts) + return self.where(self.ts_index <= ts) def after(self, ts: Any) -> "TSDF": """ @@ -462,7 +462,7 @@ def after(self, ts: Any) -> "TSDF": :return: a :class:`~tsdf.TSDF` object containing just the records after the given time """ - return self.__slice(">", ts) + return self.where(self.ts_index > ts) def atOrAfter(self, ts: Any) -> "TSDF": """ @@ -472,7 +472,7 @@ def atOrAfter(self, ts: Any) -> "TSDF": :return: a :class:`~tsdf.TSDF` object containing just the records at or after the given time """ - return self.__slice(">=", ts) + return self.where(self.ts_index >= ts) def between( self, start_ts: Any, end_ts: Any, inclusive: bool = True diff --git a/python/tempo/tsschema.py b/python/tempo/tsschema.py index 2f2fb505..48cb1129 100644 --- a/python/tempo/tsschema.py +++ b/python/tempo/tsschema.py @@ -1,7 +1,7 @@ -import warnings import re +import warnings from abc import ABC, abstractmethod -from typing import Any, Collection, List, Optional, Union +from typing import Collection, List, Optional, Union, Callable import pyspark.sql.functions as sfn from pyspark.sql import Column, Window, WindowSpec @@ -14,6 +14,7 @@ # Timestamp parsing helpers # +EPOCH_START_DATE = "1970-01-01" DEFAULT_TIMESTAMP_FORMAT = "yyyy-MM-dd HH:mm:ss" __time_pattern_components = "hHkKmsS" @@ -44,6 +45,52 @@ def sub_seconds_precision_digits(ts_fmt: str) -> int: return len(match.group(1)) +def _col_or_lit(other) -> Column: + """ + Helper function for managing unknown argument types into + Column expressions + + :param other: the argument to convert to a Column expression + + :return: a Column expression + """ + if isinstance(other, (list, tuple)): + if len(other) != 1: + raise ValueError( + "Cannot compare a TSIndex with a list or tuple " + f"of length {len(other)}: {other}" + ) + return _col_or_lit(other[0]) + if isinstance(other, Column): + return other + else: + return sfn.lit(other) + + +def _reverse_or_not( + expr: Union[Column, List[Column]], reverse: bool +) -> Union[Column, List[Column]]: + """ + Helper function for reversing the ordering of an expression, if necessary + + :param expr: the expression to reverse + :param reverse: whether to reverse the expression + + :return: the expression, reversed if necessary + """ + if not reverse: + return expr # just return the expression as-is if we're not reversing + elif isinstance(expr, Column): + return expr.desc() # reverse a single-expression + elif isinstance(expr, list): + return [col.desc() for col in expr] # reverse all columns in the expression + else: + raise TypeError( + "Type for expr argument must be either Column or " + f"List[Column], instead received: {type(expr)}" + ) + + # # Abstract Timeseries Index Classes # @@ -54,30 +101,18 @@ class TSIndex(ABC): Abstract base class for all Timeseries Index types """ - def __eq__(self, o: object) -> bool: - # must be a SimpleTSIndex - if not isinstance(o, TSIndex): - return False - return self._indexAttributes == o._indexAttributes - - def __repr__(self) -> str: - return self.__str__() - - def __str__(self) -> str: - return f"""{self.__class__.__name__}({self._indexAttributes})""" - @property @abstractmethod - def _indexAttributes(self) -> dict[str, Any]: + def colname(self) -> str: """ - :return: key attributes of this index + :return: the column name of the timeseries index """ @property @abstractmethod - def colname(self) -> str: + def dataType(self) -> DataType: """ - :return: the column name of the timeseries index + :return: the data type of the timeseries index """ @property @@ -111,20 +146,36 @@ def renamed(self, new_name: str) -> "TSIndex": :return: a copy of this :class:`TSIndex` object with the new name """ - def _reverseOrNot( - self, expr: Union[Column, List[Column]], reverse: bool - ) -> Union[Column, List[Column]]: - if not reverse: - return expr # just return the expression as-is if we're not reversing - elif isinstance(expr, Column): - return expr.desc() # reverse a single-expression - elif isinstance(expr, list): - return [col.desc() for col in expr] # reverse all columns in the expression - else: - raise TypeError( - f"Type for expr argument must be either Column or " - f"List[Column], instead received: {type(expr)}" - ) + # comparators + # Generate column expressions that compare the index + # with other columns, expressions or values + + @abstractmethod + def comparableExpr(self) -> Union[Column, List[Column]]: + """ + :return: an expression that can be used to compare an index with + other columns, expressions or values + """ + + def __eq__(self, other) -> Column: + return self.comparableExpr().eq(_col_or_lit(other)) + + def __ne__(self, other) -> Column: + return self.comparableExpr().neq(_col_or_lit(other)) + + def __lt__(self, other) -> Column: + return self.comparableExpr().lt(_col_or_lit(other)) + + def __le__(self, other) -> Column: + return self.comparableExpr().leq(_col_or_lit(other)) + + def __gt__(self, other) -> Column: + return self.comparableExpr().gt(_col_or_lit(other)) + + def __ge__(self, other) -> Column: + return self.comparableExpr().geq(_col_or_lit(other)) + + # other expression builder methods @abstractmethod def orderByExpr(self, reverse: bool = False) -> Union[Column, List[Column]]: @@ -159,16 +210,22 @@ class SimpleTSIndex(TSIndex, ABC): def __init__(self, ts_col: StructField) -> None: self.__name = ts_col.name - self.dataType = ts_col.dataType + self.__dataType = ts_col.dataType - @property - def _indexAttributes(self) -> dict[str, Any]: - return {"name": self.colname, "dataType": self.dataType, "unit": self.unit} + def __repr__(self) -> str: + return ( + f"{self.__class__.__name__}(name={self.colname}, " + f"type={self.dataType}, unit={self.unit})" + ) @property def colname(self): return self.__name + @property + def dataType(self) -> DataType: + return self.__dataType + def validate(self, df_schema: StructType) -> None: # the ts column must exist assert ( @@ -177,17 +234,20 @@ def validate(self, df_schema: StructType) -> None: schema_ts_col = df_schema[self.colname] # it must have the right type schema_ts_type = schema_ts_col.dataType - assert isinstance( - schema_ts_type, type(self.dataType) - ), f"The TSIndex column is of type {schema_ts_type}, but the expected type is {self.dataType}" + assert isinstance(schema_ts_type, type(self.dataType)), ( + f"The TSIndex column is of type {schema_ts_type}, but the expected type is" + f" {self.dataType}" + ) def renamed(self, new_name: str) -> "TSIndex": self.__name = new_name return self + def comparableExpr(self) -> Column: + return sfn.col(self.colname) + def orderByExpr(self, reverse: bool = False) -> Column: - expr = sfn.col(self.colname) - return self._reverseOrNot(expr, reverse) + return _reverse_or_not(self.comparableExpr(), reverse) @classmethod def fromTSCol(cls, ts_col: StructField) -> "SimpleTSIndex": @@ -200,7 +260,8 @@ def fromTSCol(cls, ts_col: StructField) -> "SimpleTSIndex": return SimpleDateIndex(ts_col) else: raise TypeError( - f"A SimpleTSIndex must be a Numeric, Timestamp or Date type, but column {ts_col.name} is of type {ts_col.dataType}" + "A SimpleTSIndex must be a Numeric, Timestamp or Date type, but column" + f" {ts_col.name} is of type {ts_col.dataType}" ) @@ -230,7 +291,9 @@ def unit(self) -> Optional[TimeUnit]: return None def rangeExpr(self, reverse: bool = False) -> Column: - raise TypeError("Cannot perform range operations on an OrdinalTSIndex") + raise NotImplementedError( + "Cannot perform range operations on an OrdinalTSIndex" + ) class SimpleTimestampIndex(SimpleTSIndex): @@ -241,7 +304,7 @@ class SimpleTimestampIndex(SimpleTSIndex): def __init__(self, ts_col: StructField) -> None: if not isinstance(ts_col.dataType, TimestampType): raise TypeError( - f"SimpleTimestampIndex must be of TimestampType, " + "SimpleTimestampIndex must be of TimestampType, " f"but given ts_col {ts_col.name} has type {ts_col.dataType}" ) super().__init__(ts_col) @@ -252,8 +315,8 @@ def unit(self) -> Optional[TimeUnit]: def rangeExpr(self, reverse: bool = False) -> Column: # cast timestamp to double (fractional seconds since epoch) - expr = sfn.col(self.colname).cast("double") - return self._reverseOrNot(expr, reverse) + expr = self.comparableExpr().cast("double") + return _reverse_or_not(expr, reverse) class SimpleDateIndex(SimpleTSIndex): @@ -264,7 +327,7 @@ class SimpleDateIndex(SimpleTSIndex): def __init__(self, ts_col: StructField) -> None: if not isinstance(ts_col.dataType, DateType): raise TypeError( - f"DateIndex must be of DateType, " + "DateIndex must be of DateType, " f"but given ts_col {ts_col.name} has type {ts_col.dataType}" ) super().__init__(ts_col) @@ -275,8 +338,10 @@ def unit(self) -> Optional[TimeUnit]: def rangeExpr(self, reverse: bool = False) -> Column: # convert date to number of days since the epoch - expr = sfn.datediff(sfn.col(self.colname), sfn.lit("1970-01-01").cast("date")) - return self._reverseOrNot(expr, reverse) + expr = sfn.datediff( + self.comparableExpr(), sfn.lit(EPOCH_START_DATE).cast("date") + ) + return _reverse_or_not(expr, reverse) # @@ -295,12 +360,12 @@ class CompositeTSIndex(TSIndex, ABC): def __init__(self, ts_struct: StructField, *component_fields: str) -> None: if not isinstance(ts_struct.dataType, StructType): raise TypeError( - f"CompoundTSIndex must be of type StructType, but given " + "CompoundTSIndex must be of type StructType, but given " f"ts_struct {ts_struct.name} has type {ts_struct.dataType}" ) # validate the index fields assert len(component_fields) > 0, ( - f"A MultiFieldTSIndex must have at least 1 index component field, " + "A MultiFieldTSIndex must have at least 1 index component field, " f"but {len(component_fields)} were given" ) for ind_f in component_fields: @@ -312,19 +377,21 @@ def __init__(self, ts_struct: StructField, *component_fields: str) -> None: self.schema: StructType = ts_struct.dataType self.component_fields: List[str] = list(component_fields) - @property - def _indexAttributes(self) -> dict[str, Any]: - return { - "name": self.colname, - "schema": self.schema, - "unit": self.unit, - "component_fields": self.component_fields, - } + def __repr__(self) -> str: + return ( + f"{self.__class__.__name__}(name={self.colname}, " + f"schema={self.schema}, unit={self.unit}, " + f"component_fields={self.component_fields})" + ) @property def colname(self) -> str: return self.__name + @property + def dataType(self) -> DataType: + return self.schema + @property def fieldNames(self) -> List[str]: return self.schema.fieldNames() @@ -333,19 +400,6 @@ def fieldNames(self) -> List[str]: def accessory_fields(self) -> List[str]: return list(set(self.fieldNames) - set(self.component_fields)) - def validate(self, df_schema: StructType) -> None: - # validate that the composite field exists - assert ( - self.colname in df_schema.fieldNames() - ), f"The TSIndex column {self.colname} does not exist in the given DataFrame" - schema_ts_col = df_schema[self.colname] - # it must have the right type - schema_ts_type = schema_ts_col.dataType - assert schema_ts_type == self.schema, ( - f"The TSIndex column is of type {schema_ts_type}, " - f"but the expected type is {self.schema}" - ) - def renamed(self, new_name: str) -> "TSIndex": self.__name = new_name return self @@ -361,10 +415,122 @@ def fieldPath(self, field: str) -> str: ), f"Field {field} does not exist in the TSIndex schema {self.schema}" return f"{self.colname}.{field}" + def validate(self, df_schema: StructType) -> None: + # validate that the composite field exists + assert ( + self.colname in df_schema.fieldNames() + ), f"The TSIndex column {self.colname} does not exist in the given DataFrame" + schema_ts_col = df_schema[self.colname] + # it must have the right type + schema_ts_type = schema_ts_col.dataType + assert schema_ts_type == self.schema, ( + f"The TSIndex column is of type {schema_ts_type}, " + f"but the expected type is {self.schema}" + ) + + # expression builder methods + + def comparableExpr(self) -> List[Column]: + return [sfn.col(self.fieldPath(comp)) for comp in self.component_fields] + def orderByExpr(self, reverse: bool = False) -> Union[Column, List[Column]]: - # build an expression for each TS component, in order - exprs = [sfn.col(self.fieldPath(comp)) for comp in self.component_fields] - return self._reverseOrNot(exprs, reverse) + return _reverse_or_not(self.comparableExpr(), reverse) + + # comparators + + def _validate_other(self, other) -> None: + if len(other) != len(self.component_fields): + raise ValueError( + f"{self.__class__.__name__} has {len(self.component_fields)} " + "component fields, and requires this many arguments for comparison, " + f"but received {len(other)}" + ) + + def __eq__(self, other) -> Column: + # try to compare the whole index to a single value + if not isinstance(other, (tuple, list)): + return self.__eq__([other]) + # validate the number of arguments + self._validate_other(other) + # match each component field with its corresponding comparison value + comps = zip(self.comparableExpr(), [_col_or_lit(o) for o in other]) + # build comparison expressions for each pair + comp_exprs = [c.eq(o) for (c, o) in comps] + # conjunction of all expressions (AND) + return sfn.expr(" AND ".join(comp_exprs)) + + def __ne__(self, other) -> Column: + # try to compare the whole index to a single value + if not isinstance(other, (tuple, list)): + return self.__ne__([other]) + # validate the arguments + self._validate_other(other) + # match each component field with its corresponding comparison value + comps = zip(self.comparableExpr(), [_col_or_lit(o) for o in other]) + # build comparison expressions for each pair + comp_exprs = [c.neq(o) for (c, o) in comps] + # disjunction of all expressions (OR) + return sfn.expr(" OR ".join(comp_exprs)) + + def __lt__(self, other) -> Column: + # try to compare the whole index to a single value + if not isinstance(other, (tuple, list)): + return self.__lt__([other]) + # validate the arguments + self._validate_other(other) + # match each component field with its corresponding comparison value + comps = list(zip(self.comparableExpr(), [_col_or_lit(o) for o in other])) + # do a leq for all but the last component + comp_exprs = [] + if len(comps) > 1: + comp_exprs = [c.leq(o) for (c, o) in comps[:-1]] + # strict lt for the last component + comp_exprs += [c.lt(o) for (c, o) in comps[-1:]] + # conjunction of all expressions (AND) + return sfn.expr(" AND ".join(comp_exprs)) + + def __le__(self, other) -> Column: + # try to compare the whole index to a single value + if not isinstance(other, (tuple, list)): + return self.__le__([other]) + # validate the arguments + self._validate_other(other) + # match each component field with its corresponding comparison value + comps = zip(self.comparableExpr(), [_col_or_lit(o) for o in other]) + # build comparison expressions for each pair + comp_exprs = [c.leq(o) for (c, o) in comps] + # conjunction of all expressions (AND) + return sfn.expr(" AND ".join(comp_exprs)) + + def __gt__(self, other) -> Column: + # try to compare the whole index to a single value + if not isinstance(other, (tuple, list)): + return self.__gt__([other]) + # validate the arguments + self._validate_other(other) + # match each component field with its corresponding comparison value + comps = list(zip(self.comparableExpr(), [_col_or_lit(o) for o in other])) + # do a geq for all but the last component + comp_exprs = [] + if len(comps) > 1: + comp_exprs = [c.geq(o) for (c, o) in comps[:-1]] + # strict gt for the last component + comp_exprs += [c.gt(o) for (c, o) in comps[-1:]] + # conjunction of all expressions (AND) + return sfn.expr(" AND ".join(comp_exprs)) + + def __ge__(self, other) -> Column: + # try to compare the whole index to a single value + if not isinstance(other, (tuple, list)): + return self.__ge__([other]) + # validate the arguments + self._validate_other(other) + # match each component field with its corresponding comparison value + comps = zip(self.comparableExpr(), [_col_or_lit(o) for o in other]) + # build comparison expressions for each pair + comp_exprs = [c.geq(o) for (c, o) in comps] + # conjunction of all expressions (AND) + return sfn.expr(" AND ".join(comp_exprs)) # @@ -372,135 +538,104 @@ def orderByExpr(self, reverse: bool = False) -> Union[Column, List[Column]]: # -# class ParsedTSIndex(CompositeTSIndex, ABC): -# """ -# Abstract base class for timeseries indices that are parsed from a string column. -# Retains the original string form as well as the parsed column. -# """ -# -# def __init__( -# self, ts_struct: StructField, parsed_ts_col: str, src_str_col: str -# ) -> None: -# super().__init__(ts_struct, parsed_ts_col) -# # validate the source string column -# src_str_field = self.schema[src_str_col] -# if not isinstance(src_str_field.dataType, StringType): -# raise TypeError( -# f"Source string column must be of StringType, " -# f"but given column {src_str_field.name} " -# f"is of type {src_str_field.dataType}" -# ) -# self._src_str_col = src_str_col -# # validate the parsed column -# assert parsed_ts_col in self.schema.fieldNames(), ( -# f"The parsed timestamp index field {parsed_ts_col} does not exist in the " -# f"MultiPart TSIndex schema {self.schema}" -# ) -# self._parsed_ts_col = parsed_ts_col -# -# @property -# def src_str_col(self): -# return self.fieldPath(self._src_str_col) -# -# @property -# def parsed_ts_col(self): -# return self.fieldPath(self._parsed_ts_col) -# -# @property -# def ts_col(self) -> str: -# return self.parsed_ts_col -# -# @property -# def _indexAttributes(self) -> dict[str, Any]: -# attrs = super()._indexAttributes -# attrs["parsed_ts_col"] = self.parsed_ts_col -# attrs["src_str_col"] = self.src_str_col -# return attrs -# -# def orderByExpr(self, reverse: bool = False) -> Union[Column, List[Column]]: -# expr = sfn.col(self.parsed_ts_col) -# return self._reverseOrNot(expr, reverse) -# -# @classmethod -# def fromParsedTimestamp( -# cls, -# ts_struct: StructField, -# parsed_ts_col: str, -# src_str_col: str, -# double_ts_col: Optional[str] = None, -# num_precision_digits: int = 6, -# ) -> "ParsedTSIndex": -# """ -# Create a ParsedTimestampIndex from a string column containing timestamps or dates -# -# :param ts_struct: The StructField for the TSIndex column -# :param parsed_ts_col: The name of the parsed timestamp column -# :param src_str_col: The name of the source string column -# :param double_ts_col: The name of the double-precision timestamp column -# :param num_precision_digits: The number of digits that make up the precision of -# -# :return: A ParsedTSIndex object -# """ -# -# # if a double timestamp column is given -# # then we are building a SubMicrosecondPrecisionTimestampIndex -# if double_ts_col is not None: -# return SubMicrosecondPrecisionTimestampIndex( -# ts_struct, -# double_ts_col, -# parsed_ts_col, -# src_str_col, -# num_precision_digits, -# ) -# # otherwise, we base it on the standard timestamp type -# # find the schema of the ts_struct column -# ts_schema = ts_struct.dataType -# if not isinstance(ts_schema, StructType): -# raise TypeError( -# f"A ParsedTSIndex must be of type StructType, but given " -# f"ts_struct {ts_struct.name} has type {ts_struct.dataType}" -# ) -# # get the type of the parsed timestamp column -# parsed_ts_type = ts_schema[parsed_ts_col].dataType -# if isinstance(parsed_ts_type, TimestampType): -# return ParsedTimestampIndex(ts_struct, parsed_ts_col, src_str_col) -# elif isinstance(parsed_ts_type, DateType): -# return ParsedDateIndex(ts_struct, parsed_ts_col, src_str_col) -# else: -# raise TypeError( -# f"ParsedTimestampIndex must be of TimestampType or DateType, " -# f"but given ts_col {parsed_ts_col} " -# f"has type {parsed_ts_type}" -# ) - - -class ParsedTimestampIndex(CompositeTSIndex): +class ParsedTSIndex(CompositeTSIndex, ABC): """ - Timeseries index class for timestamps parsed from a string column + Abstract base class for timeseries indices that are parsed from a string column. + Retains the original string form as well as the parsed column. """ def __init__( self, ts_struct: StructField, parsed_ts_col: str, src_str_col: str ) -> None: super().__init__(ts_struct, parsed_ts_col) - # validate the parsed column as a timestamp column - parsed_ts_field = self.schema[parsed_ts_col] - if not isinstance(parsed_ts_field.dataType, TimestampType): - raise TypeError( - f"ParsedTimestampIndex requires an index field of TimestampType, " - f"but the given parsed_ts_col {parsed_ts_col} " - f"has type {parsed_ts_field.dataType}" - ) - self.parsed_ts_col = parsed_ts_col - # validate the source column as a string column + # validate the source string column src_str_field = self.schema[src_str_col] if not isinstance(src_str_field.dataType, StringType): raise TypeError( - f"ParsedTimestampIndex requires an source field of StringType, " - f"but the given src_str_col {src_str_col} " - f"has type {src_str_field.dataType}" + "Source string column must be of StringType, " + f"but given column {src_str_field.name} " + f"is of type {src_str_field.dataType}" ) - self.src_str_col = src_str_col + self._src_str_col = src_str_col + # validate the parsed column + assert parsed_ts_col in self.schema.fieldNames(), ( + f"The parsed timestamp index field {parsed_ts_col} does not exist in the " + f"MultiPart TSIndex schema {self.schema}" + ) + self._parsed_ts_col = parsed_ts_col + + @property + def src_str_col(self): + return self.fieldPath(self._src_str_col) + + @property + def parsed_ts_col(self): + return self.fieldPath(self._parsed_ts_col) + + def orderByExpr(self, reverse: bool = False) -> Union[Column, List[Column]]: + expr = sfn.col(self.parsed_ts_col) + return _reverse_or_not(expr, reverse) + + def comparableExpr(self) -> Column: + return sfn.col(self.parsed_ts_col) + + @classmethod + def fromParsedTimestamp( + cls, + ts_struct: StructField, + parsed_ts_col: str, + src_str_col: str, + double_ts_col: Optional[str] = None, + num_precision_digits: int = 6, + ) -> "ParsedTSIndex": + """ + Create a ParsedTimestampIndex from a string column containing timestamps or dates + + :param ts_struct: The StructField for the TSIndex column + :param parsed_ts_col: The name of the parsed timestamp column + :param src_str_col: The name of the source string column + :param double_ts_col: The name of the double-precision timestamp column + :param num_precision_digits: The number of digits that make up the precision of + + :return: A ParsedTSIndex object + """ + + # if a double timestamp column is given + # then we are building a SubMicrosecondPrecisionTimestampIndex + if double_ts_col is not None: + return SubMicrosecondPrecisionTimestampIndex( + ts_struct, + double_ts_col, + parsed_ts_col, + src_str_col, + num_precision_digits, + ) + # otherwise, we base it on the standard timestamp type + # find the schema of the ts_struct column + ts_schema = ts_struct.dataType + if not isinstance(ts_schema, StructType): + raise TypeError( + "A ParsedTSIndex must be of type StructType, but given " + f"ts_struct {ts_struct.name} has type {ts_struct.dataType}" + ) + # get the type of the parsed timestamp column + parsed_ts_type = ts_schema[parsed_ts_col].dataType + if isinstance(parsed_ts_type, TimestampType): + return ParsedTimestampIndex(ts_struct, parsed_ts_col, src_str_col) + elif isinstance(parsed_ts_type, DateType): + return ParsedDateIndex(ts_struct, parsed_ts_col, src_str_col) + else: + raise TypeError( + "ParsedTimestampIndex must be of TimestampType or DateType, " + f"but given ts_col {parsed_ts_col} " + f"has type {parsed_ts_type}" + ) + + +class ParsedTimestampIndex(ParsedTSIndex): + """ + Timeseries index class for timestamps parsed from a string column + """ @property def unit(self) -> Optional[TimeUnit]: @@ -508,38 +643,15 @@ def unit(self) -> Optional[TimeUnit]: def rangeExpr(self, reverse: bool = False) -> Column: # cast timestamp to double (fractional seconds since epoch) - expr = sfn.col(self.fieldPath(self.parsed_ts_col)).cast("double") - return self._reverseOrNot(expr, reverse) + expr = sfn.col(self.parsed_ts_col).cast("double") + return _reverse_or_not(expr, reverse) -class ParsedDateIndex(CompositeTSIndex): +class ParsedDateIndex(ParsedTSIndex): """ Timeseries index class for dates parsed from a string column """ - def __init__( - self, ts_struct: StructField, parsed_date_col: str, src_str_col: str - ) -> None: - super().__init__(ts_struct, parsed_date_col) - # validate the parsed column as a date column - parsed_date_field = self.schema[parsed_date_col] - if not isinstance(parsed_date_field.dataType, DateType): - raise TypeError( - f"ParsedDateIndex requires an index field of DateType, " - f"but the given parsed_ts_col {parsed_date_col} " - f"has type {parsed_date_field.dataType}" - ) - self.parsed_date_col = parsed_date_col - # validate the source column as a string column - src_str_field = self.schema[src_str_col] - if not isinstance(src_str_field.dataType, StringType): - raise TypeError( - f"ParsedDateIndex requires an source field of StringType, " - f"but the given src_str_col {src_str_col} " - f"has type {src_str_field.dataType}" - ) - self.src_str_col = src_str_col - @property def unit(self) -> Optional[TimeUnit]: return StandardTimeUnits.DAYS @@ -547,10 +659,10 @@ def unit(self) -> Optional[TimeUnit]: def rangeExpr(self, reverse: bool = False) -> Column: # convert date to number of days since the epoch expr = sfn.datediff( - sfn.col(self.fieldPath(self.parsed_date_col)), - sfn.lit("1970-01-01").cast("date"), + sfn.col(self.parsed_ts_col), + sfn.lit(EPOCH_START_DATE).cast("date"), ) - return self._reverseOrNot(expr, reverse) + return _reverse_or_not(expr, reverse) class SubMicrosecondPrecisionTimestampIndex(CompositeTSIndex): @@ -584,7 +696,7 @@ def __init__( double_ts_field = self.schema[double_ts_col] if not isinstance(double_ts_field.dataType, DoubleType): raise TypeError( - f"The double_ts_col must be of DoubleType, " + "The double_ts_col must be of DoubleType, " f"but the given double_ts_col {double_ts_col} " f"has type {double_ts_field.dataType}" ) @@ -592,10 +704,10 @@ def __init__( # validate the number of precision digits if num_precision_digits <= 6: warnings.warn( - f"SubMicrosecondPrecisionTimestampIndex has a num_precision_digits " + "SubMicrosecondPrecisionTimestampIndex has a num_precision_digits " f"of {num_precision_digits} which is within the range of the " - f"standard timestamp precision of 6 digits (microseconds). " - f"Consider using a ParsedTimestampIndex instead." + "standard timestamp precision of 6 digits (microseconds). " + "Consider using a ParsedTimestampIndex instead." ) self.__unit = TimeUnit( f"custom_subsecond_unit (precision: {num_precision_digits})", @@ -606,7 +718,7 @@ def __init__( parsed_ts_field = self.schema[parsed_ts_col] if not isinstance(parsed_ts_field.dataType, TimestampType): raise TypeError( - f"parsed_ts_col field must be of TimestampType, " + "parsed_ts_col field must be of TimestampType, " f"but the given parsed_ts_col {parsed_ts_col} " f"has type {parsed_ts_field.dataType}" ) @@ -615,7 +727,7 @@ def __init__( src_str_field = self.schema[src_str_col] if not isinstance(src_str_field.dataType, StringType): raise TypeError( - f"src_str_col field must be of StringType, " + "src_str_col field must be of StringType, " f"but the given src_str_col {src_str_col} " f"has type {src_str_field.dataType}" ) @@ -625,14 +737,17 @@ def __init__( def unit(self) -> Optional[TimeUnit]: return self.__unit + def comparableExpr(self) -> Column: + return sfn.col(self.fieldPath(self.double_ts_col)) + def orderByExpr(self, reverse: bool = False) -> Union[Column, List[Column]]: - expr = sfn.col(self.fieldPath(self.double_ts_col)) - return self._reverseOrNot(expr, reverse) + return _reverse_or_not(self.comparableExpr(), reverse) - def rangeExpr(self, reverse: bool = False) -> Union[Column, List[Column]]: + def rangeExpr(self, reverse: bool = False) -> Column: # just use the order by expression, since this is the same return self.orderByExpr(reverse) + # # Window Builder Interface # diff --git a/python/tests/tsschema_tests.py b/python/tests/tsschema_tests.py new file mode 100644 index 00000000..e040d62a --- /dev/null +++ b/python/tests/tsschema_tests.py @@ -0,0 +1,235 @@ +import unittest +from parameterized import parameterized_class + +from pyspark.sql import Column +from pyspark.sql import functions as sfn +from pyspark.sql.types import ( + StructField, + StructType, + StringType, + TimestampType, + DoubleType, + IntegerType, + DateType, +) + +from tempo.tsschema import ( + TSIndex, + SimpleTimestampIndex, + OrdinalTSIndex, + SimpleDateIndex, + StandardTimeUnits, + ParsedTimestampIndex, + ParsedDateIndex +) +from tests.base import SparkTest + + +@parameterized_class( + ( + "name", + "ts_field", + "idx_class", + "ts_unit", + "expected_comp_expr", + "expected_range_expr", + ), + [ + ( + "simple_timestamp_index", + StructField("event_ts", TimestampType()), + SimpleTimestampIndex, + StandardTimeUnits.SECONDS, + "Column<'event_ts'>", + "Column<'CAST(event_ts AS DOUBLE)'>", + ), + ( + "ordinal_double_index", + StructField("event_ts_dbl", DoubleType()), + OrdinalTSIndex, + None, + "Column<'event_ts_dbl'>", + None, + ), + ( + "ordinal_int_index", + StructField("order", IntegerType()), + OrdinalTSIndex, + None, + "Column<'order'>", + None, + ), + ( + "simple_date_index", + StructField("date", DateType()), + SimpleDateIndex, + StandardTimeUnits.DAYS, + "Column<'date'>", + "Column<'datediff(date, CAST(1970-01-01 AS DATE))'>", + ), + ], +) +class SimpleTSIndexTests(SparkTest): + def test_constructor(self): + # create a timestamp index + ts_idx = self.idx_class(self.ts_field) + # must be a valid TSIndex object + self.assertIsNotNone(ts_idx) + self.assertIsInstance(ts_idx, self.idx_class) + # must have the correct field name and type + self.assertEqual(ts_idx.colname, self.ts_field.name) + self.assertEqual(ts_idx.dataType, self.ts_field.dataType) + # validate the unit + if self.ts_unit is None: + self.assertFalse(ts_idx.has_unit) + else: + self.assertTrue(ts_idx.has_unit) + self.assertEqual(ts_idx.unit, self.ts_unit) + + def test_comparable_expression(self): + # create a timestamp index + ts_idx: TSIndex = self.idx_class(self.ts_field) + # get the expressions + compbl_expr = ts_idx.comparableExpr() + # validate the expression + self.assertIsNotNone(compbl_expr) + self.assertIsInstance(compbl_expr, Column) + self.assertEqual(repr(compbl_expr), self.expected_comp_expr) + + def test_orderby_expression(self): + # create a timestamp index + ts_idx: TSIndex = self.idx_class(self.ts_field) + # get the expressions + orderby_expr = ts_idx.orderByExpr() + # validate the expression + self.assertIsNotNone(orderby_expr) + self.assertIsInstance(orderby_expr, Column) + self.assertEqual(repr(orderby_expr), self.expected_comp_expr) + + def test_range_expression(self): + # create a timestamp index + ts_idx = self.idx_class(self.ts_field) + # get the expressions + if isinstance(ts_idx, OrdinalTSIndex): + self.assertRaises(NotImplementedError, ts_idx.rangeExpr) + else: + range_expr = ts_idx.rangeExpr() + # validate the expression + self.assertIsNotNone(range_expr) + self.assertIsInstance(range_expr, Column) + self.assertEqual(repr(range_expr), self.expected_range_expr) + + +@parameterized_class( + ( + "name", + "ts_field", + "constr_args", + "idx_class", + "ts_unit", + "expected_comp_expr", + "expected_range_expr" + ), + [ + ( + "parsed_timestamp_index", + StructField( + "ts_idx", + StructType([ + StructField("parsed_ts", TimestampType(), True), + StructField("src_str", StringType(), True), + ]), + True, + ), + {"parsed_ts_col": "parsed_ts", "src_str_col": "src_str"}, + ParsedTimestampIndex, + StandardTimeUnits.SECONDS, + "Column<'ts_idx.parsed_ts'>", + "Column<'CAST(ts_idx.parsed_ts AS DOUBLE)'>" + ), + ( + "parsed_date_index", + StructField( + "ts_idx", + StructType([ + StructField("parsed_date", DateType(), True), + StructField("src_str", StringType(), True), + ]), + True, + ), + {"parsed_ts_col": "parsed_date", "src_str_col": "src_str"}, + ParsedDateIndex, + StandardTimeUnits.DAYS, + "Column<'ts_idx.parsed_date'>", + "Column<'datediff(ts_idx.parsed_date, CAST(1970-01-01 AS DATE))'>" + ), + ( + "sub_ms_index", + StructField( + "ts_idx", + StructType([ + StructField("double_ts", TimestampType(), True), + StructField("parsed_ts", TimestampType(), True), + StructField("src_str", StringType(), True), + ]), + True, + ), + {"double_ts_col": "double_ts", "parsed_ts_col": "parsed_ts", "src_str_col": "src_str"}, + ParsedTimestampIndex, + StandardTimeUnits.SECONDS, + "Column<'ts_idx.parsed_ts'>", + "Column<'CAST(ts_idx.parsed_ts AS DOUBLE)'>" + ), + ]) +class ParsedTSIndexTests(SparkTest): + def test_constructor(self): + # create a timestamp index + ts_idx = self.idx_class(ts_struct=self.ts_field, **self.constr_args) + # must be a valid TSIndex object + self.assertIsNotNone(ts_idx) + self.assertIsInstance(ts_idx, self.idx_class) + # must have the correct field name and type + self.assertEqual(ts_idx.colname, self.ts_field.name) + self.assertEqual(ts_idx.dataType, self.ts_field.dataType) + # validate the unit + self.assertTrue(ts_idx.has_unit) + self.assertEqual(ts_idx.unit, self.ts_unit) + + def test_comparable_expression(self): + # create a timestamp index + ts_idx = self.idx_class(ts_struct=self.ts_field, **self.constr_args) + # get the expressions + compbl_expr = ts_idx.comparableExpr() + # validate the expression + self.assertIsNotNone(compbl_expr) + self.assertIsInstance(compbl_expr, Column) + self.assertEqual(repr(compbl_expr), self.expected_comp_expr) + + def test_orderby_expression(self): + # create a timestamp index + ts_idx = self.idx_class(ts_struct=self.ts_field, **self.constr_args) + # get the expressions + orderby_expr = ts_idx.orderByExpr() + # validate the expression + self.assertIsNotNone(orderby_expr) + self.assertIsInstance(orderby_expr, Column) + self.assertEqual(repr(orderby_expr), self.expected_comp_expr) + + def test_range_expression(self): + # create a timestamp index + ts_idx = self.idx_class(ts_struct=self.ts_field, **self.constr_args) + # get the expressions + range_expr = ts_idx.rangeExpr() + # validate the expression + self.assertIsNotNone(range_expr) + self.assertIsInstance(range_expr, Column) + self.assertEqual(repr(range_expr), self.expected_range_expr) + + +# class TSSchemaTests(SparkTest): +# def test_simple_tsIndex(self): +# schema_str = "event_ts timestamp, symbol string, trade_pr double" +# schema = _parse_datatype_string(schema_str) +# ts_idx = TSSchema.fromDFSchema(schema, "event_ts", ["symbol"]) +# +# print(ts_idx) diff --git a/python/tests/unit_test_data/tsdf_tests.json b/python/tests/unit_test_data/tsdf_tests.json index 521de84d..9e54e63f 100644 --- a/python/tests/unit_test_data/tsdf_tests.json +++ b/python/tests/unit_test_data/tsdf_tests.json @@ -3,50 +3,16 @@ "temp_slice_init_data": { "schema": "symbol string, event_ts string, trade_pr float", "ts_col": "event_ts", - "series_ids": [ - "symbol" - ], + "series_ids": ["symbol"], "data": [ - [ - "S1", - "2020-08-01 00:00:10", - 349.21 - ], - [ - "S1", - "2020-08-01 00:01:12", - 351.32 - ], - [ - "S1", - "2020-09-01 00:02:10", - 361.1 - ], - [ - "S1", - "2020-09-01 00:19:12", - 362.1 - ], - [ - "S2", - "2020-08-01 00:01:10", - 743.01 - ], - [ - "S2", - "2020-08-01 00:01:24", - 751.92 - ], - [ - "S2", - "2020-09-01 00:02:10", - 761.10 - ], - [ - "S2", - "2020-09-01 00:20:42", - 762.33 - ] + ["S1", "2020-08-01 00:00:10", 349.21], + ["S1", "2020-08-01 00:01:12", 351.32], + ["S1", "2020-09-01 00:02:10", 361.1], + ["S1", "2020-09-01 00:19:12", 362.1], + ["S2", "2020-08-01 00:01:10", 743.01], + ["S2", "2020-08-01 00:01:24", 751.92], + ["S2", "2020-09-01 00:02:10", 761.10], + ["S2", "2020-09-01 00:20:42", 762.33] ] } }, diff --git a/python/tests/unit_test_data/tsschema_tests.json b/python/tests/unit_test_data/tsschema_tests.json new file mode 100644 index 00000000..bc820168 --- /dev/null +++ b/python/tests/unit_test_data/tsschema_tests.json @@ -0,0 +1,26 @@ +{ + "__SharedData": { + "simple_ts_idx": { + "schema": "symbol string, event_ts string, trade_pr float", + "ts_col": "event_ts", + "series_ids": ["symbol"], + "data": [ + ["S1", "2020-08-01 00:00:10", 349.21], + ["S1", "2020-08-01 00:01:12", 351.32], + ["S1", "2020-09-01 00:02:10", 361.1], + ["S1", "2020-09-01 00:19:12", 362.1], + ["S2", "2020-08-01 00:01:10", 743.01], + ["S2", "2020-08-01 00:01:24", 751.92], + ["S2", "2020-09-01 00:02:10", 761.10], + ["S2", "2020-09-01 00:20:42", 762.33] + ] + } + }, + "TSSchemaTests": { + "test_simple_tsIndex": { + "simple_ts_idx": { + "$ref": "#/__SharedData/simple_ts_idx" + } + } + } +} \ No newline at end of file From 2127f3bbad52b5fc687653deb673d1bc7f399a24 Mon Sep 17 00:00:00 2001 From: Tristan Nixon Date: Mon, 15 Jan 2024 15:39:27 -0800 Subject: [PATCH 09/13] checkpoint save - more advanced schema testing --- python/tempo/tsschema.py | 142 ++++--- python/tests/base.py | 58 ++- python/tests/tsschema_tests.py | 361 +++++++++++++++--- .../tests/unit_test_data/tsschema_tests.json | 37 +- 4 files changed, 440 insertions(+), 158 deletions(-) diff --git a/python/tempo/tsschema.py b/python/tempo/tsschema.py index 48cb1129..165cefe7 100644 --- a/python/tempo/tsschema.py +++ b/python/tempo/tsschema.py @@ -545,39 +545,39 @@ class ParsedTSIndex(CompositeTSIndex, ABC): """ def __init__( - self, ts_struct: StructField, parsed_ts_col: str, src_str_col: str + self, ts_struct: StructField, parsed_ts_field: str, src_str_field: str ) -> None: - super().__init__(ts_struct, parsed_ts_col) + super().__init__(ts_struct, parsed_ts_field) # validate the source string column - src_str_field = self.schema[src_str_col] - if not isinstance(src_str_field.dataType, StringType): + src_str_type = self.schema[src_str_field].dataType + if not isinstance(src_str_type, StringType): raise TypeError( "Source string column must be of StringType, " - f"but given column {src_str_field.name} " - f"is of type {src_str_field.dataType}" + f"but given column {src_str_field} " + f"is of type {src_str_type}" ) - self._src_str_col = src_str_col + self._src_str_field = src_str_field # validate the parsed column - assert parsed_ts_col in self.schema.fieldNames(), ( - f"The parsed timestamp index field {parsed_ts_col} does not exist in the " + assert parsed_ts_field in self.schema.fieldNames(), ( + f"The parsed timestamp index field {parsed_ts_field} does not exist in the " f"MultiPart TSIndex schema {self.schema}" ) - self._parsed_ts_col = parsed_ts_col + self._parsed_ts_field = parsed_ts_field @property - def src_str_col(self): - return self.fieldPath(self._src_str_col) + def src_str_field(self): + return self.fieldPath(self._src_str_field) @property - def parsed_ts_col(self): - return self.fieldPath(self._parsed_ts_col) + def parsed_ts_field(self): + return self.fieldPath(self._parsed_ts_field) def orderByExpr(self, reverse: bool = False) -> Union[Column, List[Column]]: - expr = sfn.col(self.parsed_ts_col) + expr = sfn.col(self.parsed_ts_field) return _reverse_or_not(expr, reverse) def comparableExpr(self) -> Column: - return sfn.col(self.parsed_ts_col) + return sfn.col(self.parsed_ts_field) @classmethod def fromParsedTimestamp( @@ -643,7 +643,7 @@ def unit(self) -> Optional[TimeUnit]: def rangeExpr(self, reverse: bool = False) -> Column: # cast timestamp to double (fractional seconds since epoch) - expr = sfn.col(self.parsed_ts_col).cast("double") + expr = sfn.col(self.parsed_ts_field).cast("double") return _reverse_or_not(expr, reverse) @@ -659,7 +659,7 @@ def unit(self) -> Optional[TimeUnit]: def rangeExpr(self, reverse: bool = False) -> Column: # convert date to number of days since the epoch expr = sfn.datediff( - sfn.col(self.parsed_ts_col), + sfn.col(self.parsed_ts_field), sfn.lit(EPOCH_START_DATE).cast("date"), ) return _reverse_or_not(expr, reverse) @@ -676,31 +676,31 @@ class SubMicrosecondPrecisionTimestampIndex(CompositeTSIndex): def __init__( self, ts_struct: StructField, - double_ts_col: str, - parsed_ts_col: str, - src_str_col: str, + double_ts_field: str, + secondary_parsed_ts_field: str, + src_str_field: str, num_precision_digits: int = 9, ) -> None: """ :param ts_struct: The StructField for the TSIndex column - :param double_ts_col: The name of the double-precision timestamp column - :param parsed_ts_col: The name of the parsed timestamp column - :param src_str_col: The name of the source string column + :param double_ts_field: The name of the double-precision timestamp column + :param secondary_parsed_ts_field: The name of the parsed timestamp column + :param src_str_field: The name of the source string column :param num_precision_digits: The number of digits that make up the precision of the timestamp. Ie. 9 for nanoseconds (default), 12 for picoseconds, etc. You will receive a warning if this value is 6 or less, as this is the precision of the standard timestamp type. """ - super().__init__(ts_struct, double_ts_col) + super().__init__(ts_struct, double_ts_field) # validate the double timestamp column - double_ts_field = self.schema[double_ts_col] - if not isinstance(double_ts_field.dataType, DoubleType): + double_ts_type = self.schema[double_ts_field].dataType + if not isinstance(double_ts_type, DoubleType): raise TypeError( "The double_ts_col must be of DoubleType, " - f"but the given double_ts_col {double_ts_col} " - f"has type {double_ts_field.dataType}" + f"but the given double_ts_col {double_ts_field} " + f"has type {double_ts_type}" ) - self.double_ts_col = double_ts_col + self.double_ts_field = double_ts_field # validate the number of precision digits if num_precision_digits <= 6: warnings.warn( @@ -715,30 +715,30 @@ def __init__( num_precision_digits, ) # validate the parsed column as a timestamp column - parsed_ts_field = self.schema[parsed_ts_col] - if not isinstance(parsed_ts_field.dataType, TimestampType): + parsed_ts_type = self.schema[secondary_parsed_ts_field].dataType + if not isinstance(parsed_ts_type, TimestampType): raise TypeError( "parsed_ts_col field must be of TimestampType, " - f"but the given parsed_ts_col {parsed_ts_col} " - f"has type {parsed_ts_field.dataType}" + f"but the given parsed_ts_col {secondary_parsed_ts_field} " + f"has type {parsed_ts_type}" ) - self.parsed_ts_col = parsed_ts_col + self.parsed_ts_field = secondary_parsed_ts_field # validate the source column as a string column - src_str_field = self.schema[src_str_col] + src_str_field = self.schema[src_str_field] if not isinstance(src_str_field.dataType, StringType): raise TypeError( "src_str_col field must be of StringType, " - f"but the given src_str_col {src_str_col} " + f"but the given src_str_col {src_str_field} " f"has type {src_str_field.dataType}" ) - self.src_str_col = src_str_col + self.src_str_col = src_str_field @property def unit(self) -> Optional[TimeUnit]: return self.__unit def comparableExpr(self) -> Column: - return sfn.col(self.fieldPath(self.double_ts_col)) + return sfn.col(self.fieldPath(self.double_ts_field)) def orderByExpr(self, reverse: bool = False) -> Union[Column, List[Column]]: return _reverse_or_not(self.comparableExpr(), reverse) @@ -851,8 +851,10 @@ def __eq__(self, o: object) -> bool: # must be of TSSchema type if not isinstance(o, TSSchema): return False - # must have same TSIndex - if self.ts_idx != o.ts_idx: + # must have TSIndices with the same unit + if not ((self.ts_idx.has_unit == o.ts_idx.has_unit) + and + (self.ts_idx.unit == o.ts_idx.unit)): return False # must have the same series IDs if self.series_ids != o.series_ids: @@ -860,21 +862,57 @@ def __eq__(self, o: object) -> bool: return True def __repr__(self) -> str: - return self.__str__() - - def __str__(self) -> str: - return f"""TSSchema({id(self)}) - TSIndex: {self.ts_idx} - Series IDs: {self.series_ids}""" + return f"{self.__class__.__name__}(ts_idx={self.ts_idx}, series_ids={self.series_ids})" @classmethod - def fromDFSchema( - cls, df_schema: StructType, ts_col: str, series_ids: Collection[str] = None - ) -> "TSSchema": + def fromDFSchema(cls, + df_schema: StructType, + ts_col: str, + series_ids: Optional[Collection[str]] = None) -> "TSSchema": # construct a TSIndex for the given ts_col ts_idx = SimpleTSIndex.fromTSCol(df_schema[ts_col]) return cls(ts_idx, series_ids) + @classmethod + def fromParsedTSIndex(cls, + df_schema: StructType, + ts_col: str, + parsed_field: str, + src_str_field: str, + series_ids: Optional[Collection[str]] = None, + secondary_parsed_field: Optional[str] = None) -> "TSSchema": + ts_idx_schema = df_schema[ts_col].dataType + assert isinstance(ts_idx_schema, StructType), \ + f"Expected a StructType for ts_col {ts_col}, but got {ts_idx_schema}" + # construct the TSIndex + parsed_type = ts_idx_schema[parsed_field].dataType + if isinstance(parsed_type, DoubleType): + ts_idx = SubMicrosecondPrecisionTimestampIndex( + df_schema[ts_col], + parsed_field, + secondary_parsed_field, + src_str_field, + ) + elif isinstance(parsed_type, TimestampType): + ts_idx = ParsedTimestampIndex( + df_schema[ts_col], + parsed_field, + src_str_field, + ) + elif isinstance(parsed_type, DateType): + ts_idx = ParsedDateIndex( + df_schema[ts_col], + parsed_field, + src_str_field, + ) + else: + raise TypeError( + f"Expected a DoubleType, TimestampType or DateType for parsed_field {parsed_field}, " + f"but got {parsed_type}" + ) + # construct the TSSchema + return cls(ts_idx, series_ids) + @property def structural_columns(self) -> list[str]: """ @@ -898,7 +936,7 @@ def find_observational_columns(self, df_schema: StructType) -> list[str]: return list(set(df_schema.fieldNames()) - set(self.structural_columns)) @classmethod - def is_metric_col(cls, col: StructField) -> bool: + def __is_metric_col_type(cls, col: StructField) -> bool: return isinstance(col.dataType, NumericType) or isinstance( col.dataType, BooleanType ) @@ -907,7 +945,7 @@ def find_metric_columns(self, df_schema: StructType) -> list[str]: return [ col.name for col in df_schema.fields - if self.is_metric_col(col) + if self.__is_metric_col_type(col) and (col.name in self.find_observational_columns(df_schema)) ] diff --git a/python/tests/base.py b/python/tests/base.py index a0c37d33..e6aaac2c 100644 --- a/python/tests/base.py +++ b/python/tests/base.py @@ -58,10 +58,8 @@ def tearDownClass(cls) -> None: cls.spark.stop() def setUp(self) -> None: - self.test_data = self.__loadTestData(self.id()) - - def tearDown(self) -> None: - del self.test_data + if self.test_data is None: + self.test_data = self.__loadTestData(self.id()) # # Utility Functions @@ -73,16 +71,15 @@ def get_data_as_sdf(self, name: str, convert_ts_col=True): if convert_ts_col and (td.get("ts_col", None) or td.get("other_ts_cols", [])): ts_cols = [td["ts_col"]] if "ts_col" in td else [] ts_cols.extend(td.get("other_ts_cols", [])) - return self.buildTestDF(td["schema"], td["data"], ts_cols) + return self.buildTestDF(td["df"]) def get_data_as_tsdf(self, name: str, convert_ts_col=True): df = self.get_data_as_sdf(name, convert_ts_col) td = self.test_data[name] if "sequence_col" in td: - tsdf = TSDF.fromSubsequenceCol(df, - td["ts_col"], - td["sequence_col"], - td.get("series_ids", None)) + tsdf = TSDF.fromSubsequenceCol( + df, td["ts_col"], td["sequence_col"], td.get("series_ids", None) + ) else: tsdf = TSDF(df, ts_col=td["ts_col"], series_ids=td.get("series_ids", None)) return tsdf @@ -112,7 +109,8 @@ def __getTestDataFilePath(self, test_file_name: str) -> str: dir_path = "./tests" elif cwd != "tests": raise RuntimeError( - f"Cannot locate test data file {test_file_name}, running from dir {os.getcwd()}" + f"Cannot locate test data file {test_file_name}, running from dir" + f" {os.getcwd()}" ) # return appropriate path @@ -140,35 +138,27 @@ def __loadTestData(self, test_case_path: str) -> dict: if class_name not in data_metadata_from_json: warnings.warn(f"Could not load test data for {file_name}.{class_name}") return {} - if func_name not in data_metadata_from_json[class_name]: - warnings.warn( - f"Could not load test data for {file_name}.{class_name}.{func_name}" - ) - return {} - return data_metadata_from_json[class_name][func_name] - - def buildTestDF(self, schema, data, ts_cols=["event_ts"]): + # if func_name not in data_metadata_from_json[class_name]: + # warnings.warn( + # f"Could not load test data for {file_name}.{class_name}.{func_name}" + # ) + # return {} + # return data_metadata_from_json[class_name][func_name] + return data_metadata_from_json[class_name] + + def buildTestDF(self, df_spec): """ Constructs a Spark Dataframe from the given components - :param schema: the schema to use for the Dataframe - :param data: values to use for the Dataframe - :param ts_cols: list of column names to be converted to Timestamp values + :param df_spec: a dictionary containing the following keys: schema, data, ts_convert :return: a Spark Dataframe, constructed from the given schema and values """ # build dataframe - df = self.spark.createDataFrame(data, schema) - - # check if ts_col follows standard timestamp format, then check if timestamp has micro/nanoseconds - for tsc in ts_cols: - ts_value = str(df.select(ts_cols).limit(1).collect()[0][0]) - ts_pattern = r"^\d{4}-\d{2}-\d{2}| \d{2}:\d{2}:\d{2}\.\d*$" - decimal_pattern = r"[.]\d+" - if re.match(ts_pattern, str(ts_value)) is not None: - if ( - re.search(decimal_pattern, ts_value) is None - or len(re.search(decimal_pattern, ts_value)[0]) <= 4 - ): - df = df.withColumn(tsc, sfn.to_timestamp(sfn.col(tsc))) + df = self.spark.createDataFrame(df_spec['data'], df_spec['schema']) + + # convert timestamp columns + if 'ts_convert' in df_spec: + for ts_col in df_spec['ts_convert']: + df = df.withColumn(ts_col, sfn.to_timestamp(ts_col)) return df # diff --git a/python/tests/tsschema_tests.py b/python/tests/tsschema_tests.py index e040d62a..0a5b0f4a 100644 --- a/python/tests/tsschema_tests.py +++ b/python/tests/tsschema_tests.py @@ -1,7 +1,8 @@ import unittest -from parameterized import parameterized_class +from abc import ABC, abstractmethod +from parameterized import parameterized, parameterized_class -from pyspark.sql import Column +from pyspark.sql import Column, WindowSpec from pyspark.sql import functions as sfn from pyspark.sql.types import ( StructField, @@ -11,6 +12,7 @@ DoubleType, IntegerType, DateType, + NumericType, ) from tempo.tsschema import ( @@ -20,16 +22,39 @@ SimpleDateIndex, StandardTimeUnits, ParsedTimestampIndex, - ParsedDateIndex + ParsedDateIndex, + SubMicrosecondPrecisionTimestampIndex, + TSSchema, ) from tests.base import SparkTest +class TSIndexTests(SparkTest, ABC): + @abstractmethod + def _create_index(self) -> TSIndex: + pass + + def _test_index(self, ts_idx: TSIndex): + # must be a valid TSIndex object + self.assertIsNotNone(ts_idx) + self.assertIsInstance(ts_idx, self.idx_class) + # must have the correct field name and type + self.assertEqual(ts_idx.colname, self.ts_field.name) + self.assertEqual(ts_idx.dataType, self.ts_field.dataType) + # validate the unit + if self.ts_unit is None: + self.assertFalse(ts_idx.has_unit) + else: + self.assertTrue(ts_idx.has_unit) + self.assertEqual(ts_idx.unit, self.ts_unit) + + @parameterized_class( ( "name", "ts_field", "idx_class", + "extra_constr_args", "ts_unit", "expected_comp_expr", "expected_range_expr", @@ -39,6 +64,7 @@ "simple_timestamp_index", StructField("event_ts", TimestampType()), SimpleTimestampIndex, + None, StandardTimeUnits.SECONDS, "Column<'event_ts'>", "Column<'CAST(event_ts AS DOUBLE)'>", @@ -48,6 +74,7 @@ StructField("event_ts_dbl", DoubleType()), OrdinalTSIndex, None, + None, "Column<'event_ts_dbl'>", None, ), @@ -56,6 +83,7 @@ StructField("order", IntegerType()), OrdinalTSIndex, None, + None, "Column<'order'>", None, ), @@ -63,32 +91,81 @@ "simple_date_index", StructField("date", DateType()), SimpleDateIndex, + None, StandardTimeUnits.DAYS, "Column<'date'>", "Column<'datediff(date, CAST(1970-01-01 AS DATE))'>", ), + ( + "parsed_timestamp_index", + StructField( + "ts_idx", + StructType([ + StructField("parsed_ts", TimestampType(), True), + StructField("src_str", StringType(), True), + ]), + True, + ), + ParsedTimestampIndex, + {"parsed_ts_field": "parsed_ts", "src_str_field": "src_str"}, + StandardTimeUnits.SECONDS, + "Column<'ts_idx.parsed_ts'>", + "Column<'CAST(ts_idx.parsed_ts AS DOUBLE)'>", + ), + ( + "parsed_date_index", + StructField( + "ts_idx", + StructType([ + StructField("parsed_date", DateType(), True), + StructField("src_str", StringType(), True), + ]), + True, + ), + ParsedDateIndex, + {"parsed_ts_field": "parsed_date", "src_str_field": "src_str"}, + StandardTimeUnits.DAYS, + "Column<'ts_idx.parsed_date'>", + "Column<'datediff(ts_idx.parsed_date, CAST(1970-01-01 AS DATE))'>", + ), + ( + "sub_ms_index", + StructField( + "ts_idx", + StructType([ + StructField("double_ts", DoubleType(), True), + StructField("parsed_ts", TimestampType(), True), + StructField("src_str", StringType(), True), + ]), + True, + ), + SubMicrosecondPrecisionTimestampIndex, + { + "double_ts_field": "double_ts", + "secondary_parsed_ts_field": "parsed_ts", + "src_str_field": "src_str", + }, + StandardTimeUnits.NANOSECONDS, + "Column<'ts_idx.double_ts'>", + "Column<'ts_idx.double_ts'>", + ), ], ) -class SimpleTSIndexTests(SparkTest): - def test_constructor(self): +class SimpleTSIndexTests(TSIndexTests): + def _create_index(self) -> TSIndex: + if self.extra_constr_args: + return self.idx_class(self.ts_field, **self.extra_constr_args) + return self.idx_class(self.ts_field) + + def test_index(self): # create a timestamp index - ts_idx = self.idx_class(self.ts_field) - # must be a valid TSIndex object - self.assertIsNotNone(ts_idx) - self.assertIsInstance(ts_idx, self.idx_class) - # must have the correct field name and type - self.assertEqual(ts_idx.colname, self.ts_field.name) - self.assertEqual(ts_idx.dataType, self.ts_field.dataType) - # validate the unit - if self.ts_unit is None: - self.assertFalse(ts_idx.has_unit) - else: - self.assertTrue(ts_idx.has_unit) - self.assertEqual(ts_idx.unit, self.ts_unit) + ts_idx = self._create_index() + # test the index + self._test_index(ts_idx) def test_comparable_expression(self): # create a timestamp index - ts_idx: TSIndex = self.idx_class(self.ts_field) + ts_idx = self._create_index() # get the expressions compbl_expr = ts_idx.comparableExpr() # validate the expression @@ -98,7 +175,7 @@ def test_comparable_expression(self): def test_orderby_expression(self): # create a timestamp index - ts_idx: TSIndex = self.idx_class(self.ts_field) + ts_idx = self._create_index() # get the expressions orderby_expr = ts_idx.orderByExpr() # validate the expression @@ -108,7 +185,7 @@ def test_orderby_expression(self): def test_range_expression(self): # create a timestamp index - ts_idx = self.idx_class(self.ts_field) + ts_idx = self._create_index() # get the expressions if isinstance(ts_idx, OrdinalTSIndex): self.assertRaises(NotImplementedError, ts_idx.rangeExpr) @@ -128,7 +205,7 @@ def test_range_expression(self): "idx_class", "ts_unit", "expected_comp_expr", - "expected_range_expr" + "expected_range_expr", ), [ ( @@ -141,11 +218,11 @@ def test_range_expression(self): ]), True, ), - {"parsed_ts_col": "parsed_ts", "src_str_col": "src_str"}, + {"parsed_ts_field": "parsed_ts", "src_str_field": "src_str"}, ParsedTimestampIndex, StandardTimeUnits.SECONDS, "Column<'ts_idx.parsed_ts'>", - "Column<'CAST(ts_idx.parsed_ts AS DOUBLE)'>" + "Column<'CAST(ts_idx.parsed_ts AS DOUBLE)'>", ), ( "parsed_date_index", @@ -157,47 +234,48 @@ def test_range_expression(self): ]), True, ), - {"parsed_ts_col": "parsed_date", "src_str_col": "src_str"}, + {"parsed_ts_field": "parsed_date", "src_str_field": "src_str"}, ParsedDateIndex, StandardTimeUnits.DAYS, "Column<'ts_idx.parsed_date'>", - "Column<'datediff(ts_idx.parsed_date, CAST(1970-01-01 AS DATE))'>" + "Column<'datediff(ts_idx.parsed_date, CAST(1970-01-01 AS DATE))'>", ), ( "sub_ms_index", StructField( "ts_idx", StructType([ - StructField("double_ts", TimestampType(), True), + StructField("double_ts", DoubleType(), True), StructField("parsed_ts", TimestampType(), True), StructField("src_str", StringType(), True), ]), True, ), - {"double_ts_col": "double_ts", "parsed_ts_col": "parsed_ts", "src_str_col": "src_str"}, - ParsedTimestampIndex, - StandardTimeUnits.SECONDS, - "Column<'ts_idx.parsed_ts'>", - "Column<'CAST(ts_idx.parsed_ts AS DOUBLE)'>" + { + "double_ts_field": "double_ts", + "secondary_parsed_ts_field": "parsed_ts", + "src_str_field": "src_str", + }, + SubMicrosecondPrecisionTimestampIndex, + StandardTimeUnits.NANOSECONDS, + "Column<'ts_idx.double_ts'>", + "Column<'ts_idx.double_ts'>", ), - ]) -class ParsedTSIndexTests(SparkTest): - def test_constructor(self): + ], +) +class ParsedTSIndexTests(TSIndexTests): + def _create_index(self): + return self.idx_class(ts_struct=self.ts_field, **self.constr_args) + + def test_index(self): # create a timestamp index - ts_idx = self.idx_class(ts_struct=self.ts_field, **self.constr_args) - # must be a valid TSIndex object - self.assertIsNotNone(ts_idx) - self.assertIsInstance(ts_idx, self.idx_class) - # must have the correct field name and type - self.assertEqual(ts_idx.colname, self.ts_field.name) - self.assertEqual(ts_idx.dataType, self.ts_field.dataType) - # validate the unit - self.assertTrue(ts_idx.has_unit) - self.assertEqual(ts_idx.unit, self.ts_unit) + ts_idx = self._create_index() + # test the index + self._test_index(ts_idx) def test_comparable_expression(self): # create a timestamp index - ts_idx = self.idx_class(ts_struct=self.ts_field, **self.constr_args) + ts_idx = self._create_index() # get the expressions compbl_expr = ts_idx.comparableExpr() # validate the expression @@ -207,7 +285,7 @@ def test_comparable_expression(self): def test_orderby_expression(self): # create a timestamp index - ts_idx = self.idx_class(ts_struct=self.ts_field, **self.constr_args) + ts_idx = self._create_index() # get the expressions orderby_expr = ts_idx.orderByExpr() # validate the expression @@ -217,7 +295,7 @@ def test_orderby_expression(self): def test_range_expression(self): # create a timestamp index - ts_idx = self.idx_class(ts_struct=self.ts_field, **self.constr_args) + ts_idx = self._create_index() # get the expressions range_expr = ts_idx.rangeExpr() # validate the expression @@ -226,10 +304,183 @@ def test_range_expression(self): self.assertEqual(repr(range_expr), self.expected_range_expr) -# class TSSchemaTests(SparkTest): -# def test_simple_tsIndex(self): -# schema_str = "event_ts timestamp, symbol string, trade_pr double" -# schema = _parse_datatype_string(schema_str) -# ts_idx = TSSchema.fromDFSchema(schema, "event_ts", ["symbol"]) -# -# print(ts_idx) +class TSSchemaTests(TSIndexTests, ABC): + @abstractmethod + def _create_ts_schema(self) -> TSSchema: + pass + + +@parameterized_class( + ("name", "df_schema", "ts_col", "series_ids", "idx_class", "ts_unit"), + [ + ( + "simple_timestamp_index", + StructType([ + StructField("symbol", StringType(), True), + StructField("event_ts", TimestampType(), True), + StructField("trade_pr", DoubleType(), True), + StructField("trade_vol", IntegerType(), True), + ]), + "event_ts", + ["symbol"], + SimpleTimestampIndex, + StandardTimeUnits.SECONDS, + ), + ( + "simple_ts_no_series", + StructType([ + StructField("event_ts", TimestampType(), True), + StructField("trade_pr", DoubleType(), True), + StructField("trade_vol", IntegerType(), True), + ]), + "event_ts", + [], + SimpleTimestampIndex, + StandardTimeUnits.SECONDS, + ), + ( + "ordinal_double_index", + StructType([ + StructField("symbol", StringType(), True), + StructField("event_ts_dbl", DoubleType(), True), + StructField("trade_pr", DoubleType(), True), + ]), + "event_ts_dbl", + ["symbol"], + OrdinalTSIndex, + None, + ), + ( + "ordinal_int_index", + StructType([ + StructField("symbol", StringType(), True), + StructField("order", IntegerType(), True), + StructField("trade_pr", DoubleType(), True), + ]), + "order", + ["symbol"], + OrdinalTSIndex, + None, + ), + ( + "simple_date_index", + StructType([ + StructField("symbol", StringType(), True), + StructField("date", DateType(), True), + StructField("trade_pr", DoubleType(), True), + ]), + "date", + ["symbol"], + SimpleDateIndex, + StandardTimeUnits.DAYS, + ), + ], +) +class SimpleIndexTSSchemaTests(TSSchemaTests): + def _create_index(self) -> TSIndex: + pass + + def _create_ts_schema(self) -> TSSchema: + return TSSchema.fromDFSchema(self.df_schema, self.ts_col, self.series_ids) + + def setUp(self) -> None: + super().setUp() + self.ts_field = self.df_schema[self.ts_col] + + def test_schema(self): + print(self.id()) + # create a TSSchema + ts_schema = self._create_ts_schema() + # make sure it's a valid TSSchema instance + self.assertIsNotNone(ts_schema) + self.assertIsInstance(ts_schema, TSSchema) + # test the index + self._test_index(ts_schema.ts_idx) + # test the series ids + self.assertEqual(ts_schema.series_ids, self.series_ids) + # validate the index + ts_schema.validate(self.df_schema) + + def test_structural_cols(self): + # create a TSSchema + ts_schema = self._create_ts_schema() + # test the structural columns + struct_cols = [self.ts_col] + self.series_ids + self.assertEqual(set(ts_schema.structural_columns), set(struct_cols)) + + def test_observational_cols(self): + # create a TSSchema + ts_schema = self._create_ts_schema() + # test the structural columns + struct_cols = [self.ts_col] + self.series_ids + obs_cols = set(self.df_schema.fieldNames()) - set(struct_cols) + self.assertEqual( + ts_schema.find_observational_columns(self.df_schema), list(obs_cols) + ) + + def test_metric_cols(self): + # create a TSSchema + ts_schema = self._create_ts_schema() + # test the metric columns + struct_cols = [self.ts_col] + self.series_ids + obs_cols = set(self.df_schema.fieldNames()) - set(struct_cols) + metric_cols = { + mc + for mc in obs_cols + if isinstance(self.df_schema[mc].dataType, NumericType) + } + self.assertEqual( + set(ts_schema.find_metric_columns(self.df_schema)), metric_cols + ) + + def test_base_window(self): + # create a TSSchema + ts_schema = self._create_ts_schema() + # test the base window + bw = ts_schema.baseWindow() + self.assertIsNotNone(bw) + self.assertIsInstance(bw, WindowSpec) + # test it in reverse + bw_rev = ts_schema.baseWindow(reverse=True) + self.assertIsNotNone(bw_rev) + self.assertIsInstance(bw_rev, WindowSpec) + + def test_rows_window(self): + # create a TSSchema + ts_schema = self._create_ts_schema() + # test the base window + rows_win = ts_schema.rowsBetweenWindow(0, 10) + self.assertIsNotNone(rows_win) + self.assertIsInstance(rows_win, WindowSpec) + # test it in reverse + rows_win_rev = ts_schema.rowsBetweenWindow(0, 10, reverse=True) + self.assertIsNotNone(rows_win_rev) + self.assertIsInstance(rows_win_rev, WindowSpec) + + def test_range_window(self): + # create a TSSchema + ts_schema = self._create_ts_schema() + # test the base window + if ts_schema.ts_idx.has_unit: + range_win = ts_schema.rangeBetweenWindow(0, 10) + self.assertIsNotNone(range_win) + self.assertIsInstance(range_win, WindowSpec) + else: + self.assertRaises(NotImplementedError, ts_schema.rangeBetweenWindow, 0, 10) + # test it in reverse + if ts_schema.ts_idx.has_unit: + range_win_rev = ts_schema.rangeBetweenWindow(0, 10, reverse=True) + self.assertIsNotNone(range_win_rev) + self.assertIsInstance(range_win_rev, WindowSpec) + else: + self.assertRaises( + NotImplementedError, ts_schema.rangeBetweenWindow, 0, 10, reverse=True + ) + + +class ParsedIndexTSSchemaTests(TSSchemaTests): + def _create_index(self) -> TSIndex: + pass + + def _create_ts_schema(self) -> TSSchema: + pass diff --git a/python/tests/unit_test_data/tsschema_tests.json b/python/tests/unit_test_data/tsschema_tests.json index bc820168..3e78def3 100644 --- a/python/tests/unit_test_data/tsschema_tests.json +++ b/python/tests/unit_test_data/tsschema_tests.json @@ -1,26 +1,29 @@ { "__SharedData": { "simple_ts_idx": { - "schema": "symbol string, event_ts string, trade_pr float", - "ts_col": "event_ts", - "series_ids": ["symbol"], - "data": [ - ["S1", "2020-08-01 00:00:10", 349.21], - ["S1", "2020-08-01 00:01:12", 351.32], - ["S1", "2020-09-01 00:02:10", 361.1], - ["S1", "2020-09-01 00:19:12", 362.1], - ["S2", "2020-08-01 00:01:10", 743.01], - ["S2", "2020-08-01 00:01:24", 751.92], - ["S2", "2020-09-01 00:02:10", 761.10], - ["S2", "2020-09-01 00:20:42", 762.33] - ] + "ts_idx": { + "ts_col": "event_ts", + "series_ids": ["symbol"] + }, + "df": { + "schema": "symbol string, event_ts string, trade_pr float", + "ts_convert": ["event_ts"], + "data": [ + ["S1", "2020-08-01 00:00:10", 349.21], + ["S1", "2020-08-01 00:01:12", 351.32], + ["S1", "2020-09-01 00:02:10", 361.1], + ["S1", "2020-09-01 00:19:12", 362.1], + ["S2", "2020-08-01 00:01:10", 743.01], + ["S2", "2020-08-01 00:01:24", 751.92], + ["S2", "2020-09-01 00:02:10", 761.10], + ["S2", "2020-09-01 00:20:42", 762.33] + ] + } } }, "TSSchemaTests": { - "test_simple_tsIndex": { - "simple_ts_idx": { - "$ref": "#/__SharedData/simple_ts_idx" - } + "simple_ts_idx": { + "$ref": "#/__SharedData/simple_ts_idx" } } } \ No newline at end of file From 4d4a8f3001e058bff1f4fd9f5108f8ee872fa7a7 Mon Sep 17 00:00:00 2001 From: Tristan Nixon Date: Mon, 15 Jan 2024 16:50:55 -0800 Subject: [PATCH 10/13] got index & schema test code completed! --- python/tempo/tsschema.py | 18 +- python/tests/tsschema_tests.py | 345 +++++++++++++++++---------------- 2 files changed, 189 insertions(+), 174 deletions(-) diff --git a/python/tempo/tsschema.py b/python/tempo/tsschema.py index 165cefe7..55cd10af 100644 --- a/python/tempo/tsschema.py +++ b/python/tempo/tsschema.py @@ -874,13 +874,13 @@ def fromDFSchema(cls, return cls(ts_idx, series_ids) @classmethod - def fromParsedTSIndex(cls, - df_schema: StructType, - ts_col: str, - parsed_field: str, - src_str_field: str, - series_ids: Optional[Collection[str]] = None, - secondary_parsed_field: Optional[str] = None) -> "TSSchema": + def fromParsedTimestamp(cls, + df_schema: StructType, + ts_col: str, + parsed_field: str, + src_str_field: str, + series_ids: Optional[Collection[str]] = None, + secondary_parsed_field: Optional[str] = None) -> "TSSchema": ts_idx_schema = df_schema[ts_col].dataType assert isinstance(ts_idx_schema, StructType), \ f"Expected a StructType for ts_col {ts_col}, but got {ts_idx_schema}" @@ -907,8 +907,8 @@ def fromParsedTSIndex(cls, ) else: raise TypeError( - f"Expected a DoubleType, TimestampType or DateType for parsed_field {parsed_field}, " - f"but got {parsed_type}" + f"Expected a DoubleType, TimestampType or DateType " + f"for parsed_field {parsed_field}, but got {parsed_type}" ) # construct the TSSchema return cls(ts_idx, series_ids) diff --git a/python/tests/tsschema_tests.py b/python/tests/tsschema_tests.py index 0a5b0f4a..273fd993 100644 --- a/python/tests/tsschema_tests.py +++ b/python/tests/tsschema_tests.py @@ -29,11 +29,7 @@ from tests.base import SparkTest -class TSIndexTests(SparkTest, ABC): - @abstractmethod - def _create_index(self) -> TSIndex: - pass - +class TSIndexTester(unittest.TestCase, ABC): def _test_index(self, ts_idx: TSIndex): # must be a valid TSIndex object self.assertIsNotNone(ts_idx) @@ -151,7 +147,7 @@ def _test_index(self, ts_idx: TSIndex): ), ], ) -class SimpleTSIndexTests(TSIndexTests): +class TSIndexTests(SparkTest, TSIndexTester): def _create_index(self) -> TSIndex: if self.extra_constr_args: return self.idx_class(self.ts_field, **self.extra_constr_args) @@ -200,118 +196,17 @@ def test_range_expression(self): @parameterized_class( ( "name", - "ts_field", + "df_schema", + "constr_method", "constr_args", "idx_class", "ts_unit", - "expected_comp_expr", - "expected_range_expr", + "expected_ts_field", + "expected_series_ids", + "expected_structural_cols", + "expected_obs_cols", + "expected_metric_cols", ), - [ - ( - "parsed_timestamp_index", - StructField( - "ts_idx", - StructType([ - StructField("parsed_ts", TimestampType(), True), - StructField("src_str", StringType(), True), - ]), - True, - ), - {"parsed_ts_field": "parsed_ts", "src_str_field": "src_str"}, - ParsedTimestampIndex, - StandardTimeUnits.SECONDS, - "Column<'ts_idx.parsed_ts'>", - "Column<'CAST(ts_idx.parsed_ts AS DOUBLE)'>", - ), - ( - "parsed_date_index", - StructField( - "ts_idx", - StructType([ - StructField("parsed_date", DateType(), True), - StructField("src_str", StringType(), True), - ]), - True, - ), - {"parsed_ts_field": "parsed_date", "src_str_field": "src_str"}, - ParsedDateIndex, - StandardTimeUnits.DAYS, - "Column<'ts_idx.parsed_date'>", - "Column<'datediff(ts_idx.parsed_date, CAST(1970-01-01 AS DATE))'>", - ), - ( - "sub_ms_index", - StructField( - "ts_idx", - StructType([ - StructField("double_ts", DoubleType(), True), - StructField("parsed_ts", TimestampType(), True), - StructField("src_str", StringType(), True), - ]), - True, - ), - { - "double_ts_field": "double_ts", - "secondary_parsed_ts_field": "parsed_ts", - "src_str_field": "src_str", - }, - SubMicrosecondPrecisionTimestampIndex, - StandardTimeUnits.NANOSECONDS, - "Column<'ts_idx.double_ts'>", - "Column<'ts_idx.double_ts'>", - ), - ], -) -class ParsedTSIndexTests(TSIndexTests): - def _create_index(self): - return self.idx_class(ts_struct=self.ts_field, **self.constr_args) - - def test_index(self): - # create a timestamp index - ts_idx = self._create_index() - # test the index - self._test_index(ts_idx) - - def test_comparable_expression(self): - # create a timestamp index - ts_idx = self._create_index() - # get the expressions - compbl_expr = ts_idx.comparableExpr() - # validate the expression - self.assertIsNotNone(compbl_expr) - self.assertIsInstance(compbl_expr, Column) - self.assertEqual(repr(compbl_expr), self.expected_comp_expr) - - def test_orderby_expression(self): - # create a timestamp index - ts_idx = self._create_index() - # get the expressions - orderby_expr = ts_idx.orderByExpr() - # validate the expression - self.assertIsNotNone(orderby_expr) - self.assertIsInstance(orderby_expr, Column) - self.assertEqual(repr(orderby_expr), self.expected_comp_expr) - - def test_range_expression(self): - # create a timestamp index - ts_idx = self._create_index() - # get the expressions - range_expr = ts_idx.rangeExpr() - # validate the expression - self.assertIsNotNone(range_expr) - self.assertIsInstance(range_expr, Column) - self.assertEqual(repr(range_expr), self.expected_range_expr) - - -class TSSchemaTests(TSIndexTests, ABC): - @abstractmethod - def _create_ts_schema(self) -> TSSchema: - pass - - -@parameterized_class( - ("name", "df_schema", "ts_col", "series_ids", "idx_class", "ts_unit"), [ ( "simple_timestamp_index", @@ -321,10 +216,15 @@ def _create_ts_schema(self) -> TSSchema: StructField("trade_pr", DoubleType(), True), StructField("trade_vol", IntegerType(), True), ]), - "event_ts", - ["symbol"], + "fromDFSchema", + {"ts_col": "event_ts", "series_ids": ["symbol"]}, SimpleTimestampIndex, StandardTimeUnits.SECONDS, + "event_ts", + ["symbol"], + ["event_ts", "symbol"], + ["trade_pr", "trade_vol"], + ["trade_pr", "trade_vol"], ), ( "simple_ts_no_series", @@ -333,10 +233,15 @@ def _create_ts_schema(self) -> TSSchema: StructField("trade_pr", DoubleType(), True), StructField("trade_vol", IntegerType(), True), ]), - "event_ts", - [], + "fromDFSchema", + {"ts_col": "event_ts", "series_ids": []}, SimpleTimestampIndex, StandardTimeUnits.SECONDS, + "event_ts", + [], + ["event_ts"], + ["trade_pr", "trade_vol"], + ["trade_pr", "trade_vol"], ), ( "ordinal_double_index", @@ -345,10 +250,15 @@ def _create_ts_schema(self) -> TSSchema: StructField("event_ts_dbl", DoubleType(), True), StructField("trade_pr", DoubleType(), True), ]), - "event_ts_dbl", - ["symbol"], + "fromDFSchema", + {"ts_col": "event_ts_dbl", "series_ids": ["symbol"]}, OrdinalTSIndex, None, + "event_ts_dbl", + ["symbol"], + ["event_ts_dbl", "symbol"], + ["trade_pr"], + ["trade_pr"], ), ( "ordinal_int_index", @@ -357,10 +267,15 @@ def _create_ts_schema(self) -> TSSchema: StructField("order", IntegerType(), True), StructField("trade_pr", DoubleType(), True), ]), - "order", - ["symbol"], + "fromDFSchema", + {"ts_col": "order", "series_ids": ["symbol"]}, OrdinalTSIndex, None, + "order", + ["symbol"], + ["order", "symbol"], + ["trade_pr"], + ["trade_pr"], ), ( "simple_date_index", @@ -369,68 +284,176 @@ def _create_ts_schema(self) -> TSSchema: StructField("date", DateType(), True), StructField("trade_pr", DoubleType(), True), ]), + "fromDFSchema", + {"ts_col": "date", "series_ids": ["symbol"]}, + SimpleDateIndex, + StandardTimeUnits.DAYS, "date", ["symbol"], - SimpleDateIndex, + ["date", "symbol"], + ["trade_pr"], + ["trade_pr"], + ), + ( + "parsed_timestamp_index", + StructType([ + StructField("symbol", StringType(), True), + StructField( + "ts_idx", + StructType([ + StructField("parsed_ts", TimestampType(), True), + StructField("src_str", StringType(), True), + ]), + True, + ), + StructField("trade_pr", DoubleType(), True), + StructField("trade_vol", IntegerType(), True), + ]), + "fromParsedTimestamp", + { + "ts_col": "ts_idx", + "parsed_field": "parsed_ts", + "src_str_field": "src_str", + "series_ids": ["symbol"], + }, + ParsedTimestampIndex, + StandardTimeUnits.SECONDS, + "ts_idx", + ["symbol"], + ["ts_idx", "symbol"], + ["trade_pr", "trade_vol"], + ["trade_pr", "trade_vol"], + ), + ( + "parsed_ts_no_series", + StructType([ + StructField( + "ts_idx", + StructType([ + StructField("parsed_ts", TimestampType(), True), + StructField("src_str", StringType(), True), + ]), + True, + ), + StructField("trade_pr", DoubleType(), True), + StructField("trade_vol", IntegerType(), True), + ]), + "fromParsedTimestamp", + { + "ts_col": "ts_idx", + "parsed_field": "parsed_ts", + "src_str_field": "src_str", + }, + ParsedTimestampIndex, + StandardTimeUnits.SECONDS, + "ts_idx", + [], + ["ts_idx"], + ["trade_pr", "trade_vol"], + ["trade_pr", "trade_vol"], + ), + ( + "parsed_date_index", + StructType([ + StructField( + "ts_idx", + StructType([ + StructField("parsed_date", DateType(), True), + StructField("src_str", StringType(), True), + ]), + True, + ), + StructField("symbol", StringType(), True), + StructField("trade_pr", DoubleType(), True), + ]), + "fromParsedTimestamp", + { + "ts_col": "ts_idx", + "parsed_field": "parsed_date", + "src_str_field": "src_str", + "series_ids": ["symbol"], + }, + ParsedDateIndex, StandardTimeUnits.DAYS, + "ts_idx", + ["symbol"], + ["ts_idx", "symbol"], + ["trade_pr"], + ["trade_pr"], + ), + ( + "sub_ms_index", + StructType([ + StructField( + "ts_idx", + StructType([ + StructField("double_ts", DoubleType(), True), + StructField("parsed_ts", TimestampType(), True), + StructField("src_str", StringType(), True), + ]), + True, + ), + StructField("symbol", StringType(), True), + StructField("trade_pr", DoubleType(), True), + ]), + "fromParsedTimestamp", + { + "ts_col": "ts_idx", + "parsed_field": "double_ts", + "src_str_field": "src_str", + "secondary_parsed_field": "parsed_ts", + "series_ids": ["symbol"], + }, + SubMicrosecondPrecisionTimestampIndex, + StandardTimeUnits.NANOSECONDS, + "ts_idx", + ["symbol"], + ["ts_idx", "symbol"], + ["trade_pr"], + ["trade_pr"], ), ], ) -class SimpleIndexTSSchemaTests(TSSchemaTests): - def _create_index(self) -> TSIndex: - pass - +class TSSchemaTests(SparkTest, TSIndexTester): def _create_ts_schema(self) -> TSSchema: - return TSSchema.fromDFSchema(self.df_schema, self.ts_col, self.series_ids) + return getattr(TSSchema, self.constr_method)(self.df_schema, **self.constr_args) def setUp(self) -> None: super().setUp() + self.ts_col = self.constr_args["ts_col"] self.ts_field = self.df_schema[self.ts_col] + self.ts_schema = self._create_ts_schema() def test_schema(self): - print(self.id()) - # create a TSSchema - ts_schema = self._create_ts_schema() # make sure it's a valid TSSchema instance - self.assertIsNotNone(ts_schema) - self.assertIsInstance(ts_schema, TSSchema) + self.assertIsNotNone(self.ts_schema) + self.assertIsInstance(self.ts_schema, TSSchema) # test the index - self._test_index(ts_schema.ts_idx) - # test the series ids - self.assertEqual(ts_schema.series_ids, self.series_ids) + self._test_index(self.ts_schema.ts_idx) # validate the index - ts_schema.validate(self.df_schema) + self.ts_schema.validate(self.df_schema) + + def test_series_ids(self): + # test the series ids + self.assertEqual(self.ts_schema.series_ids, self.expected_series_ids) def test_structural_cols(self): - # create a TSSchema - ts_schema = self._create_ts_schema() # test the structural columns - struct_cols = [self.ts_col] + self.series_ids - self.assertEqual(set(ts_schema.structural_columns), set(struct_cols)) + self.assertEqual(set(self.ts_schema.structural_columns), + set(self.expected_structural_cols)) def test_observational_cols(self): - # create a TSSchema - ts_schema = self._create_ts_schema() - # test the structural columns - struct_cols = [self.ts_col] + self.series_ids - obs_cols = set(self.df_schema.fieldNames()) - set(struct_cols) + # test the observational columns self.assertEqual( - ts_schema.find_observational_columns(self.df_schema), list(obs_cols) + set(self.ts_schema.find_observational_columns(self.df_schema)), + set(self.expected_obs_cols) ) def test_metric_cols(self): - # create a TSSchema - ts_schema = self._create_ts_schema() # test the metric columns - struct_cols = [self.ts_col] + self.series_ids - obs_cols = set(self.df_schema.fieldNames()) - set(struct_cols) - metric_cols = { - mc - for mc in obs_cols - if isinstance(self.df_schema[mc].dataType, NumericType) - } self.assertEqual( - set(ts_schema.find_metric_columns(self.df_schema)), metric_cols + set(self.ts_schema.find_metric_columns(self.df_schema)), + set(self.expected_metric_cols) ) def test_base_window(self): @@ -476,11 +499,3 @@ def test_range_window(self): self.assertRaises( NotImplementedError, ts_schema.rangeBetweenWindow, 0, 10, reverse=True ) - - -class ParsedIndexTSSchemaTests(TSSchemaTests): - def _create_index(self) -> TSIndex: - pass - - def _create_ts_schema(self) -> TSSchema: - pass From 8898741d7e5a1d55e2707761cdb47e6f4d6de10c Mon Sep 17 00:00:00 2001 From: Tristan Nixon Date: Wed, 17 Jan 2024 10:57:36 -0800 Subject: [PATCH 11/13] big changes to test code framework new TSDF Tests --- python/tempo/tsdf.py | 14 +- python/tempo/tsschema.py | 12 +- python/tests/base.py | 202 +- python/tests/interpol_tests.py | 2 +- python/tests/intervals_tests.py | 2 +- python/tests/tsdf_tests.py | 1575 +++------ python/tests/unit_test_data/tsdf_tests.json | 3129 +---------------- .../tests/unit_test_data/tsschema_tests.json | 29 - python/tests/utils_tests.py | 2 +- 9 files changed, 659 insertions(+), 4308 deletions(-) delete mode 100644 python/tests/unit_test_data/tsschema_tests.json diff --git a/python/tempo/tsdf.py b/python/tempo/tsdf.py index 345f7c35..945d9f07 100644 --- a/python/tempo/tsdf.py +++ b/python/tempo/tsdf.py @@ -49,14 +49,12 @@ def __init__(self, self.ts_schema.validate(df.schema) def __repr__(self) -> str: - return self.__str__() - - def __str__(self) -> str: - return f"""TSDF({id(self)}): - TS Index: {self.ts_index} - Series IDs: {self.series_ids} - Observational Cols: {self.observational_cols} - DataFrame: {self.df.schema}""" + return f"{self.__class__.__name__}(df={self.df}, ts_schema={self.ts_schema})" + + def __eq__(self, other: Any) -> bool: + if not isinstance(other, TSDF): + return False + return self.ts_schema == other.ts_schema and self.df == other.df def __withTransformedDF(self, new_df: DataFrame) -> "TSDF": """ diff --git a/python/tempo/tsschema.py b/python/tempo/tsschema.py index 55cd10af..a0d03b02 100644 --- a/python/tempo/tsschema.py +++ b/python/tempo/tsschema.py @@ -158,22 +158,22 @@ def comparableExpr(self) -> Union[Column, List[Column]]: """ def __eq__(self, other) -> Column: - return self.comparableExpr().eq(_col_or_lit(other)) + return self.comparableExpr() == _col_or_lit(other) def __ne__(self, other) -> Column: - return self.comparableExpr().neq(_col_or_lit(other)) + return self.comparableExpr() != _col_or_lit(other) def __lt__(self, other) -> Column: - return self.comparableExpr().lt(_col_or_lit(other)) + return self.comparableExpr() < _col_or_lit(other) def __le__(self, other) -> Column: - return self.comparableExpr().leq(_col_or_lit(other)) + return self.comparableExpr() <= _col_or_lit(other) def __gt__(self, other) -> Column: - return self.comparableExpr().gt(_col_or_lit(other)) + return self.comparableExpr() > _col_or_lit(other) def __ge__(self, other) -> Column: - return self.comparableExpr().geq(_col_or_lit(other)) + return self.comparableExpr() >= _col_or_lit(other) # other expression builder methods diff --git a/python/tests/base.py b/python/tests/base.py index e6aaac2c..a450bf59 100644 --- a/python/tests/base.py +++ b/python/tests/base.py @@ -2,7 +2,7 @@ import re import unittest import warnings -from typing import Union +from typing import Union, Optional import jsonref from chispa import assert_df_equality @@ -15,6 +15,82 @@ from tempo.tsdf import TSDF +class TestDataFrame: + """ + A class to hold metadata about a Spark DataFrame + """ + + def __init__(self, spark: SparkSession, test_data: dict): + """ + :param spark: the SparkSession to use + :param test_data: a dictionary containing the test data & metadata + """ + self.spark = spark + self.__test_data = test_data + + @property + def df(self): + """ + :return: the DataFrame component of the test data + """ + return self.__test_data["df"] + + @property + def schema(self): + """ + :return: the schema component of the test data + """ + return self.df["schema"] + + def data(self): + """ + :return: the data component of the test data + """ + return self.df["data"] + + @property + def ts_idx(self): + """ + :return: the timestamp index metadata component of the test data + """ + return self.__test_data["ts_idx"] + + @property + def tsdf_constructor(self) -> Optional[str]: + """ + :return: the name of the TSDF constructor to use + """ + return self.__test_data.get("tsdf_constructor", None) + + def as_sdf(self) -> DataFrame: + """ + Constructs a Spark Dataframe from the test data + """ + # build dataframe + df = self.spark.createDataFrame(self.data(), self.schema) + + # convert timestamp columns + if "ts_convert" in self.df: + for ts_col in self.df["ts_convert"]: + df = df.withColumn(ts_col, sfn.to_timestamp(ts_col)) + # convert date columns + if "date_convert" in self.df: + for date_col in self.df["date_convert"]: + df = df.withColumn(date_col, sfn.to_date(date_col)) + + return df + + def as_tsdf(self) -> TSDF: + """ + Constructs a TSDF from the test data + """ + sdf = self.as_sdf() + if self.tsdf_constructor is not None: + return getattr(TSDF, self.tsdf_constructor)(sdf, **self.ts_idx) + else: + return TSDF(sdf, **self.ts_idx) + + class SparkTest(unittest.TestCase): # # Fixtures @@ -62,38 +138,40 @@ def setUp(self) -> None: self.test_data = self.__loadTestData(self.id()) # - # Utility Functions + # Test Data Loading Functions # - def get_data_as_sdf(self, name: str, convert_ts_col=True): - td = self.test_data[name] - ts_cols = [] - if convert_ts_col and (td.get("ts_col", None) or td.get("other_ts_cols", [])): - ts_cols = [td["ts_col"]] if "ts_col" in td else [] - ts_cols.extend(td.get("other_ts_cols", [])) - return self.buildTestDF(td["df"]) - - def get_data_as_tsdf(self, name: str, convert_ts_col=True): - df = self.get_data_as_sdf(name, convert_ts_col) - td = self.test_data[name] - if "sequence_col" in td: - tsdf = TSDF.fromSubsequenceCol( - df, td["ts_col"], td["sequence_col"], td.get("series_ids", None) - ) - else: - tsdf = TSDF(df, ts_col=td["ts_col"], series_ids=td.get("series_ids", None)) - return tsdf - - def get_data_as_idf(self, name: str, convert_ts_col=True): - df = self.get_data_as_sdf(name, convert_ts_col) - td = self.test_data[name] - idf = IntervalsDF( - df, - start_ts=td["start_ts"], - end_ts=td["end_ts"], - series_ids=td.get("series", None), - ) - return idf + # def get_data_as_sdf(self, name: str) -> DataFrame: + # td = self.test_data[name] + # return self.buildTestDF(td["df"]) + + # def get_ts_idx_metadata(self, name: str) -> dict: + # td = self.test_data[name] + # return td["ts_idx"] + # + # def get_tsdf_constructor_fn(self, name: str) -> str: + # td = self.test_data[name] + # return td.get("constructor", None) + # + # def get_data_as_tsdf(self, name: str) -> TSDF: + # sdf = self.get_data_as_sdf(name) + # ts_idx_meta = self.get_ts_idx_metadata(name) + # tsdf_constructor = self.get_tsdf_constructor_fn(name) + # if tsdf_constructor is not None: + # return getattr(TSDF, tsdf_constructor)(sdf, **ts_idx_meta) + # else: + # return TSDF(sdf, **ts_idx_meta) + # + # def get_data_as_idf(self, name: str, convert_ts_col=True): + # df = self.get_data_as_sdf(name, convert_ts_col) + # td = self.test_data[name] + # idf = IntervalsDF( + # df, + # start_ts=td["start_ts"], + # end_ts=td["end_ts"], + # series_ids=td.get("series", None), + # ) + # return idf TEST_DATA_FOLDER = "unit_test_data" @@ -134,32 +212,26 @@ def __loadTestData(self, test_case_path: str) -> dict: # proces the data file with open(test_data_file, "r") as f: data_metadata_from_json = jsonref.load(f) - # warn if data not present - if class_name not in data_metadata_from_json: - warnings.warn(f"Could not load test data for {file_name}.{class_name}") - return {} - # if func_name not in data_metadata_from_json[class_name]: - # warnings.warn( - # f"Could not load test data for {file_name}.{class_name}.{func_name}" - # ) - # return {} - # return data_metadata_from_json[class_name][func_name] - return data_metadata_from_json[class_name] - - def buildTestDF(self, df_spec): - """ - Constructs a Spark Dataframe from the given components - :param df_spec: a dictionary containing the following keys: schema, data, ts_convert - :return: a Spark Dataframe, constructed from the given schema and values - """ - # build dataframe - df = self.spark.createDataFrame(df_spec['data'], df_spec['schema']) + # return the data + return data_metadata_from_json - # convert timestamp columns - if 'ts_convert' in df_spec: - for ts_col in df_spec['ts_convert']: - df = df.withColumn(ts_col, sfn.to_timestamp(ts_col)) - return df + def get_test_data(self, name: str) -> TestDataFrame: + return TestDataFrame(self.spark, self.test_data[name]) + + # def buildTestDF(self, df_spec) -> DataFrame: + # """ + # Constructs a Spark Dataframe from the given components + # :param df_spec: a dictionary containing the following keys: schema, data, ts_convert + # :return: a Spark Dataframe, constructed from the given schema and values + # """ + # # build dataframe + # df = self.spark.createDataFrame(df_spec['data'], df_spec['schema']) + # + # # convert timestamp columns + # if 'ts_convert' in df_spec: + # for ts_col in df_spec['ts_convert']: + # df = df.withColumn(ts_col, sfn.to_timestamp(ts_col)) + # return df # # Assertion Functions @@ -191,12 +263,11 @@ def assertSchemaContainsField(self, schema, field): # the attributes of the fields must be equal self.assertFieldsEqual(field, schema[field.name]) - @staticmethod + def assertDataFrameEquality( - df1: Union[IntervalsDF, TSDF, DataFrame], - df2: Union[IntervalsDF, TSDF, DataFrame], - from_tsdf: bool = False, - from_idf: bool = False, + self, + df1: Union[TSDF, DataFrame], + df2: Union[TSDF, DataFrame], ignore_row_order: bool = False, ignore_column_order: bool = True, ignore_nullable: bool = True, @@ -206,10 +277,17 @@ def assertDataFrameEquality( That is, they have equivalent schemas, and both contain the same values """ - if from_tsdf or from_idf: + # handle TSDFs + if isinstance(df1, TSDF): + # df2 must also be a TSDF + self.assertIsInstance(df2, TSDF) + # should have the same schemas + self.assertEqual(df1.ts_schema, df2.ts_schema) + # get the underlying Spark DataFrames df1 = df1.df df2 = df2.df + # handle DataFrames assert_df_equality( df1, df2, diff --git a/python/tests/interpol_tests.py b/python/tests/interpol_tests.py index 0235a011..bc3465ac 100644 --- a/python/tests/interpol_tests.py +++ b/python/tests/interpol_tests.py @@ -4,7 +4,7 @@ from tempo.interpol import Interpolation from tempo.tsdf import TSDF -from tests.tsdf_tests import SparkTest +from tests.base import SparkTest class InterpolationUnitTest(SparkTest): diff --git a/python/tests/intervals_tests.py b/python/tests/intervals_tests.py index e39ef4ee..dbbdce20 100644 --- a/python/tests/intervals_tests.py +++ b/python/tests/intervals_tests.py @@ -3,7 +3,7 @@ from pyspark.sql.utils import AnalysisException from tempo.intervals import IntervalsDF -from tests.tsdf_tests import SparkTest +from tests.base import SparkTest class IntervalsDFTests(SparkTest): diff --git a/python/tests/tsdf_tests.py b/python/tests/tsdf_tests.py index 10ce8948..12d23c02 100644 --- a/python/tests/tsdf_tests.py +++ b/python/tests/tsdf_tests.py @@ -1,1128 +1,451 @@ -import os -import sys -import unittest -from io import StringIO -from unittest import mock -from unittest.mock import patch - -from dateutil import parser as dt_parser - -import pyspark.sql.functions as sfn -from pyspark.sql.column import Column -from pyspark.sql.dataframe import DataFrame -from pyspark.sql.window import WindowSpec +from parameterized import parameterized from tempo.tsdf import TSDF -from tests.base import SparkTest - - -class TSDFBaseTests(SparkTest): - def test_TSDF_init(self): - tsdf_init = self.get_data_as_tsdf("init") - - self.assertIsInstance(tsdf_init.df, DataFrame) - self.assertEqual(tsdf_init.ts_col, "event_ts") - self.assertEqual(tsdf_init.series_ids, ["symbol"]) - - def test_describe(self): - """AS-OF Join without a time-partition test""" - - # Construct dataframes - tsdf_init = self.get_data_as_tsdf("init") - - # generate description dataframe - res = tsdf_init.describe() - - # joined dataframe should equal the expected dataframe - # self.assertDataFrameEquality(res, dfExpected) - assert res.count() == 7 - assert ( - res.filter(sfn.col("unique_time_series_count") != " ") - .select(sfn.max(sfn.col("unique_time_series_count"))) - .collect()[0][0] - == "1" - ) - assert ( - res.filter(sfn.col("min_ts") != " ") - .select(sfn.col("min_ts").cast("string")) - .collect()[0][0] - == "2020-08-01 00:00:10" - ) - assert ( - res.filter(sfn.col("max_ts") != " ") - .select(sfn.col("max_ts").cast("string")) - .collect()[0][0] - == "2020-09-01 00:19:12" - ) - - def test__getSparkPlan(self): - init_tsdf = self.get_data_as_tsdf("init") - - plan = init_tsdf._TSDF__getSparkPlan(init_tsdf.df, self.spark) - - self.assertIsInstance(plan, str) - self.assertIn("Optimized Logical Plan", plan) - self.assertIn("Physical Plan", plan) - self.assertIn("sizeInBytes", plan) - - def test__getBytesFromPlan(self): - init_tsdf = self.get_data_as_tsdf("init") - - _bytes = init_tsdf._TSDF__getBytesFromPlan(init_tsdf.df, self.spark) - - self.assertEqual(_bytes, 6.2) - - @patch("tempo.tsdf.TSDF._TSDF__getSparkPlan") - def test__getBytesFromPlan_search_result_is_None(self, mock__getSparkPlan): - mock__getSparkPlan.return_value = "will not match search value" - - init_tsdf = self.get_data_as_tsdf("init") - - self.assertRaises( - ValueError, - init_tsdf._TSDF__getBytesFromPlan, - init_tsdf.df, - self.spark, - ) - - @patch("tempo.tsdf.TSDF._TSDF__getSparkPlan") - def test__getBytesFromPlan_size_in_MiB(self, mock__getSparkPlan): - mock__getSparkPlan.return_value = "' Statistics(sizeInBytes=1.0 MiB) '" - - init_tsdf = self.get_data_as_tsdf("init") - - _bytes = init_tsdf._TSDF__getBytesFromPlan(init_tsdf.df, self.spark) - expected = 1 * 1024 * 1024 - - self.assertEqual(_bytes, expected) - - @patch("tempo.tsdf.TSDF._TSDF__getSparkPlan") - def test__getBytesFromPlan_size_in_KiB(self, mock__getSparkPlan): - mock__getSparkPlan.return_value = "' Statistics(sizeInBytes=1.0 KiB) '" - - init_tsdf = self.get_data_as_tsdf("init") - - _bytes = init_tsdf._TSDF__getBytesFromPlan(init_tsdf.df, self.spark) - - self.assertEqual(_bytes, 1 * 1024) - - @patch("tempo.tsdf.TSDF._TSDF__getSparkPlan") - def test__getBytesFromPlan_size_in_GiB(self, mock__getSparkPlan): - mock__getSparkPlan.return_value = "' Statistics(sizeInBytes=1.0 GiB) '" - - init_tsdf = self.get_data_as_tsdf("init") - - _bytes = init_tsdf._TSDF__getBytesFromPlan(init_tsdf.df, self.spark) - - self.assertEqual(_bytes, 1 * 1024 * 1024 * 1024) - - @staticmethod - @mock.patch.dict(os.environ, {"TZ": "UTC"}) - def __timestamp_to_double(ts: str) -> float: - return dt_parser.isoparse(ts).timestamp() - - @staticmethod - def __tsdf_with_double_tscol(tsdf: TSDF) -> TSDF: - return tsdf.withColumn(tsdf.ts_col, - sfn.col(tsdf.ts_col).cast("double")) - - # TODO: write equivalent test for a double ts column - # def test__add_double_ts(self): - # init_tsdf = self.get_data_as_tsdf("init") - # df = init_tsdf._TSDF__add_double_ts() - # - # schema_string = df.schema.simpleString() - # - # self.assertIn("double_ts:double", schema_string) - - # TODO: write equivalent test for TSDFs with string timestamps - # def test__validate_ts_string_valid(self): - # valid_timestamp_string = "2020-09-01 00:02:10" - # - # self.assertIsNone(TSDF._TSDF__validate_ts_string(valid_timestamp_string)) - # - # def test__validate_ts_string_alt_format_valid(self): - # valid_timestamp_string = "2020-09-01T00:02:10" - # - # self.assertIsNone(TSDF._TSDF__validate_ts_string(valid_timestamp_string)) - # - # def test__validate_ts_string_with_microseconds_valid(self): - # valid_timestamp_string = "2020-09-01 00:02:10.00000000" - # - # self.assertIsNone(TSDF._TSDF__validate_ts_string(valid_timestamp_string)) - # - # def test__validate_ts_string_alt_format_with_microseconds_valid(self): - # valid_timestamp_string = "2020-09-01T00:02:10.00000000" - # - # self.assertIsNone(TSDF._TSDF__validate_ts_string(valid_timestamp_string)) - # - # def test__validate_ts_string_invalid(self): - # invalid_timestamp_string = "this will not work" - # - # self.assertRaises( - # ValueError, TSDF._TSDF__validate_ts_string, invalid_timestamp_string - # ) - - # TODO: write equivalent test for testing TSDF initialization - # def test__validated_column_not_string(self): - # init_df = self.get_data_as_tsdf("init").df - # - # self.assertRaises(TypeError, TSDF._TSDF__validated_column, init_df, 0) - # - # def test__validated_column_not_found(self): - # init_df = self.get_data_as_tsdf("init").df - # - # self.assertRaises( - # ValueError, - # TSDF._TSDF__validated_column, - # init_df, - # "does not exist", - # ) - # - # def test__validated_column(self): - # init_df = self.get_data_as_tsdf("init").df - # - # self.assertEqual( - # TSDF._TSDF__validated_column(init_df, "symbol"), - # "symbol", - # ) - # - # def test__validated_columns_string(self): - # init_tsdf = self.get_data_as_tsdf("init") - # - # self.assertEqual( - # init_tsdf._TSDF__validated_columns(init_tsdf.df, "symbol"), - # ["symbol"], - # ) - # - # def test__validated_columns_none(self): - # init_tsdf = self.get_data_as_tsdf("init") - # - # self.assertEqual( - # init_tsdf._TSDF__validated_columns(init_tsdf.df, None), - # [], - # ) - # - # def test__validated_columns_tuple(self): - # init_tsdf = self.get_data_as_tsdf("init") - # - # self.assertRaises( - # TypeError, - # init_tsdf._TSDF__validated_columns, - # init_tsdf.df, - # ("symbol",), - # ) - # - # def test__validated_columns_list_multiple_elems(self): - # init_tsdf = self.get_data_as_tsdf("init") - # - # self.assertEqual( - # init_tsdf._TSDF__validated_columns( - # init_tsdf.df, - # ["symbol", "event_ts", "trade_pr"], - # ), - # ["symbol", "event_ts", "trade_pr"], - # ) - - def test__checkPartitionCols(self): - init_tsdf = self.get_data_as_tsdf("init") - right_tsdf = self.get_data_as_tsdf("right_tsdf") - - self.assertRaises(ValueError, init_tsdf._TSDF__checkPartitionCols, right_tsdf) - - def test__validateTsColMatch(self): - init_tsdf = self.get_data_as_tsdf("init") - right_tsdf = self.get_data_as_tsdf("right_tsdf") - - self.assertRaises(ValueError, init_tsdf._TSDF__validateTsColMatch, right_tsdf) - - def test__addPrefixToColumns_non_empty_string(self): - init_tsdf = self.get_data_as_tsdf("init") - - df = init_tsdf._TSDF__addPrefixToColumns(["event_ts"], "prefix").df - - schema_string = df.schema.simpleString() - - self.assertIn("prefix_event_ts", schema_string) - - def test__addPrefixToColumns_empty_string(self): - init_tsdf = self.get_data_as_tsdf("init") - - df = init_tsdf._TSDF__addPrefixToColumns(["event_ts"], "").df - - schema_string = df.schema.simpleString() - - # comma included (,event_ts) to ensure we don't match if there is a prefix added - self.assertIn(",event_ts", schema_string) - - def test__addColumnsFromOtherDF(self): - init_tsdf = self.get_data_as_tsdf("init") - - df = init_tsdf._TSDF__addColumnsFromOtherDF(["another_col"]).df - - schema_string = df.schema.simpleString() - - self.assertIn("another_col", schema_string) - - def test__combineTSDF(self): - init1_tsdf = self.get_data_as_tsdf("init") - init2_tsdf = self.get_data_as_tsdf("init") - - union_tsdf = init1_tsdf._TSDF__combineTSDF(init2_tsdf, "combined_ts_col") - df = union_tsdf.df - - schema_string = df.schema.simpleString() - - self.assertEqual(init1_tsdf.df.count() + init2_tsdf.df.count(), df.count()) - self.assertIn("combined_ts_col", schema_string) - - def test__getLastRightRow(self): - # TODO: several errors and hard-coded columns that throw AnalysisException - pass - - def test__getTimePartitions(self): - init_tsdf = self.get_data_as_tsdf("init") - expected_tsdf = self.get_data_as_tsdf("expected") - - actual_tsdf = init_tsdf._TSDF__getTimePartitions(10) - - self.assertDataFrameEquality( - actual_tsdf, - expected_tsdf, - from_tsdf=True, - ) - - def test__getTimePartitions_with_fraction(self): - init_tsdf = self.get_data_as_tsdf("init") - expected_tsdf = self.get_data_as_tsdf("expected") - - actual_tsdf = init_tsdf._TSDF__getTimePartitions(10, 0.25) - - self.assertDataFrameEquality( - actual_tsdf, - expected_tsdf, - from_tsdf=True, - ) - - def test_select_empty(self): - # TODO: Can we narrow down to types of Exception? - init_tsdf = self.get_data_as_tsdf("init") - - self.assertRaises(Exception, init_tsdf.select) - - def test_select_only_required_cols(self): - init_tsdf = self.get_data_as_tsdf("init") - - tsdf = init_tsdf.select("event_ts", "symbol") - - self.assertEqual(tsdf.df.columns, ["event_ts", "symbol"]) - - def test_select_all_cols(self): - init_tsdf = self.get_data_as_tsdf("init") - - tsdf = init_tsdf.select("event_ts", "symbol", "trade_pr") - - self.assertEqual(tsdf.df.columns, ["event_ts", "symbol", "trade_pr"]) - - def test_show(self): - init_tsdf = self.get_data_as_tsdf("init") - - captured_output = StringIO() - sys.stdout = captured_output - init_tsdf.show() - self.assertEqual( - captured_output.getvalue(), - ( - "+------+-------------------+--------+\n" - "|symbol| event_ts|trade_pr|\n" - "+------+-------------------+--------+\n" - "| S1|2020-08-01 00:00:10| 349.21|\n" - "| S1|2020-08-01 00:01:12| 351.32|\n" - "| S1|2020-09-01 00:02:10| 361.1|\n" - "| S1|2020-09-01 00:19:12| 362.1|\n" - "| S2|2020-08-01 00:01:10| 743.01|\n" - "| S2|2020-08-01 00:01:24| 751.92|\n" - "| S2|2020-09-01 00:02:10| 761.1|\n" - "| S2|2020-09-01 00:20:42| 762.33|\n" - "+------+-------------------+--------+\n" - "\n" - ), - ) - - def test_show_n_5(self): - init_tsdf = self.get_data_as_tsdf("init") - - captured_output = StringIO() - sys.stdout = captured_output - init_tsdf.show(5) - self.assertEqual( - captured_output.getvalue(), - ( - "+------+-------------------+--------+\n" - "|symbol| event_ts|trade_pr|\n" - "+------+-------------------+--------+\n" - "| S1|2020-08-01 00:00:10| 349.21|\n" - "| S1|2020-08-01 00:01:12| 351.32|\n" - "| S1|2020-09-01 00:02:10| 361.1|\n" - "| S1|2020-09-01 00:19:12| 362.1|\n" - "| S2|2020-08-01 00:01:10| 743.01|\n" - "+------+-------------------+--------+\n" - "only showing top 5 rows\n" - "\n" - ), - ) - - def test_show_k_gt_n(self): - init_tsdf = self.get_data_as_tsdf("init") - - captured_output = StringIO() - sys.stdout = captured_output - self.assertRaises(ValueError, init_tsdf.show, 5, 10) - - def test_show_truncate_false(self): - init_tsdf = self.get_data_as_tsdf("init") - - captured_output = StringIO() - sys.stdout = captured_output - init_tsdf.show(truncate=False) - self.assertEqual( - captured_output.getvalue(), - ( - "+------+-------------------+--------+\n" - "|symbol|event_ts |trade_pr|\n" - "+------+-------------------+--------+\n" - "|S1 |2020-08-01 00:00:10|349.21 |\n" - "|S1 |2020-08-01 00:01:12|351.32 |\n" - "|S1 |2020-09-01 00:02:10|361.1 |\n" - "|S1 |2020-09-01 00:19:12|362.1 |\n" - "|S2 |2020-08-01 00:01:10|743.01 |\n" - "|S2 |2020-08-01 00:01:24|751.92 |\n" - "|S2 |2020-09-01 00:02:10|761.1 |\n" - "|S2 |2020-09-01 00:20:42|762.33 |\n" - "+------+-------------------+--------+\n" - "\n" - ), - ) - - def test_show_vertical_true(self): - init_tsdf = self.get_data_as_tsdf("init") - - captured_output = StringIO() - sys.stdout = captured_output - init_tsdf.show(vertical=True) - self.assertEqual( - captured_output.getvalue(), - ( - "-RECORD 0-----------------------\n" - " symbol | S1 \n" - " event_ts | 2020-08-01 00:00:10 \n" - " trade_pr | 349.21 \n" - "-RECORD 1-----------------------\n" - " symbol | S1 \n" - " event_ts | 2020-08-01 00:01:12 \n" - " trade_pr | 351.32 \n" - "-RECORD 2-----------------------\n" - " symbol | S1 \n" - " event_ts | 2020-09-01 00:02:10 \n" - " trade_pr | 361.1 \n" - "-RECORD 3-----------------------\n" - " symbol | S1 \n" - " event_ts | 2020-09-01 00:19:12 \n" - " trade_pr | 362.1 \n" - "-RECORD 4-----------------------\n" - " symbol | S2 \n" - " event_ts | 2020-08-01 00:01:10 \n" - " trade_pr | 743.01 \n" - "-RECORD 5-----------------------\n" - " symbol | S2 \n" - " event_ts | 2020-08-01 00:01:24 \n" - " trade_pr | 751.92 \n" - "-RECORD 6-----------------------\n" - " symbol | S2 \n" - " event_ts | 2020-09-01 00:02:10 \n" - " trade_pr | 761.1 \n" - "-RECORD 7-----------------------\n" - " symbol | S2 \n" - " event_ts | 2020-09-01 00:20:42 \n" - " trade_pr | 762.33 \n" - "\n" - ), - ) - - def test_show_vertical_true_n_5(self): - init_tsdf = self.get_data_as_tsdf("init") - - captured_output = StringIO() - sys.stdout = captured_output - init_tsdf.show(5, vertical=True) - self.assertEqual( - captured_output.getvalue(), - ( - "-RECORD 0-----------------------\n" - " symbol | S1 \n" - " event_ts | 2020-08-01 00:00:10 \n" - " trade_pr | 349.21 \n" - "-RECORD 1-----------------------\n" - " symbol | S1 \n" - " event_ts | 2020-08-01 00:01:12 \n" - " trade_pr | 351.32 \n" - "-RECORD 2-----------------------\n" - " symbol | S1 \n" - " event_ts | 2020-09-01 00:02:10 \n" - " trade_pr | 361.1 \n" - "-RECORD 3-----------------------\n" - " symbol | S1 \n" - " event_ts | 2020-09-01 00:19:12 \n" - " trade_pr | 362.1 \n" - "-RECORD 4-----------------------\n" - " symbol | S2 \n" - " event_ts | 2020-08-01 00:01:10 \n" - " trade_pr | 743.01 \n" - "only showing top 5 rows\n" - "\n" - ), - ) - - def test_show_truncate_false_vertical_true(self): - init_tsdf = self.get_data_as_tsdf("init") - - captured_output = StringIO() - sys.stdout = captured_output - init_tsdf.show(truncate=False, vertical=True) - self.assertEqual( - captured_output.getvalue(), - ( - "-RECORD 0-----------------------\n" - " symbol | S1 \n" - " event_ts | 2020-08-01 00:00:10 \n" - " trade_pr | 349.21 \n" - "-RECORD 1-----------------------\n" - " symbol | S1 \n" - " event_ts | 2020-08-01 00:01:12 \n" - " trade_pr | 351.32 \n" - "-RECORD 2-----------------------\n" - " symbol | S1 \n" - " event_ts | 2020-09-01 00:02:10 \n" - " trade_pr | 361.1 \n" - "-RECORD 3-----------------------\n" - " symbol | S1 \n" - " event_ts | 2020-09-01 00:19:12 \n" - " trade_pr | 362.1 \n" - "-RECORD 4-----------------------\n" - " symbol | S2 \n" - " event_ts | 2020-08-01 00:01:10 \n" - " trade_pr | 743.01 \n" - "-RECORD 5-----------------------\n" - " symbol | S2 \n" - " event_ts | 2020-08-01 00:01:24 \n" - " trade_pr | 751.92 \n" - "-RECORD 6-----------------------\n" - " symbol | S2 \n" - " event_ts | 2020-09-01 00:02:10 \n" - " trade_pr | 761.1 \n" - "-RECORD 7-----------------------\n" - " symbol | S2 \n" - " event_ts | 2020-09-01 00:20:42 \n" - " trade_pr | 762.33 \n" - "\n" - ), - ) - - def test_at_string_timestamp(self): - """ - Test of time-slicing at(..) function using a string timestamp - """ - init_tsdf = self.get_data_as_tsdf("init") - expected_tsdf = self.get_data_as_tsdf("expected") - - target_ts = "2020-09-01 00:02:10" - at_tsdf = init_tsdf.at(target_ts) - - self.assertDataFrameEquality(at_tsdf, expected_tsdf, from_tsdf=True) - - def test_at_numeric_timestamp(self): - """ - Test of time-slicint at(..) function using a numeric timestamp - """ - init_tsdf = self.get_data_as_tsdf("init") - expected_tsdf = self.get_data_as_tsdf("expected") - - # test with numeric ts_col - init_dbl_tsdf = self.__tsdf_with_double_tscol(init_tsdf) - expected_dbl_tsdf = self.__tsdf_with_double_tscol(expected_tsdf) - - target_ts = "2020-09-01 00:02:10" - target_dbl = self.__timestamp_to_double(target_ts) - at_dbl_tsdf = init_dbl_tsdf.at(target_dbl) - - self.assertDataFrameEquality(at_dbl_tsdf, expected_dbl_tsdf, from_tsdf=True) - - def test_before_string_timestamp(self): - """ - Test of time-slicing before(..) function - """ - init_tsdf = self.get_data_as_tsdf("init") - expected_tsdf = self.get_data_as_tsdf("expected") - - target_ts = "2020-09-01 00:02:10" - before_tsdf = init_tsdf.before(target_ts) - - self.assertDataFrameEquality(before_tsdf, expected_tsdf, from_tsdf=True) - - def test_before_numeric_timestamp(self): - init_tsdf = self.get_data_as_tsdf("init") - expected_tsdf = self.get_data_as_tsdf("expected") - - # test with numeric ts_col - init_dbl_tsdf = self.__tsdf_with_double_tscol(init_tsdf) - expected_dbl_tsdf = self.__tsdf_with_double_tscol(expected_tsdf) - - target_ts = "2020-09-01 00:02:10" - target_dbl = self.__timestamp_to_double(target_ts) - before_dbl_tsdf = init_dbl_tsdf.before(target_dbl) - - self.assertDataFrameEquality(before_dbl_tsdf, expected_dbl_tsdf, from_tsdf=True) - - def test_atOrBefore_string_timestamp(self): - """ - Test of time-slicing atOrBefore(..) function - """ - init_tsdf = self.get_data_as_tsdf("init") - expected_tsdf = self.get_data_as_tsdf("expected") - - target_ts = "2020-09-01 00:02:10" - before_tsdf = init_tsdf.atOrBefore(target_ts) - - self.assertDataFrameEquality(before_tsdf, expected_tsdf, from_tsdf=True) - - def test_atOrBefore_numeric_timestamp(self): - """ - Test of time-slicing atOrBefore(..) function - """ - init_tsdf = self.get_data_as_tsdf("init") - expected_tsdf = self.get_data_as_tsdf("expected") - - target_ts = "2020-09-01 00:02:10" - - # test with numeric ts_col - init_dbl_tsdf = self.__tsdf_with_double_tscol(init_tsdf) - expected_dbl_tsdf = self.__tsdf_with_double_tscol(expected_tsdf) - - target_dbl = self.__timestamp_to_double(target_ts) - before_dbl_tsdf = init_dbl_tsdf.atOrBefore(target_dbl) - - self.assertDataFrameEquality(before_dbl_tsdf, expected_dbl_tsdf, from_tsdf=True) - - def test_after_string_timestamp(self): - """ - Test of time-slicing after(..) function - """ - init_tsdf = self.get_data_as_tsdf("init") - expected_tsdf = self.get_data_as_tsdf("expected") - - target_ts = "2020-09-01 00:02:10" - after_tsdf = init_tsdf.after(target_ts) - - self.assertDataFrameEquality(after_tsdf, expected_tsdf, from_tsdf=True) - - def test_after_numeric_timestamp(self): - """ - Test of time-slicing after(..) function - """ - init_tsdf = self.get_data_as_tsdf("init") - expected_tsdf = self.get_data_as_tsdf("expected") - - target_ts = "2020-09-01 00:02:10" - - # test with numeric ts_col - init_dbl_tsdf = self.__tsdf_with_double_tscol(init_tsdf) - expected_dbl_tsdf = self.__tsdf_with_double_tscol(expected_tsdf) - - target_dbl = self.__timestamp_to_double(target_ts) - after_dbl_tsdf = init_dbl_tsdf.after(target_dbl) - - self.assertDataFrameEquality(after_dbl_tsdf, expected_dbl_tsdf, from_tsdf=True) - - def test_atOrAfter_string_timestamp(self): - """ - Test of time-slicing atOrAfter(..) function - """ - init_tsdf = self.get_data_as_tsdf("init") - expected_tsdf = self.get_data_as_tsdf("expected") - - target_ts = "2020-09-01 00:02:10" - after_tsdf = init_tsdf.atOrAfter(target_ts) - - self.assertDataFrameEquality(after_tsdf, expected_tsdf, from_tsdf=True) - - def test_atOrAfter_numeric_timestamp(self): - """ - Test of time-slicing atOrAfter(..) function - """ - init_tsdf = self.get_data_as_tsdf("init") - expected_tsdf = self.get_data_as_tsdf("expected") - - target_ts = "2020-09-01 00:02:10" - - # test with numeric ts_col - init_dbl_tsdf = self.__tsdf_with_double_tscol(init_tsdf) - expected_dbl_tsdf = self.__tsdf_with_double_tscol(expected_tsdf) - - target_dbl = self.__timestamp_to_double(target_ts) - after_dbl_tsdf = init_dbl_tsdf.atOrAfter(target_dbl) - - self.assertDataFrameEquality(after_dbl_tsdf, expected_dbl_tsdf, from_tsdf=True) - - def test_between_string_timestamp(self): - """ - Test of time-slicing between(..) function - """ - init_tsdf = self.get_data_as_tsdf("init") - expected_tsdf = self.get_data_as_tsdf("expected") - - ts1 = "2020-08-01 00:01:10" - ts2 = "2020-09-01 00:18:00" - between_tsdf = init_tsdf.between(ts1, ts2) - - self.assertDataFrameEquality(between_tsdf, expected_tsdf, from_tsdf=True) - - def test_between_numeric_timestamp(self): - """ - Test of time-slicing between(..) function - """ - init_tsdf = self.get_data_as_tsdf("init") - expected_tsdf = self.get_data_as_tsdf("expected") - - ts1 = "2020-08-01 00:01:10" - ts2 = "2020-09-01 00:18:00" - - # test with numeric ts_col - init_dbl_tsdf = self.__tsdf_with_double_tscol(init_tsdf) - expected_dbl_tsdf = self.__tsdf_with_double_tscol(expected_tsdf) - - ts1_dbl = self.__timestamp_to_double(ts1) - ts2_dbl = self.__timestamp_to_double(ts2) - between_dbl_tsdf = init_dbl_tsdf.between(ts1_dbl, ts2_dbl) - - self.assertDataFrameEquality( - between_dbl_tsdf, expected_dbl_tsdf, from_tsdf=True - ) - - def test_between_exclusive_string_timestamp(self): - """ - Test of time-slicing between(..) function - """ - init_tsdf = self.get_data_as_tsdf("init") - expected_tsdf = self.get_data_as_tsdf("expected") - - ts1 = "2020-08-01 00:01:10" - ts2 = "2020-09-01 00:18:00" - between_tsdf = init_tsdf.between(ts1, ts2, inclusive=False) - - self.assertDataFrameEquality(between_tsdf, expected_tsdf, from_tsdf=True) - - def test_between_exclusive_numeric_timestamp(self): - """ - Test of time-slicing between(..) function - """ - init_tsdf = self.get_data_as_tsdf("init") - expected_tsdf = self.get_data_as_tsdf("expected") - - ts1 = "2020-08-01 00:01:10" - ts2 = "2020-09-01 00:18:00" - - # test with numeric ts_col - init_dbl_tsdf = self.__tsdf_with_double_tscol(init_tsdf) - expected_dbl_tsdf = self.__tsdf_with_double_tscol(expected_tsdf) - - ts1_dbl = self.__timestamp_to_double(ts1) - ts2_dbl = self.__timestamp_to_double(ts2) - between_dbl_tsdf = init_dbl_tsdf.between(ts1_dbl, ts2_dbl, inclusive=False) - - self.assertDataFrameEquality( - between_dbl_tsdf, expected_dbl_tsdf, from_tsdf=True - ) - - def test_earliest_string_timestamp(self): - """ - Test of time-slicing earliest(..) function - """ - init_tsdf = self.get_data_as_tsdf("init") - expected_tsdf = self.get_data_as_tsdf("expected") - - earliest_tsdf = init_tsdf.earliest(n=3) - - self.assertDataFrameEquality(earliest_tsdf, expected_tsdf, from_tsdf=True) - - def test_earliest_numeric_timestamp(self): - """ - Test of time-slicing earliest(..) function - """ - init_tsdf = self.get_data_as_tsdf("init") - expected_tsdf = self.get_data_as_tsdf("expected") - - # test with numeric ts_col - init_dbl_tsdf = self.__tsdf_with_double_tscol(init_tsdf) - expected_dbl_tsdf = self.__tsdf_with_double_tscol(expected_tsdf) - - earliest_dbl_tsdf = init_dbl_tsdf.earliest(n=3) - - self.assertDataFrameEquality( - earliest_dbl_tsdf, expected_dbl_tsdf, from_tsdf=True - ) - - def test_latest_string_timestamp(self): - """ - Test of time-slicing latest(..) function - """ - init_tsdf = self.get_data_as_tsdf("init") - expected_tsdf = self.get_data_as_tsdf("expected") - - latest_tsdf = init_tsdf.latest(n=3) - - self.assertDataFrameEquality( - latest_tsdf, expected_tsdf, ignore_row_order=True, from_tsdf=True - ) - - def test_latest_numeric_timestamp(self): - """ - Test of time-slicing latest(..) function - """ - init_tsdf = self.get_data_as_tsdf("init") - expected_tsdf = self.get_data_as_tsdf("expected") - - # test with numeric ts_col - init_dbl_tsdf = self.__tsdf_with_double_tscol(init_tsdf) - expected_dbl_tsdf = self.__tsdf_with_double_tscol(expected_tsdf) - - latest_dbl_tsdf = init_dbl_tsdf.latest(n=3) - - self.assertDataFrameEquality( - latest_dbl_tsdf, expected_dbl_tsdf, ignore_row_order=True, from_tsdf=True - ) - - def test_priorTo_string_timestamp(self): - """ - Test of time-slicing priorTo(..) function - """ - init_tsdf = self.get_data_as_tsdf("init") - expected_tsdf = self.get_data_as_tsdf("expected") - - target_ts = "2020-09-01 00:02:00" - prior_tsdf = init_tsdf.priorTo(target_ts) - - self.assertDataFrameEquality(prior_tsdf, expected_tsdf, from_tsdf=True) - - def test_priorTo_numeric_timestamp(self): - """ - Test of time-slicing priorTo(..) function - """ - init_tsdf = self.get_data_as_tsdf("init") - expected_tsdf = self.get_data_as_tsdf("expected") - - target_ts = "2020-09-01 00:02:00" - - # test with numeric ts_col - init_dbl_tsdf = self.__tsdf_with_double_tscol(init_tsdf) - expected_dbl_tsdf = self.__tsdf_with_double_tscol(expected_tsdf) - - target_dbl = self.__timestamp_to_double(target_ts) - prior_dbl_tsdf = init_dbl_tsdf.priorTo(target_dbl) - - self.assertDataFrameEquality(prior_dbl_tsdf, expected_dbl_tsdf, from_tsdf=True) - - def test_subsequentTo_string_timestamp(self): - """ - Test of time-slicing subsequentTo(..) function - """ - init_tsdf = self.get_data_as_tsdf("init") - expected_tsdf = self.get_data_as_tsdf("expected") - - target_ts = "2020-09-01 00:02:00" - subsequent_tsdf = init_tsdf.subsequentTo(target_ts) - - self.assertDataFrameEquality(subsequent_tsdf, expected_tsdf, from_tsdf=True) - - def test_subsequentTo_numeric_timestamp(self): - """ - Test of time-slicing subsequentTo(..) function - """ - init_tsdf = self.get_data_as_tsdf("init") - expected_tsdf = self.get_data_as_tsdf("expected") - - target_ts = "2020-09-01 00:02:00" - - # test with numeric ts_col - init_dbl_tsdf = self.__tsdf_with_double_tscol(init_tsdf) - expected_dbl_tsdf = self.__tsdf_with_double_tscol(expected_tsdf) - - target_dbl = self.__timestamp_to_double(target_ts) - subsequent_dbl_tsdf = init_dbl_tsdf.subsequentTo(target_dbl) - - self.assertDataFrameEquality( - subsequent_dbl_tsdf, expected_dbl_tsdf, from_tsdf=True - ) - - def test__rowsBetweenWindow(self): - init_tsdf = self.get_data_as_tsdf("init") - - self.assertIsInstance(init_tsdf.rowsBetweenWindow(1, 1), WindowSpec) - - def test_tsdf_interpolate(self): - ... - - -class ExtractStateIntervalsTest(SparkTest): - """Test of finding time ranges for metrics with constant state.""" - - def test_eq_0(self): - # construct dataframes - input_tsdf: TSDF = self.get_data_as_tsdf("input") - expected_df: DataFrame = self.get_data_as_sdf("expected") - - # call extractStateIntervals method - intervals_eq_1_df: DataFrame = input_tsdf.extractStateIntervals( - "metric_1", "metric_2", "metric_3" - ) - intervals_eq_2_df: DataFrame = input_tsdf.extractStateIntervals( - "metric_1", "metric_2", "metric_3", state_definition="==" - ) - - # test extractStateIntervals_tsdf summary - self.assertDataFrameEquality(intervals_eq_1_df, expected_df) - self.assertDataFrameEquality(intervals_eq_2_df, expected_df) - - def test_eq_1(self): - # construct dataframes - input_tsdf: TSDF = self.get_data_as_tsdf("input") - expected_df: DataFrame = self.get_data_as_sdf("expected") - - # call extractStateIntervals method - intervals_eq_1_df: DataFrame = input_tsdf.extractStateIntervals( - "metric_1", "metric_2", "metric_3" - ) - intervals_eq_2_df: DataFrame = input_tsdf.extractStateIntervals( - "metric_1", "metric_2", "metric_3", state_definition="==" - ) - - # test extractStateIntervals_tsdf summary - self.assertDataFrameEquality(intervals_eq_1_df, expected_df) - self.assertDataFrameEquality(intervals_eq_2_df, expected_df) - - def test_ne_0(self): - # construct dataframes - input_tsdf: TSDF = self.get_data_as_tsdf("input") - expected_df: DataFrame = self.get_data_as_sdf("expected") - - # call extractStateIntervals method - intervals_ne_0_df: DataFrame = input_tsdf.extractStateIntervals( - "metric_1", "metric_2", "metric_3", state_definition="!=" - ) - intervals_ne_1_df: DataFrame = input_tsdf.extractStateIntervals( - "metric_1", "metric_2", "metric_3", state_definition="<>" - ) - - # test extractStateIntervals_tsdf summary - self.assertDataFrameEquality(intervals_ne_0_df, expected_df) - self.assertDataFrameEquality(intervals_ne_1_df, expected_df) - - def test_ne_1(self): - # construct dataframes - input_tsdf: TSDF = self.get_data_as_tsdf("input") - expected_df: DataFrame = self.get_data_as_sdf("expected") - - # call extractStateIntervals method - intervals_ne_0_df: DataFrame = input_tsdf.extractStateIntervals( - "metric_1", "metric_2", "metric_3", state_definition="!=" - ) - intervals_ne_1_df: DataFrame = input_tsdf.extractStateIntervals( - "metric_1", "metric_2", "metric_3", state_definition="<>" - ) - - # test extractStateIntervals_tsdf summary - self.assertDataFrameEquality(intervals_ne_0_df, expected_df) - self.assertDataFrameEquality(intervals_ne_1_df, expected_df) - - def test_gt_0(self): - # construct dataframes - input_tsdf: TSDF = self.get_data_as_tsdf("input") - expected_df: DataFrame = self.get_data_as_sdf("expected") - - # call extractStateIntervals method - intervals_gt_df: DataFrame = input_tsdf.extractStateIntervals( - "metric_1", "metric_2", "metric_3", state_definition=">" - ) - - self.assertDataFrameEquality(intervals_gt_df, expected_df) - - def test_gt_1(self): - # construct dataframes - input_tsdf: TSDF = self.get_data_as_tsdf("input") - expected_df: DataFrame = self.get_data_as_sdf("expected") - - # call extractStateIntervals method - intervals_gt_df: DataFrame = input_tsdf.extractStateIntervals( - "metric_1", "metric_2", "metric_3", state_definition=">" - ) - - self.assertDataFrameEquality(intervals_gt_df, expected_df) - - def test_lt_0(self): - # construct dataframes - input_tsdf: TSDF = self.get_data_as_tsdf("input") - expected_df: DataFrame = self.get_data_as_sdf("expected") - - # call extractStateIntervals method - intervals_lt_df: DataFrame = input_tsdf.extractStateIntervals( - "metric_1", "metric_2", "metric_3", state_definition="<" - ) - - # test extractStateIntervals_tsdf summary - self.assertDataFrameEquality(intervals_lt_df, expected_df) - - def test_lt_1(self): - # construct dataframes - input_tsdf: TSDF = self.get_data_as_tsdf("input") - expected_df: DataFrame = self.get_data_as_sdf("expected") - - # call extractStateIntervals method - intervals_lt_df: DataFrame = input_tsdf.extractStateIntervals( - "metric_1", "metric_2", "metric_3", state_definition="<" - ) - - # test intervals_tsdf summary - self.assertDataFrameEquality(intervals_lt_df, expected_df) - - def test_gte_0(self): - # construct dataframes - input_tsdf: TSDF = self.get_data_as_tsdf("input") - expected_df: DataFrame = self.get_data_as_sdf("expected") - - # call extractStateIntervals method - intervals_gt_df: DataFrame = input_tsdf.extractStateIntervals( - "metric_1", "metric_2", "metric_3", state_definition=">=" - ) - - self.assertDataFrameEquality(intervals_gt_df, expected_df) - - def test_gte_1(self): - # construct dataframes - input_tsdf: TSDF = self.get_data_as_tsdf("input") - expected_df: DataFrame = self.get_data_as_sdf("expected") - - # call extractStateIntervals method - intervals_gt_df: DataFrame = input_tsdf.extractStateIntervals( - "metric_1", "metric_2", "metric_3", state_definition=">=" - ) - - self.assertDataFrameEquality(intervals_gt_df, expected_df) - - def test_lte_0(self): - # construct dataframes - input_tsdf: TSDF = self.get_data_as_tsdf("input") - expected_df: DataFrame = self.get_data_as_sdf("expected") - - # call extractStateIntervals method - intervals_lte_df: DataFrame = input_tsdf.extractStateIntervals( - "metric_1", "metric_2", "metric_3", state_definition="<=" - ) - - # test intervals_tsdf summary - self.assertDataFrameEquality(intervals_lte_df, expected_df) - - def test_lte_1(self): - # construct dataframes - input_tsdf: TSDF = self.get_data_as_tsdf("input") - expected_df: DataFrame = self.get_data_as_sdf("expected") - - # call extractStateIntervals method - intervals_lte_df: DataFrame = input_tsdf.extractStateIntervals( - "metric_1", "metric_2", "metric_3", state_definition="<=" - ) - - # test extractStateIntervals_tsdf summary - self.assertDataFrameEquality(intervals_lte_df, expected_df) - - def test_threshold_fn(self): - # construct dataframes - input_tsdf: TSDF = self.get_data_as_tsdf("input") - expected_df: DataFrame = self.get_data_as_sdf("expected") - - # threshold state function - def threshold_fn(a: Column, b: Column) -> Column: - return sfn.abs(a - b) < sfn.lit(0.5) - - # call extractStateIntervals method - extracted_intervals_df: DataFrame = input_tsdf.extractStateIntervals( - "metric_1", "metric_2", "metric_3", state_definition=threshold_fn - ) - - # test extractStateIntervals_tsdf summary - self.assertDataFrameEquality(extracted_intervals_df, expected_df) - - def test_null_safe_eq_0(self): - # construct dataframes - input_tsdf: TSDF = self.get_data_as_tsdf("input") - expected_df: DataFrame = self.get_data_as_sdf("expected") - - intervals_eq_df: DataFrame = input_tsdf.extractStateIntervals( - "metric_1", "metric_2", "metric_3", state_definition="<=>" - ) - - # test extractStateIntervals_tsdf summary - self.assertDataFrameEquality( - intervals_eq_df, expected_df, ignore_nullable=False - ) - - def test_null_safe_eq_1(self): - # construct dataframes - input_tsdf: TSDF = self.get_data_as_tsdf("input") - expected_df: DataFrame = self.get_data_as_sdf("expected") - - intervals_eq_df: DataFrame = input_tsdf.extractStateIntervals( - "metric_1", "metric_2", "metric_3", state_definition="<=>" - ) - - # test extractStateIntervals_tsdf summary - self.assertDataFrameEquality( - intervals_eq_df, expected_df, ignore_nullable=False - ) - - def test_adjacent_intervals(self): - # construct dataframes - input_tsdf: TSDF = self.get_data_as_tsdf("input") - expected_df: DataFrame = self.get_data_as_sdf("expected") - - intervals_eq_df: DataFrame = input_tsdf.extractStateIntervals( - "metric_1", "metric_2", "metric_3" - ) - - # test extractStateIntervals_tsdf summary - self.assertDataFrameEquality(intervals_eq_df, expected_df) - - def test_invalid_state_definition_str(self): - # construct dataframes - input_tsdf: TSDF = self.get_data_as_tsdf("input") - - try: - input_tsdf.extractStateIntervals( - "metric_1", "metric_2", "metric_3", state_definition="N/A" - ) - except ValueError as e: - self.assertEqual(type(e), ValueError) - - def test_invalid_state_definition_type(self): - # construct dataframes - input_tsdf: TSDF = self.get_data_as_tsdf("input") - - try: - input_tsdf.extractStateIntervals( - "metric_1", "metric_2", "metric_3", state_definition=0 - ) - except TypeError as e: - self.assertEqual(type(e), TypeError) - - -# MAIN -if __name__ == "__main__": - unittest.main() +from tempo.tsschema import SimpleTimestampIndex, OrdinalTSIndex, TSSchema +from tests.base import TestDataFrame, SparkTest + + +class TSDFBaseTest(SparkTest): + @parameterized.expand([ + ("simple_ts_idx", SimpleTimestampIndex), + ("simple_ts_no_series", SimpleTimestampIndex), + ("ordinal_double_index", OrdinalTSIndex), + ("ordinal_int_index", OrdinalTSIndex), + ]) + def test_tsdf_constructor(self, init_tsdf_id, expected_idx_class): + # load the test data + test_data = self.get_test_data(init_tsdf_id) + # load Spark DataFrame + init_sdf = test_data.as_sdf() + # create TSDF + init_tsdf = TSDF(init_sdf, **test_data.ts_idx) + # check that TSDF was created correctly + self.assertIsNotNone(init_tsdf) + self.assertIsInstance(init_tsdf, TSDF) + # validate the TSSchema + self.assertIsNotNone(init_tsdf.ts_schema) + self.assertIsInstance(init_tsdf.ts_schema, TSSchema) + # validate the TSIndex + self.assertIsNotNone(init_tsdf.ts_index) + self.assertIsInstance(init_tsdf.ts_index, expected_idx_class) + + @parameterized.expand([ + ("simple_ts_idx", ["symbol"]), + ("simple_ts_no_series", []), + ("ordinal_double_index", ["symbol"]), + ("ordinal_int_index", ["symbol"]), + ]) + def test_series_ids(self, init_tsdf_id, expected_series_ids): + # load TSDF + tsdf = self.get_test_data(init_tsdf_id).as_tsdf() + # validate series ids + self.assertEqual(set(tsdf.series_ids), set(expected_series_ids)) + + @parameterized.expand([ + ("simple_ts_idx", ["event_ts", "symbol"]), + ("simple_ts_no_series", ["event_ts"]), + ("ordinal_double_index", ["event_ts_dbl", "symbol"]), + ("ordinal_int_index", ["order", "symbol"]), + ]) + def test_structural_cols(self, init_tsdf_id, expected_structural_cols): + # load TSDF + tsdf = self.get_test_data(init_tsdf_id).as_tsdf() + # validate structural cols + self.assertEqual(set(tsdf.structural_cols), set(expected_structural_cols)) + + @parameterized.expand([ + ("simple_ts_idx", ["trade_pr"]), + ("simple_ts_no_series", ["trade_pr"]), + ("ordinal_double_index", ["trade_pr"]), + ("ordinal_int_index", ["trade_pr"]), + ]) + def test_obs_cols(self, init_tsdf_id, expected_obs_cols): + # load TSDF + tsdf = self.get_test_data(init_tsdf_id).as_tsdf() + # validate obs cols + self.assertEqual(set(tsdf.observational_cols), set(expected_obs_cols)) + + @parameterized.expand([ + ("simple_ts_idx", ["trade_pr"]), + ("simple_ts_no_series", ["trade_pr"]), + ("ordinal_double_index", ["trade_pr"]), + ("ordinal_int_index", ["trade_pr"]), + ]) + def test_metric_cols(self, init_tsdf_id, expected_metric_cols): + # load TSDF + tsdf = self.get_test_data(init_tsdf_id).as_tsdf() + # validate metric cols + self.assertEqual(set(tsdf.metric_cols), set(expected_metric_cols)) + + +class TimeSlicingTests(SparkTest): + @parameterized.expand([ + ( + "simple_ts_idx", + "2020-09-01 00:02:10", + { + "ts_idx": {"ts_col": "event_ts", "series_ids": ["symbol"]}, + "df": { + "schema": "symbol string, event_ts string, trade_pr float", + "ts_convert": ["event_ts"], + "data": [ + ["S1", "2020-09-01 00:02:10", 361.1], + ["S2", "2020-09-01 00:02:10", 761.10], + ], + }, + }, + ), + ( + "simple_ts_no_series", + "2020-09-01 00:19:12", + { + "ts_idx": {"ts_col": "event_ts", "series_ids": []}, + "df": { + "schema": "event_ts string, trade_pr float", + "ts_convert": ["event_ts"], + "data": [ + ["2020-09-01 00:19:12", 362.1], + ], + }, + }, + ), + ( + "ordinal_double_index", + 10.0, + { + "ts_idx": {"ts_col": "event_ts_dbl", "series_ids": ["symbol"]}, + "df": { + "schema": "symbol string, event_ts_dbl double, trade_pr float", + "data": [ + ["S1", 10.0, 361.1], + ["S2", 10.0, 762.33], + ], + }, + }, + ), + ( + "ordinal_int_index", + 1, + { + "ts_idx": {"ts_col": "order", "series_ids": ["symbol"]}, + "df": { + "schema": "symbol string, order int, trade_pr float", + "data": [ + ["S1", 1, 349.21], + ["S2", 1, 751.92], + ], + }, + }, + ), + ]) + def test_at(self, init_tsdf_id, ts, expected_tsdf_dict): + # load TSDF + tsdf = self.get_test_data(init_tsdf_id).as_tsdf() + # slice at timestamp + at_tsdf = tsdf.at(ts) + # validate the slice + expected_tsdf = TestDataFrame(self.spark, expected_tsdf_dict).as_tsdf() + self.assertDataFrameEquality(at_tsdf, expected_tsdf) + + @parameterized.expand([ + ( + "simple_ts_idx", + "2020-09-01 00:02:10", + { + "ts_idx": {"ts_col": "event_ts", "series_ids": ["symbol"]}, + "df": { + "schema": "symbol string, event_ts string, trade_pr float", + "ts_convert": ["event_ts"], + "data": [ + ["S1", "2020-08-01 00:00:10", 349.21], + ["S1", "2020-08-01 00:01:12", 351.32], + ["S2", "2020-08-01 00:01:10", 743.01], + ["S2", "2020-08-01 00:01:24", 751.92], + ], + }, + }, + ), + ( + "simple_ts_no_series", + "2020-09-01 00:19:12", + { + "ts_idx": {"ts_col": "event_ts", "series_ids": []}, + "df": { + "schema": "event_ts string, trade_pr float", + "ts_convert": ["event_ts"], + "data": [ + ["2020-08-01 00:00:10", 349.21], + ["2020-08-01 00:01:10", 743.01], + ["2020-08-01 00:01:12", 351.32], + ["2020-08-01 00:01:24", 751.92], + ["2020-09-01 00:02:10", 361.1], + ], + }, + }, + ), + ( + "ordinal_double_index", + 10.0, + { + "ts_idx": {"ts_col": "event_ts_dbl", "series_ids": ["symbol"]}, + "df": { + "schema": "symbol string, event_ts_dbl double, trade_pr float", + "data": [ + ["S1", 0.13, 349.21], + ["S1", 1.207, 351.32], + ["S2", 0.005, 743.01], + ["S2", 0.1, 751.92], + ["S2", 1.0, 761.10], + ], + }, + }, + ), + ( + "ordinal_int_index", + 1, + { + "ts_idx": {"ts_col": "order", "series_ids": ["symbol"]}, + "df": { + "schema": "symbol string, order int, trade_pr float", + "data": [["S2", 0, 743.01]], + }, + }, + ), + ]) + def test_before(self, init_tsdf_id, ts, expected_tsdf_dict): + # load TSDF + tsdf = self.get_test_data(init_tsdf_id).as_tsdf() + # slice at timestamp + at_tsdf = tsdf.before(ts) + # validate the slice + expected_tsdf = TestDataFrame(self.spark, expected_tsdf_dict).as_tsdf() + self.assertDataFrameEquality(at_tsdf, expected_tsdf) + + @parameterized.expand([ + ( + "simple_ts_idx", + "2020-09-01 00:02:10", + { + "ts_idx": {"ts_col": "event_ts", "series_ids": ["symbol"]}, + "df": { + "schema": "symbol string, event_ts string, trade_pr float", + "ts_convert": ["event_ts"], + "data": [ + ["S1", "2020-08-01 00:00:10", 349.21], + ["S1", "2020-08-01 00:01:12", 351.32], + ["S1", "2020-09-01 00:02:10", 361.1], + ["S2", "2020-08-01 00:01:10", 743.01], + ["S2", "2020-08-01 00:01:24", 751.92], + ["S2", "2020-09-01 00:02:10", 761.10], + ], + }, + }, + ), + ( + "simple_ts_no_series", + "2020-09-01 00:19:12", + { + "ts_idx": {"ts_col": "event_ts", "series_ids": []}, + "df": { + "schema": "event_ts string, trade_pr float", + "ts_convert": ["event_ts"], + "data": [ + ["2020-08-01 00:00:10", 349.21], + ["2020-08-01 00:01:10", 743.01], + ["2020-08-01 00:01:12", 351.32], + ["2020-08-01 00:01:24", 751.92], + ["2020-09-01 00:02:10", 361.1], + ["2020-09-01 00:19:12", 362.1], + ], + }, + }, + ), + ( + "ordinal_double_index", + 10.0, + { + "ts_idx": {"ts_col": "event_ts_dbl", "series_ids": ["symbol"]}, + "df": { + "schema": "symbol string, event_ts_dbl double, trade_pr float", + "data": [ + ["S1", 0.13, 349.21], + ["S1", 1.207, 351.32], + ["S1", 10.0, 361.1], + ["S2", 0.005, 743.01], + ["S2", 0.1, 751.92], + ["S2", 1.0, 761.10], + ["S2", 10.0, 762.33], + ], + }, + }, + ), + ( + "ordinal_int_index", + 1, + { + "ts_idx": {"ts_col": "order", "series_ids": ["symbol"]}, + "df": { + "schema": "symbol string, order int, trade_pr float", + "data": [["S1", 1, 349.21], ["S2", 0, 743.01], ["S2", 1, 751.92]], + }, + }, + ), + ]) + def test_atOrBefore(self, init_tsdf_id, ts, expected_tsdf_dict): + # load TSDF + tsdf = self.get_test_data(init_tsdf_id).as_tsdf() + # slice at timestamp + at_tsdf = tsdf.atOrBefore(ts) + # validate the slice + expected_tsdf = TestDataFrame(self.spark, expected_tsdf_dict).as_tsdf() + self.assertDataFrameEquality(at_tsdf, expected_tsdf) + + @parameterized.expand([ + ( + "simple_ts_idx", + "2020-09-01 00:02:10", + { + "ts_idx": {"ts_col": "event_ts", "series_ids": ["symbol"]}, + "df": { + "schema": "symbol string, event_ts string, trade_pr float", + "ts_convert": ["event_ts"], + "data": [ + ["S1", "2020-09-01 00:19:12", 362.1], + ["S2", "2020-09-01 00:20:42", 762.33] + ], + }, + }, + ), + ( + "simple_ts_no_series", + "2020-09-01 00:08:12", + { + "ts_idx": {"ts_col": "event_ts", "series_ids": []}, + "df": { + "schema": "event_ts string, trade_pr float", + "ts_convert": ["event_ts"], + "data": [ + ["2020-09-01 00:19:12", 362.1], + ["2020-09-01 00:20:42", 762.33] + ], + }, + }, + ), + ( + "ordinal_double_index", + 1.0, + { + "ts_idx": {"ts_col": "event_ts_dbl", "series_ids": ["symbol"]}, + "df": { + "schema": "symbol string, event_ts_dbl double, trade_pr float", + "data": [ + ["S1", 1.207, 351.32], + ["S1", 10.0, 361.1], + ["S1", 24.357, 362.1], + ["S2", 10.0, 762.33] + ], + }, + }, + ), + ( + "ordinal_int_index", + 10, + { + "ts_idx": {"ts_col": "order", "series_ids": ["symbol"]}, + "df": { + "schema": "symbol string, order int, trade_pr float", + "data": [ + ["S1", 20, 351.32], + ["S1", 127, 361.1], + ["S1", 243, 362.1], + ["S2", 100, 762.33] + ], + }, + }, + ), + ]) + def test_after(self, init_tsdf_id, ts, expected_tsdf_dict): + # load TSDF + tsdf = self.get_test_data(init_tsdf_id).as_tsdf() + # slice at timestamp + at_tsdf = tsdf.after(ts) + # validate the slice + expected_tsdf = TestDataFrame(self.spark, expected_tsdf_dict).as_tsdf() + self.assertDataFrameEquality(at_tsdf, expected_tsdf) + + @parameterized.expand([ + ( + "simple_ts_idx", + "2020-09-01 00:02:10", + { + "ts_idx": {"ts_col": "event_ts", "series_ids": ["symbol"]}, + "df": { + "schema": "symbol string, event_ts string, trade_pr float", + "ts_convert": ["event_ts"], + "data": [ + ["S1", "2020-09-01 00:02:10", 361.1], + ["S1", "2020-09-01 00:19:12", 362.1], + ["S2", "2020-09-01 00:02:10", 761.10], + ["S2", "2020-09-01 00:20:42", 762.33] + ], + }, + }, + ), + ( + "simple_ts_no_series", + "2020-08-01 00:01:24", + { + "ts_idx": {"ts_col": "event_ts", "series_ids": []}, + "df": { + "schema": "event_ts string, trade_pr float", + "ts_convert": ["event_ts"], + "data": [ + ["2020-08-01 00:01:24", 751.92], + ["2020-09-01 00:02:10", 361.1], + ["2020-09-01 00:19:12", 362.1], + ["2020-09-01 00:20:42", 762.33] + ], + }, + }, + ), + ( + "ordinal_double_index", + 10.0, + { + "ts_idx": {"ts_col": "event_ts_dbl", "series_ids": ["symbol"]}, + "df": { + "schema": "symbol string, event_ts_dbl double, trade_pr float", + "data": [ + ["S1", 10.0, 361.1], + ["S1", 24.357, 362.1], + ["S2", 10.0, 762.33] + ], + }, + }, + ), + ( + "ordinal_int_index", + 10, + { + "ts_idx": {"ts_col": "order", "series_ids": ["symbol"]}, + "df": { + "schema": "symbol string, order int, trade_pr float", + "data": [ + ["S1", 20, 351.32], + ["S1", 127, 361.1], + ["S1", 243, 362.1], + ["S2", 10, 761.10], + ["S2", 100, 762.33] + ], + }, + }, + ), + ]) + def test_atOrAfter(self, init_tsdf_id, ts, expected_tsdf_dict): + # load TSDF + tsdf = self.get_test_data(init_tsdf_id).as_tsdf() + # slice at timestamp + at_tsdf = tsdf.atOrAfter(ts) + # validate the slice + expected_tsdf = TestDataFrame(self.spark, expected_tsdf_dict).as_tsdf() + self.assertDataFrameEquality(at_tsdf, expected_tsdf) diff --git a/python/tests/unit_test_data/tsdf_tests.json b/python/tests/unit_test_data/tsdf_tests.json index 9e54e63f..b7f14a33 100644 --- a/python/tests/unit_test_data/tsdf_tests.json +++ b/python/tests/unit_test_data/tsdf_tests.json @@ -1,9 +1,12 @@ { - "__SharedData": { - "temp_slice_init_data": { - "schema": "symbol string, event_ts string, trade_pr float", + "simple_ts_idx": { + "ts_idx": { "ts_col": "event_ts", - "series_ids": ["symbol"], + "series_ids": ["symbol"] + }, + "df": { + "schema": "symbol string, event_ts string, trade_pr float", + "ts_convert": ["event_ts"], "data": [ ["S1", "2020-08-01 00:00:10", 349.21], ["S1", "2020-08-01 00:01:12", 351.32], @@ -16,3083 +19,61 @@ ] } }, - "TSDFBaseTests": { - "test_TSDF_init": { - "init": { - "$ref": "#/__SharedData/temp_slice_init_data" - } - }, - "test__add_double_ts": { - "init": { - "$ref": "#/__SharedData/temp_slice_init_data" - } - }, - "test__validated_column_not_string": { - "init": { - "$ref": "#/__SharedData/temp_slice_init_data" - } - }, - "test__validated_column_not_found": { - "init": { - "$ref": "#/__SharedData/temp_slice_init_data" - } - }, - "test__validated_column": { - "init": { - "$ref": "#/__SharedData/temp_slice_init_data" - } - }, - "test__validated_columns_string": { - "init": { - "$ref": "#/__SharedData/temp_slice_init_data" - } - }, - "test__validated_columns_none": { - "init": { - "$ref": "#/__SharedData/temp_slice_init_data" - } - }, - "test__validated_columns_tuple": { - "init": { - "$ref": "#/__SharedData/temp_slice_init_data" - } - }, - "test__validated_columns_list_multiple_elems": { - "init": { - "$ref": "#/__SharedData/temp_slice_init_data" - } - }, - "test__checkPartitionCols": { - "init": { - "$ref": "#/__SharedData/temp_slice_init_data" - }, - "right_tsdf": { - "schema": "symbol string, event_ts string, trade_pr float", - "ts_col": "event_ts", - "series_ids": [ - "event_ts" - ], - "data": [ - [ - "S1", - "2020-08-01 00:00:10", - 349.21 - ] - ] - } - }, - "test__validateTsColMatch": { - "init": { - "$ref": "#/__SharedData/temp_slice_init_data" - }, - "right_tsdf": { - "schema": "symbol string, event_ts int, trade_pr float", - "ts_col": "event_ts", - "series_ids": [ - "symbol" - ], - "data": [ - [ - "S1", - 1596240010, - 349.21 - ] - ] - } - }, - "test__addPrefixToColumns_non_empty_string": { - "init": { - "$ref": "#/__SharedData/temp_slice_init_data" - } - }, - "test__addPrefixToColumns_empty_string": { - "init": { - "$ref": "#/__SharedData/temp_slice_init_data" - } - }, - "test__addColumnsFromOtherDF": { - "init": { - "$ref": "#/__SharedData/temp_slice_init_data" - } - }, - "test__combineTSDF": { - "init": { - "$ref": "#/__SharedData/temp_slice_init_data" - } - }, - "test__getLastRightRow": { - "init": { - "$ref": "#/__SharedData/temp_slice_init_data" - } - }, - "test__getTimePartitions": { - "init": { - "$ref": "#/__SharedData/temp_slice_init_data" - }, - "expected": { - "schema": "symbol string, event_ts string, trade_pr float, ts_partition int, is_original int", - "ts_col": "event_ts", - "series_ids": [ - "symbol" - ], - "data": [ - [ - "S1", - "2020-08-01 00:00:10", - 349.21, - 1596240010, - 1 - ], - [ - "S1", - "2020-08-01 00:01:12", - 351.32, - 1596240070, - 1 - ], - [ - "S1", - "2020-09-01 00:02:10", - 361.1, - 1598918530, - 1 - ], - [ - "S1", - "2020-09-01 00:19:12", - 362.1, - 1598919550, - 1 - ], - [ - "S2", - "2020-08-01 00:01:10", - 743.01, - 1596240070, - 1 - ], - [ - "S2", - "2020-08-01 00:01:24", - 751.92, - 1596240080, - 1 - ], - [ - "S2", - "2020-09-01 00:02:10", - 761.1, - 1598918530, - 1 - ], - [ - "S2", - "2020-09-01 00:20:42", - 762.33, - 1598919640, - 1 - ] - ] - } - }, - "test__getTimePartitions_with_fraction": { - "init": { - "$ref": "#/__SharedData/temp_slice_init_data" - }, - "expected": { - "schema": "symbol string, event_ts string, trade_pr float, ts_partition int, is_original int", - "ts_col": "event_ts", - "series_ids": [ - "symbol" - ], - "data": [ - [ - "S1", - "2020-08-01 00:00:10", - 349.21, - 1596240010, - 1 - ], - [ - "S1", - "2020-08-01 00:01:12", - 351.32, - 1596240070, - 1 - ], - [ - "S1", - "2020-09-01 00:02:10", - 361.1, - 1598918530, - 1 - ], - [ - "S1", - "2020-09-01 00:19:12", - 362.1, - 1598919550, - 1 - ], - [ - "S2", - "2020-08-01 00:01:10", - 743.01, - 1596240070, - 1 - ], - [ - "S2", - "2020-08-01 00:01:24", - 751.92, - 1596240080, - 1 - ], - [ - "S2", - "2020-09-01 00:02:10", - 761.1, - 1598918530, - 1 - ], - [ - "S2", - "2020-09-01 00:20:42", - 762.33, - 1598919640, - 1 - ] - ] - } - }, - "test_select_empty": { - "init": { - "$ref": "#/__SharedData/temp_slice_init_data" - } - }, - "test_select_only_required_cols": { - "init": { - "$ref": "#/__SharedData/temp_slice_init_data" - } - }, - "test_select_all_cols": { - "init": { - "$ref": "#/__SharedData/temp_slice_init_data" - } - }, - "test_show": { - "init": { - "$ref": "#/__SharedData/temp_slice_init_data" - } - }, - "test_show_n_5": { - "init": { - "$ref": "#/__SharedData/temp_slice_init_data" - } - }, - "test_show_k_gt_n": { - "init": { - "$ref": "#/__SharedData/temp_slice_init_data" - } - }, - "test_show_truncate_false": { - "init": { - "$ref": "#/__SharedData/temp_slice_init_data" - } - }, - "test_show_vertical_true": { - "init": { - "$ref": "#/__SharedData/temp_slice_init_data" - } - }, - "test_show_vertical_true_n_5": { - "init": { - "$ref": "#/__SharedData/temp_slice_init_data" - } - }, - "test_show_truncate_false_vertical_true": { - "init": { - "$ref": "#/__SharedData/temp_slice_init_data" - } - }, - "test_describe": { - "init": { - "schema": "symbol string, event_ts string, trade_pr float", - "ts_col": "event_ts", - "series_ids": [ - "symbol" - ], - "data": [ - [ - "S1", - "2020-08-01 00:00:10", - 349.21 - ], - [ - "S1", - "2020-08-01 00:01:12", - 351.32 - ], - [ - "S1", - "2020-09-01 00:02:10", - 361.1 - ], - [ - "S1", - "2020-09-01 00:19:12", - 362.1 - ] - ] - } - }, - "test__getSparkPlan": { - "init": { - "$ref": "#/TSDFBaseTests/test__getBytesFromPlan/init" - } - }, - "test__getBytesFromPlan": { - "init": { - "schema": "symbol string, event_ts string, trade_pr float", - "ts_col": "event_ts", - "series_ids": [ - "symbol" - ], - "data": [ - [ - "S1", - "2020-08-01 00:00:10", - 349.21 - ], - [ - "S1", - "2020-08-01 00:01:12", - 351.32 - ], - [ - "S1", - "2020-09-01 00:02:10", - 361.1 - ], - [ - "S1", - "2020-09-01 00:19:12", - 362.1 - ] - ] - } - }, - "test__getBytesFromPlan_search_result_is_None": { - "init": { - "$ref": "#/TSDFBaseTests/test__getBytesFromPlan/init" - } - }, - "test__getBytesFromPlan_size_in_GiB": { - "init": { - "$ref": "#/TSDFBaseTests/test__getBytesFromPlan/init" - } - }, - "test__getBytesFromPlan_size_in_MiB": { - "init": { - "$ref": "#/TSDFBaseTests/test__getBytesFromPlan/init" - } - }, - "test__getBytesFromPlan_size_in_KiB": { - "init": { - "$ref": "#/TSDFBaseTests/test__getBytesFromPlan/init" - } - }, - "test_at_string_timestamp": { - "init": { - "$ref": "#/__SharedData/temp_slice_init_data" - }, - "expected": { - "schema": "symbol string, event_ts string, trade_pr float", - "ts_col": "event_ts", - "series_ids": [ - "symbol" - ], - "data": [ - [ - "S1", - "2020-09-01 00:02:10", - 361.1 - ], - [ - "S2", - "2020-09-01 00:02:10", - 761.10 - ] - ] - } - }, - "test_at_numeric_timestamp": { - "init": { - "$ref": "#/__SharedData/temp_slice_init_data" - }, - "expected": { - "$ref": "#/TSDFBaseTests/test_at_string_timestamp/expected" - } - }, - "test_before_string_timestamp": { - "init": { - "$ref": "#/__SharedData/temp_slice_init_data" - }, - "expected": { - "schema": "symbol string, event_ts string, trade_pr float", - "ts_col": "event_ts", - "series_ids": [ - "symbol" - ], - "data": [ - [ - "S1", - "2020-08-01 00:00:10", - 349.21 - ], - [ - "S1", - "2020-08-01 00:01:12", - 351.32 - ], - [ - "S2", - "2020-08-01 00:01:10", - 743.01 - ], - [ - "S2", - "2020-08-01 00:01:24", - 751.92 - ] - ] - } - }, - "test_before_numeric_timestamp": { - "init": { - "$ref": "#/__SharedData/temp_slice_init_data" - }, - "expected": { - "$ref": "#/TSDFBaseTests/test_before_string_timestamp/expected" - } - }, - "test_atOrBefore_string_timestamp": { - "init": { - "$ref": "#/__SharedData/temp_slice_init_data" - }, - "expected": { - "schema": "symbol string, event_ts string, trade_pr float", - "ts_col": "event_ts", - "series_ids": [ - "symbol" - ], - "data": [ - [ - "S1", - "2020-08-01 00:00:10", - 349.21 - ], - [ - "S1", - "2020-08-01 00:01:12", - 351.32 - ], - [ - "S1", - "2020-09-01 00:02:10", - 361.1 - ], - [ - "S2", - "2020-08-01 00:01:10", - 743.01 - ], - [ - "S2", - "2020-08-01 00:01:24", - 751.92 - ], - [ - "S2", - "2020-09-01 00:02:10", - 761.10 - ] - ] - } - }, - "test_atOrBefore_numeric_timestamp": { - "init": { - "$ref": "#/__SharedData/temp_slice_init_data" - }, - "expected": { - "$ref": "#/TSDFBaseTests/test_atOrBefore_string_timestamp/expected" - } - }, - "test_after_string_timestamp": { - "init": { - "$ref": "#/__SharedData/temp_slice_init_data" - }, - "expected": { - "schema": "symbol string, event_ts string, trade_pr float", - "ts_col": "event_ts", - "series_ids": [ - "symbol" - ], - "data": [ - [ - "S1", - "2020-09-01 00:19:12", - 362.1 - ], - [ - "S2", - "2020-09-01 00:20:42", - 762.33 - ] - ] - } - }, - "test_after_numeric_timestamp": { - "init": { - "$ref": "#/__SharedData/temp_slice_init_data" - }, - "expected": { - "$ref": "#/TSDFBaseTests/test_after_string_timestamp/expected" - } - }, - "test_atOrAfter_string_timestamp": { - "init": { - "$ref": "#/__SharedData/temp_slice_init_data" - }, - "expected": { - "schema": "symbol string, event_ts string, trade_pr float", - "ts_col": "event_ts", - "series_ids": [ - "symbol" - ], - "data": [ - [ - "S1", - "2020-09-01 00:02:10", - 361.1 - ], - [ - "S1", - "2020-09-01 00:19:12", - 362.1 - ], - [ - "S2", - "2020-09-01 00:02:10", - 761.10 - ], - [ - "S2", - "2020-09-01 00:20:42", - 762.33 - ] - ] - } - }, - "test_atOrAfter_numeric_timestamp": { - "init": { - "$ref": "#/__SharedData/temp_slice_init_data" - }, - "expected": { - "$ref": "#/TSDFBaseTests/test_atOrAfter_string_timestamp/expected" - } - }, - "test_between_string_timestamp": { - "init": { - "$ref": "#/__SharedData/temp_slice_init_data" - }, - "expected": { - "schema": "symbol string, event_ts string, trade_pr float", - "ts_col": "event_ts", - "series_ids": [ - "symbol" - ], - "data": [ - [ - "S1", - "2020-08-01 00:01:12", - 351.32 - ], - [ - "S1", - "2020-09-01 00:02:10", - 361.1 - ], - [ - "S2", - "2020-08-01 00:01:10", - 743.01 - ], - [ - "S2", - "2020-08-01 00:01:24", - 751.92 - ], - [ - "S2", - "2020-09-01 00:02:10", - 761.10 - ] - ] - } - }, - "test_between_numeric_timestamp": { - "init": { - "$ref": "#/__SharedData/temp_slice_init_data" - }, - "expected": { - "$ref": "#/TSDFBaseTests/test_between_string_timestamp/expected" - } - }, - "test_between_exclusive_string_timestamp": { - "init": { - "$ref": "#/__SharedData/temp_slice_init_data" - }, - "expected": { - "schema": "symbol string, event_ts string, trade_pr float", - "ts_col": "event_ts", - "series_ids": [ - "symbol" - ], - "data": [ - [ - "S1", - "2020-08-01 00:01:12", - 351.32 - ], - [ - "S1", - "2020-09-01 00:02:10", - 361.1 - ], - [ - "S2", - "2020-08-01 00:01:24", - 751.92 - ], - [ - "S2", - "2020-09-01 00:02:10", - 761.10 - ] - ] - } - }, - "test_between_exclusive_numeric_timestamp": { - "init": { - "$ref": "#/__SharedData/temp_slice_init_data" - }, - "expected": { - "$ref": "#/TSDFBaseTests/test_between_exclusive_string_timestamp/expected" - } - }, - "test_earliest_string_timestamp": { - "init": { - "$ref": "#/__SharedData/temp_slice_init_data" - }, - "expected": { - "schema": "symbol string, event_ts string, trade_pr float", - "ts_col": "event_ts", - "series_ids": [ - "symbol" - ], - "data": [ - [ - "S1", - "2020-08-01 00:00:10", - 349.21 - ], - [ - "S1", - "2020-08-01 00:01:12", - 351.32 - ], - [ - "S1", - "2020-09-01 00:02:10", - 361.1 - ], - [ - "S2", - "2020-08-01 00:01:10", - 743.01 - ], - [ - "S2", - "2020-08-01 00:01:24", - 751.92 - ], - [ - "S2", - "2020-09-01 00:02:10", - 761.10 - ] - ] - } - }, - "test_earliest_numeric_timestamp": { - "init": { - "$ref": "#/__SharedData/temp_slice_init_data" - }, - "expected": { - "$ref": "#/TSDFBaseTests/test_earliest_string_timestamp/expected" - } - }, - "test_latest_string_timestamp": { - "init": { - "$ref": "#/__SharedData/temp_slice_init_data" - }, - "expected": { - "schema": "symbol string, event_ts string, trade_pr float", - "ts_col": "event_ts", - "series_ids": [ - "symbol" - ], - "data": [ - [ - "S1", - "2020-08-01 00:01:12", - 351.32 - ], - [ - "S1", - "2020-09-01 00:02:10", - 361.1 - ], - [ - "S1", - "2020-09-01 00:19:12", - 362.1 - ], - [ - "S2", - "2020-08-01 00:01:24", - 751.92 - ], - [ - "S2", - "2020-09-01 00:02:10", - 761.10 - ], - [ - "S2", - "2020-09-01 00:20:42", - 762.33 - ] - ] - } - }, - "test_latest_numeric_timestamp": { - "init": { - "$ref": "#/__SharedData/temp_slice_init_data" - }, - "expected": { - "$ref": "#/TSDFBaseTests/test_latest_string_timestamp/expected" - } - }, - "test_priorTo_string_timestamp": { - "init": { - "$ref": "#/__SharedData/temp_slice_init_data" - }, - "expected": { - "schema": "symbol string, event_ts string, trade_pr float", - "ts_col": "event_ts", - "series_ids": [ - "symbol" - ], - "data": [ - [ - "S1", - "2020-08-01 00:01:12", - 351.32 - ], - [ - "S2", - "2020-08-01 00:01:24", - 751.92 - ] - ] - } - }, - "test_priorTo_numeric_timestamp": { - "init": { - "$ref": "#/__SharedData/temp_slice_init_data" - }, - "expected": { - "$ref": "#/TSDFBaseTests/test_priorTo_string_timestamp/expected" - } - }, - "test_subsequentTo_string_timestamp": { - "init": { - "$ref": "#/__SharedData/temp_slice_init_data" - }, - "expected": { - "schema": "symbol string, event_ts string, trade_pr float", - "ts_col": "event_ts", - "series_ids": [ - "symbol" - ], - "data": [ - [ - "S1", - "2020-09-01 00:02:10", - 361.1 - ], - [ - "S2", - "2020-09-01 00:02:10", - 761.10 - ] - ] - } - }, - "test_subsequentTo_numeric_timestamp": { - "init": { - "$ref": "#/__SharedData/temp_slice_init_data" - }, - "expected": { - "$ref": "#/TSDFBaseTests/test_subsequentTo_string_timestamp/expected" - } - }, - "test__rowsBetweenWindow": { - "init": { - "$ref": "#/__SharedData/temp_slice_init_data" - } + "simple_ts_no_series": { + "ts_idx": { + "ts_col": "event_ts", + "series_ids": [] }, - "test_withPartitionCols": { - "init": { - "schema": "symbol string, event_ts string, trade_pr float", - "ts_col": "event_ts", - "data": { - "$ref": "#/__SharedData/temp_slice_init_data/data" - } - } + "df": { + "schema": "event_ts string, trade_pr float", + "ts_convert": ["event_ts"], + "data": [ + ["2020-08-01 00:00:10", 349.21], + ["2020-08-01 00:01:10", 743.01], + ["2020-08-01 00:01:12", 351.32], + ["2020-08-01 00:01:24", 751.92], + ["2020-09-01 00:02:10", 361.1], + ["2020-09-01 00:19:12", 362.1], + ["2020-09-01 00:20:42", 762.33] + ] } }, - "ResampleTest": { - "test_resample": { - "input": { - "schema": "symbol string, date string, event_ts string, trade_pr float, trade_pr_2 float", - "ts_col": "event_ts", - "series_ids": [ - "symbol" - ], - "data": [ - [ - "S1", - "SAME_DT", - "2020-08-01 00:00:10", - 349.21, - 10.0 - ], - [ - "S1", - "SAME_DT", - "2020-08-01 00:00:11", - 340.21, - 9.0 - ], - [ - "S1", - "SAME_DT", - "2020-08-01 00:01:12", - 353.32, - 8.0 - ], - [ - "S1", - "SAME_DT", - "2020-08-01 00:01:13", - 351.32, - 7.0 - ], - [ - "S1", - "SAME_DT", - "2020-08-01 00:01:14", - 350.32, - 6.0 - ], - [ - "S1", - "SAME_DT", - "2020-09-01 00:01:12", - 361.1, - 5.0 - ], - [ - "S1", - "SAME_DT", - "2020-09-01 00:19:12", - 362.1, - 4.0 - ] - ] - }, - "expected": { - "schema": "symbol string, event_ts string, floor_trade_pr float, floor_date string, floor_trade_pr_2 float", - "ts_col": "event_ts", - "series_ids": [ - "symbol" - ], - "data": [ - [ - "S1", - "2020-08-01 00:00:00", - 349.21, - "SAME_DT", - 10.0 - ], - [ - "S1", - "2020-08-01 00:01:00", - 353.32, - "SAME_DT", - 8.0 - ], - [ - "S1", - "2020-09-01 00:01:00", - 361.1, - "SAME_DT", - 5.0 - ], - [ - "S1", - "2020-09-01 00:19:00", - 362.1, - "SAME_DT", - 4.0 - ] - ] - }, - "expected30m": { - "schema": "symbol string, event_ts string, date double, trade_pr double, trade_pr_2 double", - "ts_col": "event_ts", - "series_ids": [ - "symbol" - ], - "data": [ - [ - "S1", - "2020-08-01 00:00:00", - null, - 348.88, - 8.0 - ], - [ - "S1", - "2020-09-01 00:00:00", - null, - 361.1, - 5.0 - ], - [ - "S1", - "2020-09-01 00:15:00", - null, - 362.1, - 4.0 - ] - ] - }, - "expectedbars": { - "schema": "symbol string, event_ts string, close_trade_pr float, close_trade_pr_2 float, high_trade_pr float, high_trade_pr_2 float, low_trade_pr float, low_trade_pr_2 float, open_trade_pr float, open_trade_pr_2 float", - "ts_col": "event_ts", - "series_ids": [ - "symbol" - ], - "data": [ - [ - "S1", - "2020-08-01 00:00:00", - 340.21, - 9.0, - 349.21, - 10.0, - 340.21, - 9.0, - 349.21, - 10.0 - ], - [ - "S1", - "2020-08-01 00:01:00", - 350.32, - 6.0, - 353.32, - 8.0, - 350.32, - 6.0, - 353.32, - 8.0 - ], - [ - "S1", - "2020-09-01 00:01:00", - 361.1, - 5.0, - 361.1, - 5.0, - 361.1, - 5.0, - 361.1, - 5.0 - ], - [ - "S1", - "2020-09-01 00:19:00", - 362.1, - 4.0, - 362.1, - 4.0, - 362.1, - 4.0, - 362.1, - 4.0 - ] - ] - } + "ordinal_double_index": { + "ts_idx": { + "ts_col": "event_ts_dbl", + "series_ids": ["symbol"] }, - "test_resample_millis": { - "init": { - "schema": "symbol string, date string, event_ts string, trade_pr float, trade_pr_2 float", - "ts_col": "event_ts", - "series_ids": [ - "symbol" - ], - "data": [ - [ - "S1", - "SAME_DT", - "2020-08-01 00:00:10.12345", - 349.21, - 10.0 - ], - [ - "S1", - "SAME_DT", - "2020-08-01 00:00:10.123", - 340.21, - 9.0 - ], - [ - "S1", - "SAME_DT", - "2020-08-01 00:00:10.124", - 353.32, - 8.0 - ] - ] - }, - "expectedms": { - "schema": "symbol string, event_ts string, date double, trade_pr double, trade_pr_2 double", - "ts_col": "event_ts", - "series_ids": [ - "symbol" - ], - "data": [ - [ - "S1", - "2020-08-01 00:00:10.123", - null, - 344.71, - 9.5 - ], - [ - "S1", - "2020-08-01 00:00:10.124", - null, - 353.32, - 8.0 - ] - ] - } - }, - "test_upsample": { - "input": { - "schema": "symbol string, date string, event_ts string, trade_pr float, trade_pr_2 float", - "ts_col": "event_ts", - "series_ids": [ - "symbol" - ], - "data": [ - [ - "S1", - "SAME_DT", - "2020-08-01 00:00:10", - 349.21, - 10.0 - ], - [ - "S1", - "SAME_DT", - "2020-08-01 00:00:11", - 340.21, - 9.0 - ], - [ - "S1", - "SAME_DT", - "2020-08-01 00:01:12", - 353.32, - 8.0 - ], - [ - "S1", - "SAME_DT", - "2020-08-01 00:01:13", - 351.32, - 7.0 - ], - [ - "S1", - "SAME_DT", - "2020-08-01 00:01:14", - 350.32, - 6.0 - ], - [ - "S1", - "SAME_DT", - "2020-09-01 00:01:12", - 361.1, - 5.0 - ], - [ - "S1", - "SAME_DT", - "2020-09-01 00:19:12", - 362.1, - 4.0 - ] - ] - }, - "expected": { - "schema": "symbol string, event_ts string, floor_trade_pr float, floor_date string, floor_trade_pr_2 float", - "ts_col": "event_ts", - "series_ids": [ - "symbol" - ], - "data": [ - [ - "S1", - "2020-08-01 00:00:00", - 349.21, - "SAME_DT", - 10.0 - ], - [ - "S1", - "2020-08-01 00:01:00", - 353.32, - "SAME_DT", - 8.0 - ], - [ - "S1", - "2020-09-01 00:01:00", - 361.1, - "SAME_DT", - 5.0 - ], - [ - "S1", - "2020-09-01 00:19:00", - 362.1, - "SAME_DT", - 4.0 - ] - ] - }, - "expected30m": { - "schema": "symbol string, event_ts string, date double, trade_pr double, trade_pr_2 double", - "ts_col": "event_ts", - "series_ids": [ - "symbol" - ], - "data": [ - [ - "S1", - "2020-08-01 00:00:00", - 0.0, - 348.88, - 8.0 - ], - [ - "S1", - "2020-08-01 00:05:00", - 0.0, - 0.0, - 0.0 - ], - [ - "S1", - "2020-09-01 00:00:00", - 0.0, - 361.1, - 5.0 - ], - [ - "S1", - "2020-09-01 00:15:00", - 0.0, - 362.1, - 4.0 - ] - ] - }, - "expectedbars": { - "schema": "symbol string, event_ts string, close_trade_pr float, close_trade_pr_2 float, high_trade_pr float, high_trade_pr_2 float, low_trade_pr float, low_trade_pr_2 float, open_trade_pr float, open_trade_pr_2 float", - "ts_col": "event_ts", - "series_ids": [ - "symbol" - ], - "data": [ - [ - "S1", - "2020-08-01 00:00:00", - 340.21, - 9.0, - 349.21, - 10.0, - 340.21, - 9.0, - 349.21, - 10.0 - ], - [ - "S1", - "2020-08-01 00:01:00", - 350.32, - 6.0, - 353.32, - 8.0, - 350.32, - 6.0, - 353.32, - 8.0 - ], - [ - "S1", - "2020-09-01 00:01:00", - 361.1, - 5.0, - 361.1, - 5.0, - 361.1, - 5.0, - 361.1, - 5.0 - ], - [ - "S1", - "2020-09-01 00:19:00", - 362.1, - 4.0, - 362.1, - 4.0, - 362.1, - 4.0, - 362.1, - 4.0 - ] - ] - } + "df": { + "schema": "symbol string, event_ts_dbl double, trade_pr float", + "data": [ + ["S1", 0.13, 349.21], + ["S1", 1.207, 351.32], + ["S1", 10.0, 361.1], + ["S1", 24.357, 362.1], + ["S2", 0.005, 743.01], + ["S2", 0.1, 751.92], + ["S2", 1.0, 761.10], + ["S2", 10.0, 762.33] + ] } }, - "ExtractStateIntervalsTest": { - "test_eq_0": { - "input": { - "schema": "event_ts STRING NOT NULL, identifier_1 STRING NOT NULL, identifier_2 STRING NOT NULL, identifier_3 STRING NOT NULL, metric_1 FLOAT NOT NULL, metric_2 FLOAT NOT NULL, metric_3 FLOAT NOT NULL", - "ts_col": "event_ts", - "series_ids": [ - "identifier_1", - "identifier_2", - "identifier_3" - ], - "data": [ - [ - "2020-08-01 00:00:09", - "v1", - "foo", - "bar", - 4.1, - 4.1, - 4.1 - ], - [ - "2020-08-01 00:00:10", - "v1", - "foo", - "bar", - 4.1, - 4.1, - 4.1 - ], - [ - "2020-08-01 00:00:11", - "v1", - "foo", - "bar", - 5.0, - 5.0, - 5.0 - ], - [ - "2020-08-01 00:01:12", - "v1", - "foo", - "bar", - 10.7, - 10.7, - 10.7 - ], - [ - "2020-08-01 00:01:13", - "v1", - "foo", - "bar", - 10.7, - 10.7, - 10.7 - ], - [ - "2020-08-01 00:01:14", - "v1", - "foo", - "bar", - 10.7, - 10.7, - 10.7 - ], - [ - "2020-08-01 00:01:15", - "v1", - "foo", - "bar", - 42.3, - 42.3, - 42.3 - ], - [ - "2020-08-01 00:01:16", - "v1", - "foo", - "bar", - 37.6, - 37.6, - 37.6 - ], - [ - "2020-08-01 00:01:17", - "v1", - "foo", - "bar", - 61.5, - 61.5, - 61.5 - ], - [ - "2020-09-01 00:01:12", - "v1", - "foo", - "bar", - 28.9, - 28.9, - 28.9 - ], - [ - "2020-09-01 00:19:12", - "v1", - "foo", - "bar", - 0.1, - 0.1, - 0.1 - ] - ] - }, - "expected": { - "schema": "start_ts STRING NOT NULL, end_ts STRING NOT NULL,identifier_1 STRING NOT NULL,identifier_2 STRING NOT NULL,identifier_3 STRING NOT NULL", - "other_ts_cols": [ - "start_ts", - "end_ts" - ], - "data": [ - [ - "2020-08-01 00:00:09", - "2020-08-01 00:00:10", - "v1", - "foo", - "bar" - ], - [ - "2020-08-01 00:01:12", - "2020-08-01 00:01:14", - "v1", - "foo", - "bar" - ] - ] - } - }, - "test_eq_1": { - "input": { - "schema": "event_ts STRING NOT NULL, identifier_1 STRING NOT NULL, identifier_2 STRING NOT NULL, identifier_3 STRING NOT NULL, metric_1 FLOAT, metric_2 FLOAT, metric_3 FLOAT", - "ts_col": "event_ts", - "series_ids": [ - "identifier_1", - "identifier_2", - "identifier_3" - ], - "data": [ - [ - "2020-08-01 00:00:09", - "v1", - "foo", - "bar", - 4.1, - null, - 4.1 - ], - [ - "2020-08-01 00:00:10", - "v1", - "foo", - "bar", - 4.1, - null, - 4.1 - ], - [ - "2020-08-01 00:00:11", - "v1", - "foo", - "bar", - 5.0, - 5.0, - 5.0 - ], - [ - "2020-08-01 00:01:12", - "v1", - "foo", - "bar", - 10.7, - 10.7, - 10.7 - ], - [ - "2020-08-01 00:01:13", - "v1", - "foo", - "bar", - 10.7, - 10.7, - 10.7 - ], - [ - "2020-08-01 00:01:14", - "v1", - "foo", - "bar", - 10.7, - 10.7, - null - ], - [ - "2020-08-01 00:01:15", - "v1", - "foo", - "bar", - 42.3, - 42.3, - 42.3 - ], - [ - "2020-08-01 00:01:16", - "v1", - "foo", - "bar", - 37.6, - 37.6, - 37.6 - ], - [ - "2020-08-01 00:01:17", - "v1", - "foo", - "bar", - 61.5, - 61.5, - 61.5 - ], - [ - "2020-09-01 00:01:12", - "v1", - "foo", - "bar", - 28.9, - 28.9, - 28.9 - ], - [ - "2020-09-01 00:19:12", - "v1", - "foo", - "bar", - 0.1, - 0.1, - 0.1 - ] - ] - }, - "expected": { - "schema": "start_ts STRING NOT NULL, end_ts STRING NOT NULL,identifier_1 STRING NOT NULL,identifier_2 STRING NOT NULL,identifier_3 STRING NOT NULL", - "other_ts_cols": [ - "start_ts", - "end_ts" - ], - "data": [ - [ - "2020-08-01 00:01:12", - "2020-08-01 00:01:13", - "v1", - "foo", - "bar" - ] - ] - } - }, - "test_ne_0": { - "input": { - "schema": "event_ts STRING NOT NULL, identifier_1 STRING NOT NULL, identifier_2 STRING NOT NULL, identifier_3 STRING NOT NULL, metric_1 FLOAT NOT NULL, metric_2 FLOAT NOT NULL, metric_3 FLOAT NOT NULL", - "ts_col": "event_ts", - "series_ids": [ - "identifier_1", - "identifier_2", - "identifier_3" - ], - "data": [ - [ - "2020-08-01 00:00:09", - "v1", - "foo", - "bar", - 4.1, - 4.1, - 4.1 - ], - [ - "2020-08-01 00:00:10", - "v1", - "foo", - "bar", - 4.1, - 4.1, - 4.1 - ], - [ - "2020-08-01 00:00:11", - "v1", - "foo", - "bar", - 5.0, - 5.0, - 5.0 - ], - [ - "2020-08-01 00:01:12", - "v1", - "foo", - "bar", - 10.7, - 10.7, - 10.7 - ], - [ - "2020-08-01 00:01:13", - "v1", - "foo", - "bar", - 10.7, - 10.7, - 10.7 - ], - [ - "2020-08-01 00:01:14", - "v1", - "foo", - "bar", - 10.7, - 10.7, - 10.7 - ], - [ - "2020-08-01 00:01:15", - "v1", - "foo", - "bar", - 42.3, - 42.3, - 42.3 - ], - [ - "2020-08-01 00:01:16", - "v1", - "foo", - "bar", - 37.6, - 37.6, - 37.6 - ], - [ - "2020-08-01 00:01:17", - "v1", - "foo", - "bar", - 61.5, - 61.5, - 61.5 - ], - [ - "2020-09-01 00:01:12", - "v1", - "foo", - "bar", - 28.9, - 28.9, - 28.9 - ], - [ - "2020-09-01 00:19:12", - "v1", - "foo", - "bar", - 0.1, - 0.1, - 0.1 - ] - ] - }, - "expected": { - "schema": "start_ts STRING NOT NULL, end_ts STRING NOT NULL,identifier_1 STRING NOT NULL,identifier_2 STRING NOT NULL,identifier_3 STRING NOT NULL", - "other_ts_cols": [ - "start_ts", - "end_ts" - ], - "data": [ - [ - "2020-08-01 00:00:10", - "2020-08-01 00:01:12", - "v1", - "foo", - "bar" - ], - [ - "2020-08-01 00:01:14", - "2020-09-01 00:19:12", - "v1", - "foo", - "bar" - ] - ] - } - }, - "test_ne_1": { - "input": { - "schema": "event_ts STRING NOT NULL, identifier_1 STRING NOT NULL, identifier_2 STRING NOT NULL, identifier_3 STRING NOT NULL, metric_1 FLOAT NOT NULL, metric_2 FLOAT NOT NULL, metric_3 FLOAT NOT NULL", - "ts_col": "event_ts", - "series_ids": [ - "identifier_1", - "identifier_2", - "identifier_3" - ], - "data": [ - [ - "2020-08-01 00:00:09", - "v1", - "foo", - "bar", - 4.1, - 4.1, - 4.1 - ], - [ - "2020-08-01 00:00:10", - "v1", - "foo", - "bar", - 4.1, - 4.0, - 4.2 - ], - [ - "2020-08-01 00:00:11", - "v1", - "foo", - "bar", - 4.3, - 4.1, - 4.7 - ] - ] - }, - "expected": { - "schema": "start_ts STRING NOT NULL, end_ts STRING NOT NULL,identifier_1 STRING NOT NULL,identifier_2 STRING NOT NULL,identifier_3 STRING NOT NULL", - "other_ts_cols": [ - "start_ts", - "end_ts" - ], - "data": [ - [ - "2020-08-01 00:00:10", - "2020-08-01 00:00:11", - "v1", - "foo", - "bar" - ] - ] - } - }, - "test_gt_0": { - "input": { - "schema": "event_ts STRING NOT NULL, identifier_1 STRING NOT NULL, identifier_2 STRING NOT NULL, identifier_3 STRING NOT NULL, metric_1 FLOAT NOT NULL, metric_2 FLOAT NOT NULL, metric_3 FLOAT NOT NULL", - "ts_col": "event_ts", - "series_ids": [ - "identifier_1", - "identifier_2", - "identifier_3" - ], - "data": [ - [ - "2020-08-01 00:00:09", - "v1", - "foo", - "bar", - 4.1, - 4.1, - 4.1 - ], - [ - "2020-08-01 00:00:10", - "v1", - "foo", - "bar", - 4.1, - 4.1, - 4.1 - ], - [ - "2020-08-01 00:00:11", - "v1", - "foo", - "bar", - 5.0, - 5.0, - 5.0 - ], - [ - "2020-08-01 00:01:12", - "v1", - "foo", - "bar", - 10.7, - 10.7, - 10.7 - ], - [ - "2020-08-01 00:01:13", - "v1", - "foo", - "bar", - 10.7, - 10.7, - 10.7 - ], - [ - "2020-08-01 00:01:14", - "v1", - "foo", - "bar", - 10.7, - 10.7, - 10.7 - ], - [ - "2020-08-01 00:01:15", - "v1", - "foo", - "bar", - 42.3, - 42.3, - 42.3 - ], - [ - "2020-08-01 00:01:16", - "v1", - "foo", - "bar", - 37.6, - 37.6, - 37.6 - ], - [ - "2020-08-01 00:01:17", - "v1", - "foo", - "bar", - 61.5, - 61.5, - 61.5 - ], - [ - "2020-09-01 00:01:12", - "v1", - "foo", - "bar", - 28.9, - 28.9, - 28.9 - ], - [ - "2020-09-01 00:19:12", - "v1", - "foo", - "bar", - 0.1, - 0.1, - 0.1 - ] - ] - }, - "expected": { - "schema": "start_ts STRING NOT NULL, end_ts STRING NOT NULL,identifier_1 STRING NOT NULL,identifier_2 STRING NOT NULL,identifier_3 STRING NOT NULL", - "other_ts_cols": [ - "start_ts", - "end_ts" - ], - "data": [ - [ - "2020-08-01 00:00:10", - "2020-08-01 00:01:12", - "v1", - "foo", - "bar" - ], - [ - "2020-08-01 00:01:14", - "2020-08-01 00:01:15", - "v1", - "foo", - "bar" - ], - [ - "2020-08-01 00:01:16", - "2020-08-01 00:01:17", - "v1", - "foo", - "bar" - ] - ] - } - }, - "test_gt_1": { - "input": { - "schema": "event_ts STRING NOT NULL, identifier_1 STRING NOT NULL, identifier_2 STRING NOT NULL, identifier_3 STRING NOT NULL, metric_1 FLOAT NOT NULL, metric_2 FLOAT NOT NULL, metric_3 FLOAT NOT NULL", - "ts_col": "event_ts", - "series_ids": [ - "identifier_1", - "identifier_2", - "identifier_3" - ], - "data": [ - [ - "2020-08-01 00:00:09", - "v1", - "foo", - "bar", - 4.3, - 4.1, - 4.7 - ], - [ - "2020-08-01 00:00:10", - "v1", - "foo", - "bar", - 4.4, - 4.0, - 4.6 - ], - [ - "2020-08-01 00:00:11", - "v1", - "foo", - "bar", - 4.5, - 4.1, - 4.7 - ] - ] - }, - "expected": { - "schema": "start_ts STRING NOT NULL, end_ts STRING NOT NULL,identifier_1 STRING NOT NULL,identifier_2 STRING NOT NULL,identifier_3 STRING NOT NULL", - "other_ts_cols": [ - "start_ts", - "end_ts" - ], - "data": [ - [ - "2020-08-01 00:00:10", - "2020-08-01 00:00:11", - "v1", - "foo", - "bar" - ] - ] - } - }, - "test_lt_0": { - "input": { - "schema": "event_ts STRING NOT NULL, identifier_1 STRING NOT NULL, identifier_2 STRING NOT NULL, identifier_3 STRING NOT NULL, metric_1 FLOAT NOT NULL, metric_2 FLOAT NOT NULL, metric_3 FLOAT NOT NULL", - "ts_col": "event_ts", - "series_ids": [ - "identifier_1", - "identifier_2", - "identifier_3" - ], - "data": [ - [ - "2020-08-01 00:00:09", - "v1", - "foo", - "bar", - 4.1, - 4.1, - 4.1 - ], - [ - "2020-08-01 00:00:10", - "v1", - "foo", - "bar", - 4.1, - 4.1, - 4.1 - ], - [ - "2020-08-01 00:00:11", - "v1", - "foo", - "bar", - 5.0, - 5.0, - 5.0 - ], - [ - "2020-08-01 00:01:12", - "v1", - "foo", - "bar", - 10.7, - 10.7, - 10.7 - ], - [ - "2020-08-01 00:01:13", - "v1", - "foo", - "bar", - 10.7, - 10.7, - 10.7 - ], - [ - "2020-08-01 00:01:14", - "v1", - "foo", - "bar", - 10.7, - 10.7, - 10.7 - ], - [ - "2020-08-01 00:01:15", - "v1", - "foo", - "bar", - 42.3, - 42.3, - 42.3 - ], - [ - "2020-08-01 00:01:16", - "v1", - "foo", - "bar", - 37.6, - 37.6, - 37.6 - ], - [ - "2020-08-01 00:01:17", - "v1", - "foo", - "bar", - 61.5, - 61.5, - 61.5 - ], - [ - "2020-09-01 00:01:12", - "v1", - "foo", - "bar", - 28.9, - 28.9, - 28.9 - ], - [ - "2020-09-01 00:19:12", - "v1", - "foo", - "bar", - 0.1, - 0.1, - 0.1 - ] - ] - }, - "expected": { - "schema": "start_ts STRING NOT NULL, end_ts STRING NOT NULL,identifier_1 STRING NOT NULL,identifier_2 STRING NOT NULL,identifier_3 STRING NOT NULL", - "other_ts_cols": [ - "start_ts", - "end_ts" - ], - "data": [ - [ - "2020-08-01 00:01:15", - "2020-08-01 00:01:16", - "v1", - "foo", - "bar" - ], - [ - "2020-08-01 00:01:17", - "2020-09-01 00:19:12", - "v1", - "foo", - "bar" - ] - ] - } - }, - "test_lt_1": { - "input": { - "schema": "event_ts STRING NOT NULL, identifier_1 STRING NOT NULL, identifier_2 STRING NOT NULL, identifier_3 STRING NOT NULL, metric_1 FLOAT NOT NULL, metric_2 FLOAT NOT NULL, metric_3 FLOAT NOT NULL", - "ts_col": "event_ts", - "series_ids": [ - "identifier_1", - "identifier_2", - "identifier_3" - ], - "data": [ - [ - "2020-08-01 00:00:09", - "v1", - "foo", - "bar", - 4.3, - 4.1, - 4.7 - ], - [ - "2020-08-01 00:00:10", - "v1", - "foo", - "bar", - 4.2, - 4.2, - 4.8 - ], - [ - "2020-08-01 00:00:11", - "v1", - "foo", - "bar", - 4.1, - 4.1, - 4.7 - ] - ] - }, - "expected": { - "schema": "start_ts STRING NOT NULL, end_ts STRING NOT NULL,identifier_1 STRING NOT NULL,identifier_2 STRING NOT NULL,identifier_3 STRING NOT NULL", - "other_ts_cols": [ - "start_ts", - "end_ts" - ], - "data": [ - [ - "2020-08-01 00:00:10", - "2020-08-01 00:00:11", - "v1", - "foo", - "bar" - ] - ] - } - }, - "test_gte_0": { - "input": { - "schema": "event_ts STRING NOT NULL, identifier_1 STRING NOT NULL, identifier_2 STRING NOT NULL, identifier_3 STRING NOT NULL, metric_1 FLOAT NOT NULL, metric_2 FLOAT NOT NULL, metric_3 FLOAT NOT NULL", - "ts_col": "event_ts", - "series_ids": [ - "identifier_1", - "identifier_2", - "identifier_3" - ], - "data": [ - [ - "2020-08-01 00:00:09", - "v1", - "foo", - "bar", - 4.1, - 4.1, - 4.1 - ], - [ - "2020-08-01 00:00:10", - "v1", - "foo", - "bar", - 4.1, - 4.1, - 4.1 - ], - [ - "2020-08-01 00:00:11", - "v1", - "foo", - "bar", - 5.0, - 5.0, - 5.0 - ], - [ - "2020-08-01 00:01:12", - "v1", - "foo", - "bar", - 10.7, - 10.7, - 10.7 - ], - [ - "2020-08-01 00:01:13", - "v1", - "foo", - "bar", - 10.7, - 10.7, - 10.7 - ], - [ - "2020-08-01 00:01:14", - "v1", - "foo", - "bar", - 10.7, - 10.7, - 10.7 - ], - [ - "2020-08-01 00:01:15", - "v1", - "foo", - "bar", - 42.3, - 42.3, - 42.3 - ], - [ - "2020-08-01 00:01:16", - "v1", - "foo", - "bar", - 37.6, - 37.6, - 37.6 - ], - [ - "2020-08-01 00:01:17", - "v1", - "foo", - "bar", - 61.5, - 61.5, - 61.5 - ], - [ - "2020-09-01 00:01:12", - "v1", - "foo", - "bar", - 28.9, - 28.9, - 28.9 - ], - [ - "2020-09-01 00:19:12", - "v1", - "foo", - "bar", - 0.1, - 0.1, - 0.1 - ] - ] - }, - "expected": { - "schema": "start_ts STRING NOT NULL, end_ts STRING NOT NULL,identifier_1 STRING NOT NULL,identifier_2 STRING NOT NULL,identifier_3 STRING NOT NULL", - "other_ts_cols": [ - "start_ts", - "end_ts" - ], - "data": [ - [ - "2020-08-01 00:00:09", - "2020-08-01 00:01:15", - "v1", - "foo", - "bar" - ], - [ - "2020-08-01 00:01:16", - "2020-08-01 00:01:17", - "v1", - "foo", - "bar" - ] - ] - } + "ordinal_int_index": { + "ts_idx": { + "ts_col": "order", + "series_ids": ["symbol"] }, - "test_gte_1": { - "input": { - "schema": "event_ts STRING NOT NULL, identifier_1 STRING NOT NULL, identifier_2 STRING NOT NULL, identifier_3 STRING NOT NULL, metric_1 FLOAT NOT NULL, metric_2 FLOAT NOT NULL, metric_3 FLOAT NOT NULL", - "ts_col": "event_ts", - "series_ids": [ - "identifier_1", - "identifier_2", - "identifier_3" - ], - "data": [ - [ - "2020-08-01 00:00:09", - "v1", - "foo", - "bar", - 4.3, - 4.1, - 4.7 - ], - [ - "2020-08-01 00:00:10", - "v1", - "foo", - "bar", - 4.4, - 4.0, - 4.6 - ], - [ - "2020-08-01 00:00:11", - "v1", - "foo", - "bar", - 4.5, - 4.0, - 4.7 - ] - ] - }, - "expected": { - "schema": "start_ts STRING NOT NULL, end_ts STRING NOT NULL,identifier_1 STRING NOT NULL,identifier_2 STRING NOT NULL,identifier_3 STRING NOT NULL", - "other_ts_cols": [ - "start_ts", - "end_ts" - ], - "data": [ - [ - "2020-08-01 00:00:10", - "2020-08-01 00:00:11", - "v1", - "foo", - "bar" - ] - ] - } - }, - "test_lte_0": { - "input": { - "schema": "event_ts STRING NOT NULL, identifier_1 STRING NOT NULL, identifier_2 STRING NOT NULL, identifier_3 STRING NOT NULL, metric_1 FLOAT NOT NULL, metric_2 FLOAT NOT NULL, metric_3 FLOAT NOT NULL", - "ts_col": "event_ts", - "series_ids": [ - "identifier_1", - "identifier_2", - "identifier_3" - ], - "data": [ - [ - "2020-08-01 00:00:09", - "v1", - "foo", - "bar", - 4.1, - 4.1, - 4.1 - ], - [ - "2020-08-01 00:00:10", - "v1", - "foo", - "bar", - 4.1, - 4.1, - 4.1 - ], - [ - "2020-08-01 00:00:11", - "v1", - "foo", - "bar", - 5.0, - 5.0, - 5.0 - ], - [ - "2020-08-01 00:01:12", - "v1", - "foo", - "bar", - 10.7, - 10.7, - 10.7 - ], - [ - "2020-08-01 00:01:13", - "v1", - "foo", - "bar", - 10.7, - 10.7, - 10.7 - ], - [ - "2020-08-01 00:01:14", - "v1", - "foo", - "bar", - 10.7, - 10.7, - 10.7 - ], - [ - "2020-08-01 00:01:15", - "v1", - "foo", - "bar", - 42.3, - 42.3, - 42.3 - ], - [ - "2020-08-01 00:01:16", - "v1", - "foo", - "bar", - 37.6, - 37.6, - 37.6 - ], - [ - "2020-08-01 00:01:17", - "v1", - "foo", - "bar", - 61.5, - 61.5, - 61.5 - ], - [ - "2020-09-01 00:01:12", - "v1", - "foo", - "bar", - 28.9, - 28.9, - 28.9 - ], - [ - "2020-09-01 00:19:12", - "v1", - "foo", - "bar", - 0.1, - 0.1, - 0.1 - ] - ] - }, - "expected": { - "schema": "start_ts STRING NOT NULL, end_ts STRING NOT NULL,identifier_1 STRING NOT NULL,identifier_2 STRING NOT NULL,identifier_3 STRING NOT NULL", - "other_ts_cols": [ - "start_ts", - "end_ts" - ], - "data": [ - [ - "2020-08-01 00:00:09", - "2020-08-01 00:00:10", - "v1", - "foo", - "bar" - ], - [ - "2020-08-01 00:01:12", - "2020-08-01 00:01:14", - "v1", - "foo", - "bar" - ], - [ - "2020-08-01 00:01:15", - "2020-08-01 00:01:16", - "v1", - "foo", - "bar" - ], - [ - "2020-08-01 00:01:17", - "2020-09-01 00:19:12", - "v1", - "foo", - "bar" - ] - ] - } - }, - "test_lte_1": { - "input": { - "schema": "event_ts STRING NOT NULL, identifier_1 STRING NOT NULL, identifier_2 STRING NOT NULL, identifier_3 STRING NOT NULL, metric_1 FLOAT NOT NULL, metric_2 FLOAT NOT NULL, metric_3 FLOAT NOT NULL", - "ts_col": "event_ts", - "series_ids": [ - "identifier_1", - "identifier_2", - "identifier_3" - ], - "data": [ - [ - "2020-08-01 00:00:09", - "v1", - "foo", - "bar", - 4.3, - 4.1, - 4.7 - ], - [ - "2020-08-01 00:00:10", - "v1", - "foo", - "bar", - 4.2, - 4.2, - 4.8 - ], - [ - "2020-08-01 00:00:11", - "v1", - "foo", - "bar", - 4.1, - 4.2, - 4.7 - ] - ] - }, - "expected": { - "schema": "start_ts STRING NOT NULL, end_ts STRING NOT NULL,identifier_1 STRING NOT NULL,identifier_2 STRING NOT NULL,identifier_3 STRING NOT NULL", - "other_ts_cols": [ - "start_ts", - "end_ts" - ], - "data": [ - [ - "2020-08-01 00:00:10", - "2020-08-01 00:00:11", - "v1", - "foo", - "bar" - ] - ] - } - }, - "test_threshold_fn": { - "input": { - "schema": "event_ts STRING NOT NULL, identifier_1 STRING NOT NULL, identifier_2 STRING NOT NULL, identifier_3 STRING NOT NULL, metric_1 FLOAT NOT NULL, metric_2 FLOAT NOT NULL, metric_3 FLOAT NOT NULL", - "ts_col": "event_ts", - "series_ids": [ - "identifier_1", - "identifier_2", - "identifier_3" - ], - "data": [ - [ - "2020-08-01 00:00:09", - "v1", - "foo", - "bar", - 4.1, - 4.1, - 4.1 - ], - [ - "2020-08-01 00:00:10", - "v1", - "foo", - "bar", - 4.1, - 4.1, - 4.1 - ], - [ - "2020-08-01 00:00:11", - "v1", - "foo", - "bar", - 5.0, - 5.0, - 5.0 - ], - [ - "2020-08-01 00:01:12", - "v1", - "foo", - "bar", - 10.7, - 10.7, - 10.7 - ], - [ - "2020-08-01 00:01:13", - "v1", - "foo", - "bar", - 10.7, - 10.7, - 10.7 - ], - [ - "2020-08-01 00:01:14", - "v1", - "foo", - "bar", - 10.7, - 10.7, - 10.7 - ], - [ - "2020-08-01 00:01:15", - "v1", - "foo", - "bar", - 42.3, - 42.3, - 42.3 - ], - [ - "2020-08-01 00:01:16", - "v1", - "foo", - "bar", - 37.6, - 37.6, - 37.6 - ], - [ - "2020-08-01 00:01:17", - "v1", - "foo", - "bar", - 61.5, - 61.5, - 61.5 - ], - [ - "2020-09-01 00:01:12", - "v1", - "foo", - "bar", - 28.9, - 28.9, - 28.9 - ], - [ - "2020-09-01 00:19:12", - "v1", - "foo", - "bar", - 0.1, - 0.1, - 0.1 - ] - ] - }, - "expected": { - "schema": "start_ts: STRING, end_ts: STRING, identifier_1 STRING NOT NULL, identifier_2 STRING NOT NULL ,identifier_3 STRING NOT NULL", - "other_ts_cols": [ - "start_ts", - "end_ts" - ], - "data": [ - [ - "2020-08-01 00:00:09", - "2020-08-01 00:00:10", - "v1", - "foo", - "bar" - ], - [ - "2020-08-01 00:01:12", - "2020-08-01 00:01:14", - "v1", - "foo", - "bar" - ] - ] - } - }, - "test_null_safe_eq_0": { - "input": { - "schema": "event_ts STRING NOT NULL, identifier_1 STRING NOT NULL, identifier_2 STRING NOT NULL, identifier_3 STRING NOT NULL, metric_1 FLOAT, metric_2 FLOAT, metric_3 FLOAT", - "ts_col": "event_ts", - "series_ids": [ - "identifier_1", - "identifier_2", - "identifier_3" - ], - "data": [ - [ - "2020-08-01 00:00:09", - "v1", - "foo", - "bar", - 4.1, - null, - 4.1 - ], - [ - "2020-08-01 00:00:10", - "v1", - "foo", - "bar", - 4.1, - null, - 4.1 - ], - [ - "2020-08-01 00:00:11", - "v1", - "foo", - "bar", - 5.0, - 5.0, - 5.0 - ], - [ - "2020-08-01 00:01:12", - "v1", - "foo", - "bar", - null, - 10.7, - 10.7 - ], - [ - "2020-08-01 00:01:13", - "v1", - "foo", - "bar", - null, - 10.7, - 10.7 - ], - [ - "2020-08-01 00:01:14", - "v1", - "foo", - "bar", - null, - 10.7, - 10.7 - ], - [ - "2020-08-01 00:01:15", - "v1", - "foo", - "bar", - 42.3, - 42.3, - 42.3 - ], - [ - "2020-08-01 00:01:16", - "v1", - "foo", - "bar", - 37.6, - 37.6, - 37.6 - ], - [ - "2020-08-01 00:01:17", - "v1", - "foo", - "bar", - 61.5, - 61.5, - 61.5 - ], - [ - "2020-09-01 00:01:12", - "v1", - "foo", - "bar", - 28.9, - 28.9, - 28.9 - ], - [ - "2020-09-01 00:19:12", - "v1", - "foo", - "bar", - 0.1, - 0.1, - 0.1 - ] - ] - }, - "expected": { - "schema": "start_ts STRING NOT NULL, end_ts STRING NOT NULL,identifier_1 STRING NOT NULL,identifier_2 STRING NOT NULL,identifier_3 STRING NOT NULL", - "other_ts_cols": [ - "start_ts", - "end_ts" - ], - "data": [ - [ - "2020-08-01 00:00:09", - "2020-08-01 00:00:10", - "v1", - "foo", - "bar" - ], - [ - "2020-08-01 00:01:12", - "2020-08-01 00:01:14", - "v1", - "foo", - "bar" - ] - ] - } - }, - "test_null_safe_eq_1": { - "input": { - "schema": "event_ts STRING NOT NULL, identifier_1 STRING NOT NULL, identifier_2 STRING NOT NULL, identifier_3 STRING NOT NULL, metric_1 FLOAT, metric_2 FLOAT, metric_3 FLOAT", - "ts_col": "event_ts", - "series_ids": [ - "identifier_1", - "identifier_2", - "identifier_3" - ], - "data": [ - [ - "2020-08-01 00:00:09", - "v1", - "foo", - "bar", - 4.1, - null, - 4.1 - ], - [ - "2020-08-01 00:00:10", - "v1", - "foo", - "bar", - 4.1, - 4.1, - null - ], - [ - "2020-08-01 00:00:11", - "v1", - "foo", - "bar", - 5.0, - 5.0, - 5.0 - ], - [ - "2020-08-01 00:01:12", - "v1", - "foo", - "bar", - null, - 10.7, - 10.7 - ], - [ - "2020-08-01 00:01:13", - "v1", - "foo", - "bar", - null, - 10.7, - 10.7 - ], - [ - "2020-08-01 00:01:14", - "v1", - "foo", - "bar", - 10.7, - null, - 10.7 - ], - [ - "2020-08-01 00:01:15", - "v1", - "foo", - "bar", - 42.3, - 42.3, - 42.3 - ], - [ - "2020-08-01 00:01:16", - "v1", - "foo", - "bar", - 37.6, - 37.6, - 37.6 - ], - [ - "2020-08-01 00:01:17", - "v1", - "foo", - "bar", - 61.5, - 61.5, - 61.5 - ], - [ - "2020-09-01 00:01:12", - "v1", - "foo", - "bar", - 28.9, - 28.9, - 28.9 - ], - [ - "2020-09-01 00:19:12", - "v1", - "foo", - "bar", - 0.1, - 0.1, - 0.1 - ] - ] - }, - "expected": { - "schema": "start_ts STRING NOT NULL, end_ts STRING NOT NULL,identifier_1 STRING NOT NULL,identifier_2 STRING NOT NULL,identifier_3 STRING NOT NULL", - "other_ts_cols": [ - "start_ts", - "end_ts" - ], - "data": [ - [ - "2020-08-01 00:01:12", - "2020-08-01 00:01:13", - "v1", - "foo", - "bar" - ] - ] - } - }, - "test_adjacent_intervals": { - "input": { - "schema": "event_ts STRING NOT NULL, identifier_1 STRING NOT NULL, identifier_2 STRING NOT NULL, identifier_3 STRING NOT NULL, metric_1 FLOAT, metric_2 FLOAT, metric_3 FLOAT", - "ts_col": "event_ts", - "series_ids": [ - "identifier_1", - "identifier_2", - "identifier_3" - ], - "data": [ - [ - "2020-08-01 00:00:09", - "v1", - "foo", - "bar", - 4.1, - 4.1, - 4.1 - ], - [ - "2020-08-01 00:00:10", - "v1", - "foo", - "bar", - 4.1, - 4.1, - 4.1 - ], - [ - "2020-08-01 00:00:11", - "v1", - "foo", - "bar", - 5.0, - 5.0, - 5.0 - ], - [ - "2020-08-01 00:00:12", - "v1", - "foo", - "bar", - 5.0, - 5.0, - 5.0 - ], - [ - "2020-08-01 00:01:12", - "v1", - "foo", - "bar", - 10.7, - 10.7, - 10.7 - ], - [ - "2020-08-01 00:01:13", - "v1", - "foo", - "bar", - 10.7, - 10.7, - 10.7 - ], - [ - "2020-08-01 00:01:14", - "v1", - "foo", - "bar", - 10.7, - 10.7, - 10.7 - ] - ] - }, - "expected": { - "schema": "start_ts STRING NOT NULL, end_ts STRING NOT NULL,identifier_1 STRING NOT NULL,identifier_2 STRING NOT NULL,identifier_3 STRING NOT NULL", - "other_ts_cols": [ - "start_ts", - "end_ts" - ], - "data": [ - [ - "2020-08-01 00:00:09", - "2020-08-01 00:00:10", - "v1", - "foo", - "bar" - ], - [ - "2020-08-01 00:00:11", - "2020-08-01 00:00:12", - "v1", - "foo", - "bar" - ], - [ - "2020-08-01 00:01:12", - "2020-08-01 00:01:14", - "v1", - "foo", - "bar" - ] - ] - } - }, - "test_invalid_state_definition_str": { - "input": { - "schema": "event_ts STRING NOT NULL, identifier_1 STRING NOT NULL, identifier_2 STRING NOT NULL, identifier_3 STRING NOT NULL, metric_1 FLOAT NOT NULL, metric_2 FLOAT NOT NULL, metric_3 FLOAT NOT NULL", - "ts_col": "event_ts", - "series_ids": [ - "identifier_1", - "identifier_2", - "identifier_3" - ], - "data": [ - [ - "2020-08-01 00:00:09", - "v1", - "foo", - "bar", - 4.1, - 4.1, - 4.1 - ] - ] - } - }, - "test_invalid_state_definition_type": { - "input": { - "schema": "event_ts STRING NOT NULL, identifier_1 STRING NOT NULL, identifier_2 STRING NOT NULL, identifier_3 STRING NOT NULL, metric_1 FLOAT NOT NULL, metric_2 FLOAT NOT NULL, metric_3 FLOAT NOT NULL", - "ts_col": "event_ts", - "series_ids": [ - "identifier_1", - "identifier_2", - "identifier_3" - ], - "data": [ - [ - "2020-08-01 00:00:09", - "v1", - "foo", - "bar", - 4.1, - 4.1, - 4.1 - ] - ] - } + "df": { + "schema": "symbol string, order int, trade_pr float", + "data": [ + ["S1", 1, 349.21], + ["S1", 20, 351.32], + ["S1", 127, 361.1], + ["S1", 243, 362.1], + ["S2", 0, 743.01], + ["S2", 1, 751.92], + ["S2", 10, 761.10], + ["S2", 100, 762.33] + ] } } } \ No newline at end of file diff --git a/python/tests/unit_test_data/tsschema_tests.json b/python/tests/unit_test_data/tsschema_tests.json deleted file mode 100644 index 3e78def3..00000000 --- a/python/tests/unit_test_data/tsschema_tests.json +++ /dev/null @@ -1,29 +0,0 @@ -{ - "__SharedData": { - "simple_ts_idx": { - "ts_idx": { - "ts_col": "event_ts", - "series_ids": ["symbol"] - }, - "df": { - "schema": "symbol string, event_ts string, trade_pr float", - "ts_convert": ["event_ts"], - "data": [ - ["S1", "2020-08-01 00:00:10", 349.21], - ["S1", "2020-08-01 00:01:12", 351.32], - ["S1", "2020-09-01 00:02:10", 361.1], - ["S1", "2020-09-01 00:19:12", 362.1], - ["S2", "2020-08-01 00:01:10", 743.01], - ["S2", "2020-08-01 00:01:24", 751.92], - ["S2", "2020-09-01 00:02:10", 761.10], - ["S2", "2020-09-01 00:20:42", 762.33] - ] - } - } - }, - "TSSchemaTests": { - "simple_ts_idx": { - "$ref": "#/__SharedData/simple_ts_idx" - } - } -} \ No newline at end of file diff --git a/python/tests/utils_tests.py b/python/tests/utils_tests.py index a2a0523c..e0878062 100644 --- a/python/tests/utils_tests.py +++ b/python/tests/utils_tests.py @@ -4,7 +4,7 @@ from unittest import mock from tempo.utils import * # noqa: F403 -from tests.tsdf_tests import SparkTest +from tests.base import SparkTest class UtilsTest(SparkTest): From c49901e2aba91499f17a8b1eb466a8402f698c20 Mon Sep 17 00:00:00 2001 From: Tristan Nixon Date: Wed, 17 Jan 2024 18:27:46 -0800 Subject: [PATCH 12/13] completed time slicing tests for simple ts indexes --- python/tests/tsdf_tests.py | 500 +++++++++++++++++++++++++++++++++++-- 1 file changed, 484 insertions(+), 16 deletions(-) diff --git a/python/tests/tsdf_tests.py b/python/tests/tsdf_tests.py index 12d23c02..bd8ece76 100644 --- a/python/tests/tsdf_tests.py +++ b/python/tests/tsdf_tests.py @@ -216,10 +216,10 @@ def test_before(self, init_tsdf_id, ts, expected_tsdf_dict): # load TSDF tsdf = self.get_test_data(init_tsdf_id).as_tsdf() # slice at timestamp - at_tsdf = tsdf.before(ts) + before_tsdf = tsdf.before(ts) # validate the slice expected_tsdf = TestDataFrame(self.spark, expected_tsdf_dict).as_tsdf() - self.assertDataFrameEquality(at_tsdf, expected_tsdf) + self.assertDataFrameEquality(before_tsdf, expected_tsdf) @parameterized.expand([ ( @@ -295,10 +295,10 @@ def test_atOrBefore(self, init_tsdf_id, ts, expected_tsdf_dict): # load TSDF tsdf = self.get_test_data(init_tsdf_id).as_tsdf() # slice at timestamp - at_tsdf = tsdf.atOrBefore(ts) + at_before_tsdf = tsdf.atOrBefore(ts) # validate the slice expected_tsdf = TestDataFrame(self.spark, expected_tsdf_dict).as_tsdf() - self.assertDataFrameEquality(at_tsdf, expected_tsdf) + self.assertDataFrameEquality(at_before_tsdf, expected_tsdf) @parameterized.expand([ ( @@ -311,7 +311,7 @@ def test_atOrBefore(self, init_tsdf_id, ts, expected_tsdf_dict): "ts_convert": ["event_ts"], "data": [ ["S1", "2020-09-01 00:19:12", 362.1], - ["S2", "2020-09-01 00:20:42", 762.33] + ["S2", "2020-09-01 00:20:42", 762.33], ], }, }, @@ -326,7 +326,7 @@ def test_atOrBefore(self, init_tsdf_id, ts, expected_tsdf_dict): "ts_convert": ["event_ts"], "data": [ ["2020-09-01 00:19:12", 362.1], - ["2020-09-01 00:20:42", 762.33] + ["2020-09-01 00:20:42", 762.33], ], }, }, @@ -342,7 +342,7 @@ def test_atOrBefore(self, init_tsdf_id, ts, expected_tsdf_dict): ["S1", 1.207, 351.32], ["S1", 10.0, 361.1], ["S1", 24.357, 362.1], - ["S2", 10.0, 762.33] + ["S2", 10.0, 762.33], ], }, }, @@ -358,7 +358,7 @@ def test_atOrBefore(self, init_tsdf_id, ts, expected_tsdf_dict): ["S1", 20, 351.32], ["S1", 127, 361.1], ["S1", 243, 362.1], - ["S2", 100, 762.33] + ["S2", 100, 762.33], ], }, }, @@ -368,10 +368,10 @@ def test_after(self, init_tsdf_id, ts, expected_tsdf_dict): # load TSDF tsdf = self.get_test_data(init_tsdf_id).as_tsdf() # slice at timestamp - at_tsdf = tsdf.after(ts) + after_tsdf = tsdf.after(ts) # validate the slice expected_tsdf = TestDataFrame(self.spark, expected_tsdf_dict).as_tsdf() - self.assertDataFrameEquality(at_tsdf, expected_tsdf) + self.assertDataFrameEquality(after_tsdf, expected_tsdf) @parameterized.expand([ ( @@ -386,7 +386,7 @@ def test_after(self, init_tsdf_id, ts, expected_tsdf_dict): ["S1", "2020-09-01 00:02:10", 361.1], ["S1", "2020-09-01 00:19:12", 362.1], ["S2", "2020-09-01 00:02:10", 761.10], - ["S2", "2020-09-01 00:20:42", 762.33] + ["S2", "2020-09-01 00:20:42", 762.33], ], }, }, @@ -403,7 +403,7 @@ def test_after(self, init_tsdf_id, ts, expected_tsdf_dict): ["2020-08-01 00:01:24", 751.92], ["2020-09-01 00:02:10", 361.1], ["2020-09-01 00:19:12", 362.1], - ["2020-09-01 00:20:42", 762.33] + ["2020-09-01 00:20:42", 762.33], ], }, }, @@ -418,7 +418,7 @@ def test_after(self, init_tsdf_id, ts, expected_tsdf_dict): "data": [ ["S1", 10.0, 361.1], ["S1", 24.357, 362.1], - ["S2", 10.0, 762.33] + ["S2", 10.0, 762.33], ], }, }, @@ -435,7 +435,7 @@ def test_after(self, init_tsdf_id, ts, expected_tsdf_dict): ["S1", 127, 361.1], ["S1", 243, 362.1], ["S2", 10, 761.10], - ["S2", 100, 762.33] + ["S2", 100, 762.33], ], }, }, @@ -445,7 +445,475 @@ def test_atOrAfter(self, init_tsdf_id, ts, expected_tsdf_dict): # load TSDF tsdf = self.get_test_data(init_tsdf_id).as_tsdf() # slice at timestamp - at_tsdf = tsdf.atOrAfter(ts) + at_after_tsdf = tsdf.atOrAfter(ts) # validate the slice expected_tsdf = TestDataFrame(self.spark, expected_tsdf_dict).as_tsdf() - self.assertDataFrameEquality(at_tsdf, expected_tsdf) + self.assertDataFrameEquality(at_after_tsdf, expected_tsdf) + + @parameterized.expand([ + ( + "simple_ts_idx", + "2020-08-01 00:01:10", + "2020-09-01 00:02:10", + { + "ts_idx": {"ts_col": "event_ts", "series_ids": ["symbol"]}, + "df": { + "schema": "symbol string, event_ts string, trade_pr float", + "ts_convert": ["event_ts"], + "data": [ + ["S1", "2020-08-01 00:01:12", 351.32], + ["S2", "2020-08-01 00:01:24", 751.92], + ], + }, + }, + ), + ( + "simple_ts_no_series", + "2020-08-01 00:01:10", + "2020-09-01 00:02:10", + { + "ts_idx": {"ts_col": "event_ts", "series_ids": []}, + "df": { + "schema": "event_ts string, trade_pr float", + "ts_convert": ["event_ts"], + "data": [ + ["2020-08-01 00:01:12", 351.32], + ["2020-08-01 00:01:24", 751.92], + ], + }, + }, + ), + ( + "ordinal_double_index", + 0.1, + 10.0, + { + "ts_idx": {"ts_col": "event_ts_dbl", "series_ids": ["symbol"]}, + "df": { + "schema": "symbol string, event_ts_dbl double, trade_pr float", + "data": [ + ["S1", 0.13, 349.21], + ["S1", 1.207, 351.32], + ["S2", 1.0, 761.10], + ], + }, + }, + ), + ( + "ordinal_int_index", + 1, + 100, + { + "ts_idx": {"ts_col": "order", "series_ids": ["symbol"]}, + "df": { + "schema": "symbol string, order int, trade_pr float", + "data": [["S1", 20, 351.32], ["S2", 10, 761.10]], + }, + }, + ), + ]) + def test_between_non_inclusive( + self, init_tsdf_id, start_ts, end_ts, expected_tsdf_dict + ): + # load TSDF + tsdf = self.get_test_data(init_tsdf_id).as_tsdf() + # slice at timestamp + between_tsdf = tsdf.between(start_ts, end_ts, inclusive=False) + # validate the slice + expected_tsdf = TestDataFrame(self.spark, expected_tsdf_dict).as_tsdf() + self.assertDataFrameEquality(between_tsdf, expected_tsdf) + + @parameterized.expand([ + ( + "simple_ts_idx", + "2020-08-01 00:01:10", + "2020-09-01 00:02:10", + { + "ts_idx": {"ts_col": "event_ts", "series_ids": ["symbol"]}, + "df": { + "schema": "symbol string, event_ts string, trade_pr float", + "ts_convert": ["event_ts"], + "data": [ + ["S1", "2020-08-01 00:01:12", 351.32], + ["S1", "2020-09-01 00:02:10", 361.1], + ["S2", "2020-08-01 00:01:10", 743.01], + ["S2", "2020-08-01 00:01:24", 751.92], + ["S2", "2020-09-01 00:02:10", 761.10], + ], + }, + }, + ), + ( + "simple_ts_no_series", + "2020-08-01 00:01:10", + "2020-09-01 00:02:10", + { + "ts_idx": {"ts_col": "event_ts", "series_ids": []}, + "df": { + "schema": "event_ts string, trade_pr float", + "ts_convert": ["event_ts"], + "data": [ + ["2020-08-01 00:01:10", 743.01], + ["2020-08-01 00:01:12", 351.32], + ["2020-08-01 00:01:24", 751.92], + ["2020-09-01 00:02:10", 361.1], + ], + }, + }, + ), + ( + "ordinal_double_index", + 0.1, + 10.0, + { + "ts_idx": {"ts_col": "event_ts_dbl", "series_ids": ["symbol"]}, + "df": { + "schema": "symbol string, event_ts_dbl double, trade_pr float", + "data": [ + ["S1", 0.13, 349.21], + ["S1", 1.207, 351.32], + ["S1", 10.0, 361.1], + ["S2", 0.1, 751.92], + ["S2", 1.0, 761.10], + ["S2", 10.0, 762.33], + ], + }, + }, + ), + ( + "ordinal_int_index", + 1, + 100, + { + "ts_idx": {"ts_col": "order", "series_ids": ["symbol"]}, + "df": { + "schema": "symbol string, order int, trade_pr float", + "data": [ + ["S1", 1, 349.21], + ["S1", 20, 351.32], + ["S2", 1, 751.92], + ["S2", 10, 761.10], + ["S2", 100, 762.33], + ], + }, + }, + ), + ]) + def test_between_inclusive( + self, init_tsdf_id, start_ts, end_ts, expected_tsdf_dict + ): + # load TSDF + tsdf = self.get_test_data(init_tsdf_id).as_tsdf() + # slice at timestamp + between_tsdf = tsdf.between(start_ts, end_ts, inclusive=True) + # validate the slice + expected_tsdf = TestDataFrame(self.spark, expected_tsdf_dict).as_tsdf() + self.assertDataFrameEquality(between_tsdf, expected_tsdf) + + @parameterized.expand([ + ( + "simple_ts_idx", + 2, + { + "ts_idx": {"ts_col": "event_ts", "series_ids": ["symbol"]}, + "df": { + "schema": "symbol string, event_ts string, trade_pr float", + "ts_convert": ["event_ts"], + "data": [ + ["S1", "2020-08-01 00:00:10", 349.21], + ["S1", "2020-08-01 00:01:12", 351.32], + ["S2", "2020-08-01 00:01:10", 743.01], + ["S2", "2020-08-01 00:01:24", 751.92], + ], + }, + }, + ), + ( + "simple_ts_no_series", + 2, + { + "ts_idx": {"ts_col": "event_ts", "series_ids": []}, + "df": { + "schema": "event_ts string, trade_pr float", + "ts_convert": ["event_ts"], + "data": [ + ["2020-08-01 00:00:10", 349.21], + ["2020-08-01 00:01:10", 743.01], + ], + }, + }, + ), + ( + "ordinal_double_index", + 2, + { + "ts_idx": {"ts_col": "event_ts_dbl", "series_ids": ["symbol"]}, + "df": { + "schema": "symbol string, event_ts_dbl double, trade_pr float", + "data": [ + ["S1", 0.13, 349.21], + ["S1", 1.207, 351.32], + ["S2", 0.005, 743.01], + ["S2", 0.1, 751.92], + ], + }, + }, + ), + ( + "ordinal_int_index", + 2, + { + "ts_idx": {"ts_col": "order", "series_ids": ["symbol"]}, + "df": { + "schema": "symbol string, order int, trade_pr float", + "data": [ + ["S1", 1, 349.21], + ["S1", 20, 351.32], + ["S2", 0, 743.01], + ["S2", 1, 751.92], + ], + }, + }, + ), + ]) + def test_earliest(self, init_tsdf_id, num_records, expected_tsdf_dict): + # load TSDF + tsdf = self.get_test_data(init_tsdf_id).as_tsdf() + # get earliest timestamp + earliest_ts = tsdf.earliest(n=num_records) + # validate the timestamp + expected_tsdf = TestDataFrame(self.spark, expected_tsdf_dict).as_tsdf() + self.assertDataFrameEquality(earliest_ts, expected_tsdf) + + @parameterized.expand([ + ( + "simple_ts_idx", + 2, + { + "ts_idx": {"ts_col": "event_ts", "series_ids": ["symbol"]}, + "df": { + "schema": "symbol string, event_ts string, trade_pr float", + "ts_convert": ["event_ts"], + "data": [ + ["S1", "2020-09-01 00:19:12", 362.1], + ["S1", "2020-09-01 00:02:10", 361.1], + ["S2", "2020-09-01 00:20:42", 762.33], + ["S2", "2020-09-01 00:02:10", 761.10] + ], + }, + }, + ), + ( + "simple_ts_no_series", + 4, + { + "ts_idx": {"ts_col": "event_ts", "series_ids": []}, + "df": { + "schema": "event_ts string, trade_pr float", + "ts_convert": ["event_ts"], + "data": [ + ["2020-09-01 00:20:42", 762.33], + ["2020-09-01 00:19:12", 362.1], + ["2020-09-01 00:02:10", 361.1], + ["2020-08-01 00:01:24", 751.92] + ], + }, + }, + ), + ( + "ordinal_double_index", + 1, + { + "ts_idx": {"ts_col": "event_ts_dbl", "series_ids": ["symbol"]}, + "df": { + "schema": "symbol string, event_ts_dbl double, trade_pr float", + "data": [ + ["S1", 24.357, 362.1], + ["S2", 10.0, 762.33] + ], + }, + }, + ), + ( + "ordinal_int_index", + 3, + { + "ts_idx": {"ts_col": "order", "series_ids": ["symbol"]}, + "df": { + "schema": "symbol string, order int, trade_pr float", + "data": [ + ["S1", 243, 362.1], + ["S1", 127, 361.1], + ["S1", 20, 351.32], + ["S2", 100, 762.33], + ["S2", 10, 761.10], + ["S2", 1, 751.92], + ], + }, + }, + ), + ]) + def test_latest(self, init_tsdf_id, num_records, expected_tsdf_dict): + # load TSDF + tsdf = self.get_test_data(init_tsdf_id).as_tsdf() + # get earliest timestamp + latest_ts = tsdf.latest(n=num_records) + # validate the timestamp + expected_tsdf = TestDataFrame(self.spark, expected_tsdf_dict).as_tsdf() + self.assertDataFrameEquality(latest_ts, expected_tsdf) + + + @parameterized.expand([ + ( + "simple_ts_idx", + "2020-09-01 00:02:10", + 2, + { + "ts_idx": {"ts_col": "event_ts", "series_ids": ["symbol"]}, + "df": { + "schema": "symbol string, event_ts string, trade_pr float", + "ts_convert": ["event_ts"], + "data": [ + ["S1", "2020-09-01 00:02:10", 361.1], + ["S1", "2020-08-01 00:01:12", 351.32], + ["S2", "2020-09-01 00:02:10", 761.10], + ["S2", "2020-08-01 00:01:24", 751.92] + ], + }, + }, + ), + ( + "simple_ts_no_series", + "2020-09-01 00:19:12", + 3, + { + "ts_idx": {"ts_col": "event_ts", "series_ids": []}, + "df": { + "schema": "event_ts string, trade_pr float", + "ts_convert": ["event_ts"], + "data": [ + ["2020-09-01 00:19:12", 362.1], + ["2020-09-01 00:02:10", 361.1], + ["2020-08-01 00:01:24", 751.92], + ], + }, + }, + ), + ( + "ordinal_double_index", + 10.0, + 4, + { + "ts_idx": {"ts_col": "event_ts_dbl", "series_ids": ["symbol"]}, + "df": { + "schema": "symbol string, event_ts_dbl double, trade_pr float", + "data": [ + ["S1", 10.0, 361.1], + ["S1", 1.207, 351.32], + ["S1", 0.13, 349.21], + ["S2", 10.0, 762.33], + ["S2", 1.0, 761.10], + ["S2", 0.1, 751.92], + ["S2", 0.005, 743.01], + ], + }, + }, + ), + ( + "ordinal_int_index", + 1, + 1, + { + "ts_idx": {"ts_col": "order", "series_ids": ["symbol"]}, + "df": { + "schema": "symbol string, order int, trade_pr float", + "data": [["S1", 1, 349.21], ["S2", 1, 751.92]], + }, + }, + ), + ]) + def test_priorTo(self, init_tsdf_id, ts, n, expected_tsdf_dict): + # load TSDF + tsdf = self.get_test_data(init_tsdf_id).as_tsdf() + # slice at timestamp + prior_tsdf = tsdf.priorTo(ts, n=n) + # validate the slice + expected_tsdf = TestDataFrame(self.spark, expected_tsdf_dict).as_tsdf() + self.assertDataFrameEquality(prior_tsdf, expected_tsdf) + + @parameterized.expand([ + ( + "simple_ts_idx", + "2020-09-01 00:02:10", + 1, + { + "ts_idx": {"ts_col": "event_ts", "series_ids": ["symbol"]}, + "df": { + "schema": "symbol string, event_ts string, trade_pr float", + "ts_convert": ["event_ts"], + "data": [ + ["S1", "2020-09-01 00:02:10", 361.1], + ["S2", "2020-09-01 00:02:10", 761.10] + ], + }, + }, + ), + ( + "simple_ts_no_series", + "2020-08-01 00:01:24", + 3, + { + "ts_idx": {"ts_col": "event_ts", "series_ids": []}, + "df": { + "schema": "event_ts string, trade_pr float", + "ts_convert": ["event_ts"], + "data": [ + ["2020-08-01 00:01:24", 751.92], + ["2020-09-01 00:02:10", 361.1], + ["2020-09-01 00:19:12", 362.1] + ], + }, + }, + ), + ( + "ordinal_double_index", + 10.0, + 2, + { + "ts_idx": {"ts_col": "event_ts_dbl", "series_ids": ["symbol"]}, + "df": { + "schema": "symbol string, event_ts_dbl double, trade_pr float", + "data": [ + ["S1", 10.0, 361.1], + ["S1", 24.357, 362.1], + ["S2", 10.0, 762.33], + ], + }, + }, + ), + ( + "ordinal_int_index", + 10, + 2, + { + "ts_idx": {"ts_col": "order", "series_ids": ["symbol"]}, + "df": { + "schema": "symbol string, order int, trade_pr float", + "data": [ + ["S1", 20, 351.32], + ["S1", 127, 361.1], + ["S2", 10, 761.10], + ["S2", 100, 762.33], + ], + }, + }, + ), + ]) + def test_subsequentTo(self, init_tsdf_id, ts, n, expected_tsdf_dict): + # load TSDF + tsdf = self.get_test_data(init_tsdf_id).as_tsdf() + # slice at timestamp + subseq_tsdf = tsdf.subsequentTo(ts, n=n) + # validate the slice + expected_tsdf = TestDataFrame(self.spark, expected_tsdf_dict).as_tsdf() + self.assertDataFrameEquality(subseq_tsdf, expected_tsdf) From 5c31d3a5ea7a0f41d735e9affcfd389da70d55ce Mon Sep 17 00:00:00 2001 From: Tristan Nixon Date: Mon, 22 Jan 2024 20:28:42 -0800 Subject: [PATCH 13/13] completed time slicing tests for parsed ts indexes --- python/tempo/timeunit.py | 8 +- python/tempo/tsdf.py | 6 +- python/tempo/tsschema.py | 110 +-- python/tempo/typing.py | 2 - python/tests/tsdf_tests.py | 712 +++++++++++++++++++- python/tests/tsschema_tests.py | 21 +- python/tests/unit_test_data/tsdf_tests.json | 62 ++ 7 files changed, 829 insertions(+), 92 deletions(-) diff --git a/python/tempo/timeunit.py b/python/tempo/timeunit.py index f472a07c..5506443e 100644 --- a/python/tempo/timeunit.py +++ b/python/tempo/timeunit.py @@ -6,7 +6,6 @@ class TimeUnit(NamedTuple): name: str approx_seconds: float - sub_second_precision: int = 0 """ Represents a unit of time, with a name, an approximate number of seconds, @@ -40,7 +39,8 @@ def __lt__(self, other): TimeUnit("hour", 60 * 60), TimeUnit("minute", 60), TimeUnit("second", 1), - TimeUnit("millisecond", 1e-03, 3), - TimeUnit("microsecond", 1e-06, 6), - TimeUnit("nanosecond", 1e-09, 9) + TimeUnit("millisecond", 1e-03), + TimeUnit("microsecond", 1e-06), + TimeUnit("nanosecond", 1e-09) ) + diff --git a/python/tempo/tsdf.py b/python/tempo/tsdf.py index 5c487c78..a21bcc48 100644 --- a/python/tempo/tsdf.py +++ b/python/tempo/tsdf.py @@ -23,8 +23,8 @@ import tempo.resample as t_resample import tempo.utils as t_utils from tempo.intervals import IntervalsDF -from tempo.tsschema import DEFAULT_TIMESTAMP_FORMAT, is_time_format, \ - CompositeTSIndex, TSIndex, TSSchema, WindowBuilder +from tempo.tsschema import DEFAULT_TIMESTAMP_FORMAT, is_time_format, sub_seconds_precision_digits, \ + CompositeTSIndex, ParsedTSIndex, TSIndex, TSSchema, WindowBuilder from tempo.typing import ColumnOrName, PandasMapIterFunction, PandasGroupedMapFunction logger = logging.getLogger(__name__) @@ -146,7 +146,7 @@ def fromStringTimestamp( ts_expr = sfn.to_date(sfn.col(ts_col), ts_fmt) # parse the ts_col give the expression parsed_ts_col = cls.__DEFAULT_PARSED_TS_COL - parsed_df = df.withColumn(cls.__DEFAULT_PARSED_TS_COL, ts_expr) + parsed_df = df.withColumn(parsed_ts_col, ts_expr) # move the ts cols into a struct struct_col_name = cls.__DEFAULT_TS_IDX_COL with_parsed_struct_df = cls.__makeStructFromCols(parsed_df, diff --git a/python/tempo/tsschema.py b/python/tempo/tsschema.py index a0d03b02..696489a0 100644 --- a/python/tempo/tsschema.py +++ b/python/tempo/tsschema.py @@ -1,7 +1,7 @@ import re import warnings from abc import ABC, abstractmethod -from typing import Collection, List, Optional, Union, Callable +from typing import Collection, List, Optional, Union import pyspark.sql.functions as sfn from pyspark.sql import Column, Window, WindowSpec @@ -455,9 +455,12 @@ def __eq__(self, other) -> Column: # match each component field with its corresponding comparison value comps = zip(self.comparableExpr(), [_col_or_lit(o) for o in other]) # build comparison expressions for each pair - comp_exprs = [c.eq(o) for (c, o) in comps] + comp_exprs: list[Column] = [(c == o) for (c, o) in comps] # conjunction of all expressions (AND) - return sfn.expr(" AND ".join(comp_exprs)) + if len(comp_exprs) > 1: + return sfn.expr(" AND ".join(comp_exprs)) + else: + return comp_exprs[0] def __ne__(self, other) -> Column: # try to compare the whole index to a single value @@ -468,9 +471,12 @@ def __ne__(self, other) -> Column: # match each component field with its corresponding comparison value comps = zip(self.comparableExpr(), [_col_or_lit(o) for o in other]) # build comparison expressions for each pair - comp_exprs = [c.neq(o) for (c, o) in comps] + comp_exprs = [(c != o) for (c, o) in comps] # disjunction of all expressions (OR) - return sfn.expr(" OR ".join(comp_exprs)) + if len(comp_exprs) > 1: + return sfn.expr(" OR ".join(comp_exprs)) + else: + return comp_exprs[0] def __lt__(self, other) -> Column: # try to compare the whole index to a single value @@ -483,11 +489,14 @@ def __lt__(self, other) -> Column: # do a leq for all but the last component comp_exprs = [] if len(comps) > 1: - comp_exprs = [c.leq(o) for (c, o) in comps[:-1]] + comp_exprs = [(c <= o) for (c, o) in comps[:-1]] # strict lt for the last component - comp_exprs += [c.lt(o) for (c, o) in comps[-1:]] + comp_exprs += [(c < o) for (c, o) in comps[-1:]] # conjunction of all expressions (AND) - return sfn.expr(" AND ".join(comp_exprs)) + if len(comp_exprs) > 1: + return sfn.expr(" AND ".join(comp_exprs)) + else: + return comp_exprs[0] def __le__(self, other) -> Column: # try to compare the whole index to a single value @@ -498,9 +507,12 @@ def __le__(self, other) -> Column: # match each component field with its corresponding comparison value comps = zip(self.comparableExpr(), [_col_or_lit(o) for o in other]) # build comparison expressions for each pair - comp_exprs = [c.leq(o) for (c, o) in comps] + comp_exprs = [(c <= o) for (c, o) in comps] # conjunction of all expressions (AND) - return sfn.expr(" AND ".join(comp_exprs)) + if len(comp_exprs) > 1: + return sfn.expr(" AND ".join(comp_exprs)) + else: + return comp_exprs[0] def __gt__(self, other) -> Column: # try to compare the whole index to a single value @@ -513,11 +525,14 @@ def __gt__(self, other) -> Column: # do a geq for all but the last component comp_exprs = [] if len(comps) > 1: - comp_exprs = [c.geq(o) for (c, o) in comps[:-1]] + comp_exprs = [(c >= o) for (c, o) in comps[:-1]] # strict gt for the last component - comp_exprs += [c.gt(o) for (c, o) in comps[-1:]] + comp_exprs += [(c > o) for (c, o) in comps[-1:]] # conjunction of all expressions (AND) - return sfn.expr(" AND ".join(comp_exprs)) + if len(comp_exprs) > 1: + return sfn.expr(" AND ".join(comp_exprs)) + else: + return comp_exprs[0] def __ge__(self, other) -> Column: # try to compare the whole index to a single value @@ -528,9 +543,12 @@ def __ge__(self, other) -> Column: # match each component field with its corresponding comparison value comps = zip(self.comparableExpr(), [_col_or_lit(o) for o in other]) # build comparison expressions for each pair - comp_exprs = [c.geq(o) for (c, o) in comps] + comp_exprs = [(c >= o) for (c, o) in comps] # conjunction of all expressions (AND) - return sfn.expr(" AND ".join(comp_exprs)) + if len(comp_exprs) > 1: + return sfn.expr(" AND ".join(comp_exprs)) + else: + return comp_exprs[0] # @@ -572,13 +590,6 @@ def src_str_field(self): def parsed_ts_field(self): return self.fieldPath(self._parsed_ts_field) - def orderByExpr(self, reverse: bool = False) -> Union[Column, List[Column]]: - expr = sfn.col(self.parsed_ts_field) - return _reverse_or_not(expr, reverse) - - def comparableExpr(self) -> Column: - return sfn.col(self.parsed_ts_field) - @classmethod def fromParsedTimestamp( cls, @@ -602,14 +613,14 @@ def fromParsedTimestamp( # if a double timestamp column is given # then we are building a SubMicrosecondPrecisionTimestampIndex - if double_ts_col is not None: - return SubMicrosecondPrecisionTimestampIndex( - ts_struct, - double_ts_col, - parsed_ts_col, - src_str_col, - num_precision_digits, - ) + # if double_ts_col is not None: + # return SubMicrosecondPrecisionTimestampIndex( + # ts_struct, + # double_ts_col, + # parsed_ts_col, + # src_str_col, + # num_precision_digits, + # ) # otherwise, we base it on the standard timestamp type # find the schema of the ts_struct column ts_schema = ts_struct.dataType @@ -665,7 +676,7 @@ def rangeExpr(self, reverse: bool = False) -> Column: return _reverse_or_not(expr, reverse) -class SubMicrosecondPrecisionTimestampIndex(CompositeTSIndex): +class SubMicrosecondPrecisionTimestampIndex(ParsedTSIndex): """ Timeseries index class for timestamps with sub-microsecond precision parsed from a string column. Internally, the timestamps are stored as @@ -691,7 +702,7 @@ def __init__( You will receive a warning if this value is 6 or less, as this is the precision of the standard timestamp type. """ - super().__init__(ts_struct, double_ts_field) + super().__init__(ts_struct, double_ts_field, src_str_field) # validate the double timestamp column double_ts_type = self.schema[double_ts_field].dataType if not isinstance(double_ts_type, DoubleType): @@ -700,7 +711,7 @@ def __init__( f"but the given double_ts_col {double_ts_field} " f"has type {double_ts_type}" ) - self.double_ts_field = double_ts_field + self._double_ts_field = double_ts_field # validate the number of precision digits if num_precision_digits <= 6: warnings.warn( @@ -709,11 +720,7 @@ def __init__( "standard timestamp precision of 6 digits (microseconds). " "Consider using a ParsedTimestampIndex instead." ) - self.__unit = TimeUnit( - f"custom_subsecond_unit (precision: {num_precision_digits})", - 10 ** (-num_precision_digits), - num_precision_digits, - ) + self._num_precision_digits = num_precision_digits # validate the parsed column as a timestamp column parsed_ts_type = self.schema[secondary_parsed_ts_field].dataType if not isinstance(parsed_ts_type, TimestampType): @@ -722,30 +729,23 @@ def __init__( f"but the given parsed_ts_col {secondary_parsed_ts_field} " f"has type {parsed_ts_type}" ) - self.parsed_ts_field = secondary_parsed_ts_field - # validate the source column as a string column - src_str_field = self.schema[src_str_field] - if not isinstance(src_str_field.dataType, StringType): - raise TypeError( - "src_str_col field must be of StringType, " - f"but the given src_str_col {src_str_field} " - f"has type {src_str_field.dataType}" - ) - self.src_str_col = src_str_field + self.secondary_parsed_ts_field = secondary_parsed_ts_field @property - def unit(self) -> Optional[TimeUnit]: - return self.__unit + def double_ts_field(self): + return self.fieldPath(self._double_ts_field) - def comparableExpr(self) -> Column: - return sfn.col(self.fieldPath(self.double_ts_field)) + @property + def num_precision_digits(self): + return self._num_precision_digits - def orderByExpr(self, reverse: bool = False) -> Union[Column, List[Column]]: - return _reverse_or_not(self.comparableExpr(), reverse) + @property + def unit(self) -> Optional[TimeUnit]: + return StandardTimeUnits.SECONDS def rangeExpr(self, reverse: bool = False) -> Column: # just use the order by expression, since this is the same - return self.orderByExpr(reverse) + return _reverse_or_not(sfn.col(self.double_ts_field), reverse) # diff --git a/python/tempo/typing.py b/python/tempo/typing.py index a29c3d9f..c6176fcb 100644 --- a/python/tempo/typing.py +++ b/python/tempo/typing.py @@ -4,8 +4,6 @@ from pandas.core.frame import DataFrame as PandasDataFrame -from pyspark.sql.pandas._typing import PandasMapIterFunction, PandasGroupedMapFunction - # These definitions were copied from private pypark modules: # - pyspark.sql._typing # - pyspark.sql.pandas._typing diff --git a/python/tests/tsdf_tests.py b/python/tests/tsdf_tests.py index bd8ece76..b3abee88 100644 --- a/python/tests/tsdf_tests.py +++ b/python/tests/tsdf_tests.py @@ -1,24 +1,30 @@ from parameterized import parameterized from tempo.tsdf import TSDF -from tempo.tsschema import SimpleTimestampIndex, OrdinalTSIndex, TSSchema +from tempo.tsschema import ( + SimpleTimestampIndex, + SimpleDateIndex, + OrdinalTSIndex, + ParsedTimestampIndex, + ParsedDateIndex, + TSSchema, +) from tests.base import TestDataFrame, SparkTest -class TSDFBaseTest(SparkTest): +class TSDFBaseTests(SparkTest): @parameterized.expand([ ("simple_ts_idx", SimpleTimestampIndex), ("simple_ts_no_series", SimpleTimestampIndex), + ("simple_date_idx", SimpleDateIndex), ("ordinal_double_index", OrdinalTSIndex), ("ordinal_int_index", OrdinalTSIndex), + ("parsed_ts_idx", ParsedTimestampIndex), + ("parsed_date_idx", ParsedDateIndex), ]) def test_tsdf_constructor(self, init_tsdf_id, expected_idx_class): - # load the test data - test_data = self.get_test_data(init_tsdf_id) - # load Spark DataFrame - init_sdf = test_data.as_sdf() # create TSDF - init_tsdf = TSDF(init_sdf, **test_data.ts_idx) + init_tsdf = self.get_test_data(init_tsdf_id).as_tsdf() # check that TSDF was created correctly self.assertIsNotNone(init_tsdf) self.assertIsInstance(init_tsdf, TSDF) @@ -32,8 +38,11 @@ def test_tsdf_constructor(self, init_tsdf_id, expected_idx_class): @parameterized.expand([ ("simple_ts_idx", ["symbol"]), ("simple_ts_no_series", []), + ("simple_date_idx", ["station"]), ("ordinal_double_index", ["symbol"]), ("ordinal_int_index", ["symbol"]), + ("parsed_ts_idx", ["symbol"]), + ("parsed_date_idx", ["station"]), ]) def test_series_ids(self, init_tsdf_id, expected_series_ids): # load TSDF @@ -44,8 +53,11 @@ def test_series_ids(self, init_tsdf_id, expected_series_ids): @parameterized.expand([ ("simple_ts_idx", ["event_ts", "symbol"]), ("simple_ts_no_series", ["event_ts"]), + ("simple_date_idx", ["date", "station"]), ("ordinal_double_index", ["event_ts_dbl", "symbol"]), ("ordinal_int_index", ["order", "symbol"]), + ("parsed_ts_idx", ["ts_idx", "symbol"]), + ("parsed_date_idx", ["ts_idx", "station"]), ]) def test_structural_cols(self, init_tsdf_id, expected_structural_cols): # load TSDF @@ -56,8 +68,11 @@ def test_structural_cols(self, init_tsdf_id, expected_structural_cols): @parameterized.expand([ ("simple_ts_idx", ["trade_pr"]), ("simple_ts_no_series", ["trade_pr"]), + ("simple_date_idx", ["temp"]), ("ordinal_double_index", ["trade_pr"]), ("ordinal_int_index", ["trade_pr"]), + ("parsed_ts_idx", ["trade_pr"]), + ("parsed_date_idx", ["temp"]), ]) def test_obs_cols(self, init_tsdf_id, expected_obs_cols): # load TSDF @@ -68,8 +83,11 @@ def test_obs_cols(self, init_tsdf_id, expected_obs_cols): @parameterized.expand([ ("simple_ts_idx", ["trade_pr"]), ("simple_ts_no_series", ["trade_pr"]), + ("simple_date_idx", ["temp"]), ("ordinal_double_index", ["trade_pr"]), ("ordinal_int_index", ["trade_pr"]), + ("parsed_ts_idx", ["trade_pr"]), + ("parsed_date_idx", ["temp"]), ]) def test_metric_cols(self, init_tsdf_id, expected_metric_cols): # load TSDF @@ -109,6 +127,21 @@ class TimeSlicingTests(SparkTest): }, }, ), + ( + "simple_date_idx", + "2020-08-02", + { + "ts_idx": {"ts_col": "date", "series_ids": ["station"]}, + "df": { + "schema": "station string, date string, temp float", + "date_convert": ["date"], + "data": [ + ["LGA", "2020-08-02", 28.79], + ["YYZ", "2020-08-02", 22.25], + ], + }, + }, + ), ( "ordinal_double_index", 10.0, @@ -137,6 +170,43 @@ class TimeSlicingTests(SparkTest): }, }, ), + ( + "parsed_ts_idx", + "2020-09-01 00:02:10.032", + { + "ts_idx": { + "ts_col": "event_ts", + "series_ids": ["symbol"], + "ts_fmt": "yyyy-MM-dd HH:mm:ss.SSS", + }, + "tsdf_constructor": "fromStringTimestamp", + "df": { + "schema": "symbol string, event_ts string, trade_pr float", + "data": [ + ["S1", "2020-09-01 00:02:10.032", 361.1], + ], + }, + }, + ), + ( + "parsed_date_idx", + "2020-08-04", + { + "ts_idx": { + "ts_col": "date", + "series_ids": ["station"], + "ts_fmt": "yyyy-MM-dd", + }, + "tsdf_constructor": "fromStringTimestamp", + "df": { + "schema": "station string, date string, temp float", + "data": [ + ["LGA", "2020-08-04", 25.57], + ["YYZ", "2020-08-04", 20.65], + ], + }, + }, + ), ]) def test_at(self, init_tsdf_id, ts, expected_tsdf_dict): # load TSDF @@ -183,6 +253,23 @@ def test_at(self, init_tsdf_id, ts, expected_tsdf_dict): }, }, ), + ( + "simple_date_idx", + "2020-08-03", + { + "ts_idx": {"ts_col": "date", "series_ids": ["station"]}, + "df": { + "schema": "station string, date string, temp float", + "date_convert": ["date"], + "data": [ + ["LGA", "2020-08-01", 27.58], + ["LGA", "2020-08-02", 28.79], + ["YYZ", "2020-08-01", 24.16], + ["YYZ", "2020-08-02", 22.25], + ], + }, + }, + ), ( "ordinal_double_index", 10.0, @@ -211,6 +298,48 @@ def test_at(self, init_tsdf_id, ts, expected_tsdf_dict): }, }, ), + ( + "parsed_ts_idx", + "2020-09-01 00:02:10.000", + { + "ts_idx": { + "ts_col": "event_ts", + "series_ids": ["symbol"], + "ts_fmt": "yyyy-MM-dd HH:mm:ss.SSS", + }, + "tsdf_constructor": "fromStringTimestamp", + "df": { + "schema": "symbol string, event_ts string, trade_pr float", + "data": [ + ["S1", "2020-08-01 00:00:10.010", 349.21], + ["S1", "2020-08-01 00:01:12.021", 351.32], + ["S2", "2020-08-01 00:01:10.054", 743.01], + ["S2", "2020-08-01 00:01:24.065", 751.92], + ], + }, + }, + ), + ( + "parsed_date_idx", + "2020-08-03", + { + "ts_idx": { + "ts_col": "date", + "series_ids": ["station"], + "ts_fmt": "yyyy-MM-dd", + }, + "tsdf_constructor": "fromStringTimestamp", + "df": { + "schema": "station string, date string, temp float", + "data": [ + ["LGA", "2020-08-01", 27.58], + ["LGA", "2020-08-02", 28.79], + ["YYZ", "2020-08-01", 24.16], + ["YYZ", "2020-08-02", 22.25], + ], + }, + }, + ), ]) def test_before(self, init_tsdf_id, ts, expected_tsdf_dict): # load TSDF @@ -260,6 +389,25 @@ def test_before(self, init_tsdf_id, ts, expected_tsdf_dict): }, }, ), + ( + "simple_date_idx", + "2020-08-03", + { + "ts_idx": {"ts_col": "date", "series_ids": ["station"]}, + "df": { + "schema": "station string, date string, temp float", + "date_convert": ["date"], + "data": [ + ["LGA", "2020-08-01", 27.58], + ["LGA", "2020-08-02", 28.79], + ["LGA", "2020-08-03", 28.53], + ["YYZ", "2020-08-01", 24.16], + ["YYZ", "2020-08-02", 22.25], + ["YYZ", "2020-08-03", 20.62], + ], + }, + }, + ), ( "ordinal_double_index", 10.0, @@ -290,6 +438,50 @@ def test_before(self, init_tsdf_id, ts, expected_tsdf_dict): }, }, ), + ( + "parsed_ts_idx", + "2020-09-01 00:02:10.000", + { + "ts_idx": { + "ts_col": "event_ts", + "series_ids": ["symbol"], + "ts_fmt": "yyyy-MM-dd HH:mm:ss.SSS", + }, + "tsdf_constructor": "fromStringTimestamp", + "df": { + "schema": "symbol string, event_ts string, trade_pr float", + "data": [ + ["S1", "2020-08-01 00:00:10.010", 349.21], + ["S1", "2020-08-01 00:01:12.021", 351.32], + ["S2", "2020-08-01 00:01:10.054", 743.01], + ["S2", "2020-08-01 00:01:24.065", 751.92], + ], + }, + }, + ), + ( + "parsed_date_idx", + "2020-08-03", + { + "ts_idx": { + "ts_col": "date", + "series_ids": ["station"], + "ts_fmt": "yyyy-MM-dd", + }, + "tsdf_constructor": "fromStringTimestamp", + "df": { + "schema": "station string, date string, temp float", + "data": [ + ["LGA", "2020-08-01", 27.58], + ["LGA", "2020-08-02", 28.79], + ["LGA", "2020-08-03", 28.53], + ["YYZ", "2020-08-01", 24.16], + ["YYZ", "2020-08-02", 22.25], + ["YYZ", "2020-08-03", 20.62], + ], + }, + }, + ), ]) def test_atOrBefore(self, init_tsdf_id, ts, expected_tsdf_dict): # load TSDF @@ -331,6 +523,23 @@ def test_atOrBefore(self, init_tsdf_id, ts, expected_tsdf_dict): }, }, ), + ( + "simple_date_idx", + "2020-08-02", + { + "ts_idx": {"ts_col": "date", "series_ids": ["station"]}, + "df": { + "schema": "station string, date string, temp float", + "date_convert": ["date"], + "data": [ + ["LGA", "2020-08-03", 28.53], + ["LGA", "2020-08-04", 25.57], + ["YYZ", "2020-08-03", 20.62], + ["YYZ", "2020-08-04", 20.65], + ], + }, + }, + ), ( "ordinal_double_index", 1.0, @@ -363,6 +572,46 @@ def test_atOrBefore(self, init_tsdf_id, ts, expected_tsdf_dict): }, }, ), + ( + "parsed_ts_idx", + "2020-09-01 00:02:10.000", + { + "ts_idx": { + "ts_col": "event_ts", + "series_ids": ["symbol"], + "ts_fmt": "yyyy-MM-dd HH:mm:ss.SSS", + }, + "tsdf_constructor": "fromStringTimestamp", + "df": { + "schema": "symbol string, event_ts string, trade_pr float", + "data": [ + ["S1", "2020-09-01 00:02:10.032", 361.1], + ["S1", "2020-09-01 00:19:12.043", 362.1], + ["S2", "2020-09-01 00:02:10.076", 761.10], + ["S2", "2020-09-01 00:20:42.087", 762.33], + ], + }, + }, + ), + ( + "parsed_date_idx", + "2020-08-03", + { + "ts_idx": { + "ts_col": "date", + "series_ids": ["station"], + "ts_fmt": "yyyy-MM-dd", + }, + "tsdf_constructor": "fromStringTimestamp", + "df": { + "schema": "station string, date string, temp float", + "data": [ + ["LGA", "2020-08-04", 25.57], + ["YYZ", "2020-08-04", 20.65], + ], + }, + }, + ), ]) def test_after(self, init_tsdf_id, ts, expected_tsdf_dict): # load TSDF @@ -408,6 +657,23 @@ def test_after(self, init_tsdf_id, ts, expected_tsdf_dict): }, }, ), + ( + "simple_date_idx", + "2020-08-03", + { + "ts_idx": {"ts_col": "date", "series_ids": ["station"]}, + "df": { + "schema": "station string, date string, temp float", + "date_convert": ["date"], + "data": [ + ["LGA", "2020-08-03", 28.53], + ["LGA", "2020-08-04", 25.57], + ["YYZ", "2020-08-03", 20.62], + ["YYZ", "2020-08-04", 20.65], + ], + }, + }, + ), ( "ordinal_double_index", 10.0, @@ -440,6 +706,48 @@ def test_after(self, init_tsdf_id, ts, expected_tsdf_dict): }, }, ), + ( + "parsed_ts_idx", + "2020-09-01 00:02:10.000", + { + "ts_idx": { + "ts_col": "event_ts", + "series_ids": ["symbol"], + "ts_fmt": "yyyy-MM-dd HH:mm:ss.SSS", + }, + "tsdf_constructor": "fromStringTimestamp", + "df": { + "schema": "symbol string, event_ts string, trade_pr float", + "data": [ + ["S1", "2020-09-01 00:02:10.032", 361.1], + ["S1", "2020-09-01 00:19:12.043", 362.1], + ["S2", "2020-09-01 00:02:10.076", 761.10], + ["S2", "2020-09-01 00:20:42.087", 762.33], + ], + }, + }, + ), + ( + "parsed_date_idx", + "2020-08-03", + { + "ts_idx": { + "ts_col": "date", + "series_ids": ["station"], + "ts_fmt": "yyyy-MM-dd", + }, + "tsdf_constructor": "fromStringTimestamp", + "df": { + "schema": "station string, date string, temp float", + "data": [ + ["LGA", "2020-08-03", 28.53], + ["LGA", "2020-08-04", 25.57], + ["YYZ", "2020-08-03", 20.62], + ["YYZ", "2020-08-04", 20.65], + ], + }, + }, + ), ]) def test_atOrAfter(self, init_tsdf_id, ts, expected_tsdf_dict): # load TSDF @@ -483,6 +791,22 @@ def test_atOrAfter(self, init_tsdf_id, ts, expected_tsdf_dict): }, }, ), + ( + "simple_date_idx", + "2020-08-01", + "2020-08-03", + { + "ts_idx": {"ts_col": "date", "series_ids": ["station"]}, + "df": { + "schema": "station string, date string, temp float", + "date_convert": ["date"], + "data": [ + ["LGA", "2020-08-02", 28.79], + ["YYZ", "2020-08-02", 22.25], + ], + }, + }, + ), ( "ordinal_double_index", 0.1, @@ -511,6 +835,48 @@ def test_atOrAfter(self, init_tsdf_id, ts, expected_tsdf_dict): }, }, ), + ( + "parsed_ts_idx", + "2020-08-01 00:00:10.010", + "2020-09-01 00:02:10.076", + { + "ts_idx": { + "ts_col": "event_ts", + "series_ids": ["symbol"], + "ts_fmt": "yyyy-MM-dd HH:mm:ss.SSS", + }, + "tsdf_constructor": "fromStringTimestamp", + "df": { + "schema": "symbol string, event_ts string, trade_pr float", + "data": [ + ["S1", "2020-08-01 00:01:12.021", 351.32], + ["S1", "2020-09-01 00:02:10.032", 361.1], + ["S2", "2020-08-01 00:01:10.054", 743.01], + ["S2", "2020-08-01 00:01:24.065", 751.92], + ], + }, + }, + ), + ( + "parsed_date_idx", + "2020-08-01", + "2020-08-03", + { + "ts_idx": { + "ts_col": "date", + "series_ids": ["station"], + "ts_fmt": "yyyy-MM-dd", + }, + "tsdf_constructor": "fromStringTimestamp", + "df": { + "schema": "station string, date string, temp float", + "data": [ + ["LGA", "2020-08-02", 28.79], + ["YYZ", "2020-08-02", 22.25], + ], + }, + }, + ), ]) def test_between_non_inclusive( self, init_tsdf_id, start_ts, end_ts, expected_tsdf_dict @@ -561,6 +927,26 @@ def test_between_non_inclusive( }, }, ), + ( + "simple_date_idx", + "2020-08-01", + "2020-08-03", + { + "ts_idx": {"ts_col": "date", "series_ids": ["station"]}, + "df": { + "schema": "station string, date string, temp float", + "date_convert": ["date"], + "data": [ + ["LGA", "2020-08-01", 27.58], + ["LGA", "2020-08-02", 28.79], + ["LGA", "2020-08-03", 28.53], + ["YYZ", "2020-08-01", 24.16], + ["YYZ", "2020-08-02", 22.25], + ["YYZ", "2020-08-03", 20.62], + ], + }, + }, + ), ( "ordinal_double_index", 0.1, @@ -598,6 +984,54 @@ def test_between_non_inclusive( }, }, ), + ( + "parsed_ts_idx", + "2020-08-01 00:00:10.010", + "2020-09-01 00:02:10.076", + { + "ts_idx": { + "ts_col": "event_ts", + "series_ids": ["symbol"], + "ts_fmt": "yyyy-MM-dd HH:mm:ss.SSS", + }, + "tsdf_constructor": "fromStringTimestamp", + "df": { + "schema": "symbol string, event_ts string, trade_pr float", + "data": [ + ["S1", "2020-08-01 00:00:10.010", 349.21], + ["S1", "2020-08-01 00:01:12.021", 351.32], + ["S1", "2020-09-01 00:02:10.032", 361.1], + ["S2", "2020-08-01 00:01:10.054", 743.01], + ["S2", "2020-08-01 00:01:24.065", 751.92], + ["S2", "2020-09-01 00:02:10.076", 761.10], + ], + }, + }, + ), + ( + "parsed_date_idx", + "2020-08-01", + "2020-08-03", + { + "ts_idx": { + "ts_col": "date", + "series_ids": ["station"], + "ts_fmt": "yyyy-MM-dd", + }, + "tsdf_constructor": "fromStringTimestamp", + "df": { + "schema": "station string, date string, temp float", + "data": [ + ["LGA", "2020-08-01", 27.58], + ["LGA", "2020-08-02", 28.79], + ["LGA", "2020-08-03", 28.53], + ["YYZ", "2020-08-01", 24.16], + ["YYZ", "2020-08-02", 22.25], + ["YYZ", "2020-08-03", 20.62], + ], + }, + }, + ), ]) def test_between_inclusive( self, init_tsdf_id, start_ts, end_ts, expected_tsdf_dict @@ -643,6 +1077,23 @@ def test_between_inclusive( }, }, ), + ( + "simple_date_idx", + 2, + { + "ts_idx": {"ts_col": "date", "series_ids": ["station"]}, + "df": { + "schema": "station string, date string, temp float", + "date_convert": ["date"], + "data": [ + ["LGA", "2020-08-01", 27.58], + ["LGA", "2020-08-02", 28.79], + ["YYZ", "2020-08-01", 24.16], + ["YYZ", "2020-08-02", 22.25], + ], + }, + }, + ), ( "ordinal_double_index", 2, @@ -675,6 +1126,48 @@ def test_between_inclusive( }, }, ), + ( + "parsed_ts_idx", + 2, + { + "ts_idx": { + "ts_col": "event_ts", + "series_ids": ["symbol"], + "ts_fmt": "yyyy-MM-dd HH:mm:ss.SSS", + }, + "tsdf_constructor": "fromStringTimestamp", + "df": { + "schema": "symbol string, event_ts string, trade_pr float", + "data": [ + ["S1", "2020-08-01 00:00:10.010", 349.21], + ["S1", "2020-08-01 00:01:12.021", 351.32], + ["S2", "2020-08-01 00:01:10.054", 743.01], + ["S2", "2020-08-01 00:01:24.065", 751.92], + ], + }, + }, + ), + ( + "parsed_date_idx", + 2, + { + "ts_idx": { + "ts_col": "date", + "series_ids": ["station"], + "ts_fmt": "yyyy-MM-dd", + }, + "tsdf_constructor": "fromStringTimestamp", + "df": { + "schema": "station string, date string, temp float", + "data": [ + ["LGA", "2020-08-01", 27.58], + ["LGA", "2020-08-02", 28.79], + ["YYZ", "2020-08-01", 24.16], + ["YYZ", "2020-08-02", 22.25], + ], + }, + }, + ), ]) def test_earliest(self, init_tsdf_id, num_records, expected_tsdf_dict): # load TSDF @@ -698,7 +1191,7 @@ def test_earliest(self, init_tsdf_id, num_records, expected_tsdf_dict): ["S1", "2020-09-01 00:19:12", 362.1], ["S1", "2020-09-01 00:02:10", 361.1], ["S2", "2020-09-01 00:20:42", 762.33], - ["S2", "2020-09-01 00:02:10", 761.10] + ["S2", "2020-09-01 00:02:10", 761.10], ], }, }, @@ -715,7 +1208,26 @@ def test_earliest(self, init_tsdf_id, num_records, expected_tsdf_dict): ["2020-09-01 00:20:42", 762.33], ["2020-09-01 00:19:12", 362.1], ["2020-09-01 00:02:10", 361.1], - ["2020-08-01 00:01:24", 751.92] + ["2020-08-01 00:01:24", 751.92], + ], + }, + }, + ), + ( + "simple_date_idx", + 3, + { + "ts_idx": {"ts_col": "date", "series_ids": ["station"]}, + "df": { + "schema": "station string, date string, temp float", + "date_convert": ["date"], + "data": [ + ["LGA", "2020-08-04", 25.57], + ["LGA", "2020-08-03", 28.53], + ["LGA", "2020-08-02", 28.79], + ["YYZ", "2020-08-04", 20.65], + ["YYZ", "2020-08-03", 20.62], + ["YYZ", "2020-08-02", 22.25], ], }, }, @@ -727,10 +1239,7 @@ def test_earliest(self, init_tsdf_id, num_records, expected_tsdf_dict): "ts_idx": {"ts_col": "event_ts_dbl", "series_ids": ["symbol"]}, "df": { "schema": "symbol string, event_ts_dbl double, trade_pr float", - "data": [ - ["S1", 24.357, 362.1], - ["S2", 10.0, 762.33] - ], + "data": [["S1", 24.357, 362.1], ["S2", 10.0, 762.33]], }, }, ), @@ -752,6 +1261,48 @@ def test_earliest(self, init_tsdf_id, num_records, expected_tsdf_dict): }, }, ), + ( + "parsed_ts_idx", + 3, + { + "ts_idx": { + "ts_col": "event_ts", + "series_ids": ["symbol"], + "ts_fmt": "yyyy-MM-dd HH:mm:ss.SSS", + }, + "tsdf_constructor": "fromStringTimestamp", + "df": { + "schema": "symbol string, event_ts string, trade_pr float", + "data": [ + ["S1", "2020-09-01 00:19:12.043", 362.1], + ["S1", "2020-09-01 00:02:10.032", 361.1], + ["S1", "2020-08-01 00:01:12.021", 351.32], + ["S2", "2020-09-01 00:20:42.087", 762.33], + ["S2", "2020-09-01 00:02:10.076", 761.10], + ["S2", "2020-08-01 00:01:24.065", 751.92], + ], + }, + }, + ), + ( + "parsed_date_idx", + 1, + { + "ts_idx": { + "ts_col": "date", + "series_ids": ["station"], + "ts_fmt": "yyyy-MM-dd", + }, + "tsdf_constructor": "fromStringTimestamp", + "df": { + "schema": "station string, date string, temp float", + "data": [ + ["LGA", "2020-08-04", 25.57], + ["YYZ", "2020-08-04", 20.65], + ], + }, + }, + ), ]) def test_latest(self, init_tsdf_id, num_records, expected_tsdf_dict): # load TSDF @@ -762,7 +1313,6 @@ def test_latest(self, init_tsdf_id, num_records, expected_tsdf_dict): expected_tsdf = TestDataFrame(self.spark, expected_tsdf_dict).as_tsdf() self.assertDataFrameEquality(latest_ts, expected_tsdf) - @parameterized.expand([ ( "simple_ts_idx", @@ -777,7 +1327,7 @@ def test_latest(self, init_tsdf_id, num_records, expected_tsdf_dict): ["S1", "2020-09-01 00:02:10", 361.1], ["S1", "2020-08-01 00:01:12", 351.32], ["S2", "2020-09-01 00:02:10", 761.10], - ["S2", "2020-08-01 00:01:24", 751.92] + ["S2", "2020-08-01 00:01:24", 751.92], ], }, }, @@ -799,6 +1349,24 @@ def test_latest(self, init_tsdf_id, num_records, expected_tsdf_dict): }, }, ), + ( + "simple_date_idx", + "2020-08-03", + 2, + { + "ts_idx": {"ts_col": "date", "series_ids": ["station"]}, + "df": { + "schema": "station string, date string, temp float", + "date_convert": ["date"], + "data": [ + ["LGA", "2020-08-03", 28.53], + ["LGA", "2020-08-02", 28.79], + ["YYZ", "2020-08-03", 20.62], + ["YYZ", "2020-08-02", 22.25], + ], + }, + }, + ), ( "ordinal_double_index", 10.0, @@ -831,6 +1399,52 @@ def test_latest(self, init_tsdf_id, num_records, expected_tsdf_dict): }, }, ), + ( + "parsed_ts_idx", + "2020-09-01 00:02:10.000", + 2, + { + "ts_idx": { + "ts_col": "event_ts", + "series_ids": ["symbol"], + "ts_fmt": "yyyy-MM-dd HH:mm:ss.SSS", + }, + "tsdf_constructor": "fromStringTimestamp", + "df": { + "schema": "symbol string, event_ts string, trade_pr float", + "data": [ + ["S1", "2020-08-01 00:01:12.021", 351.32], + ["S1", "2020-08-01 00:00:10.010", 349.21], + ["S2", "2020-08-01 00:01:24.065", 751.92], + ["S2", "2020-08-01 00:01:10.054", 743.01], + ], + }, + }, + ), + ( + "parsed_date_idx", + "2020-08-03", + 3, + { + "ts_idx": { + "ts_col": "date", + "series_ids": ["station"], + "ts_fmt": "yyyy-MM-dd", + }, + "tsdf_constructor": "fromStringTimestamp", + "df": { + "schema": "station string, date string, temp float", + "data": [ + ["LGA", "2020-08-03", 28.53], + ["LGA", "2020-08-02", 28.79], + ["LGA", "2020-08-01", 27.58], + ["YYZ", "2020-08-03", 20.62], + ["YYZ", "2020-08-02", 22.25], + ["YYZ", "2020-08-01", 24.16], + ], + }, + }, + ), ]) def test_priorTo(self, init_tsdf_id, ts, n, expected_tsdf_dict): # load TSDF @@ -853,7 +1467,7 @@ def test_priorTo(self, init_tsdf_id, ts, n, expected_tsdf_dict): "ts_convert": ["event_ts"], "data": [ ["S1", "2020-09-01 00:02:10", 361.1], - ["S2", "2020-09-01 00:02:10", 761.10] + ["S2", "2020-09-01 00:02:10", 761.10], ], }, }, @@ -870,7 +1484,27 @@ def test_priorTo(self, init_tsdf_id, ts, n, expected_tsdf_dict): "data": [ ["2020-08-01 00:01:24", 751.92], ["2020-09-01 00:02:10", 361.1], - ["2020-09-01 00:19:12", 362.1] + ["2020-09-01 00:19:12", 362.1], + ], + }, + }, + ), + ( + "simple_date_idx", + "2020-08-02", + 5, + { + "ts_idx": {"ts_col": "date", "series_ids": ["station"]}, + "df": { + "schema": "station string, date string, temp float", + "date_convert": ["date"], + "data": [ + ["LGA", "2020-08-02", 28.79], + ["LGA", "2020-08-03", 28.53], + ["LGA", "2020-08-04", 25.57], + ["YYZ", "2020-08-02", 22.25], + ["YYZ", "2020-08-03", 20.62], + ["YYZ", "2020-08-04", 20.65], ], }, }, @@ -908,6 +1542,50 @@ def test_priorTo(self, init_tsdf_id, ts, n, expected_tsdf_dict): }, }, ), + ( + "parsed_ts_idx", + "2020-09-01 00:02:10", + 3, + { + "ts_idx": { + "ts_col": "event_ts", + "series_ids": ["symbol"], + "ts_fmt": "yyyy-MM-dd HH:mm:ss.SSS", + }, + "tsdf_constructor": "fromStringTimestamp", + "df": { + "schema": "symbol string, event_ts string, trade_pr float", + "data": [ + ["S1", "2020-09-01 00:02:10.032", 361.1], + ["S1", "2020-09-01 00:19:12.043", 362.1], + ["S2", "2020-09-01 00:02:10.076", 761.10], + ["S2", "2020-09-01 00:20:42.087", 762.33], + ], + }, + }, + ), + ( + "parsed_date_idx", + "2020-08-03", + 2, + { + "ts_idx": { + "ts_col": "date", + "series_ids": ["station"], + "ts_fmt": "yyyy-MM-dd", + }, + "tsdf_constructor": "fromStringTimestamp", + "df": { + "schema": "station string, date string, temp float", + "data": [ + ["LGA", "2020-08-03", 28.53], + ["LGA", "2020-08-04", 25.57], + ["YYZ", "2020-08-03", 20.62], + ["YYZ", "2020-08-04", 20.65], + ], + }, + }, + ), ]) def test_subsequentTo(self, init_tsdf_id, ts, n, expected_tsdf_dict): # load TSDF diff --git a/python/tests/tsschema_tests.py b/python/tests/tsschema_tests.py index 273fd993..7cbf58df 100644 --- a/python/tests/tsschema_tests.py +++ b/python/tests/tsschema_tests.py @@ -1,9 +1,9 @@ import unittest -from abc import ABC, abstractmethod -from parameterized import parameterized, parameterized_class +from abc import ABC +from typing import List +from parameterized import parameterized_class from pyspark.sql import Column, WindowSpec -from pyspark.sql import functions as sfn from pyspark.sql.types import ( StructField, StructType, @@ -12,7 +12,6 @@ DoubleType, IntegerType, DateType, - NumericType, ) from tempo.tsschema import ( @@ -105,7 +104,7 @@ def _test_index(self, ts_idx: TSIndex): ParsedTimestampIndex, {"parsed_ts_field": "parsed_ts", "src_str_field": "src_str"}, StandardTimeUnits.SECONDS, - "Column<'ts_idx.parsed_ts'>", + "[Column<'ts_idx.parsed_ts'>]", "Column<'CAST(ts_idx.parsed_ts AS DOUBLE)'>", ), ( @@ -121,7 +120,7 @@ def _test_index(self, ts_idx: TSIndex): ParsedDateIndex, {"parsed_ts_field": "parsed_date", "src_str_field": "src_str"}, StandardTimeUnits.DAYS, - "Column<'ts_idx.parsed_date'>", + "[Column<'ts_idx.parsed_date'>]", "Column<'datediff(ts_idx.parsed_date, CAST(1970-01-01 AS DATE))'>", ), ( @@ -141,8 +140,8 @@ def _test_index(self, ts_idx: TSIndex): "secondary_parsed_ts_field": "parsed_ts", "src_str_field": "src_str", }, - StandardTimeUnits.NANOSECONDS, - "Column<'ts_idx.double_ts'>", + StandardTimeUnits.SECONDS, + "[Column<'ts_idx.double_ts'>]", "Column<'ts_idx.double_ts'>", ), ], @@ -166,7 +165,7 @@ def test_comparable_expression(self): compbl_expr = ts_idx.comparableExpr() # validate the expression self.assertIsNotNone(compbl_expr) - self.assertIsInstance(compbl_expr, Column) + self.assertIsInstance(compbl_expr, (Column, List)) self.assertEqual(repr(compbl_expr), self.expected_comp_expr) def test_orderby_expression(self): @@ -176,7 +175,7 @@ def test_orderby_expression(self): orderby_expr = ts_idx.orderByExpr() # validate the expression self.assertIsNotNone(orderby_expr) - self.assertIsInstance(orderby_expr, Column) + self.assertIsInstance(orderby_expr, (Column, List)) self.assertEqual(repr(orderby_expr), self.expected_comp_expr) def test_range_expression(self): @@ -405,7 +404,7 @@ def test_range_expression(self): "series_ids": ["symbol"], }, SubMicrosecondPrecisionTimestampIndex, - StandardTimeUnits.NANOSECONDS, + StandardTimeUnits.SECONDS, "ts_idx", ["symbol"], ["ts_idx", "symbol"], diff --git a/python/tests/unit_test_data/tsdf_tests.json b/python/tests/unit_test_data/tsdf_tests.json index b7f14a33..4944417d 100644 --- a/python/tests/unit_test_data/tsdf_tests.json +++ b/python/tests/unit_test_data/tsdf_tests.json @@ -38,6 +38,26 @@ ] } }, + "simple_date_idx": { + "ts_idx": { + "ts_col": "date", + "series_ids": ["station"] + }, + "df": { + "schema": "station string, date string, temp float", + "date_convert": ["date"], + "data": [ + ["LGA", "2020-08-01", 27.58], + ["LGA", "2020-08-02", 28.79], + ["LGA", "2020-08-03", 28.53], + ["LGA", "2020-08-04", 25.57], + ["YYZ", "2020-08-01", 24.16], + ["YYZ", "2020-08-02", 22.25], + ["YYZ", "2020-08-03", 20.62], + ["YYZ", "2020-08-04", 20.65] + ] + } + }, "ordinal_double_index": { "ts_idx": { "ts_col": "event_ts_dbl", @@ -75,5 +95,47 @@ ["S2", 100, 762.33] ] } + }, + "parsed_ts_idx": { + "ts_idx": { + "ts_col": "event_ts", + "series_ids": ["symbol"], + "ts_fmt": "yyyy-MM-dd HH:mm:ss.SSS" + }, + "tsdf_constructor": "fromStringTimestamp", + "df": { + "schema": "symbol string, event_ts string, trade_pr float", + "data": [ + ["S1", "2020-08-01 00:00:10.010", 349.21], + ["S1", "2020-08-01 00:01:12.021", 351.32], + ["S1", "2020-09-01 00:02:10.032", 361.1], + ["S1", "2020-09-01 00:19:12.043", 362.1], + ["S2", "2020-08-01 00:01:10.054", 743.01], + ["S2", "2020-08-01 00:01:24.065", 751.92], + ["S2", "2020-09-01 00:02:10.076", 761.10], + ["S2", "2020-09-01 00:20:42.087", 762.33] + ] + } + }, + "parsed_date_idx": { + "ts_idx": { + "ts_col": "date", + "series_ids": ["station"], + "ts_fmt": "yyyy-MM-dd" + }, + "tsdf_constructor": "fromStringTimestamp", + "df": { + "schema": "station string, date string, temp float", + "data": [ + ["LGA", "2020-08-01", 27.58], + ["LGA", "2020-08-02", 28.79], + ["LGA", "2020-08-03", 28.53], + ["LGA", "2020-08-04", 25.57], + ["YYZ", "2020-08-01", 24.16], + ["YYZ", "2020-08-02", 22.25], + ["YYZ", "2020-08-03", 20.62], + ["YYZ", "2020-08-04", 20.65] + ] + } } } \ No newline at end of file