diff --git a/src/tabmat/formula.py b/src/tabmat/formula.py index 4cea1162..a4471211 100644 --- a/src/tabmat/formula.py +++ b/src/tabmat/formula.py @@ -345,7 +345,7 @@ def to_tabmat( dtype: numpy.dtype = numpy.float64, sparse_threshold: float = 0.1, cat_threshold: int = 4, - ) -> DenseMatrix: + ) -> Union[SparseMatrix, DenseMatrix]: if (self.values != 0).mean() > sparse_threshold: return DenseMatrix(self.values) else: @@ -439,7 +439,7 @@ def to_tabmat( dtype: numpy.dtype = numpy.float64, sparse_threshold: float = 0.1, cat_threshold: int = 4, - ) -> Union[CategoricalMatrix, SplitMatrix]: + ) -> Union[DenseMatrix, CategoricalMatrix, SplitMatrix]: codes = self.codes.copy() categories = self.categories.copy() if -2 in self.codes: