Skip to content

Commit

Permalink
merge main and update the options class
Browse files Browse the repository at this point in the history
  • Loading branch information
ksimpson-work committed Nov 13, 2024
1 parent 92aa731 commit 81086e0
Show file tree
Hide file tree
Showing 2 changed files with 298 additions and 0 deletions.
170 changes: 170 additions & 0 deletions cuda_core/cuda/core/experimental/_linker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
from cuda.core.experimental._module import ObjectCode
from cuda.core.experimental._utils import check_or_create_options
from dataclasses import dataclass
from typing import Optional
from cuda.bindings import nvjitlink


@dataclass
class LinkerOptions:
arch: str # /**< -arch=sm_<N> Pass SM architecture value. See nvcc for valid values of <N>. Can use compute_<N> value instead if only generating PTX. This is a required option. */
max_register_count: Optional[int] = None # /**< -maxrregcount=<N> Maximum register count. */
time: Optional[bool] = None # /**< -time Print timing information to InfoLog. */
verbose: Optional[bool] = None # /**< -verbose Print verbose messages to InfoLog. */
link_time_optimization: Optional[bool] = None # /**< -lto Do link time optimization. */
ptx: Optional[bool] = None # /**< -ptx Emit ptx after linking instead of cubin; only supported with -lto. */
optimization_level: Optional[int] = None # /**< -O<N> Optimization level. Only 0 and 3 are accepted. */
debug: Optional[bool] = None # /**< -g Generate debug information. */
lineinfo: Optional[bool] = None # /**< -lineinfo Generate line information. */
ftz: Optional[bool] = None # /**< -ftz=<n> Flush to zero. */
prec_div: Optional[bool] = None # /**< -prec-div=<n> Precise divide. */
prec_sqrt: Optional[bool] = None # /**< -prec-sqrt=<n> Precise square root. */
fma: Optional[bool] = None # /**< -fma=<n> Fast multiply add. */
kernels_used: Optional[list[str]] = None # /**< -kernels-used=<name> Pass list of kernels that are used; any not in the list can be removed. This option can be specified multiple times. */
variables_used: Optional[list[str]] = None # /**< -variables-used=<name> Pass list of variables that are used; any not in the list can be removed. This option can be specified multiple times. */
optimize_unused_variables: Optional[bool] = None # /**< -optimize-unused-variables Normally device code optimization is limited by not knowing what the host code references. With this option it can assume that if a variable is not referenced in device code then it can be removed. */
xptxas: Optional[list[str]] = None # /**< -Xptxas=<opt> Pass <opt> to ptxas. This option can be called multiple times. */
split_compile: Optional[int] = None # /**< -split-compile=<N> Split compilation maximum thread count. Use 0 to use all available processors. Value of 1 disables split compilation (default). */
split_compile_extended: Optional[int] = None # /**< -split-compile-extended=<N> A more aggressive form of split compilation available in LTO mode only. Accepts a maximum thread count value. Use 0 to use all available processors. Value of 1 disables extended split compilation (default). Note: This option can potentially impact performance of the compiled binary. */
jump_table_density: Optional[int] = None # /**< -jump-table-density=<N> When doing LTO, specify the case density percentage in switch statements, and use it as a minimal threshold to determine whether jump table(brx.idx instruction) will be used to implement a switch statement. Default value is 101. The percentage ranges from 0 to 101 inclusively. */
no_cache: Optional[bool] = None # /**< -no-cache Don’t cache the intermediate steps of nvJitLink. */
device_stack_protector: Optional[bool] = None # /**< -device-stack-protector Enable stack canaries in device code. Stack canaries make it more difficult to exploit certain types of memory safety bugs involving stack-local variables. The compiler uses heuristics to assess the risk of such a bug in each function. Only those functions which are deemed high-risk make use of a stack canary. */


