From 2127f3bbad52b5fc687653deb673d1bc7f399a24 Mon Sep 17 00:00:00 2001
From: Tristan Nixon <tristan.nixon@databricks.com>
Date: Mon, 15 Jan 2024 15:39:27 -0800
Subject: [PATCH] checkpoint save - more advanced schema testing

---
 python/tempo/tsschema.py                      | 142 ++++---
 python/tests/base.py                          |  58 ++-
 python/tests/tsschema_tests.py                | 361 +++++++++++++++---
 .../tests/unit_test_data/tsschema_tests.json  |  37 +-
 4 files changed, 440 insertions(+), 158 deletions(-)

diff --git a/python/tempo/tsschema.py b/python/tempo/tsschema.py
index 48cb1129..165cefe7 100644
--- a/python/tempo/tsschema.py
+++ b/python/tempo/tsschema.py
@@ -545,39 +545,39 @@ class ParsedTSIndex(CompositeTSIndex, ABC):
     """
 
     def __init__(
-        self, ts_struct: StructField, parsed_ts_col: str, src_str_col: str
+        self, ts_struct: StructField, parsed_ts_field: str, src_str_field: str
     ) -> None:
-        super().__init__(ts_struct, parsed_ts_col)
+        super().__init__(ts_struct, parsed_ts_field)
         # validate the source string column
-        src_str_field = self.schema[src_str_col]
-        if not isinstance(src_str_field.dataType, StringType):
+        src_str_type = self.schema[src_str_field].dataType
+        if not isinstance(src_str_type, StringType):
             raise TypeError(
                 "Source string column must be of StringType, "
-                f"but given column {src_str_field.name} "
-                f"is of type {src_str_field.dataType}"
+                f"but given column {src_str_field} "
+                f"is of type {src_str_type}"
             )
-        self._src_str_col = src_str_col
+        self._src_str_field = src_str_field
         # validate the parsed column
-        assert parsed_ts_col in self.schema.fieldNames(), (
-            f"The parsed timestamp index field {parsed_ts_col} does not exist in the "
+        assert parsed_ts_field in self.schema.fieldNames(), (
+            f"The parsed timestamp index field {parsed_ts_field} does not exist in the "
             f"MultiPart TSIndex schema {self.schema}"
         )
-        self._parsed_ts_col = parsed_ts_col
+        self._parsed_ts_field = parsed_ts_field
 
     @property
-    def src_str_col(self):
-        return self.fieldPath(self._src_str_col)
+    def src_str_field(self):
+        return self.fieldPath(self._src_str_field)
 
     @property
-    def parsed_ts_col(self):
-        return self.fieldPath(self._parsed_ts_col)
+    def parsed_ts_field(self):
+        return self.fieldPath(self._parsed_ts_field)
 
     def orderByExpr(self, reverse: bool = False) -> Union[Column, List[Column]]:
-        expr = sfn.col(self.parsed_ts_col)
+        expr = sfn.col(self.parsed_ts_field)
         return _reverse_or_not(expr, reverse)
 
     def comparableExpr(self) -> Column:
-        return sfn.col(self.parsed_ts_col)
+        return sfn.col(self.parsed_ts_field)
 
     @classmethod
     def fromParsedTimestamp(
@@ -643,7 +643,7 @@ def unit(self) -> Optional[TimeUnit]:
 
     def rangeExpr(self, reverse: bool = False) -> Column:
         # cast timestamp to double (fractional seconds since epoch)
-        expr = sfn.col(self.parsed_ts_col).cast("double")
+        expr = sfn.col(self.parsed_ts_field).cast("double")
         return _reverse_or_not(expr, reverse)
 
 
@@ -659,7 +659,7 @@ def unit(self) -> Optional[TimeUnit]:
     def rangeExpr(self, reverse: bool = False) -> Column:
         # convert date to number of days since the epoch
         expr = sfn.datediff(
-            sfn.col(self.parsed_ts_col),
+            sfn.col(self.parsed_ts_field),
             sfn.lit(EPOCH_START_DATE).cast("date"),
         )
         return _reverse_or_not(expr, reverse)
@@ -676,31 +676,31 @@ class SubMicrosecondPrecisionTimestampIndex(CompositeTSIndex):
     def __init__(
         self,
         ts_struct: StructField,
-        double_ts_col: str,
-        parsed_ts_col: str,
-        src_str_col: str,
+        double_ts_field: str,
+        secondary_parsed_ts_field: str,
+        src_str_field: str,
         num_precision_digits: int = 9,
     ) -> None:
         """
         :param ts_struct: The StructField for the TSIndex column
