Skip to content

Commit

Permalink
[Cuda] Save the Cuda native error code on adapter-specific errors
Browse files Browse the repository at this point in the history
  • Loading branch information
GeorgeWeb committed Sep 13, 2024
1 parent eb63d1a commit 5feebf6
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 8 deletions.
2 changes: 1 addition & 1 deletion source/adapters/cuda/adapter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urAdapterRelease(ur_adapter_handle_t) {

UR_APIEXPORT ur_result_t UR_APICALL urAdapterGetLastError(
ur_adapter_handle_t, const char **ppMessage, int32_t *pError) {
std::ignore = pError;
*pError = ErrorAdapterNativeCode;
*ppMessage = ErrorMessage;
return ErrorMessageCode;
}
Expand Down
23 changes: 17 additions & 6 deletions source/adapters/cuda/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,12 @@

#include "common.hpp"
#include "logger/ur_logger.hpp"
#include "ur_api.h"

#include <cuda.h>

#include <sstream>
#include <string.h>

ur_result_t mapErrorUR(CUresult Result) {
switch (Result) {
Expand Down Expand Up @@ -105,6 +107,7 @@ void detail::ur::assertion(bool Condition, const char *Message) {
// Global variables for ZER_EXT_RESULT_ADAPTER_SPECIFIC_ERROR
thread_local ur_result_t ErrorMessageCode = UR_RESULT_SUCCESS;
thread_local char ErrorMessage[MaxMessageSize];
thread_local int32_t ErrorAdapterNativeCode = 0;

// Utility function for setting a message and warning
[[maybe_unused]] void setErrorMessage(const char *pMessage,
Expand All @@ -114,16 +117,24 @@ thread_local char ErrorMessage[MaxMessageSize];
ErrorMessageCode = ErrorCode;
}

void setPluginSpecificMessage(CUresult cu_res) {
[[maybe_unused]] void setAdapterSpecificMessage(CUresult cu_res) {
ErrorAdapterNativeCode = static_cast<int32_t>(cu_res);
// according to the documentation of the cuGetErrorName and cuGetErrorString
// CUDA driver APIs, both error_name and error_string are null-terminated.
const char *error_string;
const char *error_name;
cuGetErrorName(cu_res, &error_name);
cuGetErrorString(cu_res, &error_string);
char *message = (char *)malloc(strlen(error_string) + strlen(error_name) + 2);
strcpy(message, error_name);
strcat(message, "\n");
strcat(message, error_string);
static constexpr char new_line[] = "\n";
// non-null-terminated sizes
const size_t error_string_size = std::strlen(error_string);
const size_t error_name_size = std::strlen(error_name);
char *message = reinterpret_cast<char *>(
std::malloc(error_string_size + error_name_size + sizeof(new_line)));
std::strcpy(message, error_name);
std::strcat(message, new_line);
std::strncat(message, error_string, error_string_size);

setErrorMessage(message, UR_RESULT_ERROR_ADAPTER_SPECIFIC);
free(message);
std::free(message);
}
3 changes: 2 additions & 1 deletion source/adapters/cuda/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,13 @@ std::string getCudaVersionString();
constexpr size_t MaxMessageSize = 256;
extern thread_local ur_result_t ErrorMessageCode;
extern thread_local char ErrorMessage[MaxMessageSize];
extern thread_local int32_t ErrorAdapterNativeCode;

// Utility function for setting a message and warning
[[maybe_unused]] void setErrorMessage(const char *pMessage,
ur_result_t ErrorCode);

void setPluginSpecificMessage(CUresult cu_res);
void setAdapterSpecificMessage(CUresult cu_res);

/// ------ Error handling, matching OpenCL plugin semantics.
namespace detail {
Expand Down

0 comments on commit 5feebf6

Please sign in to comment.