Skip to content

Commit

Permalink
[Linalg] Fixed reduction RETTYPE locations
Browse files Browse the repository at this point in the history
  • Loading branch information
dmed256 committed Jul 4, 2018
1 parent e672972 commit 4bee423
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 42 deletions.
32 changes: 16 additions & 16 deletions include/occa/array/kernels/linalg.okl
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
#define CPU_CHUNK ((entries + CPU_DOT_OUTER - 1) / CPU_DOT_OUTER)
#define CPU_REDUCTION_BODY(INIT, OPERATION, RED_OPERATION) \
for (int oi = 0; oi < CPU_DOT_OUTER; ++oi; @outer) { \
VTYPE r_red = INIT(oi); \
RETTYPE r_red = INIT(oi); \
for (int i = 0; i < CPU_CHUNK; ++i; @inner) { \
if ((oi * CPU_CHUNK + i) < entries) { \
OPERATION(r_red, oi * CPU_CHUNK + i); \
Expand Down Expand Up @@ -125,10 +125,10 @@

#define GPU_REDUCTION_BODY(INIT, OPERATION, RED_OPERATION) \
for (int oi = 0; oi < GPU_DOT_OUTER; ++oi; @outer) { \
@shared VTYPE s_red[GPU_DOT_INNER]; \
@shared RETTYPE s_red[GPU_DOT_INNER]; \
\
for (int i = 0; i < GPU_DOT_INNER; ++i; @inner) { \
VTYPE r_red = INIT(oi); \
RETTYPE r_red = INIT(oi); \
for (int j = (oi*GPU_DOT_INNER + i); j < entries; j += GPU_DOT_BLOCK) { \
OPERATION(r_red, j); \
} \
Expand Down Expand Up @@ -166,17 +166,17 @@
red += part

#define MAX_RED_OPERATION(red, part) \
const VTYPE r_red2 = red; \
const VTYPE r_part2 = part; \
const RETTYPE r_red2 = red; \
const RETTYPE r_part2 = part; \
red = r_red2 > r_part2 ? r_red2 : r_part2

#define MIN_RED_OPERATION(red, part) \
const VTYPE r_red2 = red; \
const VTYPE r_part2 = part; \
const RETTYPE r_red2 = red; \
const RETTYPE r_part2 = part; \
red = r_red2 < r_part2 ? r_red2 : r_part2

@kernel void l1Norm(const int entries,
const VTYPE * vec,
const VTYPE * vec,
RETTYPE * vecReduction) {
#define L1_NORM_OPERATION(out, idx) \
out += ABS_FUNC((VTYPE) vec[idx])
Expand All @@ -187,7 +187,7 @@
const VTYPE * vec,
RETTYPE * vecReduction) {
#define L2_NORM_OPERATION(out, idx) \
const VTYPE vec_i = vec[idx]; \
const RETTYPE vec_i = vec[idx]; \
out += vec_i * vec_i;
REDUCTION_BODY(INIT_ZERO, L2_NORM_OPERATION, SUM_RED_OPERATION);
}
Expand All @@ -197,18 +197,18 @@
const VTYPE * vec,
RETTYPE * vecReduction) {
#define LP_NORM_OPERATION(out, idx) \
const VTYPE vec_i = vec[idx]; \
const RETTYPE vec_i = vec[idx]; \
out += pow((VTYPE) vec_i, (VTYPE) p)
REDUCTION_BODY(INIT_ZERO, LP_NORM_OPERATION, SUM_RED_OPERATION);
}

@kernel void lInfNorm(const int entries,
const VTYPE * vec,
RETTYPE * vecReduction) {
#define LINF_NORM_OPERATION(out, idx) \
const VTYPE vec_i = ABS_FUNC((VTYPE) vec[idx]); \
if (out < vec_i) { \
out = vec_i; \
#define LINF_NORM_OPERATION(out, idx) \
const RETTYPE vec_i = ABS_FUNC((VTYPE) vec[idx]); \
if (out < vec_i) { \
out = vec_i; \
}

REDUCTION_BODY(INIT_ABS_FIRST, LINF_NORM_OPERATION, MAX_RED_OPERATION);
Expand All @@ -218,7 +218,7 @@
const VTYPE * vec,
RETTYPE * vecReduction) {
#define MAX_OPERATION(out, idx) \
const VTYPE vec_i = vec[idx]; \
const RETTYPE vec_i = vec[idx]; \
out = out > vec_i ? out : vec_i
REDUCTION_BODY(INIT_FIRST, MAX_OPERATION, MAX_RED_OPERATION);
}
Expand All @@ -227,7 +227,7 @@
const VTYPE * vec,
RETTYPE * vecReduction) {
#define MIN_OPERATION(out, idx) \
const VTYPE vec_i = vec[idx]; \
const RETTYPE vec_i = vec[idx]; \
out = out < vec_i ? out : vec_i
REDUCTION_BODY(INIT_FIRST, MIN_OPERATION, MIN_RED_OPERATION);
}
Expand Down
25 changes: 11 additions & 14 deletions include/occa/array/linalg.tpp
Original file line number Diff line number Diff line change
Expand Up @@ -194,24 +194,13 @@ namespace occa {
//---[ Linear Algebra ]-------------
template <class TM>
TM *hostReductionBuffer(const int size) {
std::map<int, TM*> &bufferMap = hostBufferMap<TM>();
TM *&buffer = bufferMap[size];
if (!buffer) {
buffer = new TM[size];
}
return buffer;
return new TM[size];
}

template <class TM>
occa::memory deviceReductionBuffer(occa::device device,
const int size) {

hashedMemoryMap &bufferMap = deviceBufferMap<TM>();
occa::memory &buffer = bufferMap[hash(device) ^ size];
if (!buffer.isInitialized()) {
buffer = device.malloc(size * sizeof(TM));
}
return buffer;
return device.malloc(size * sizeof(TM));
}

template <class VTYPE, class RETTYPE>
Expand All @@ -226,7 +215,6 @@ namespace occa {
builder.build(dev)(entries,
vec,
deviceBuffer);
dev.finish();
deviceBuffer.copyTo(hostBuffer);
return hostBuffer;
}
Expand All @@ -241,6 +229,7 @@ namespace occa {
for (int i = 0; i < 1024; ++i) {
ret += partialReduction[i];
}
delete partialReduction;
return ret;
}

Expand All @@ -254,6 +243,7 @@ namespace occa {
for (int i = 0; i < 1024; ++i) {
ret += partialReduction[i];
}
delete partialReduction;
return sqrt(ret);
}

Expand All @@ -279,6 +269,7 @@ namespace occa {
for (int i = 0; i < 1024; ++i) {
ret += hostBuffer[i];
}
delete hostBuffer;
return pow(ret, 1.0/p);
}

Expand All @@ -295,6 +286,7 @@ namespace occa {
ret = abs_i;
}
}
delete partialReduction;
return ret;
}

Expand All @@ -310,6 +302,7 @@ namespace occa {
ret = partialReduction[i];
}
}
delete partialReduction;
return ret;
}

Expand All @@ -325,6 +318,7 @@ namespace occa {
ret = partialReduction[i];
}
}
delete partialReduction;
return ret;
}

