Skip to content

Commit

Permalink
feat: implement cuda_component repo rule for deliverable ctk
Browse files Browse the repository at this point in the history
  • Loading branch information
cloudhan committed Dec 30, 2024
1 parent faa178c commit 06bbedf
Show file tree
Hide file tree
Showing 10 changed files with 321 additions and 49 deletions.
52 changes: 47 additions & 5 deletions cuda/extensions.bzl
Original file line number Diff line number Diff line change
@@ -1,10 +1,39 @@
"""Entry point for extensions used by bzlmod."""

load("//cuda/private:repositories.bzl", "local_cuda")
load("//cuda/private:repositories.bzl", "cuda_component", "local_cuda")

cuda_component_tag = tag_class(attrs = {
"name": attr.string(mandatory = True, doc = "Repo name for the deliverable cuda_component"),
"component_name": attr.string(doc = "Short name of the component defined in registry."),
"integrity": attr.string(
doc = "Expected checksum in Subresource Integrity format of the file downloaded. " +
"This must match the checksum of the file downloaded.",
),
"sha256": attr.string(
doc = "The expected SHA-256 of the file downloaded. " +
"This must match the SHA-256 of the file downloaded.",
),
"strip_prefix": attr.string(
doc = "A directory prefix to strip from the extracted files. " +
"Many archives contain a top-level directory that contains all of the useful files in archive.",
),
"urls": attr.string_list(
mandatory = True,
doc = "A list of URLs to a file that will be made available to Bazel. " +
"Each entry must be a file, http or https URL. Redirections are followed. " +
"Authentication is not supported. " +
"URLs are tried in order until one succeeds, so you should list local mirrors first. " +
"If all downloads fail, the rule will fail.",
),
})

cuda_toolkit_tag = tag_class(attrs = {
"name": attr.string(doc = "Name for the toolchain repository", default = "local_cuda"),
"toolkit_path": attr.string(doc = "Path to the CUDA SDK, if empty the environment variable CUDA_PATH will be used to deduce this path."),
"components_mapping": attr.string_keyed_label_dict(
doc = "A mapping from component names to component repos of a deliverable CUDA Toolkit. " +
"Only the repo part of the label is usefull",
),
})

def _find_modules(module_ctx):
Expand All @@ -25,10 +54,20 @@ def _find_modules(module_ctx):
def _module_tag_to_dict(t):
return {attr: getattr(t, attr) for attr in dir(t)}

def _init(module_ctx):
def _impl(module_ctx):
# Toolchain configuration is only allowed in the root module, or in rules_cuda.
root, rules_cuda = _find_modules(module_ctx)
toolkits = root.tags.toolkit or rules_cuda.tags.toolkit
components = None
toolkits = None
if root.tags.toolkit:
components = root.tags.component
toolkits = root.tags.toolkit
else:
components = rules_cuda.tags.component
toolkits = rules_cuda.tags.toolkit

for component in components:
cuda_component(**_module_tag_to_dict(component))

registrations = {}
for toolkit in toolkits:
Expand All @@ -43,6 +82,9 @@ def _init(module_ctx):
local_cuda(**_module_tag_to_dict(toolkit))

toolchain = module_extension(
implementation = _init,
tag_classes = {"toolkit": cuda_toolkit_tag},
implementation = _impl,
tag_classes = {
"component": cuda_component_tag,
"toolkit": cuda_toolkit_tag,
},
)
173 changes: 142 additions & 31 deletions cuda/private/repositories.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,17 @@
load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
load("@bazel_tools//tools/build_defs/repo:utils.bzl", "maybe")
load("//cuda/private:template_helper.bzl", "template_helper")
load("//cuda/private:templates/registry.bzl", "FULL_COMPONENT_NAME", "REGISTRY")
load("//cuda/private:toolchain.bzl", "register_detected_cuda_toolchains")

def _to_forward_slash(s):
return s.replace("\\", "/")

def _is_linux(ctx):
return ctx.os.name.startswith("linux")

def _is_windows(ctx):
return ctx.os.name.lower().startswith("windows")

def _get_nvcc_version(repository_ctx, cuda_path):
result = repository_ctx.execute([cuda_path + "/bin/nvcc", "--version"])
def _get_nvcc_version(repository_ctx, nvcc_root):
result = repository_ctx.execute([nvcc_root + "/bin/nvcc", "--version"])
if result.return_code != 0:
return [-1, -1]
for line in [line for line in result.stdout.split("\n") if ", release " in line]:
Expand All @@ -27,21 +25,7 @@ def _get_nvcc_version(repository_ctx, cuda_path):
return version[:2]
return [-1, -1]

