From bd8435cd46f13b1a5b922052521e6bb25ba5f9a6 Mon Sep 17 00:00:00 2001 From: Ishtiaq Hussain <53497039+Ishticode@users.noreply.github.com> Date: Thu, 7 Dec 2023 20:12:41 +0000 Subject: [PATCH] feat: define predict method for BaseDecisionTree in sklearn frontend based on https://github.com/scikit-learn/scikit-learn/blob/main/sklearn/tree/_classes.py#L505 --- ivy/functional/frontends/sklearn/base.py | 3 --- .../frontends/sklearn/tree/_classes.py | 18 +++++++++++++++++- 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/ivy/functional/frontends/sklearn/base.py b/ivy/functional/frontends/sklearn/base.py index 8fa319eb2b587..e8dad2105936a 100644 --- a/ivy/functional/frontends/sklearn/base.py +++ b/ivy/functional/frontends/sklearn/base.py @@ -13,9 +13,6 @@ def score(self, X, y, sample_weight=None): def fit(self, X, y, **kwargs): raise NotImplementedError - def predict(self, X): - raise NotImplementedError - class TransformerMixin: def fit_transform(self, X, y=None, **fit_params): diff --git a/ivy/functional/frontends/sklearn/tree/_classes.py b/ivy/functional/frontends/sklearn/tree/_classes.py index 4fcef96e18f8a..109887524d374 100644 --- a/ivy/functional/frontends/sklearn/tree/_classes.py +++ b/ivy/functional/frontends/sklearn/tree/_classes.py @@ -137,7 +137,23 @@ def _fit( return self def predict(self, X, check_input=True): - raise NotImplementedError + proba = self.tree_.predict(X) + n_samples = X.shape[0] + + # Classification + + if self.n_outputs_ == 1: + return ivy.gather(self.classes_, ivy.argmax(proba, axis=1), axis=0) + + else: + class_type = self.classes_[0].dtype + predictions = ivy.zeros((n_samples, self.n_outputs_), dtype=class_type) + for k in range(self.n_outputs_): + predictions[:, k] = ivy.gather( + self.classes_[k], ivy.argmax(proba[:, k], axis=1), axis=0 + ) + + return predictions def apply(self, X, check_input=True): raise NotImplementedError