Skip to content

Commit

Permalink
parse toml
Browse files Browse the repository at this point in the history
  • Loading branch information
Yu-Zhewen committed Aug 15, 2023
1 parent 3cad40d commit b802921
Showing 1 changed file with 61 additions and 39 deletions.
100 changes: 61 additions & 39 deletions onnx_sparsity_attribute.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import onnx
import argparse
import csv
import toml

from utils import load_model, model_names, replace_modules

def torch_onnx_exporter(model, model_name, random_input, output_path):
Expand Down Expand Up @@ -33,31 +35,30 @@ def annotate_quantisation(model, weight_width, data_width, acc_width, block_floa
else:
set_nodeattr(node, "data_width", data_width)

def layer_name_translation(model_name, onnx_name):
onnx_name = onnx_name.split("/")
if model_name in ["resnet18", "resnet50"]:
if len(onnx_name) == 3: # first conv
torch_name = onnx_name[1]+ ".1"
else:
assert len(onnx_name) in [5,6]
torch_name = onnx_name[2] + "." +onnx_name[-2]+ ".1"
elif model_name == "mobilenet_v2":
if len(onnx_name) == 5: # first and last conv
def annotate_sparsity_from_numpy(model_name, onnx_model, data_path):
def _layer_name_translation(model_name, onnx_name):
onnx_name = onnx_name.split("/")
if model_name in ["resnet18", "resnet50"]:
if len(onnx_name) == 3: # first conv
torch_name = onnx_name[1]+ ".1"
else:
assert len(onnx_name) in [5,6]
torch_name = onnx_name[2] + "." +onnx_name[-2]+ ".1"
elif model_name == "mobilenet_v2":
if len(onnx_name) == 5: # first and last conv
torch_name = onnx_name[-2] + ".1"
else:
assert len(onnx_name) in [6,7]
torch_name = onnx_name[2] + "." + onnx_name[-2] + ".1"
elif model_name in ["alexnet", "vgg11", "vgg16"]:
torch_name = onnx_name[-2] + ".1"
else:
assert len(onnx_name) in [6,7]
torch_name = onnx_name[2] + "." + onnx_name[-2] + ".1"
elif model_name in ["alexnet", "vgg11", "vgg16"]:
torch_name = onnx_name[-2] + ".1"
elif model_name == "repvgg-a0":
torch_name = ".".join(onnx_name[1:-1]) + ".1"
return torch_name
elif model_name == "repvgg-a0":
torch_name = ".".join(onnx_name[1:-1]) + ".1"
return torch_name

def annotate_sparsity(model_name, onnx_model, data_path):
conv_layer_index = 0
for node in onnx_model.graph.node:
if node.op_type == 'Conv':
layer_name = layer_name_translation(model_name, node.name)
layer_name = _layer_name_translation(model_name, node.name)
np_path = os.path.join(data_path, model_name + "_" + layer_name + "_mean.npy")
num_of_zeros_mean = np.load(np_path)
for attr in node.attribute:
Expand All @@ -67,23 +68,44 @@ def annotate_sparsity(model_name, onnx_model, data_path):
sparsity_data = num_of_zeros_mean / np.prod(kernel_shape)
set_nodeattr(node, "input sparsity", sparsity_data)

parser = argparse.ArgumentParser(description='Export ONNX model with sparsity attribute')
parser.add_argument('-a', '--arch', metavar='ARCH', default='vgg16',
choices=model_names,
help='model architecture: ' +
' | '.join(model_names))
parser.add_argument('--data', metavar='DIR', default="runlog/per_channel/vgg16_sparsity_run_50k_2023_03_13_10_02_17_996357",
help='path to onnx model')
parser.add_argument('--dense_onnx_path', metavar='DIR', default="models/vgg16.onnx",
help='path to onnx model')
parser.add_argument('--sparse_onnx_path', metavar='DIR', default="models/vgg16_sparse.onnx",
help='path to onnx model')
def annotate_sparsity_from_toml(model_name, onnx_model, data_path):
def _layer_name_translation(model_name, onnx_name):
onnx_name = onnx_name.split("/")
if model_name in ["resnet18"]:
if len(onnx_name) == 3: #first_conv
torch_name = onnx_name[1]
else:
torch_name = onnx_name[2] + "." + onnx_name[-2]
return torch_name

with open(data_path) as f:
toml_data = toml.load(f)

for node in onnx_model.graph.node:
if node.op_type == 'Conv':
layer_name = _layer_name_translation(model_name, node.name)
sparsity_data = toml_data[layer_name]["avg"]
set_nodeattr(node, "input sparsity", sparsity_data)

if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Export ONNX model with sparsity attribute')
parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18',choices=model_names)
parser.add_argument('--state_dict', metavar='DIR', default="/home/zy18/Downloads/Pruning Results-20230815T143526Z-001/Pruning Results/weight_sparse_50/resnet18_classification_imagenet_2023-08-12/software/transform/transformed_ckpt/state_dict.pt")
parser.add_argument('--data_path', metavar='DIR', default="/home/zy18/Downloads/Pruning Results-20230815T143526Z-001/Pruning Results/weight_sparse_50/resnet18_classification_imagenet_2023-08-12/software/transform/prune/activation_report.toml")
parser.add_argument('--export_path', metavar='DIR', default="models")

args = parser.parse_args()
args = parser.parse_args()

torch_model = load_model(args.arch)
torch_onnx_exporter(torch_model, args.arch, torch.randn(1, 3, 224, 224), args.dense_onnx_path)
onnx_model = onnx.load(args.dense_onnx_path)
annotate_quantisation(onnx_model, 16, 16, 32, False)
annotate_sparsity(args.arch, onnx_model, args.data)
onnx.save(onnx_model, args.sparse_onnx_path)
torch_model = load_model(args.arch)
if args.state_dict is not None:
torch_model.load_state_dict(torch.load(args.state_dict, map_location="cpu"))
dense_onnx_path = os.path.join(args.export_path, args.arch + ".onnx")
sparse_onnx_path = os.path.join(args.export_path, args.arch + "_sparse.onnx")
torch_onnx_exporter(torch_model, args.arch, torch.randn(1, 3, 224, 224), dense_onnx_path)
onnx_model = onnx.load(dense_onnx_path)
annotate_quantisation(onnx_model, 16, 16, 32, False)
if args.data_path.endswith(".toml"):
annotate_sparsity_from_toml(args.arch, onnx_model, args.data_path)
else:
annotate_sparsity_from_numpy(args.arch, onnx_model, args.data_path)
onnx.save(onnx_model, sparse_onnx_path)

0 comments on commit b802921

Please sign in to comment.