Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add option to pass in pre-computed signatures #6

Merged
merged 8 commits into from
Oct 17, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ dependencies = [
"sktime@git+https://github.com/sz85512678/sktime",
"numpy",
"scikit-learn",
"matplotlib",
"numba",
]

Expand Down
58 changes: 37 additions & 21 deletions src/signature_mahalanobis_knn/mahal_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,11 @@


class Mahalanobis:
"""
After fit is called, becomes callable and intended to be used
as a distance function in sklearn nearest neighbour.
"""

def __init__(self):
"""
After fit is called, becomes callable and intended to be used
as a distance function in sklearn nearest neighbour.
"""
self.Vt: np.ndarray = np.empty(
0
) # Truncated right singular matrix transposed of the corpus
Expand All @@ -28,10 +27,10 @@ def fit(self, X: np.ndarray, **kwargs) -> None:
"""
Fit the object to a corpus X.

:param X: numpy array, panel data representing the corpus, each row is a data point
:param y: No use, here for interface consistency

:return: None
Parameters
----------
X : np.ndarray
Panel data representing the corpus, each row is a data point.
"""
# mean centering
self.mu = np.mean(X, axis=0)
Expand All @@ -55,13 +54,23 @@ def calc_distance(
"""
Compute the variance norm between x1 and x2 using the precomputed SVD.

:param x1: 1D array, row vector
:param x2: 1D array, row vector
:param Vt: 2D array, truncated right singular matrix transposed of the corpus
:param S: 1D array, truncated singular values of the corpus
:subspace_thres: float, threshold to decide whether a point is in the data subspace

:return: a value representing distance between x, y
Parameters
----------
x1 : np.ndarray
One-dimensional array.
x2 : np.ndarray
One-dimensional array.
Vt : np.ndarray
Two-dimensional arrat, truncated right singular matrix transposed of the corpus.
S : np.ndarray
One-dimensional array, truncated singular values of the corpus.
subspace_thres : float
Threshold to decide whether a point is in the data subspace.

Returns
-------
float
Value representing distance between x, y.
"""
x = x1 - x2
# quantifies the amount that x is outside the row-subspace
Expand All @@ -76,12 +85,19 @@ def calc_distance(

def distance(self, x1: np.ndarray, x2: np.ndarray) -> float:
"""
Compute the variance norm between x1 and x2.

:param x1: 1D array, row vector
:param x2: 1D array, row vector
Compute the variance norm between x1 and x2 using the precomputed SVD.

:return: a value representing distance between x, y
Parameters
----------
x1 : np.ndarray
One-dimensional array.
x2 : np.ndarray
One-dimensional array.

Returns
-------
float
Value representing distance between x, y.
"""

return self.calc_distance(x1, x2, self.Vt, self.S, self.subspace_thres)
Loading