Skip to content

Commit

Permalink
native match
Browse files Browse the repository at this point in the history
  • Loading branch information
raver119 committed Aug 6, 2016
1 parent 73e5f6b commit a925415
Show file tree
Hide file tree
Showing 6 changed files with 175 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
import org.nd4j.linalg.api.complex.IComplexNDArray;
import org.nd4j.linalg.api.complex.IComplexNumber;
import org.nd4j.linalg.api.instrumentation.Instrumentation;
import org.nd4j.linalg.api.iter.NdIndexIterator;
import org.nd4j.linalg.api.iter.FirstAxisIterator;
import org.nd4j.linalg.api.ops.impl.accum.Max;
import org.nd4j.linalg.api.ops.impl.accum.*;
Expand Down Expand Up @@ -2953,7 +2952,7 @@ public INDArray assign(Number value) {
*/
@Override
public INDArray assignIf(INDArray arr, Condition condition) {
BooleanIndexing.assignWhere(this, arr, condition);
BooleanIndexing.assignIf(this, arr, condition);
return this;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@
import org.nd4j.linalg.api.ops.Op;

/**
* Sum the components
* Absolute sum the components
*
* @author Adam Gibson
* @author [email protected]
*/
public class ASum extends BaseAccumulation {

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
/*
*
* * Copyright 2015 Skymind,Inc.
* *
* * Licensed under the Apache License, Version 2.0 (the "License");
* * you may not use this file except in compliance with the License.
* * You may obtain a copy of the License at
* *
* * http://www.apache.org/licenses/LICENSE-2.0
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS,
* * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* * See the License for the specific language governing permissions and
* * limitations under the License.
*
*
*/

package org.nd4j.linalg.api.ops.impl.accum;

import org.nd4j.linalg.api.complex.IComplexNumber;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.BaseAccumulation;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.conditions.Condition;

/**
* Absolute sum the components
*
* @author [email protected]
*/
public class MatchCondition extends BaseAccumulation {

double compare;
double eps;
int mode;

public MatchCondition() {
}


public MatchCondition(INDArray x, Condition condition) {
this(x, Nd4j.EPS_THRESHOLD, condition);
}

public MatchCondition(INDArray x, double eps, Condition condition) {
super(x);
this.compare = condition.getValue();
this.mode = condition.condtionNum();
this.eps = eps;

this.extraArgs = new Object[] {compare, eps, (double) mode};
}

@Override
public int opNum() {
return 12;
}

@Override
public String name() {
return "match_condition";
}

@Override
public Op opForDimension(int index, int dimension) {
return null;
}

@Override
public Op opForDimension(int index, int... dimension) {
return null;
}

@Override
public IComplexNumber op(IComplexNumber origin, IComplexNumber other) {
return null;
}

@Override
public IComplexNumber op(IComplexNumber origin, float other) {
return null;
}

@Override
public IComplexNumber op(IComplexNumber origin, double other) {
return null;
}

@Override
public IComplexNumber op(IComplexNumber origin) {
return null;
}

@Override
public double update(double accum, double x) {
return 0;
}

@Override
public double update(double accum, double x, double y) {
return 0;
}

@Override
public float update(float accum, float x) {
return 0;
}

@Override
public float update(float accum, float x, float y) {
return 0;
}

@Override
public IComplexNumber update(IComplexNumber accum, double x) {
return null;
}

@Override
public IComplexNumber update(IComplexNumber accum, double x, double y) {
return null;
}

@Override
public IComplexNumber update(IComplexNumber accum, IComplexNumber x) {
return null;
}

@Override
public IComplexNumber update(IComplexNumber accum, IComplexNumber x, IComplexNumber y) {
return null;
}

@Override
public IComplexNumber update(IComplexNumber accum, IComplexNumber x, double y) {
return null;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
import org.nd4j.linalg.indexing.conditions.Condition;

/**
* Element-wise Compare-and-set implementation as Op
* Element-wise Compare-and-Replace implementation as Op
*
* @author [email protected]
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ public void process(int[]... coord) {
* @param from
* @param condition
*/
public static void assignWhere(@NonNull INDArray to,@NonNull INDArray from, @NonNull Condition condition) {
public static void assignIf(@NonNull INDArray to, @NonNull INDArray from, @NonNull Condition condition) {
if (!(condition instanceof BaseCondition))
throw new UnsupportedOperationException("Only static Conditions are supported");

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import org.junit.runners.Parameterized;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.accum.MatchCondition;
import org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndReplace;
import org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndSet;
import org.nd4j.linalg.factory.Nd4j;
Expand Down Expand Up @@ -328,6 +329,34 @@ public void testCaRPairwiseTransform3() throws Exception {
assertEquals(comp, x);
}


@Test
public void testMatchConditionAllDimensions1() throws Exception {
INDArray array = Nd4j.create(new double[]{0, 1, 2, 3, 4, 5, 6, 7, 8, 9});

int val = (int) Nd4j.getExecutioner().exec(new MatchCondition(array, Conditions.lessThan(5)), Integer.MAX_VALUE).getDouble(0);

assertEquals(5, val);
}

@Test
public void testMatchConditionAllDimensions2() throws Exception {
INDArray array = Nd4j.create(new double[]{0, 1, 2, 3, Double.NaN, 5, 6, 7, 8, 9});

int val = (int) Nd4j.getExecutioner().exec(new MatchCondition(array, Conditions.isNan()), Integer.MAX_VALUE).getDouble(0);

assertEquals(1, val);
}

@Test
public void testMatchConditionAllDimensions3() throws Exception {
INDArray array = Nd4j.create(new double[]{0, 1, 2, 3, Double.NEGATIVE_INFINITY, 5, 6, 7, 8, 9});

int val = (int) Nd4j.getExecutioner().exec(new MatchCondition(array, Conditions.isInfinite()), Integer.MAX_VALUE).getDouble(0);

assertEquals(1, val);
}

@Override
public char ordering() {
return 'c';
Expand Down

0 comments on commit a925415

Please sign in to comment.