Skip to content

Commit

Permalink
ARIMA: Change armav to int32 and add gt>0 asserts
Browse files Browse the repository at this point in the history
  • Loading branch information
filipcacky committed Nov 20, 2024
1 parent b5bd7e3 commit bdef8f5
Show file tree
Hide file tree
Showing 9 changed files with 45 additions and 17 deletions.
2 changes: 1 addition & 1 deletion nbs/src/arima.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -658,7 +658,7 @@
" order[1],\n",
" seasonal['order'][1],\n",
" ],\n",
" dtype=np.uint32,\n",
" dtype=np.int32,\n",
" )\n",
" narma = arma[:4].sum().item()\n",
" \n",
Expand Down
2 changes: 2 additions & 0 deletions python/statsforecast/adapters/prophet.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""In 2017, Facebook open-sourced [Prophet](https://peerj.com/preprints/3190.pdf), with the promise of providing experts and non-experts the possibility of producing high-quality predictions. The forecasting community heavily adopted the solution, reaching millions of accumulated downloads. It became evident that its [quality is shadowed](https://www.reddit.com/r/MachineLearning/comments/wqrw8x/d_fool_me_once_shame_on_you_fool_me_twice_shame/) by simpler well-proven methods. This effort aims to provide an alternative to overcome the Prophet's memory.<br><br><div align="center">"It is important to note that false prophets sometimes prophesied accurately, ... " <br>(Deuteronomy 13:2,5) </div>"""

# AUTOGENERATED! DO NOT EDIT! File to edit: ../../../nbs/src/adapters.prophet.ipynb.

# %% auto 0
Expand Down
2 changes: 1 addition & 1 deletion python/statsforecast/arima.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ def maInvert(ma):
order[1],
seasonal["order"][1],
],
dtype=np.uint32,
dtype=np.int32,
)
narma = arma[:4].sum().item()

Expand Down
2 changes: 2 additions & 0 deletions python/statsforecast/core.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""Methods for Fit, Predict, Forecast (fast), Cross Validation and plotting"""

# AUTOGENERATED! DO NOT EDIT! File to edit: ../../nbs/src/core/core.ipynb.

# %% auto 0
Expand Down
2 changes: 2 additions & 0 deletions python/statsforecast/distributed/multiprocess.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""The computational efficiency of `StatsForecast` can be tracked to its two core components:<br>1. Its `models` written in NumBa that optimizes Python code to reach C speeds.<br>2. Its `core.StatsForecast` class that enables distributed computing.<br>This is a low-level class enabling other distribution methods.<br><br>"""

# AUTOGENERATED! DO NOT EDIT! File to edit: ../../../nbs/src/distributed.multiprocess.ipynb.

# %% auto 0
Expand Down
2 changes: 2 additions & 0 deletions python/statsforecast/feature_engineering.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""Generate features for downstream models"""

# AUTOGENERATED! DO NOT EDIT! File to edit: ../../nbs/src/feature_engineering.ipynb.

# %% auto 0
Expand Down
2 changes: 2 additions & 0 deletions python/statsforecast/models.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""Models currently supported by StatsForecast"""

# AUTOGENERATED! DO NOT EDIT! File to edit: ../../nbs/src/core/models.ipynb.

# %% auto 0
Expand Down
2 changes: 2 additions & 0 deletions python/statsforecast/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""The `core.StatsForecast` class allows you to efficiently fit multiple `StatsForecast` models for large sets of time series. It operates with pandas DataFrame `df` that identifies individual series and datestamps with the `unique_id` and `ds` columns, and the `y` column denotes the target time series variable. To assist development, we declare useful datasets that we use throughout all `StatsForecast`'s unit tests."""

# AUTOGENERATED! DO NOT EDIT! File to edit: ../../nbs/src/utils.ipynb.

