Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] Added ROCKAD anomaly detector to aeon #2376

Merged
merged 20 commits into from
Dec 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
303f1e2
Added ROCKAD anomaly detector to aeon
pattplatt Nov 21, 2024
bb43951
Added ROCKAD to anomaly_detection.rst
pattplatt Nov 21, 2024
6384627
Empty commit for CI
MatthewMiddlehurst Dec 2, 2024
1382a58
Merge remote-tracking branch 'upstream/main' into feature/rockad
pattplatt Dec 2, 2024
a6b867b
Merge branch 'feature/rockad' of https://github.com/pattplatt/aeon_ts…
pattplatt Dec 2, 2024
48724d6
Automatic `pre-commit` fixes
pattplatt Dec 2, 2024
ae7bf5b
Fix newline at end of file
pattplatt Dec 2, 2024
9853a7b
Merge branch 'feature/rockad' of https://github.com/pattplatt/aeon_ts…
pattplatt Dec 2, 2024
db89e0f
adopted code to fit refactored Rocket arguments
pattplatt Dec 2, 2024
a9bec44
added "capability:multithreading": True to _tags
pattplatt Dec 4, 2024
d60d9c2
Catch power transform and disable it if it results in error
pattplatt Dec 6, 2024
dcacaf3
Added fallback when power transform fails, added ValueError if number…
pattplatt Dec 9, 2024
48304c7
Added private attribute for power_transform activation/deactivation
pattplatt Dec 9, 2024
5b9906b
added tests for kneighbors check, adapted univariate and multivariate…
pattplatt Dec 10, 2024
a93d47e
Removed pandas, set rocket normalise default to False, use transposed…
pattplatt Dec 10, 2024
9d85e5f
Added test for power transform failure
pattplatt Dec 12, 2024
a953da5
removed transpose, added back user warning when power transform fails…
pattplatt Dec 12, 2024
78a2c9d
removed noop
pattplatt Dec 13, 2024
18d62d1
removed inf_columns_ check
pattplatt Dec 19, 2024
7cda1df
moved parent class init and n_jobs to have consistent structure
pattplatt Dec 20, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions aeon/anomaly_detection/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
"LOF",
"MERLIN",
"OneClassSVM",
"ROCKAD",
"PyODAdapter",
"STOMP",
"STRAY",
Expand All @@ -25,5 +26,6 @@
from aeon.anomaly_detection._merlin import MERLIN
from aeon.anomaly_detection._one_class_svm import OneClassSVM
from aeon.anomaly_detection._pyodadapter import PyODAdapter
from aeon.anomaly_detection._rockad import ROCKAD
from aeon.anomaly_detection._stomp import STOMP
from aeon.anomaly_detection._stray import STRAY
262 changes: 262 additions & 0 deletions aeon/anomaly_detection/_rockad.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,262 @@
"""ROCKAD anomaly detector."""

__all__ = ["ROCKAD"]

import warnings
from typing import Optional

import numpy as np
from sklearn.neighbors import NearestNeighbors
from sklearn.preprocessing import PowerTransformer
from sklearn.utils import resample

from aeon.anomaly_detection.base import BaseAnomalyDetector
from aeon.transformations.collection.convolution_based import Rocket
from aeon.utils.windowing import reverse_windowing, sliding_windows


class ROCKAD(BaseAnomalyDetector):
"""
ROCKET-based Anomaly Detector (ROCKAD).

ROCKAD leverages the ROCKET transformation for feature extraction from
time series data and applies the scikit learn k-nearest neighbors (k-NN)
approach with bootstrap aggregation for robust anomaly detection.
After windowing, the data gets transformed into the ROCKET feature space.
Then the windows are compared based on the feature space by
finding the nearest neighbours.

This class supports both univariate and multivariate time series and
provides options for normalizing features, applying power transformations,
and customizing the distance metric.

Parameters
----------
n_estimators : int, default=10
Number of k-NN estimators to use in the bootstrap aggregation.
n_kernels : int, default=100
Number of kernels to use in the ROCKET transformation.
normalise : bool, default=False
Whether to normalize the ROCKET-transformed features.
n_neighbors : int, default=5
Number of neighbors to use for the k-NN algorithm.
n_jobs : int, default=1
Number of parallel jobs to use for the k-NN algorithm and ROCKET transformation.
metric : str, default="euclidean"
Distance metric to use for the k-NN algorithm.
power_transform : bool, default=True
Whether to apply a power transformation (Yeo-Johnson) to the features.
window_size : int, default=10
Size of the sliding window for segmenting input time series data.
stride : int, default=1
Step size for moving the sliding window over the time series data.
random_state : int, default=42
Random seed for reproducibility.

Attributes
----------
rocket_transformer_ : Optional[Rocket]
Instance of the ROCKET transformer used to extract features, set after fitting.
list_baggers_ : Optional[list[NearestNeighbors]]
List containing k-NN estimators used for anomaly scoring, set after fitting.
power_transformer_ : PowerTransformer
Transformer used to apply power transformation to the features.
"""