def detect_cuda_toolkit(repository_ctx):
"""Detect CUDA Toolkit.
The path to CUDA Toolkit is determined as:
- the value of `toolkit_path` passed to local_cuda as an attribute
- taken from `CUDA_PATH` environment variable or
- determined through 'which ptxas' or
- defaults to '/usr/local/cuda'
Args:
repository_ctx: repository_ctx
Returns:
A struct contains the information of CUDA Toolkit.
"""
def _detect_local_cuda_toolkit(repository_ctx):
cuda_path = repository_ctx.attr.toolkit_path
if cuda_path == "":
cuda_path = repository_ctx.os.environ.get("CUDA_PATH", None)
Expand Down Expand Up @@ -97,6 +81,61 @@ def detect_cuda_toolkit(repository_ctx):
fatbinary_label = fatbinary,
)

def _detect_deliverable_cuda_toolkit(repository_ctx):
# NOTE: component nvcc contains some headers that will be used.
required_components = ["cccl", "cudart", "nvcc"]
for rc in required_components:
if rc not in repository_ctx.attr.components_mapping:
fail('component "{}" is required.'.format(rc))

is_bzlmod_enabled = str(Label("//:invalid")).startswith("@@")
canonical_nvcc_repo_name = repository_ctx.attr.components_mapping["nvcc"].repo_name
nvcc_repo = ("@@{}" if is_bzlmod_enabled else "@{}").format(canonical_nvcc_repo_name)

bin_ext = ".exe" if _is_windows(repository_ctx) else ""
nvlink = str(Label(nvcc_repo + "//:nvcc/bin/nvlink{}".format(bin_ext)))
link_stub = str(Label(nvcc_repo + "//:nvcc/bin/crt/link.stub"))
bin2c = str(Label(nvcc_repo + "//:nvcc/bin/bin2c{}".format(bin_ext)))
fatbinary = str(Label(nvcc_repo + "//:nvcc/bin/fatbinary{}".format(bin_ext)))

nvcc_root = Label(nvcc_repo).workspace_root + "/nvcc"
nvcc_version_major, nvcc_version_minor = _get_nvcc_version(repository_ctx, nvcc_root)

return struct(
path = nvcc_root,
# this should have been extracted from cuda.h, reuse nvcc for now
version_major = nvcc_version_major,
version_minor = nvcc_version_minor,
# this is extracted from `nvcc --version`
nvcc_version_major = nvcc_version_major,
nvcc_version_minor = nvcc_version_minor,
nvlink_label = nvlink,
link_stub_label = link_stub,
bin2c_label = bin2c,
fatbinary_label = fatbinary,
)

def detect_cuda_toolkit(repository_ctx):
"""Detect CUDA Toolkit.
The path to CUDA Toolkit is determined as:
- use nvcc component from deliverable
- the value of `toolkit_path` passed to local_cuda as an attribute
- taken from `CUDA_PATH` environment variable or
- determined through 'which ptxas' or
- defaults to '/usr/local/cuda'
Args:
repository_ctx: repository_ctx
Returns:
A struct contains the information of CUDA Toolkit.
"""
if repository_ctx.attr.components_mapping != []:
return _detect_deliverable_cuda_toolkit(repository_ctx)
else:
return _detect_local_cuda_toolkit(repository_ctx)

def config_cuda_toolkit_and_nvcc(repository_ctx, cuda):
"""Generate `@local_cuda//BUILD` and `@local_cuda//defs.bzl` and `@local_cuda//toolchain/BUILD`
Expand All @@ -105,30 +144,40 @@ def config_cuda_toolkit_and_nvcc(repository_ctx, cuda):
cuda: The struct returned from detect_cuda_toolkit
"""

# True: locally installed cuda toolkit
# False: hermatic cuda toolkit (components)
# True: locally installed cuda toolkit (@local_cuda with full install of local CTK)
# False: hermatic cuda toolkit (@local_cuda with alias of components)
# None: cuda toolkit is not presented
is_local_cuda = None
if cuda.path != None:
is_local_ctk = None

if len(repository_ctx.attr.components_mapping) != 0:
is_local_ctk = False

if is_local_ctk == None and cuda.path != None:
# When using a special cuda toolkit path install, need to manually fix up the lib64 links
if cuda.path == "/usr/lib/nvidia-cuda-toolkit":
repository_ctx.symlink(cuda.path + "/bin", "cuda/bin")
repository_ctx.symlink("/usr/lib/x86_64-linux-gnu", "cuda/lib64")
else:
repository_ctx.symlink(cuda.path, "cuda")
is_local_cuda = True
is_local_ctk = True

