Skip to content

Commit

Permalink
more tests for multidim symmetric shuffle
Browse files Browse the repository at this point in the history
  • Loading branch information
raver119 committed Aug 5, 2016
1 parent 4865709 commit 05f9812
Show file tree
Hide file tree
Showing 6 changed files with 231 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -547,6 +547,27 @@ public void shuffle(long seed) {
Nd4j.shuffle(Arrays.asList(getFeatures(), getLabels()), ArrayUtil.range(1,getFeatures().rank()));
}

List<INDArray> arrays = new ArrayList<>();
List<int[]> dimensions = new ArrayList<>();

arrays.add(getFeatures());
dimensions.add(ArrayUtil.range(1,getFeatures().rank()));

arrays.add(getLabels());
dimensions.add(ArrayUtil.range(1,getLabels().rank()));

if (featuresMask != null) {
arrays.add(getFeaturesMaskArray());
dimensions.add(ArrayUtil.range(1,getFeaturesMaskArray().rank()));
}

if (featuresMask != null) {
arrays.add(getLabelsMaskArray());
dimensions.add(ArrayUtil.range(1,getLabelsMaskArray().rank()));
}

Nd4j.shuffle(arrays, dimensions);

/*
int[] nonzeroDimsFeat = ArrayUtil.range(1,getFeatures().rank());
int[] nonzeroDimsLab = ArrayUtil.range(1,getLabels().rank());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -403,14 +403,24 @@ public interface NDArrayFactory {

/**
* Symmetric in place shuffle of an ndarray
* along a specified set of dimensions. All arrays
* along a specified set of dimensions. All arrays should have equal shapes.
*
* @param array the ndarray to shuffle
* @param dimension the dimension to do the shuffle
* @return
*/
void shuffle(Collection<INDArray> array, int... dimension);

/**
* Symmetric in place shuffle of an ndarray
* along a specified set of dimensions. Each array in list should have it's own dimension at the same index of dimensions array
*
* @param array the ndarray to shuffle
* @param dimensions the dimensions to do the shuffle
* @return
*/
void shuffle(List<INDArray> array, List<int[]> dimensions);

/**
* This method averages input arrays, and returns averaged array
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,21 @@ public static void shuffle(Collection<INDArray> toShuffle,int...dimension) {
INSTANCE.shuffle(toShuffle, dimension);
}

/**
* Symmetric in place shuffle of an ndarray
* along a variable dimensions
*
* @param toShuffle the ndarray to shuffle
* @param dimensions the dimension to do the shuffle. Please note - order matters here.
* @return
*/
public static void shuffle(List<INDArray> toShuffle, List<int[]> dimensions) {

INSTANCE.shuffle(toShuffle, dimensions);
}



