From 57ca4875e3b194f8e31dd063591dd973fa75c461 Mon Sep 17 00:00:00 2001 From: Nick Fraser Date: Thu, 21 Mar 2024 10:49:03 +0000 Subject: [PATCH 1/3] Fix: changed bnn-pynq models to have a 24bit bias. Updated average pool to do rounding. --- src/brevitas_examples/bnn_pynq/models/resnet.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/brevitas_examples/bnn_pynq/models/resnet.py b/src/brevitas_examples/bnn_pynq/models/resnet.py index 14efdf498..366cc4de9 100644 --- a/src/brevitas_examples/bnn_pynq/models/resnet.py +++ b/src/brevitas_examples/bnn_pynq/models/resnet.py @@ -9,7 +9,7 @@ import brevitas.nn as qnn from brevitas.quant import Int8WeightPerChannelFloat from brevitas.quant import Int8WeightPerTensorFloat -from brevitas.quant import Int32Bias +from brevitas.quant import Int24Bias from brevitas.quant import TruncTo8bit from brevitas.quant_tensor import QuantTensor @@ -120,8 +120,8 @@ def __init__( num_classes=10, act_bit_width=8, weight_bit_width=8, - round_average_pool=False, - last_layer_bias_quant=Int32Bias, + round_average_pool=True, + last_layer_bias_quant=Int24Bias, weight_quant=Int8WeightPerChannelFloat, first_layer_weight_quant=Int8WeightPerChannelFloat, last_layer_weight_quant=Int8WeightPerTensorFloat): From d880e67d0b324a84d8ce3246c5cffe4d3997936a Mon Sep 17 00:00:00 2001 From: Nick Fraser Date: Thu, 21 Mar 2024 10:50:52 +0000 Subject: [PATCH 2/3] feat: Added ONNX export for BNN-PYNQ examples --- .../bnn_pynq/bnn_pynq_train.py | 4 +++- src/brevitas_examples/bnn_pynq/trainer.py | 23 +++++++++++++++++++ 2 files changed, 26 insertions(+), 1 deletion(-) diff --git a/src/brevitas_examples/bnn_pynq/bnn_pynq_train.py b/src/brevitas_examples/bnn_pynq/bnn_pynq_train.py index 5aa316aea..1b7c26b9f 100644 --- a/src/brevitas_examples/bnn_pynq/bnn_pynq_train.py +++ b/src/brevitas_examples/bnn_pynq/bnn_pynq_train.py @@ -73,6 +73,8 @@ def parse_args(args): "--state_dict_to_pth", action='store_true', help="Saves a model state_dict into a pth and then exits") + parser.add_argument("--export_qonnx", action='store_true', help="Export QONNX Model") + parser.add_argument("--export_qcdq_onnx", action='store_true', help="Export QCDQ ONNX Model") return parser.parse_args(args) @@ -110,7 +112,7 @@ def launch(cmd_args): # Avoid creating new folders etc. if args.evaluate: - args.dry_run = True + args.dry_run = True # Comment out to export ONNX models from pre-trained # Init trainer trainer = Trainer(args) diff --git a/src/brevitas_examples/bnn_pynq/trainer.py b/src/brevitas_examples/bnn_pynq/trainer.py index 78c1db97e..729668fd3 100644 --- a/src/brevitas_examples/bnn_pynq/trainer.py +++ b/src/brevitas_examples/bnn_pynq/trainer.py @@ -18,6 +18,8 @@ from torchvision.datasets import CIFAR10 from torchvision.datasets import MNIST +from brevitas.export import export_onnx_qcdq, export_qonnx + from .logger import EvalEpochMeters from .logger import Logger from .logger import TrainingEpochMeters @@ -149,6 +151,27 @@ def __init__(self, args): self.logger.info("Saving checkpoint model to {}".format(new_path)) exit(0) + if args.export_qonnx: + name = args.network.lower() + path = os.path.join(self.checkpoints_dir_path, name) + export_qonnx(model, self.train_loader.dataset[0][0].unsqueeze(0), path) + with open(path, "rb") as f: + bytes = f.read() + readable_hash = sha256(bytes).hexdigest()[:8] + new_path = os.path.join(self.checkpoints_dir_path, "{}-qonnx-{}.onnx".format(name, readable_hash)) + os.rename(path, new_path) + self.logger.info("Exporting QONNX to {}".format(new_path)) + if args.export_qcdq_onnx: + name = args.network.lower() + path = os.path.join(self.checkpoints_dir_path, name) + export_onnx_qcdq(model, self.train_loader.dataset[0][0].unsqueeze(0), path) + with open(path, "rb") as f: + bytes = f.read() + readable_hash = sha256(bytes).hexdigest()[:8] + new_path = os.path.join(self.checkpoints_dir_path, "{}-qcdq-{}.onnx".format(name, readable_hash)) + os.rename(path, new_path) + self.logger.info("Exporting QCDQ ONNX to {}".format(new_path)) + if args.gpus is not None and len(args.gpus) == 1: model = model.to(device=self.device) if args.gpus is not None and len(args.gpus) > 1: From 9c0569357c230d2de59c2ce20820c99e00705f9c Mon Sep 17 00:00:00 2001 From: Nick Fraser Date: Thu, 21 Mar 2024 14:04:54 +0000 Subject: [PATCH 3/3] Fix: Style to conform to precommit --- src/brevitas_examples/bnn_pynq/bnn_pynq_train.py | 2 +- src/brevitas_examples/bnn_pynq/trainer.py | 9 ++++++--- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/src/brevitas_examples/bnn_pynq/bnn_pynq_train.py b/src/brevitas_examples/bnn_pynq/bnn_pynq_train.py index 1b7c26b9f..505617745 100644 --- a/src/brevitas_examples/bnn_pynq/bnn_pynq_train.py +++ b/src/brevitas_examples/bnn_pynq/bnn_pynq_train.py @@ -112,7 +112,7 @@ def launch(cmd_args): # Avoid creating new folders etc. if args.evaluate: - args.dry_run = True # Comment out to export ONNX models from pre-trained + args.dry_run = True # Comment out to export ONNX models from pre-trained # Init trainer trainer = Trainer(args) diff --git a/src/brevitas_examples/bnn_pynq/trainer.py b/src/brevitas_examples/bnn_pynq/trainer.py index 729668fd3..90ff99f5c 100644 --- a/src/brevitas_examples/bnn_pynq/trainer.py +++ b/src/brevitas_examples/bnn_pynq/trainer.py @@ -18,7 +18,8 @@ from torchvision.datasets import CIFAR10 from torchvision.datasets import MNIST -from brevitas.export import export_onnx_qcdq, export_qonnx +from brevitas.export import export_onnx_qcdq +from brevitas.export import export_qonnx from .logger import EvalEpochMeters from .logger import Logger @@ -158,7 +159,8 @@ def __init__(self, args): with open(path, "rb") as f: bytes = f.read() readable_hash = sha256(bytes).hexdigest()[:8] - new_path = os.path.join(self.checkpoints_dir_path, "{}-qonnx-{}.onnx".format(name, readable_hash)) + new_path = os.path.join( + self.checkpoints_dir_path, "{}-qonnx-{}.onnx".format(name, readable_hash)) os.rename(path, new_path) self.logger.info("Exporting QONNX to {}".format(new_path)) if args.export_qcdq_onnx: @@ -168,7 +170,8 @@ def __init__(self, args): with open(path, "rb") as f: bytes = f.read() readable_hash = sha256(bytes).hexdigest()[:8] - new_path = os.path.join(self.checkpoints_dir_path, "{}-qcdq-{}.onnx".format(name, readable_hash)) + new_path = os.path.join( + self.checkpoints_dir_path, "{}-qcdq-{}.onnx".format(name, readable_hash)) os.rename(path, new_path) self.logger.info("Exporting QCDQ ONNX to {}".format(new_path))