Skip to content

Commit

Permalink
Changed default cuda dtype to torch.float32.
Browse files Browse the repository at this point in the history
  • Loading branch information
miguelalonsojr committed Oct 4, 2024
1 parent c683e77 commit 54664fc
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion ml-agents/mlagents/torch_utils/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def set_torch_config(torch_settings: TorchSettings) -> None:

if _device.type == "cuda":
torch.set_default_device(_device.type)
torch.set_default_dtype(torch.cuda.FloatTensor)
torch.set_default_dtype(torch.float32)
else:
torch.set_default_dtype(torch.float32)
logger.debug(f"default Torch device: {_device}")
Expand Down

0 comments on commit 54664fc

Please sign in to comment.