From 61877158808e18deac35314572a14173de4bceee Mon Sep 17 00:00:00 2001 From: Aryan Kothari <87589047+thearyadev@users.noreply.github.com> Date: Thu, 29 Feb 2024 16:09:31 -0500 Subject: [PATCH] map model to device using `map_location` parameter --- heroes/hero_comparison.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/heroes/hero_comparison.py b/heroes/hero_comparison.py index de4c1780..669d3757 100644 --- a/heroes/hero_comparison.py +++ b/heroes/hero_comparison.py @@ -145,7 +145,7 @@ def predict_hero_name(self, image: Image, model_directory: Path) -> str: if model_directory.name not in model_cache.keys(): - model_cache[model_directory.name] = torch.load(model_directory / "model.pth") + model_cache[model_directory.name] = torch.load(model_directory / "model.pth", map_location=device) st_dict = model_cache[model_directory.name]