diff --git a/models/segmentation/brats2020.py b/models/segmentation/brats2020.py index 899bdb8..1c38529 100644 --- a/models/segmentation/brats2020.py +++ b/models/segmentation/brats2020.py @@ -34,7 +34,7 @@ def load_model(self, eval=True, approx_transpose_conv=True): checkpoint_path = "https://drive.google.com/uc?export=download&id=1NiyVXIr5zcnd3F-zNi3FCj6PmYZnabH-" state_dict = torch.hub.load_state_dict_from_url( - checkpoint_path, file_name=f"{self.model_name}.pth") + checkpoint_path, file_name=f"{self.model_name}.pth", map_location=self.device) self.model.load_state_dict(state_dict, strict=True) @@ -88,9 +88,9 @@ def onnx_exporter(self, onnx_path): replace_dict = {} for module in self.model.modules(): if isinstance(module, nn.GroupNorm): - replace_dict[module] = nn.Identity() + replace_dict[module] = nn.BatchNorm3d(module.num_channels) self.replace_modules(replace_dict) - torch.onnx.export(self.model, random_input, onnx_path, verbose=False, keep_initializers_as_inputs=True) + torch.onnx.export(self, random_input, onnx_path, verbose=False, keep_initializers_as_inputs=True) model = onnx.load(onnx_path) model_simp, check = simplify(model) onnx.checker.check_model(model_simp)