diff --git a/molpipeline/estimators/chemprop/models.py b/molpipeline/estimators/chemprop/models.py index af4c577c..cd5c21a6 100644 --- a/molpipeline/estimators/chemprop/models.py +++ b/molpipeline/estimators/chemprop/models.py @@ -193,7 +193,7 @@ def fit( The fitted model. """ if self._is_classifier(): - self._classes_ = np.unique(y) + self._classes_ = np.unique(list(y)) return super().fit(X, y) def predict(