Skip to content

Commit

Permalink
Add support for loading PLS model correctly on python 3.7 or sklearn …
Browse files Browse the repository at this point in the history
…< 1.1
  • Loading branch information
ejolly committed Jun 10, 2022
1 parent 6e43865 commit 2edb894
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 6 deletions.
2 changes: 1 addition & 1 deletion feat/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -991,7 +991,7 @@ def predict(au, model=None, feature_range=None):
if len(au) != model.n_components:
print(au)
print(model.n_components)
raise ValueError("au vector must be len(", model.n_components, ").")
raise ValueError(f"au vector must be length {model.n_components}.")

if len(au.shape) == 1:
au = np.reshape(au, (1, -1))
Expand Down
15 changes: 11 additions & 4 deletions feat/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,10 +262,17 @@ def load_viz_model(
hf = h5py.File(h5_path, mode="r")
x_weights = np.array(hf.get("x_weights"))
model = PLSRegression(n_components=x_weights.shape[1])
# PLSRegression in < 1.3 stores coefs ax features x samples unlike other
# estimators
model.__dict__["coef_"] = np.array(hf.get("coef")).T
model.__dict__["_coef_"] = np.array(hf.get("coef")).T
# PLSRegression in sklearn < 1.1 storex coefs as samples x features, but
# recent versions transpose this. Check if the user is on Python 3.7 (which
# only supports sklearn 1.0.x) or < sklearn 1.1.x
if (my_pymajor == 3 and my_pyminor == 7) or (
my_skmajor == 1 and my_skminor != 1
):
model.__dict__["coef_"] = np.array(hf.get("coef"))
model.__dict__["_coef_"] = np.array(hf.get("coef"))
else:
model.__dict__["coef_"] = np.array(hf.get("coef")).T
model.__dict__["_coef_"] = np.array(hf.get("coef")).T
model.__dict__["x_weights_"] = np.array(hf.get("x_weights"))
model.__dict__["y_weights_"] = np.array(hf.get("y_weights"))
model.__dict__["x_loadings"] = np.array(hf.get("x_loadings"))
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ numpy>=1.9
seaborn>=0.7.0
matplotlib>=2.1
nltools>=0.3.6
scikit-learn>=1.1
scikit-learn>=1.0
pywavelets>=0.3.0
h5py>=2.7.0
Pillow>=6.0.0
Expand Down

0 comments on commit 2edb894

Please sign in to comment.