From 95c92a1cee513318e8bfb0a4ff30843f16eabb20 Mon Sep 17 00:00:00 2001 From: raver119 Date: Sat, 6 Aug 2016 20:53:00 +0300 Subject: [PATCH] tests are passing --- .../nd4j/linalg/indexing/BooleanIndexing.java | 62 +++++++++++++------ .../indexing/conditions/Conditions.java | 4 ++ .../indexing/conditions/EqualsCondition.java | 2 +- .../conditions/NotEqualsCondition.java | 60 ++++++++++++++++++ .../linalg/indexing/BooleanIndexingTest.java | 5 ++ 5 files changed, 112 insertions(+), 21 deletions(-) create mode 100644 nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/indexing/conditions/NotEqualsCondition.java diff --git a/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/indexing/BooleanIndexing.java b/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/indexing/BooleanIndexing.java index a4ebf0694b..9d370fd1bd 100644 --- a/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/indexing/BooleanIndexing.java +++ b/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/indexing/BooleanIndexing.java @@ -24,6 +24,7 @@ import org.nd4j.linalg.api.complex.IComplexNDArray; import org.nd4j.linalg.api.complex.IComplexNumber; import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.impl.accum.MatchCondition; import org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndReplace; import org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndSet; import org.nd4j.linalg.api.shape.Shape; @@ -31,6 +32,7 @@ import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.indexing.conditions.BaseCondition; import org.nd4j.linalg.indexing.conditions.Condition; +import org.nd4j.linalg.indexing.conditions.Conditions; import java.util.concurrent.atomic.AtomicBoolean; @@ -84,17 +86,27 @@ public static boolean or(IComplexNDArray n, Condition cond) { * condition false otherwise */ public static boolean and(final INDArray n, final Condition cond) { - boolean ret = true; - final AtomicBoolean a = new AtomicBoolean(ret); - Shape.iterate(n, new CoordinateFunction() { - @Override - public void process(int[]... coord) { - if (a.get()) - a.compareAndSet(true, a.get() && cond.apply(n.getFloat(coord[0]))); - } - }); + if (cond instanceof BaseCondition) { + long val = (long) Nd4j.getExecutioner().exec(new MatchCondition(n, cond), Integer.MAX_VALUE).getDouble(0); + + if (val == n.lengthLong()) + return true; + else + return false; - return a.get(); + } else { + boolean ret = true; + final AtomicBoolean a = new AtomicBoolean(ret); + Shape.iterate(n, new CoordinateFunction() { + @Override + public void process(int[]... coord) { + if (a.get()) + a.compareAndSet(true, a.get() && cond.apply(n.getFloat(coord[0]))); + } + }); + + return a.get(); + } } /** @@ -105,17 +117,27 @@ public void process(int[]... coord) { * @return */ public static boolean or(final INDArray n, final Condition cond) { - boolean ret = false; - final AtomicBoolean a = new AtomicBoolean(ret); - Shape.iterate(n, new CoordinateFunction() { - @Override - public void process(int[]... coord) { - if (!a.get()) - a.compareAndSet(false, a.get() || cond.apply(n.getFloat(coord[0]))); - } - }); + if (cond instanceof BaseCondition) { + long val = (long) Nd4j.getExecutioner().exec(new MatchCondition(n, cond), Integer.MAX_VALUE).getDouble(0); + + if (val > 0) + return true; + else + return false; - return a.get(); + } else { + boolean ret = false; + final AtomicBoolean a = new AtomicBoolean(ret); + Shape.iterate(n, new CoordinateFunction() { + @Override + public void process(int[]... coord) { + if (!a.get()) + a.compareAndSet(false, a.get() || cond.apply(n.getFloat(coord[0]))); + } + }); + + return a.get(); + } } /** diff --git a/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/indexing/conditions/Conditions.java b/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/indexing/conditions/Conditions.java index 73b97b9757..d9ccb94367 100644 --- a/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/indexing/conditions/Conditions.java +++ b/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/indexing/conditions/Conditions.java @@ -60,6 +60,10 @@ public static Condition equals(Number value) { return new EqualsCondition(value); } + public static Condition notEquals(Number value) { + return new NotEqualsCondition(value); + } + public static Condition greaterThan(IComplexNumber value) { return new GreaterThan(value); } diff --git a/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/indexing/conditions/EqualsCondition.java b/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/indexing/conditions/EqualsCondition.java index cb8ab53c04..c866fbf155 100644 --- a/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/indexing/conditions/EqualsCondition.java +++ b/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/indexing/conditions/EqualsCondition.java @@ -42,7 +42,7 @@ public EqualsCondition(IComplexNumber complexNumber) { */ @Override public int condtionNum() { - return 0; + return 10; } @Override diff --git a/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/indexing/conditions/NotEqualsCondition.java b/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/indexing/conditions/NotEqualsCondition.java new file mode 100644 index 0000000000..5979fb402d --- /dev/null +++ b/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/indexing/conditions/NotEqualsCondition.java @@ -0,0 +1,60 @@ +/* + * + * * Copyright 2015 Skymind,Inc. + * * + * * Licensed under the Apache License, Version 2.0 (the "License"); + * * you may not use this file except in compliance with the License. + * * You may obtain a copy of the License at + * * + * * http://www.apache.org/licenses/LICENSE-2.0 + * * + * * Unless required by applicable law or agreed to in writing, software + * * distributed under the License is distributed on an "AS IS" BASIS, + * * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * * See the License for the specific language governing permissions and + * * limitations under the License. + * + * + */ + +package org.nd4j.linalg.indexing.conditions; + +import org.nd4j.linalg.api.buffer.DataBuffer; +import org.nd4j.linalg.api.complex.IComplexNumber; +import org.nd4j.linalg.factory.Nd4j; + +/** + * Created by agibsonccc on 10/8/14. + */ +public class NotEqualsCondition extends BaseCondition { + public NotEqualsCondition(Number value) { + super(value); + } + + public NotEqualsCondition(IComplexNumber complexNumber) { + super(complexNumber); + } + + /** + * Returns condition ID for native side + * + * @return + */ + @Override + public int condtionNum() { + return 11; + } + + @Override + public Boolean apply(Number input) { + if (Nd4j.dtype == DataBuffer.Type.DOUBLE) + return input.doubleValue() == value.doubleValue(); + else + return input.floatValue() == value.floatValue(); + } + + @Override + public Boolean apply(IComplexNumber input) { + return complexNumber.equals(input); + } +} diff --git a/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/indexing/BooleanIndexingTest.java b/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/indexing/BooleanIndexingTest.java index a07d2a120f..2045d59fb8 100644 --- a/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/indexing/BooleanIndexingTest.java +++ b/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/indexing/BooleanIndexingTest.java @@ -167,7 +167,12 @@ public void test2dAnd2() throws Exception { array.slice(4).putScalar(2, 1e-5f); + + System.out.println(array); + assertFalse(BooleanIndexing.and(array, Conditions.equals(0f))); + + } @Test