Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] Implement Proximity Forest classifier #1729

Merged
merged 34 commits into from
Aug 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
dc8bd1d
Proximity Forest draft
itsdivya1309 Jun 26, 2024
d4ce3e3
Merge branch 'aeon-toolkit:main' into proximityForest
itsdivya1309 Jun 27, 2024
7406dff
Update init
itsdivya1309 Jun 28, 2024
7ed886f
Merge branch 'aeon-toolkit:main' into proximityForest
itsdivya1309 Jul 1, 2024
d063b99
Tests for forest
itsdivya1309 Jul 2, 2024
81f86ee
Docstring
itsdivya1309 Jul 2, 2024
1ab94aa
Fix initialization error
itsdivya1309 Jul 2, 2024
3359fe3
Merge branch 'main' into proximityForest
itsdivya1309 Jul 2, 2024
80f1ca8
Update tags
itsdivya1309 Jul 2, 2024
db136c6
Fix tests
itsdivya1309 Jul 3, 2024
4631e8d
Review comments resolved
itsdivya1309 Jul 5, 2024
b7d0461
Review comments resolved
itsdivya1309 Jul 5, 2024
59a8175
Merge branch 'main' into proximityForest
itsdivya1309 Jul 5, 2024
8905461
Parallelization using joblib
itsdivya1309 Jul 5, 2024
0294311
Merge branch 'aeon-toolkit:main' into proximityForest
itsdivya1309 Jul 8, 2024
2d74d4d
pickling objects
itsdivya1309 Jul 8, 2024
7953c00
Merge branch 'aeon-toolkit:main' into proximityForest
itsdivya1309 Jul 11, 2024
b7505ad
Parallel threading
itsdivya1309 Jul 11, 2024
c853e55
Using unit test dataset
itsdivya1309 Jul 11, 2024
8959efa
Merge branch 'main' into proximityForest
itsdivya1309 Jul 12, 2024
1cb74a4
Merge branch 'main' into proximityForest
MatthewMiddlehurst Jul 13, 2024
e5a095f
parallel_backend parameter
itsdivya1309 Jul 15, 2024
3d449e8
classes_
itsdivya1309 Jul 15, 2024
28bbad5
First check pure nodes
itsdivya1309 Jul 15, 2024
7790ba9
No overwriting of base class attributes
itsdivya1309 Jul 16, 2024
cb9b361
Remove n_jobs for tree
itsdivya1309 Jul 16, 2024
9db1645
Majority Voting
itsdivya1309 Jul 21, 2024
ff4c83a
Merge branch 'main' into proximityForest
itsdivya1309 Jul 21, 2024
f3c54f8
Merge branch 'main' into proximityForest
itsdivya1309 Jul 22, 2024
0a998bb
threading
itsdivya1309 Jul 25, 2024
98cf52b
More randomness
itsdivya1309 Jul 26, 2024
a6a9821
Revert "More randomness"
itsdivya1309 Jul 26, 2024
966c6dd
Revert "threading"
itsdivya1309 Jul 26, 2024
57b3e2d
Merge branch 'aeon-toolkit:main' into proximityForest
itsdivya1309 Jul 29, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion aeon/classification/distance_based/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
"""Distance based time series classifiers."""

__all__ = ["ElasticEnsemble", "KNeighborsTimeSeriesClassifier", "ProximityTree"]
__all__ = [
"ElasticEnsemble",
"KNeighborsTimeSeriesClassifier",
"ProximityTree",
"ProximityForest",
]

