From 0ffd22b2ff972d51409f12f29a0b8337c64899e7 Mon Sep 17 00:00:00 2001 From: Krish Agrawal Date: Wed, 9 Aug 2023 14:54:54 +0100 Subject: [PATCH] Fixed histogram statistics for grouped convolutions --- imagenet_main.py | 2 +- relu_main.py | 23 ++++++++++------------- sparsity_utils.py | 7 ++++--- 3 files changed, 15 insertions(+), 17 deletions(-) diff --git a/imagenet_main.py b/imagenet_main.py index ac9f7ef..d5cd81b 100644 --- a/imagenet_main.py +++ b/imagenet_main.py @@ -39,7 +39,7 @@ parser.add_argument('--ma_window_size', default=None, type=int, help='') -parser.add_argument('--calibration-size', default=1000, type=int, +parser.add_argument('--calibration-size', default=2500, type=int, help='') parser.add_argument('--relu_threshold', default=0, type=str, help='') diff --git a/relu_main.py b/relu_main.py index 3484e80..fe36bbb 100644 --- a/relu_main.py +++ b/relu_main.py @@ -138,15 +138,13 @@ def layer_name_translation(model_name, onnx_name): # Note accuracy with open(args.accuracy_path, 'r') as f: lines = f.read().splitlines() - for line in lines[1:]: - line_vals = line.split(",") - if float(line_vals[1]) == threshold: - top1 = float(line_vals[-3]) - top5 = float(line_vals[-2]) - sparsity = float(line_vals[-1]) - break + line = lines[run + 1] + line_vals = line.split(",") + top1 = float(line_vals[-3]) + top5 = float(line_vals[-2]) + sparsity = float(line_vals[-1]) - sparsity_dir = args.sparsity_path + sparsity_dir = args.sparsity_path + "/uniform_relu_" + str(threshold) #Else collect sparsity else: @@ -177,7 +175,6 @@ def layer_name_translation(model_name, onnx_name): top1 = float(last_line.split(",")[-3]) top5 = float(last_line.split(",")[-2]) sparsity = float(last_line.split(",")[-1]) - sparsity_dir = args.sparsity_path @@ -193,7 +190,7 @@ def layer_name_translation(model_name, onnx_name): throughput, latency = get_new_throughput(args.arch, net, sparsity_dir) log_info = relu_thresholds | {"top1_accuracy": top1, "top5_accuracy": top5, "throughput": throughput, "latency": latency, "network_sparsity": sparsity} - + print("Logging:", log_info) #Else annotate sparsity, run optimiser, note resources, throughput, and latency else: @@ -256,12 +253,12 @@ def layer_name_translation(model_name, onnx_name): #Update based on relu-policy - threshold += THRESHOLD_INC + threshold = round(threshold + THRESHOLD_INC, 4) if args.relu_policy == "uniform": for name, module in model.named_modules(): if isinstance(module, nn.ReLU): - relu_thresholds[name + ".1"] = 0.0 + relu_thresholds[name + ".1"] = round(threshold, 4) elif args.relu_policy == "slowest_node": assert args.fixed_hardware @@ -285,7 +282,7 @@ def layer_name_translation(model_name, onnx_name): if (partition.graph.nodes[layer]['type'] == LAYER_TYPE.Convolution): layer_latency = partition.graph.nodes[layer]['hw'].latency() if previous_relu != None: - previous_layer = layer_name_translation(previous_relu) + previous_layer = layer_name_translation(args.arch, previous_relu) if layer_latency > max_latency and len(partition.graph.nodes[layer]['hw'].sparsity): max_latency = layer_latency replace_layer = previous_layer diff --git a/sparsity_utils.py b/sparsity_utils.py index 6fcd616..8fd4d55 100644 --- a/sparsity_utils.py +++ b/sparsity_utils.py @@ -189,11 +189,12 @@ def forward(self, x): zeros_hists = F.one_hot(num_of_zeros, num_classes = self.kk + 1) # (batch_size, h_windows//self.roll_factor, w_windows//self.roll_factor, groups, out_channels//groups, in_channels//groups, bins) - #All groups and out_channels have the input feature map and therefore same sparsity, can squeeze those dimensions - zeros_hists = zeros_hists[:, :, :, 0, 0].squeeze(4).squeeze(3) # (batch_size, h_windows//self.roll_factor, w_windows//self.roll_factor, in_channels//groups, bins) + #All out_channels have the input feature map and therefore same sparsity, can squeeze those dimensions + zeros_hists = zeros_hists[:, :, :, :, 0].squeeze(4) # (batch_size, h_windows//self.roll_factor, w_windows//self.roll_factor, groups, in_channels//groups, bins) #NOTE: Toggle the commenting for the following 5 lines for per window - zeros_hists = zeros_hists.sum(dim = (0, 1, 2)) # (in_channels//groups, bins) + zeros_hists = zeros_hists.reshape(batch_size, h_windows//self.roll_factor, w_windows//self.roll_factor, in_channels, self.kk + 1) + zeros_hists = zeros_hists.sum(dim = (0, 1, 2)) # (in_channels, bins) self.statistics.histograms += zeros_hists # zeros_hists = zeros_hists.sum(dim = 0) # (h_windows//self.roll_factor, w_windows//self.roll_factor, in_channels//groups, bins) # zeros_hists = zeros_hists.permute(2, 0, 1, 3) # (in_channels//groups, h_windows//self.roll_factor, w_windows//self.roll_factor, bins)