From ce85a7b2a9b58bdf1c6dffa29ee52d62e35bea51 Mon Sep 17 00:00:00 2001 From: Petros Toupas Date: Wed, 10 Jan 2024 01:50:43 +0200 Subject: [PATCH] Update brats2020.py --- models/segmentation/brats2020.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/models/segmentation/brats2020.py b/models/segmentation/brats2020.py index 1c38529..5575788 100644 --- a/models/segmentation/brats2020.py +++ b/models/segmentation/brats2020.py @@ -88,7 +88,7 @@ def onnx_exporter(self, onnx_path): replace_dict = {} for module in self.model.modules(): if isinstance(module, nn.GroupNorm): - replace_dict[module] = nn.BatchNorm3d(module.num_channels) + replace_dict[module] = nn.BatchNorm3d(module.num_channels).to(self.device) self.replace_modules(replace_dict) torch.onnx.export(self, random_input, onnx_path, verbose=False, keep_initializers_as_inputs=True) model = onnx.load(onnx_path)