Expand All @@ -351,6 +345,7 @@ namespace occa {
for (int i = 0; i < 1024; ++i) {
ret += hostBuffer[i];
}
delete hostBuffer;
return ret;
}

Expand All @@ -377,6 +372,7 @@ namespace occa {
for (int i = 0; i < 1024; ++i) {
ret += hostBuffer[i];
}
delete hostBuffer;
return sqrt(ret);
}

Expand All @@ -390,6 +386,7 @@ namespace occa {
for (int i = 0; i < 1024; ++i) {
ret += partialReduction[i];
}
delete partialReduction;
return ret;
}

Expand Down
24 changes: 12 additions & 12 deletions src/kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -196,18 +196,18 @@ namespace occa {
for (int i = 0; i < argCount; ++i) {
occa::memory_v *mHandle = args[i].mHandle;

if (mHandle &&
mHandle->isManaged() &&
mHandle->dHandle->hasSeparateMemorySpace()) {

if (!mHandle->inDevice()) {
mHandle->copyFrom(mHandle->uvaPtr, mHandle->size);
mHandle->memInfo |= uvaFlag::inDevice;
}
if (!isConst && !mHandle->isStale()) {
uvaStaleMemory.push_back(mHandle);
mHandle->memInfo |= uvaFlag::isStale;
}
if (!mHandle ||
!mHandle->isManaged() ||
!mHandle->dHandle->hasSeparateMemorySpace()) {
continue;
}
if (!mHandle->inDevice()) {
mHandle->copyFrom(mHandle->uvaPtr, mHandle->size);
mHandle->memInfo |= uvaFlag::inDevice;
}
if (!isConst && !mHandle->isStale()) {
uvaStaleMemory.push_back(mHandle);
mHandle->memInfo |= uvaFlag::isStale;
}
}
}
Expand Down

0 comments on commit 4bee423

Please sign in to comment.