diff --git a/ml-agents/mlagents/trainers/tests/test_torch_utils.py b/ml-agents/mlagents/trainers/tests/test_torch_utils.py index f0d69e3564..a8e15a4a26 100644 --- a/ml-agents/mlagents/trainers/tests/test_torch_utils.py +++ b/ml-agents/mlagents/trainers/tests/test_torch_utils.py @@ -11,8 +11,8 @@ "device_str, expected_type, expected_index, expected_tensor_type", [ ("cpu", "cpu", None, torch.float32), - ("cuda", "cuda", None, torch.cuda.FloatTensor), - ("cuda:42", "cuda", 42, torch.cuda.FloatTensor), + ("cuda", "cuda", None, torch.float32), + ("cuda:42", "cuda", 42, torch.float32), ("opengl", "opengl", None, torch.float32), ], )