-        :param double_ts_col: The name of the double-precision timestamp column
-        :param parsed_ts_col: The name of the parsed timestamp column
-        :param src_str_col: The name of the source string column
+        :param double_ts_field: The name of the double-precision timestamp column
+        :param secondary_parsed_ts_field: The name of the parsed timestamp column
+        :param src_str_field: The name of the source string column
         :param num_precision_digits: The number of digits that make up the precision of
         the timestamp. Ie. 9 for nanoseconds (default), 12 for picoseconds, etc.
         You will receive a warning if this value is 6 or less, as this is the precision
         of the standard timestamp type.
         """
-        super().__init__(ts_struct, double_ts_col)
+        super().__init__(ts_struct, double_ts_field)
         # validate the double timestamp column
-        double_ts_field = self.schema[double_ts_col]
-        if not isinstance(double_ts_field.dataType, DoubleType):
+        double_ts_type = self.schema[double_ts_field].dataType
+        if not isinstance(double_ts_type, DoubleType):
             raise TypeError(
                 "The double_ts_col must be of DoubleType, "
-                f"but the given double_ts_col {double_ts_col} "
-                f"has type {double_ts_field.dataType}"
+                f"but the given double_ts_col {double_ts_field} "
+                f"has type {double_ts_type}"
             )
-        self.double_ts_col = double_ts_col
+        self.double_ts_field = double_ts_field
         # validate the number of precision digits
         if num_precision_digits <= 6:
             warnings.warn(
@@ -715,30 +715,30 @@ def __init__(
             num_precision_digits,
         )
         # validate the parsed column as a timestamp column
-        parsed_ts_field = self.schema[parsed_ts_col]
-        if not isinstance(parsed_ts_field.dataType, TimestampType):
+        parsed_ts_type = self.schema[secondary_parsed_ts_field].dataType
+        if not isinstance(parsed_ts_type, TimestampType):
             raise TypeError(
                 "parsed_ts_col field must be of TimestampType, "
-                f"but the given parsed_ts_col {parsed_ts_col} "
-                f"has type {parsed_ts_field.dataType}"
+                f"but the given parsed_ts_col {secondary_parsed_ts_field} "
+                f"has type {parsed_ts_type}"
             )
-        self.parsed_ts_col = parsed_ts_col
+        self.parsed_ts_field = secondary_parsed_ts_field
         # validate the source column as a string column
-        src_str_field = self.schema[src_str_col]
+        src_str_field = self.schema[src_str_field]
         if not isinstance(src_str_field.dataType, StringType):
             raise TypeError(
                 "src_str_col field must be of StringType, "
-                f"but the given src_str_col {src_str_col} "
+                f"but the given src_str_col {src_str_field} "
                 f"has type {src_str_field.dataType}"
             )
-        self.src_str_col = src_str_col
+        self.src_str_col = src_str_field
 
     @property
     def unit(self) -> Optional[TimeUnit]:
         return self.__unit
 
     def comparableExpr(self) -> Column:
-        return sfn.col(self.fieldPath(self.double_ts_col))
+        return sfn.col(self.fieldPath(self.double_ts_field))
 
     def orderByExpr(self, reverse: bool = False) -> Union[Column, List[Column]]:
         return _reverse_or_not(self.comparableExpr(), reverse)
@@ -851,8 +851,10 @@ def __eq__(self, o: object) -> bool:
         # must be of TSSchema type
         if not isinstance(o, TSSchema):
             return False
-        # must have same TSIndex
-        if self.ts_idx != o.ts_idx:
+        # must have TSIndices with the same unit
+        if not ((self.ts_idx.has_unit == o.ts_idx.has_unit)
+                and
+                (self.ts_idx.unit == o.ts_idx.unit)):
             return False
         # must have the same series IDs
         if self.series_ids != o.series_ids:
@@ -860,21 +862,57 @@ def __eq__(self, o: object) -> bool:
         return True
 
     def __repr__(self) -> str:
-        return self.__str__()
-
-    def __str__(self) -> str:
-        return f"""TSSchema({id(self)})
-    TSIndex: {self.ts_idx}
-    Series IDs: {self.series_ids}"""
+        return f"{self.__class__.__name__}(ts_idx={self.ts_idx}, series_ids={self.series_ids})"
 
     @classmethod
-    def fromDFSchema(
-        cls, df_schema: StructType, ts_col: str, series_ids: Collection[str] = None
-    ) -> "TSSchema":
+    def fromDFSchema(cls,
+                     df_schema: StructType,
+                     ts_col: str,
+                     series_ids: Optional[Collection[str]] = None) -> "TSSchema":
         # construct a TSIndex for the given ts_col
         ts_idx = SimpleTSIndex.fromTSCol(df_schema[ts_col])
         return cls(ts_idx, series_ids)
 
+    @classmethod
+    def fromParsedTSIndex(cls,
+                          df_schema: StructType,
+                          ts_col: str,
+                          parsed_field: str,
+                          src_str_field: str,
+                          series_ids: Optional[Collection[str]] = None,
+                          secondary_parsed_field: Optional[str] = None) -> "TSSchema":
+        ts_idx_schema = df_schema[ts_col].dataType
+        assert isinstance(ts_idx_schema, StructType), \
+            f"Expected a StructType for ts_col {ts_col}, but got {ts_idx_schema}"
+        # construct the TSIndex
+        parsed_type = ts_idx_schema[parsed_field].dataType
+        if isinstance(parsed_type, DoubleType):
+            ts_idx = SubMicrosecondPrecisionTimestampIndex(
+                df_schema[ts_col],
+                parsed_field,
+                secondary_parsed_field,
+                src_str_field,
+            )
+        elif isinstance(parsed_type, TimestampType):
+            ts_idx = ParsedTimestampIndex(
+                df_schema[ts_col],
+                parsed_field,
+                src_str_field,
+            )
+        elif isinstance(parsed_type, DateType):
+            ts_idx = ParsedDateIndex(
+                df_schema[ts_col],
+                parsed_field,
+                src_str_field,
+            )
+        else:
+            raise TypeError(
+                f"Expected a DoubleType, TimestampType or DateType for parsed_field {parsed_field}, "
+                f"but got {parsed_type}"
+            )
+        # construct the TSSchema
+        return cls(ts_idx, series_ids)
+
     @property
     def structural_columns(self) -> list[str]:
         """
