Skip to content

Commit

Permalink
Fixes corruption of memory datatypes caused by short circuit logic in…
Browse files Browse the repository at this point in the history
… `modeMemory_t::slice` (#727)
  • Loading branch information
kris-rowe authored Dec 8, 2023
1 parent 3612b5d commit 5226290
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 4 deletions.
4 changes: 0 additions & 4 deletions src/occa/internal/core/memory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,6 @@ namespace occa {

modeMemory_t* modeMemory_t::slice(const dim_t offset_,
const udim_t bytes) {

//quick return if we're not really slicing
if ((offset_ == 0) && (bytes == size)) return this;

OCCA_ERROR("ModeMemory not initialized or has been freed",
modeBuffer != NULL);

Expand Down
17 changes: 17 additions & 0 deletions tests/src/core/memory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@
void testMalloc();
void testSlice();
void testUnwrap();
void testCast();

int main(const int argc, const char **argv) {
testMalloc();
testSlice();
testUnwrap();
testCast();

return 0;
}
Expand Down Expand Up @@ -154,3 +156,18 @@ void testUnwrap() {

delete[] host_memory;
}

void testCast() {
occa::device occa_device({{"mode", "Serial"}});

occa::memory occa_memory = occa_device.malloc<double>(10);

ASSERT_TRUE(occa::dtype::double_ == occa_memory.dtype());

occa::memory casted_memory = occa_memory.cast(occa::dtype::byte);

ASSERT_TRUE(occa::dtype::double_ == occa_memory.dtype());
ASSERT_TRUE(occa::dtype::byte == casted_memory.dtype());

ASSERT_EQ(occa_memory.size(), casted_memory.size());
}

0 comments on commit 5226290

Please sign in to comment.