Skip to content

Commit

Permalink
modfiy format
Browse files Browse the repository at this point in the history
  • Loading branch information
lingyiyang committed Oct 16, 2024
1 parent 56d0880 commit 270cef2
Showing 1 changed file with 9 additions and 6 deletions.
15 changes: 9 additions & 6 deletions src/signature_mahalanobis_knn/sig_mahal_knn.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,8 +206,8 @@ def conformance(
kth: int | 1, optional
The distance to the kth nearest neighbor to be returned, 1 is nearest
n_neighbors: int | 20, optional
The neighborhood to look for the kth nearest neighbor.
The neighborhood to look for the kth nearest neighbor.
return_indices : bool, optional
Whether to return the indices of the nearest neighbors,
by default False.
Expand Down Expand Up @@ -256,11 +256,11 @@ def conformance(
)

# post-process the candidate distances (in the original space)
# create (n_test, n_neighbours) array with each column as [0, 1, ..., n_test-1]
# create (n_test, n_neighbours) array with each column as [0, 1, ..., n_test-1]
test_indices = np.tile(
np.arange(train_indices.shape[0]), (train_indices.shape[1], 1)
).T

# differences has shape (n_test x n_neighbors x sig_dim)
differences = (
self.signatures_train[train_indices] - signatures_test[test_indices]
Expand All @@ -285,6 +285,9 @@ def conformance(

# compute the kth closest point of the candidate distances for each data point
if return_indices:
return np.partition(candidate_distances, kth-1, axis=-1)[:, kth-1], train_indices
return (
np.partition(candidate_distances, kth - 1, axis=-1)[:, kth - 1],
train_indices,
)

return np.partition(candidate_distances, kth-1, axis=-1)[:, kth-1]
return np.partition(candidate_distances, kth - 1, axis=-1)[:, kth - 1]

0 comments on commit 270cef2

Please sign in to comment.