Skip to content

Commit

Permalink
test cuda
Browse files Browse the repository at this point in the history
  • Loading branch information
geohot committed Nov 26, 2023
1 parent f7f425c commit 4be6a99
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 2 deletions.
5 changes: 4 additions & 1 deletion gpuctypes/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,10 @@ def char_pointer_cast(string, encoding='utf-8'):

_libraries = {}
_libraries['libcuda.so'] = ctypes.CDLL('/usr/lib/x86_64-linux-gnu/libcuda.so')
_libraries['libnvrtc.so'] = ctypes.CDLL('/usr/lib/x86_64-linux-gnu/libnvrtc.so')
try:
_libraries['libnvrtc.so'] = ctypes.CDLL('/usr/lib/x86_64-linux-gnu/libnvrtc.so')
except OSError:
_libraries['libnvrtc.so'] = ctypes.CDLL('/usr/local/cuda/targets/x86_64-linux/lib/libnvrtc.so')


cuuint32_t = ctypes.c_uint32
Expand Down
28 changes: 27 additions & 1 deletion test/test_cuda.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,33 @@
import ctypes
import unittest
import gpuctypes.cuda as cuda

def check(status):
if status != 0:
error = ctypes.POINTER(ctypes.c_char)()
check(cuda.cuGetErrorString(status, ctypes.byref(error)))
raise RuntimeError(f"CUDA Error {status}, {ctypes.string_at(error).decode()}")

class TestCUDA(unittest.TestCase):
pass
@classmethod
def setUpClass(cls):
check(cuda.cuInit(0))
cls.device = cuda.CUdevice()
check(cuda.cuDeviceGet(ctypes.byref(cls.device), 0))
cls.context = cuda.CUcontext()
check(cuda.cuCtxCreate_v2(ctypes.byref(cls.context), 0, cls.device))

def test_device_count(self):
count = ctypes.c_int()
check(cuda.cuDeviceGetCount(ctypes.byref(count)))
print(f"got {count.value} devices")
assert count.value > 0

def test_malloc(self):
ptr = ctypes.c_ulong()
check(cuda.cuMemAlloc_v2(ctypes.byref(ptr), 16))
assert ptr.value != 0
print(ptr.value)

if __name__ == '__main__':
unittest.main()

0 comments on commit 4be6a99

Please sign in to comment.