from aeon.classification.distance_based._elastic_ensemble import ElasticEnsemble
from aeon.classification.distance_based._proximity_forest import ProximityForest
from aeon.classification.distance_based._proximity_tree import ProximityTree
from aeon.classification.distance_based._time_series_neighbors import (
KNeighborsTimeSeriesClassifier,
Expand Down
156 changes: 156 additions & 0 deletions aeon/classification/distance_based/_proximity_forest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
"""Proximity Forest Classifier.

The Proximity Forest is an ensemble of Proximity Trees.
"""

__all__ = ["ProximityForest"]

from typing import Type, Union

import numpy as np
from joblib import Parallel, delayed
from sklearn.utils import check_random_state

from aeon.classification.base import BaseClassifier
from aeon.classification.distance_based._proximity_tree import ProximityTree


class ProximityForest(BaseClassifier):
"""Proximity Forest Classifier.

The Proximity Forest is a distance-based classifier that creates an
ensemble of decision trees, where the splits are based on the
similarity between time series measured using various parameterised
distance measures.

Parameters
----------
n_trees: int, default = 100
The number of trees, by default an ensemble of 100 trees is formed.
n_splitters: int, default = 5
The number of candidate splitters to be evaluated at each node.
max_depth: int, default = None
The maximum depth of the tree. If None, then nodes are expanded until all
leaves are pure or until all leaves contain less than min_samples_split samples.
min_samples_split: int, default = 2
The minimum number of samples required to split an internal node.
random_state : int, RandomState instance or None, default=None
If `int`, random_state is the seed used by the random number generator;
If `RandomState` instance, random_state is the random number generator;
If `None`, the random number generator is the `RandomState` instance used
by `np.random`.
n_jobs : int, default = 1
The number of parallel jobs to run for neighbors search.
``None`` means 1 unless in a :obj:`joblib.parallel_backend` context.
``-1`` means using all processors. See :term:`Glossary <n_jobs>`
for more details. Parameter for compatibility purposes, still unimplemented.
parallel_backend : str, ParallelBackendBase instance or None, default=None
Specify the parallelisation backend implementation in joblib, if None a 'prefer'
value of "threads" is used by default.
Valid options are "loky", "multiprocessing", "threading" or a custom backend.
See the joblib Parallel documentation for more details.

Notes
-----
For the Java version, see
`ProximityForest
<https://github.com/fpetitjean/ProximityForest>`_.

References
----------
.. [1] Lucas, B., Shifaz, A., Pelletier, C., O’Neill, L., Zaidi, N., Goethals, B.,
Petitjean, F. and Webb, G.I., 2019. Proximity forest: an effective and scalable
distance-based classifier for time series. Data Mining and Knowledge Discovery,
33(3), pp.607-635.

Examples
--------
>>> from aeon.datasets import load_unit_test
>>> from aeon.classification.distance_based import ProximityForest
>>> X_train, y_train = load_unit_test(split="train")
>>> X_test, y_test = load_unit_test(split="test")
>>> classifier = ProximityForest(n_trees = 10, n_splitters = 3)
>>> classifier.fit(X_train, y_train)
ProximityForest(...)
>>> y_pred = classifier.predict(X_test)
"""

_tags = {
"capability:multivariate": False,
"capability:unequal_length": False,
"capability:multithreading": True,
"algorithm_type": "distance",
"X_inner_type": ["numpy2D"],
}

def __init__(
self,
n_trees=100,
n_splitters: int = 5,
max_depth: int = None,
min_samples_split: int = 2,
random_state: Union[int, Type[np.random.RandomState], None] = None,
n_jobs: int = 1,
parallel_backend=None,
):
self.n_trees = n_trees
self.n_splitters = n_splitters
self.max_depth = max_depth
self.min_samples_split = min_samples_split
self.random_state = random_state
self.n_jobs = n_jobs
self.parallel_backend = parallel_backend
super().__init__()

def _fit(self, X, y):
rng = check_random_state(self.random_state)
self.trees_ = Parallel(
n_jobs=self._n_jobs, backend=self.parallel_backend, prefer="threads"
itsdivya1309 marked this conversation as resolved.
Show resolved Hide resolved
)(
delayed(_fit_tree)(
X,
y,
self.n_splitters,
self.max_depth,
self.min_samples_split,
check_random_state(rng.randint(np.iinfo(np.int32).max)),
baraline marked this conversation as resolved.
Show resolved Hide resolved
)
for _ in range(self.n_trees)
)

def _predict_proba(self, X):
classes = list(self.classes_)
preds = Parallel(
n_jobs=self._n_jobs, backend=self.parallel_backend, prefer="threads"
baraline marked this conversation as resolved.
Show resolved Hide resolved
)(delayed(_predict_tree)(tree, X) for tree in self.trees_)
n_cases = X.shape[0]
votes = np.zeros((n_cases, self.n_classes_))
for i in range(len(preds)):
predictions = np.array(
[classes.index(class_label) for class_label in preds[i]]
)
for j in range(n_cases):
votes[j, predictions[j]] += 1
output_probas = votes / self.n_trees
return output_probas

def _predict(self, X):
probas = self._predict_proba(X)
idx = np.argmax(probas, axis=1)
preds = np.asarray([self.classes_[x] for x in idx])
return preds


def _fit_tree(X, y, n_splitters, max_depth, min_samples_split, random_state):
clf = ProximityTree(
n_splitters=n_splitters,
max_depth=max_depth,
min_samples_split=min_samples_split,
random_state=random_state,
)
clf.fit(X, y)
return clf


def _predict_tree(tree, X):
return tree.predict(X)
39 changes: 8 additions & 31 deletions aeon/classification/distance_based/_proximity_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,6 @@ class ProximityTree(BaseClassifier):
If `RandomState` instance, random_state is the random number generator;
If `None`, the random number generator is the `RandomState` instance used
by `np.random`.
n_jobs : int, default = 1
The number of parallel jobs to run for neighbors search.
``None`` means 1 unless in a :obj:`joblib.parallel_backend` context.
``-1`` means using all processors. See :term:`Glossary <n_jobs>`
for more details. Parameter for compatibility purposes, still unimplemented.

Notes
-----
Expand Down Expand Up @@ -117,7 +112,7 @@ class ProximityTree(BaseClassifier):
"capability:multivariate": False,
"capability:unequal_length": False,
"algorithm_type": "distance",
"X_inner_type": ["numpy2D", "numpy3D"],
"X_inner_type": ["numpy2D"],
}

