From 2edb894147b290a92738410944747a1e568d4fa8 Mon Sep 17 00:00:00 2001 From: ejolly Date: Fri, 10 Jun 2022 16:55:24 -0400 Subject: [PATCH] Add support for loading PLS model correctly on python 3.7 or sklearn < 1.1 --- feat/plotting.py | 2 +- feat/utils.py | 15 +++++++++++---- requirements.txt | 2 +- 3 files changed, 13 insertions(+), 6 deletions(-) diff --git a/feat/plotting.py b/feat/plotting.py index 1b684164..5b971d7a 100644 --- a/feat/plotting.py +++ b/feat/plotting.py @@ -991,7 +991,7 @@ def predict(au, model=None, feature_range=None): if len(au) != model.n_components: print(au) print(model.n_components) - raise ValueError("au vector must be len(", model.n_components, ").") + raise ValueError(f"au vector must be length {model.n_components}.") if len(au.shape) == 1: au = np.reshape(au, (1, -1)) diff --git a/feat/utils.py b/feat/utils.py index 0f7b9ff0..dade13ad 100644 --- a/feat/utils.py +++ b/feat/utils.py @@ -262,10 +262,17 @@ def load_viz_model( hf = h5py.File(h5_path, mode="r") x_weights = np.array(hf.get("x_weights")) model = PLSRegression(n_components=x_weights.shape[1]) - # PLSRegression in < 1.3 stores coefs ax features x samples unlike other - # estimators - model.__dict__["coef_"] = np.array(hf.get("coef")).T - model.__dict__["_coef_"] = np.array(hf.get("coef")).T + # PLSRegression in sklearn < 1.1 storex coefs as samples x features, but + # recent versions transpose this. Check if the user is on Python 3.7 (which + # only supports sklearn 1.0.x) or < sklearn 1.1.x + if (my_pymajor == 3 and my_pyminor == 7) or ( + my_skmajor == 1 and my_skminor != 1 + ): + model.__dict__["coef_"] = np.array(hf.get("coef")) + model.__dict__["_coef_"] = np.array(hf.get("coef")) + else: + model.__dict__["coef_"] = np.array(hf.get("coef")).T + model.__dict__["_coef_"] = np.array(hf.get("coef")).T model.__dict__["x_weights_"] = np.array(hf.get("x_weights")) model.__dict__["y_weights_"] = np.array(hf.get("y_weights")) model.__dict__["x_loadings"] = np.array(hf.get("x_loadings")) diff --git a/requirements.txt b/requirements.txt index f948a0fe..e4a0591b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,7 +3,7 @@ numpy>=1.9 seaborn>=0.7.0 matplotlib>=2.1 nltools>=0.3.6 -scikit-learn>=1.1 +scikit-learn>=1.0 pywavelets>=0.3.0 h5py>=2.7.0 Pillow>=6.0.0