_tags = {
"capability:univariate": True,
"capability:multivariate": True,
"capability:missing_values": False,
"capability:multithreading": True,
"fit_is_empty": False,
}
baraline marked this conversation as resolved.
Show resolved Hide resolved

def __init__(
self,
n_estimators=10,
n_kernels=100,
normalise=False,
n_neighbors=5,
metric="euclidean",
power_transform=True,
window_size: int = 10,
stride: int = 1,
n_jobs=1,
random_state=42,
):

self.n_estimators = n_estimators
self.n_kernels = n_kernels
self.normalise = normalise
self.n_neighbors = n_neighbors
self.n_jobs = n_jobs
self.metric = metric
self.power_transform = power_transform
self.window_size = window_size
self.stride = stride
self.random_state = random_state

self.rocket_transformer_: Optional[Rocket] = None
self.list_baggers_: Optional[list[NearestNeighbors]] = None
self.power_transformer_: Optional[PowerTransformer] = None

super().__init__(axis=0)

def _fit(self, X: np.ndarray, y: Optional[np.ndarray] = None) -> "ROCKAD":
self._check_params(X)
# X: (n_timepoints, 1) because __init__(axis==0)
_X, _ = sliding_windows(
X, window_size=self.window_size, stride=self.stride, axis=0
)
pattplatt marked this conversation as resolved.
Show resolved Hide resolved
# _X: (n_windows, window_size)
self._inner_fit(_X)

return self

def _check_params(self, X: np.ndarray) -> None:
if self.window_size < 1 or self.window_size > X.shape[0]:
raise ValueError(
"The window size must be at least 1 and at most the length of the "
"time series."
)

if self.stride < 1 or self.stride > self.window_size:
raise ValueError(
"The stride must be at least 1 and at most the window size."
)

if int((X.shape[0] - self.window_size) / self.stride + 1) < self.n_neighbors:
raise ValueError(
f"Window count ({int((X.shape[0]-self.window_size)/self.stride+1)}) "
f"has to be larger than n_neighbors ({self.n_neighbors})."
"Please choose a smaller n_neighbors value or increase window count "
"by choosing a smaller window size or larger stride."
)

def _inner_fit(self, X: np.ndarray) -> None:

self.rocket_transformer_ = Rocket(
n_kernels=self.n_kernels,
normalise=self.normalise,
n_jobs=self.n_jobs,
random_state=self.random_state,
)
# X: (n_windows, window_size)
Xt = self.rocket_transformer_.fit_transform(X)
# XT: (n_cases, n_kernels*2)
Xt = Xt.astype(np.float64)

if self.power_transform:
self.power_transformer_ = PowerTransformer()
try:
Xtp = self.power_transformer_.fit_transform(Xt)

except Exception:
warnings.warn(
"Power Transform failed and thus has been disabled. "
"Try increasing the window size.",
UserWarning,
stacklevel=2,
)
self.power_transformer_ = None
Xtp = Xt
else:
Xtp = Xt

self.list_baggers_ = []

for idx_estimator in range(self.n_estimators):
# Initialize estimator
estimator = NearestNeighbors(
n_neighbors=self.n_neighbors,
n_jobs=self.n_jobs,
metric=self.metric,
algorithm="kd_tree",
)
# Bootstrap Aggregation
Xtp_scaled_sample = resample(
Xtp,
replace=True,
n_samples=None,
random_state=self.random_state + idx_estimator,
stratify=None,
)

# Fit estimator and append to estimator list
estimator.fit(Xtp_scaled_sample)
self.list_baggers_.append(estimator)

def _predict(self, X) -> np.ndarray:

_X, padding = sliding_windows(
X, window_size=self.window_size, stride=self.stride, axis=0
)

