Skip to content

Commit

Permalink
cuda shuffle prototype 1
Browse files Browse the repository at this point in the history
  • Loading branch information
raver119 committed Aug 4, 2016
1 parent 5e7eade commit 6c91cc6
Show file tree
Hide file tree
Showing 9 changed files with 211 additions and 250 deletions.

This file was deleted.

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,15 @@ public interface NDArrayFactory {
*/
INDArray pullRows(INDArray source, int sourceDimension, int[] indexes);

/**
* In place shuffle of an ndarray
* along a specified set of dimensions
* @param array the ndarray to shuffle
* @param dimension the dimension to do the shuffle
* @return
*/
void shuffle(INDArray array, int... dimension);

/**
* This method averages input arrays, and returns averaged array
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,8 @@ public static void shuffle(INDArray toShuffle,Random random,int...dimension) {
* @return
*/
public static void shuffle(INDArray toShuffle,int...dimension) {
shuffle(toShuffle, new Random(), dimension);
//shuffle(toShuffle, new Random(), dimension);
INSTANCE.shuffle(toShuffle, dimension);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1183,4 +1183,12 @@ public native void concatHalf(
///////////////////////

public native void enableP2P(boolean reallyEnable);

//

public native void shuffleDouble(PointerPointer extraPointers, Pointer x, Pointer xShapeInfo, Pointer z, Pointer zShapeInfo, Pointer shuffleMap, Pointer tadShapeInfo, Pointer tadOffsets);

public native void shuffleFloat(PointerPointer extraPointers, Pointer x, Pointer xShapeInfo, Pointer z, Pointer zShapeInfo, Pointer shuffleMap, Pointer tadShapeInfo, Pointer tadOffsets);

public native void shuffleHalf(PointerPointer extraPointers, Pointer x, Pointer xShapeInfo, Pointer z, Pointer zShapeInfo, Pointer shuffleMap, Pointer tadShapeInfo, Pointer tadOffsets);
}
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
import org.nd4j.linalg.jcublas.buffer.AddressRetriever;
import org.nd4j.linalg.jcublas.buffer.CudaDoubleDataBuffer;
import org.nd4j.linalg.jcublas.buffer.CudaFloatDataBuffer;
import org.nd4j.linalg.jcublas.buffer.CudaIntDataBuffer;
import org.nd4j.linalg.jcublas.complex.ComplexDouble;
import org.nd4j.linalg.jcublas.complex.ComplexFloat;
import org.nd4j.linalg.jcublas.complex.JCublasComplexNDArray;
Expand Down Expand Up @@ -847,4 +848,92 @@ public INDArray average(INDArray[] arrays) {
public INDArray average(INDArray target, Collection<INDArray> arrays) {
return average(target, arrays.toArray(new INDArray[0]));
}

/**
* In place shuffle of an ndarray
* along a specified set of dimensions
*
* @param array the ndarray to shuffle
* @param dimension the dimension to do the shuffle
* @return
*/
@Override
public void shuffle(INDArray array, int... dimension) {
// no dimension - no shuffle
if (dimension == null || dimension.length == 0)
return;

// first we build TAD for input array and dimensions

AtomicAllocator allocator = AtomicAllocator.getInstance();

CudaContext context = allocator.getFlowController().prepareAction(array);

PointerPointer extras = new PointerPointer(
null, // not used
context.getOldStream(),
allocator.getDeviceIdPointer()
);

Pointer x = AtomicAllocator.getInstance().getPointer(array, context);
Pointer xShapeInfo = AtomicAllocator.getInstance().getPointer(array.shapeInfoDataBuffer(), context);


TADManager tadManager = ((JCudaExecutioner) Nd4j.getExecutioner()).getTadManager();

Pair<DataBuffer, DataBuffer> tadBuffers = tadManager.getTADOnlyShapeInfo(array, dimension);

Pointer tadShapeInfo = AtomicAllocator.getInstance().getPointer(tadBuffers.getFirst(), context);

DataBuffer offsets = tadBuffers.getSecond();
Pointer tadOffsets = AtomicAllocator.getInstance().getPointer(offsets, context);

// FIXME: temporal map impl
int[] map = new int[10];
for (int i = 0; i < 10; i++) {
System.out.println("ShuffleMap: I: "+ i + ", V: "+ (9 - i));
map[i] = 9 - i;
}
CudaIntDataBuffer shuffle = new CudaIntDataBuffer(map);

Pointer shuffleMap = allocator.getPointer(shuffle, context);

if (array.data().dataType() == DataBuffer.Type.DOUBLE) {
nativeOps.shuffleDouble(
extras,
x,
xShapeInfo,
x,
xShapeInfo,
shuffleMap,
tadShapeInfo,
tadOffsets
);
} else if (array.data().dataType() == DataBuffer.Type.FLOAT) {
nativeOps.shuffleFloat(
extras,
x,
xShapeInfo,
x,
xShapeInfo,
shuffleMap,
tadShapeInfo,
tadOffsets
);
} else {
// HALFs
nativeOps.shuffleHalf(
extras,
x,
xShapeInfo,
x,
xShapeInfo,
shuffleMap,
tadShapeInfo,
tadOffsets
);
}

allocator.getFlowController().registerAction(context, array);
}
}
Loading

0 comments on commit 6c91cc6

Please sign in to comment.