def __post_init__(self):
self.formatted_options = []
if self.arch is not None:
self.formatted_options.append(f"-arch={self.arch}")
if self.max_register_count is not None:
self.formatted_options.append(f"-maxrregcount={self.max_register_count}")
if self.time is not None:
self.formatted_options.append("-time")
if self.verbose is not None:
self.formatted_options.append("-verbose")
if self.link_time_optimization is not None:
self.formatted_options.append("-lto")
if self.ptx is not None:
self.formatted_options.append("-ptx")
if self.optimization_level is not None:
self.formatted_options.append(f"-O{self.optimization_level}")
if self.debug is not None:
self.formatted_options.append("-g")
if self.lineinfo is not None:
self.formatted_options.append("-lineinfo")
if self.ftz is not None:
self.formatted_options.append(f"-ftz={'true' if self.ftz else 'false'}")
if self.prec_div is not None:
self.formatted_options.append(f"-prec-div={'true' if self.prec_div else 'false'}")
if self.prec_sqrt is not None:
self.formatted_options.append(f"-prec-sqrt={'true' if self.prec_sqrt else 'false'}")
if self.fma is not None:
self.formatted_options.append(f"-fma={'true' if self.fma else 'false'}")
if self.kernels_used is not None:
for kernel in self.kernels_used:
self.formatted_options.append(f"-kernels-used={kernel}")
if self.variables_used is not None:
for variable in self.variables_used:
self.formatted_options.append(f"-variables-used={variable}")
if self.optimize_unused_variables is not None:
self.formatted_options.append("-optimize-unused-variables")
if self.xptxas is not None:
for opt in self.xptxas:
self.formatted_options.append(f"-Xptxas={opt}")
if self.split_compile is not None:
self.formatted_options.append(f"-split-compile={self.split_compile}")
if self.split_compile_extended is not None:
self.formatted_options.append(f"-split-compile-extended={self.split_compile_extended}")
if self.jump_table_density is not None:
self.formatted_options.append(f"-jump-table-density={self.jump_table_density}")
if self.no_cache is not None:
self.formatted_options.append("-no-cache")
if self.device_stack_protector is not None:
self.formatted_options.append("-device-stack-protector")


class Linker:

__slots__ = ("_handle")

def __init__(self, options: LinkerOptions, object_codes = None):
self._handle = None
options = check_or_create_options(LinkerOptions, options, "Linker options")
self._handle = nvjitlink.create(len(options.formatted_options), options.formatted_options)

if object_codes is not None:
if object_codes.__iter__:
for code in object_codes:
self.add_code_object(code)
else:
self.add_code_object(object_codes)


def add_code_object(self, object_code: ObjectCode):
data = object_code._module
assert isinstance(data, bytes)
nvjitlink.add_data(
self._handle,
self._input_type_from_code_type(object_code._code_type),
data,
len(data),
f"{object_code._handle}_{object_code._code_type}",
)


def link(self, target_type) -> ObjectCode:
nvjitlink.complete(self._handle)
if target_type not in ["cubin", "ptx"]:
raise ValueError(f"Unsupported target type: {target_type}")
code = None
if target_type == "cubin":
cubin_size = nvjitlink.get_linked_cubin_size(self._handle)
code = bytearray(cubin_size)
nvjitlink.get_linked_cubin(self._handle, code)
else:
ptx_size = nvjitlink.get_linked_ptx_size(self._handle)
code = bytearray(ptx_size)
nvjitlink.get_linked_ptx(self._handle, code)

return ObjectCode(bytes(code), target_type)


def get_error_log(self) -> str:
log_size = nvjitlink.get_error_log_size(self._handle)
log = bytearray(log_size)
nvjitlink.get_error_log(self._handle, log)
return log.decode()


def get_info_log(self) -> str:
log_size = nvjitlink.get_info_log_size(self._handle)
log = bytearray(log_size)
nvjitlink.get_info_log(self._handle, log)
return log.decode()


def _input_type_from_code_type(self, code_type: str) -> nvjitlink.InputType:
# this list is based on the supported values for code_type in the ObjectCode class definition. nvjitlink supports other options for input type
if code_type == "ptx":
return nvjitlink.InputType.PTX
elif code_type == "cubin":
return nvjitlink.InputType.CUBIN
elif code_type == "fatbin":
return nvjitlink.InputType.FATBIN
elif code_type == "ltoir":
return nvjitlink.InputType.LTOIR
elif code_type == "object":
return nvjitlink.InputType.OBJECT
else:
raise ValueError(
f"Unknown code_type associated with ObjectCode: {code_type}"
)


@property
def handle(self) -> int:
return self._handle

def __del__(self):
if self._handle is not None:
nvjitlink.destroy(self._handle)
self._handle = None
128 changes: 128 additions & 0 deletions cuda_core/tests/test_linker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
import pytest
from cuda.core.experimental._linker import Linker, LinkerOptions
from cuda.core.experimental._module import ObjectCode
from cuda.core.experimental._program import Program

ARCH = "sm_80" # use sm_80 for testing the oop nvJitLink wrapper
empty_entrypoint_kernel = "__global__ void A() {}"
empty_kernel = "__device__ void B() {}"
addition_kernel = "__device__ int C(int a, int b) { return a + b; }"


@pytest.fixture(scope="module")
def compile_ptx_functions(init_cuda):

object_code_a_ptx = Program(empty_entrypoint_kernel, "c++").compile("ptx")
object_code_b_ptx = Program(empty_kernel, "c++").compile("ptx")
object_code_c_ptx = Program(addition_kernel, "c++").compile("ptx")

