-
Notifications
You must be signed in to change notification settings - Fork 2
/
threshold_relu_example.py
138 lines (118 loc) · 6.42 KB
/
threshold_relu_example.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import argparse
import copy
import os
import pathlib
import random
import torch
import numpy as np
from models import initialize
from optimiser_interface.utils import opt_cli_launcher, load_hardware_checkpoint
from quantization.utils import QuantMode, quantize_model
from sparsity.relu_utils import apply_threshold_relu
from sparsity.utils import measure_model_sparsity
# slowest conv layer with tunable input threshold (ThresholdedReLU before Conv)
def get_slowest_threshold_relu_conv(net):
import fpgaconvnet.tools.graphs as graphs
from fpgaconvnet.tools.layer_enum import LAYER_TYPE
slowest_layers = [] # list, as there are multiple partitions
for partition in net.partitions:
partition.remove_squeeze()
layers = []
for layer in graphs.ordered_node_list(partition.graph):
if partition.graph.nodes[layer]['type'] == LAYER_TYPE.Convolution:
layers.append(layer)
node_latencys = np.array([ partition.graph.nodes[layer]['hw'].latency() for layer in layers ])
index = list(reversed(np.argsort(node_latencys, kind='mergesort')))[0]
conv_layer = layers[index]
for prev_layer in graphs.get_prev_nodes(partition.graph, conv_layer):
if partition.graph.nodes[prev_layer]['type'] == LAYER_TYPE.ThresholdedReLU:
slowest_layers.append(prev_layer)
elif partition.graph.nodes[prev_layer]['type'] in [LAYER_TYPE.Split, LAYER_TYPE.Concat, LAYER_TYPE.EltWise]:
for prev_prev_layer in graphs.get_prev_nodes(partition.graph, prev_layer):
if partition.graph.nodes[prev_prev_layer]['type'] == LAYER_TYPE.ThresholdedReLU:
slowest_layers.append(prev_prev_layer)
relu_index_dict = {}
for n in graphs.ordered_node_list(net.graph):
if net.graph.nodes[n]['type'] == LAYER_TYPE.ThresholdedReLU:
relu_index_dict[n] = len(relu_index_dict)
slowest_layers_indices = [ relu_index_dict[layer] for layer in slowest_layers ]
return slowest_layers_indices
def main():
parser = argparse.ArgumentParser(description='Thresholded ReLU Example')
parser.add_argument('--dataset_name', default="imagenet", type=str,
help='dataset name')
parser.add_argument('--dataset_path', metavar='DIR', default="~/dataset/ILSVRC2012_img",
help='path to dataset')
parser.add_argument('--model_name', metavar='ARCH', default='resnet18',
help='model architecture')
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
help='number of data loading workers')
parser.add_argument('-b', '--batch-size', default=16, type=int, metavar='N',
help='mini-batch size')
parser.add_argument('--gpu', default=None, type=int,
help='GPU id to use.')
parser.add_argument('--output_path', default=None, type=str,
help='output path')
parser.add_argument('--relu-policy', choices=['slowest_node', 'uniform'], default="slowest_node", type=str)
parser.add_argument('--fixed-hardware-checkpoint', default=None, type=str,
help='path of config.json file generated by optimiser')
parser.add_argument('--runs', default=100, type=int,
help='how many runs')
parser.add_argument('--threshold_inc', default=0.5, type=float,
help='threshold increment')
parser.add_argument("--platform", default="u250", type=str)
parser.add_argument("--optimiser_config", default="single_partition_throughput", type=str)
args = parser.parse_args()
if args.output_path == None:
args.output_path = os.path.join(os.getcwd(),
f"output/{args.dataset_name}/{args.model_name}")
pathlib.Path(args.output_path).mkdir(parents=True, exist_ok=True)
print(args)
if args.gpu is not None:
torch.cuda.set_device(args.gpu)
random.seed(0)
torch.manual_seed(0)
model_wrapper = initialize_wrapper(args.dataset_name, args.model_name,
os.path.expanduser(args.dataset_path), args.batch_size, args.workers)
print("NETWORK FP16 Inference")
quantize_model(model_wrapper, {'weight_width': 8, 'data_width': 8, 'mode': QuantMode.NETWORK_FP})
model_copy = copy.deepcopy(model_wrapper.model)
# initialise relu thresholds
threshold = 0.0
apply_threshold_relu(model_wrapper, threshold)
for run in range(args.runs):
# create log directory for this run
log_dir = args.output_path + "/run_" + str(run)
pathlib.Path(log_dir).mkdir(parents=True, exist_ok=True)
top1, top5 = model_wrapper.inference("test")
avg_sparsity = measure_model_sparsity(model_wrapper)
sparse_onnx_path = model_wrapper.generate_onnx_files(log_dir)
# optimise
print("Optimising...")
opt_dir = log_dir + "/optimiser"
pathlib.Path(opt_dir).mkdir(parents=True, exist_ok=True)
if not args.fixed_hardware_checkpoint:
opt_cli_launcher(args.arch, sparse_onnx_path, opt_dir, device=args.platform, opt_cfg=args.optimiser_config)
checkpoint_path = os.path.join(opt_dir, "config.json")
else:
checkpoint_path = args.fixed_hardware_checkpoint
fpgaconvnet_net, (thr, lat, rsc) = load_hardware_checkpoint(sparse_onnx_path, opt_dir, args.platform, checkpoint_path)
# logging
threshold_relu = model_wrapper.sideband_info['threshold_relu']
log_info = threshold_relu | rsc | {"top1_accuracy": top1, "top5_accuracy": top5, "throughput": thr, "latency": lat, "sparsity": avg_sparsity}
print("Logging:", log_info)
# update threshold
model_wrapper.model = copy.deepcopy(model_copy)
if args.relu_policy == "uniform":
threshold = round(threshold + args.threshold_inc, 4)
apply_threshold_relu(model_wrapper, threshold)
elif args.relu_policy == "slowest_node":
slowest_layers_indices = get_slowest_threshold_relu_conv(fpgaconvnet_net)
info = model_wrapper.sideband_info['threshold_relu']
for i, (k, v) in enumerate(info.items()):
if i in slowest_layers_indices:
info[k] = round(threshold + args.threshold_inc, 4)
model_wrapper.sideband_info['threshold_relu'] = info
apply_threshold_relu(model_wrapper)
if __name__ == '__main__':
main()