From a49d3f6efa7044c237a9e809701c0f12934e2e25 Mon Sep 17 00:00:00 2001 From: Petros Toupas Date: Fri, 17 Nov 2023 01:26:50 +0200 Subject: [PATCH] Add file_name when loading checkpoint from url using the load_state_dict_from_url --- models/action_recognition/ucf101.py | 4 ++-- models/segmentation/camvid.py | 10 +++++----- models/segmentation/cityscapes.py | 2 +- quantization_example.py | 1 + 4 files changed, 9 insertions(+), 8 deletions(-) diff --git a/models/action_recognition/ucf101.py b/models/action_recognition/ucf101.py index 472abab..0438a0b 100644 --- a/models/action_recognition/ucf101.py +++ b/models/action_recognition/ucf101.py @@ -25,7 +25,7 @@ def load_model(self, val=True): case "x3d_m": config_path = os.path.join( MMACTION_PATH, "configs/recognition/x3d/x3d_m_16x5x1_facebook-kinetics400-rgb.py") - checkpoint_path = "https://drive.google.com/uc?export=download&id=1l6x6LOmSfpugMOSuEZYb4foRIC8jXMQU" + checkpoint_path = "https://drive.google.com/uc?export=download&id=1_YgpEIb8SK6didDh8dv7Db0Rmb7GAdnn" cfg = Config.fromfile(config_path) # runner only load checkpoint when running inference, too late for compression, as model is already substituted @@ -48,7 +48,7 @@ def load_model(self, val=True): # cfg.log_level = "WARNING" self.runner = Runner.from_cfg(cfg) self.model = self.runner.model - state_dict = torch.hub.load_state_dict_from_url(checkpoint_path)[ + state_dict = torch.hub.load_state_dict_from_url(checkpoint_path, file_name=f"{self.model_name}.pth")[ 'state_dict'] self.model.load_state_dict(state_dict) # load_checkpoint(self.model, checkpoint_path, map_location="cpu") diff --git a/models/segmentation/camvid.py b/models/segmentation/camvid.py index 45c9279..ed039e7 100644 --- a/models/segmentation/camvid.py +++ b/models/segmentation/camvid.py @@ -23,7 +23,7 @@ def load_model(self, eval=True, approx_transpose_conv=True): assert self.model_name == 'unet' self.model = UNet(input_size_hw=self.input_size[2:], in_channels=self.input_size[1], n_classes=self.num_classes) - checkpoint = torch.hub.load_state_dict_from_url('https://storage.openvinotoolkit.org/repositories/nncf/models/v2.6.0/torch/unet_camvid.pth') + checkpoint = torch.hub.load_state_dict_from_url('https://storage.openvinotoolkit.org/repositories/nncf/models/v2.6.0/torch/unet_camvid.pth', file_name="unet_camvid.pth") state_dict = checkpoint['state_dict'] # remove 'module.' prefix state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()} @@ -34,16 +34,16 @@ def load_model(self, eval=True, approx_transpose_conv=True): if torch.cuda.is_available(): self.model = self.model.cuda() - + def load_data(self, batch_size, workers): # todo: download dataset # https://github.com/alexgkendall/SegNet-Tutorial/tree/master/CamVid CAMVID_PATH = os.environ.get("CAMVID_PATH", os.path.expanduser("~/dataset/CamVid")) - + val_transforms = Compose([ Resize(size=self.input_size[2:]), ToTensor(), - Normalize(mean=[0.39068785, 0.40521392, 0.41434407], std=[0.29652068, 0.30514979, 0.30080369]) + Normalize(mean=[0.39068785, 0.40521392, 0.41434407], std=[0.29652068, 0.30514979, 0.30080369]) ]) val_data = CamVid(CAMVID_PATH, "val", transforms=val_transforms) test_data = CamVid(CAMVID_PATH, "test", transforms=val_transforms) @@ -58,7 +58,7 @@ def load_data(self, batch_size, workers): batch_size=batch_size, shuffle=False, num_workers=workers, pin_memory=True, collate_fn=collate_fn) - + self.data_loaders['calibrate'] = val_loader self.data_loaders['validate'] = val_loader self.data_loaders['test'] = test_loader diff --git a/models/segmentation/cityscapes.py b/models/segmentation/cityscapes.py index 9cd7ef0..a62d48a 100644 --- a/models/segmentation/cityscapes.py +++ b/models/segmentation/cityscapes.py @@ -29,7 +29,7 @@ def load_model(self, eval=True): self.runner = Runner.from_cfg(cfg) self.model = self.runner.model - state_dict = torch.hub.load_state_dict_from_url(checkpoint_path)[ + state_dict = torch.hub.load_state_dict_from_url(checkpoint_path, file_name=f"{self.model_name}.pth")[ 'state_dict'] self.model.load_state_dict(state_dict) diff --git a/quantization_example.py b/quantization_example.py index 416156c..2fb172d 100644 --- a/quantization_example.py +++ b/quantization_example.py @@ -48,6 +48,7 @@ def main(): model_wrapper.inference("test") model_wrapper.generate_onnx_files( os.path.join(args.output_path, "float32")) + # TEST 2 print("NETWORK FP16 Inference") # reload the model everytime a new quantization mode is tested