-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_classifier.py
73 lines (53 loc) · 2.26 KB
/
test_classifier.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import os
import argparse
import torch
import yaml
from torch.utils.data.dataloader import DataLoader
from torchvision.transforms import transforms
from torchvision.datasets import CIFAR10
from classifier.tester import Tester
def argparser():
args = argparse.ArgumentParser()
args.add_argument('--model_name', type=str, help='File name of the model to be used during testing', required=True)
args.add_argument('--device', type=str, default='cuda:0', choices=['cuda:0', 'cpu'], help='Specify the device on which executes the training.', required=False)
args.add_argument('--config', type=str, default='./config/classifier/alexnet_cGAN_epoch100.yaml', help='Path to the configuration file.', required=False)
return args.parse_args()
def get_config(config: str):
"""
Load the configuration file.
Parameters
----------
config : str
Path of the configuration file.
Returns
-------
The yaml configuration file parsed.
"""
with open(config, 'r') as f:
return yaml.load(f, Loader=yaml.FullLoader)
def main(args):
config = get_config(args.config)
# Resize is fine both for alexnet and resnet
transformList = transforms.Compose([transforms.Resize(224),
transforms.ToTensor(),
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])
cifar10_testset = CIFAR10(root=config['dataset_path'],
train=False,
transform=transformList,
download=True)
cifar10_testloader = DataLoader(dataset=cifar10_testset,
batch_size=config['batch_size'],
shuffle=False,
num_workers=config['num_workers'])
device = torch.device(args.device if torch.cuda.is_available() else "cpu")
print(f"Code will be executed on {device}")
tester = Tester(test_loader=cifar10_testloader,
device=device,
args=args,
config=config)
tester.test()
if __name__ == "__main__":
# To suppress tensorflow warnings
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
args = argparser()
main(args)