Skip to content

Commit

Permalink
Merge pull request #25 from felipeangelimvieira/fix/prior_loc_scale
Browse files Browse the repository at this point in the history
Fix prior loc scale
  • Loading branch information
felipeangelimvieira authored May 10, 2024
2 parents 0566957 + 14d28b4 commit ca257f2
Show file tree
Hide file tree
Showing 5 changed files with 290 additions and 107 deletions.
81 changes: 42 additions & 39 deletions docs/examples/hierarchical.ipynb

Large diffs are not rendered by default.

277 changes: 214 additions & 63 deletions docs/examples/univariate.ipynb

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion src/prophetverse/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,9 @@ def univariate_model(
noise_scale = numpyro.sample("noise_scale", dist.HalfNormal(noise_scale))

with numpyro.plate("data", len(mean), dim=-2) as time_plate:
numpyro.sample(
s = numpyro.sample(
"obs",
dist.Normal(mean.reshape((-1, 1)), noise_scale),
obs=y,
)
s
27 changes: 24 additions & 3 deletions src/prophetverse/trend/piecewise.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import numpyro
import numpyro.distributions as dist
import pandas as pd
from sktime.transformations.series.detrend import Detrender

from prophetverse.utils.frame_to_array import series_to_tensor

Expand All @@ -21,13 +22,17 @@ def __init__(
changepoint_prior_scale: dist.Distribution,
offset_prior_scale=0.1,
squeeze_if_single_series: bool = True,
remove_seasonality_before_suggesting_initial_vals: bool = True,
**kwargs
):
self.changepoint_interval = changepoint_interval
self.changepoint_range = changepoint_range
self.changepoint_prior_scale = changepoint_prior_scale
self.offset_prior_scale = offset_prior_scale
self.squeeze_if_single_series = squeeze_if_single_series
self.remove_seasonality_before_suggesting_initial_vals = (
remove_seasonality_before_suggesting_initial_vals
)
super().__init__(**kwargs)

def initialize(self, y: pd.DataFrame):
Expand Down Expand Up @@ -158,11 +163,10 @@ def _setup_changepoints(self, t_scaled) -> None:
changepoint_range=changepoint_range,
)
)

if len(changepoint_ts[-1]) == 0:
raise ValueError(

Check warning on line 168 in src/prophetverse/trend/piecewise.py

View check run for this annotation

Codecov / codecov/patch

src/prophetverse/trend/piecewise.py#L168

Added line #L168 was not covered by tests
f"No changepoints were generated. Try increasing the changing the changepoint_range. There are {len(t_scaled)} timepoints in the series, changepoint_range is {changepoint_range} and changepoint_interval is {changepoint_interval}.")


self._changepoint_ts = changepoint_ts

Expand All @@ -176,6 +180,11 @@ def _setup_changepoint_prior_vectors(self, y: pd.DataFrame) -> None:
Returns:
None
"""

if self.remove_seasonality_before_suggesting_initial_vals:
detrender = Detrender()
y = y - detrender.fit_transform(y)

self.global_rates, self.offset_loc = self._suggest_global_trend_and_offset(y)
self._changepoint_prior_loc, self._changepoint_prior_scale = (
self._get_changepoint_prior_vectors(global_rates=self.global_rates)
Expand Down Expand Up @@ -266,8 +275,20 @@ def compute_trend(self, changepoint_matrix: jnp.ndarray) -> jnp.ndarray:
None
"""

if isinstance(self.changepoint_prior_scale, (list, tuple)):
offset_scale = [x * self.offset_prior_scale for x in self.changepoint_prior_scale]

Check warning on line 280 in src/prophetverse/trend/piecewise.py

View check run for this annotation

Codecov / codecov/patch

src/prophetverse/trend/piecewise.py#L280

Added line #L280 was not covered by tests
elif isinstance(self.changepoint_prior_scale, (int, float)):
offset_scale = self.changepoint_prior_scale * self.offset_prior_scale
else:
raise ValueError(f"Invalid type for changepoint_prior_scale {self.changepoint_prior_scale}")

Check warning on line 284 in src/prophetverse/trend/piecewise.py

View check run for this annotation

Codecov / codecov/patch

src/prophetverse/trend/piecewise.py#L284

Added line #L284 was not covered by tests

offset = numpyro.sample(
"offset", dist.Normal(self.offset_loc, self.offset_prior_scale)
"offset",
dist.Normal(
self.offset_loc,
offset_scale
),
)

changepoint_coefficients = numpyro.sample(
Expand Down
9 changes: 8 additions & 1 deletion tests/sktime/test_expand_column_per_level.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,14 @@ def test_fit_identifies_matched_columns():
assert "value2" in transformer.matched_columns_
assert "other" not in transformer.matched_columns_

X = X.loc[("series1")]

transformer = ExpandColumnPerLevel(columns_regex=["value"])
transformer.fit(X)

assert "value1" in transformer.matched_columns_
assert "value2" in transformer.matched_columns_
assert "other" not in transformer.matched_columns_

def test_transform_expands_columns():
"""
Expand All @@ -55,7 +63,6 @@ def test_transform_expands_columns():
assert all(col in X_transformed.columns for col in expected_columns)



def test_transform_preserves_original_data():
"""
Test that the transform method preserves the original data in the newly expanded columns.
Expand Down

0 comments on commit ca257f2

Please sign in to comment.