@@ -898,7 +936,7 @@ def find_observational_columns(self, df_schema: StructType) -> list[str]:
         return list(set(df_schema.fieldNames()) - set(self.structural_columns))
 
     @classmethod
-    def is_metric_col(cls, col: StructField) -> bool:
+    def __is_metric_col_type(cls, col: StructField) -> bool:
         return isinstance(col.dataType, NumericType) or isinstance(
             col.dataType, BooleanType
         )
@@ -907,7 +945,7 @@ def find_metric_columns(self, df_schema: StructType) -> list[str]:
         return [
             col.name
             for col in df_schema.fields
-            if self.is_metric_col(col)
+            if self.__is_metric_col_type(col)
             and (col.name in self.find_observational_columns(df_schema))
         ]
 
diff --git a/python/tests/base.py b/python/tests/base.py
index a0c37d33..e6aaac2c 100644
--- a/python/tests/base.py
+++ b/python/tests/base.py
@@ -58,10 +58,8 @@ def tearDownClass(cls) -> None:
         cls.spark.stop()
 
     def setUp(self) -> None:
-        self.test_data = self.__loadTestData(self.id())
-
-    def tearDown(self) -> None:
-        del self.test_data
+        if self.test_data is None:
+            self.test_data = self.__loadTestData(self.id())
 
     #
     # Utility Functions
@@ -73,16 +71,15 @@ def get_data_as_sdf(self, name: str, convert_ts_col=True):
         if convert_ts_col and (td.get("ts_col", None) or td.get("other_ts_cols", [])):
             ts_cols = [td["ts_col"]] if "ts_col" in td else []
             ts_cols.extend(td.get("other_ts_cols", []))
-        return self.buildTestDF(td["schema"], td["data"], ts_cols)
+        return self.buildTestDF(td["df"])
 
     def get_data_as_tsdf(self, name: str, convert_ts_col=True):
         df = self.get_data_as_sdf(name, convert_ts_col)
         td = self.test_data[name]
         if "sequence_col" in td:
-            tsdf = TSDF.fromSubsequenceCol(df,
-                                           td["ts_col"],
-                                           td["sequence_col"],
-                                           td.get("series_ids", None))
+            tsdf = TSDF.fromSubsequenceCol(
+                df, td["ts_col"], td["sequence_col"], td.get("series_ids", None)
+            )
         else:
             tsdf = TSDF(df, ts_col=td["ts_col"], series_ids=td.get("series_ids", None))
         return tsdf
@@ -112,7 +109,8 @@ def __getTestDataFilePath(self, test_file_name: str) -> str:
             dir_path = "./tests"
         elif cwd != "tests":
             raise RuntimeError(
-                f"Cannot locate test data file {test_file_name}, running from dir {os.getcwd()}"
+                f"Cannot locate test data file {test_file_name}, running from dir"
+                f" {os.getcwd()}"
             )
 
         # return appropriate path
@@ -140,35 +138,27 @@ def __loadTestData(self, test_case_path: str) -> dict:
             if class_name not in data_metadata_from_json:
                 warnings.warn(f"Could not load test data for {file_name}.{class_name}")
                 return {}