return object_code_a_ptx, object_code_b_ptx, object_code_c_ptx


@pytest.fixture(scope="module")
def compile_ltoir_functions(init_cuda):
object_code_a_ltoir = Program(empty_entrypoint_kernel, "c++").compile("ltoir", options=("-dlto",))
object_code_b_ltoir = Program(empty_kernel, "c++").compile("ltoir", options=("-dlto",))
object_code_c_ltoir = Program(addition_kernel, "c++").compile("ltoir", options=("-dlto",))

return object_code_a_ltoir, object_code_b_ltoir, object_code_c_ltoir


def test_linker_init_valid_options():
options = LinkerOptions(arch=ARCH)
linker = Linker(options)
assert linker.handle is not None

def test_linker_init_invalid_arch():
options = LinkerOptions(arch=None)
with pytest.raises(ValueError):
Linker(options)

def test_linker_init(compile_ptx_functions):
combinations = [
LinkerOptions(arch=ARCH),
LinkerOptions(arch=ARCH, max_register_count=32),
LinkerOptions(arch=ARCH, time=True),
LinkerOptions(arch=ARCH, verbose=True),
LinkerOptions(arch=ARCH, optimization_level=3),
LinkerOptions(arch=ARCH, debug=True),
LinkerOptions(arch=ARCH, lineinfo=True),
LinkerOptions(arch=ARCH, ftz=True),
LinkerOptions(arch=ARCH, prec_div=True),
LinkerOptions(arch=ARCH, prec_sqrt=True),
LinkerOptions(arch=ARCH, fma=True),
LinkerOptions(arch=ARCH, kernels_used=["kernel1"]),
LinkerOptions(arch=ARCH, variables_used=["var1"]),
LinkerOptions(arch=ARCH, optimize_unused_variables=True),
LinkerOptions(arch=ARCH, xptxas=["-v"]),
LinkerOptions(arch=ARCH, split_compile=0),
LinkerOptions(arch=ARCH, split_compile_extended=1),
LinkerOptions(arch=ARCH, jump_table_density=100),
LinkerOptions(arch=ARCH, no_cache=True)
]

# Try the combinations, with and without providing object codes to the constructor
for i, options in enumerate(combinations):
linker = Linker(options, object_codes=compile_ptx_functions)
object_code = linker.link("cubin")
assert isinstance(object_code, ObjectCode)

def test_linker_add_code_object(compile_ptx_functions):
options = LinkerOptions(arch=ARCH)
linker = Linker(options)
functions = compile_ptx_functions
linker.add_code_object(functions[0])
linker.add_code_object(functions[1])
linker.add_code_object(functions[2])

def test_linker_link_ptx(compile_ltoir_functions):
options = LinkerOptions(arch=ARCH, link_time_optimization=True, ptx=True)
linker = Linker(options)
functions = compile_ltoir_functions
linker.add_code_object(functions[0])
linker.add_code_object(functions[1])
linker.add_code_object(functions[2])
linked_code = linker.link("ptx")
assert isinstance(linked_code, ObjectCode)

def test_linker_link_cubin(compile_ptx_functions):
options = LinkerOptions(arch=ARCH)
linker = Linker(options)
functions = compile_ptx_functions
linker.add_code_object(functions[0])
linker.add_code_object(functions[1])
linker.add_code_object(functions[2])
linked_code = linker.link("cubin")
assert isinstance(linked_code, ObjectCode)

def test_linker_link_invalid_target_type(compile_ptx_functions):
options = LinkerOptions(arch=ARCH)
linker = Linker(options)
functions = compile_ptx_functions
linker.add_code_object(functions[0])
linker.add_code_object(functions[1])
linker.add_code_object(functions[2])
with pytest.raises(ValueError):
linker.link("invalid_target")

def test_linker_get_error_log(compile_ptx_functions):
options = LinkerOptions(arch=ARCH)
linker = Linker(options)
functions = compile_ptx_functions
linker.add_code_object(functions[0])
linker.add_code_object(functions[1])
linker.add_code_object(functions[2])
linker.link("cubin")
log = linker.get_error_log()
assert isinstance(log, str)

def test_linker_get_info_log(compile_ptx_functions):
options = LinkerOptions(arch=ARCH)
linker = Linker(options)
functions = compile_ptx_functions
linker.add_code_object(functions[0])
linker.add_code_object(functions[1])
linker.add_code_object(functions[2])
linker.link("cubin")
log = linker.get_info_log()
assert isinstance(log, str)

0 comments on commit 81086e0

Please sign in to comment.