Skip to content

Commit

Permalink
Replace len(X) with X.shape[0] to improve support for sparse matrices
Browse files Browse the repository at this point in the history
  • Loading branch information
mirkobunse committed Mar 15, 2024
1 parent 96099af commit f3a0dc0
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
2 changes: 1 addition & 1 deletion qunfold/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def fit(self, X, y):
def predict_proba(self, X):
if not hasattr(self, "classes_"):
raise NotFittedError()
y_pred = np.zeros((len(self.estimators_), len(X), len(self.classes_)))
y_pred = np.zeros((len(self.estimators_), X.shape[0], len(self.classes_)))
for i, (estimator, i_classes) in enumerate(zip(self.estimators_, self.i_classes_)):
y_pred[i, :, i_classes] = estimator.predict_proba(X).T
return np.mean(y_pred, axis=0) # shape (n_samples, n_classes)
Expand Down
4 changes: 2 additions & 2 deletions qunfold/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def fit_transform(self, X, y, average=True, n_classes=None):
n_classes = len(self.p_trn) # not None anymore
if self.fit_classifier:
self.classifier.fit(X, y)
fX = np.zeros((len(X), n_classes))
fX = np.zeros((X.shape[0], n_classes))
fX[:, self.classifier.classes_] = self.classifier.oob_decision_function_
is_finite = np.all(np.isfinite(fX), axis=1)
fX = fX[is_finite,:] # drop instances that never became OOB
Expand All @@ -105,7 +105,7 @@ def fit_transform(self, X, y, average=True, n_classes=None):
return fX, y
def transform(self, X, average=True):
n_classes = len(self.p_trn)
fX = np.zeros((len(X), n_classes))
fX = np.zeros((X.shape[0], n_classes))
fX[:, self.classifier.classes_] = self.classifier.predict_proba(X)
if not self.is_probabilistic:
fX = _onehot_encoding(np.argmax(fX, axis=1), n_classes)
Expand Down

0 comments on commit f3a0dc0

Please sign in to comment.