Skip to content

Commit

Permalink
checkpoint save - more advanced schema testing
Browse files Browse the repository at this point in the history
  • Loading branch information
tnixon committed Jan 15, 2024
1 parent d488d7c commit 2127f3b
Show file tree
Hide file tree
Showing 4 changed files with 440 additions and 158 deletions.
142 changes: 90 additions & 52 deletions python/tempo/tsschema.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,39 +545,39 @@ class ParsedTSIndex(CompositeTSIndex, ABC):
"""

def __init__(
self, ts_struct: StructField, parsed_ts_col: str, src_str_col: str
self, ts_struct: StructField, parsed_ts_field: str, src_str_field: str
) -> None:
super().__init__(ts_struct, parsed_ts_col)
super().__init__(ts_struct, parsed_ts_field)
# validate the source string column
src_str_field = self.schema[src_str_col]
if not isinstance(src_str_field.dataType, StringType):
src_str_type = self.schema[src_str_field].dataType
if not isinstance(src_str_type, StringType):
raise TypeError(
"Source string column must be of StringType, "
f"but given column {src_str_field.name} "
f"is of type {src_str_field.dataType}"
f"but given column {src_str_field} "
f"is of type {src_str_type}"
)
self._src_str_col = src_str_col
self._src_str_field = src_str_field
# validate the parsed column
assert parsed_ts_col in self.schema.fieldNames(), (
f"The parsed timestamp index field {parsed_ts_col} does not exist in the "
assert parsed_ts_field in self.schema.fieldNames(), (
f"The parsed timestamp index field {parsed_ts_field} does not exist in the "
f"MultiPart TSIndex schema {self.schema}"
)
self._parsed_ts_col = parsed_ts_col
self._parsed_ts_field = parsed_ts_field

@property
def src_str_col(self):
return self.fieldPath(self._src_str_col)
def src_str_field(self):
return self.fieldPath(self._src_str_field)

@property
def parsed_ts_col(self):
return self.fieldPath(self._parsed_ts_col)
def parsed_ts_field(self):
return self.fieldPath(self._parsed_ts_field)

def orderByExpr(self, reverse: bool = False) -> Union[Column, List[Column]]:
expr = sfn.col(self.parsed_ts_col)
expr = sfn.col(self.parsed_ts_field)
return _reverse_or_not(expr, reverse)

def comparableExpr(self) -> Column:
return sfn.col(self.parsed_ts_col)
return sfn.col(self.parsed_ts_field)

@classmethod
def fromParsedTimestamp(
Expand Down Expand Up @@ -643,7 +643,7 @@ def unit(self) -> Optional[TimeUnit]:

def rangeExpr(self, reverse: bool = False) -> Column:
# cast timestamp to double (fractional seconds since epoch)
expr = sfn.col(self.parsed_ts_col).cast("double")
expr = sfn.col(self.parsed_ts_field).cast("double")
return _reverse_or_not(expr, reverse)


Expand All @@ -659,7 +659,7 @@ def unit(self) -> Optional[TimeUnit]:
def rangeExpr(self, reverse: bool = False) -> Column:
# convert date to number of days since the epoch
expr = sfn.datediff(
sfn.col(self.parsed_ts_col),
sfn.col(self.parsed_ts_field),
sfn.lit(EPOCH_START_DATE).cast("date"),
)
return _reverse_or_not(expr, reverse)
Expand All @@ -676,31 +676,31 @@ class SubMicrosecondPrecisionTimestampIndex(CompositeTSIndex):
def __init__(
self,
ts_struct: StructField,
double_ts_col: str,
parsed_ts_col: str,
src_str_col: str,
double_ts_field: str,
secondary_parsed_ts_field: str,
src_str_field: str,
num_precision_digits: int = 9,
) -> None:
"""
:param ts_struct: The StructField for the TSIndex column
:param double_ts_col: The name of the double-precision timestamp column
:param parsed_ts_col: The name of the parsed timestamp column
:param src_str_col: The name of the source string column
:param double_ts_field: The name of the double-precision timestamp column
:param secondary_parsed_ts_field: The name of the parsed timestamp column
:param src_str_field: The name of the source string column
:param num_precision_digits: The number of digits that make up the precision of
the timestamp. Ie. 9 for nanoseconds (default), 12 for picoseconds, etc.
You will receive a warning if this value is 6 or less, as this is the precision
of the standard timestamp type.
"""
super().__init__(ts_struct, double_ts_col)
super().__init__(ts_struct, double_ts_field)
# validate the double timestamp column
double_ts_field = self.schema[double_ts_col]
if not isinstance(double_ts_field.dataType, DoubleType):
double_ts_type = self.schema[double_ts_field].dataType
if not isinstance(double_ts_type, DoubleType):
raise TypeError(
"The double_ts_col must be of DoubleType, "
f"but the given double_ts_col {double_ts_col} "
f"has type {double_ts_field.dataType}"
f"but the given double_ts_col {double_ts_field} "
f"has type {double_ts_type}"
)
self.double_ts_col = double_ts_col
self.double_ts_field = double_ts_field
# validate the number of precision digits
if num_precision_digits <= 6:
warnings.warn(
Expand All @@ -715,30 +715,30 @@ def __init__(
num_precision_digits,
)
# validate the parsed column as a timestamp column
parsed_ts_field = self.schema[parsed_ts_col]
if not isinstance(parsed_ts_field.dataType, TimestampType):
parsed_ts_type = self.schema[secondary_parsed_ts_field].dataType
if not isinstance(parsed_ts_type, TimestampType):
raise TypeError(
"parsed_ts_col field must be of TimestampType, "
f"but the given parsed_ts_col {parsed_ts_col} "
f"has type {parsed_ts_field.dataType}"
f"but the given parsed_ts_col {secondary_parsed_ts_field} "
f"has type {parsed_ts_type}"
)
self.parsed_ts_col = parsed_ts_col
self.parsed_ts_field = secondary_parsed_ts_field
# validate the source column as a string column
src_str_field = self.schema[src_str_col]
src_str_field = self.schema[src_str_field]
if not isinstance(src_str_field.dataType, StringType):
raise TypeError(
"src_str_col field must be of StringType, "
f"but the given src_str_col {src_str_col} "
f"but the given src_str_col {src_str_field} "
f"has type {src_str_field.dataType}"
)
self.src_str_col = src_str_col
self.src_str_col = src_str_field

@property
def unit(self) -> Optional[TimeUnit]:
return self.__unit

def comparableExpr(self) -> Column:
return sfn.col(self.fieldPath(self.double_ts_col))
return sfn.col(self.fieldPath(self.double_ts_field))

def orderByExpr(self, reverse: bool = False) -> Union[Column, List[Column]]:
return _reverse_or_not(self.comparableExpr(), reverse)
Expand Down Expand Up @@ -851,30 +851,68 @@ def __eq__(self, o: object) -> bool:
# must be of TSSchema type
if not isinstance(o, TSSchema):
return False
# must have same TSIndex
if self.ts_idx != o.ts_idx:
# must have TSIndices with the same unit
if not ((self.ts_idx.has_unit == o.ts_idx.has_unit)
and
(self.ts_idx.unit == o.ts_idx.unit)):
return False
# must have the same series IDs
if self.series_ids != o.series_ids:
return False
return True

def __repr__(self) -> str:
return self.__str__()

def __str__(self) -> str:
return f"""TSSchema({id(self)})
TSIndex: {self.ts_idx}
Series IDs: {self.series_ids}"""
return f"{self.__class__.__name__}(ts_idx={self.ts_idx}, series_ids={self.series_ids})"

@classmethod
def fromDFSchema(
cls, df_schema: StructType, ts_col: str, series_ids: Collection[str] = None
) -> "TSSchema":
def fromDFSchema(cls,
df_schema: StructType,
ts_col: str,
series_ids: Optional[Collection[str]] = None) -> "TSSchema":
# construct a TSIndex for the given ts_col
ts_idx = SimpleTSIndex.fromTSCol(df_schema[ts_col])
return cls(ts_idx, series_ids)

@classmethod
def fromParsedTSIndex(cls,
df_schema: StructType,
ts_col: str,
parsed_field: str,
src_str_field: str,
series_ids: Optional[Collection[str]] = None,
secondary_parsed_field: Optional[str] = None) -> "TSSchema":
ts_idx_schema = df_schema[ts_col].dataType
assert isinstance(ts_idx_schema, StructType), \
f"Expected a StructType for ts_col {ts_col}, but got {ts_idx_schema}"
# construct the TSIndex
parsed_type = ts_idx_schema[parsed_field].dataType
if isinstance(parsed_type, DoubleType):
ts_idx = SubMicrosecondPrecisionTimestampIndex(
df_schema[ts_col],
parsed_field,
secondary_parsed_field,
src_str_field,
)
elif isinstance(parsed_type, TimestampType):
ts_idx = ParsedTimestampIndex(
df_schema[ts_col],
parsed_field,
src_str_field,
)
elif isinstance(parsed_type, DateType):
ts_idx = ParsedDateIndex(
df_schema[ts_col],
parsed_field,
src_str_field,
)
else:
raise TypeError(
f"Expected a DoubleType, TimestampType or DateType for parsed_field {parsed_field}, "
f"but got {parsed_type}"
)
# construct the TSSchema
return cls(ts_idx, series_ids)

@property
def structural_columns(self) -> list[str]:
"""
Expand All @@ -898,7 +936,7 @@ def find_observational_columns(self, df_schema: StructType) -> list[str]:
return list(set(df_schema.fieldNames()) - set(self.structural_columns))

