diff --git a/.gitignore b/.gitignore index 4328f217e..6d07f4bc3 100644 --- a/.gitignore +++ b/.gitignore @@ -22,6 +22,8 @@ var/ *.egg-info/ .installed.cfg *.egg +Pipfile +Pipfile.lock # PyInstaller # Usually these files are written by a python script from a template diff --git a/imblearn/utils/_validation.py b/imblearn/utils/_validation.py index 8cb505f50..d1b0069b7 100644 --- a/imblearn/utils/_validation.py +++ b/imblearn/utils/_validation.py @@ -12,6 +12,7 @@ from sklearn.base import clone from sklearn.neighbors._base import KNeighborsMixin from sklearn.neighbors import NearestNeighbors +from sklearn.utils import column_or_1d from sklearn.utils.multiclass import type_of_target from ..exceptions import raise_isinstance_error @@ -96,6 +97,8 @@ def check_target_type(y, indicate_one_vs_all=False): "multioutput targets are not supported." ) y = y.argmax(axis=1) + else: + y = column_or_1d(y) return (y, type_y == "multilabel-indicator") if indicate_one_vs_all else y diff --git a/imblearn/utils/estimator_checks.py b/imblearn/utils/estimator_checks.py index 43f117ba3..51a039f85 100644 --- a/imblearn/utils/estimator_checks.py +++ b/imblearn/utils/estimator_checks.py @@ -44,6 +44,7 @@ def _yield_sampler_checks(name, Estimator): yield check_samplers_multiclass_ova yield check_samplers_preserve_dtype yield check_samplers_sample_indices + yield check_samplers_2d_target def _yield_classifier_checks(name, Estimator): @@ -283,6 +284,20 @@ def check_samplers_multiclass_ova(name, Sampler): assert_allclose(y_res, y_res_ova.argmax(axis=1)) +def check_samplers_2d_target(name, Sampler): + X, y = make_classification( + n_samples=100, + n_classes=3, + n_informative=4, + weights=[0.2, 0.3, 0.5], + random_state=0, + ) + + y = y.reshape(-1, 1) # Make the target 2d + sampler = Sampler() + sampler.fit_resample(X, y) + + def check_samplers_preserve_dtype(name, Sampler): X, y = make_classification( n_samples=1000,