Skip to content

Commit

Permalink
fixed (I think) the mistake for LeNet
Browse files Browse the repository at this point in the history
  • Loading branch information
joannapng committed Jun 21, 2024
1 parent aeff97d commit 71d8fc1
Show file tree
Hide file tree
Showing 25 changed files with 119 additions and 88 deletions.
2 changes: 1 addition & 1 deletion exporter/Exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def preprocessing(model: ModelWrapper, cfg: build.DataflowBuildConfig):

model = model.transform(MergeONNXModels(preproc_model))
global_inp_name = model.graph.input[0].name
model.set_tensor_datatype(global_inp_name, DataType["UINT8"])
model.set_tensor_datatype(global_inp_name, DataType["INT8"])
model = tidy_up(model)

return model
Expand Down
1 change: 0 additions & 1 deletion pretrain/models/LeNet5.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ def clip_weights(self, min_val = -1, max_val = 1):
mod.weight.data.clamp_(min_val, max_val)

def forward(self, x):
x = 2.0 * x - 1.0
for mod in self.conv_features:
x = mod(x)

Expand Down
Binary file removed pretrain/models/__pycache__/LeNet5.cpython-311.pyc
Binary file not shown.
Binary file removed pretrain/models/__pycache__/__init__.cpython-311.pyc
Binary file not shown.
3 changes: 0 additions & 3 deletions pretrain/trainer/Trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,7 @@ def init_dataset(self, args, config):

transformations = transforms.Compose([
transforms.RandomCrop(32, padding = 4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize
])
elif args.dataset == 'MNIST':
normalize = transforms.Normalize(mean = (0.1307, ), std = (0.3081, ))
Expand All @@ -98,7 +96,6 @@ def init_dataset(self, args, config):
transforms.Resize(28),
transforms.CenterCrop(28),
transforms.ToTensor(),
normalize
])

