Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FIX] Unify API #1023

Open
wants to merge 85 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
Show all changes
85 commits
Select commit Hold shift + click to select a range
72771c0
draft
elephaint May 31, 2024
dd9f26e
next_iteration
elephaint Jun 4, 2024
e0ee8d1
next_iter
elephaint Jun 9, 2024
e7bbf30
next_iteration
elephaint Jun 11, 2024
3419432
next_iter
elephaint Jun 13, 2024
75bea55
draft
elephaint May 31, 2024
ef019d1
next_iteration
elephaint Jun 4, 2024
4313c13
next_iter
elephaint Jun 9, 2024
8101656
next_iteration
elephaint Jun 11, 2024
14fbf32
next_iter
elephaint Jun 13, 2024
ae6d73c
merge_main
elephaint Jun 14, 2024
302489e
fix_iql_and_isqf
elephaint Jun 14, 2024
0dcb6a2
fix_mixture_losses
elephaint Jun 14, 2024
9160647
add_quantile_to_distributionloss_predict
elephaint Jun 15, 2024
b73c097
add_quantile_to_mixture_loss_predict
elephaint Jun 16, 2024
20c18c5
bugfixes
elephaint Jun 18, 2024
a26ac29
fix_bugs
elephaint Jul 11, 2024
b20fe3f
fix_multivariate_bugs
elephaint Jul 12, 2024
1070f1d
Merge branch 'main' into fix/docs_and_refactoring
elephaint Jul 15, 2024
452388f
fix_json
elephaint Jul 15, 2024
2d3762f
Merge branch 'main' into fix/docs_and_refactoring
elephaint Jul 22, 2024
2419eb5
Merge branch 'main' into fix/docs_and_refactoring
elephaint Jul 26, 2024
bffa8d1
merge_main
elephaint Jul 26, 2024
f02b50f
merge_main
elephaint Sep 24, 2024
f80c59b
fix_examples_and_mixture_loss_bug
elephaint Sep 24, 2024
a60498b
add_exceptions_and_add_dev_dep_for_ci
elephaint Sep 24, 2024
b5ba554
fix_failing_polars_test
elephaint Sep 24, 2024
f4de0ff
fix_tests
elephaint Sep 25, 2024
a4ec70d
fix_tests
elephaint Sep 25, 2024
706ef74
fix_tests
elephaint Sep 25, 2024
829fc17
fix_docs_multivariate
elephaint Sep 25, 2024
efe2e76
fix_tests_in_models
elephaint Sep 25, 2024
47c36f7
reduce_multivariate_test_time
elephaint Sep 25, 2024
998e813
remove_stemgnn_from_test_and_add_contiguous
elephaint Sep 25, 2024
99c4b14
remove_contiguous_static
elephaint Sep 25, 2024
a4e4ee7
change_contiguous_windows
elephaint Sep 25, 2024
ff89950
improve_speed
elephaint Sep 26, 2024
b3fafc3
reduce_default_windows_batch_size_multivariate
elephaint Sep 26, 2024
87af3ac
fix_rnn_models
elephaint Sep 27, 2024
af070a9
Merge branch 'main' into fix/docs_and_refactoring
elephaint Oct 8, 2024
ec32f28
improve_dilated_rnn
elephaint Oct 8, 2024
2801f19
fix_scalar_autodilatedrnn
elephaint Oct 8, 2024
6c3b2af
improve_speed_dilatedrnn_test
elephaint Oct 8, 2024
ccf8b2d
improve_speed_tests
elephaint Oct 8, 2024
8cba223
fix_loss_detach
elephaint Oct 8, 2024
e35f5e1
improve_speed_of_tests
elephaint Oct 8, 2024
5fc0437
fix_contiguous_multivariate
elephaint Oct 8, 2024
9c52adb
maybe_improve_drnn_speed
elephaint Oct 8, 2024
9d5a2bc
test_move_contiguous_for_better_performance
elephaint Oct 9, 2024
97507f0
improve_speed
elephaint Oct 9, 2024
5494554
Merge branch 'main' into fix/docs_and_refactoring
elephaint Oct 9, 2024
e9bc822
try_fix_slow_test
elephaint Oct 9, 2024
baf7014
improve_speed_recurrent_models
elephaint Oct 10, 2024
fffbda3
Merge branch 'main' into fix/docs_and_refactoring
elephaint Oct 10, 2024
9c727cc
improve_speed_tcn
elephaint Oct 10, 2024
1a0ba55
Merge branch 'fix/docs_and_refactoring' of https://github.com/Nixtla/…
elephaint Oct 10, 2024
6bb64be
windows_without_contiguous
elephaint Oct 10, 2024
a8a9362
Merge branch 'main' into fix/docs_and_refactoring
elephaint Oct 10, 2024
d6e24de
merge_main
elephaint Oct 10, 2024
430732f
try_improve_nhits_bitcn_speed
elephaint Oct 11, 2024
6a472dc
reduce_test_time_models
elephaint Oct 11, 2024
ae49324
improve_losses
elephaint Oct 11, 2024
d681fdf
Merge branch 'main' into fix/docs_and_refactoring
elephaint Oct 11, 2024
abe522b
change_forward_to_call_losses
elephaint Oct 11, 2024
932fd55
Merge branch 'main' into fix/docs_and_refactoring
elephaint Oct 15, 2024
1f52b8e
Merge branch 'main' into fix/docs_and_refactoring
elephaint Oct 15, 2024
0b980c0
fix_linting
elephaint Oct 15, 2024
63984e6
unify_quantile_and_level_in_predict
elephaint Oct 17, 2024
bbea7ba
Merge branch 'main' into fix/docs_and_refactoring
elephaint Oct 17, 2024
6f2272c
fix_parameter_errors
elephaint Oct 17, 2024
a4c8b54
rework_conformal
elephaint Oct 17, 2024
8ee4592
quantile_maybe_used
elephaint Oct 17, 2024
96ab536
fix_non_monotonic_iq_loss_and_redundant_cv_conformal
elephaint Oct 17, 2024
030dabe
Merge branch 'main' into fix/docs_and_refactoring
elephaint Nov 4, 2024
ddc617f
fix_batch_size_max_multivariate
elephaint Nov 19, 2024
c529ced
merge_main
elephaint Nov 19, 2024
b2c7691
fix_base_model
elephaint Nov 19, 2024
42d1abe
Merge branch 'main' into fix/docs_and_refactoring
elephaint Nov 26, 2024
862ec70
Merge branch 'main' into fix/docs_and_refactoring
elephaint Jan 31, 2025
88ae192
minor_merge_to_main_fixes
elephaint Jan 31, 2025
a42848d
fix_horizon_weight_distribution_loss
elephaint Jan 31, 2025
512e071
fix_horizon_weights
elephaint Feb 2, 2025
eddf9b3
clean_up_eval
elephaint Feb 3, 2025
99e4631
merge_main
elephaint Feb 13, 2025
98d7770
clean_up
elephaint Feb 13, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix_non_monotonic_iq_loss_and_redundant_cv_conformal
elephaint committed Oct 17, 2024
commit 96ab536406ee81b83e027be094d528de63a98241
46 changes: 21 additions & 25 deletions nbs/core.ipynb
Original file line number Diff line number Diff line change
@@ -338,6 +338,7 @@
" # Flags and attributes\n",
" self._fitted = False\n",
" self._reset_models()\n",
" self._add_level = False\n",
"\n",
" def _scalers_fit_transform(self, dataset: TimeSeriesDataset) -> None:\n",
" self.scalers_ = {} \n",
@@ -1030,27 +1031,6 @@
" _warn_id_as_idx()\n",
" fcsts_df = fcsts_df.set_index(self.id_col)\n",
"\n",
" # # add prediction intervals or quantiles to models trained with point loss functions via level argument\n",
" # if level is not None or quantiles is not None:\n",
" # model_names = self._get_model_names(add_level=True)\n",
" # if model_names:\n",
" # if self.prediction_intervals is None:\n",
" # raise AttributeError(\n",
" # \"You have trained one or more models with a point loss function (e.g. MAE, MSE). \"\n",
" # \"You then must set `prediction_intervals` during fit to use level or quantiles during predict.\") \n",
" # prediction_interval_method = get_prediction_interval_method(self.prediction_intervals.method)\n",
"\n",
" # fcsts_df = prediction_interval_method(\n",
" # fcsts_df,\n",
" # self._cs_df,\n",
" # model_names=list(model_names),\n",
" # level=level_ if level is not None else None,\n",
" # cs_n_windows=self.prediction_intervals.n_windows,\n",
" # n_series=len(uids),\n",
" # horizon=self.h,\n",
" # quantiles=quantiles_ if quantiles is not None else None,\n",
" # ) \n",
"\n",
" return fcsts_df\n",
"\n",
" def _reset_models(self):\n",
@@ -1111,6 +1091,9 @@
"\n",
" fcsts_list: List = []\n",
" for model in self.models:\n",
" if self._add_level and (model.loss.outputsize_multiplier > 1 or isinstance(model.loss, IQLoss)):\n",
" continue\n",
"\n",
" model.fit(dataset=self.dataset,\n",
" val_size=val_size, \n",
" test_size=test_size)\n",
@@ -1147,7 +1130,7 @@
" self._fitted = True\n",
"\n",
" # Add predictions to forecasts DataFrame\n",
" cols = self._get_model_names()\n",
" cols = self._get_model_names(add_level=self._add_level)\n",
" if isinstance(self.uids, pl_Series):\n",
" fcsts = pl_DataFrame(dict(zip(cols, fcsts.T)))\n",
" else:\n",
@@ -1678,6 +1661,7 @@
" \"Please reduce the number of windows, horizon or remove those series.\"\n",
" )\n",
" \n",
" self._add_level = True\n",
" cv_results = self.cross_validation(\n",
" df=df,\n",
" static_df=static_df,\n",
@@ -1686,7 +1670,8 @@
" time_col=time_col,\n",
" target_col=target_col,\n",
" )\n",
" \n",
" self._add_level = False\n",
"\n",
" kept = [time_col, id_col, 'cutoff']\n",
" # conformity score for each model\n",
" for model in self._get_model_names(add_level=True):\n",
@@ -1730,10 +1715,21 @@
" cols.extend(col_names)\n",
" # case 2: IQLoss\n",
" elif quantiles_ is not None and isinstance(model.loss, IQLoss):\n",
" # IQLoss does not give monotonically increasing quantiles, so we apply a hack: compute all quantiles, and take the quantile over the quantiles\n",
" quantiles_iqloss = np.linspace(0.01, 0.99, 20)\n",
" fcsts_list_iqloss = []\n",
" for i, quantile in enumerate(quantiles_iqloss):\n",
" model_fcsts = model.predict(dataset=dataset, quantiles = [quantile], **data_kwargs) \n",
" fcsts_list_iqloss.append(model_fcsts) \n",
" fcsts_iqloss = np.concatenate(fcsts_list_iqloss, axis=-1)\n",
"\n",
" # Get the actual requested quantiles\n",
" model_fcsts = np.quantile(fcsts_iqloss, quantiles_, axis=-1).T\n",
" fcsts_list.append(model_fcsts) \n",
"\n",
" # Get the right column names\n",
" col_names = []\n",
" for i, quantile in enumerate(quantiles_):\n",
" model_fcsts = model.predict(dataset=dataset, quantiles = [quantile], **data_kwargs)\n",
" fcsts_list.append(model_fcsts) \n",
" col_name = self._get_column_name(model_name, quantile, has_level)\n",
" col_names.extend([col_name]) \n",
" cols.extend(col_names)\n",
48 changes: 23 additions & 25 deletions neuralforecast/core.py
Original file line number Diff line number Diff line change
@@ -270,6 +270,7 @@ def __init__(
# Flags and attributes
self._fitted = False
self._reset_models()
self._add_level = False

def _scalers_fit_transform(self, dataset: TimeSeriesDataset) -> None:
self.scalers_ = {}
@@ -998,27 +999,6 @@ def predict(
_warn_id_as_idx()
fcsts_df = fcsts_df.set_index(self.id_col)

# # add prediction intervals or quantiles to models trained with point loss functions via level argument
# if level is not None or quantiles is not None:
# model_names = self._get_model_names(add_level=True)
# if model_names:
# if self.prediction_intervals is None:
# raise AttributeError(
# "You have trained one or more models with a point loss function (e.g. MAE, MSE). "
# "You then must set `prediction_intervals` during fit to use level or quantiles during predict.")
# prediction_interval_method = get_prediction_interval_method(self.prediction_intervals.method)

# fcsts_df = prediction_interval_method(
# fcsts_df,
# self._cs_df,
# model_names=list(model_names),
# level=level_ if level is not None else None,
# cs_n_windows=self.prediction_intervals.n_windows,
# n_series=len(uids),
# horizon=self.h,
# quantiles=quantiles_ if quantiles is not None else None,
# )

return fcsts_df

def _reset_models(self):
@@ -1082,6 +1062,11 @@ def _no_refit_cross_validation(

fcsts_list: List = []
for model in self.models:
if self._add_level and (
model.loss.outputsize_multiplier > 1 or isinstance(model.loss, IQLoss)
):
continue

model.fit(dataset=self.dataset, val_size=val_size, test_size=test_size)
model_fcsts = model.predict(
self.dataset, step_size=step_size, **data_kwargs
@@ -1118,7 +1103,7 @@ def _no_refit_cross_validation(
self._fitted = True

# Add predictions to forecasts DataFrame
cols = self._get_model_names()
cols = self._get_model_names(add_level=self._add_level)
if isinstance(self.uids, pl_Series):
fcsts = pl_DataFrame(dict(zip(cols, fcsts.T)))
else:
@@ -1678,6 +1663,7 @@ def _conformity_scores(
"Please reduce the number of windows, horizon or remove those series."
)

self._add_level = True
cv_results = self.cross_validation(
df=df,
static_df=static_df,
@@ -1686,6 +1672,7 @@ def _conformity_scores(
time_col=time_col,
target_col=target_col,
)
self._add_level = False

kept = [time_col, id_col, "cutoff"]
# conformity score for each model
@@ -1751,12 +1738,23 @@ def _generate_forecasts(
cols.extend(col_names)
# case 2: IQLoss
elif quantiles_ is not None and isinstance(model.loss, IQLoss):
col_names = []
for i, quantile in enumerate(quantiles_):
# IQLoss does not give monotonically increasing quantiles, so we apply a hack: compute all quantiles, and take the quantile over the quantiles
quantiles_iqloss = np.linspace(0.01, 0.99, 20)
fcsts_list_iqloss = []
for i, quantile in enumerate(quantiles_iqloss):
model_fcsts = model.predict(
dataset=dataset, quantiles=[quantile], **data_kwargs
)
fcsts_list.append(model_fcsts)
fcsts_list_iqloss.append(model_fcsts)
fcsts_iqloss = np.concatenate(fcsts_list_iqloss, axis=-1)

# Get the actual requested quantiles
model_fcsts = np.quantile(fcsts_iqloss, quantiles_, axis=-1).T
fcsts_list.append(model_fcsts)

# Get the right column names
col_names = []
for i, quantile in enumerate(quantiles_):
col_name = self._get_column_name(model_name, quantile, has_level)
col_names.extend([col_name])
cols.extend(col_names)