Skip to content

Commit

Permalink
Merge branch 'main' into pre-comp-signature-passing
Browse files Browse the repository at this point in the history
  • Loading branch information
rchan26 committed Oct 12, 2023
2 parents 1470dca + c4c2c68 commit 1d2ad4c
Showing 1 changed file with 45 additions and 16 deletions.
61 changes: 45 additions & 16 deletions src/signature_mahalanobis_knn/mahal_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,41 @@


class Mahalanobis:
def __init__(self):
def __init__(
self,
subspace_thres: float = 1e-3,
svd_thres: float = 1e-12,
zero_thres: float = 1e-15,
):
"""
After fit is called, becomes callable and intended to be used
as a distance function in sklearn nearest neighbour.
Parameters
----------
subspace_thres : float, optional
Threshold to decide whether a point is in the data subspace,
by default 1e-3.
svd_thres : float, optional
Threshold to decide numerical rank of the data matrix,
by default 1e-12.
zero_thres : float, optional
Threshold to decide whether the distance is zero,
by default 1e-15.
"""
self.Vt: np.ndarray = np.empty(
0
) # Truncated right singular matrix transposed of the corpus
self.mu: np.ndarray = np.empty(0) # Mean of the corpus
self.S: np.ndarray = np.empty(0) # Truncated singular values of the corpus
self.subspace_thres: float = (
1e-3 # Threshold to decide whether a point is in the data subspace
)
self.svd_thres: float = (
1e-12 # Threshold to decide numerical rank of the data matrix
)
self.numerical_rank: int = -1 # Numerical rank
self.default_dtype = np.float64
self.subspace_thres: float = subspace_thres
self.svd_thres: float = svd_thres
self.zero_thres: float = zero_thres

# set the following after fit() is called - None means not fitted yet
# truncated right singular matrix transposed of the corpus
self.Vt: np.ndarray | None = None
# nean of the corpus
self.mu: np.ndarray | None = None
# truncated singular values of the corpus
self.S: np.ndarray | None = None
# numerical rank of the corpus
self.numerical_rank: int | None = None

def fit(self, X: np.ndarray, y: None = None, **kwargs) -> None: # noqa: ARG002
"""
Expand Down Expand Up @@ -53,6 +70,7 @@ def calc_distance(
Vt: np.ndarray,
S: np.ndarray,
subspace_thres: float,
zero_thres: float,
) -> float:
"""
Compute the variance norm between x1 and x2 using the precomputed SVD.
Expand All @@ -77,7 +95,7 @@ def calc_distance(
"""
x = x1 - x2
norm_x = np.linalg.norm(x)
if norm_x < 1e-15:
if norm_x < zero_thres:
return 0.0

# quantifies the amount that x is outside the row-subspace
Expand All @@ -103,8 +121,19 @@ def distance(self, x1: np.ndarray, x2: np.ndarray) -> float:
float
Value representing distance between x, y.
"""
if self.numerical_rank is None:
msg = "Mahalanobis distance is not fitted yet."
raise ValueError(msg)

# ensure inputs are the right data type
x1 = x1.astype(self.default_dtype)
x2 = x2.astype(self.default_dtype)

return self.calc_distance(x1, x2, self.Vt, self.S, self.subspace_thres)
return self.calc_distance(
x1=x1,
x2=x2,
Vt=self.Vt,
S=self.S,
subspace_thres=self.subspace_thres,
zero_thres=self.zero_thres,
)

0 comments on commit 1d2ad4c

Please sign in to comment.