diff --git a/ml-agents/mlagents/torch_utils/torch.py b/ml-agents/mlagents/torch_utils/torch.py index ce1fec7a57..311304ef54 100644 --- a/ml-agents/mlagents/torch_utils/torch.py +++ b/ml-agents/mlagents/torch_utils/torch.py @@ -53,7 +53,7 @@ def set_torch_config(torch_settings: TorchSettings) -> None: if _device.type == "cuda": torch.set_default_device(_device.type) - torch.set_default_dtype(torch.cuda.FloatTensor) + torch.set_default_dtype(torch.float32) else: torch.set_default_dtype(torch.float32) logger.debug(f"default Torch device: {_device}")