diff --git a/sky/utils/accelerator_registry.py b/sky/utils/accelerator_registry.py index c71eea3c951..ada328171a7 100644 --- a/sky/utils/accelerator_registry.py +++ b/sky/utils/accelerator_registry.py @@ -25,6 +25,7 @@ 'A10G', 'Gaudi HL-205', 'Inferentia', + 'Trainium', 'K520', 'K80', 'M60', diff --git a/tests/test_list_accelerators.py b/tests/test_list_accelerators.py index cee92ba5350..777f7a92c10 100644 --- a/tests/test_list_accelerators.py +++ b/tests/test_list_accelerators.py @@ -8,6 +8,7 @@ def test_list_accelerators(): assert 'V100' in result, result assert 'tpu-v3-8' in result, result assert 'Inferentia' not in result, result + assert 'Trainium' not in result, result assert 'A100-80GB' in result, result @@ -16,6 +17,7 @@ def test_list_ccelerators_all(): assert 'V100' in result, result assert 'tpu-v3-8' in result, result assert 'Inferentia' in result, result + assert 'Trainium' in result, result assert 'A100-80GB' in result, result