Skip to content

Commit

Permalink
Add input validation for row/column vector ops
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexDBlack committed Aug 5, 2016
1 parent 42df7f6 commit 1f9190b
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2010,6 +2010,11 @@ protected void assertColumnVector(INDArray column) {
* @return
*/
protected INDArray doColumnWise(INDArray columnVector, char operation) {
//Input validation: require (a) columnVector to actually be a column vector, and (b) this.size(0) to match columnVector.size(0)
if(!columnVector.isColumnVector() || this.size(0) != columnVector.size(1)){
throw new IllegalStateException("Mismatched shapes (shape = " + Arrays.toString(shape()) + ", row vector shape =" + Arrays.toString(columnVector.shape()) + ")");
}

if(columnVector.data().sameUnderlyingData(data()))
return doColumnWise(columnVector.dup(),operation);
if(isVector()) {
Expand Down Expand Up @@ -2082,6 +2087,11 @@ protected void assertRowVector(INDArray rowVector) {
* @return
*/
protected INDArray doRowWise(final INDArray rowVector, final char operation) {
//Input validation: require (a) rowVector to actually be a row vector, and (b) this.size(1) to match rowVector.size(1)
if(!rowVector.isRowVector() || this.size(1) != rowVector.size(1)){
throw new IllegalStateException("Mismatched shapes (shape = " + Arrays.toString(shape()) + ", row vector shape =" + Arrays.toString(rowVector.shape()) + ")");
}

if(rowVector.data().sameUnderlyingData(data()))
return doRowWise(rowVector.dup(),operation);

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package org.nd4j.linalg;

import org.junit.Test;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;

/**
* Created by Alex on 05/08/2016.
*/
public class InputValidationTests {

@Test(expected=IllegalStateException.class)
public void testInvalidColVectorOp1(){
INDArray first = Nd4j.create(10,10);
INDArray col = Nd4j.create(5,1);
first.muliColumnVector(col);
}

@Test(expected=IllegalStateException.class)
public void testInvalidColVectorOp2(){
INDArray first = Nd4j.create(10,10);
INDArray col = Nd4j.create(5,1);
first.addColumnVector(col);
}

@Test(expected=IllegalStateException.class)
public void testInvalidRowVectorOp1(){
INDArray first = Nd4j.create(10,10);
INDArray row = Nd4j.create(1,5);
first.addiRowVector(row);
}

@Test(expected=IllegalStateException.class)
public void testInvalidRowVectorOp2(){
INDArray first = Nd4j.create(10,10);
INDArray row = Nd4j.create(1,5);
first.subRowVector(row);
}

}

0 comments on commit 1f9190b

Please sign in to comment.