Skip to content

Commit

Permalink
add option to pass in args to mahal dist class
Browse files Browse the repository at this point in the history
  • Loading branch information
rchan26 committed Oct 12, 2023
1 parent b8808e1 commit 0c0a663
Showing 1 changed file with 28 additions and 18 deletions.
46 changes: 28 additions & 18 deletions src/signature_mahalanobis_knn/mahal_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,31 @@


class Mahalanobis:
"""
After fit is called, becomes callable and intended to be used
as a distance function in sklearn nearest neighbour.
"""

def __init__(self):
self.Vt: np.ndarray = np.empty(
0
) # Truncated right singular matrix transposed of the corpus
self.mu: np.ndarray = np.empty(0) # Mean of the corpus
self.S: np.ndarray = np.empty(0) # Truncated singular values of the corpus
self.subspace_thres: float = (
1e-3 # Threshold to decide whether a point is in the data subspace
)
self.svd_thres: float = (
1e-12 # Threshold to decide numerical rank of the data matrix
)
self.numerical_rank: int = -1 # Numerical rank
def __init__(self, subspace_thres: float = 1e-3, svd_thres: float = 1e-12):
"""
After fit is called, becomes callable and intended to be used
as a distance function in sklearn nearest neighbour.
Parameters
----------
subspace_thres : float, optional
Threshold to decide whether a point is in the data subspace,
by default 1e-3.
svd_thres : float, optional
Threshold to decide numerical rank of the data matrix,
by default 1e-12.
"""
self.subspace_thres: subspace_thres
self.svd_thres: svd_thres

# truncated right singular matrix transposed of the corpus
self.Vt: np.ndarray | None = None
# nean of the corpus
self.mu: np.ndarray | None = None
# truncated singular values of the corpus
self.S: np.ndarray | None = None
# numerical rank of the corpus, -1 means not fitted
self.numerical_rank: int | None = None

def fit(self, X: np.ndarray, **kwargs) -> None:
"""
Expand Down Expand Up @@ -83,5 +90,8 @@ def distance(self, x1: np.ndarray, x2: np.ndarray) -> float:
:return: a value representing distance between x, y
"""
if self.numerical_rank is None:
msg = "Mahalanobis distance is not fitted yet."
raise ValueError(msg)

return self.calc_distance(x1, x2, self.Vt, self.S, self.subspace_thres)

0 comments on commit 0c0a663

Please sign in to comment.