diff --git a/pymc_marketing/clv/models/basic.py b/pymc_marketing/clv/models/basic.py index 8258c7a69..299e1d26e 100644 --- a/pymc_marketing/clv/models/basic.py +++ b/pymc_marketing/clv/models/basic.py @@ -273,10 +273,6 @@ def output_var(self): """Output variable of the model.""" pass - def _generate_and_preprocess_model_data(self, *args, **kwargs): - """Generate and preprocess model data.""" - pass - def _data_setter(self): """Set the data for the model.""" pass diff --git a/pymc_marketing/model_builder.py b/pymc_marketing/model_builder.py index 6f11c588f..87e835d49 100644 --- a/pymc_marketing/model_builder.py +++ b/pymc_marketing/model_builder.py @@ -207,8 +207,8 @@ def _validate_data(self, X, y=None): @abstractmethod def _data_setter( self, - X: np.ndarray | pd.DataFrame, - y: np.ndarray | pd.Series | None = None, + X: np.ndarray | pd.DataFrame | xr.Dataset | xr.DataArray, + y: np.ndarray | pd.Series | xr.DataArray | None = None, ) -> None: """Set new data in the model. @@ -304,44 +304,11 @@ def default_sampler_config(self) -> dict: """ - @abstractmethod - def _generate_and_preprocess_model_data( - self, X: pd.DataFrame | pd.Series, y: np.ndarray - ) -> None: - """Apply preprocessing to the data before fitting the model. - - if validate is True, it will check if the data is valid for the model. - sets self.model_coords based on provided dataset - - In case of optional parameters being passed into the model, this method should implement the conditional - logic responsible for correct handling of the optional parameters, and including them into the dataset. - - Parameters - ---------- - X : array, shape (n_obs, n_features) - y : array, shape (n_obs,) - - Examples - -------- - >>> @classmethod - >>> def _generate_and_preprocess_model_data(self, X, y): - coords = { - 'x_dim': X.dim_variable, - } #only include if applicable for your model - >>> self.X = X - >>> self.y = y - - Returns - ------- - None - - """ - @abstractmethod def build_model( self, - X: pd.DataFrame, - y: pd.Series | np.ndarray, + X: pd.DataFrame | xr.Dataset | xr.DataArray, + y: pd.Series | np.ndarray | xr.DataArray, **kwargs, ) -> None: """Create an instance of `pm.Model` based on provided data and model_config. @@ -656,10 +623,30 @@ def load(cls, fname: str): ) raise DifferentModelError(error_msg) from e + def create_fit_data( + self, + X: pd.DataFrame | xr.Dataset | xr.DataArray, + y: np.ndarray | pd.Series | xr.DataArray, + ) -> xr.Dataset: + """Create the fit_data group based on the input data.""" + if isinstance(y, np.ndarray): + y = pd.Series(y, index=X.index, name=self.output_var) + + if y.name is None: + y.name = self.output_var + + if isinstance(X, pd.DataFrame): + X = X.to_xarray() + + if isinstance(y, pd.Series): + y = y.to_xarray() + + return xr.merge([X, y]) + def fit( self, - X: pd.DataFrame, - y: pd.Series | np.ndarray | None = None, + X: pd.DataFrame | xr.Dataset | xr.DataArray, + y: pd.Series | xr.DataArray | np.ndarray | None = None, progressbar: bool | None = None, random_seed: RandomState | None = None, **kwargs: Any, @@ -694,23 +681,23 @@ def fit( Initializing NUTS using jitter+adapt_diag... """ - if isinstance(y, pd.Series) and not X.index.equals(y.index): + if ( + isinstance(y, pd.Series) + and isinstance(X, pd.DataFrame) + and not X.index.equals(y.index) + ): raise ValueError("Index of X and y must match.") if y is None: y = np.zeros(X.shape[0]) - y_df = pd.DataFrame({self.output_var: y}, index=X.index) - self._generate_and_preprocess_model_data(X, y_df.values.flatten()) - if self.X is None or self.y is None: - raise ValueError("X and y must be set before calling build_model!") - if self.output_var in X.columns: + if self.output_var in X: raise ValueError( f"X includes a column named '{self.output_var}', which conflicts with the target variable." ) if not hasattr(self, "model"): - self.build_model(self.X, self.y) + self.build_model(X, y) sampler_kwargs = create_sample_kwargs( self.sampler_config, @@ -727,21 +714,18 @@ def fit( else: self.idata = idata - X_df = pd.DataFrame(X, columns=X.columns) - combined_data = pd.concat([X_df, y_df], axis=1) - if not all(combined_data.columns): - raise ValueError("All columns must have non-empty names") - if "fit_data" in self.idata: del self.idata.fit_data + fit_data = self.create_fit_data(X, y) + with warnings.catch_warnings(): warnings.filterwarnings( "ignore", category=UserWarning, message="The group fit_data is not defined in the InferenceData scheme", ) - self.idata.add_groups(fit_data=combined_data.to_xarray()) # type: ignore + self.idata.add_groups(fit_data=fit_data) self.set_idata_attrs(self.idata) return self.idata # type: ignore diff --git a/tests/conftest.py b/tests/conftest.py index 7c8943805..342785c55 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -124,12 +124,12 @@ def mock_sample(*args, **kwargs): """This is a mock of pm.sample that returns the prior predictive samples as the posterior.""" random_seed = kwargs.get("random_seed", None) model = kwargs.get("model", None) - samples = kwargs.get("draws", 10) + draws = kwargs.get("draws", 10) n_chains = kwargs.get("chains", 1) idata: InferenceData = pm.sample_prior_predictive( model=model, random_seed=random_seed, - samples=samples, + draws=draws, ) expanded_chains = DataArray( @@ -147,6 +147,16 @@ def mock_sample(*args, **kwargs): return idata +@pytest.fixture +def mock_pymc_sample(): + original_sample = pm.sample + pm.sample = mock_sample + + yield + + pm.sample = original_sample + + def mock_fit_MAP(self, *args, **kwargs): draws = 1 chains = 1 @@ -173,9 +183,7 @@ def fitted_bg(test_summary_data) -> BetaGeoModel: model_config=model_config, ) model.build_model() - fake_fit = pm.sample_prior_predictive( - samples=50, model=model.model, random_seed=rng - ) + fake_fit = pm.sample_prior_predictive(draws=50, model=model.model, random_seed=rng) # posterior group required to pass L80 assert check fake_fit.add_groups(posterior=fake_fit.prior) set_model_fit(model, fake_fit) @@ -205,7 +213,9 @@ def fitted_pnbd(test_summary_data) -> ParetoNBDModel: # Mock an idata object for tests requiring a fitted model # TODO: This is quite slow. Check similar fixtures in the model tests to speed this up. fake_fit = pm.sample_prior_predictive( - samples=50, model=pnbd_model.model, random_seed=rng + draws=50, + model=pnbd_model.model, + random_seed=rng, ) # posterior group required to pass L80 assert check fake_fit.add_groups(posterior=fake_fit.prior) diff --git a/tests/test_model_builder.py b/tests/test_model_builder.py index 3630e3709..494aa9c45 100644 --- a/tests/test_model_builder.py +++ b/tests/test_model_builder.py @@ -11,20 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# Copyright 2023 The PyMC Developers -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - import hashlib import json import sys @@ -106,24 +92,29 @@ def __init__(self, model_config=None, sampler_config=None, test_parameter=None): _model_type = "test_model" version = "0.1" - def build_model(self, X: pd.DataFrame, y: pd.Series, model_config=None): + def build_model(self, X: pd.DataFrame, y: pd.Series): coords = {"numbers": np.arange(len(X))} - self._generate_and_preprocess_model_data(X, y) + + y = y if isinstance(y, np.ndarray) else y.values + with pm.Model(coords=coords) as self.model: - if model_config is None: - model_config = self.default_model_config - x = pm.Data("x", self.X["input"].values) - y_data = pm.Data("y_data", self.y) + x = pm.Data("x", X["input"].values) + y_data = pm.Data("y_data", y) # prior parameters - a_loc = model_config["a"]["loc"] - a_scale = model_config["a"]["scale"] - b_loc = model_config["b"]["loc"] - b_scale = model_config["b"]["scale"] - obs_error = model_config["obs_error"] + a_loc = self.model_config["a"]["loc"] + a_scale = self.model_config["a"]["scale"] + b_loc = self.model_config["b"]["loc"] + b_scale = self.model_config["b"]["scale"] + obs_error = self.model_config["obs_error"] # priors - a = pm.Normal("a", a_loc, sigma=a_scale, dims=model_config["a"]["dims"]) + a = pm.Normal( + "a", + a_loc, + sigma=a_scale, + dims=self.model_config["a"]["dims"], + ) b = pm.Normal("b", b_loc, sigma=b_scale) obs_error = pm.HalfNormal("σ_model_fmc", obs_error) @@ -140,7 +131,7 @@ def create_idata_attrs(self): def output_var(self): return "output" - def _data_setter(self, X: pd.DataFrame, y: pd.Series = None): + def _data_setter(self, X: pd.DataFrame, y: pd.Series | None = None): with self.model: pm.set_data({"x": X["input"].values}) if y is not None: @@ -151,10 +142,6 @@ def _data_setter(self, X: pd.DataFrame, y: pd.Series = None): def _serializable_model_config(self): return self.model_config - def _generate_and_preprocess_model_data(self, X: pd.DataFrame, y: pd.Series): - self.X = X - self.y = y - @property def default_model_config(self) -> dict: return { @@ -279,9 +266,7 @@ def test_set_fit_result(toy_X, toy_y): model = ModelBuilderTest() model.build_model(X=toy_X, y=toy_y) model.idata = None - fake_fit = pm.sample_prior_predictive( - samples=50, model=model.model, random_seed=1234 - ) + fake_fit = pm.sample_prior_predictive(draws=50, model=model.model, random_seed=1234) fake_fit.add_groups(dict(posterior=fake_fit.prior)) model.fit_result = fake_fit with pytest.warns(UserWarning, match="Overriding pre-existing fit_result"): @@ -656,3 +641,96 @@ def test_X_pred_prior_deprecation(fitted_model_instance, toy_X, toy_y) -> None: assert isinstance(fitted_model_instance.prior, xr.Dataset) assert isinstance(fitted_model_instance.prior_predictive, xr.Dataset) + + +class XarrayModel(ModelBuilder): + """Multivariate Regression model.""" + + def build_model(self, X, y, **kwargs): + if isinstance(X, xr.Dataset): + X = X["x"] + + coords = { + "country": ["A", "B"], + "date": [0, 1], + } + with pm.Model(coords=coords) as self.model: + x = pm.Data("X", X.values, dims=("country", "date")) + y = pm.Data("y", y.values, dims=("country", "date")) + + alpha = pm.Normal("alpha", 0, 1, dims=("country",)) + beta = pm.Normal("beta", 0, 1, dims=("country",)) + + mu = alpha + beta * x + + sigma = pm.HalfNormal("sigma") + + pm.Normal("output", mu=mu, sigma=sigma, observed=y) + + def _data_setter(self, X, y=None): + pass + + @property + def _serializable_model_config(self): + return {} + + @property + def output_var(self): + return "output" + + @property + def default_model_config(self): + return {} + + @property + def default_sampler_config(self): + return {} + + +@pytest.fixture +def xarray_X() -> xr.Dataset: + return ( + pd.DataFrame( + { + "x": [1, 2, 3, 4], + "date": [0, 1, 0, 1], + "country": ["A", "A", "B", "B"], + } + ) + .set_index(["country", "date"]) + .to_xarray() + ) + + +@pytest.fixture +def xarray_y(xarray_X) -> xr.DataArray: + alpha = xr.DataArray( + [1, 2], + dims=["country"], + coords={"country": ["A", "B"]}, + ) + beta = xr.DataArray([1, 2], dims=["country"], coords={"country": ["A", "B"]}) + + return (alpha + beta * xarray_X["x"]).rename("output") + + +@pytest.mark.parametrize("X_is_array", [False, True], ids=["DataArray", "Dataset"]) +def test_xarray_model_builder(X_is_array, xarray_X, xarray_y, mock_pymc_sample) -> None: + model = XarrayModel() + + X = xarray_X if X_is_array else xarray_X["x"] + + model.fit(X, xarray_y) + + xr.testing.assert_equal( + model.idata.fit_data, # type: ignore + pd.DataFrame( + { + "x": [1, 2, 3, 4], + "output": [2, 3, 8, 10], + }, + index=pd.MultiIndex.from_tuples( + [("A", 0), ("A", 1), ("B", 0), ("B", 1)], names=["country", "date"] + ), + ).to_xarray(), + )