# %% auto 0
Expand Down
46 changes: 31 additions & 15 deletions src/arima.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,18 +42,22 @@ void partrans(const uint32_t p, const std::span<const double> rawv,

std::tuple<py::array_t<double>, py::array_t<double>>
arima_transpar(const py::array_t<double> params_inv,
const py::array_t<uint32_t> armav, bool trans) {
const py::array_t<int32_t> armav, bool trans) {
assert(params_inv.ndim() == 1);
assert(armav.ndim() == 1);

const auto arma = make_cspan(armav);
const auto params_in = make_cspan(params_inv);

const uint32_t mp = arma[0];
const uint32_t mq = arma[1];
const uint32_t msp = arma[2];
const uint32_t msq = arma[3];
const uint32_t ns = arma[4];
assert(arma.size() == 7);
assert(arma[0] >= 0 && arma[1] >= 0 && arma[2] >= 0 && arma[3] >= 0 &&
arma[4] >= 0 && arma[5] >= 0 && arma[6] >= 0);

const int32_t mp = arma[0];
const int32_t mq = arma[1];
const int32_t msp = arma[2];
const int32_t msq = arma[3];
const int32_t ns = arma[4];
const uint32_t p = mp + ns * msp;
const uint32_t q = mq + ns * msq;

Expand Down Expand Up @@ -102,7 +106,7 @@ arima_transpar(const py::array_t<double> params_inv,
}

std::tuple<double, py::array_t<double>>
arima_css(const py::array_t<double> yv, const py::array_t<uint32_t> armav,
arima_css(const py::array_t<double> yv, const py::array_t<int32_t> armav,
const py::array_t<double> phiv, const py::array_t<double> thetav) {
assert(yv.ndim() == 1);
assert(armav.ndim() == 1);
Expand All @@ -119,6 +123,10 @@ arima_css(const py::array_t<double> yv, const py::array_t<uint32_t> armav,
const auto phi = make_cspan(phiv);
const auto theta = make_cspan(thetav);

assert(arma.size() == 7);
assert(arma[0] >= 0 && arma[1] >= 0 && arma[2] >= 0 && arma[3] >= 0 &&
arma[4] >= 0 && arma[5] >= 0 && arma[6] >= 0);

const uint32_t ncond = arma[0] + arma[5] + arma[4] * (arma[2] + arma[6]);
uint32_t nu = 0;
double ssq = 0.0;
Expand Down Expand Up @@ -601,7 +609,7 @@ void getQ0(const py::array_t<double> phiv, const py::array_t<double> thetav,
}

py::array_t<double> arima_gradtrans(const py::array_t<double> xv,
const py::array_t<uint32_t> armav) {
const py::array_t<int32_t> armav) {
assert(xv.ndim() == 1);
assert(armav.ndim() == 1);

Expand All @@ -610,9 +618,13 @@ py::array_t<double> arima_gradtrans(const py::array_t<double> xv,
const auto x = make_cspan(xv);
const size_t n = x.size();

const uint32_t mp = arma[0];
const uint32_t mq = arma[1];
const uint32_t msp = arma[2];
assert(arma.size() == 7);
assert(arma[0] >= 0 && arma[1] >= 0 && arma[2] >= 0 && arma[3] >= 0 &&
arma[4] >= 0 && arma[5] >= 0 && arma[6] >= 0);

const int32_t mp = arma[0];
const int32_t mq = arma[1];
const int32_t msp = arma[2];

std::array<double, 100> w1;
std::array<double, 100> w2;
Expand Down Expand Up @@ -662,16 +674,20 @@ py::array_t<double> arima_gradtrans(const py::array_t<double> xv,
}

py::array_t<double> arima_undopars(const py::array_t<double> xv,
const py::array_t<uint32_t> armav) {
const py::array_t<int32_t> armav) {
assert(xv.ndim() == 1);
assert(armav.ndim() == 1);

const auto x = make_cspan(xv);
const auto arma = make_cspan(armav);

const uint32_t mp = arma[0];
const uint32_t mq = arma[1];
const uint32_t msp = arma[2];
assert(arma.size() == 7);
assert(arma[0] >= 0 && arma[1] >= 0 && arma[2] >= 0 && arma[3] >= 0 &&
arma[4] >= 0 && arma[5] >= 0 && arma[6] >= 0);

const int32_t mp = arma[0];
const int32_t mq = arma[1];
const int32_t msp = arma[2];

py::array_t<double> outv(xv.size());
const auto out = make_span(outv);
Expand Down

0 comments on commit bdef8f5

Please sign in to comment.