From e103f5f737cce2526819514ec0afd048a8e51df5 Mon Sep 17 00:00:00 2001 From: Petros Toupas Date: Mon, 8 Jan 2024 18:39:09 +0200 Subject: [PATCH] add simplified onnx export model for unet3d case --- models/segmentation/brats2020.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/models/segmentation/brats2020.py b/models/segmentation/brats2020.py index c5520c3..738f37b 100644 --- a/models/segmentation/brats2020.py +++ b/models/segmentation/brats2020.py @@ -2,11 +2,13 @@ import nibabel as nib import numpy as np +import onnx import pandas as pd import torch import torch.nn as nn import torch.nn.functional as F from albumentations import Compose +from onnxsim import simplify from skimage.transform import resize from sklearn.model_selection import StratifiedKFold from torch.utils import data @@ -79,6 +81,12 @@ def _compute_outputs(images: torch.Tensor, dice_score, iou_score = meter.get_metrics() print("Dice: {:.4f}, IoU: {:.4f}".format(dice_score, iou_score)) + def onnx_exporter(self, onnx_path): + super().onnx_exporter(onnx_path) + model = onnx.load(onnx_path) + model_simp, check = simplify(model) + onnx.checker.check_model(model_simp) + onnx.save(model_simp, onnx_path) class BraTS2020(data.Dataset): def __init__(self, dataset_path: str, phase: str = "validation", is_resize: bool = False, seed: int = 55):