Skip to content

Commit

Permalink
fix unet3d onnx naming
Browse files Browse the repository at this point in the history
  • Loading branch information
Yu-Zhewen committed Jan 9, 2024
1 parent 41b550c commit ae5db03
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions models/segmentation/brats2020.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def load_model(self, eval=True, approx_transpose_conv=True):

checkpoint_path = "https://drive.google.com/uc?export=download&id=1NiyVXIr5zcnd3F-zNi3FCj6PmYZnabH-"
state_dict = torch.hub.load_state_dict_from_url(
checkpoint_path, file_name=f"{self.model_name}.pth")
checkpoint_path, file_name=f"{self.model_name}.pth", map_location=self.device)

self.model.load_state_dict(state_dict, strict=True)

Expand Down Expand Up @@ -88,9 +88,9 @@ def onnx_exporter(self, onnx_path):
replace_dict = {}
for module in self.model.modules():
if isinstance(module, nn.GroupNorm):
replace_dict[module] = nn.Identity()
replace_dict[module] = nn.BatchNorm3d(module.num_channels)
self.replace_modules(replace_dict)
torch.onnx.export(self.model, random_input, onnx_path, verbose=False, keep_initializers_as_inputs=True)
torch.onnx.export(self, random_input, onnx_path, verbose=False, keep_initializers_as_inputs=True)
model = onnx.load(onnx_path)
model_simp, check = simplify(model)
onnx.checker.check_model(model_simp)
Expand Down

0 comments on commit ae5db03

Please sign in to comment.