From e9c2988af29ea56d8b9c6488c72e7844180346ad Mon Sep 17 00:00:00 2001 From: Stefan Hahmann Date: Tue, 19 Nov 2024 11:48:54 +0100 Subject: [PATCH] Add (unfinished) PCADemo --- pom.xml | 8 ++++ .../dimensionalityreduction/pca/PCADemo.java | 43 +++++++++++++++++++ 2 files changed, 51 insertions(+) create mode 100644 src/test/java/org/mastodon/mamut/feature/dimensionalityreduction/pca/PCADemo.java diff --git a/pom.xml b/pom.xml index 57254af65..5e759db9d 100644 --- a/pom.xml +++ b/pom.xml @@ -129,6 +129,14 @@ v2.6.4 + + + + org.apache.spark + spark-mllib_2.13 + 3.5.3 + + org.apache.commons diff --git a/src/test/java/org/mastodon/mamut/feature/dimensionalityreduction/pca/PCADemo.java b/src/test/java/org/mastodon/mamut/feature/dimensionalityreduction/pca/PCADemo.java new file mode 100644 index 000000000..e7dd03886 --- /dev/null +++ b/src/test/java/org/mastodon/mamut/feature/dimensionalityreduction/pca/PCADemo.java @@ -0,0 +1,43 @@ +package org.mastodon.mamut.feature.dimensionalityreduction.pca; + +import java.util.ArrayList; +import java.util.List; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.linalg.Matrix; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.mllib.linalg.distributed.RowMatrix; +import org.mastodon.mamut.feature.dimensionalityreduction.PlotPoints; +import org.mastodon.mamut.feature.dimensionalityreduction.RandomDataTools; + +public class PCADemo +{ + // TODO: check https://stackoverflow.com/questions/10604507/pca-implementation-in-java + public static void main( String[] args ) + { + double[][] inputData = RandomDataTools.generateSampleData(); + double[][] result = setUpPCA( inputData ); + PlotPoints.plot( inputData, result, resultValues -> resultValues[ 0 ] > 10 ); + } + + static double[][] setUpPCA( double[][] inputData ) + { + try (JavaSparkContext jsc = new JavaSparkContext( "local", "PCA" )) + { + List< Vector > data = new ArrayList<>(); + for ( final double[] row : inputData ) + data.add( Vectors.dense( row ) ); + JavaRDD< Vector > rows = jsc.parallelize( data ); + // Create a RowMatrix from JavaRDD. + RowMatrix rowMatrix = new RowMatrix( rows.rdd() ); + // Compute the top 2 principal components. + // Principal components are stored in a local dense matrix. + Matrix pc = rowMatrix.computePrincipalComponents( 2 ); + // Project the rows to the linear space spanned by the top 4 principal components. + RowMatrix projected = rowMatrix.multiply( pc ); + return projected.rows().toJavaRDD().collect().stream().map( Vector::toArray ).toArray( double[][]::new ); + } + } +}