diff --git a/qunfold/sklearn.py b/qunfold/sklearn.py index b31206d..7191d2a 100644 --- a/qunfold/sklearn.py +++ b/qunfold/sklearn.py @@ -46,7 +46,7 @@ def fit(self, X, y): def predict_proba(self, X): if not hasattr(self, "classes_"): raise NotFittedError() - y_pred = np.zeros((len(self.estimators_), len(X), len(self.classes_))) + y_pred = np.zeros((len(self.estimators_), X.shape[0], len(self.classes_))) for i, (estimator, i_classes) in enumerate(zip(self.estimators_, self.i_classes_)): y_pred[i, :, i_classes] = estimator.predict_proba(X).T return np.mean(y_pred, axis=0) # shape (n_samples, n_classes) diff --git a/qunfold/transformers.py b/qunfold/transformers.py index ac1fafd..1dbeb46 100644 --- a/qunfold/transformers.py +++ b/qunfold/transformers.py @@ -89,7 +89,7 @@ def fit_transform(self, X, y, average=True, n_classes=None): n_classes = len(self.p_trn) # not None anymore if self.fit_classifier: self.classifier.fit(X, y) - fX = np.zeros((len(X), n_classes)) + fX = np.zeros((X.shape[0], n_classes)) fX[:, self.classifier.classes_] = self.classifier.oob_decision_function_ is_finite = np.all(np.isfinite(fX), axis=1) fX = fX[is_finite,:] # drop instances that never became OOB @@ -105,7 +105,7 @@ def fit_transform(self, X, y, average=True, n_classes=None): return fX, y def transform(self, X, average=True): n_classes = len(self.p_trn) - fX = np.zeros((len(X), n_classes)) + fX = np.zeros((X.shape[0], n_classes)) fX[:, self.classifier.classes_] = self.classifier.predict_proba(X) if not self.is_probabilistic: fX = _onehot_encoding(np.argmax(fX, axis=1), n_classes)