-            if func_name not in data_metadata_from_json[class_name]:
-                warnings.warn(
-                    f"Could not load test data for {file_name}.{class_name}.{func_name}"
-                )
-                return {}
-            return data_metadata_from_json[class_name][func_name]
-
-    def buildTestDF(self, schema, data, ts_cols=["event_ts"]):
+            # if func_name not in data_metadata_from_json[class_name]:
+            #     warnings.warn(
+            #         f"Could not load test data for {file_name}.{class_name}.{func_name}"
+            #     )
+            #     return {}
+            # return data_metadata_from_json[class_name][func_name]
+            return data_metadata_from_json[class_name]
+
+    def buildTestDF(self, df_spec):
         """
         Constructs a Spark Dataframe from the given components
-        :param schema: the schema to use for the Dataframe
-        :param data: values to use for the Dataframe
-        :param ts_cols: list of column names to be converted to Timestamp values
+        :param df_spec: a dictionary containing the following keys: schema, data, ts_convert
         :return: a Spark Dataframe, constructed from the given schema and values
         """
         # build dataframe
-        df = self.spark.createDataFrame(data, schema)
-
-        # check if ts_col follows standard timestamp format, then check if timestamp has micro/nanoseconds
-        for tsc in ts_cols:
-            ts_value = str(df.select(ts_cols).limit(1).collect()[0][0])
-            ts_pattern = r"^\d{4}-\d{2}-\d{2}| \d{2}:\d{2}:\d{2}\.\d*$"
-            decimal_pattern = r"[.]\d+"
-            if re.match(ts_pattern, str(ts_value)) is not None:
-                if (
-                    re.search(decimal_pattern, ts_value) is None
-                    or len(re.search(decimal_pattern, ts_value)[0]) <= 4
-                ):
-                    df = df.withColumn(tsc, sfn.to_timestamp(sfn.col(tsc)))
+        df = self.spark.createDataFrame(df_spec['data'], df_spec['schema'])
+
+        # convert timestamp columns
+        if 'ts_convert' in df_spec:
+            for ts_col in df_spec['ts_convert']:
+                df = df.withColumn(ts_col, sfn.to_timestamp(ts_col))
         return df
 
     #
diff --git a/python/tests/tsschema_tests.py b/python/tests/tsschema_tests.py
index e040d62a..0a5b0f4a 100644
--- a/python/tests/tsschema_tests.py
+++ b/python/tests/tsschema_tests.py
@@ -1,7 +1,8 @@
 import unittest
-from parameterized import parameterized_class
+from abc import ABC, abstractmethod
+from parameterized import parameterized, parameterized_class
 
-from pyspark.sql import Column
+from pyspark.sql import Column, WindowSpec
 from pyspark.sql import functions as sfn
 from pyspark.sql.types import (
     StructField,
@@ -11,6 +12,7 @@
     DoubleType,
     IntegerType,
     DateType,
+    NumericType,
 )
 
 from tempo.tsschema import (
@@ -20,16 +22,39 @@
     SimpleDateIndex,
     StandardTimeUnits,
     ParsedTimestampIndex,
-    ParsedDateIndex
+    ParsedDateIndex,
+    SubMicrosecondPrecisionTimestampIndex,
+    TSSchema,
 )
 from tests.base import SparkTest
 
 
+class TSIndexTests(SparkTest, ABC):
+    @abstractmethod
+    def _create_index(self) -> TSIndex:
+        pass
+
+    def _test_index(self, ts_idx: TSIndex):
+        # must be a valid TSIndex object
+        self.assertIsNotNone(ts_idx)
+        self.assertIsInstance(ts_idx, self.idx_class)
+        # must have the correct field name and type
+        self.assertEqual(ts_idx.colname, self.ts_field.name)
+        self.assertEqual(ts_idx.dataType, self.ts_field.dataType)
+        # validate the unit
+        if self.ts_unit is None:
+            self.assertFalse(ts_idx.has_unit)
+        else:
+            self.assertTrue(ts_idx.has_unit)
+            self.assertEqual(ts_idx.unit, self.ts_unit)
+
+
 @parameterized_class(
     (
         "name",
         "ts_field",
         "idx_class",
+        "extra_constr_args",
         "ts_unit",
         "expected_comp_expr",
         "expected_range_expr",
@@ -39,6 +64,7 @@
             "simple_timestamp_index",
             StructField("event_ts", TimestampType()),
             SimpleTimestampIndex,
+            None,
             StandardTimeUnits.SECONDS,
             "Column<'event_ts'>",
             "Column<'CAST(event_ts AS DOUBLE)'>",
@@ -48,6 +74,7 @@
             StructField("event_ts_dbl", DoubleType()),
             OrdinalTSIndex,
             None,
+            None,
             "Column<'event_ts_dbl'>",
             None,
         ),
@@ -56,6 +83,7 @@
             StructField("order", IntegerType()),
             OrdinalTSIndex,
             None,
+            None,
             "Column<'order'>",
             None,
         ),