def __init__(
Expand All @@ -126,13 +121,11 @@ def __init__(
max_depth: int = None,
min_samples_split: int = 2,
random_state: Union[int, Type[np.random.RandomState], None] = None,
n_jobs: int = 1,
) -> None:
self.n_splitters = n_splitters
self.max_depth = max_depth
self.min_samples_split = min_samples_split
self.random_state = random_state
self.n_jobs = n_jobs
super().__init__()

def _get_parameter_value(self, X):
Expand Down Expand Up @@ -276,8 +269,8 @@ def _build_tree(self, X, y, depth, node_id, parent_target_value=None):
for label, count in zip(*np.unique(y, return_counts=True))
}

# If min sample splits is reached
if self.min_samples_split >= len(X):
# Pure node
if len(np.unique(y)) == 1:
leaf_label = target_value
leaf = _Node(
node_id=node_id,
Expand All @@ -287,8 +280,8 @@ def _build_tree(self, X, y, depth, node_id, parent_target_value=None):
)
return leaf

# If max depth is reached
if (self.max_depth is not None) and (depth >= self.max_depth):
# If min sample splits is reached
if self.min_samples_split >= len(X):
leaf_label = target_value
leaf = _Node(
node_id=node_id,
Expand All @@ -298,8 +291,8 @@ def _build_tree(self, X, y, depth, node_id, parent_target_value=None):
)
return leaf

