Skip to content

Commit

Permalink
cleanups
Browse files Browse the repository at this point in the history
  • Loading branch information
geohot committed Nov 26, 2023
1 parent a9315e4 commit 200da97
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 29 deletions.
13 changes: 8 additions & 5 deletions test/helpers.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from typing import List
import unittest, os, ctypes

c_char_p_p = ctypes.POINTER(ctypes.POINTER(ctypes.c_char))
def to_char_p_p(options: List[str]):
c_options = (ctypes.POINTER(ctypes.c_char) * len(options))()
c_options[:] = [ctypes.cast(ctypes.create_string_buffer(o.encode("utf-8")), ctypes.POINTER(ctypes.c_char)) for o in options]
Expand All @@ -15,8 +14,12 @@ def wrapper(func):
CI = os.getenv("CI", "") != ""

def get_bytes(arg, get_sz, get_str, check) -> bytes:
sz = ctypes.c_size_t()
check(get_sz(arg, ctypes.byref(sz)))
mstr = ctypes.create_string_buffer(sz.value)
check(get_str(arg, mstr))
check(get_sz(arg, ctypes.byref((sz := ctypes.c_size_t()))))
check(get_str(arg, (mstr := ctypes.create_string_buffer(sz.value))))
return ctypes.string_at(mstr, size=sz.value)

def compile(prg, options, f, check):
check(f.create(ctypes.pointer((prog := f.new())), prg.encode(), "<null>".encode(), 0, None, None))
status = f.compile(prog, len(options), to_char_p_p(options))
if status != 0: raise RuntimeError(f"compile failed: {get_bytes(prog, f.getLogSize, f.getLog, check)}")
return get_bytes(prog, f.getCodeSize, f.getCode, check)
23 changes: 11 additions & 12 deletions test/test_cuda.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,22 @@
import ctypes
import unittest
import gpuctypes.cuda as cuda
from helpers import to_char_p_p, CI, get_bytes
from helpers import CI, compile

def check(status):
if status != 0:
error = ctypes.POINTER(ctypes.c_char)()
check(cuda.cuGetErrorString(status, ctypes.byref(error)))
raise RuntimeError(f"CUDA Error {status}, {ctypes.string_at(error).decode()}")

def _test_compile(prg):
prog = cuda.nvrtcProgram()
check(cuda.nvrtcCreateProgram(ctypes.pointer(prog), prg.encode(), "<null>".encode(), 0, None, None))
options = ["--gpu-architecture=sm_35"]
status = cuda.nvrtcCompileProgram(prog, len(options), to_char_p_p(options))
if status != 0:
log = get_bytes(prog, cuda.nvrtcGetProgramLogSize, cuda.nvrtcGetProgramLog, check)
raise RuntimeError(f"CUDA compile failed: {log}")
return get_bytes(prog, cuda.nvrtcGetCUBINSize, cuda.nvrtcGetCUBIN, check)
class CUDACompile:
new = cuda.nvrtcProgram
create = cuda.nvrtcCreateProgram
compile = cuda.nvrtcCompileProgram
getLogSize = cuda.nvrtcGetProgramLogSize
getLog = cuda.nvrtcGetProgramLog
getCodeSize = cuda.nvrtcGetCUBINSize
getCode = cuda.nvrtcGetCUBIN

class TestCUDA(unittest.TestCase):
def test_has_methods(self):
Expand All @@ -27,10 +26,10 @@ def test_has_methods(self):

def test_compile_fail(self):
with self.assertRaises(RuntimeError):
_test_compile("__device__ void test() { {")
compile("__device__ void test() { {", ["--gpu-architecture=sm_35"], CUDACompile, check)

def test_compile(self):
prg = _test_compile("__device__ int test() { return 42; }")
prg = compile("__device__ int test() { return 42; }", ["--gpu-architecture=sm_35"], CUDACompile, check)
assert len(prg) > 10

@unittest.skipIf(CI, "cuda doesn't work in CI")
Expand Down
23 changes: 11 additions & 12 deletions test/test_hip.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,19 @@
import unittest
import ctypes
import gpuctypes.hip as hip
from helpers import to_char_p_p, CI, expectedFailureIf, get_bytes
from helpers import CI, expectedFailureIf, compile

def check(status):
if status != 0: raise RuntimeError(f"HIP Error {status}, {ctypes.string_at(hip.hipGetErrorString(status)).decode()}")

def _test_compile(prg):
prog = hip.hiprtcProgram()
check(hip.hiprtcCreateProgram(ctypes.pointer(prog), prg.encode(), "<null>".encode(), 0, None, None))
options = ["--offload-arch=gfx1100"]
status = hip.hiprtcCompileProgram(prog, len(options), to_char_p_p(options))
if status != 0:
log = get_bytes(prog, hip.hiprtcGetProgramLogSize, hip.hiprtcGetProgramLog, check)
raise RuntimeError(f"HIP compile failed: {log}")
return get_bytes(prog, hip.hiprtcGetCodeSize, hip.hiprtcGetCode, check)
class HIPCompile:
new = hip.hiprtcProgram
create = hip.hiprtcCreateProgram
compile = hip.hiprtcCompileProgram
getLogSize = hip.hiprtcGetProgramLogSize
getLog = hip.hiprtcGetProgramLog
getCodeSize = hip.hiprtcGetCodeSize
getCode = hip.hiprtcGetCode

class TestHIP(unittest.TestCase):
def test_has_methods(self):
Expand All @@ -24,10 +23,10 @@ def test_has_methods(self):

def test_compile_fail(self):
with self.assertRaises(RuntimeError):
_test_compile("void test() { {")
compile("void test() { {", ["--offload-arch=gfx1100"], HIPCompile, check)

def test_compile(self):
prg = _test_compile("int test() { return 42; }")
prg = compile("int test() { return 42; }", ["--offload-arch=gfx1100"], HIPCompile, check)
assert len(prg) > 10

class TestHIPDevice(unittest.TestCase):
Expand Down

0 comments on commit 200da97

Please sign in to comment.