/**
* The reference queue used for cleaning up
* ndarrays
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -865,16 +865,16 @@ public void shuffle(INDArray array, int... dimension) {

/**
* Symmetric in place shuffle of an ndarray
* along a specified set of dimensions. All arrays
* along a specified set of dimensions. Each array in list should have it's own dimension at the same index of dimensions array
*
* @param sourceArrays the ndarray to shuffle
* @param dimension the dimension to do the shuffle
* @param sourceArrays the ndarray to shuffle
* @param dimensions the dimensions to do the shuffle
* @return
*/
@Override
public void shuffle(Collection<INDArray> sourceArrays, int... dimension) {
public void shuffle(List<INDArray> sourceArrays, List<int[]> dimensions) {
// no dimension - no shuffle
if (dimension == null || dimension.length == 0)
if (dimensions == null || dimensions.size() == 0)
throw new RuntimeException("Dimension can't be null or 0-length");

if (sourceArrays == null || sourceArrays.size() ==0)
Expand All @@ -893,8 +893,8 @@ public void shuffle(Collection<INDArray> sourceArrays, int... dimension) {
}

int tadLength = 1;
for (int i = 0; i < dimension.length; i++) {
tadLength *= arrays.get(0).shape()[dimension[i]];
for (int i = 0; i < dimensions.get(0).length; i++) {
tadLength *= arrays.get(0).shape()[dimensions.get(0)[i]];
}

int numTads = arrays.get(0).length() / tadLength;
Expand Down Expand Up @@ -926,6 +926,8 @@ public void shuffle(Collection<INDArray> sourceArrays, int... dimension) {

TADManager tadManager = ((JCudaExecutioner) Nd4j.getExecutioner()).getTadManager();

int[] dimension = dimensions.get(i);

Pair<DataBuffer, DataBuffer> tadBuffers = tadManager.getTADOnlyShapeInfo(array, dimension);

Pointer tadShapeInfo = AtomicAllocator.getInstance().getPointer(tadBuffers.getFirst(), context);
Expand Down Expand Up @@ -1004,4 +1006,17 @@ public void shuffle(Collection<INDArray> sourceArrays, int... dimension) {
tempOffsets.dataType();
tempTAD.dataType();
}

/**
* Symmetric in place shuffle of an ndarray
* along a specified set of dimensions. All arrays
*
* @param sourceArrays the ndarray to shuffle
* @param dimension the dimension to do the shuffle
* @return
*/
@Override
public void shuffle(Collection<INDArray> sourceArrays, int... dimension) {
shuffle(new ArrayList<INDArray>(sourceArrays), Collections.singletonList(dimension));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,108 @@ public void testSymmetricShuffle1() {
}
}

@Test
public void testSymmetricShuffle2() throws Exception {
INDArray features = Nd4j.zeros(10, 10, 20);
INDArray labels = Nd4j.zeros(10, 10, 3);

for (int x = 0; x < 10; x++) {
features.slice(x).assign(x);
labels.slice(x).assign(x);
}

System.out.println(features);

OrderScanner3D scannerFeatures = new OrderScanner3D(features);
OrderScanner3D scannerLabels = new OrderScanner3D(labels);

System.out.println();

List<INDArray> list = new ArrayList<>();
list.add(features);
list.add(labels);

Nd4j.shuffle(list, 1, 2);

System.out.println(features);

System.out.println("------------------");

System.out.println(labels);

assertTrue(scannerFeatures.compareSlice(features));
assertTrue(scannerLabels.compareSlice(labels));

for (int x = 0; x < 10; x++) {
double val = features.slice(x).getDouble(0);
INDArray row = labels.slice(x);

for (int y = 0; y < row.length(); y++ ) {
assertEquals(val, row.getDouble(y), 0.001);
}
}
}

@Test
public void testSymmetricShuffle3() throws Exception {
INDArray features = Nd4j.zeros(10, 10, 20);
INDArray featuresMask = Nd4j.zeros(10, 20);
INDArray labels = Nd4j.zeros(10, 10, 3);
INDArray labelsMask = Nd4j.zeros(10, 3);

for (int x = 0; x < 10; x++) {
features.slice(x).assign(x);
featuresMask.slice(x).assign(x);
labels.slice(x).assign(x);
labelsMask.slice(x).assign(x);
}

OrderScanner3D scannerFeatures = new OrderScanner3D(features);
OrderScanner3D scannerLabels = new OrderScanner3D(labels);
OrderScanner3D scannerFeaturesMask = new OrderScanner3D(featuresMask);
OrderScanner3D scannerLabelsMask = new OrderScanner3D(labelsMask);


List<INDArray> arrays = new ArrayList<>();
arrays.add(features);
arrays.add(labels);
arrays.add(featuresMask);
arrays.add(labelsMask);

List<int[]> dimensions = new ArrayList<>();
dimensions.add(ArrayUtil.range(1,features.rank()));
dimensions.add(ArrayUtil.range(1,labels.rank()));
dimensions.add(ArrayUtil.range(1,featuresMask.rank()));
dimensions.add(ArrayUtil.range(1,labelsMask.rank()));

Nd4j.shuffle(arrays, dimensions);

assertTrue(scannerFeatures.compareSlice(features));
assertTrue(scannerLabels.compareSlice(labels));
assertTrue(scannerFeaturesMask.compareSlice(featuresMask));
assertTrue(scannerLabelsMask.compareSlice(labelsMask));


for (int x = 0; x < 10; x++) {
double val = features.slice(x).getDouble(0);
INDArray sliceLabels = labels.slice(x);
INDArray sliceLabelsMask = labelsMask.slice(x);
INDArray sliceFeaturesMask = featuresMask.slice(x);

for (int y = 0; y < sliceLabels.length(); y++ ) {
assertEquals(val, sliceLabels.getDouble(y), 0.001);
}

for (int y = 0; y < sliceLabelsMask.length(); y++ ) {
assertEquals(val, sliceLabelsMask.getDouble(y), 0.001);
}

for (int y = 0; y < sliceFeaturesMask.length(); y++ ) {
assertEquals(val, sliceFeaturesMask.getDouble(y), 0.001);
}
}
}


@Test
public void testHalfVectors() throws Exception {
Expand All @@ -119,6 +221,53 @@ public void testHalfVectors() throws Exception {
}


public static class OrderScanner3D {
private float[] map;
public OrderScanner3D(INDArray data) {
map = measureState(data);
}

public float[] measureState(INDArray data) {
// for 3D we save 0 element for each slice.
float[] result = new float[data.shape()[0]];

for (int x = 0; x < data.shape()[0]; x++) {
result[x] = data.slice(x).getFloat(0);
}

return result;
}

public boolean compareSlice(INDArray data) {
float[] newMap = measureState(data);

if (newMap.length != map.length) {
System.out.println("Different map lengths");
return false;
}

if (Arrays.equals(map, newMap)) {
System.out.println("Maps are equal");
return false;
}

for (int x = 0; x < data.shape()[0]; x++) {
INDArray slice = data.slice(x);

for (int y = 0; y < slice.length(); y++) {
if (Math.abs(slice.getFloat(y) - newMap[x]) > Nd4j.EPS_THRESHOLD) {
System.out.print("Different data in a row");
return false;
}
}
}


return true;
}
}


public static class OrderScanner2D {
private float[] map;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -780,4 +780,17 @@ public void shuffle(INDArray array, int... dimension) {
public void shuffle(Collection<INDArray> array, int... dimension) {
// TODO: to be implemented
}

/**
* Symmetric in place shuffle of an ndarray
* along a specified set of dimensions. Each array in list should have it's own dimension at the same index of dimensions array
*
* @param array the ndarray to shuffle
* @param dimensions the dimensions to do the shuffle
* @return
*/
@Override
public void shuffle(List<INDArray> array, List<int[]> dimensions) {
// TODO: to be implemented
}
}

0 comments on commit 05f9812

Please sign in to comment.