From 1c18b013a2499fad584328897c948aaa3b1ecfca Mon Sep 17 00:00:00 2001 From: Tristan Nixon Date: Tue, 2 Jan 2024 14:09:41 -0800 Subject: [PATCH] solving some type-check issues --- python/tempo/tsdf.py | 5 +++-- python/tempo/tsschema.py | 31 ++++++++++++++----------------- 2 files changed, 17 insertions(+), 19 deletions(-) diff --git a/python/tempo/tsdf.py b/python/tempo/tsdf.py index 68e4a3ab..02358e6c 100644 --- a/python/tempo/tsdf.py +++ b/python/tempo/tsdf.py @@ -47,6 +47,7 @@ def __init__( if ts_schema: self.ts_schema = ts_schema else: + assert ts_col is not None self.ts_schema = TSSchema.fromDFSchema(self.df.schema, ts_col, series_ids) # validate that this schema works for this DataFrame self.ts_schema.validate(df.schema) @@ -114,7 +115,7 @@ def fromSubsequenceCol( df: DataFrame, ts_col: str, subsequence_col: str, - series_ids: Collection[str] = None, + series_ids: Optional[Collection[str]] = None, ) -> "TSDF": # construct a struct with the ts_col and subsequence_col struct_col_name = cls.__DEFAULT_TS_IDX_COL @@ -132,7 +133,7 @@ def fromTimestampString( cls, df: DataFrame, ts_col: str, - series_ids: Collection[str] = None, + series_ids: Optional[Collection[str]] = None, ts_fmt: str = "YYYY-MM-DDThh:mm:ss[.SSSSSS]", ) -> "TSDF": pass diff --git a/python/tempo/tsschema.py b/python/tempo/tsschema.py index d04ca0d0..95811f1d 100644 --- a/python/tempo/tsschema.py +++ b/python/tempo/tsschema.py @@ -123,7 +123,7 @@ def orderByExpr(self, reverse: bool = False) -> Union[Column, List[Column]]: """ @abstractmethod - def rangeExpr(self, reverse: bool = False) -> Column: + def rangeExpr(self, reverse: bool = False) -> Union[Column, List[Column]]: """ Gets an expression appropriate for performing range operations on the :class:`TSDF` records. @@ -176,7 +176,7 @@ def renamed(self, new_name: str) -> "TSIndex": self.__name = new_name return self - def orderByExpr(self, reverse: bool = False) -> Column: + def orderByExpr(self, reverse: bool = False) -> Union[Column, List[Column]]: expr = sfn.col(self.colname) return self._reverseOrNot(expr, reverse) @@ -211,7 +211,7 @@ def __init__(self, ts_idx: StructField) -> None: def unit(self) -> Optional[TimeUnits]: return None - def rangeExpr(self, reverse: bool = False) -> Column: + def rangeExpr(self, reverse: bool = False) -> Union[Column, List[Column]]: return self.orderByExpr(reverse) @@ -231,7 +231,7 @@ def __init__(self, ts_idx: StructField) -> None: def unit(self) -> Optional[TimeUnits]: return TimeUnits.SECONDS - def rangeExpr(self, reverse: bool = False) -> Column: + def rangeExpr(self, reverse: bool = False) -> Union[Column, List[Column]]: # cast timestamp to double (fractional seconds since epoch) expr = sfn.col(self.colname).cast("double") return self._reverseOrNot(expr, reverse) @@ -253,7 +253,7 @@ def __init__(self, ts_idx: StructField) -> None: def unit(self) -> Optional[TimeUnits]: return TimeUnits.DAYS - def rangeExpr(self, reverse: bool = False) -> Column: + def rangeExpr(self, reverse: bool = False) -> Union[Column, List[Column]]: # convert date to number of days since the epoch expr = sfn.datediff(sfn.col(self.colname), sfn.lit("1970-01-01").cast("date")) return self._reverseOrNot(expr, reverse) @@ -350,12 +350,12 @@ def ts_component(self, component_index: int) -> str: """ return self.component(self.ts_components[component_index].colname) - def orderByExpr(self, reverse: bool = False) -> Column: + def orderByExpr(self, reverse: bool = False) -> Union[Column, List[Column]]: # build an expression for each TS component, in order exprs = [sfn.col(self.component(comp.colname)) for comp in self.ts_components] return self._reverseOrNot(exprs, reverse) - def rangeExpr(self, reverse: bool = False) -> Column: + def rangeExpr(self, reverse: bool = False) -> Union[Column, List[Column]]: return self.primary_ts_idx.rangeExpr(reverse) @@ -366,7 +366,7 @@ class ParsedTSIndex(CompositeTSIndex, ABC): """ def __init__(self, ts_idx: StructField, src_str_col: str, parsed_col: str) -> None: - super().__init__(ts_idx, primary_ts_col=parsed_col) + super().__init__(ts_idx, parsed_col) src_str_field = self.struct[src_str_col] if not isinstance(src_str_field.dataType, StringType): raise TypeError( @@ -390,9 +390,8 @@ def validate(self, df_schema: StructType) -> None: composite_idx_type: StructType = cast( StructType, df_schema[self.colname].dataType ) - assert ( - self.__src_str_col in composite_idx_type - ), f"The src_str_col column {self.src_str_col} does not exist in the composite field {composite_idx_type}" + assert (self.__src_str_col in composite_idx_type.fieldNames()), \ + f"The src_str_col column {self.src_str_col} does not exist in the composite field {composite_idx_type}" # make sure it's StringType src_str_field_type = composite_idx_type[self.__src_str_col].dataType assert isinstance( @@ -412,7 +411,7 @@ def __init__(self, ts_idx: StructField, src_str_col: str, parsed_col: str) -> No f"ParsedTimestampIndex must be of TimestampType, but given ts_col {self.primary_ts_idx.colname} has type {self.primary_ts_idx.dataType}" ) - def rangeExpr(self, reverse: bool = False) -> Column: + def rangeExpr(self, reverse: bool = False) -> Union[Column, List[Column]]: # cast timestamp to double (fractional seconds since epoch) expr = sfn.col(self.primary_ts_col).cast("double") return self._reverseOrNot(expr, reverse) @@ -430,7 +429,7 @@ def __init__(self, ts_idx: StructField, src_str_col: str, parsed_col: str) -> No f"ParsedDateIndex must be of DateType, but given ts_col {self.primary_ts_idx.colname} has type {self.primary_ts_idx.dataType}" ) - def rangeExpr(self, reverse: bool = False) -> Column: + def rangeExpr(self, reverse: bool = False) -> Union[Column, List[Column]]: # convert date to number of days since the epoch expr = sfn.datediff( sfn.col(self.primary_ts_col), sfn.lit("1970-01-01").cast("date") @@ -522,7 +521,7 @@ class TSSchema(WindowBuilder): Schema type for a :class:`TSDF` class. """ - def __init__(self, ts_idx: TSIndex, series_ids: Collection[str] = None) -> None: + def __init__(self, ts_idx: TSIndex, series_ids: Optional[Collection[str]]) -> None: self.__ts_idx = ts_idx if series_ids: self.__series_ids = list(series_ids) @@ -558,9 +557,7 @@ def __str__(self) -> str: Series IDs: {self.series_ids}""" @classmethod - def fromDFSchema( - cls, df_schema: StructType, ts_col: str, series_ids: Collection[str] = None - ) -> "TSSchema": + def fromDFSchema(cls, df_schema: StructType, ts_col: str, series_ids: Optional[Collection[str]]) -> "TSSchema": # construct a TSIndex for the given ts_col ts_idx = SimpleTSIndex.fromTSCol(df_schema[ts_col]) return cls(ts_idx, series_ids)