# Generate @local_cuda//BUILD
if is_local_cuda == None:
if is_local_ctk == None:
repository_ctx.symlink(Label("//cuda/private:templates/BUILD.local_cuda_disabled"), "BUILD")
elif is_local_cuda:
elif is_local_ctk:
libpath = "lib64" if _is_linux(repository_ctx) else "lib"
template_helper.generate_build(repository_ctx, libpath)
else:
fail("hermatic cuda toolchain is not implemented")
template_helper.generate_build(
repository_ctx,
libpath = "lib",
components = repository_ctx.attr.components_mapping,
is_local_cuda = True,
is_deliverable = True,
)

# Generate @local_cuda//defs.bzl
template_helper.generate_defs_bzl(repository_ctx, is_local_cuda)
template_helper.generate_defs_bzl(repository_ctx, is_local_ctk == True)

# Generate @local_cuda//toolchain/BUILD
template_helper.generate_toolchain_build(repository_ctx, cuda)
Expand Down Expand Up @@ -187,13 +236,73 @@ def _local_cuda_impl(repository_ctx):

local_cuda = repository_rule(
implementation = _local_cuda_impl,
attrs = {"toolkit_path": attr.string(mandatory = False)},
attrs = {
"toolkit_path": attr.string(mandatory = False),
"components_mapping": attr.string_keyed_label_dict(),
},
configure = True,
local = True,
environ = ["CUDA_PATH", "PATH", "CUDA_CLANG_PATH", "BAZEL_LLVM"],
# remotable = True,
)

def _cuda_component_impl(repository_ctx):
name_fragments = repository_ctx.name.split("local_cuda_")
if len(name_fragments) != 2 or (name_fragments[0] != "" and not name_fragments[0].endswith("~")):
fail("cuda_component(name='{}') is expected to have a repo name starts with local_cuda_".format(repository_ctx.name))

component_name = None
if repository_ctx.attr.component_name:
component_name = repository_ctx.attr.component_name
if component_name not in REGISTRY:
fail("invalid component '{}', available: {}".format(component_name, repr(REGISTRY.keys())))
else:
component_name = repository_ctx.name[len("local_cuda_"):]
if component_name not in REGISTRY:
fail("invalid derived component '{}', available: {}, ".format(component_name, repr(REGISTRY.keys())) +
" if derivation result is unexpected, please specify `component_name` attribute manually")

if not repository_ctx.attr.url and not repository_ctx.attr.urls:
fail("either attribute `url` or `urls` must be filled")
if repository_ctx.attr.url and repository_ctx.attr.urls:
fail("attributes `url` and `urls` cannot be used at the same time")

repository_ctx.download_and_extract(
url = repository_ctx.attr.url or repository_ctx.attr.urls,
output = component_name,
integrity = repository_ctx.attr.integrity,
sha256 = repository_ctx.attr.sha256,
stripPrefix = repository_ctx.attr.strip_prefix,
)

template_helper.generate_build(
repository_ctx,
libpath = "lib",
components = {component_name: repository_ctx.name},
is_local_cuda = False,
is_deliverable = True,
)

cuda_component = repository_rule(
implementation = _cuda_component_impl,
attrs = {
"component_name": attr.string(),
"url": attr.string(),
"urls": attr.string_list(),
"integrity": attr.string(),
"sha256": attr.string(),
"strip_prefix": attr.string(),
},
)

def default_components_mapping(components):
"""Create a default components_mapping from list of component names.
Args:
components: list of string, a list of component names.
"""
return {c: "@local_cuda_" + c for c in components}

def rules_cuda_dependencies():
"""Populate the dependencies for rules_cuda. This will setup other bazel rules as workspace dependencies"""
maybe(
Expand All @@ -216,17 +325,19 @@ def rules_cuda_dependencies():
],
)

def rules_cuda_toolchains(toolkit_path = None, register_toolchains = False):
def rules_cuda_toolchains(toolkit_path = None, components_mapping = None, register_toolchains = False):
"""Populate the local_cuda repo.
Args:
toolkit_path: Optionally specify the path to CUDA toolkit. If not specified, it will be detected automatically.
components_mapping: dict mapping from component_name to its corresponding cuda_component's repo_name
register_toolchains: Register the toolchains if enabled.
"""

local_cuda(
name = "local_cuda",
toolkit_path = toolkit_path,
components_mapping = components_mapping,
)

if register_toolchains:
Expand Down
Loading

0 comments on commit 06bbedf

Please sign in to comment.