Skip to content

Commit

Permalink
Add support for Tegra chips in hermetic CUDA rules.
Browse files Browse the repository at this point in the history
See [Github issue](tensorflow/tensorflow#75353).

PiperOrigin-RevId: 674350119
  • Loading branch information
tensorflower-gardener authored and Google-ML-Automation committed Sep 13, 2024
1 parent 571cedd commit 3f8c182
Showing 1 changed file with 25 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,16 @@ load(
OS_ARCH_DICT = {
"amd64": "x86_64-unknown-linux-gnu",
"aarch64": "aarch64-unknown-linux-gnu",
"tegra-aarch64": "tegra-aarch64-unknown-linux-gnu",
}
_REDIST_ARCH_DICT = {
"linux-x86_64": "x86_64-unknown-linux-gnu",
"linux-sbsa": "aarch64-unknown-linux-gnu",
"linux-aarch64": "tegra-aarch64-unknown-linux-gnu",
}

TEGRA = "tegra"

SUPPORTED_ARCHIVE_EXTENSIONS = [
".zip",
".jar",
Expand Down Expand Up @@ -277,6 +281,15 @@ def _download_redistribution(repository_ctx, arch_key, path_prefix):
)
repository_ctx.delete(file_name)

def _get_platform_architecture(repository_ctx):
host_arch = repository_ctx.os.arch

if host_arch == "aarch64":
uname_result = repository_ctx.execute(["uname", "-a"]).stdout
if TEGRA in uname_result:
return "{}-{}".format(TEGRA, host_arch)
return host_arch

def _use_downloaded_cuda_redistribution(repository_ctx):
# buildifier: disable=function-docstring-args
""" Downloads CUDA redistribution and initializes hermetic CUDA repository."""
Expand All @@ -298,7 +311,7 @@ def _use_downloaded_cuda_redistribution(repository_ctx):
return

# Download archive only when GPU config is used.
arch_key = OS_ARCH_DICT[repository_ctx.os.arch]
arch_key = OS_ARCH_DICT[_get_platform_architecture(repository_ctx)]
if arch_key not in repository_ctx.attr.url_dict.keys():
fail(
("The supported platforms are {supported_platforms}." +
Expand Down Expand Up @@ -374,7 +387,7 @@ def _use_downloaded_cudnn_redistribution(repository_ctx):
return

# Download archive only when GPU config is used.
arch_key = OS_ARCH_DICT[repository_ctx.os.arch]
arch_key = OS_ARCH_DICT[_get_platform_architecture(repository_ctx)]
if arch_key not in repository_ctx.attr.url_dict.keys():
arch_key = "cuda{version}_{arch}".format(
version = cuda_version.split(".")[0],
Expand Down Expand Up @@ -438,21 +451,24 @@ cudnn_repo = repository_rule(
def _get_redistribution_urls(dist_info):
url_dict = {}
for arch in _REDIST_ARCH_DICT.keys():
if "relative_path" in dist_info[arch]:
arch_key = arch
if arch_key == "linux-aarch64" and arch_key not in dist_info:
arch_key = "linux-sbsa"
if "relative_path" in dist_info[arch_key]:
url_dict[_REDIST_ARCH_DICT[arch]] = [
dist_info[arch]["relative_path"],
dist_info[arch].get("sha256", ""),
dist_info[arch_key]["relative_path"],
dist_info[arch_key].get("sha256", ""),
]
continue

if "full_path" in dist_info[arch]:
if "full_path" in dist_info[arch_key]:
url_dict[_REDIST_ARCH_DICT[arch]] = [
dist_info[arch]["full_path"],
dist_info[arch].get("sha256", ""),
dist_info[arch_key]["full_path"],
dist_info[arch_key].get("sha256", ""),
]
continue

for cuda_version, data in dist_info[arch].items():
for cuda_version, data in dist_info[arch_key].items():
# CUDNN JSON might contain paths for each CUDA version.
path_key = "relative_path"
if path_key not in data.keys():
Expand Down

0 comments on commit 3f8c182

Please sign in to comment.