Skip to content

Commit

Permalink
Replace t-SNE library and enable t-SNE unit tests
Browse files Browse the repository at this point in the history
* The previously used library com.github.lejon.T-SNE-Java has classpath conflicts with ejml-0.25.jar which comes with Fiji
  • Loading branch information
stefanhahmann committed Dec 3, 2024
1 parent f6597f3 commit 15a4613
Show file tree
Hide file tree
Showing 8 changed files with 28 additions and 73 deletions.
9 changes: 1 addition & 8 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -126,14 +126,7 @@
<version>1.0</version>
</dependency>

<!-- t-SNE -->
<dependency>
<groupId>com.github.lejon.T-SNE-Java</groupId>
<artifactId>tsne</artifactId>
<version>v2.6.4</version>
</dependency>

<!-- PCA -->
<!-- PCA, t-SNE -->
<dependency>
<groupId>com.github.haifengl</groupId>
<artifactId>smile-core</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,6 @@
import java.util.List;
import java.util.concurrent.locks.ReentrantReadWriteLock;

import com.jujutsu.tsne.TSneConfiguration;
import com.jujutsu.tsne.barneshut.BarnesHutTSne;
import com.jujutsu.tsne.barneshut.ParallelBHTsne;
import com.jujutsu.utils.TSneUtils;

import org.mastodon.RefPool;
import org.mastodon.graph.Edge;
import org.mastodon.graph.ReadOnlyGraph;
Expand All @@ -51,6 +46,8 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import smile.manifold.TSNE;

/**
* Abstract class for computing t-SNE features in the Mastodon project.
* <br>
Expand Down Expand Up @@ -98,14 +95,10 @@ protected void computeAlgorithm( double[][] dataMatrix )
+ ") requires the perplexity (" + tSneSettings.getPerplexity() + ") to not be higher than ("
+ tSneSettings.getMaxValidPerplexity( rows ) + ")." );
}
TSneConfiguration tSneConfig =
TSneUtils.buildConfig( dataMatrix, settings.getNumberOfOutputDimensions(), TSneSettings.INITIAL_DIMENSIONS,
tSneSettings.getPerplexity(),
tSneSettings.getMaxIterations(), TSneSettings.USE_PCA, TSneSettings.THETA, false, true );

BarnesHutTSne tsne = new ParallelBHTsne();
logger.info( "Computing t-SNE. Data matrix has {} rows x {} columns.", dataMatrix.length, dataMatrix[ 0 ].length );
tSneResult = tsne.tsne( tSneConfig );
TSNE tsne = new smile.manifold.TSNE( dataMatrix, settings.getNumberOfOutputDimensions(), tSneSettings.getPerplexity(), 200,
tSneSettings.getMaxIterations() );
tSneResult = tsne.coordinates;
logger.info( "Finished computing t-SNE. Results has {} rows x {} columns.", tSneResult.length,
tSneResult.length > 0 ? tSneResult[ 0 ].length : 0 );
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@
import net.imglib2.util.Cast;

import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.mastodon.feature.Dimension;
import org.mastodon.feature.FeatureModel;
Expand All @@ -64,7 +63,6 @@
import org.mastodon.mamut.model.branch.BranchSpot;
import org.scijava.Context;

@Disabled( "mvn test takes too long" )
public class BranchTSneFeatureTest extends AbstractFeatureTest< BranchSpot >
{
private BranchTSneFeature tSneFeature;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,32 +28,23 @@
*/
package org.mastodon.mamut.feature.dimensionalityreduction.tsne;

import com.jujutsu.tsne.TSneConfiguration;
import com.jujutsu.tsne.barneshut.BarnesHutTSne;
import com.jujutsu.tsne.barneshut.ParallelBHTsne;
import com.jujutsu.utils.TSneUtils;

import org.mastodon.mamut.feature.dimensionalityreduction.PlotPoints;
import org.mastodon.mamut.feature.dimensionalityreduction.RandomDataTools;

import smile.manifold.TSNE;