# Pure node
if len(np.unique(y)) == 1:
# If max depth is reached
if (self.max_depth is not None) and (depth >= self.max_depth):
leaf_label = target_value
leaf = _Node(
node_id=node_id,
Expand Down Expand Up @@ -371,16 +364,6 @@ def _find_target_value(y):
return mode_value

def _fit(self, X, y):
# Check dimension of X
if X.ndim == 3:
if X.shape[1] == 1:
X = np.squeeze(X, axis=1)
else:
raise ValueError("X should be univariate.")

# Set the unique class labels
self.classes_ = list(np.unique(y))

self.root = self._build_tree(
X, y, depth=0, node_id="0", parent_target_value=None
)
Expand All @@ -391,14 +374,8 @@ def _predict(self, X):
return np.array([self.classes_[pred] for pred in predictions])

def _predict_proba(self, X):
# Check dimension of X
if X.ndim == 3:
if X.shape[1] == 1:
X = np.squeeze(X, axis=1)
else:
raise ValueError("X should be univariate.")
# Get the unique class labels
classes = self.classes_
classes = list(self.classes_)
class_count = len(classes)
probas = []

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
"""Test for Proximity Forest."""

import pytest
from sklearn.metrics import accuracy_score

from aeon.classification.distance_based import ProximityForest
from aeon.datasets import load_unit_test


def test_univariate():
"""Test that the function gives appropriate error message."""
X, y = load_unit_test()
X_multivariate = X.reshape((-1, 2, 12))
clf = ProximityForest(n_trees=5, random_state=42, n_jobs=-1)
with pytest.raises(ValueError):
clf.fit(X_multivariate, y)


def test_proximity_forest():
"""Test the fit method of ProximityTree."""
X_train, y_train = load_unit_test()
X_test, y_test = load_unit_test(split="test")
clf = ProximityForest(n_trees=5, n_splitters=3, max_depth=4)
clf.fit(X_train, y_train)
y_pred = clf.predict(X_test)
score = accuracy_score(y_test, y_pred)
assert score >= 0.9
26 changes: 8 additions & 18 deletions aeon/classification/distance_based/tests/test_proximity_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,7 @@

from aeon.classification.distance_based import ProximityTree
from aeon.classification.distance_based._proximity_tree import gini, gini_gain
from aeon.testing.data_generation import make_example_3d_numpy


@pytest.fixture
def time_series_dataset():
"""Generate time series dataset for testing."""
n_samples = 100 # Total number of samples (should be even)
n_timepoints = 24 # Length of each time series
n_channels = 1
data, labels = make_example_3d_numpy(n_samples, n_channels, n_timepoints)
return data, labels
from aeon.datasets import load_unit_test


def test_gini():
Expand Down Expand Up @@ -110,9 +100,9 @@ def test_get_parameter_value():
assert measure_params["c"] in [10**i for i in range(-2, 3)]


def test_get_cadidate_splitter(time_series_dataset):
def test_get_cadidate_splitter():
"""Test the method to generate candidate splitters."""
X, y = time_series_dataset
X, y = load_unit_test()
clf = ProximityTree()
splitter = clf._get_candidate_splitter(X, y)
assert len(splitter) == 2
Expand All @@ -132,9 +122,9 @@ def test_get_cadidate_splitter(time_series_dataset):
assert measure in expected_measures


def test_get_best_splitter(time_series_dataset):
def test_get_best_splitter():
"""Test the method to get optimum splitter of a node."""
X, y = time_series_dataset
X, y = load_unit_test()
clf = ProximityTree(n_splitters=3)

splitter = clf._get_best_splitter(X, y)
Expand All @@ -146,12 +136,12 @@ def test_get_best_splitter(time_series_dataset):
assert len(splitter) == 2


def test_proximity_tree(time_series_dataset):
def test_proximity_tree():
"""Test the fit method of ProximityTree."""
X, y = time_series_dataset
X, y = load_unit_test()
clf = ProximityTree(n_splitters=3, max_depth=4)
clf.fit(X, y)
X_test, y_test = time_series_dataset
X_test, y_test = load_unit_test(split="train")
y_pred = clf.predict(X_test)
score = accuracy_score(y_test, y_pred)
assert score >= 0.9
1 change: 1 addition & 0 deletions docs/api_reference/classification.rst
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ Distance-based

ElasticEnsemble
KNeighborsTimeSeriesClassifier
ProximityForest
ProximityTree

Feature-based
Expand Down