diff --git a/pynvjitlink/patch.py b/pynvjitlink/patch.py index 1928e050..f979f58f 100644 --- a/pynvjitlink/patch.py +++ b/pynvjitlink/patch.py @@ -38,7 +38,8 @@ class PatchedLinker(Linker): - def __init__(self, max_registers=None, lineinfo=False, cc=None): + def __init__(self, max_registers=None, lineinfo=False, cc=None, + lto=False, additional_flags=None): if cc is None: raise RuntimeError("PatchedLinker requires CC to be specified") if not any(isinstance(cc, t) for t in [list, tuple]): @@ -46,13 +47,18 @@ def __init__(self, max_registers=None, lineinfo=False, cc=None): sm_ver = f"{cc[0] * 10 + cc[1]}" arch = f"-arch=sm_{sm_ver}" - opts = [arch] + options = [arch] if max_registers: - opts.append(f"-maxrregcount={max_registers}") + options.append(f"-maxrregcount={max_registers}") if lineinfo: - opts.append("-lineinfo") + options.append("-lineinfo") + if lto: + options.append('-lto') + if additional_flags is not None: + options.extend(additional_flags) - self._linker = NvJitLinker(*opts) + self._linker = NvJitLinker(*options) + self.options = options @property def info_log(self): @@ -65,6 +71,15 @@ def error_log(self): def add_ptx(self, ptx, name=''): self._linker.add_ptx(ptx, name) + def add_fatbin(self, fatbin, name=''): + self._linker.add_fatbin(fatbin, name) + + def add_ltoir(self, ltoir, name=''): + self._linker.add_ltoir(ltoir, name) + + def add_object(self, obj, name=''): + self._linker.add_object(obj, name) + def add_file(self, path, kind): try: with open(path, 'rb') as f: @@ -81,6 +96,8 @@ def add_file(self, path, kind): raise LinkerError("Don't know how to link archives") elif kind == FILE_EXTENSION_MAP['ptx']: return self.add_ptx(data, name) + elif kind == FILE_EXTENSION_MAP['o']: + fn = self._linker.add_object else: raise LinkerError(f"Don't know how to link {kind}") @@ -110,8 +127,10 @@ def complete(self): raise LinkerError from e -def new_patched_linker(max_registers=0, lineinfo=False, cc=None): - return PatchedLinker(max_registers, lineinfo, cc) +def new_patched_linker(max_registers=0, lineinfo=False, cc=None, lto=False, + additional_flags=None): + return PatchedLinker(max_registers=max_registers, lineinfo=lineinfo, cc=cc, + lto=lto, additional_flags=additional_flags) def patch_numba_linker(): diff --git a/pynvjitlink/tests/test_patch.py b/pynvjitlink/tests/test_patch.py new file mode 100644 index 00000000..6d76000d --- /dev/null +++ b/pynvjitlink/tests/test_patch.py @@ -0,0 +1,95 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. + +import pytest +import sys + +from pynvjitlink import patch, NvJitLinkError +from pynvjitlink.patch import (PatchedLinker, patch_numba_linker, + new_patched_linker, required_numba_ver, + _numba_version_ok) +from unittest.mock import patch as mock_patch + + +def test_numba_patching_numba_not_ok(): + with mock_patch.multiple( + patch, + _numba_version_ok=False, + _numba_error=''): + with pytest.raises(RuntimeError, match='Cannot patch Numba: '): + patch_numba_linker() + + +@pytest.mark.skipif( + not _numba_version_ok, + reason=f"Requires Numba == {required_numba_ver[0]}.{required_numba_ver[1]}" +) +def test_numba_patching(): + # We import the linker here rather than at the top level because the import + # may fail if if Numba is not present or an unsupported version. + from numba.cuda.cudadrv.driver import Linker + patch_numba_linker() + assert Linker.new is new_patched_linker + + +def test_create(): + patched_linker = PatchedLinker(cc=(7, 5)) + assert "-arch=sm_75" in patched_linker.options + + +def test_create_no_cc_error(): + # nvJitLink expects at least the architecture to be specified. + with pytest.raises(RuntimeError, + match='PatchedLinker requires CC to be specified'): + PatchedLinker() + + +def test_invalid_arch_error(): + # CC 0.0 is not a valid compute capability + with pytest.raises(NvJitLinkError, + match='NVJITLINK_ERROR_UNRECOGNIZED_OPTION error'): + PatchedLinker(cc=(0, 0)) + + +def test_invalid_cc_type_error(): + with pytest.raises(TypeError, + match='`cc` must be a list or tuple of length 2'): + PatchedLinker(cc=0) + + +@pytest.mark.parametrize('max_registers', (None, 32)) +@pytest.mark.parametrize('lineinfo', (False, True)) +@pytest.mark.parametrize('lto', (False, True)) +@pytest.mark.parametrize('additional_flags', (None, ('-g',), ('-g', '-time'))) +def test_ptx_compile_options(max_registers, lineinfo, lto, additional_flags): + patched_linker = PatchedLinker( + cc=(7, 5), + max_registers=max_registers, + lineinfo=lineinfo, + lto=lto, + additional_flags=additional_flags + ) + + assert "-arch=sm_75" in patched_linker.options + + if max_registers: + assert f"-maxrregcount={max_registers}" in patched_linker.options + else: + assert "-maxrregcount" not in patched_linker.options + + if lineinfo: + assert "-lineinfo" in patched_linker.options + else: + assert "-lineinfo" not in patched_linker.options + + if lto: + assert "-lto" in patched_linker.options + else: + assert "-lto" not in patched_linker.options + + if additional_flags: + for flag in additional_flags: + assert flag in patched_linker.options + + +if __name__ == '__main__': + sys.exit(pytest.main())