@@ -63,32 +91,81 @@
             "simple_date_index",
             StructField("date", DateType()),
             SimpleDateIndex,
+            None,
             StandardTimeUnits.DAYS,
             "Column<'date'>",
             "Column<'datediff(date, CAST(1970-01-01 AS DATE))'>",
         ),
+        (
+            "parsed_timestamp_index",
+            StructField(
+                "ts_idx",
+                StructType([
+                    StructField("parsed_ts", TimestampType(), True),
+                    StructField("src_str", StringType(), True),
+                ]),
+                True,
+            ),
+            ParsedTimestampIndex,
+            {"parsed_ts_field": "parsed_ts", "src_str_field": "src_str"},
+            StandardTimeUnits.SECONDS,
+            "Column<'ts_idx.parsed_ts'>",
+            "Column<'CAST(ts_idx.parsed_ts AS DOUBLE)'>",
+        ),
+        (
+            "parsed_date_index",
+            StructField(
+                "ts_idx",
+                StructType([
+                    StructField("parsed_date", DateType(), True),
+                    StructField("src_str", StringType(), True),
+                ]),
+                True,
+            ),
+            ParsedDateIndex,
+            {"parsed_ts_field": "parsed_date", "src_str_field": "src_str"},
+            StandardTimeUnits.DAYS,
+            "Column<'ts_idx.parsed_date'>",
+            "Column<'datediff(ts_idx.parsed_date, CAST(1970-01-01 AS DATE))'>",
+        ),
+        (
+            "sub_ms_index",
+            StructField(
+                "ts_idx",
+                StructType([
+                    StructField("double_ts", DoubleType(), True),
+                    StructField("parsed_ts", TimestampType(), True),
+                    StructField("src_str", StringType(), True),
+                ]),
+                True,
+            ),
+            SubMicrosecondPrecisionTimestampIndex,
+            {
+                "double_ts_field": "double_ts",
+                "secondary_parsed_ts_field": "parsed_ts",
+                "src_str_field": "src_str",
+            },
+            StandardTimeUnits.NANOSECONDS,
+            "Column<'ts_idx.double_ts'>",
+            "Column<'ts_idx.double_ts'>",
+        ),
     ],
 )
-class SimpleTSIndexTests(SparkTest):
-    def test_constructor(self):
+class SimpleTSIndexTests(TSIndexTests):
+    def _create_index(self) -> TSIndex:
+        if self.extra_constr_args:
+            return self.idx_class(self.ts_field, **self.extra_constr_args)
+        return self.idx_class(self.ts_field)
+
+    def test_index(self):
         # create a timestamp index
-        ts_idx = self.idx_class(self.ts_field)
-        # must be a valid TSIndex object
-        self.assertIsNotNone(ts_idx)
-        self.assertIsInstance(ts_idx, self.idx_class)
-        # must have the correct field name and type
-        self.assertEqual(ts_idx.colname, self.ts_field.name)
-        self.assertEqual(ts_idx.dataType, self.ts_field.dataType)
-        # validate the unit
-        if self.ts_unit is None:
-            self.assertFalse(ts_idx.has_unit)
-        else:
-            self.assertTrue(ts_idx.has_unit)
-            self.assertEqual(ts_idx.unit, self.ts_unit)
+        ts_idx = self._create_index()
+        # test the index
+        self._test_index(ts_idx)
 
     def test_comparable_expression(self):
         # create a timestamp index
-        ts_idx: TSIndex = self.idx_class(self.ts_field)
+        ts_idx = self._create_index()
         # get the expressions
         compbl_expr = ts_idx.comparableExpr()
         # validate the expression
@@ -98,7 +175,7 @@ def test_comparable_expression(self):
 
     def test_orderby_expression(self):
         # create a timestamp index
-        ts_idx: TSIndex = self.idx_class(self.ts_field)
+        ts_idx = self._create_index()
         # get the expressions
         orderby_expr = ts_idx.orderByExpr()
         # validate the expression
@@ -108,7 +185,7 @@ def test_orderby_expression(self):
 
     def test_range_expression(self):
         # create a timestamp index
-        ts_idx = self.idx_class(self.ts_field)
+        ts_idx = self._create_index()
         # get the expressions
         if isinstance(ts_idx, OrdinalTSIndex):
             self.assertRaises(NotImplementedError, ts_idx.rangeExpr)
