-
Notifications
You must be signed in to change notification settings - Fork 486
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Major] Support Re-Training #1635
base: main
Are you sure you want to change the base?
Changes from all commits
2ae4506
900c8d5
f1355eb
da3a6d5
492dee9
f996928
f9a77f8
7ad761d
b14d20b
9fe3401
00f2e25
5f103d8
e043201
df74dc3
63c935c
6a74680
420f8a6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -23,6 +23,24 @@ | |
@dataclass | ||
class Model: | ||
lagged_reg_layers: Optional[List[int]] | ||
quantiles: Optional[List[float]] = None | ||
|
||
def setup_quantiles(self): | ||
# convert quantiles to empty list [] if None | ||
if self.quantiles is None: | ||
self.quantiles = [] | ||
# assert quantiles is a list type | ||
assert isinstance(self.quantiles, list), "Quantiles must be provided as list." | ||
# check if quantiles are float values in (0, 1) | ||
assert all( | ||
0 < quantile < 1 for quantile in self.quantiles | ||
), "The quantiles specified need to be floats in-between (0, 1)." | ||
# sort the quantiles | ||
self.quantiles.sort() | ||
# check if quantiles contain 0.5 or close to 0.5, remove if so as 0.5 will be inserted again as first index | ||
self.quantiles = [quantile for quantile in self.quantiles if not math.isclose(0.5, quantile)] | ||
# 0 is the median quantile index | ||
self.quantiles.insert(0, 0.5) | ||
|
||
|
||
@dataclass | ||
|
@@ -92,9 +110,9 @@ | |
batch_size: Optional[int] | ||
loss_func: Union[str, torch.nn.modules.loss._Loss, Callable] | ||
optimizer: Union[str, Type[torch.optim.Optimizer]] | ||
quantiles: List[float] = field(default_factory=list) | ||
# quantiles: List[float] = field(default_factory=list) | ||
optimizer_args: dict = field(default_factory=dict) | ||
scheduler: Optional[Type[torch.optim.lr_scheduler.OneCycleLR]] = None | ||
scheduler: Optional[Union[str, Type[torch.optim.lr_scheduler.LRScheduler]]] = None | ||
scheduler_args: dict = field(default_factory=dict) | ||
newer_samples_weight: float = 1.0 | ||
newer_samples_start: float = 0.0 | ||
|
@@ -104,18 +122,21 @@ | |
n_data: int = field(init=False) | ||
loss_func_name: str = field(init=False) | ||
lr_finder_args: dict = field(default_factory=dict) | ||
optimizer_state: dict = field(default_factory=dict) | ||
continue_training: bool = False | ||
trainer_config: dict = field(default_factory=dict) | ||
|
||
def __post_init__(self): | ||
# assert the uncertainty estimation params and then finalize the quantiles | ||
self.set_quantiles() | ||
assert self.newer_samples_weight >= 1.0 | ||
assert self.newer_samples_start >= 0.0 | ||
assert self.newer_samples_start < 1.0 | ||
self.set_loss_func() | ||
self.set_optimizer() | ||
self.set_scheduler() | ||
# self.set_loss_func(self.quantiles) | ||
|
||
def set_loss_func(self): | ||
# called in TimeNet configure_optimizers: | ||
# self.set_optimizer() | ||
# self.set_scheduler() | ||
|
||
def set_loss_func(self, quantiles: List[float]): | ||
if isinstance(self.loss_func, str): | ||
if self.loss_func.lower() in ["smoothl1", "smoothl1loss", "huber"]: | ||
# keeping 'huber' for backwards compatiblility, though not identical | ||
|
@@ -135,25 +156,8 @@ | |
self.loss_func_name = type(self.loss_func).__name__ | ||
else: | ||
raise NotImplementedError(f"Loss function {self.loss_func} not found") | ||
if len(self.quantiles) > 1: | ||
self.loss_func = PinballLoss(loss_func=self.loss_func, quantiles=self.quantiles) | ||
|
||
def set_quantiles(self): | ||
# convert quantiles to empty list [] if None | ||
if self.quantiles is None: | ||
self.quantiles = [] | ||
# assert quantiles is a list type | ||
assert isinstance(self.quantiles, list), "Quantiles must be in a list format, not None or scalar." | ||
# check if quantiles contain 0.5 or close to 0.5, remove if so as 0.5 will be inserted again as first index | ||
self.quantiles = [quantile for quantile in self.quantiles if not math.isclose(0.5, quantile)] | ||
# check if quantiles are float values in (0, 1) | ||
assert all( | ||
0 < quantile < 1 for quantile in self.quantiles | ||
), "The quantiles specified need to be floats in-between (0, 1)." | ||
# sort the quantiles | ||
self.quantiles.sort() | ||
# 0 is the median quantile index | ||
self.quantiles.insert(0, 0.5) | ||
if len(quantiles) > 1: | ||
self.loss_func = PinballLoss(loss_func=self.loss_func, quantiles=quantiles) | ||
|
||
def set_auto_batch_epoch( | ||
self, | ||
|
@@ -182,26 +186,87 @@ | |
""" | ||
Set the optimizer and optimizer args. If optimizer is a string, then it will be converted to the corresponding | ||
torch optimizer. The optimizer is not initialized yet as this is done in configure_optimizers in TimeNet. | ||
|
||
Parameters | ||
---------- | ||
optimizer_name : int | ||
Object provided to NeuralProphet as optimizer. | ||
optimizer_args : dict | ||
Arguments for the optimizer. | ||
|
||
""" | ||
self.optimizer, self.optimizer_args = utils_torch.create_optimizer_from_config( | ||
self.optimizer, self.optimizer_args | ||
) | ||
if isinstance(self.optimizer, str): | ||
if self.optimizer.lower() == "adamw": | ||
# Tends to overfit, but reliable | ||
self.optimizer = torch.optim.AdamW | ||
self.optimizer_args["weight_decay"] = 1e-3 | ||
elif self.optimizer.lower() == "sgd": | ||
# better validation performance, but diverges sometimes | ||
self.optimizer = torch.optim.SGD | ||
self.optimizer_args["momentum"] = 0.9 | ||
self.optimizer_args["weight_decay"] = 1e-4 | ||
else: | ||
raise ValueError( | ||
f"The optimizer name {self.optimizer} is not supported. Please pass the optimizer class." | ||
) | ||
elif not issubclass(self.optimizer, torch.optim.Optimizer): | ||
raise ValueError("The provided optimizer is not supported.") | ||
|
||
def set_scheduler(self): | ||
""" | ||
Set the scheduler and scheduler args. | ||
Set the scheduler and scheduler arg depending on the user selection. | ||
The scheduler is not initialized yet as this is done in configure_optimizers in TimeNet. | ||
""" | ||
self.scheduler = torch.optim.lr_scheduler.OneCycleLR | ||
self.scheduler_args.update( | ||
{ | ||
"pct_start": 0.3, | ||
"anneal_strategy": "cos", | ||
"div_factor": 10.0, | ||
"final_div_factor": 10.0, | ||
"three_phase": True, | ||
} | ||
) | ||
if self.continue_training: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. move to other PR |
||
if (isinstance(self.scheduler, str) and self.scheduler.lower() == "onecyclelr") or isinstance( | ||
self.scheduler, torch.optim.lr_scheduler.OneCycleLR | ||
): | ||
log.warning( | ||
"OneCycleLR scheduler is not supported for continued training. Please set another scheduler. Falling back to ExponentialLR scheduler" | ||
) | ||
self.scheduler = "exponentiallr" | ||
|
||
if self.scheduler is None: | ||
log.warning("No scheduler specified. Falling back to ExponentialLR scheduler.") | ||
self.scheduler = "exponentiallr" | ||
|
||
if isinstance(self.scheduler, str): | ||
if self.scheduler.lower() == "onecyclelr": | ||
self.scheduler = torch.optim.lr_scheduler.OneCycleLR | ||
defaults = { | ||
"pct_start": 0.3, | ||
"anneal_strategy": "cos", | ||
"div_factor": 10.0, | ||
"final_div_factor": 10.0, | ||
"three_phase": True, | ||
} | ||
elif self.scheduler.lower() == "steplr": | ||
self.scheduler = torch.optim.lr_scheduler.StepLR | ||
defaults = { | ||
"step_size": 10, | ||
"gamma": 0.1, | ||
} | ||
elif self.scheduler.lower() == "exponentiallr": | ||
self.scheduler = torch.optim.lr_scheduler.ExponentialLR | ||
defaults = { | ||
"gamma": 0.95, | ||
} | ||
elif self.scheduler.lower() == "cosineannealinglr": | ||
self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR | ||
defaults = { | ||
"T_max": 50, | ||
} | ||
else: | ||
raise NotImplementedError( | ||
f"Scheduler {self.scheduler} is not supported from string. Please pass the scheduler class." | ||
) | ||
if self.scheduler_args is not None: | ||
defaults.update(self.scheduler_args) | ||
self.scheduler_args = defaults | ||
else: | ||
assert issubclass( | ||
self.scheduler, torch.optim.lr_scheduler.LRScheduler | ||
), "Scheduler must be a subclass of torch.optim.lr_scheduler.LRScheduler" | ||
|
||
def set_lr_finder_args(self, dataset_size, num_batches): | ||
""" | ||
|
@@ -239,6 +304,9 @@ | |
delay_weight = 1 | ||
return delay_weight | ||
|
||
def set_optimizer_state(self, optimizer_state: dict): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. move to other PR |
||
self.optimizer_state = optimizer_state | ||
|
||
|
||
@dataclass | ||
class Trend: | ||
|
@@ -304,7 +372,7 @@ | |
log.error("Invalid growth for global_local mode '{}'. Set to 'global'".format(self.trend_global_local)) | ||
self.trend_global_local = "global" | ||
|
||
if self.trend_local_reg < 0: | ||
log.error("Invalid negative trend_local_reg '{}'. Set to False".format(self.trend_local_reg)) | ||
self.trend_local_reg = False | ||
|
||
|
@@ -353,13 +421,13 @@ | |
log.error("Invalid global_local mode '{}'. Set to 'global'".format(self.global_local)) | ||
self.global_local = "global" | ||
|
||
self.periods = OrderedDict( | ||
{ | ||
"yearly": Season( | ||
resolution=6, | ||
period=365.25, | ||
arg=self.yearly_arg, | ||
global_local=( | ||
Check failure on line 430 in neuralprophet/configure.py GitHub Actions / pyright
|
||
self.yearly_global_local | ||
if self.yearly_global_local in ["global", "local"] | ||
else self.global_local | ||
|
@@ -370,7 +438,7 @@ | |
resolution=3, | ||
period=7, | ||
arg=self.weekly_arg, | ||
global_local=( | ||
Check failure on line 441 in neuralprophet/configure.py GitHub Actions / pyright
|
||
self.weekly_global_local | ||
if self.weekly_global_local in ["global", "local"] | ||
else self.global_local | ||
|
@@ -381,7 +449,7 @@ | |
resolution=6, | ||
period=1, | ||
arg=self.daily_arg, | ||
global_local=( | ||
Check failure on line 452 in neuralprophet/configure.py GitHub Actions / pyright
|
||
self.daily_global_local if self.daily_global_local in ["global", "local"] else self.global_local | ||
), | ||
condition_name=None, | ||
|
@@ -389,7 +457,7 @@ | |
} | ||
) | ||
|
||
assert self.seasonality_local_reg >= 0, "Invalid seasonality_local_reg '{}'.".format(self.seasonality_local_reg) | ||
|
||
if self.seasonality_local_reg is True: | ||
log.warning("seasonality_local_reg = True. Default seasonality_local_reg value set to 1") | ||
|
@@ -407,7 +475,7 @@ | |
resolution=resolution, | ||
period=period, | ||
arg=arg, | ||
global_local=global_local if global_local in ["global", "local"] else self.global_local, | ||
Check failure on line 478 in neuralprophet/configure.py GitHub Actions / pyright
|
||
condition_name=condition_name, | ||
) | ||
|
||
|
@@ -483,7 +551,7 @@ | |
regressors: OrderedDict = field(init=False) # contains RegressorConfig objects | ||
|
||
def __post_init__(self): | ||
self.regressors = None | ||
|
||
|
||
@dataclass | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
move to separate PR