This repository contains the official implementation for the paper: Distillation Enhanced Time Series Forecasting Network with Momentum Contrastive Learning.
The recommended requirements for DETSMCL are specified as follows:
- Python 3.6/3.8
- torch==1.10.3
- torchvision==0.11.2
- scikit_learn==0.24.2
- scipy==1.6.1
- numpy==1.21.5
- numpy-base==1.23.5
- pandas==1.0.1
- Bottleneck==1.3.1
The dependencies can be installed by:
pip install -r requirements.txt
The datasets can be obtained and put into datasets/
folder in the following way:
- 3 ETT datasets should be placed at
datasets/ETTh1.csv
,datasets/ETTh2.csv
anddatasets/ETTm1.csv
. - Electricity dataset should be preprocessed using
datasets/preprocess_electricity.py
and placed atdatasets/electricity.csv
.
To train and evaluate DETSMCL on a dataset, run the following command:
python train.py <dataset_name> <run_name> --loader <loader> --batch-size <batch_size> --repr-dims <repr_dims> --gpu <gpu> --eval