Skip to content

Commit

Permalink
Add file_name when loading checkpoint from url using the load_state_d…
Browse files Browse the repository at this point in the history
…ict_from_url
  • Loading branch information
ptoupas committed Nov 16, 2023
1 parent ccbf5d0 commit a49d3f6
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 8 deletions.
4 changes: 2 additions & 2 deletions models/action_recognition/ucf101.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def load_model(self, val=True):
case "x3d_m":
config_path = os.path.join(
MMACTION_PATH, "configs/recognition/x3d/x3d_m_16x5x1_facebook-kinetics400-rgb.py")
checkpoint_path = "https://drive.google.com/uc?export=download&id=1l6x6LOmSfpugMOSuEZYb4foRIC8jXMQU"
checkpoint_path = "https://drive.google.com/uc?export=download&id=1_YgpEIb8SK6didDh8dv7Db0Rmb7GAdnn"

cfg = Config.fromfile(config_path)
# runner only load checkpoint when running inference, too late for compression, as model is already substituted
Expand All @@ -48,7 +48,7 @@ def load_model(self, val=True):
# cfg.log_level = "WARNING"
self.runner = Runner.from_cfg(cfg)
self.model = self.runner.model
state_dict = torch.hub.load_state_dict_from_url(checkpoint_path)[
state_dict = torch.hub.load_state_dict_from_url(checkpoint_path, file_name=f"{self.model_name}.pth")[
'state_dict']
self.model.load_state_dict(state_dict)
# load_checkpoint(self.model, checkpoint_path, map_location="cpu")
Expand Down
10 changes: 5 additions & 5 deletions models/segmentation/camvid.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def load_model(self, eval=True, approx_transpose_conv=True):
assert self.model_name == 'unet'

self.model = UNet(input_size_hw=self.input_size[2:], in_channels=self.input_size[1], n_classes=self.num_classes)
checkpoint = torch.hub.load_state_dict_from_url('https://storage.openvinotoolkit.org/repositories/nncf/models/v2.6.0/torch/unet_camvid.pth')
checkpoint = torch.hub.load_state_dict_from_url('https://storage.openvinotoolkit.org/repositories/nncf/models/v2.6.0/torch/unet_camvid.pth', file_name="unet_camvid.pth")
state_dict = checkpoint['state_dict']
# remove 'module.' prefix
state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
Expand All @@ -34,16 +34,16 @@ def load_model(self, eval=True, approx_transpose_conv=True):

if torch.cuda.is_available():
self.model = self.model.cuda()

def load_data(self, batch_size, workers):
# todo: download dataset
# https://github.com/alexgkendall/SegNet-Tutorial/tree/master/CamVid
CAMVID_PATH = os.environ.get("CAMVID_PATH", os.path.expanduser("~/dataset/CamVid"))

val_transforms = Compose([
Resize(size=self.input_size[2:]),
ToTensor(),
Normalize(mean=[0.39068785, 0.40521392, 0.41434407], std=[0.29652068, 0.30514979, 0.30080369])
Normalize(mean=[0.39068785, 0.40521392, 0.41434407], std=[0.29652068, 0.30514979, 0.30080369])
])
val_data = CamVid(CAMVID_PATH, "val", transforms=val_transforms)
test_data = CamVid(CAMVID_PATH, "test", transforms=val_transforms)
Expand All @@ -58,7 +58,7 @@ def load_data(self, batch_size, workers):
batch_size=batch_size, shuffle=False,
num_workers=workers, pin_memory=True,
collate_fn=collate_fn)

self.data_loaders['calibrate'] = val_loader
self.data_loaders['validate'] = val_loader
self.data_loaders['test'] = test_loader
Expand Down
2 changes: 1 addition & 1 deletion models/segmentation/cityscapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def load_model(self, eval=True):

self.runner = Runner.from_cfg(cfg)
self.model = self.runner.model
state_dict = torch.hub.load_state_dict_from_url(checkpoint_path)[
state_dict = torch.hub.load_state_dict_from_url(checkpoint_path, file_name=f"{self.model_name}.pth")[
'state_dict']
self.model.load_state_dict(state_dict)

Expand Down
1 change: 1 addition & 0 deletions quantization_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def main():
model_wrapper.inference("test")
model_wrapper.generate_onnx_files(
os.path.join(args.output_path, "float32"))

# TEST 2
print("NETWORK FP16 Inference")
# reload the model everytime a new quantization mode is tested
Expand Down

0 comments on commit a49d3f6

Please sign in to comment.