@@ -128,7 +205,7 @@ def test_range_expression(self):
         "idx_class",
         "ts_unit",
         "expected_comp_expr",
-        "expected_range_expr"
+        "expected_range_expr",
     ),
     [
         (
@@ -141,11 +218,11 @@ def test_range_expression(self):
                 ]),
                 True,
             ),
-            {"parsed_ts_col": "parsed_ts", "src_str_col": "src_str"},
+            {"parsed_ts_field": "parsed_ts", "src_str_field": "src_str"},
             ParsedTimestampIndex,
             StandardTimeUnits.SECONDS,
             "Column<'ts_idx.parsed_ts'>",
-            "Column<'CAST(ts_idx.parsed_ts AS DOUBLE)'>"
+            "Column<'CAST(ts_idx.parsed_ts AS DOUBLE)'>",
         ),
         (
             "parsed_date_index",
@@ -157,47 +234,48 @@ def test_range_expression(self):
                 ]),
                 True,
             ),
-            {"parsed_ts_col": "parsed_date", "src_str_col": "src_str"},
+            {"parsed_ts_field": "parsed_date", "src_str_field": "src_str"},
             ParsedDateIndex,
             StandardTimeUnits.DAYS,
             "Column<'ts_idx.parsed_date'>",
-            "Column<'datediff(ts_idx.parsed_date, CAST(1970-01-01 AS DATE))'>"
+            "Column<'datediff(ts_idx.parsed_date, CAST(1970-01-01 AS DATE))'>",
         ),
         (
             "sub_ms_index",
             StructField(
                 "ts_idx",
                 StructType([
-                    StructField("double_ts", TimestampType(), True),
+                    StructField("double_ts", DoubleType(), True),
                     StructField("parsed_ts", TimestampType(), True),
                     StructField("src_str", StringType(), True),
                 ]),
                 True,
             ),
-            {"double_ts_col": "double_ts", "parsed_ts_col": "parsed_ts", "src_str_col": "src_str"},
-            ParsedTimestampIndex,
-            StandardTimeUnits.SECONDS,
-            "Column<'ts_idx.parsed_ts'>",
-            "Column<'CAST(ts_idx.parsed_ts AS DOUBLE)'>"
+            {
+                "double_ts_field": "double_ts",
+                "secondary_parsed_ts_field": "parsed_ts",
+                "src_str_field": "src_str",
+            },
+            SubMicrosecondPrecisionTimestampIndex,
+            StandardTimeUnits.NANOSECONDS,
+            "Column<'ts_idx.double_ts'>",
+            "Column<'ts_idx.double_ts'>",
         ),
-    ])
-class ParsedTSIndexTests(SparkTest):
-    def test_constructor(self):
+    ],
+)
+class ParsedTSIndexTests(TSIndexTests):
+    def _create_index(self):
+        return self.idx_class(ts_struct=self.ts_field, **self.constr_args)
+
+    def test_index(self):
         # create a timestamp index
-        ts_idx = self.idx_class(ts_struct=self.ts_field, **self.constr_args)
-        # must be a valid TSIndex object
-        self.assertIsNotNone(ts_idx)
-        self.assertIsInstance(ts_idx, self.idx_class)
-        # must have the correct field name and type
-        self.assertEqual(ts_idx.colname, self.ts_field.name)
-        self.assertEqual(ts_idx.dataType, self.ts_field.dataType)
-        # validate the unit
-        self.assertTrue(ts_idx.has_unit)
-        self.assertEqual(ts_idx.unit, self.ts_unit)
+        ts_idx = self._create_index()
+        # test the index
+        self._test_index(ts_idx)
 
     def test_comparable_expression(self):
         # create a timestamp index
-        ts_idx = self.idx_class(ts_struct=self.ts_field, **self.constr_args)
+        ts_idx = self._create_index()
         # get the expressions
         compbl_expr = ts_idx.comparableExpr()
         # validate the expression
@@ -207,7 +285,7 @@ def test_comparable_expression(self):
 
     def test_orderby_expression(self):
         # create a timestamp index
-        ts_idx = self.idx_class(ts_struct=self.ts_field, **self.constr_args)
+        ts_idx = self._create_index()
         # get the expressions
         orderby_expr = ts_idx.orderByExpr()
         # validate the expression
@@ -217,7 +295,7 @@ def test_orderby_expression(self):
 
     def test_range_expression(self):
         # create a timestamp index
-        ts_idx = self.idx_class(ts_struct=self.ts_field, **self.constr_args)
+        ts_idx = self._create_index()
         # get the expressions
         range_expr = ts_idx.rangeExpr()
         # validate the expression
