From 2127f3bbad52b5fc687653deb673d1bc7f399a24 Mon Sep 17 00:00:00 2001 From: Tristan Nixon <tristan.nixon@databricks.com> Date: Mon, 15 Jan 2024 15:39:27 -0800 Subject: [PATCH] checkpoint save - more advanced schema testing --- python/tempo/tsschema.py | 142 ++++--- python/tests/base.py | 58 ++- python/tests/tsschema_tests.py | 361 +++++++++++++++--- .../tests/unit_test_data/tsschema_tests.json | 37 +- 4 files changed, 440 insertions(+), 158 deletions(-) diff --git a/python/tempo/tsschema.py b/python/tempo/tsschema.py index 48cb1129..165cefe7 100644 --- a/python/tempo/tsschema.py +++ b/python/tempo/tsschema.py @@ -545,39 +545,39 @@ class ParsedTSIndex(CompositeTSIndex, ABC): """ def __init__( - self, ts_struct: StructField, parsed_ts_col: str, src_str_col: str + self, ts_struct: StructField, parsed_ts_field: str, src_str_field: str ) -> None: - super().__init__(ts_struct, parsed_ts_col) + super().__init__(ts_struct, parsed_ts_field) # validate the source string column - src_str_field = self.schema[src_str_col] - if not isinstance(src_str_field.dataType, StringType): + src_str_type = self.schema[src_str_field].dataType + if not isinstance(src_str_type, StringType): raise TypeError( "Source string column must be of StringType, " - f"but given column {src_str_field.name} " - f"is of type {src_str_field.dataType}" + f"but given column {src_str_field} " + f"is of type {src_str_type}" ) - self._src_str_col = src_str_col + self._src_str_field = src_str_field # validate the parsed column - assert parsed_ts_col in self.schema.fieldNames(), ( - f"The parsed timestamp index field {parsed_ts_col} does not exist in the " + assert parsed_ts_field in self.schema.fieldNames(), ( + f"The parsed timestamp index field {parsed_ts_field} does not exist in the " f"MultiPart TSIndex schema {self.schema}" ) - self._parsed_ts_col = parsed_ts_col + self._parsed_ts_field = parsed_ts_field @property - def src_str_col(self): - return self.fieldPath(self._src_str_col) + def src_str_field(self): + return self.fieldPath(self._src_str_field) @property - def parsed_ts_col(self): - return self.fieldPath(self._parsed_ts_col) + def parsed_ts_field(self): + return self.fieldPath(self._parsed_ts_field) def orderByExpr(self, reverse: bool = False) -> Union[Column, List[Column]]: - expr = sfn.col(self.parsed_ts_col) + expr = sfn.col(self.parsed_ts_field) return _reverse_or_not(expr, reverse) def comparableExpr(self) -> Column: - return sfn.col(self.parsed_ts_col) + return sfn.col(self.parsed_ts_field) @classmethod def fromParsedTimestamp( @@ -643,7 +643,7 @@ def unit(self) -> Optional[TimeUnit]: def rangeExpr(self, reverse: bool = False) -> Column: # cast timestamp to double (fractional seconds since epoch) - expr = sfn.col(self.parsed_ts_col).cast("double") + expr = sfn.col(self.parsed_ts_field).cast("double") return _reverse_or_not(expr, reverse) @@ -659,7 +659,7 @@ def unit(self) -> Optional[TimeUnit]: def rangeExpr(self, reverse: bool = False) -> Column: # convert date to number of days since the epoch expr = sfn.datediff( - sfn.col(self.parsed_ts_col), + sfn.col(self.parsed_ts_field), sfn.lit(EPOCH_START_DATE).cast("date"), ) return _reverse_or_not(expr, reverse) @@ -676,31 +676,31 @@ class SubMicrosecondPrecisionTimestampIndex(CompositeTSIndex): def __init__( self, ts_struct: StructField, - double_ts_col: str, - parsed_ts_col: str, - src_str_col: str, + double_ts_field: str, + secondary_parsed_ts_field: str, + src_str_field: str, num_precision_digits: int = 9, ) -> None: """ :param ts_struct: The StructField for the TSIndex column - :param double_ts_col: The name of the double-precision timestamp column - :param parsed_ts_col: The name of the parsed timestamp column - :param src_str_col: The name of the source string column + :param double_ts_field: The name of the double-precision timestamp column + :param secondary_parsed_ts_field: The name of the parsed timestamp column + :param src_str_field: The name of the source string column :param num_precision_digits: The number of digits that make up the precision of the timestamp. Ie. 9 for nanoseconds (default), 12 for picoseconds, etc. You will receive a warning if this value is 6 or less, as this is the precision of the standard timestamp type. """ - super().__init__(ts_struct, double_ts_col) + super().__init__(ts_struct, double_ts_field) # validate the double timestamp column - double_ts_field = self.schema[double_ts_col] - if not isinstance(double_ts_field.dataType, DoubleType): + double_ts_type = self.schema[double_ts_field].dataType + if not isinstance(double_ts_type, DoubleType): raise TypeError( "The double_ts_col must be of DoubleType, " - f"but the given double_ts_col {double_ts_col} " - f"has type {double_ts_field.dataType}" + f"but the given double_ts_col {double_ts_field} " + f"has type {double_ts_type}" ) - self.double_ts_col = double_ts_col + self.double_ts_field = double_ts_field # validate the number of precision digits if num_precision_digits <= 6: warnings.warn( @@ -715,30 +715,30 @@ def __init__( num_precision_digits, ) # validate the parsed column as a timestamp column - parsed_ts_field = self.schema[parsed_ts_col] - if not isinstance(parsed_ts_field.dataType, TimestampType): + parsed_ts_type = self.schema[secondary_parsed_ts_field].dataType + if not isinstance(parsed_ts_type, TimestampType): raise TypeError( "parsed_ts_col field must be of TimestampType, " - f"but the given parsed_ts_col {parsed_ts_col} " - f"has type {parsed_ts_field.dataType}" + f"but the given parsed_ts_col {secondary_parsed_ts_field} " + f"has type {parsed_ts_type}" ) - self.parsed_ts_col = parsed_ts_col + self.parsed_ts_field = secondary_parsed_ts_field # validate the source column as a string column - src_str_field = self.schema[src_str_col] + src_str_field = self.schema[src_str_field] if not isinstance(src_str_field.dataType, StringType): raise TypeError( "src_str_col field must be of StringType, " - f"but the given src_str_col {src_str_col} " + f"but the given src_str_col {src_str_field} " f"has type {src_str_field.dataType}" ) - self.src_str_col = src_str_col + self.src_str_col = src_str_field @property def unit(self) -> Optional[TimeUnit]: return self.__unit def comparableExpr(self) -> Column: - return sfn.col(self.fieldPath(self.double_ts_col)) + return sfn.col(self.fieldPath(self.double_ts_field)) def orderByExpr(self, reverse: bool = False) -> Union[Column, List[Column]]: return _reverse_or_not(self.comparableExpr(), reverse) @@ -851,8 +851,10 @@ def __eq__(self, o: object) -> bool: # must be of TSSchema type if not isinstance(o, TSSchema): return False - # must have same TSIndex - if self.ts_idx != o.ts_idx: + # must have TSIndices with the same unit + if not ((self.ts_idx.has_unit == o.ts_idx.has_unit) + and + (self.ts_idx.unit == o.ts_idx.unit)): return False # must have the same series IDs if self.series_ids != o.series_ids: @@ -860,21 +862,57 @@ def __eq__(self, o: object) -> bool: return True def __repr__(self) -> str: - return self.__str__() - - def __str__(self) -> str: - return f"""TSSchema({id(self)}) - TSIndex: {self.ts_idx} - Series IDs: {self.series_ids}""" + return f"{self.__class__.__name__}(ts_idx={self.ts_idx}, series_ids={self.series_ids})" @classmethod - def fromDFSchema( - cls, df_schema: StructType, ts_col: str, series_ids: Collection[str] = None - ) -> "TSSchema": + def fromDFSchema(cls, + df_schema: StructType, + ts_col: str, + series_ids: Optional[Collection[str]] = None) -> "TSSchema": # construct a TSIndex for the given ts_col ts_idx = SimpleTSIndex.fromTSCol(df_schema[ts_col]) return cls(ts_idx, series_ids) + @classmethod + def fromParsedTSIndex(cls, + df_schema: StructType, + ts_col: str, + parsed_field: str, + src_str_field: str, + series_ids: Optional[Collection[str]] = None, + secondary_parsed_field: Optional[str] = None) -> "TSSchema": + ts_idx_schema = df_schema[ts_col].dataType + assert isinstance(ts_idx_schema, StructType), \ + f"Expected a StructType for ts_col {ts_col}, but got {ts_idx_schema}" + # construct the TSIndex + parsed_type = ts_idx_schema[parsed_field].dataType + if isinstance(parsed_type, DoubleType): + ts_idx = SubMicrosecondPrecisionTimestampIndex( + df_schema[ts_col], + parsed_field, + secondary_parsed_field, + src_str_field, + ) + elif isinstance(parsed_type, TimestampType): + ts_idx = ParsedTimestampIndex( + df_schema[ts_col], + parsed_field, + src_str_field, + ) + elif isinstance(parsed_type, DateType): + ts_idx = ParsedDateIndex( + df_schema[ts_col], + parsed_field, + src_str_field, + ) + else: + raise TypeError( + f"Expected a DoubleType, TimestampType or DateType for parsed_field {parsed_field}, " + f"but got {parsed_type}" + ) + # construct the TSSchema + return cls(ts_idx, series_ids) + @property def structural_columns(self) -> list[str]: """ @@ -898,7 +936,7 @@ def find_observational_columns(self, df_schema: StructType) -> list[str]: return list(set(df_schema.fieldNames()) - set(self.structural_columns)) @classmethod - def is_metric_col(cls, col: StructField) -> bool: + def __is_metric_col_type(cls, col: StructField) -> bool: return isinstance(col.dataType, NumericType) or isinstance( col.dataType, BooleanType ) @@ -907,7 +945,7 @@ def find_metric_columns(self, df_schema: StructType) -> list[str]: return [ col.name for col in df_schema.fields - if self.is_metric_col(col) + if self.__is_metric_col_type(col) and (col.name in self.find_observational_columns(df_schema)) ] diff --git a/python/tests/base.py b/python/tests/base.py index a0c37d33..e6aaac2c 100644 --- a/python/tests/base.py +++ b/python/tests/base.py @@ -58,10 +58,8 @@ def tearDownClass(cls) -> None: cls.spark.stop() def setUp(self) -> None: - self.test_data = self.__loadTestData(self.id()) - - def tearDown(self) -> None: - del self.test_data + if self.test_data is None: + self.test_data = self.__loadTestData(self.id()) # # Utility Functions @@ -73,16 +71,15 @@ def get_data_as_sdf(self, name: str, convert_ts_col=True): if convert_ts_col and (td.get("ts_col", None) or td.get("other_ts_cols", [])): ts_cols = [td["ts_col"]] if "ts_col" in td else [] ts_cols.extend(td.get("other_ts_cols", [])) - return self.buildTestDF(td["schema"], td["data"], ts_cols) + return self.buildTestDF(td["df"]) def get_data_as_tsdf(self, name: str, convert_ts_col=True): df = self.get_data_as_sdf(name, convert_ts_col) td = self.test_data[name] if "sequence_col" in td: - tsdf = TSDF.fromSubsequenceCol(df, - td["ts_col"], - td["sequence_col"], - td.get("series_ids", None)) + tsdf = TSDF.fromSubsequenceCol( + df, td["ts_col"], td["sequence_col"], td.get("series_ids", None) + ) else: tsdf = TSDF(df, ts_col=td["ts_col"], series_ids=td.get("series_ids", None)) return tsdf @@ -112,7 +109,8 @@ def __getTestDataFilePath(self, test_file_name: str) -> str: dir_path = "./tests" elif cwd != "tests": raise RuntimeError( - f"Cannot locate test data file {test_file_name}, running from dir {os.getcwd()}" + f"Cannot locate test data file {test_file_name}, running from dir" + f" {os.getcwd()}" ) # return appropriate path @@ -140,35 +138,27 @@ def __loadTestData(self, test_case_path: str) -> dict: if class_name not in data_metadata_from_json: warnings.warn(f"Could not load test data for {file_name}.{class_name}") return {} - if func_name not in data_metadata_from_json[class_name]: - warnings.warn( - f"Could not load test data for {file_name}.{class_name}.{func_name}" - ) - return {} - return data_metadata_from_json[class_name][func_name] - - def buildTestDF(self, schema, data, ts_cols=["event_ts"]): + # if func_name not in data_metadata_from_json[class_name]: + # warnings.warn( + # f"Could not load test data for {file_name}.{class_name}.{func_name}" + # ) + # return {} + # return data_metadata_from_json[class_name][func_name] + return data_metadata_from_json[class_name] + + def buildTestDF(self, df_spec): """ Constructs a Spark Dataframe from the given components - :param schema: the schema to use for the Dataframe - :param data: values to use for the Dataframe - :param ts_cols: list of column names to be converted to Timestamp values + :param df_spec: a dictionary containing the following keys: schema, data, ts_convert :return: a Spark Dataframe, constructed from the given schema and values """ # build dataframe - df = self.spark.createDataFrame(data, schema) - - # check if ts_col follows standard timestamp format, then check if timestamp has micro/nanoseconds - for tsc in ts_cols: - ts_value = str(df.select(ts_cols).limit(1).collect()[0][0]) - ts_pattern = r"^\d{4}-\d{2}-\d{2}| \d{2}:\d{2}:\d{2}\.\d*$" - decimal_pattern = r"[.]\d+" - if re.match(ts_pattern, str(ts_value)) is not None: - if ( - re.search(decimal_pattern, ts_value) is None - or len(re.search(decimal_pattern, ts_value)[0]) <= 4 - ): - df = df.withColumn(tsc, sfn.to_timestamp(sfn.col(tsc))) + df = self.spark.createDataFrame(df_spec['data'], df_spec['schema']) + + # convert timestamp columns + if 'ts_convert' in df_spec: + for ts_col in df_spec['ts_convert']: + df = df.withColumn(ts_col, sfn.to_timestamp(ts_col)) return df # diff --git a/python/tests/tsschema_tests.py b/python/tests/tsschema_tests.py index e040d62a..0a5b0f4a 100644 --- a/python/tests/tsschema_tests.py +++ b/python/tests/tsschema_tests.py @@ -1,7 +1,8 @@ import unittest -from parameterized import parameterized_class +from abc import ABC, abstractmethod +from parameterized import parameterized, parameterized_class -from pyspark.sql import Column +from pyspark.sql import Column, WindowSpec from pyspark.sql import functions as sfn from pyspark.sql.types import ( StructField, @@ -11,6 +12,7 @@ DoubleType, IntegerType, DateType, + NumericType, ) from tempo.tsschema import ( @@ -20,16 +22,39 @@ SimpleDateIndex, StandardTimeUnits, ParsedTimestampIndex, - ParsedDateIndex + ParsedDateIndex, + SubMicrosecondPrecisionTimestampIndex, + TSSchema, ) from tests.base import SparkTest +class TSIndexTests(SparkTest, ABC): + @abstractmethod + def _create_index(self) -> TSIndex: + pass + + def _test_index(self, ts_idx: TSIndex): + # must be a valid TSIndex object + self.assertIsNotNone(ts_idx) + self.assertIsInstance(ts_idx, self.idx_class) + # must have the correct field name and type + self.assertEqual(ts_idx.colname, self.ts_field.name) + self.assertEqual(ts_idx.dataType, self.ts_field.dataType) + # validate the unit + if self.ts_unit is None: + self.assertFalse(ts_idx.has_unit) + else: + self.assertTrue(ts_idx.has_unit) + self.assertEqual(ts_idx.unit, self.ts_unit) + + @parameterized_class( ( "name", "ts_field", "idx_class", + "extra_constr_args", "ts_unit", "expected_comp_expr", "expected_range_expr", @@ -39,6 +64,7 @@ "simple_timestamp_index", StructField("event_ts", TimestampType()), SimpleTimestampIndex, + None, StandardTimeUnits.SECONDS, "Column<'event_ts'>", "Column<'CAST(event_ts AS DOUBLE)'>", @@ -48,6 +74,7 @@ StructField("event_ts_dbl", DoubleType()), OrdinalTSIndex, None, + None, "Column<'event_ts_dbl'>", None, ), @@ -56,6 +83,7 @@ StructField("order", IntegerType()), OrdinalTSIndex, None, + None, "Column<'order'>", None, ), @@ -63,32 +91,81 @@ "simple_date_index", StructField("date", DateType()), SimpleDateIndex, + None, StandardTimeUnits.DAYS, "Column<'date'>", "Column<'datediff(date, CAST(1970-01-01 AS DATE))'>", ), + ( + "parsed_timestamp_index", + StructField( + "ts_idx", + StructType([ + StructField("parsed_ts", TimestampType(), True), + StructField("src_str", StringType(), True), + ]), + True, + ), + ParsedTimestampIndex, + {"parsed_ts_field": "parsed_ts", "src_str_field": "src_str"}, + StandardTimeUnits.SECONDS, + "Column<'ts_idx.parsed_ts'>", + "Column<'CAST(ts_idx.parsed_ts AS DOUBLE)'>", + ), + ( + "parsed_date_index", + StructField( + "ts_idx", + StructType([ + StructField("parsed_date", DateType(), True), + StructField("src_str", StringType(), True), + ]), + True, + ), + ParsedDateIndex, + {"parsed_ts_field": "parsed_date", "src_str_field": "src_str"}, + StandardTimeUnits.DAYS, + "Column<'ts_idx.parsed_date'>", + "Column<'datediff(ts_idx.parsed_date, CAST(1970-01-01 AS DATE))'>", + ), + ( + "sub_ms_index", + StructField( + "ts_idx", + StructType([ + StructField("double_ts", DoubleType(), True), + StructField("parsed_ts", TimestampType(), True), + StructField("src_str", StringType(), True), + ]), + True, + ), + SubMicrosecondPrecisionTimestampIndex, + { + "double_ts_field": "double_ts", + "secondary_parsed_ts_field": "parsed_ts", + "src_str_field": "src_str", + }, + StandardTimeUnits.NANOSECONDS, + "Column<'ts_idx.double_ts'>", + "Column<'ts_idx.double_ts'>", + ), ], ) -class SimpleTSIndexTests(SparkTest): - def test_constructor(self): +class SimpleTSIndexTests(TSIndexTests): + def _create_index(self) -> TSIndex: + if self.extra_constr_args: + return self.idx_class(self.ts_field, **self.extra_constr_args) + return self.idx_class(self.ts_field) + + def test_index(self): # create a timestamp index - ts_idx = self.idx_class(self.ts_field) - # must be a valid TSIndex object - self.assertIsNotNone(ts_idx) - self.assertIsInstance(ts_idx, self.idx_class) - # must have the correct field name and type - self.assertEqual(ts_idx.colname, self.ts_field.name) - self.assertEqual(ts_idx.dataType, self.ts_field.dataType) - # validate the unit - if self.ts_unit is None: - self.assertFalse(ts_idx.has_unit) - else: - self.assertTrue(ts_idx.has_unit) - self.assertEqual(ts_idx.unit, self.ts_unit) + ts_idx = self._create_index() + # test the index + self._test_index(ts_idx) def test_comparable_expression(self): # create a timestamp index - ts_idx: TSIndex = self.idx_class(self.ts_field) + ts_idx = self._create_index() # get the expressions compbl_expr = ts_idx.comparableExpr() # validate the expression @@ -98,7 +175,7 @@ def test_comparable_expression(self): def test_orderby_expression(self): # create a timestamp index - ts_idx: TSIndex = self.idx_class(self.ts_field) + ts_idx = self._create_index() # get the expressions orderby_expr = ts_idx.orderByExpr() # validate the expression @@ -108,7 +185,7 @@ def test_orderby_expression(self): def test_range_expression(self): # create a timestamp index - ts_idx = self.idx_class(self.ts_field) + ts_idx = self._create_index() # get the expressions if isinstance(ts_idx, OrdinalTSIndex): self.assertRaises(NotImplementedError, ts_idx.rangeExpr) @@ -128,7 +205,7 @@ def test_range_expression(self): "idx_class", "ts_unit", "expected_comp_expr", - "expected_range_expr" + "expected_range_expr", ), [ ( @@ -141,11 +218,11 @@ def test_range_expression(self): ]), True, ), - {"parsed_ts_col": "parsed_ts", "src_str_col": "src_str"}, + {"parsed_ts_field": "parsed_ts", "src_str_field": "src_str"}, ParsedTimestampIndex, StandardTimeUnits.SECONDS, "Column<'ts_idx.parsed_ts'>", - "Column<'CAST(ts_idx.parsed_ts AS DOUBLE)'>" + "Column<'CAST(ts_idx.parsed_ts AS DOUBLE)'>", ), ( "parsed_date_index", @@ -157,47 +234,48 @@ def test_range_expression(self): ]), True, ), - {"parsed_ts_col": "parsed_date", "src_str_col": "src_str"}, + {"parsed_ts_field": "parsed_date", "src_str_field": "src_str"}, ParsedDateIndex, StandardTimeUnits.DAYS, "Column<'ts_idx.parsed_date'>", - "Column<'datediff(ts_idx.parsed_date, CAST(1970-01-01 AS DATE))'>" + "Column<'datediff(ts_idx.parsed_date, CAST(1970-01-01 AS DATE))'>", ), ( "sub_ms_index", StructField( "ts_idx", StructType([ - StructField("double_ts", TimestampType(), True), + StructField("double_ts", DoubleType(), True), StructField("parsed_ts", TimestampType(), True), StructField("src_str", StringType(), True), ]), True, ), - {"double_ts_col": "double_ts", "parsed_ts_col": "parsed_ts", "src_str_col": "src_str"}, - ParsedTimestampIndex, - StandardTimeUnits.SECONDS, - "Column<'ts_idx.parsed_ts'>", - "Column<'CAST(ts_idx.parsed_ts AS DOUBLE)'>" + { + "double_ts_field": "double_ts", + "secondary_parsed_ts_field": "parsed_ts", + "src_str_field": "src_str", + }, + SubMicrosecondPrecisionTimestampIndex, + StandardTimeUnits.NANOSECONDS, + "Column<'ts_idx.double_ts'>", + "Column<'ts_idx.double_ts'>", ), - ]) -class ParsedTSIndexTests(SparkTest): - def test_constructor(self): + ], +) +class ParsedTSIndexTests(TSIndexTests): + def _create_index(self): + return self.idx_class(ts_struct=self.ts_field, **self.constr_args) + + def test_index(self): # create a timestamp index - ts_idx = self.idx_class(ts_struct=self.ts_field, **self.constr_args) - # must be a valid TSIndex object - self.assertIsNotNone(ts_idx) - self.assertIsInstance(ts_idx, self.idx_class) - # must have the correct field name and type - self.assertEqual(ts_idx.colname, self.ts_field.name) - self.assertEqual(ts_idx.dataType, self.ts_field.dataType) - # validate the unit - self.assertTrue(ts_idx.has_unit) - self.assertEqual(ts_idx.unit, self.ts_unit) + ts_idx = self._create_index() + # test the index + self._test_index(ts_idx) def test_comparable_expression(self): # create a timestamp index - ts_idx = self.idx_class(ts_struct=self.ts_field, **self.constr_args) + ts_idx = self._create_index() # get the expressions compbl_expr = ts_idx.comparableExpr() # validate the expression @@ -207,7 +285,7 @@ def test_comparable_expression(self): def test_orderby_expression(self): # create a timestamp index - ts_idx = self.idx_class(ts_struct=self.ts_field, **self.constr_args) + ts_idx = self._create_index() # get the expressions orderby_expr = ts_idx.orderByExpr() # validate the expression @@ -217,7 +295,7 @@ def test_orderby_expression(self): def test_range_expression(self): # create a timestamp index - ts_idx = self.idx_class(ts_struct=self.ts_field, **self.constr_args) + ts_idx = self._create_index() # get the expressions range_expr = ts_idx.rangeExpr() # validate the expression @@ -226,10 +304,183 @@ def test_range_expression(self): self.assertEqual(repr(range_expr), self.expected_range_expr) -# class TSSchemaTests(SparkTest): -# def test_simple_tsIndex(self): -# schema_str = "event_ts timestamp, symbol string, trade_pr double" -# schema = _parse_datatype_string(schema_str) -# ts_idx = TSSchema.fromDFSchema(schema, "event_ts", ["symbol"]) -# -# print(ts_idx) +class TSSchemaTests(TSIndexTests, ABC): + @abstractmethod + def _create_ts_schema(self) -> TSSchema: + pass + + +@parameterized_class( + ("name", "df_schema", "ts_col", "series_ids", "idx_class", "ts_unit"), + [ + ( + "simple_timestamp_index", + StructType([ + StructField("symbol", StringType(), True), + StructField("event_ts", TimestampType(), True), + StructField("trade_pr", DoubleType(), True), + StructField("trade_vol", IntegerType(), True), + ]), + "event_ts", + ["symbol"], + SimpleTimestampIndex, + StandardTimeUnits.SECONDS, + ), + ( + "simple_ts_no_series", + StructType([ + StructField("event_ts", TimestampType(), True), + StructField("trade_pr", DoubleType(), True), + StructField("trade_vol", IntegerType(), True), + ]), + "event_ts", + [], + SimpleTimestampIndex, + StandardTimeUnits.SECONDS, + ), + ( + "ordinal_double_index", + StructType([ + StructField("symbol", StringType(), True), + StructField("event_ts_dbl", DoubleType(), True), + StructField("trade_pr", DoubleType(), True), + ]), + "event_ts_dbl", + ["symbol"], + OrdinalTSIndex, + None, + ), + ( + "ordinal_int_index", + StructType([ + StructField("symbol", StringType(), True), + StructField("order", IntegerType(), True), + StructField("trade_pr", DoubleType(), True), + ]), + "order", + ["symbol"], + OrdinalTSIndex, + None, + ), + ( + "simple_date_index", + StructType([ + StructField("symbol", StringType(), True), + StructField("date", DateType(), True), + StructField("trade_pr", DoubleType(), True), + ]), + "date", + ["symbol"], + SimpleDateIndex, + StandardTimeUnits.DAYS, + ), + ], +) +class SimpleIndexTSSchemaTests(TSSchemaTests): + def _create_index(self) -> TSIndex: + pass + + def _create_ts_schema(self) -> TSSchema: + return TSSchema.fromDFSchema(self.df_schema, self.ts_col, self.series_ids) + + def setUp(self) -> None: + super().setUp() + self.ts_field = self.df_schema[self.ts_col] + + def test_schema(self): + print(self.id()) + # create a TSSchema + ts_schema = self._create_ts_schema() + # make sure it's a valid TSSchema instance + self.assertIsNotNone(ts_schema) + self.assertIsInstance(ts_schema, TSSchema) + # test the index + self._test_index(ts_schema.ts_idx) + # test the series ids + self.assertEqual(ts_schema.series_ids, self.series_ids) + # validate the index + ts_schema.validate(self.df_schema) + + def test_structural_cols(self): + # create a TSSchema + ts_schema = self._create_ts_schema() + # test the structural columns + struct_cols = [self.ts_col] + self.series_ids + self.assertEqual(set(ts_schema.structural_columns), set(struct_cols)) + + def test_observational_cols(self): + # create a TSSchema + ts_schema = self._create_ts_schema() + # test the structural columns + struct_cols = [self.ts_col] + self.series_ids + obs_cols = set(self.df_schema.fieldNames()) - set(struct_cols) + self.assertEqual( + ts_schema.find_observational_columns(self.df_schema), list(obs_cols) + ) + + def test_metric_cols(self): + # create a TSSchema + ts_schema = self._create_ts_schema() + # test the metric columns + struct_cols = [self.ts_col] + self.series_ids + obs_cols = set(self.df_schema.fieldNames()) - set(struct_cols) + metric_cols = { + mc + for mc in obs_cols + if isinstance(self.df_schema[mc].dataType, NumericType) + } + self.assertEqual( + set(ts_schema.find_metric_columns(self.df_schema)), metric_cols + ) + + def test_base_window(self): + # create a TSSchema + ts_schema = self._create_ts_schema() + # test the base window + bw = ts_schema.baseWindow() + self.assertIsNotNone(bw) + self.assertIsInstance(bw, WindowSpec) + # test it in reverse + bw_rev = ts_schema.baseWindow(reverse=True) + self.assertIsNotNone(bw_rev) + self.assertIsInstance(bw_rev, WindowSpec) + + def test_rows_window(self): + # create a TSSchema + ts_schema = self._create_ts_schema() + # test the base window + rows_win = ts_schema.rowsBetweenWindow(0, 10) + self.assertIsNotNone(rows_win) + self.assertIsInstance(rows_win, WindowSpec) + # test it in reverse + rows_win_rev = ts_schema.rowsBetweenWindow(0, 10, reverse=True) + self.assertIsNotNone(rows_win_rev) + self.assertIsInstance(rows_win_rev, WindowSpec) + + def test_range_window(self): + # create a TSSchema + ts_schema = self._create_ts_schema() + # test the base window + if ts_schema.ts_idx.has_unit: + range_win = ts_schema.rangeBetweenWindow(0, 10) + self.assertIsNotNone(range_win) + self.assertIsInstance(range_win, WindowSpec) + else: + self.assertRaises(NotImplementedError, ts_schema.rangeBetweenWindow, 0, 10) + # test it in reverse + if ts_schema.ts_idx.has_unit: + range_win_rev = ts_schema.rangeBetweenWindow(0, 10, reverse=True) + self.assertIsNotNone(range_win_rev) + self.assertIsInstance(range_win_rev, WindowSpec) + else: + self.assertRaises( + NotImplementedError, ts_schema.rangeBetweenWindow, 0, 10, reverse=True + ) + + +class ParsedIndexTSSchemaTests(TSSchemaTests): + def _create_index(self) -> TSIndex: + pass + + def _create_ts_schema(self) -> TSSchema: + pass diff --git a/python/tests/unit_test_data/tsschema_tests.json b/python/tests/unit_test_data/tsschema_tests.json index bc820168..3e78def3 100644 --- a/python/tests/unit_test_data/tsschema_tests.json +++ b/python/tests/unit_test_data/tsschema_tests.json @@ -1,26 +1,29 @@ { "__SharedData": { "simple_ts_idx": { - "schema": "symbol string, event_ts string, trade_pr float", - "ts_col": "event_ts", - "series_ids": ["symbol"], - "data": [ - ["S1", "2020-08-01 00:00:10", 349.21], - ["S1", "2020-08-01 00:01:12", 351.32], - ["S1", "2020-09-01 00:02:10", 361.1], - ["S1", "2020-09-01 00:19:12", 362.1], - ["S2", "2020-08-01 00:01:10", 743.01], - ["S2", "2020-08-01 00:01:24", 751.92], - ["S2", "2020-09-01 00:02:10", 761.10], - ["S2", "2020-09-01 00:20:42", 762.33] - ] + "ts_idx": { + "ts_col": "event_ts", + "series_ids": ["symbol"] + }, + "df": { + "schema": "symbol string, event_ts string, trade_pr float", + "ts_convert": ["event_ts"], + "data": [ + ["S1", "2020-08-01 00:00:10", 349.21], + ["S1", "2020-08-01 00:01:12", 351.32], + ["S1", "2020-09-01 00:02:10", 361.1], + ["S1", "2020-09-01 00:19:12", 362.1], + ["S2", "2020-08-01 00:01:10", 743.01], + ["S2", "2020-08-01 00:01:24", 751.92], + ["S2", "2020-09-01 00:02:10", 761.10], + ["S2", "2020-09-01 00:20:42", 762.33] + ] + } } }, "TSSchemaTests": { - "test_simple_tsIndex": { - "simple_ts_idx": { - "$ref": "#/__SharedData/simple_ts_idx" - } + "simple_ts_idx": { + "$ref": "#/__SharedData/simple_ts_idx" } } } \ No newline at end of file