This is an official PyTorch implementation of MLR-SNet: Transferable LR Schedules for Heterogeneous Tasks. Please contact: Jun Shu ([email protected]); Deyu Meng([email protected]).
@ARTICLE{shu2021mlrsnet,
author={Shu, Jun and Zhu, Yanwen and Zhao, Qian and Meng, Deyu and Xu, Zongben},
journal={IEEE Transactions on Pattern Analysis and Machine Intelligence},
title={MLR-SNet: Transferable LR Schedules for Heterogeneous Tasks},
year={2022},
volume={},
number={},
pages={1-16},
doi={10.1109/TPAMI.2022.3184315}}
The learning rate (LR) is one of the most important hyperparameters in stochastic gradient descent (SGD) algorithm for training deep neural networks (DNN). However, current hand-designed LR schedules need to manually pre-specify a fixed form, which limits their ability to adapt to practical non-convex optimization problems due to the significant diversification of training dynamics. Meanwhile, it always needs to search proper LR schedules from scratch for new tasks, which, however, are often largely different with task variations, like data modalities, network architectures, or training data capacities. To address this learning-rate-schedule setting issues, we propose to parameterize LR schedules with an explicit mapping formulation, called \textit{MLR-SNet}. The learnable parameterized structure brings more flexibility for MLR-SNet to learn a proper LR schedule to comply with the training dynamics of DNN. Image and text classification benchmark experiments substantiate the capability of our method for achieving proper LR schedules. Moreover, the explicit parameterized structure makes the meta-learned LR schedules capable of being transferable and plug-and-play, which can be easily generalized to new heterogeneous tasks. We transfer our meta-learned MLR-SNet to query tasks like different training epochs, network architectures, data modalities, dataset sizes from the training ones, and achieve comparable or even better performance compared with hand-designed LR schedules specifically designed for the query tasks. The robustness of MLR-SNet is also substantiated when the training data are biased with corrupted noise. We further prove the convergence of the SGD algorithm equipped with LR schedule produced by our MLR-Net, with the convergence rate comparable to the best-known ones of the algorithm for solving the problem.
- Python 3.7 (Anaconda)
- PyTorch >= 1.2.0
- Torchvision >= 0.2.1
Please use meta-train sub-folder to meta-learn the MLR-SNet meta-model. Here is an example for image-classification:
python meta-train.py --network resnet --dataset cifar10 --lr 1e-3
The lr is the learning rate of Adam meta-optimizer, we suggest to set
Please use meta-test sub-folder to evaluate the transferability and generalization capability of the LR Schedules Meta-learned by MLR-SNe. We also provide the MLR-SNet we learned in the meta-test sub-folder. mlr_snet 1.pth , mlr_snet 100.pth, mlr_snet 200.pth. Here is an example for transfer learned MLR-SNet to different network architectures setting.
python meta-test.py --network shufflenetv2 --dataset cifar10
We transfer learned MLR-SNet to help train ResNet-50 on ImageNet dataset and achieve the similar performance with SOTA hand-designed method.
This project is licensed under the terms of the MIT license.