public class TSneDemo
{
public static void main( final String[] args )
{
double[][] inputData = RandomDataTools.generateSampleData();
TSneConfiguration config = setUpTSne( inputData );
BarnesHutTSne tsne = new ParallelBHTsne(); // according to https://github.com/lejon/T-SNE-Java/ the parallel version is faster at same accuracy
double[][] tsneResult = tsne.tsne( config );
PlotPoints.plot( inputData, tsneResult, resultValues -> resultValues[ 0 ] > 18 );
double[][] tsneResult = setUpTSne( inputData );
PlotPoints.plot( inputData, tsneResult, resultValues -> resultValues[ 1 ] > 0 );
}

static TSneConfiguration setUpTSne( double[][] inputData )
static double[][] setUpTSne( double[][] inputData )
{
// Recommendations for t-SNE defaults: https://scikit-learn.org/stable/modules/generated/sklearn.manifold.TSNE.html
int initialDimensions = 50; // used if PCA is true and dimensions of the input data are greater than this value
double perplexity = 30d; // recommended value is between 5 and 50
int maxIterations = 1000; // should be at least 250

return TSneUtils.buildConfig( inputData, 2, initialDimensions, perplexity, maxIterations, true, 0.5d, false, true );
TSNE tsne = new smile.manifold.TSNE( inputData, 2, 30d, 200, 1000 );
return tsne.coordinates;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,14 @@
import java.util.Arrays;
import java.util.List;

import com.jujutsu.tsne.TSneConfiguration;
import com.jujutsu.tsne.barneshut.BarnesHutTSne;
import com.jujutsu.tsne.barneshut.ParallelBHTsne;
import com.jujutsu.utils.TSneUtils;
import com.opencsv.CSVReader;
import com.opencsv.exceptions.CsvValidationException;

import org.mastodon.mamut.feature.dimensionalityreduction.PlotPoints;
import org.mastodon.mamut.feature.dimensionalityreduction.util.StandardScaler;

import smile.manifold.TSNE;

public class TSneDemoTgmmMini
{
public static void main( String[] args ) throws IOException, CsvValidationException
Expand All @@ -34,21 +32,15 @@ public static void main( String[] args ) throws IOException, CsvValidationExcept
}
double[][] inputData = data.toArray( new double[ data.size() ][ data.get( 0 ).length ] );
StandardScaler.standardizeColumns( inputData );
TSneConfiguration config = setUpTSne( inputData );
BarnesHutTSne tsne = new ParallelBHTsne(); // according to https://github.com/lejon/T-SNE-Java/ the parallel version is faster at same accuracy
double[][] result = tsne.tsne( config );
result = Arrays.stream( result ).map( row -> Arrays.stream( row ).map( value -> value * 5d ).toArray() ) // scale up
double[][] result = setUpTSne( inputData );
result = Arrays.stream( result ).map( row -> Arrays.stream( row ).map( value -> value * 2d ).toArray() ) // scale up
.toArray( double[][]::new );
PlotPoints.plot( null, result, null );
}

static TSneConfiguration setUpTSne( double[][] inputData )
static double[][] setUpTSne( double[][] inputData )
{
// Recommendations for t-SNE defaults: https://scikit-learn.org/stable/modules/generated/sklearn.manifold.TSNE.html
int initialDimensions = 50; // used if PCA is true and dimensions of the input data are greater than this value
double perplexity = 30d; // recommended value is between 5 and 50
int maxIterations = 1000; // should be at least 250

return TSneUtils.buildConfig( inputData, 2, initialDimensions, perplexity, maxIterations, true, 0.5d, false, true );
TSNE tsne = new smile.manifold.TSNE( inputData, 2, 30d, 200, 1000 );
return tsne.coordinates;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,18 +33,13 @@

import java.lang.invoke.MethodHandles;

import com.jujutsu.tsne.TSneConfiguration;
import com.jujutsu.tsne.barneshut.BarnesHutTSne;
import com.jujutsu.tsne.barneshut.ParallelBHTsne;
import com.jujutsu.utils.TSneUtils;

import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.mastodon.mamut.feature.dimensionalityreduction.RandomDataTools;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@Disabled( "mvn test takes too long" )
import smile.manifold.TSNE;

class TSneTest
{
private static final Logger logger = LoggerFactory.getLogger( MethodHandles.lookup().lookupClass() );
Expand All @@ -58,22 +53,18 @@ void test()
logger.debug( "dimensions rows: {}, columns:{}", inputData.length, inputData[ 0 ].length );

// Recommendations for t-SNE defaults: https://scikit-learn.org/stable/modules/generated/sklearn.manifold.TSNE.html
int initialDimensions = 50; // used if PCA is true and dimensions of the input data are greater than this value
double perplexity = 30d; // recommended value is between 5 and 50
int maxIterations = 1000; // should be at least 250

TSneConfiguration tSneConfig =
TSneUtils.buildConfig( inputData, 2, initialDimensions, perplexity, maxIterations, true, 0.5d, false, true );

BarnesHutTSne tsne = new ParallelBHTsne();
double[][] tsneResult = tsne.tsne( tSneConfig );
TSNE tsne = new smile.manifold.TSNE( inputData, 2, perplexity, 200, maxIterations );
double[][] tsneResult = tsne.coordinates;

assertEquals( tsneResult.length, inputData.length );
assertEquals( 2, tsneResult[ 0 ].length );

for ( int i = 0; i < numCluster1; i++ )
assertTrue( tsneResult[ i ][ 0 ] > 18 );
assertTrue( tsneResult[ i ][ 1 ] < 0 );
for ( int i = numCluster1; i < numCluster1 + numCluster2; i++ )
assertTrue( tsneResult[ i ][ 0 ] < 18 );
assertTrue( tsneResult[ i ][ 1 ] > 0 );
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,12 @@
*/
package org.mastodon.mamut.feature.dimensionalityreduction.tsne.feature;

import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
import static org.junit.jupiter.api.Assertions.assertThrows;

import java.util.List;
import java.util.function.Supplier;

import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.mastodon.feature.FeatureModel;
import org.mastodon.mamut.feature.branch.exampleGraph.ExampleGraph2;
Expand All @@ -44,7 +44,6 @@
import org.mastodon.mamut.model.Spot;
import org.scijava.Context;

@Disabled( "mvn test takes too long" )
class AbstractTSneFeatureComputerTest
{

Expand All @@ -61,7 +60,7 @@ void testDataDrivenExceptions()
() -> InputDimension.getListFromFeatureModel( featureModel, Spot.class, Link.class );
assertThrows( IllegalArgumentException.class, () -> controller.computeFeature( inputDimensionsSupplier ) );
controller.getTSneSettings().setPerplexity( 2 );
assertThrows( ArrayIndexOutOfBoundsException.class, () -> controller.computeFeature( inputDimensionsSupplier ) );
assertDoesNotThrow( () -> controller.computeFeature( inputDimensionsSupplier ) );
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
import java.util.function.Supplier;

import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.mastodon.feature.Dimension;
import org.mastodon.feature.FeatureModel;
Expand All @@ -57,7 +56,6 @@
import org.mastodon.mamut.model.Spot;
import org.scijava.Context;

@Disabled( "mvn test takes too long" )
public class SpotTSneFeatureTest extends AbstractFeatureTest< Spot >
{
private SpotTSneFeature spotTSneFeature;
Expand Down

0 comments on commit 15a4613

Please sign in to comment.