@classmethod
def is_metric_col(cls, col: StructField) -> bool:
def __is_metric_col_type(cls, col: StructField) -> bool:
return isinstance(col.dataType, NumericType) or isinstance(
col.dataType, BooleanType
)
Expand All @@ -907,7 +945,7 @@ def find_metric_columns(self, df_schema: StructType) -> list[str]:
return [
col.name
for col in df_schema.fields
if self.is_metric_col(col)
if self.__is_metric_col_type(col)
and (col.name in self.find_observational_columns(df_schema))
]

Expand Down
58 changes: 24 additions & 34 deletions python/tests/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,8 @@ def tearDownClass(cls) -> None:
cls.spark.stop()

def setUp(self) -> None:
self.test_data = self.__loadTestData(self.id())

def tearDown(self) -> None:
del self.test_data
if self.test_data is None:
self.test_data = self.__loadTestData(self.id())

#
# Utility Functions
Expand All @@ -73,16 +71,15 @@ def get_data_as_sdf(self, name: str, convert_ts_col=True):
if convert_ts_col and (td.get("ts_col", None) or td.get("other_ts_cols", [])):
ts_cols = [td["ts_col"]] if "ts_col" in td else []
ts_cols.extend(td.get("other_ts_cols", []))
return self.buildTestDF(td["schema"], td["data"], ts_cols)
return self.buildTestDF(td["df"])

