Skip to content

Commit

Permalink
add simplified onnx export model for unet3d case
Browse files Browse the repository at this point in the history
  • Loading branch information
ptoupas committed Jan 8, 2024
1 parent cf1bc4f commit e103f5f
Showing 1 changed file with 8 additions and 0 deletions.
8 changes: 8 additions & 0 deletions models/segmentation/brats2020.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@

import nibabel as nib
import numpy as np
import onnx
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from albumentations import Compose
from onnxsim import simplify
from skimage.transform import resize
from sklearn.model_selection import StratifiedKFold
from torch.utils import data
Expand Down Expand Up @@ -79,6 +81,12 @@ def _compute_outputs(images: torch.Tensor,
dice_score, iou_score = meter.get_metrics()
print("Dice: {:.4f}, IoU: {:.4f}".format(dice_score, iou_score))

def onnx_exporter(self, onnx_path):
super().onnx_exporter(onnx_path)
model = onnx.load(onnx_path)
model_simp, check = simplify(model)
onnx.checker.check_model(model_simp)
onnx.save(model_simp, onnx_path)

class BraTS2020(data.Dataset):
def __init__(self, dataset_path: str, phase: str = "validation", is_resize: bool = False, seed: int = 55):
Expand Down

0 comments on commit e103f5f

Please sign in to comment.