Skip to content

Commit

Permalink
Fix GroupNorm issue on unet3d onnx model
Browse files Browse the repository at this point in the history
  • Loading branch information
ptoupas committed Jan 8, 2024
1 parent d6a0e71 commit 9015322
Showing 1 changed file with 9 additions and 1 deletion.
10 changes: 9 additions & 1 deletion models/segmentation/brats2020.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,15 @@ def _compute_outputs(images: torch.Tensor,
print("Dice: {:.4f}, IoU: {:.4f}".format(dice_score, iou_score))

def onnx_exporter(self, onnx_path):
super().onnx_exporter(onnx_path)
random_input = torch.randn(self.input_size)
if torch.cuda.is_available():
random_input = random_input.cuda()
replace_dict = {}
for module in self.model.modules():
if isinstance(module, nn.GroupNorm):
replace_dict[module] = nn.Identity()
self.replace_modules(replace_dict)
torch.onnx.export(self.model, random_input, onnx_path, verbose=False, keep_initializers_as_inputs=True)
model = onnx.load(onnx_path)
model_simp, check = simplify(model)
onnx.checker.check_model(model_simp)
Expand Down

0 comments on commit 9015322

Please sign in to comment.