Skip to content

Commit

Permalink
Fixed histogram statistics for grouped convolutions
Browse files Browse the repository at this point in the history
  • Loading branch information
Krish Agrawal committed Aug 9, 2023
1 parent b4174c7 commit 0ffd22b
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 17 deletions.
2 changes: 1 addition & 1 deletion imagenet_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@

parser.add_argument('--ma_window_size', default=None, type=int,
help='')
parser.add_argument('--calibration-size', default=1000, type=int,
parser.add_argument('--calibration-size', default=2500, type=int,
help='')
parser.add_argument('--relu_threshold', default=0, type=str,
help='')
Expand Down
23 changes: 10 additions & 13 deletions relu_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,15 +138,13 @@ def layer_name_translation(model_name, onnx_name):
# Note accuracy
with open(args.accuracy_path, 'r') as f:
lines = f.read().splitlines()
for line in lines[1:]:
line_vals = line.split(",")
if float(line_vals[1]) == threshold:
top1 = float(line_vals[-3])
top5 = float(line_vals[-2])
sparsity = float(line_vals[-1])
break
line = lines[run + 1]
line_vals = line.split(",")
top1 = float(line_vals[-3])
top5 = float(line_vals[-2])
sparsity = float(line_vals[-1])

sparsity_dir = args.sparsity_path
sparsity_dir = args.sparsity_path + "/uniform_relu_" + str(threshold)

#Else collect sparsity
else:
Expand Down Expand Up @@ -177,7 +175,6 @@ def layer_name_translation(model_name, onnx_name):
top1 = float(last_line.split(",")[-3])
top5 = float(last_line.split(",")[-2])
sparsity = float(last_line.split(",")[-1])
sparsity_dir = args.sparsity_path



Expand All @@ -193,7 +190,7 @@ def layer_name_translation(model_name, onnx_name):
throughput, latency = get_new_throughput(args.arch, net, sparsity_dir)

log_info = relu_thresholds | {"top1_accuracy": top1, "top5_accuracy": top5, "throughput": throughput, "latency": latency, "network_sparsity": sparsity}

print("Logging:", log_info)

#Else annotate sparsity, run optimiser, note resources, throughput, and latency
else:
Expand Down Expand Up @@ -256,12 +253,12 @@ def layer_name_translation(model_name, onnx_name):


#Update based on relu-policy
threshold += THRESHOLD_INC
threshold = round(threshold + THRESHOLD_INC, 4)

if args.relu_policy == "uniform":
for name, module in model.named_modules():
if isinstance(module, nn.ReLU):
relu_thresholds[name + ".1"] = 0.0
relu_thresholds[name + ".1"] = round(threshold, 4)
elif args.relu_policy == "slowest_node":
assert args.fixed_hardware

Expand All @@ -285,7 +282,7 @@ def layer_name_translation(model_name, onnx_name):
if (partition.graph.nodes[layer]['type'] == LAYER_TYPE.Convolution):
layer_latency = partition.graph.nodes[layer]['hw'].latency()
if previous_relu != None:
previous_layer = layer_name_translation(previous_relu)
previous_layer = layer_name_translation(args.arch, previous_relu)
if layer_latency > max_latency and len(partition.graph.nodes[layer]['hw'].sparsity):
max_latency = layer_latency
replace_layer = previous_layer
Expand Down
7 changes: 4 additions & 3 deletions sparsity_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,11 +189,12 @@ def forward(self, x):

zeros_hists = F.one_hot(num_of_zeros, num_classes = self.kk + 1) # (batch_size, h_windows//self.roll_factor, w_windows//self.roll_factor, groups, out_channels//groups, in_channels//groups, bins)

#All groups and out_channels have the input feature map and therefore same sparsity, can squeeze those dimensions
zeros_hists = zeros_hists[:, :, :, 0, 0].squeeze(4).squeeze(3) # (batch_size, h_windows//self.roll_factor, w_windows//self.roll_factor, in_channels//groups, bins)
#All out_channels have the input feature map and therefore same sparsity, can squeeze those dimensions
zeros_hists = zeros_hists[:, :, :, :, 0].squeeze(4) # (batch_size, h_windows//self.roll_factor, w_windows//self.roll_factor, groups, in_channels//groups, bins)

#NOTE: Toggle the commenting for the following 5 lines for per window
zeros_hists = zeros_hists.sum(dim = (0, 1, 2)) # (in_channels//groups, bins)
zeros_hists = zeros_hists.reshape(batch_size, h_windows//self.roll_factor, w_windows//self.roll_factor, in_channels, self.kk + 1)
zeros_hists = zeros_hists.sum(dim = (0, 1, 2)) # (in_channels, bins)
self.statistics.histograms += zeros_hists
# zeros_hists = zeros_hists.sum(dim = 0) # (h_windows//self.roll_factor, w_windows//self.roll_factor, in_channels//groups, bins)
# zeros_hists = zeros_hists.permute(2, 0, 1, 3) # (in_channels//groups, h_windows//self.roll_factor, w_windows//self.roll_factor, bins)
Expand Down

0 comments on commit 0ffd22b

Please sign in to comment.