-
Notifications
You must be signed in to change notification settings - Fork 93
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ENH] Implement Proximity Forest classifier #1729
base: main
Are you sure you want to change the base?
Conversation
Thank you for contributing to
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work, we'll have to give it a run through on the UCR archive datasets as we discussed. Next steps are up to you really, we can discuss on Slack.
This needs to be included in the API documentation also. |
def _fit_tree(self, X, y): | ||
clf = ProximityTree( | ||
n_splitters=self.n_splitters, | ||
max_depth=self.max_depth, | ||
min_samples_split=self.min_samples_split, | ||
random_state=self.random_state, | ||
n_jobs=self.n_jobs, | ||
) | ||
clf.fit(X, y) | ||
return clf |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar comment for predict, but I think it might be better to define the function you parallelize with joblib outside of the object you call them from. Something to do with the fact that joblib pickling the objects you parallelize, if I remember right ? This might mean that you create a copy of the ProximityForest
object every time you call _fit_tree
.
To avoid that, you would define _fit_tree
as a function outside ProximityForest
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for pointing this out.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this true? I think we have functions elsewhere that do this. Interesting to see if that needs to be changed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oher than this testing issue, the rest LGTM !
aeon/classification/distance_based/tests/test_proximity_forest.py
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You should investigate how we use joblib in other estimators in the classification module. Typically we use "threads" as a default backend, with a parameter to change that,
The docstring n_jobs
needs updating.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Only this small parameter missing and it should be good to go !
) | ||
|
||
def _predict_proba(self, X): | ||
output_probas = Parallel(n_jobs=self._n_jobs, prefer="threads")( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We discussed the need of a parameter for the joblib backend (i.e. threads vs processes), you should add a class parameter that default to threads.
It might be better to use the backend
parameter instead of the prefer
(see docs) to have a more fine grained control over the chosen backend.
Reference Issues/PRs
Closes #159
What does this implement/fix? Explain your changes.
Implementation of Proximity Forest Algorithm using the Proximity Trees.