point_anomaly_scores = self._inner_predict(_X, padding)

return point_anomaly_scores

def _fit_predict(self, X: np.ndarray, y: Optional[np.ndarray] = None) -> np.ndarray:
self._check_params(X)
_X, padding = sliding_windows(
X, window_size=self.window_size, stride=self.stride, axis=0
)

self._inner_fit(_X)
point_anomaly_scores = self._inner_predict(_X, padding)
return point_anomaly_scores

def _inner_predict(self, X: np.ndarray, padding: int) -> np.ndarray:

anomaly_scores = self._predict_proba(X)

point_anomaly_scores = reverse_windowing(
anomaly_scores, self.window_size, np.nanmean, self.stride, padding
)

point_anomaly_scores = (point_anomaly_scores - point_anomaly_scores.min()) / (
point_anomaly_scores.max() - point_anomaly_scores.min()
)

return point_anomaly_scores

def _predict_proba(self, X):
"""
Predicts the probability of anomalies for the input data.

Parameters
----------
X (array-like): The input data.

Returns
-------
np.ndarray: The predicted probabilities.

"""
y_scores = np.zeros((len(X), self.n_estimators))
# Transform into rocket feature space
Xt = self.rocket_transformer_.transform(X)

Xt = Xt.astype(np.float64)

if self.power_transformer_ is not None:
# Power Transform using yeo-johnson
Xtp = self.power_transformer_.transform(Xt)
pattplatt marked this conversation as resolved.
Show resolved Hide resolved

else:
Xtp = Xt

for idx, bagger in enumerate(self.list_baggers_):
# Get scores from each estimator
distances, _ = bagger.kneighbors(Xtp)

# Compute mean distance of nearest points in window
scores = distances.mean(axis=1).reshape(-1, 1)
scores = scores.squeeze()

y_scores[:, idx] = scores

# Average the scores to get the final score for each time series
y_scores = y_scores.mean(axis=1)

return y_scores
75 changes: 75 additions & 0 deletions aeon/anomaly_detection/tests/test_rockad.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
"""Tests for the ROCKAD anomaly detector."""

import numpy as np
import pytest
from sklearn.utils import check_random_state

from aeon.anomaly_detection import ROCKAD


def test_rockad_univariate():
"""Test ROCKAD univariate output."""
rng = check_random_state(seed=2)
series = rng.normal(size=(100,))
series[50:58] -= 5

ad = ROCKAD(
n_estimators=100,
n_kernels=10,
n_neighbors=9,
power_transform=True,
window_size=20,
stride=1,
)

pred = ad.fit_predict(series, axis=0)

assert pred.shape == (100,)
assert pred.dtype == np.float64
assert 50 <= np.argmax(pred) <= 58


def test_rockad_multivariate():
"""Test ROCKAD multivariate output."""
rng = check_random_state(seed=2)
series = rng.normal(size=(100, 3))
series[50:58, 0] -= 5
series[87:90, 1] += 0.1

ad = ROCKAD(
n_estimators=1000,
n_kernels=100,
n_neighbors=20,
power_transform=True,
window_size=10,
stride=1,
)

pred = ad.fit_predict(series, axis=0)

assert pred.shape == (100,)
assert pred.dtype == np.float64
assert 50 <= np.argmax(pred) <= 58


def test_rockad_incorrect_input():
"""Test ROCKAD incorrect input."""
rng = check_random_state(seed=2)
series = rng.normal(size=(100,))

with pytest.raises(ValueError, match="The window size must be at least 1"):
ad = ROCKAD(window_size=0)
ad.fit_predict(series)
with pytest.raises(ValueError, match="The stride must be at least 1"):
ad = ROCKAD(stride=0)
ad.fit_predict(series)
pattplatt marked this conversation as resolved.
Show resolved Hide resolved
with pytest.raises(
ValueError, match=r"Window count .* has to be larger than n_neighbors .*"
):
ad = ROCKAD(stride=1, window_size=100)
ad.fit_predict(series)
with pytest.warns(
UserWarning, match=r"Power Transform failed and thus has been disabled."
):
ad = ROCKAD(stride=1, window_size=5)
ad.fit_predict(series)
1 change: 1 addition & 0 deletions docs/api_reference/anomaly_detection.rst
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ Detectors
MERLIN
OneClassSVM
PyODAdapter
ROCKAD
STOMP
STRAY

Expand Down