forked from wqi/WIMP
-
Notifications
You must be signed in to change notification settings - Fork 0
/
test.py
118 lines (89 loc) Β· 5.26 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import json
import os
import torch
import pytorch_lightning as pl
from argparse import ArgumentParser
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.loggers import TensorBoardLogger
from src.data.argoverse_datamodule import ArgoverseDataModule
from src.data.dummy_datamodule import DummyDataModule
from src.models.WIMP import WIMP
def parse_arguments():
parser = ArgumentParser()
# Load experiment and trainer-sepcific args
parser = add_experimental_args(parser)
# parser = pl.Trainer.add_argparse_args(parser)
# Parse dataset model to use
parser.add_argument('--dataset', type=str, default='argoverse', help='Name of dataset to use')
parser.add_argument('--model-name', type=str, default='WIMP', help='Name of model to load')
temp_args, _ = parser.parse_known_args()
# Load dataset specific args
if temp_args.dataset == 'argoverse':
parser = ArgoverseDataModule.add_data_specific_args(parser)
else:
raise NotImplementedError
# Load model specific args
if temp_args.model_name == 'WIMP':
parser = WIMP.add_model_specific_args(parser)
else:
raise NotImplementedError
args = parser.parse_args()
# with open('args.json', 'r') as f:
# args.__dict__ = json.load(f)
return args
def add_experimental_args(parent_parser):
parser = ArgumentParser(parents=[parent_parser], add_help=False)
# General Params
parser.add_argument("--mode", required=True, type=str, choices=['train', 'val', 'trainval', 'oracle-val'],
help='Mode to run forecasting model in')
parser.add_argument('--seed', type=int, help="Seed the random parameter generation")
# I/O and Feature Space Params
parser.add_argument('--predict-delta', action='store_true', help="Predict delta-xy coordinates instead of absolute") # NOQA
parser.add_argument('--IFC', action='store_true', help="Compute centerline features when predicting trajectory") # NOQA
parser.add_argument('--map-features', action='store_true', help="Compute map features") # NOQA
parser.add_argument('--no-heuristic', action='store_true', help="Don't use heuristic centerline features") # NOQA
parser.add_argument('--use-oracle', action='store_true', help='Whether to use features obtained from oracle') # NOQA
# Datamodule Params
parser.add_argument('--dataroot', required=True, help="Path to the processed dataset folder")
parser.add_argument("--batch-size", type=int, default=25, help="Training batch size")
parser.add_argument('--workers', type=int, default=8, help="Number of dataloader workers")
# Trainer Params
parser.add_argument("--gpus", type=int, default=1, help='# of GPUs to use for training')
parser.add_argument("--check-val-every-n-epoch", type=int, default=3, help="# of training epochs between val") # NOQA
parser.add_argument("--max-epochs", type=int, default=120, help="Max # of training epochs")
parser.add_argument("--early-stop-threshold", type=int, default=5, help="Number of consecutive val epochs without improvement before termination") # NOQA
parser.add_argument('--distributed-backend', default=None, help='Trainer backend')
parser.add_argument('--num-nodes', default=1, type=int, help='Number of nodes used')
parser.add_argument('--precision', default=32, type=int, help='Precision employed in weights')
parser.add_argument('--resume-from-checkpoint', help='Path to checkpoint to resume training from')
parser.add_argument('--ckpt_path', type=str, required=True)
# Logging Params
parser.add_argument('--experiment-name', type=str, help='Save file prefix')
# Test Parmas
parser.add_argument('--save_dir', type=str, required=True)
parser.add_argument('--save_json', action='store_true')
return parser
def cli_main(args):
# Set global random seed
pl.seed_everything(args.seed)
# Initialize data module
dm = ArgoverseDataModule(args)
# Initialize selected model
if args.model_name == 'WIMP':
model = WIMP(args)
# Initialize trainer
logger = TensorBoardLogger(os.getcwd(), name='experiments', version=args.experiment_name)
early_stop_cb = EarlyStopping(patience=args.early_stop_threshold, verbose=True)
trainer = pl.Trainer(gpus=args.gpus, check_val_every_n_epoch=args.check_val_every_n_epoch,
max_epochs=0, default_root_dir=os.getcwd(),
distributed_backend=args.distributed_backend, num_nodes=args.num_nodes,
precision=args.precision, resume_from_checkpoint=args.resume_from_checkpoint,
logger=logger, callbacks=[early_stop_cb])
os.makedirs(args.save_dir, exist_ok=True)
print("============================================================================================START============================================================================================")
trainer.fit(model, dm)
trainer.test(ckpt_path=args.ckpt_path, test_dataloaders = dm.test_dataloader())
print("============================================================================================END============================================================================================")
if __name__ == '__main__':
args = parse_arguments()
cli_main(args)