Skip to content

Commit

Permalink
add support for mac m1/m2 chips mps gpu
Browse files Browse the repository at this point in the history
need upgrade onnx and protobuf, add setup default device for torch, so user can use command params
"--torch-device mps:0" to launch gpu training on mac
  • Loading branch information
左云龙 committed Mar 31, 2024
1 parent fb2af76 commit 1853609
Showing 1 changed file with 2 additions and 0 deletions.
2 changes: 2 additions & 0 deletions ml-agents/mlagents/torch_utils/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ def set_torch_config(torch_settings: TorchSettings) -> None:

if _device.type == "cuda":
torch.set_default_tensor_type(torch.cuda.FloatTensor)
elif _device.type == 'mps':
torch.set_default_device(device_str)
else:
torch.set_default_tensor_type(torch.FloatTensor)
logger.debug(f"default Torch device: {_device}")
Expand Down

0 comments on commit 1853609

Please sign in to comment.