diff --git a/ml-agents/mlagents/torch_utils/torch.py b/ml-agents/mlagents/torch_utils/torch.py index ce1fec7a57b..311304ef544 100644 --- a/ml-agents/mlagents/torch_utils/torch.py +++ b/ml-agents/mlagents/torch_utils/torch.py @@ -53,7 +53,7 @@ def set_torch_config(torch_settings: TorchSettings) -> None: if _device.type == "cuda": torch.set_default_device(_device.type) - torch.set_default_dtype(torch.cuda.FloatTensor) + torch.set_default_dtype(torch.float32) else: torch.set_default_dtype(torch.float32) logger.debug(f"default Torch device: {_device}") diff --git a/ml-agents/mlagents/trainers/tests/test_torch_utils.py b/ml-agents/mlagents/trainers/tests/test_torch_utils.py index f0d69e35643..a8e15a4a26c 100644 --- a/ml-agents/mlagents/trainers/tests/test_torch_utils.py +++ b/ml-agents/mlagents/trainers/tests/test_torch_utils.py @@ -11,8 +11,8 @@ "device_str, expected_type, expected_index, expected_tensor_type", [ ("cpu", "cpu", None, torch.float32), - ("cuda", "cuda", None, torch.cuda.FloatTensor), - ("cuda:42", "cuda", 42, torch.cuda.FloatTensor), + ("cuda", "cuda", None, torch.float32), + ("cuda:42", "cuda", 42, torch.float32), ("opengl", "opengl", None, torch.float32), ], )