Skip to content

Commit

Permalink
map model to device using map_location parameter
Browse files Browse the repository at this point in the history
  • Loading branch information
thearyadev committed Feb 29, 2024
1 parent 145fac3 commit 6187715
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion heroes/hero_comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def predict_hero_name(self, image: Image, model_directory: Path) -> str:


if model_directory.name not in model_cache.keys():
model_cache[model_directory.name] = torch.load(model_directory / "model.pth")
model_cache[model_directory.name] = torch.load(model_directory / "model.pth", map_location=device)
st_dict = model_cache[model_directory.name]


Expand Down

0 comments on commit 6187715

Please sign in to comment.