Skip to content

Commit

Permalink
fix opset_version
Browse files Browse the repository at this point in the history
  • Loading branch information
KudoKhang committed Oct 11, 2022
1 parent eca2cac commit 2187399
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 39 deletions.
4 changes: 2 additions & 2 deletions inferenceONNX.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from networks import *

# predictor = BSNONNXPredict(pretrained='checkpoints/bisenet_no_opt.onnx')
predictor = BSNONNXPredict()
predictor = BSNONNXPredict(pretrained='checkpoints/bisenet.onnx')

root = 'dataset/Figaro_1k/test/images/'
list_path = [root + name for name in os.listdir(root) if name.endswith(('jpg'))]
Expand All @@ -11,7 +11,7 @@
# cv2.imshow('result', result)
# cv2.waitKey(0)

img = cv2.imread(root + '568.jpg')
img = cv2.imread(root + '959.jpg')
result = predictor.predict(img)
cv2.imshow('result', result)
cv2.waitKey(0)
18 changes: 2 additions & 16 deletions networks/BSNONNXPredict.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,6 @@ def __init__(self, pretrained='checkpoints/bisenet.onnx', is_draw_bbox=False):
def to_numpy(self, tensor):
return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()

def shift(self, image, shift_x=22, shift_y=10):
shift_to_right_or_down = 1
shift_to_left_or_top = -1
for i in range(image.shape[1] - 1, image.shape[1] - shift_x, -1):
image = np.roll(image, shift_to_right_or_down, axis=1)
image[:, -1] = 0

for i in range(image.shape[0] - 1, image.shape[0] - shift_y, -1):
image = np.roll(image, 1, axis=0)
image[-1, :] = 0

return image

def reverse_one_hot(self, image):
# Convert output of model to predicted class
image = image.permute(1, 2, 0)
Expand All @@ -38,7 +25,6 @@ def process_output(self, label, img):
label = np.array(label.cpu(), dtype='uint8')
label = (1 - label) * 255
label = cv2.resize(label, img.shape[:2][::-1])
label = self.shift(label)
return label

def process_input(self, image_path):
Expand Down Expand Up @@ -70,13 +56,13 @@ def visualize(self, img, label, color = (0, 255, 0)):
label[:,:,2][np.where(label[:,:,2]==255)] = color[2]
return cv2.addWeighted(img, 0.6, label, 0.4, 0)

def predict(self, image, is_visualize=True):
def predict(self, image, visualize=True):
image_processed = self.process_input(image)
inputs = {self.model.get_inputs()[0].name: self.to_numpy(image_processed)}
outputs = self.model.run(None, inputs)
mask = self.process_output(outputs, image)

if is_visualize:
if visualize:
mask = self.visualize(image, mask)

return mask
25 changes: 4 additions & 21 deletions torch2onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import io
import numpy as np
from torch import nn
import torch.utils.model_zoo as model_zoo
from networks import *
import onnx

Expand All @@ -15,10 +14,10 @@ def Convert_ONNX():
# Export the model
torch.onnx.export(model, # model being run
dummy_input, # model input (or a tuple for multiple inputs)
"checkpoints/bisenet_no_opt.onnx", # where to save the model
"checkpoints/bisenet.onnx", # where to save the model
verbose=True, # Show progress
export_params=True, # store the trained parameter weights inside the model file
opset_version=10, # the ONNX version to export the model to
export_params=False, # store the trained parameter weights inside the model file
opset_version=12, # the ONNX version to export the model to
do_constant_folding=False, # whether to execute constant folding for optimization
input_names = ['modelInput'], # the model's input names
output_names = ['modelOutput'], # the model's output names
Expand All @@ -27,26 +26,10 @@ def Convert_ONNX():
print(" ")
print('Model has been converted to ONNX')

def check_model(weight):
# load model from onnx
cnn = onnx.load(weight)
decoder = onnx.load(weight)
encoder = onnx.load(weight)

# confirm model has valid schema
onnx.checker.check_model(cnn)
onnx.checker.check_model(decoder)
onnx.checker.check_model(encoder)

# Print a human readable representation of the graph
print(onnx.helper.printable_graph(encoder.graph))


if __name__ == "__main__":
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = BiSeNet(num_classes=2)
weight = "checkpoints/lastest_model_CeFiLa.pth"
model.load_state_dict(torch.load(weight, map_location=torch.device(device))['state_dict'])
model = model.to(device)
Convert_ONNX()
# check_model('checkpoints/bisenet.onnx')
Convert_ONNX()

0 comments on commit 2187399

Please sign in to comment.