diff --git a/python/tempo/tsdf.py b/python/tempo/tsdf.py index c8e9ce30..6cd6ab54 100644 --- a/python/tempo/tsdf.py +++ b/python/tempo/tsdf.py @@ -110,7 +110,7 @@ def __validate_ts_string(ts_text: str) -> None: @staticmethod def __validated_column(df: DataFrame, colname: str) -> str: - if colname is not str: + if type(colname) != str: raise TypeError( f"Column names must be of type str; found {type(colname)} instead!" ) @@ -122,20 +122,19 @@ def __validated_columns( self, df: DataFrame, colnames: Optional[Union[str, List[str]]] ) -> List[str]: # if provided a string, treat it as a single column - valid_colnames: List[str] = [] - if colnames is str: - valid_colnames = [str(colnames)] + if type(colnames) == str: + colnames = [colnames] # otherwise we really should have a list or None elif colnames is None: - valid_colnames = [] - elif colnames is not list: + colnames = [] + elif type(colnames) != list: raise TypeError( f"Columns must be of type list, str, or None; found {type(colnames)} instead!" ) # validate each column - for col in valid_colnames: + for col in colnames: self.__validated_column(df, col) - return valid_colnames + return colnames def __checkPartitionCols(self, tsdf_right: "TSDF") -> None: for left_col, right_col in zip(self.partitionCols, tsdf_right.partitionCols):