@@ -31,6 +31,8 @@ def __init__(
3131 scheduler : str = "reduce_on_plateau" ,
3232 scheduler_params : Dict [str , Any ] = None ,
3333 log_freq : int = 100 ,
34+ auto_lr_finder : bool = False ,
35+ ** kwargs ,
3436 ) -> None :
3537 """Segmentation model training experiment.
3638
@@ -64,6 +66,7 @@ def __init__(
6466 optim paramas like learning rates, weight decays etc for diff parts of
6567 the network.
6668 E.g. {"encoder": {"weight_decay: 0.1, "lr": 0.1}, "sem": {"lr": 0.01}}
69+ or {"learning_rate": 0.005, "weight_decay": 0.03}
6770 lookahead : bool, default=False
6871 Flag whether the optimizer uses lookahead.
6972 scheduler : str, default="reduce_on_plateau"
@@ -75,6 +78,8 @@ def __init__(
7578 for the possible scheduler arguments.
7679 log_freq : int, default=100
7780 Return logs every n batches in logging callbacks.
81+ auto_lr_finder : bool, default=False
82+ Flag, whether to use the lightning in-built auto-lr-finder.
7883
7984 Raises
8085 ------
@@ -83,6 +88,8 @@ def __init__(
8388 ValueError if illegal metric names are given.
8489 ValueError if illegal optimizer name is given.
8590 ValueError if illegal scheduler name is given.
91+ KeyError if `auto_lr_finder` is set to True and `optim_params` does not
92+ contain `lr`-key.
8693 """
8794 super ().__init__ ()
8895 self .model = model
@@ -95,6 +102,16 @@ def __init__(
95102 self .scheduler = scheduler
96103 self .scheduler_params = scheduler_params
97104 self .lookahead = lookahead
105+ self .auto_lr_finder = auto_lr_finder
106+
107+ if auto_lr_finder :
108+ try :
109+ self .lr = optim_params ["lr" ]
110+ except KeyError :
111+ raise KeyError (
112+ "To use lightning in-built auto_lr_finder, the `optim_params` "
113+ "config variable has to contain 'lr'-key for learning-rate."
114+ )
98115
99116 self .branch_losses = branch_losses
100117 self .branch_metrics = branch_metrics
@@ -309,15 +326,20 @@ def configure_optimizers(self):
309326 f"Illegal scheduler given. Got { self .scheduler } . Allowed: { allowed } ."
310327 )
311328
312- # set sensible default if None.
313- if self .optim_params is None :
314- self .optim_params = {
315- "encoder" : {"lr" : 0.00005 , "weight_decay" : 0.00003 },
316- "decoder" : {"lr" : 0.0005 , "weight_decay" : 0.0003 },
317- }
329+ if not self .auto_lr_finder :
330+ # set sensible default if None.
331+ if self .optim_params is None :
332+ self .optim_params = {
333+ "encoder" : {"lr" : 0.00005 , "weight_decay" : 0.00005 },
334+ "decoder" : {"lr" : 0.0005 , "weight_decay" : 0.0005 },
335+ }
318336
319- params = adjust_optim_params (self .model , self .optim_params )
320- optimizer = OPTIM_LOOKUP [self .optimizer ](params )
337+ params = adjust_optim_params (self .model , self .optim_params )
338+ optimizer = OPTIM_LOOKUP [self .optimizer ](params )
339+ else :
340+ optimizer = OPTIM_LOOKUP [self .optimizer ](
341+ self .model .parameters (), lr = self .lr
342+ )
321343
322344 if self .lookahead :
323345 optimizer = OPTIM_LOOKUP ["lookahead" ](optimizer , k = 5 , alpha = 0.5 )
0 commit comments