Skip to content

Commit

Permalink
i love the walrus
Browse files Browse the repository at this point in the history
  • Loading branch information
geohot committed Nov 26, 2023
1 parent 200da97 commit dd0a06b
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 14 deletions.
8 changes: 4 additions & 4 deletions test/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@ def wrapper(func):
CI = os.getenv("CI", "") != ""

def get_bytes(arg, get_sz, get_str, check) -> bytes:
check(get_sz(arg, ctypes.byref((sz := ctypes.c_size_t()))))
check(get_str(arg, (mstr := ctypes.create_string_buffer(sz.value))))
check(get_sz(arg, ctypes.byref(sz := ctypes.c_size_t())))
check(get_str(arg, mstr := ctypes.create_string_buffer(sz.value)))
return ctypes.string_at(mstr, size=sz.value)

def compile(prg, options, f, check):
check(f.create(ctypes.pointer((prog := f.new())), prg.encode(), "<null>".encode(), 0, None, None))
def cuda_compile(prg, options, f, check):
check(f.create(ctypes.pointer(prog := f.new()), prg.encode(), "<null>".encode(), 0, None, None))
status = f.compile(prog, len(options), to_char_p_p(options))
if status != 0: raise RuntimeError(f"compile failed: {get_bytes(prog, f.getLogSize, f.getLog, check)}")
return get_bytes(prog, f.getCodeSize, f.getCode, check)
12 changes: 5 additions & 7 deletions test/test_cuda.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import ctypes
import unittest
import gpuctypes.cuda as cuda
from helpers import CI, compile
from helpers import CI, cuda_compile

def check(status):
if status != 0:
Expand All @@ -26,10 +26,10 @@ def test_has_methods(self):

def test_compile_fail(self):
with self.assertRaises(RuntimeError):
compile("__device__ void test() { {", ["--gpu-architecture=sm_35"], CUDACompile, check)
cuda_compile("__device__ void test() { {", ["--gpu-architecture=sm_35"], CUDACompile, check)

def test_compile(self):
prg = compile("__device__ int test() { return 42; }", ["--gpu-architecture=sm_35"], CUDACompile, check)
prg = cuda_compile("__device__ int test() { return 42; }", ["--gpu-architecture=sm_35"], CUDACompile, check)
assert len(prg) > 10

@unittest.skipIf(CI, "cuda doesn't work in CI")
Expand All @@ -44,14 +44,12 @@ def setUpClass(cls):

# NOTE: this requires cuInit, so it doesn't run in CI
def test_device_count(self):
count = ctypes.c_int()
check(cuda.cuDeviceGetCount(ctypes.byref(count)))
check(cuda.cuDeviceGetCount(ctypes.byref(count := ctypes.c_int())))
print(f"got {count.value} devices")
assert count.value > 0

def test_malloc(self):
ptr = ctypes.c_ulong()
check(cuda.cuMemAlloc_v2(ctypes.byref(ptr), 16))
check(cuda.cuMemAlloc_v2(ctypes.byref(ptr := ctypes.c_ulong()), 16))
assert ptr.value != 0
print(ptr.value)

Expand Down
12 changes: 9 additions & 3 deletions test/test_hip.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import unittest
import ctypes
import gpuctypes.hip as hip
from helpers import CI, expectedFailureIf, compile
from helpers import CI, expectedFailureIf, cuda_compile

def check(status):
if status != 0: raise RuntimeError(f"HIP Error {status}, {ctypes.string_at(hip.hipGetErrorString(status)).decode()}")
Expand All @@ -23,10 +23,10 @@ def test_has_methods(self):

def test_compile_fail(self):
with self.assertRaises(RuntimeError):
compile("void test() { {", ["--offload-arch=gfx1100"], HIPCompile, check)
cuda_compile("void test() { {", ["--offload-arch=gfx1100"], HIPCompile, check)

def test_compile(self):
prg = compile("int test() { return 42; }", ["--offload-arch=gfx1100"], HIPCompile, check)
prg = cuda_compile("int test() { return 42; }", ["--offload-arch=gfx1100"], HIPCompile, check)
assert len(prg) > 10

class TestHIPDevice(unittest.TestCase):
Expand All @@ -37,6 +37,12 @@ def test_malloc(self):
assert ptr.value != 0
check(hip.hipFree(ptr))

@expectedFailureIf(CI)
def test_device_count(self):
check(hip.hipGetDeviceCount(ctypes.byref(count := ctypes.c_int())))
print(f"got {count.value} devices")
assert count.value > 0

@expectedFailureIf(CI)
def test_get_device_properties(self) -> hip.hipDeviceProp_t:
device_properties = hip.hipDeviceProp_t()
Expand Down

0 comments on commit dd0a06b

Please sign in to comment.