diff --git a/heroes/hero_comparison.py b/heroes/hero_comparison.py index de4c1780..669d3757 100644 --- a/heroes/hero_comparison.py +++ b/heroes/hero_comparison.py @@ -145,7 +145,7 @@ def predict_hero_name(self, image: Image, model_directory: Path) -> str: if model_directory.name not in model_cache.keys(): - model_cache[model_directory.name] = torch.load(model_directory / "model.pth") + model_cache[model_directory.name] = torch.load(model_directory / "model.pth", map_location=device) st_dict = model_cache[model_directory.name]