Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for LTO in Numba linker #11

Merged
merged 4 commits into from
Nov 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 26 additions & 7 deletions pynvjitlink/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,21 +38,27 @@


class PatchedLinker(Linker):
def __init__(self, max_registers=None, lineinfo=False, cc=None):
def __init__(self, max_registers=None, lineinfo=False, cc=None,
lto=False, additional_flags=None):
if cc is None:
raise RuntimeError("PatchedLinker requires CC to be specified")
if not any(isinstance(cc, t) for t in [list, tuple]):
raise TypeError("`cc` must be a list or tuple of length 2")

sm_ver = f"{cc[0] * 10 + cc[1]}"
arch = f"-arch=sm_{sm_ver}"
opts = [arch]
options = [arch]
if max_registers:
opts.append(f"-maxrregcount={max_registers}")
options.append(f"-maxrregcount={max_registers}")
if lineinfo:
opts.append("-lineinfo")
options.append("-lineinfo")
if lto:
options.append('-lto')
if additional_flags is not None:
options.extend(additional_flags)

self._linker = NvJitLinker(*opts)
self._linker = NvJitLinker(*options)
self.options = options

@property
def info_log(self):
Expand All @@ -65,6 +71,15 @@ def error_log(self):
def add_ptx(self, ptx, name='<cudapy-ptx>'):
self._linker.add_ptx(ptx, name)

def add_fatbin(self, fatbin, name='<external-fatbin>'):
self._linker.add_fatbin(fatbin, name)

def add_ltoir(self, ltoir, name='<external-ltoir>'):
self._linker.add_ltoir(ltoir, name)

def add_object(self, obj, name='<external-object>'):
self._linker.add_object(obj, name)

def add_file(self, path, kind):
try:
with open(path, 'rb') as f:
Expand All @@ -81,6 +96,8 @@ def add_file(self, path, kind):
raise LinkerError("Don't know how to link archives")
elif kind == FILE_EXTENSION_MAP['ptx']:
return self.add_ptx(data, name)
elif kind == FILE_EXTENSION_MAP['o']:
fn = self._linker.add_object
else:
raise LinkerError(f"Don't know how to link {kind}")

Expand Down Expand Up @@ -110,8 +127,10 @@ def complete(self):
raise LinkerError from e


def new_patched_linker(max_registers=0, lineinfo=False, cc=None):
return PatchedLinker(max_registers, lineinfo, cc)
def new_patched_linker(max_registers=0, lineinfo=False, cc=None, lto=False,
additional_flags=None):
return PatchedLinker(max_registers=max_registers, lineinfo=lineinfo, cc=cc,
lto=lto, additional_flags=additional_flags)


def patch_numba_linker():
Expand Down
95 changes: 95 additions & 0 deletions pynvjitlink/tests/test_patch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# Copyright (c) 2023, NVIDIA CORPORATION.

import pytest
import sys

from pynvjitlink import patch, NvJitLinkError
from pynvjitlink.patch import (PatchedLinker, patch_numba_linker,
new_patched_linker, required_numba_ver,
_numba_version_ok)
from unittest.mock import patch as mock_patch


def test_numba_patching_numba_not_ok():
with mock_patch.multiple(
patch,
_numba_version_ok=False,
_numba_error='<error>'):
with pytest.raises(RuntimeError, match='Cannot patch Numba: <error>'):
patch_numba_linker()


@pytest.mark.skipif(
not _numba_version_ok,
reason=f"Requires Numba == {required_numba_ver[0]}.{required_numba_ver[1]}"
)
def test_numba_patching():
# We import the linker here rather than at the top level because the import
# may fail if if Numba is not present or an unsupported version.
from numba.cuda.cudadrv.driver import Linker
patch_numba_linker()
assert Linker.new is new_patched_linker


def test_create():
patched_linker = PatchedLinker(cc=(7, 5))
assert "-arch=sm_75" in patched_linker.options


def test_create_no_cc_error():
# nvJitLink expects at least the architecture to be specified.
with pytest.raises(RuntimeError,
match='PatchedLinker requires CC to be specified'):
PatchedLinker()


def test_invalid_arch_error():
# CC 0.0 is not a valid compute capability
with pytest.raises(NvJitLinkError,
match='NVJITLINK_ERROR_UNRECOGNIZED_OPTION error'):
PatchedLinker(cc=(0, 0))


def test_invalid_cc_type_error():
with pytest.raises(TypeError,
match='`cc` must be a list or tuple of length 2'):
PatchedLinker(cc=0)


@pytest.mark.parametrize('max_registers', (None, 32))
@pytest.mark.parametrize('lineinfo', (False, True))
@pytest.mark.parametrize('lto', (False, True))
@pytest.mark.parametrize('additional_flags', (None, ('-g',), ('-g', '-time')))
def test_ptx_compile_options(max_registers, lineinfo, lto, additional_flags):
patched_linker = PatchedLinker(
cc=(7, 5),
max_registers=max_registers,
lineinfo=lineinfo,
lto=lto,
additional_flags=additional_flags
)

assert "-arch=sm_75" in patched_linker.options

if max_registers:
assert f"-maxrregcount={max_registers}" in patched_linker.options
else:
assert "-maxrregcount" not in patched_linker.options

if lineinfo:
assert "-lineinfo" in patched_linker.options
else:
assert "-lineinfo" not in patched_linker.options

if lto:
assert "-lto" in patched_linker.options
else:
assert "-lto" not in patched_linker.options

if additional_flags:
for flag in additional_flags:
assert flag in patched_linker.options


if __name__ == '__main__':
sys.exit(pytest.main())
Loading