From f570c3e484ae584f93772e96906d2b0c613dba04 Mon Sep 17 00:00:00 2001
From: JasonChen86899 <15952011990@163.com>
Date: Sat, 18 Nov 2017 00:43:54 +0800
Subject: [PATCH 1/2] added class CommonBlas which maintains inner product and
scalar matrix
---
src/main/java/org/jblas/CommonBlas.java | 88 +++++++++++++++++++++++
src/main/java/org/jblas/DoubleMatrix.java | 7 ++
src/main/java/org/jblas/FloatMatrix.java | 7 ++
3 files changed, 102 insertions(+)
create mode 100644 src/main/java/org/jblas/CommonBlas.java
diff --git a/src/main/java/org/jblas/CommonBlas.java b/src/main/java/org/jblas/CommonBlas.java
new file mode 100644
index 00000000..d7ad2914
--- /dev/null
+++ b/src/main/java/org/jblas/CommonBlas.java
@@ -0,0 +1,88 @@
+package org.jblas;
+
+import org.jblas.exceptions.SizeException;
+
+/**
+ * This class provides some common function interfaces of m*n matrices.
+ * This class is distinguished from the SimpleBlas which
+ * provides an interface for the computation of one row or column matrix dot product.
+ * For instance,it provides an interface of the m*n and n*p matrix inner product.
+ *
+ * For example, you can do any legitimate m*n matrix operation
+ *
+ * Currently, only implements inner,scalar matrix
+ */
+public class CommonBlas {
+
+ /**
+ * inner product of x with y <- x.y
+ * @param x
+ * @param y
+ * @return result of compution
+ * @throws SizeException x colums is inconsistent with y rows
+ */
+ public static FloatMatrix inner(FloatMatrix x,FloatMatrix y) throws SizeException {
+ FloatMatrix x_t = x.transpose();
+ if (x_t.rows == y.rows) {
+ FloatMatrix result = new FloatMatrix(x.columns,y.columns);
+ for (int i = 0; i < x.columns; i++) {
+ for (int j = 0; j < y.columns; j++) {
+ float value_ij = SimpleBlas.dot(x.getColumn(i), y.getColumn(j));
+ result.put(i, j, value_ij);
+ }
+ }
+ return result;
+ }else
+ throw new SizeException("x colums is inconsistent with y rows");
+ }
+
+ /**
+ * inner product of x with y <- x.y
+ * @param x
+ * @param y
+ * @return result of compution
+ * @throws SizeException x colums is inconsistent with y rows
+ */
+ public static DoubleMatrix inner(DoubleMatrix x,DoubleMatrix y) throws SizeException {
+ DoubleMatrix x_t = x.transpose();
+ if (x_t.rows == y.rows) {
+ DoubleMatrix result = new DoubleMatrix(x.columns,y.columns);
+ for (int i = 0; i < x.columns; i++) {
+ for (int j = 0; j < y.columns; j++) {
+ double value_ij = SimpleBlas.dot(x.getColumn(i), y.getColumn(j));
+ result.put(i, j, value_ij);
+ }
+ }
+ return result;
+ }else
+ throw new SizeException("x colums is inconsistent with y rows");
+ }
+
+ /**
+ * provide a instance of n*n float scalar matrix which value is n
+ * @param rows
+ * @param n
+ * @return result
+ */
+ public static FloatMatrix floatScalarMatrixInstance(int rows,float n){
+ FloatMatrix result = new FloatMatrix(rows,rows);
+ for(int i=0;i
Date: Sat, 18 Nov 2017 16:46:55 +0800
Subject: [PATCH 2/2] modified CommonBlas class innner function and added test
case
---
src/main/java/org/jblas/CommonBlas.java | 26 +++++++--------
src/test/java/org/jblas/CommonBlasTest.java | 35 +++++++++++++++++++++
2 files changed, 47 insertions(+), 14 deletions(-)
create mode 100644 src/test/java/org/jblas/CommonBlasTest.java
diff --git a/src/main/java/org/jblas/CommonBlas.java b/src/main/java/org/jblas/CommonBlas.java
index d7ad2914..8aaf8f1c 100644
--- a/src/main/java/org/jblas/CommonBlas.java
+++ b/src/main/java/org/jblas/CommonBlas.java
@@ -16,18 +16,17 @@ public class CommonBlas {
/**
* inner product of x with y <- x.y
- * @param x
- * @param y
+ *
* @return result of compution
* @throws SizeException x colums is inconsistent with y rows
*/
public static FloatMatrix inner(FloatMatrix x,FloatMatrix y) throws SizeException {
FloatMatrix x_t = x.transpose();
if (x_t.rows == y.rows) {
- FloatMatrix result = new FloatMatrix(x.columns,y.columns);
- for (int i = 0; i < x.columns; i++) {
+ FloatMatrix result = new FloatMatrix(x_t.columns,y.columns);
+ for (int i = 0; i < x_t.columns; i++) {
for (int j = 0; j < y.columns; j++) {
- float value_ij = SimpleBlas.dot(x.getColumn(i), y.getColumn(j));
+ float value_ij = SimpleBlas.dot(x_t.getColumn(i), y.getColumn(j));
result.put(i, j, value_ij);
}
}
@@ -38,18 +37,17 @@ public static FloatMatrix inner(FloatMatrix x,FloatMatrix y) throws SizeExceptio
/**
* inner product of x with y <- x.y
- * @param x
- * @param y
+ *
* @return result of compution
* @throws SizeException x colums is inconsistent with y rows
*/
public static DoubleMatrix inner(DoubleMatrix x,DoubleMatrix y) throws SizeException {
DoubleMatrix x_t = x.transpose();
if (x_t.rows == y.rows) {
- DoubleMatrix result = new DoubleMatrix(x.columns,y.columns);
- for (int i = 0; i < x.columns; i++) {
+ DoubleMatrix result = new DoubleMatrix(x_t.columns,y.columns);
+ for (int i = 0; i < x_t.columns; i++) {
for (int j = 0; j < y.columns; j++) {
- double value_ij = SimpleBlas.dot(x.getColumn(i), y.getColumn(j));
+ double value_ij = SimpleBlas.dot(x_t.getColumn(i), y.getColumn(j));
result.put(i, j, value_ij);
}
}
@@ -60,8 +58,8 @@ public static DoubleMatrix inner(DoubleMatrix x,DoubleMatrix y) throws SizeExcep
/**
* provide a instance of n*n float scalar matrix which value is n
- * @param rows
- * @param n
+ * @param rows rows=column
+ * @param n value
* @return result
*/
public static FloatMatrix floatScalarMatrixInstance(int rows,float n){
@@ -74,8 +72,8 @@ public static FloatMatrix floatScalarMatrixInstance(int rows,float n){
/**
* provide a instance of n*n double scalar matrix which value is n
- * @param rows
- * @param n
+ * @param rows rows=column
+ * @param n value
* @return result
*/
public static DoubleMatrix doubleScalarMatrixInstance(int rows,double n){
diff --git a/src/test/java/org/jblas/CommonBlasTest.java b/src/test/java/org/jblas/CommonBlasTest.java
new file mode 100644
index 00000000..fa22a73d
--- /dev/null
+++ b/src/test/java/org/jblas/CommonBlasTest.java
@@ -0,0 +1,35 @@
+package org.jblas;
+
+import org.junit.Test;
+
+import static org.junit.Assert.assertTrue;
+
+/**
+ * Some test for class CommonBlas
+ *
+ * @author Jason Chen
+ */
+public class CommonBlasTest {
+ @Test
+ public void testInner() {
+ DoubleMatrix A = new DoubleMatrix(2, 2, 3.0, -3.0, 1.0, 1.0);
+ DoubleMatrix B = new DoubleMatrix(2,2,1.0,-1.0,2.0,4.0);
+ DoubleMatrix C = new DoubleMatrix(2,2,2.0,-4.0,10.0,-2.0);
+
+ FloatMatrix D = new FloatMatrix(2, 2, 3.0f, -3.0f, 1.0f, 1.0f);
+ FloatMatrix E = new FloatMatrix(2,2,1.0f,-1.0f,2.0f,4.0f);
+ FloatMatrix F = new FloatMatrix(2,2,2.0f,-4.0f,10.0f,-2.0f);
+
+ assertTrue(CommonBlas.inner(A,B).equals(C));
+ assertTrue(CommonBlas.inner(D,E).equals(F));
+ }
+
+ @Test
+ public void testScalar(){
+ DoubleMatrix scalarDoubleMatrix = CommonBlas.doubleScalarMatrixInstance(2,-9.0);
+ FloatMatrix scalarFloatMatrix = CommonBlas.floatScalarMatrixInstance(2,-8.0f);
+
+ assertTrue(new DoubleMatrix(2,2,-9.0,0.0,0.0,-9.0).equals(scalarDoubleMatrix));
+ assertTrue(new FloatMatrix(2,2,-8.0f,0,0,-8.0f).equals(scalarFloatMatrix));
+ }
+}