Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Major] Support Re-Training #1635

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
150 changes: 109 additions & 41 deletions neuralprophet/configure.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,24 @@
@dataclass
class Model:
lagged_reg_layers: Optional[List[int]]
quantiles: Optional[List[float]] = None

def setup_quantiles(self):
# convert quantiles to empty list [] if None
if self.quantiles is None:
self.quantiles = []
# assert quantiles is a list type
assert isinstance(self.quantiles, list), "Quantiles must be provided as list."
# check if quantiles are float values in (0, 1)
assert all(
0 < quantile < 1 for quantile in self.quantiles
), "The quantiles specified need to be floats in-between (0, 1)."
# sort the quantiles
self.quantiles.sort()
# check if quantiles contain 0.5 or close to 0.5, remove if so as 0.5 will be inserted again as first index
self.quantiles = [quantile for quantile in self.quantiles if not math.isclose(0.5, quantile)]
# 0 is the median quantile index
self.quantiles.insert(0, 0.5)


@dataclass
Expand Down Expand Up @@ -92,9 +110,9 @@
batch_size: Optional[int]
loss_func: Union[str, torch.nn.modules.loss._Loss, Callable]
optimizer: Union[str, Type[torch.optim.Optimizer]]
quantiles: List[float] = field(default_factory=list)
# quantiles: List[float] = field(default_factory=list)
optimizer_args: dict = field(default_factory=dict)
scheduler: Optional[Type[torch.optim.lr_scheduler.OneCycleLR]] = None
scheduler: Optional[Union[str, Type[torch.optim.lr_scheduler.LRScheduler]]] = None
scheduler_args: dict = field(default_factory=dict)
newer_samples_weight: float = 1.0
newer_samples_start: float = 0.0
Expand All @@ -104,18 +122,21 @@
n_data: int = field(init=False)
loss_func_name: str = field(init=False)
lr_finder_args: dict = field(default_factory=dict)
optimizer_state: dict = field(default_factory=dict)
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move to separate PR

continue_training: bool = False
trainer_config: dict = field(default_factory=dict)

def __post_init__(self):
# assert the uncertainty estimation params and then finalize the quantiles
self.set_quantiles()
assert self.newer_samples_weight >= 1.0
assert self.newer_samples_start >= 0.0
assert self.newer_samples_start < 1.0
self.set_loss_func()
self.set_optimizer()
self.set_scheduler()
# self.set_loss_func(self.quantiles)

def set_loss_func(self):
# called in TimeNet configure_optimizers:
# self.set_optimizer()
# self.set_scheduler()

def set_loss_func(self, quantiles: List[float]):
if isinstance(self.loss_func, str):
if self.loss_func.lower() in ["smoothl1", "smoothl1loss", "huber"]:
# keeping 'huber' for backwards compatiblility, though not identical
Expand All @@ -135,25 +156,8 @@
self.loss_func_name = type(self.loss_func).__name__
else:
raise NotImplementedError(f"Loss function {self.loss_func} not found")
if len(self.quantiles) > 1:
self.loss_func = PinballLoss(loss_func=self.loss_func, quantiles=self.quantiles)

def set_quantiles(self):
# convert quantiles to empty list [] if None
if self.quantiles is None:
self.quantiles = []
# assert quantiles is a list type
assert isinstance(self.quantiles, list), "Quantiles must be in a list format, not None or scalar."
# check if quantiles contain 0.5 or close to 0.5, remove if so as 0.5 will be inserted again as first index
self.quantiles = [quantile for quantile in self.quantiles if not math.isclose(0.5, quantile)]
# check if quantiles are float values in (0, 1)
assert all(
0 < quantile < 1 for quantile in self.quantiles
), "The quantiles specified need to be floats in-between (0, 1)."
# sort the quantiles
self.quantiles.sort()
# 0 is the median quantile index
self.quantiles.insert(0, 0.5)
if len(quantiles) > 1:
self.loss_func = PinballLoss(loss_func=self.loss_func, quantiles=quantiles)

