Skip to content

Commit

Permalink
solving some type-check issues
Browse files Browse the repository at this point in the history
  • Loading branch information
tnixon committed Jan 2, 2024
1 parent 88bb03b commit 1c18b01
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 19 deletions.
5 changes: 3 additions & 2 deletions python/tempo/tsdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def __init__(
if ts_schema:
self.ts_schema = ts_schema
else:
assert ts_col is not None
self.ts_schema = TSSchema.fromDFSchema(self.df.schema, ts_col, series_ids)
# validate that this schema works for this DataFrame
self.ts_schema.validate(df.schema)
Expand Down Expand Up @@ -114,7 +115,7 @@ def fromSubsequenceCol(
df: DataFrame,
ts_col: str,
subsequence_col: str,
series_ids: Collection[str] = None,
series_ids: Optional[Collection[str]] = None,
) -> "TSDF":
# construct a struct with the ts_col and subsequence_col
struct_col_name = cls.__DEFAULT_TS_IDX_COL
Expand All @@ -132,7 +133,7 @@ def fromTimestampString(
cls,
df: DataFrame,
ts_col: str,
series_ids: Collection[str] = None,
series_ids: Optional[Collection[str]] = None,
ts_fmt: str = "YYYY-MM-DDThh:mm:ss[.SSSSSS]",
) -> "TSDF":
pass
Expand Down
31 changes: 14 additions & 17 deletions python/tempo/tsschema.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def orderByExpr(self, reverse: bool = False) -> Union[Column, List[Column]]:
"""

@abstractmethod
def rangeExpr(self, reverse: bool = False) -> Column:
def rangeExpr(self, reverse: bool = False) -> Union[Column, List[Column]]:
"""
Gets an expression appropriate for performing range operations on the :class:`TSDF` records.
Expand Down Expand Up @@ -176,7 +176,7 @@ def renamed(self, new_name: str) -> "TSIndex":
self.__name = new_name
return self

def orderByExpr(self, reverse: bool = False) -> Column:
def orderByExpr(self, reverse: bool = False) -> Union[Column, List[Column]]:
expr = sfn.col(self.colname)
return self._reverseOrNot(expr, reverse)

Expand Down Expand Up @@ -211,7 +211,7 @@ def __init__(self, ts_idx: StructField) -> None:
def unit(self) -> Optional[TimeUnits]:
return None

def rangeExpr(self, reverse: bool = False) -> Column:
def rangeExpr(self, reverse: bool = False) -> Union[Column, List[Column]]:
return self.orderByExpr(reverse)


Expand All @@ -231,7 +231,7 @@ def __init__(self, ts_idx: StructField) -> None:
def unit(self) -> Optional[TimeUnits]:
return TimeUnits.SECONDS

def rangeExpr(self, reverse: bool = False) -> Column:
def rangeExpr(self, reverse: bool = False) -> Union[Column, List[Column]]:
# cast timestamp to double (fractional seconds since epoch)
expr = sfn.col(self.colname).cast("double")
return self._reverseOrNot(expr, reverse)
Expand All @@ -253,7 +253,7 @@ def __init__(self, ts_idx: StructField) -> None:
def unit(self) -> Optional[TimeUnits]:
return TimeUnits.DAYS

def rangeExpr(self, reverse: bool = False) -> Column:
def rangeExpr(self, reverse: bool = False) -> Union[Column, List[Column]]:
# convert date to number of days since the epoch
expr = sfn.datediff(sfn.col(self.colname), sfn.lit("1970-01-01").cast("date"))
return self._reverseOrNot(expr, reverse)
Expand Down Expand Up @@ -350,12 +350,12 @@ def ts_component(self, component_index: int) -> str:
"""
return self.component(self.ts_components[component_index].colname)

def orderByExpr(self, reverse: bool = False) -> Column:
def orderByExpr(self, reverse: bool = False) -> Union[Column, List[Column]]:
# build an expression for each TS component, in order
exprs = [sfn.col(self.component(comp.colname)) for comp in self.ts_components]
return self._reverseOrNot(exprs, reverse)

def rangeExpr(self, reverse: bool = False) -> Column:
def rangeExpr(self, reverse: bool = False) -> Union[Column, List[Column]]:
return self.primary_ts_idx.rangeExpr(reverse)


Expand All @@ -366,7 +366,7 @@ class ParsedTSIndex(CompositeTSIndex, ABC):
"""

def __init__(self, ts_idx: StructField, src_str_col: str, parsed_col: str) -> None:
super().__init__(ts_idx, primary_ts_col=parsed_col)
super().__init__(ts_idx, parsed_col)
src_str_field = self.struct[src_str_col]
if not isinstance(src_str_field.dataType, StringType):
raise TypeError(
Expand All @@ -390,9 +390,8 @@ def validate(self, df_schema: StructType) -> None:
composite_idx_type: StructType = cast(
StructType, df_schema[self.colname].dataType
)
assert (
self.__src_str_col in composite_idx_type
), f"The src_str_col column {self.src_str_col} does not exist in the composite field {composite_idx_type}"
assert (self.__src_str_col in composite_idx_type.fieldNames()), \
f"The src_str_col column {self.src_str_col} does not exist in the composite field {composite_idx_type}"
# make sure it's StringType
src_str_field_type = composite_idx_type[self.__src_str_col].dataType
assert isinstance(
Expand All @@ -412,7 +411,7 @@ def __init__(self, ts_idx: StructField, src_str_col: str, parsed_col: str) -> No
f"ParsedTimestampIndex must be of TimestampType, but given ts_col {self.primary_ts_idx.colname} has type {self.primary_ts_idx.dataType}"
)

def rangeExpr(self, reverse: bool = False) -> Column:
def rangeExpr(self, reverse: bool = False) -> Union[Column, List[Column]]:
# cast timestamp to double (fractional seconds since epoch)
expr = sfn.col(self.primary_ts_col).cast("double")
return self._reverseOrNot(expr, reverse)
Expand All @@ -430,7 +429,7 @@ def __init__(self, ts_idx: StructField, src_str_col: str, parsed_col: str) -> No
f"ParsedDateIndex must be of DateType, but given ts_col {self.primary_ts_idx.colname} has type {self.primary_ts_idx.dataType}"
)

def rangeExpr(self, reverse: bool = False) -> Column:
def rangeExpr(self, reverse: bool = False) -> Union[Column, List[Column]]:
# convert date to number of days since the epoch
expr = sfn.datediff(
sfn.col(self.primary_ts_col), sfn.lit("1970-01-01").cast("date")
Expand Down Expand Up @@ -522,7 +521,7 @@ class TSSchema(WindowBuilder):
Schema type for a :class:`TSDF` class.
"""

def __init__(self, ts_idx: TSIndex, series_ids: Collection[str] = None) -> None:
def __init__(self, ts_idx: TSIndex, series_ids: Optional[Collection[str]]) -> None:
self.__ts_idx = ts_idx
if series_ids:
self.__series_ids = list(series_ids)
Expand Down Expand Up @@ -558,9 +557,7 @@ def __str__(self) -> str:
Series IDs: {self.series_ids}"""

@classmethod
def fromDFSchema(
cls, df_schema: StructType, ts_col: str, series_ids: Collection[str] = None
) -> "TSSchema":
def fromDFSchema(cls, df_schema: StructType, ts_col: str, series_ids: Optional[Collection[str]]) -> "TSSchema":
# construct a TSIndex for the given ts_col
ts_idx = SimpleTSIndex.fromTSCol(df_schema[ts_col])
return cls(ts_idx, series_ids)
Expand Down

0 comments on commit 1c18b01

Please sign in to comment.