diff --git a/smartsim/_core/_install/buildenv.py b/smartsim/_core/_install/buildenv.py index 4212a223a..72e09a4b9 100644 --- a/smartsim/_core/_install/buildenv.py +++ b/smartsim/_core/_install/buildenv.py @@ -218,6 +218,8 @@ def ml_extras_required(self) -> t.Dict[str, t.List[str]]: "onnxmltools": "1.12.0", "scikit-learn": "1.3.2", "torchvision": "0.15.2", + "torch_cpu_suffix": "+cpu", + "torch_cuda_suffix": "+cu117", } # remove torch-related fields as they are subject to change