diff --git a/pkgs/development/python-modules/torch/default.nix b/pkgs/development/python-modules/torch/default.nix index 10eecd1de99b7c..82ad61a191674e 100644 --- a/pkgs/development/python-modules/torch/default.nix +++ b/pkgs/development/python-modules/torch/default.nix @@ -52,7 +52,7 @@ # ROCm dependencies rocmSupport ? config.rocmSupport, - rocmPackages, + rocmPackages_5, gpuTargets ? [ ] }: @@ -60,6 +60,8 @@ let inherit (lib) attrsets lists strings trivial; inherit (cudaPackages) cudaFlags cudnn nccl; + rocmPackages = rocmPackages_5; + setBool = v: if v then "1" else "0"; # https://github.com/pytorch/pytorch/blob/v2.0.1/torch/utils/cpp_extension.py#L1744