@@ -226,10 +304,183 @@ def test_range_expression(self):
         self.assertEqual(repr(range_expr), self.expected_range_expr)
 
 
-# class TSSchemaTests(SparkTest):
-#     def test_simple_tsIndex(self):
-#         schema_str = "event_ts timestamp, symbol string, trade_pr double"
-#         schema = _parse_datatype_string(schema_str)
-#         ts_idx = TSSchema.fromDFSchema(schema, "event_ts", ["symbol"])
-#
-#         print(ts_idx)
+class TSSchemaTests(TSIndexTests, ABC):
+    @abstractmethod
+    def _create_ts_schema(self) -> TSSchema:
+        pass
+
+
+@parameterized_class(
+    ("name", "df_schema", "ts_col", "series_ids", "idx_class", "ts_unit"),
+    [
+        (
+            "simple_timestamp_index",
+            StructType([
+                StructField("symbol", StringType(), True),
+                StructField("event_ts", TimestampType(), True),
+                StructField("trade_pr", DoubleType(), True),
+                StructField("trade_vol", IntegerType(), True),
+            ]),
+            "event_ts",
+            ["symbol"],
+            SimpleTimestampIndex,
+            StandardTimeUnits.SECONDS,
+        ),
+        (
+            "simple_ts_no_series",
+            StructType([
+                StructField("event_ts", TimestampType(), True),
+                StructField("trade_pr", DoubleType(), True),
+                StructField("trade_vol", IntegerType(), True),
+            ]),
+            "event_ts",
+            [],
+            SimpleTimestampIndex,
+            StandardTimeUnits.SECONDS,
+        ),
+        (
+            "ordinal_double_index",
+            StructType([
+                StructField("symbol", StringType(), True),
+                StructField("event_ts_dbl", DoubleType(), True),
+                StructField("trade_pr", DoubleType(), True),
+            ]),
+            "event_ts_dbl",
+            ["symbol"],
+            OrdinalTSIndex,
+            None,
+        ),
+        (
+            "ordinal_int_index",
+            StructType([
+                StructField("symbol", StringType(), True),
+                StructField("order", IntegerType(), True),
+                StructField("trade_pr", DoubleType(), True),
+            ]),
+            "order",
+            ["symbol"],
+            OrdinalTSIndex,
+            None,
+        ),
+        (
+            "simple_date_index",
+            StructType([
+                StructField("symbol", StringType(), True),
+                StructField("date", DateType(), True),
+                StructField("trade_pr", DoubleType(), True),
+            ]),
+            "date",
+            ["symbol"],
+            SimpleDateIndex,
+            StandardTimeUnits.DAYS,
+        ),
+    ],
+)
+class SimpleIndexTSSchemaTests(TSSchemaTests):
+    def _create_index(self) -> TSIndex:
+        pass
+
+    def _create_ts_schema(self) -> TSSchema:
+        return TSSchema.fromDFSchema(self.df_schema, self.ts_col, self.series_ids)
+
+    def setUp(self) -> None:
+        super().setUp()
+        self.ts_field = self.df_schema[self.ts_col]
+
+    def test_schema(self):
+        print(self.id())
+        # create a TSSchema
+        ts_schema = self._create_ts_schema()
+        # make sure it's a valid TSSchema instance
+        self.assertIsNotNone(ts_schema)
+        self.assertIsInstance(ts_schema, TSSchema)
+        # test the index
+        self._test_index(ts_schema.ts_idx)
+        # test the series ids
+        self.assertEqual(ts_schema.series_ids, self.series_ids)
+        # validate the index
+        ts_schema.validate(self.df_schema)
+
+    def test_structural_cols(self):
+        # create a TSSchema
+        ts_schema = self._create_ts_schema()
+        # test the structural columns
+        struct_cols = [self.ts_col] + self.series_ids
+        self.assertEqual(set(ts_schema.structural_columns), set(struct_cols))
+
+    def test_observational_cols(self):
+        # create a TSSchema
+        ts_schema = self._create_ts_schema()
+        # test the structural columns
+        struct_cols = [self.ts_col] + self.series_ids
+        obs_cols = set(self.df_schema.fieldNames()) - set(struct_cols)
+        self.assertEqual(
+            ts_schema.find_observational_columns(self.df_schema), list(obs_cols)
+        )
+
+    def test_metric_cols(self):
+        # create a TSSchema
+        ts_schema = self._create_ts_schema()
+        # test the metric columns
+        struct_cols = [self.ts_col] + self.series_ids
+        obs_cols = set(self.df_schema.fieldNames()) - set(struct_cols)
+        metric_cols = {
+            mc
+            for mc in obs_cols
+            if isinstance(self.df_schema[mc].dataType, NumericType)
+        }
+        self.assertEqual(
+            set(ts_schema.find_metric_columns(self.df_schema)), metric_cols
+        )
+
+    def test_base_window(self):
+        # create a TSSchema
+        ts_schema = self._create_ts_schema()
+        # test the base window
+        bw = ts_schema.baseWindow()
+        self.assertIsNotNone(bw)
+        self.assertIsInstance(bw, WindowSpec)
+        # test it in reverse
+        bw_rev = ts_schema.baseWindow(reverse=True)
+        self.assertIsNotNone(bw_rev)
+        self.assertIsInstance(bw_rev, WindowSpec)
+
+    def test_rows_window(self):
+        # create a TSSchema
+        ts_schema = self._create_ts_schema()
+        # test the base window
+        rows_win = ts_schema.rowsBetweenWindow(0, 10)
+        self.assertIsNotNone(rows_win)
+        self.assertIsInstance(rows_win, WindowSpec)
+        # test it in reverse
+        rows_win_rev = ts_schema.rowsBetweenWindow(0, 10, reverse=True)
+        self.assertIsNotNone(rows_win_rev)
+        self.assertIsInstance(rows_win_rev, WindowSpec)
+
+    def test_range_window(self):
+        # create a TSSchema
+        ts_schema = self._create_ts_schema()
+        # test the base window
+        if ts_schema.ts_idx.has_unit:
+            range_win = ts_schema.rangeBetweenWindow(0, 10)
+            self.assertIsNotNone(range_win)
+            self.assertIsInstance(range_win, WindowSpec)
+        else:
+            self.assertRaises(NotImplementedError, ts_schema.rangeBetweenWindow, 0, 10)
+        # test it in reverse
+        if ts_schema.ts_idx.has_unit:
+            range_win_rev = ts_schema.rangeBetweenWindow(0, 10, reverse=True)
+            self.assertIsNotNone(range_win_rev)
+            self.assertIsInstance(range_win_rev, WindowSpec)
+        else:
+            self.assertRaises(
+                NotImplementedError, ts_schema.rangeBetweenWindow, 0, 10, reverse=True
+            )
+
+
+class ParsedIndexTSSchemaTests(TSSchemaTests):
+    def _create_index(self) -> TSIndex:
+        pass
+
+    def _create_ts_schema(self) -> TSSchema:
+        pass
diff --git a/python/tests/unit_test_data/tsschema_tests.json b/python/tests/unit_test_data/tsschema_tests.json
index bc820168..3e78def3 100644
--- a/python/tests/unit_test_data/tsschema_tests.json
+++ b/python/tests/unit_test_data/tsschema_tests.json
@@ -1,26 +1,29 @@
 {
   "__SharedData": {
     "simple_ts_idx": {
-      "schema": "symbol string, event_ts string, trade_pr float",
-      "ts_col": "event_ts",
-      "series_ids": ["symbol"],
-      "data": [
-        ["S1", "2020-08-01 00:00:10", 349.21],
-        ["S1", "2020-08-01 00:01:12", 351.32],
-        ["S1", "2020-09-01 00:02:10", 361.1],
-        ["S1", "2020-09-01 00:19:12", 362.1],
-        ["S2", "2020-08-01 00:01:10", 743.01],
-        ["S2", "2020-08-01 00:01:24", 751.92],
-        ["S2", "2020-09-01 00:02:10", 761.10],
-        ["S2", "2020-09-01 00:20:42", 762.33]
-      ]
+      "ts_idx": {
+        "ts_col": "event_ts",
+        "series_ids": ["symbol"]
+      },
+      "df": {
+        "schema": "symbol string, event_ts string, trade_pr float",
+        "ts_convert": ["event_ts"],
+        "data": [
+          ["S1", "2020-08-01 00:00:10", 349.21],
+          ["S1", "2020-08-01 00:01:12", 351.32],
+          ["S1", "2020-09-01 00:02:10", 361.1],
+          ["S1", "2020-09-01 00:19:12", 362.1],
+          ["S2", "2020-08-01 00:01:10", 743.01],
+          ["S2", "2020-08-01 00:01:24", 751.92],
+          ["S2", "2020-09-01 00:02:10", 761.10],
+          ["S2", "2020-09-01 00:20:42", 762.33]
+        ]
+      }
     }
   },
   "TSSchemaTests": {
-    "test_simple_tsIndex": {
-      "simple_ts_idx": {
-        "$ref": "#/__SharedData/simple_ts_idx"
-      }
+    "simple_ts_idx": {
+      "$ref": "#/__SharedData/simple_ts_idx"
     }
   }
 }
\ No newline at end of file