train_set = builder(root=args.datadir,
Expand Down
Binary file removed pretrain/trainer/__pycache__/Trainer.cpython-311.pyc
Binary file not shown.
Binary file removed pretrain/trainer/__pycache__/__init__.cpython-311.pyc
Binary file not shown.
60 changes: 36 additions & 24 deletions test.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
from pkgutil import get_data
import attr
import onnx
import onnx.numpy_helper as nph
import torch
import brevitas

import math
import numpy as np
import argparse
from brevitas.graph.utils import get_module
import importlib_resources as importlib
from urllib3 import disable_warnings

from train.env import ModelEnv
from stable_baselines3.common.monitor import Monitor
Expand All @@ -30,6 +34,7 @@

model_names = ['LeNet5', 'resnet18', 'resnet34', 'resnet50', 'resnet100', 'resnet152']


def get_example_input(dataset):
if dataset == "MNIST":
raw_i = get_data("qonnx.data", "onnx/mnist-conv/test_data_set_0/input_0.pb")
Expand Down Expand Up @@ -127,42 +132,49 @@ def get_example_input(dataset):
img_shape = center_crop_shape

'''
#input_tensor_torch, _ = next(iter(env.finetuner.export_loader))
input_tensor_torch = next(iter(env.finetuner.export_loader))[0][0]
input_tensor_torch = input_tensor_torch[None, :, :, :]
input_tensor_npy = input_tensor_torch.cpu().numpy().astype(np.float32)
input_tensor_npy = np.transpose(input_tensor_npy, (0, 2, 3, 1)) # N, H, W, C
input_tensor_torch = input_tensor_torch.detach().to(env.finetuner.device) / 255.0
np.save("input.npy", input_tensor_npy)
output_golden = model.forward(input_tensor_torch).detach().cpu().numpy()
output_golden = np.argmax(output_golden, axis = 1)
np.save("expected_output.npy", output_golden)
'''

input_tensor_npy = get_example_input(args.dataset)
input_tensor_torch = torch.from_numpy(input_tensor_npy).float() / 255.0
input_tensor_torch = input_tensor_torch.detach().to(env.finetuner.device)
input_tensor_npy = np.transpose(input_tensor_npy, (0, 2, 3, 1)) # N, H, W, C
np.save("input.npy", input_tensor_npy)
np.save(f'{os.path.join(args.output_dir, "input.npy")}', input_tensor_npy)
'''

output_golden = model.forward(input_tensor_torch).detach().cpu().numpy()
print(output_golden)
output_golden = np.flip(output_golden.flatten().argsort())[:1]
print(output_golden)
np.save("expected_output.npy", output_golden)
input_tensor_torch, _ = next(iter(env.finetuner.export_loader))
input_tensor_numpy = input_tensor_torch.detach().cpu().numpy().astype(np.float32)
input_tensor_numpy = np.transpose(input_tensor_numpy, (0, 2, 3, 1))
np.save(f'{os.path.join(args.output_dir, "input.npy")}', input_tensor_numpy)

output_golden = model.forward(input_tensor_torch.to(env.finetuner.device)).detach().cpu().numpy()
output_golden = np.argmax(output_golden, axis = 1)
np.save(f'{os.path.join(args.output_dir, "expected_output.npy")}', output_golden)

# export quant model to qonnx
output = os.path.join(args.output_dir, args.onnx_output)
name = output + '_quant.onnx'
model.cpu()
device, dtype = next(model.parameters()).device, next(model.parameters()).dtype
ref_input = torch.randn(1, env.finetuner.in_channels, img_shape, img_shape, device = device, dtype = dtype)

bo.export_qonnx(model, input_t = ref_input, export_path = name, opset_version = 11)
from qonnx.core.modelwrapper import ModelWrapper
model = ModelWrapper(name)

# export original model to onnx
orig_model = env.orig_model
orig_model.eval()
output = os.path.join(args.output_dir, args.onnx_output)
orig_model.cpu()
name = output + '.onnx'
torch.onnx.export(orig_model, ref_input, name, export_params = False, opset_version=11)
torch.onnx.export(orig_model, ref_input, name, export_params = True, opset_version=11)

graph = model.graph
for node in graph.node:
if node.op_type in ["Quant", "BinaryQuant", "Trunc"] and "weight_quant" in node.name:
for attribute in node.attribute:
if attribute.name == "signed":
attribute.i = 1
elif attribute.name == "narrow":
attribute.i = 0

model.save(output + '_quant.onnx')

# export quant model to qonnx
name = output + '_quant.onnx'
bo.export_qonnx(model, ref_input, export_path = name, export_params = True, keep_initializers_as_inputs = False, opset_version=11)

8 changes: 8 additions & 0 deletions train/env/ModelEnv.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,12 @@ def reset(self, seed = None, option = None):
self.strategy = []

obs = self.layer_embedding[0].copy()
self.quantizer = Quantizer(
self.model,
self.args.weight_bit_width,
self.args.act_bit_width,
self.args.bias_bit_width
)
return obs, {}

def step(self, action):
Expand All @@ -279,6 +285,8 @@ def step(self, action):
self.finetuner.finetune()

# validate model
self.model = deepcopy(self.finetuner.model)
self.finetuner.model = deepcopy(self.model)
acc = self.finetuner.validate()

reward = self.reward(acc)
Expand Down
Binary file removed train/env/__pycache__/ModelEnv.cpython-311.pyc
Binary file not shown.
Binary file not shown.
Binary file removed train/env/__pycache__/__init__.cpython-311.pyc
Binary file not shown.
Binary file removed train/env/__pycache__/utils.cpython-311.pyc
Binary file not shown.
20 changes: 12 additions & 8 deletions train/finetune/Finetuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,13 @@ def init_dataset(self, args, config):
transforms.ToTensor(),
normalize
])

export_transformations = transforms.Compose([
transforms.RandomCrop(32, padding = 4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
])

elif args.dataset == 'MNIST':
normalize = transforms.Normalize(mean = (0.1307, ), std = (0.3081, ))

Expand All @@ -93,7 +100,6 @@ def init_dataset(self, args, config):
transforms.Resize(28),
transforms.CenterCrop(28),
transforms.ToTensor(),
normalize
])

export_transformations = transforms.Compose([
Expand All @@ -107,7 +113,6 @@ def init_dataset(self, args, config):
transforms.Resize(config['resize_shape']),
transforms.CenterCrop(config['center_crop_shape']),
transforms.ToTensor(),
normalize
])

export_transformations = transforms.Compose([
Expand Down Expand Up @@ -211,11 +216,10 @@ def init_loss(self):
elif self.args.loss == 'SqrHinge':
self.criterion = nn.SqrHingeLoss()

def check_accuracy(self, loader, model, eval = True):
def check_accuracy(self, loader, model):
num_correct = 0
num_samples = 0
if eval:
model.eval() # set model to evaluation mode
model.eval() # set model to evaluation mode

with torch.no_grad():
for x_val, y_val in loader:
Expand Down Expand Up @@ -254,11 +258,11 @@ def finetune(self):
#print("Training Complete")
# Testing accuracy in the testing dataset
print('-------- Testing Accuracy -------')
self.test_acc = self.check_accuracy(self.test_loader, self.model, eval = False)
self.test_acc = self.check_accuracy(self.test_loader, self.model)
return 0.0, self.model

def validate(self, eval = True):
return validate(self.model, val_loader=self.test_loader, eval = eval)
def validate(self):
return validate(self.model, val_loader=self.test_loader)

def calibrate(self):
calibrate(self.args, self.model, calib_loader=self.calib_loader)
Expand Down
Binary file removed train/finetune/__pycache__/Finetuner.cpython-311.pyc
Binary file not shown.
Binary file removed train/finetune/__pycache__/__init__.cpython-311.pyc
Binary file not shown.
Binary file removed train/finetune/__pycache__/calibrate.cpython-311.pyc
Binary file not shown.
Binary file removed train/finetune/__pycache__/validate.cpython-311.pyc
Binary file not shown.
13 changes: 7 additions & 6 deletions train/finetune/calibrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@ def calibrate(args, model, calib_loader):
images = images.to(dtype)
model(images)

if args.bias_corr:
with bias_correction_mode(model):
for i, (images, target) in enumerate(calib_loader):
images = images.to(device)
images = images.to(dtype)
model(images)
'''
with bias_correction_mode(model):
for i, (images, target) in enumerate(calib_loader):
images = images.to(device)
images = images.to(dtype)
model(images)
'''
5 changes: 2 additions & 3 deletions train/finetune/validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,15 +52,14 @@ def accuracy(output, target, topk=(1,), stable=False):
return res


def validate(model, val_loader, eval = True):
def validate(model, val_loader):
top1 = AverageMeter('Acc@1', ':6.2f')

def print_accuracy(top1, prefix=''):
print('{}Avg acc@1 {top1.avg:2.3f}'.format(prefix, top1=top1))

# do not set model to eval mode because it requires that the residual connections are handled
if eval:
model.eval()
model.eval()

dtype = next(model.parameters()).dtype
device = next(model.parameters()).device
Expand Down
Loading

0 comments on commit 71d8fc1

Please sign in to comment.