From 0c0a66386d33dd4a83471ff6ccd7201b00be640d Mon Sep 17 00:00:00 2001 From: rchan Date: Thu, 12 Oct 2023 11:43:37 +0100 Subject: [PATCH] add option to pass in args to mahal dist class --- .../mahal_distance.py | 46 +++++++++++-------- 1 file changed, 28 insertions(+), 18 deletions(-) diff --git a/src/signature_mahalanobis_knn/mahal_distance.py b/src/signature_mahalanobis_knn/mahal_distance.py index 4791667..7e89ecc 100644 --- a/src/signature_mahalanobis_knn/mahal_distance.py +++ b/src/signature_mahalanobis_knn/mahal_distance.py @@ -5,24 +5,31 @@ class Mahalanobis: - """ - After fit is called, becomes callable and intended to be used - as a distance function in sklearn nearest neighbour. - """ - - def __init__(self): - self.Vt: np.ndarray = np.empty( - 0 - ) # Truncated right singular matrix transposed of the corpus - self.mu: np.ndarray = np.empty(0) # Mean of the corpus - self.S: np.ndarray = np.empty(0) # Truncated singular values of the corpus - self.subspace_thres: float = ( - 1e-3 # Threshold to decide whether a point is in the data subspace - ) - self.svd_thres: float = ( - 1e-12 # Threshold to decide numerical rank of the data matrix - ) - self.numerical_rank: int = -1 # Numerical rank + def __init__(self, subspace_thres: float = 1e-3, svd_thres: float = 1e-12): + """ + After fit is called, becomes callable and intended to be used + as a distance function in sklearn nearest neighbour. + + Parameters + ---------- + subspace_thres : float, optional + Threshold to decide whether a point is in the data subspace, + by default 1e-3. + svd_thres : float, optional + Threshold to decide numerical rank of the data matrix, + by default 1e-12. + """ + self.subspace_thres: subspace_thres + self.svd_thres: svd_thres + + # truncated right singular matrix transposed of the corpus + self.Vt: np.ndarray | None = None + # nean of the corpus + self.mu: np.ndarray | None = None + # truncated singular values of the corpus + self.S: np.ndarray | None = None + # numerical rank of the corpus, -1 means not fitted + self.numerical_rank: int | None = None def fit(self, X: np.ndarray, **kwargs) -> None: """ @@ -83,5 +90,8 @@ def distance(self, x1: np.ndarray, x2: np.ndarray) -> float: :return: a value representing distance between x, y """ + if self.numerical_rank is None: + msg = "Mahalanobis distance is not fitted yet." + raise ValueError(msg) return self.calc_distance(x1, x2, self.Vt, self.S, self.subspace_thres)