Skip to content

Commit 9895dde

Browse files
committed
Issue #4612 - add checks for error return types
cudart.cudaMemcpyAsync and cudart.cudaStreamSynchronize return a tuple which leads to AttributeError as cudart.cudaGetErrorString(err) expects a cudaError_t. This is solved by adding type check before passing it to the raise function. Signed-off-by: Gr0ly <[email protected]>
1 parent a833f79 commit 9895dde

File tree

1 file changed

+3
-0
lines changed

1 file changed

+3
-0
lines changed

quickstart/IntroNotebooks/onnx_helper.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ def predict(self, batch): # result gets copied into output
8686
err = cudart.cudaMemcpyAsync(
8787
self.d_input, batch.ctypes.data, batch.nbytes, cudart.cudaMemcpyKind.cudaMemcpyHostToDevice, self.stream
8888
)
89+
err = err[0] if isinstance(err, tuple) else err
8990
if err != cudart.cudaError_t.cudaSuccess:
9091
raise RuntimeError(f"Failed to copy input to device: {cudart.cudaGetErrorString(err)}")
9192

@@ -100,11 +101,13 @@ def predict(self, batch): # result gets copied into output
100101
cudart.cudaMemcpyKind.cudaMemcpyDeviceToHost,
101102
self.stream,
102103
)
104+
err = err[0] if isinstance(err, tuple) else err
103105
if err != cudart.cudaError_t.cudaSuccess:
104106
raise RuntimeError(f"Failed to copy output from device: {cudart.cudaGetErrorString(err)}")
105107

106108
# synchronize threads
107109
err = cudart.cudaStreamSynchronize(self.stream)
110+
err = err[0] if isinstance(err, tuple) else err
108111
if err != cudart.cudaError_t.cudaSuccess:
109112
raise RuntimeError(f"Failed to synchronize stream: {cudart.cudaGetErrorString(err)}")
110113

0 commit comments

Comments
 (0)