-
Notifications
You must be signed in to change notification settings - Fork 483
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Major] Dataloader: Just-In-Time tabularization (#1529)
* minimal pytest * move_func_getitem * slicing * predict_mode * typos * lr-finder * drop_missing * predict_v2 * predict_v3 * samples * lagged regressor n_lags * preliminary: events, holidays * adjustes pytests * selective forecasting * black * ruff * lagged_regressors * Note down df path to TimeDataset * complete notes on TimeDataset, move meta * Big rewrite with real and pseudocode * create_target_start_end_mask * boolean mask * combine masks into map * notes for nan check * bypass NAN filter * rework index to point at prediction origin, not first forecast. * tabularize: converted time and lags to single sample extraction * convert lagged regressors * consolidate seasonality computation in one script * finish Seasonlity conversion * update todos * complete targets and future regressors * convert events * finish events and holidays conversion * debug timedataset * debugging * make_country_specific_holidays_df * remove uses of df.loc[...].values * debug time * debugging types * debug timedata * debugging time_dataset variable shapes * address indexing and slicing issues, .loc * fix dimensions except nonstationary components * integrate torch formatting into tabularize * check shapes * AirPassengers test working! * fix dataset generator * fixed all performance tests but Energy due to nonstationary components * fixed nonstationary issue. all performance tests running * refactor tabularize function * fix bug * initial build of GlobalTimeDataset * refactor TimeDataset not to use kwargs passthrough * debugged seasonal components call of TimeDataset * fix numpy object type error * fix seasonality condition bugs * fix events and future regressor cases * fixing prediction frequency filter * performance_test_energy * debug events * convert new energytest to daily data * fix events util reference * fix test_get_country_holidays * fix test_timedataset_minima * fix selective forecasting * cleanup timedataset * refactor tabularize_univariate * daily_data * start nan check for smaple mask * working on time nan2 * fix tests * finish nan-check * fix dims * pass self.df to indexing * fix zero dim lagged regressors * close figures in tests * fix typings * black * ruff * linting * linting * modify logs * add benchmarking script for computational time * speed up uncertainty tests * fix unit test multiple country * reduce tests log level to ERROR * reduce log level to ERROR and fix adding multiple countries * bypass intentional glocal test error log * fix prev * benchmark dataloader time * remove hourly energy test * add debug notebook for energy hourly * set to log model performance INFO * address config_regressors.regressors * clean up create_nan_mask * clean up create_nan_mask params * clean TimeDataframe * update prediction frequency documentation * improve prediction frequency documentation * further improve prediction frequency documentation * fix test errors * fix df_names call * fix selective prediction assertion * normalize holiday naes * fix linting * fix tests * update to use new holiday functions in event_utils.py * fix seasonality_local_reg test * limit holidays to less than 1.0 * changed holidays * update lock * changed tests * adjsuted tests * fix reserved names * fixed ruff lintint * changed test * translate holidays to english is possible * exclude py3.13 * update lock * Merge all holidays related tests in one file * add deterministic flag * fixed ruff linting issues * fixed glocal test * fix lock file * update poetry * moved the deterministic flag to the train method * update lock file --------- Co-authored-by: Simon W <[email protected]> Co-authored-by: MaiBe-ctrl <[email protected]> Co-authored-by: Maisa Ben Salah <[email protected]>
- Loading branch information
1 parent
e90bf5a
commit 1bfa633
Showing
29 changed files
with
4,251 additions
and
881 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,20 @@ | ||
NeuralProphet Class | ||
----------------------- | ||
Core Module Documentation | ||
========================== | ||
|
||
.. toctree:: | ||
:hidden: | ||
:maxdepth: 1 | ||
|
||
configure.py <configure> | ||
df_utils.py <df_utils> | ||
event_utils.py <event_utils> | ||
plot_forecast_plotly.py <plot_forecast_plotly> | ||
plot_forecast_matplotlib.py <plot_forecast_matplotlib> | ||
plot_model_parameters_plotly.py <plot_model_parameters_plotly> | ||
plot_model_parameters_matplotlib.py <plot_model_parameters_matplotlib> | ||
time_dataset.py <time_dataset> | ||
time_net.py <time_net> | ||
utils.py <utils> | ||
|
||
.. automodule:: neuralprophet.forecaster | ||
:members: |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
from collections import defaultdict | ||
from typing import Iterable, Union | ||
|
||
import numpy as np | ||
import pandas as pd | ||
from holidays import country_holidays | ||
|
||
|
||
def get_holiday_names(country: Union[str, Iterable[str]], df=None): | ||
""" | ||
Return all possible holiday names for a list of countries over time period in df | ||
Parameters | ||
---------- | ||
country : str, list | ||
List of country names to retrieve country specific holidays | ||
df : pd.Dataframe | ||
Dataframe from which datestamps will be retrieved from | ||
Returns | ||
------- | ||
set | ||
All possible holiday names of given country | ||
""" | ||
if df is None: | ||
years = np.arange(1995, 2045) | ||
else: | ||
dates = df["ds"].copy(deep=True) | ||
years = pd.unique(dates.apply(lambda x: x.year)) | ||
# years = list({x.year for x in dates}) | ||
# support multiple countries, convert to list if not already | ||
if isinstance(country, str): | ||
country = [country] | ||
|
||
all_holidays = get_all_holidays(years=years, country=country) | ||
return set(all_holidays.keys()) | ||
|
||
|
||
def get_all_holidays(years, country): | ||
""" | ||
Make dataframe of country specific holidays for given years and countries | ||
Parameters | ||
---------- | ||
year_list : list | ||
List of years | ||
country : str, list, dict | ||
List of country names and optional subdivisions | ||
Returns | ||
------- | ||
pd.DataFrame | ||
Containing country specific holidays df with columns 'ds' and 'holiday' | ||
""" | ||
# convert to list if not already | ||
if isinstance(country, str): | ||
country = {country: None} | ||
elif isinstance(country, list): | ||
country = dict(zip(country, [None] * len(country))) | ||
|
||
all_holidays = defaultdict(list) | ||
# iterate over countries and get holidays for each country | ||
for single_country, subdivision in country.items(): | ||
# For compatibility with Turkey as "TU" cases. | ||
single_country = "TUR" if single_country == "TU" else single_country | ||
# get dict of dates and their holiday name | ||
single_country_specific_holidays = country_holidays( | ||
country=single_country, subdiv=subdivision, years=years, expand=True, observed=False, language="en" | ||
) | ||
# invert order - for given holiday, store list of dates | ||
for date, name in single_country_specific_holidays.items(): | ||
all_holidays[name].append(pd.to_datetime(date)) | ||
return all_holidays |
Oops, something went wrong.