From f570c3e484ae584f93772e96906d2b0c613dba04 Mon Sep 17 00:00:00 2001 From: JasonChen86899 <15952011990@163.com> Date: Sat, 18 Nov 2017 00:43:54 +0800 Subject: [PATCH 1/2] added class CommonBlas which maintains inner product and scalar matrix --- src/main/java/org/jblas/CommonBlas.java | 88 +++++++++++++++++++++++ src/main/java/org/jblas/DoubleMatrix.java | 7 ++ src/main/java/org/jblas/FloatMatrix.java | 7 ++ 3 files changed, 102 insertions(+) create mode 100644 src/main/java/org/jblas/CommonBlas.java diff --git a/src/main/java/org/jblas/CommonBlas.java b/src/main/java/org/jblas/CommonBlas.java new file mode 100644 index 00000000..d7ad2914 --- /dev/null +++ b/src/main/java/org/jblas/CommonBlas.java @@ -0,0 +1,88 @@ +package org.jblas; + +import org.jblas.exceptions.SizeException; + +/** + * This class provides some common function interfaces of m*n matrices. + * This class is distinguished from the SimpleBlas which + * provides an interface for the computation of one row or column matrix dot product. + * For instance,it provides an interface of the m*n and n*p matrix inner product. + *

+ * For example, you can do any legitimate m*n matrix operation + *

+ * Currently, only implements inner,scalar matrix + */ +public class CommonBlas { + + /** + * inner product of x with y <- x.y + * @param x + * @param y + * @return result of compution + * @throws SizeException x colums is inconsistent with y rows + */ + public static FloatMatrix inner(FloatMatrix x,FloatMatrix y) throws SizeException { + FloatMatrix x_t = x.transpose(); + if (x_t.rows == y.rows) { + FloatMatrix result = new FloatMatrix(x.columns,y.columns); + for (int i = 0; i < x.columns; i++) { + for (int j = 0; j < y.columns; j++) { + float value_ij = SimpleBlas.dot(x.getColumn(i), y.getColumn(j)); + result.put(i, j, value_ij); + } + } + return result; + }else + throw new SizeException("x colums is inconsistent with y rows"); + } + + /** + * inner product of x with y <- x.y + * @param x + * @param y + * @return result of compution + * @throws SizeException x colums is inconsistent with y rows + */ + public static DoubleMatrix inner(DoubleMatrix x,DoubleMatrix y) throws SizeException { + DoubleMatrix x_t = x.transpose(); + if (x_t.rows == y.rows) { + DoubleMatrix result = new DoubleMatrix(x.columns,y.columns); + for (int i = 0; i < x.columns; i++) { + for (int j = 0; j < y.columns; j++) { + double value_ij = SimpleBlas.dot(x.getColumn(i), y.getColumn(j)); + result.put(i, j, value_ij); + } + } + return result; + }else + throw new SizeException("x colums is inconsistent with y rows"); + } + + /** + * provide a instance of n*n float scalar matrix which value is n + * @param rows + * @param n + * @return result + */ + public static FloatMatrix floatScalarMatrixInstance(int rows,float n){ + FloatMatrix result = new FloatMatrix(rows,rows); + for(int i=0;i Date: Sat, 18 Nov 2017 16:46:55 +0800 Subject: [PATCH 2/2] modified CommonBlas class innner function and added test case --- src/main/java/org/jblas/CommonBlas.java | 26 +++++++-------- src/test/java/org/jblas/CommonBlasTest.java | 35 +++++++++++++++++++++ 2 files changed, 47 insertions(+), 14 deletions(-) create mode 100644 src/test/java/org/jblas/CommonBlasTest.java diff --git a/src/main/java/org/jblas/CommonBlas.java b/src/main/java/org/jblas/CommonBlas.java index d7ad2914..8aaf8f1c 100644 --- a/src/main/java/org/jblas/CommonBlas.java +++ b/src/main/java/org/jblas/CommonBlas.java @@ -16,18 +16,17 @@ public class CommonBlas { /** * inner product of x with y <- x.y - * @param x - * @param y + * * @return result of compution * @throws SizeException x colums is inconsistent with y rows */ public static FloatMatrix inner(FloatMatrix x,FloatMatrix y) throws SizeException { FloatMatrix x_t = x.transpose(); if (x_t.rows == y.rows) { - FloatMatrix result = new FloatMatrix(x.columns,y.columns); - for (int i = 0; i < x.columns; i++) { + FloatMatrix result = new FloatMatrix(x_t.columns,y.columns); + for (int i = 0; i < x_t.columns; i++) { for (int j = 0; j < y.columns; j++) { - float value_ij = SimpleBlas.dot(x.getColumn(i), y.getColumn(j)); + float value_ij = SimpleBlas.dot(x_t.getColumn(i), y.getColumn(j)); result.put(i, j, value_ij); } } @@ -38,18 +37,17 @@ public static FloatMatrix inner(FloatMatrix x,FloatMatrix y) throws SizeExceptio /** * inner product of x with y <- x.y - * @param x - * @param y + * * @return result of compution * @throws SizeException x colums is inconsistent with y rows */ public static DoubleMatrix inner(DoubleMatrix x,DoubleMatrix y) throws SizeException { DoubleMatrix x_t = x.transpose(); if (x_t.rows == y.rows) { - DoubleMatrix result = new DoubleMatrix(x.columns,y.columns); - for (int i = 0; i < x.columns; i++) { + DoubleMatrix result = new DoubleMatrix(x_t.columns,y.columns); + for (int i = 0; i < x_t.columns; i++) { for (int j = 0; j < y.columns; j++) { - double value_ij = SimpleBlas.dot(x.getColumn(i), y.getColumn(j)); + double value_ij = SimpleBlas.dot(x_t.getColumn(i), y.getColumn(j)); result.put(i, j, value_ij); } } @@ -60,8 +58,8 @@ public static DoubleMatrix inner(DoubleMatrix x,DoubleMatrix y) throws SizeExcep /** * provide a instance of n*n float scalar matrix which value is n - * @param rows - * @param n + * @param rows rows=column + * @param n value * @return result */ public static FloatMatrix floatScalarMatrixInstance(int rows,float n){ @@ -74,8 +72,8 @@ public static FloatMatrix floatScalarMatrixInstance(int rows,float n){ /** * provide a instance of n*n double scalar matrix which value is n - * @param rows - * @param n + * @param rows rows=column + * @param n value * @return result */ public static DoubleMatrix doubleScalarMatrixInstance(int rows,double n){ diff --git a/src/test/java/org/jblas/CommonBlasTest.java b/src/test/java/org/jblas/CommonBlasTest.java new file mode 100644 index 00000000..fa22a73d --- /dev/null +++ b/src/test/java/org/jblas/CommonBlasTest.java @@ -0,0 +1,35 @@ +package org.jblas; + +import org.junit.Test; + +import static org.junit.Assert.assertTrue; + +/** + * Some test for class CommonBlas + * + * @author Jason Chen + */ +public class CommonBlasTest { + @Test + public void testInner() { + DoubleMatrix A = new DoubleMatrix(2, 2, 3.0, -3.0, 1.0, 1.0); + DoubleMatrix B = new DoubleMatrix(2,2,1.0,-1.0,2.0,4.0); + DoubleMatrix C = new DoubleMatrix(2,2,2.0,-4.0,10.0,-2.0); + + FloatMatrix D = new FloatMatrix(2, 2, 3.0f, -3.0f, 1.0f, 1.0f); + FloatMatrix E = new FloatMatrix(2,2,1.0f,-1.0f,2.0f,4.0f); + FloatMatrix F = new FloatMatrix(2,2,2.0f,-4.0f,10.0f,-2.0f); + + assertTrue(CommonBlas.inner(A,B).equals(C)); + assertTrue(CommonBlas.inner(D,E).equals(F)); + } + + @Test + public void testScalar(){ + DoubleMatrix scalarDoubleMatrix = CommonBlas.doubleScalarMatrixInstance(2,-9.0); + FloatMatrix scalarFloatMatrix = CommonBlas.floatScalarMatrixInstance(2,-8.0f); + + assertTrue(new DoubleMatrix(2,2,-9.0,0.0,0.0,-9.0).equals(scalarDoubleMatrix)); + assertTrue(new FloatMatrix(2,2,-8.0f,0,0,-8.0f).equals(scalarFloatMatrix)); + } +}