def get_data_as_tsdf(self, name: str, convert_ts_col=True):
df = self.get_data_as_sdf(name, convert_ts_col)
td = self.test_data[name]
if "sequence_col" in td:
tsdf = TSDF.fromSubsequenceCol(df,
td["ts_col"],
td["sequence_col"],
td.get("series_ids", None))
tsdf = TSDF.fromSubsequenceCol(
df, td["ts_col"], td["sequence_col"], td.get("series_ids", None)
)
else:
tsdf = TSDF(df, ts_col=td["ts_col"], series_ids=td.get("series_ids", None))
return tsdf
Expand Down Expand Up @@ -112,7 +109,8 @@ def __getTestDataFilePath(self, test_file_name: str) -> str:
dir_path = "./tests"
elif cwd != "tests":
raise RuntimeError(
f"Cannot locate test data file {test_file_name}, running from dir {os.getcwd()}"
f"Cannot locate test data file {test_file_name}, running from dir"
f" {os.getcwd()}"
)

# return appropriate path
Expand Down Expand Up @@ -140,35 +138,27 @@ def __loadTestData(self, test_case_path: str) -> dict:
if class_name not in data_metadata_from_json:
warnings.warn(f"Could not load test data for {file_name}.{class_name}")
return {}
if func_name not in data_metadata_from_json[class_name]:
warnings.warn(
f"Could not load test data for {file_name}.{class_name}.{func_name}"
)
return {}
return data_metadata_from_json[class_name][func_name]

def buildTestDF(self, schema, data, ts_cols=["event_ts"]):
# if func_name not in data_metadata_from_json[class_name]:
# warnings.warn(
# f"Could not load test data for {file_name}.{class_name}.{func_name}"
# )
# return {}
# return data_metadata_from_json[class_name][func_name]
return data_metadata_from_json[class_name]

def buildTestDF(self, df_spec):
"""
Constructs a Spark Dataframe from the given components
:param schema: the schema to use for the Dataframe
:param data: values to use for the Dataframe
:param ts_cols: list of column names to be converted to Timestamp values
:param df_spec: a dictionary containing the following keys: schema, data, ts_convert
:return: a Spark Dataframe, constructed from the given schema and values
"""
# build dataframe
df = self.spark.createDataFrame(data, schema)

# check if ts_col follows standard timestamp format, then check if timestamp has micro/nanoseconds
for tsc in ts_cols:
ts_value = str(df.select(ts_cols).limit(1).collect()[0][0])
ts_pattern = r"^\d{4}-\d{2}-\d{2}| \d{2}:\d{2}:\d{2}\.\d*$"
decimal_pattern = r"[.]\d+"
if re.match(ts_pattern, str(ts_value)) is not None:
if (
re.search(decimal_pattern, ts_value) is None
or len(re.search(decimal_pattern, ts_value)[0]) <= 4
):
df = df.withColumn(tsc, sfn.to_timestamp(sfn.col(tsc)))
df = self.spark.createDataFrame(df_spec['data'], df_spec['schema'])

# convert timestamp columns
if 'ts_convert' in df_spec:
for ts_col in df_spec['ts_convert']:
df = df.withColumn(ts_col, sfn.to_timestamp(ts_col))
return df

#
Expand Down
Loading

0 comments on commit 2127f3b

Please sign in to comment.