diff --git a/src/eddymotion/model/dipy.py b/src/eddymotion/model/dipy.py index 63e2150f..1907363c 100644 --- a/src/eddymotion/model/dipy.py +++ b/src/eddymotion/model/dipy.py @@ -64,11 +64,11 @@ def gp_prediction( # Check it's fitted as they do in sklearn internally # https://github.com/scikit-learn/scikit-learn/blob/972e17fe1aa12d481b120ad4a3dc076bae736931/\ # sklearn/gaussian_process/_gpr.py#L410C9-L410C42 - if not hasattr(model._gpr, "X_train_"): + if not hasattr(model, "X_train_"): raise RuntimeError("Model is not yet fitted.") # Extract orientations from gtab, and highly likely, the b-value too. - return model._gpr.predict(gtab, return_std=False) + return model.predict(gtab, return_std=False) def get_kernel( @@ -175,9 +175,11 @@ def fit( data[mask[..., None]] if mask is not None else np.reshape(data, (-1, data.shape[-1])) ) - if data.shape[-1] != len(gtab): + signal_dirs = data.shape[-1] + grad_dirs = gtab.gradients.shape[0] + if signal_dirs != grad_dirs: raise ValueError( - f"Mismatched data {data.shape[-1]} and gradient table {len(gtab)} sizes." + f"Mismatched data {signal_dirs} and gradient table {grad_dirs} sizes." ) gpr = GaussianProcessRegressor( @@ -185,8 +187,8 @@ def fit( random_state=random_state, ) self._modelfit = GPFit( - gpr.fit(gtab.gradients, data), gtab=gtab, + model=gpr.fit(gtab.gradients, data[0]), mask=mask, ) return self._modelfit diff --git a/test/test_model.py b/test/test_model.py index 274fabf8..89f5235e 100644 --- a/test/test_model.py +++ b/test/test_model.py @@ -24,7 +24,8 @@ import numpy as np import pytest -from sklearn.datasets import make_friedman2 +from dipy.core.gradients import gradient_table +from sklearn.datasets import make_regression from eddymotion import model from eddymotion.data.dmri import DWI @@ -108,14 +109,17 @@ def test_average_model(): def test_gp_model(): - gp = GaussianProcessModel(kernel="default") + gp = GaussianProcessModel("test", kernel="default") assert isinstance(gp, model.dipy.GaussianProcessModel) - X, y = make_friedman2(n_samples=500, noise=0, random_state=0) - gp.fit(X, y) - X_qry = X[:2, :] - prediction, _ = gp.predict(X_qry, return_std=True) + X, y = make_regression(n_samples=100, n_features=3, noise=0, random_state=0) + + bvecs = X.T / np.linalg.norm(X.T, axis=0) + gtab = gradient_table([1000] * bvecs.shape[-1], bvecs) + gp.fit(y, gtab) + X_qry = bvecs[:, :2].T + prediction = gp.predict(X_qry, return_std=True) assert prediction.shape == (X_qry.shape[0],)