def set_auto_batch_epoch(
self,
Expand Down Expand Up @@ -182,26 +186,87 @@
"""
Set the optimizer and optimizer args. If optimizer is a string, then it will be converted to the corresponding
torch optimizer. The optimizer is not initialized yet as this is done in configure_optimizers in TimeNet.

Parameters
----------
optimizer_name : int
Object provided to NeuralProphet as optimizer.
optimizer_args : dict
Arguments for the optimizer.

"""
self.optimizer, self.optimizer_args = utils_torch.create_optimizer_from_config(
self.optimizer, self.optimizer_args
)
if isinstance(self.optimizer, str):
if self.optimizer.lower() == "adamw":
# Tends to overfit, but reliable
self.optimizer = torch.optim.AdamW
self.optimizer_args["weight_decay"] = 1e-3
elif self.optimizer.lower() == "sgd":
# better validation performance, but diverges sometimes
self.optimizer = torch.optim.SGD
self.optimizer_args["momentum"] = 0.9
self.optimizer_args["weight_decay"] = 1e-4
else:
raise ValueError(
f"The optimizer name {self.optimizer} is not supported. Please pass the optimizer class."
)
elif not issubclass(self.optimizer, torch.optim.Optimizer):
raise ValueError("The provided optimizer is not supported.")

def set_scheduler(self):
"""
Set the scheduler and scheduler args.
Set the scheduler and scheduler arg depending on the user selection.
The scheduler is not initialized yet as this is done in configure_optimizers in TimeNet.
"""
self.scheduler = torch.optim.lr_scheduler.OneCycleLR
self.scheduler_args.update(
{
"pct_start": 0.3,
"anneal_strategy": "cos",
"div_factor": 10.0,
"final_div_factor": 10.0,
"three_phase": True,
}
)
if self.continue_training:
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move to other PR

if (isinstance(self.scheduler, str) and self.scheduler.lower() == "onecyclelr") or isinstance(
self.scheduler, torch.optim.lr_scheduler.OneCycleLR
):
log.warning(
"OneCycleLR scheduler is not supported for continued training. Please set another scheduler. Falling back to ExponentialLR scheduler"
)
self.scheduler = "exponentiallr"

if self.scheduler is None:
log.warning("No scheduler specified. Falling back to ExponentialLR scheduler.")
self.scheduler = "exponentiallr"

if isinstance(self.scheduler, str):
if self.scheduler.lower() == "onecyclelr":
self.scheduler = torch.optim.lr_scheduler.OneCycleLR
defaults = {
"pct_start": 0.3,
"anneal_strategy": "cos",
"div_factor": 10.0,
"final_div_factor": 10.0,
"three_phase": True,
}
elif self.scheduler.lower() == "steplr":
self.scheduler = torch.optim.lr_scheduler.StepLR
defaults = {
"step_size": 10,
"gamma": 0.1,
}
elif self.scheduler.lower() == "exponentiallr":
self.scheduler = torch.optim.lr_scheduler.ExponentialLR
defaults = {
"gamma": 0.95,
}
elif self.scheduler.lower() == "cosineannealinglr":
self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR
defaults = {
"T_max": 50,
}
else:
raise NotImplementedError(
f"Scheduler {self.scheduler} is not supported from string. Please pass the scheduler class."
)
if self.scheduler_args is not None:
defaults.update(self.scheduler_args)
self.scheduler_args = defaults
else:
assert issubclass(
self.scheduler, torch.optim.lr_scheduler.LRScheduler
), "Scheduler must be a subclass of torch.optim.lr_scheduler.LRScheduler"

def set_lr_finder_args(self, dataset_size, num_batches):
"""
Expand Down Expand Up @@ -239,6 +304,9 @@
delay_weight = 1
return delay_weight

def set_optimizer_state(self, optimizer_state: dict):
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move to other PR

self.optimizer_state = optimizer_state


@dataclass
class Trend:
Expand Down Expand Up @@ -304,7 +372,7 @@
log.error("Invalid growth for global_local mode '{}'. Set to 'global'".format(self.trend_global_local))
self.trend_global_local = "global"

if self.trend_local_reg < 0:

Check failure on line 375 in neuralprophet/configure.py

View workflow job for this annotation

GitHub Actions / pyright

Operator "<" not supported for "None" (reportOptionalOperand)
log.error("Invalid negative trend_local_reg '{}'. Set to False".format(self.trend_local_reg))
self.trend_local_reg = False

Expand Down Expand Up @@ -353,13 +421,13 @@
log.error("Invalid global_local mode '{}'. Set to 'global'".format(self.global_local))
self.global_local = "global"

self.periods = OrderedDict(

Check failure on line 424 in neuralprophet/configure.py

View workflow job for this annotation

GitHub Actions / pyright

No overloads for "__init__" match the provided arguments (reportCallIssue)
{

Check failure on line 425 in neuralprophet/configure.py

View workflow job for this annotation

GitHub Actions / pyright

Argument of type "dict[str, Season]" cannot be assigned to parameter "iterable" of type "Iterable[list[bytes]]" in function "__init__" (reportArgumentType)
"yearly": Season(
resolution=6,
period=365.25,
arg=self.yearly_arg,
global_local=(

Check failure on line 430 in neuralprophet/configure.py

View workflow job for this annotation

GitHub Actions / pyright

Argument of type "SeasonGlobalLocalMode | Literal['auto']" cannot be assigned to parameter "global_local" of type "SeasonGlobalLocalMode" in function "__init__" (reportArgumentType)
self.yearly_global_local
if self.yearly_global_local in ["global", "local"]
else self.global_local
Expand All @@ -370,7 +438,7 @@
resolution=3,
period=7,
arg=self.weekly_arg,
global_local=(

Check failure on line 441 in neuralprophet/configure.py

View workflow job for this annotation

GitHub Actions / pyright

Argument of type "SeasonGlobalLocalMode | Literal['auto']" cannot be assigned to parameter "global_local" of type "SeasonGlobalLocalMode" in function "__init__" (reportArgumentType)
self.weekly_global_local
if self.weekly_global_local in ["global", "local"]
else self.global_local
Expand All @@ -381,7 +449,7 @@
resolution=6,
period=1,
arg=self.daily_arg,
global_local=(

Check failure on line 452 in neuralprophet/configure.py

View workflow job for this annotation

GitHub Actions / pyright

Argument of type "SeasonGlobalLocalMode | Literal['auto']" cannot be assigned to parameter "global_local" of type "SeasonGlobalLocalMode" in function "__init__" (reportArgumentType)
self.daily_global_local if self.daily_global_local in ["global", "local"] else self.global_local
),
condition_name=None,
Expand All @@ -389,7 +457,7 @@
}
)

assert self.seasonality_local_reg >= 0, "Invalid seasonality_local_reg '{}'.".format(self.seasonality_local_reg)

Check failure on line 460 in neuralprophet/configure.py

View workflow job for this annotation

GitHub Actions / pyright

Operator ">=" not supported for "None" (reportOptionalOperand)

if self.seasonality_local_reg is True:
log.warning("seasonality_local_reg = True. Default seasonality_local_reg value set to 1")
Expand All @@ -407,7 +475,7 @@
resolution=resolution,
period=period,
arg=arg,
global_local=global_local if global_local in ["global", "local"] else self.global_local,

Check failure on line 478 in neuralprophet/configure.py

View workflow job for this annotation

GitHub Actions / pyright

Argument of type "str" cannot be assigned to parameter "global_local" of type "SeasonGlobalLocalMode" in function "__init__"   Type "str" is incompatible with type "SeasonGlobalLocalMode"     "str" is incompatible with type "Literal['global']"     "str" is incompatible with type "Literal['local']"     "str" is incompatible with type "Literal['glocal']" (reportArgumentType)
condition_name=condition_name,
)

Expand Down Expand Up @@ -483,7 +551,7 @@
regressors: OrderedDict = field(init=False) # contains RegressorConfig objects

def __post_init__(self):
self.regressors = None

Check failure on line 554 in neuralprophet/configure.py

View workflow job for this annotation

GitHub Actions / pyright

Cannot assign to attribute "regressors" for class "ConfigFutureRegressors*"   "None" is incompatible with "OrderedDict[Unknown, Unknown]" (reportAttributeAccessIssue)


@dataclass
Expand Down
Loading
Loading