Skip to content

Commit

Permalink
[C] Added magic header to occaType for kernel launch safety
Browse files Browse the repository at this point in the history
  • Loading branch information
dmed256 committed Mar 25, 2018
1 parent e2a352b commit 9b41254
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 4 deletions.
2 changes: 2 additions & 0 deletions include/occa/c/defines.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,6 @@
# endif
#endif

#define OCCA_C_TYPE_MAGIC_HEADER 0x5514E455

#endif
1 change: 1 addition & 0 deletions include/occa/c/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ typedef struct {
} occaDim;

typedef struct {
int magicHeader;
int type;
occaUDim_t bytes;

Expand Down
3 changes: 3 additions & 0 deletions src/c/kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,9 @@ void OCCA_RFUNC occaKernelRunN(occaKernel kernel,
occa::kernelArg kArg;

occaType arg = va_arg(args, occaType);
OCCA_ERROR("A non-occaType argument was passed",
arg.magicHeader == OCCA_C_TYPE_MAGIC_HEADER);

switch (arg.type) {
case occa::c::typeType::none: {
kArg.add(NULL, false, false); break;
Expand Down
26 changes: 22 additions & 4 deletions src/c/types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,17 @@
namespace occa {
namespace c {
occaType defaultOccaType() {
occaType type;
type.type = occa::c::typeType::none;
type.value.ptr = NULL;
return type;
occaType oType;
oType.magicHeader = OCCA_C_TYPE_MAGIC_HEADER;
oType.type = occa::c::typeType::none;
oType.value.ptr = NULL;
return oType;
}

template <>
occaType newOccaType(const bool &value) {
occaType oType;
oType.magicHeader = OCCA_C_TYPE_MAGIC_HEADER;
oType.type = typeType::bool_;
oType.bytes = sizeof(int8_t);
oType.value.int8_ = value;
Expand All @@ -43,6 +45,7 @@ namespace occa {
template <>
occaType newOccaType(const int8_t &value) {
occaType oType;
oType.magicHeader = OCCA_C_TYPE_MAGIC_HEADER;
oType.type = typeType::int8_;
oType.bytes = sizeof(int8_t);
oType.value.int8_ = value;
Expand All @@ -52,6 +55,7 @@ namespace occa {
template <>
occaType newOccaType(const uint8_t &value) {
occaType oType;
oType.magicHeader = OCCA_C_TYPE_MAGIC_HEADER;
oType.type = typeType::uint8_;
oType.bytes = sizeof(uint8_t);
oType.value.uint8_ = value;
Expand All @@ -61,6 +65,7 @@ namespace occa {
template <>
occaType newOccaType(const int16_t &value) {
occaType oType;
oType.magicHeader = OCCA_C_TYPE_MAGIC_HEADER;
oType.type = typeType::int16_;
oType.bytes = sizeof(int16_t);
oType.value.int16_ = value;
Expand All @@ -70,6 +75,7 @@ namespace occa {
template <>
occaType newOccaType(const uint16_t &value) {
occaType oType;
oType.magicHeader = OCCA_C_TYPE_MAGIC_HEADER;
oType.type = typeType::uint16_;
oType.bytes = sizeof(uint16_t);
oType.value.uint16_ = value;
Expand All @@ -79,6 +85,7 @@ namespace occa {
template <>
occaType newOccaType(const int32_t &value) {
occaType oType;
oType.magicHeader = OCCA_C_TYPE_MAGIC_HEADER;
oType.type = typeType::int32_;
oType.bytes = sizeof(int32_t);
oType.value.int32_ = value;
Expand All @@ -88,6 +95,7 @@ namespace occa {
template <>
occaType newOccaType(const uint32_t &value) {
occaType oType;
oType.magicHeader = OCCA_C_TYPE_MAGIC_HEADER;
oType.type = typeType::uint32_;
oType.bytes = sizeof(uint32_t);
oType.value.uint32_ = value;
Expand All @@ -97,6 +105,7 @@ namespace occa {
template <>
occaType newOccaType(const int64_t &value) {
occaType oType;
oType.magicHeader = OCCA_C_TYPE_MAGIC_HEADER;
oType.type = typeType::int64_;
oType.bytes = sizeof(int64_t);
oType.value.int64_ = value;
Expand All @@ -106,6 +115,7 @@ namespace occa {
template <>
occaType newOccaType(const uint64_t &value) {
occaType oType;
oType.magicHeader = OCCA_C_TYPE_MAGIC_HEADER;
oType.type = typeType::uint64_;
oType.bytes = sizeof(uint64_t);
oType.value.uint64_ = value;
Expand All @@ -115,6 +125,7 @@ namespace occa {
template <>
occaType newOccaType(const float &value) {
occaType oType;
oType.magicHeader = OCCA_C_TYPE_MAGIC_HEADER;
oType.type = typeType::float_;
oType.bytes = sizeof(float);
oType.value.float_ = value;
Expand All @@ -124,6 +135,7 @@ namespace occa {
template <>
occaType newOccaType(const double &value) {
occaType oType;
oType.magicHeader = OCCA_C_TYPE_MAGIC_HEADER;
oType.type = typeType::double_;
oType.bytes = sizeof(double);
oType.value.double_ = value;
Expand All @@ -132,6 +144,7 @@ namespace occa {

occaType newOccaType(occa::device device) {
occaType oType;
oType.magicHeader = OCCA_C_TYPE_MAGIC_HEADER;
oType.type = typeType::device;
oType.bytes = sizeof(void*);
oType.value.ptr = (char*) device.getDHandle();
Expand All @@ -140,6 +153,7 @@ namespace occa {

occaType newOccaType(occa::kernel kernel) {
occaType oType;
oType.magicHeader = OCCA_C_TYPE_MAGIC_HEADER;
oType.type = typeType::kernel;
oType.bytes = sizeof(void*);
oType.value.ptr = (char*) kernel.getKHandle();
Expand All @@ -148,6 +162,7 @@ namespace occa {

occaType newOccaType(occa::memory memory) {
occaType oType;
oType.magicHeader = OCCA_C_TYPE_MAGIC_HEADER;
oType.type = typeType::memory;
oType.bytes = sizeof(void*);
oType.value.ptr = (char*) memory.getMHandle();
Expand All @@ -156,6 +171,7 @@ namespace occa {

occaType newOccaType(occa::properties &properties) {
occaType oType;
oType.magicHeader = OCCA_C_TYPE_MAGIC_HEADER;
oType.type = typeType::properties;
oType.bytes = sizeof(void*);
oType.value.ptr = (char*) &properties;
Expand Down Expand Up @@ -335,6 +351,7 @@ OCCA_LFUNC occaType OCCA_RFUNC occaDouble(double value) {
OCCA_LFUNC occaType OCCA_RFUNC occaStruct(void *value,
occaUDim_t bytes) {
occaType oType;
oType.magicHeader = OCCA_C_TYPE_MAGIC_HEADER;
oType.type = occa::c::typeType::struct_;
oType.bytes = bytes;
oType.value.ptr = (char*) value;
Expand All @@ -343,6 +360,7 @@ OCCA_LFUNC occaType OCCA_RFUNC occaStruct(void *value,

OCCA_LFUNC occaType OCCA_RFUNC occaString(const char *str) {
occaType oType;
oType.magicHeader = OCCA_C_TYPE_MAGIC_HEADER;
oType.type = occa::c::typeType::string;
oType.bytes = strlen(str);
oType.value.ptr = const_cast<char*>(str);
Expand Down

0 comments on commit 9b41254

Please sign in to comment.