Skip to content

Commit

Permalink
Add (unfinished) PCADemo
Browse files Browse the repository at this point in the history
  • Loading branch information
stefanhahmann committed Nov 19, 2024
1 parent e0fb205 commit e9c2988
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 0 deletions.
8 changes: 8 additions & 0 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,14 @@
<version>v2.6.4</version>
</dependency>

<!-- PCA -->
<!-- Example: https://spark.apache.org/docs/latest/mllib-dimensionality-reduction#principal-component-analysis-pca -->
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-mllib_2.13</artifactId>
<version>3.5.3</version>
</dependency>

<!-- Standardization for UMAP preprocessing -->
<dependency>
<groupId>org.apache.commons</groupId>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
package org.mastodon.mamut.feature.dimensionalityreduction.pca;

import java.util.ArrayList;
import java.util.List;

import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.linalg.Matrix;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.mllib.linalg.distributed.RowMatrix;
import org.mastodon.mamut.feature.dimensionalityreduction.PlotPoints;
import org.mastodon.mamut.feature.dimensionalityreduction.RandomDataTools;

public class PCADemo
{
// TODO: check https://stackoverflow.com/questions/10604507/pca-implementation-in-java
public static void main( String[] args )
{
double[][] inputData = RandomDataTools.generateSampleData();
double[][] result = setUpPCA( inputData );
PlotPoints.plot( inputData, result, resultValues -> resultValues[ 0 ] > 10 );
}

static double[][] setUpPCA( double[][] inputData )
{
try (JavaSparkContext jsc = new JavaSparkContext( "local", "PCA" ))
{
List< Vector > data = new ArrayList<>();
for ( final double[] row : inputData )
data.add( Vectors.dense( row ) );
JavaRDD< Vector > rows = jsc.parallelize( data );
// Create a RowMatrix from JavaRDD<Vector>.
RowMatrix rowMatrix = new RowMatrix( rows.rdd() );
// Compute the top 2 principal components.
// Principal components are stored in a local dense matrix.
Matrix pc = rowMatrix.computePrincipalComponents( 2 );
// Project the rows to the linear space spanned by the top 4 principal components.
RowMatrix projected = rowMatrix.multiply( pc );
return projected.rows().toJavaRDD().collect().stream().map( Vector::toArray ).toArray( double[][]::new );
}
}
}

0 comments on commit e9c2988

Please sign in to comment.