Skip to content

Commit

Permalink
tests are passing
Browse files Browse the repository at this point in the history
  • Loading branch information
raver119 committed Aug 6, 2016
1 parent a925415 commit 95c92a1
Show file tree
Hide file tree
Showing 5 changed files with 112 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,15 @@
import org.nd4j.linalg.api.complex.IComplexNDArray;
import org.nd4j.linalg.api.complex.IComplexNumber;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.accum.MatchCondition;
import org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndReplace;
import org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndSet;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.api.shape.loop.coordinatefunction.CoordinateFunction;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.conditions.BaseCondition;
import org.nd4j.linalg.indexing.conditions.Condition;
import org.nd4j.linalg.indexing.conditions.Conditions;

import java.util.concurrent.atomic.AtomicBoolean;

Expand Down Expand Up @@ -84,17 +86,27 @@ public static boolean or(IComplexNDArray n, Condition cond) {
* condition false otherwise
*/
public static boolean and(final INDArray n, final Condition cond) {
boolean ret = true;
final AtomicBoolean a = new AtomicBoolean(ret);
Shape.iterate(n, new CoordinateFunction() {
@Override
public void process(int[]... coord) {
if (a.get())
a.compareAndSet(true, a.get() && cond.apply(n.getFloat(coord[0])));
}
});
if (cond instanceof BaseCondition) {
long val = (long) Nd4j.getExecutioner().exec(new MatchCondition(n, cond), Integer.MAX_VALUE).getDouble(0);

if (val == n.lengthLong())
return true;
else
return false;

return a.get();
} else {
boolean ret = true;
final AtomicBoolean a = new AtomicBoolean(ret);
Shape.iterate(n, new CoordinateFunction() {
@Override
public void process(int[]... coord) {
if (a.get())
a.compareAndSet(true, a.get() && cond.apply(n.getFloat(coord[0])));
}
});

return a.get();
}
}

/**
Expand All @@ -105,17 +117,27 @@ public void process(int[]... coord) {
* @return
*/
public static boolean or(final INDArray n, final Condition cond) {
boolean ret = false;
final AtomicBoolean a = new AtomicBoolean(ret);
Shape.iterate(n, new CoordinateFunction() {
@Override
public void process(int[]... coord) {
if (!a.get())
a.compareAndSet(false, a.get() || cond.apply(n.getFloat(coord[0])));
}
});
if (cond instanceof BaseCondition) {
long val = (long) Nd4j.getExecutioner().exec(new MatchCondition(n, cond), Integer.MAX_VALUE).getDouble(0);

if (val > 0)
return true;
else
return false;

return a.get();
} else {
boolean ret = false;
final AtomicBoolean a = new AtomicBoolean(ret);
Shape.iterate(n, new CoordinateFunction() {
@Override
public void process(int[]... coord) {
if (!a.get())
a.compareAndSet(false, a.get() || cond.apply(n.getFloat(coord[0])));
}
});

return a.get();
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,10 @@ public static Condition equals(Number value) {
return new EqualsCondition(value);
}

public static Condition notEquals(Number value) {
return new NotEqualsCondition(value);
}

public static Condition greaterThan(IComplexNumber value) {
return new GreaterThan(value);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ public EqualsCondition(IComplexNumber complexNumber) {
*/
@Override
public int condtionNum() {
return 0;
return 10;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
/*
*
* * Copyright 2015 Skymind,Inc.
* *
* * Licensed under the Apache License, Version 2.0 (the "License");
* * you may not use this file except in compliance with the License.
* * You may obtain a copy of the License at
* *
* * http://www.apache.org/licenses/LICENSE-2.0
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS,
* * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* * See the License for the specific language governing permissions and
* * limitations under the License.
*
*
*/

package org.nd4j.linalg.indexing.conditions;

import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.complex.IComplexNumber;
import org.nd4j.linalg.factory.Nd4j;

/**
* Created by agibsonccc on 10/8/14.
*/
public class NotEqualsCondition extends BaseCondition {
public NotEqualsCondition(Number value) {
super(value);
}

public NotEqualsCondition(IComplexNumber complexNumber) {
super(complexNumber);
}

/**
* Returns condition ID for native side
*
* @return
*/
@Override
public int condtionNum() {
return 11;
}

@Override
public Boolean apply(Number input) {
if (Nd4j.dtype == DataBuffer.Type.DOUBLE)
return input.doubleValue() == value.doubleValue();
else
return input.floatValue() == value.floatValue();
}

@Override
public Boolean apply(IComplexNumber input) {
return complexNumber.equals(input);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,12 @@ public void test2dAnd2() throws Exception {

array.slice(4).putScalar(2, 1e-5f);


System.out.println(array);

assertFalse(BooleanIndexing.and(array, Conditions.equals(0f)));


}

@Test
Expand Down

0 comments on commit 95c92a1

Please sign in to comment.