From 3f8c182ae3423746206cb6a722e479f0ecfd47f0 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 13 Sep 2024 10:33:01 -0700 Subject: [PATCH] Add support for Tegra chips in hermetic CUDA rules. See [Github issue](https://github.com/tensorflow/tensorflow/issues/75353). PiperOrigin-RevId: 674350119 --- .../cuda_redist_init_repositories.bzl | 34 ++++++++++++++----- 1 file changed, 25 insertions(+), 9 deletions(-) diff --git a/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_redist_init_repositories.bzl b/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_redist_init_repositories.bzl index c1e52f1ce83e0..0237a09993c01 100644 --- a/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_redist_init_repositories.bzl +++ b/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_redist_init_repositories.bzl @@ -25,12 +25,16 @@ load( OS_ARCH_DICT = { "amd64": "x86_64-unknown-linux-gnu", "aarch64": "aarch64-unknown-linux-gnu", + "tegra-aarch64": "tegra-aarch64-unknown-linux-gnu", } _REDIST_ARCH_DICT = { "linux-x86_64": "x86_64-unknown-linux-gnu", "linux-sbsa": "aarch64-unknown-linux-gnu", + "linux-aarch64": "tegra-aarch64-unknown-linux-gnu", } +TEGRA = "tegra" + SUPPORTED_ARCHIVE_EXTENSIONS = [ ".zip", ".jar", @@ -277,6 +281,15 @@ def _download_redistribution(repository_ctx, arch_key, path_prefix): ) repository_ctx.delete(file_name) +def _get_platform_architecture(repository_ctx): + host_arch = repository_ctx.os.arch + + if host_arch == "aarch64": + uname_result = repository_ctx.execute(["uname", "-a"]).stdout + if TEGRA in uname_result: + return "{}-{}".format(TEGRA, host_arch) + return host_arch + def _use_downloaded_cuda_redistribution(repository_ctx): # buildifier: disable=function-docstring-args """ Downloads CUDA redistribution and initializes hermetic CUDA repository.""" @@ -298,7 +311,7 @@ def _use_downloaded_cuda_redistribution(repository_ctx): return # Download archive only when GPU config is used. - arch_key = OS_ARCH_DICT[repository_ctx.os.arch] + arch_key = OS_ARCH_DICT[_get_platform_architecture(repository_ctx)] if arch_key not in repository_ctx.attr.url_dict.keys(): fail( ("The supported platforms are {supported_platforms}." + @@ -374,7 +387,7 @@ def _use_downloaded_cudnn_redistribution(repository_ctx): return # Download archive only when GPU config is used. - arch_key = OS_ARCH_DICT[repository_ctx.os.arch] + arch_key = OS_ARCH_DICT[_get_platform_architecture(repository_ctx)] if arch_key not in repository_ctx.attr.url_dict.keys(): arch_key = "cuda{version}_{arch}".format( version = cuda_version.split(".")[0], @@ -438,21 +451,24 @@ cudnn_repo = repository_rule( def _get_redistribution_urls(dist_info): url_dict = {} for arch in _REDIST_ARCH_DICT.keys(): - if "relative_path" in dist_info[arch]: + arch_key = arch + if arch_key == "linux-aarch64" and arch_key not in dist_info: + arch_key = "linux-sbsa" + if "relative_path" in dist_info[arch_key]: url_dict[_REDIST_ARCH_DICT[arch]] = [ - dist_info[arch]["relative_path"], - dist_info[arch].get("sha256", ""), + dist_info[arch_key]["relative_path"], + dist_info[arch_key].get("sha256", ""), ] continue - if "full_path" in dist_info[arch]: + if "full_path" in dist_info[arch_key]: url_dict[_REDIST_ARCH_DICT[arch]] = [ - dist_info[arch]["full_path"], - dist_info[arch].get("sha256", ""), + dist_info[arch_key]["full_path"], + dist_info[arch_key].get("sha256", ""), ] continue - for cuda_version, data in dist_info[arch].items(): + for cuda_version, data in dist_info[arch_key].items(): # CUDNN JSON might contain paths for each CUDA version. path_key = "relative_path" if path_key not in data.keys():