Skip to content

Commit

Permalink
generalize
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosuc3m committed Sep 30, 2024
1 parent 63bf16f commit a91c77e
Showing 1 changed file with 69 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,16 @@
import net.imglib2.loops.LoopBuilder;
import net.imglib2.type.NativeType;
import net.imglib2.type.numeric.RealType;
import net.imglib2.type.numeric.integer.ByteType;
import net.imglib2.type.numeric.integer.IntType;
import net.imglib2.type.numeric.integer.LongType;
import net.imglib2.type.numeric.integer.ShortType;
import net.imglib2.type.numeric.integer.UnsignedByteType;
import net.imglib2.type.numeric.integer.UnsignedIntType;
import net.imglib2.type.numeric.integer.UnsignedShortType;
import net.imglib2.type.numeric.real.DoubleType;
import net.imglib2.type.numeric.real.FloatType;
import net.imglib2.util.Util;
import net.imglib2.view.IntervalView;
import net.imglib2.view.Views;

Expand Down Expand Up @@ -139,11 +148,11 @@ else if (mode instanceof Mode)

public void checkRequiredArgs() {
if (this.mode == Mode.FIXED && this.meanArr == null && this.meanDouble == null) {
throw new IllegalArgumentException(String.format(DEFAULT_MISSING_ARG_ERR, "mean")
throw new IllegalArgumentException(String.format(DEFAULT_MISSING_ARG_ERR, name, "mean")
+ System.lineSeparator() + "If 'mode' parameter equals 'fixed', the 'mean' "
+ "argument should be provided too.");
} else if (this.mode == Mode.FIXED && this.stdArr == null && this.stdDouble == null) {
throw new IllegalArgumentException(String.format(DEFAULT_MISSING_ARG_ERR, "std")
throw new IllegalArgumentException(String.format(DEFAULT_MISSING_ARG_ERR, name, "std")
+ System.lineSeparator() + "If 'mode' parameter equals 'fixed', the 'std' "
+ "argument should be provided too.");
} else if (this.mode == Mode.FIXED && ((stdDouble == null && meanDouble != null)
Expand All @@ -166,7 +175,7 @@ public < R extends RealType< R > & NativeType< R > > Tensor< FloatType > apply(
}

@Override
public void applyInPlace( final Tensor< FloatType > input )
public < R extends RealType< R > & NativeType< R > > void applyInPlace( final Tensor< R > input )
{
checkRequiredArgs();
String selectedAxes = "";
Expand Down Expand Up @@ -207,14 +216,11 @@ else if (meanDouble != null)
}
}

private void fixedModeGlobalMeanStd( final Tensor< FloatType > output ) {

LoopBuilder.setImages( output.getData() )
.multiThreaded()
.forEachPixel( i -> i.set( ( i.get() - meanDouble.floatValue() ) / ( stdDouble.floatValue() + eps ) ) );
private < R extends RealType< R > & NativeType< R > > void fixedModeGlobalMeanStd( final Tensor< R > output ) {
zeroMeanUnitVariance(output.getData(), meanDouble.doubleValue(), stdDouble.doubleValue());
}

private void notFixedAxesMeanStd( final Tensor< FloatType > output, String axesOfInterest) {
private < R extends RealType< R > & NativeType< R > > void notFixedAxesMeanStd( final Tensor< R > output, String axesOfInterest) {
long[] start = new long[output.getData().numDimensions()];
long[] dims = output.getData().dimensionsAsLongArray();
long[] indOfDims = new long[axesOfInterest.length()];
Expand All @@ -236,17 +242,15 @@ private void notFixedAxesMeanStd( final Tensor< FloatType > output, String axesO
// Define the view by defining the length per axis
long[] end = new long[dims.length];
for (int i = 0; i < dims.length; i ++) end[i] = dims[i] - start[i];
IntervalView<FloatType> plane = Views.offsetInterval( output.getData(), start, end );
IntervalView<R> plane = Views.offsetInterval( output.getData(), start, end );
final float[] meanStd = meanStd( plane );
final float mean = meanStd[ 0 ];
final float std = meanStd[ 1 ];
LoopBuilder.setImages( plane )
.multiThreaded()
.forEachPixel( i -> i.set( ( i.get() - mean ) / ( std + eps ) ) );
zeroMeanUnitVariance(output.getData(), mean, std);
}
}

private void fixedAxesMeanStd( final Tensor< FloatType > output, String axesOfInterest) {
private < R extends RealType< R > & NativeType< R > > void fixedAxesMeanStd( final Tensor< R > output, String axesOfInterest) {
long[] start = new long[output.getData().numDimensions()];
long[] dims = output.getData().dimensionsAsLongArray();
long[] indOfDims = new long[axesOfInterest.length()];
Expand All @@ -268,31 +272,27 @@ private void fixedAxesMeanStd( final Tensor< FloatType > output, String axesOfIn
// Define the view by defining the length per axis
long[] end = new long[dims.length];
for (int i = 0; i < dims.length; i ++) end[i] = dims[i] - start[i];
IntervalView<FloatType> plane = Views.offsetInterval( output.getData(), start, end );
IntervalView<R> plane = Views.offsetInterval( output.getData(), start, end );
final float mean = (float) this.meanArr[c];
final float std = (float) this.stdArr[c ++ ];
LoopBuilder.setImages( plane )
.multiThreaded()
.forEachPixel( i -> i.set( ( i.get() - mean ) / ( std + eps ) ) );
zeroMeanUnitVariance(plane, mean, std);
}
}

private void notFixedModeGlobalMeanStd( final Tensor< FloatType > output ) {
private < R extends RealType< R > & NativeType< R > > void notFixedModeGlobalMeanStd( final Tensor< R > output ) {

final float[] meanStd = meanStd( output.getData() );
final float mean = meanStd[ 0 ];
final float std = meanStd[ 1 ];
LoopBuilder.setImages( output.getData() )
.multiThreaded()
.forEachPixel( i -> i.set( ( i.get() - mean ) / ( std + eps ) ) );
zeroMeanUnitVariance(output.getData(), mean, std);
}

public static float[] meanStd( final RandomAccessibleInterval< FloatType > rai )
public static < R extends RealType< R > & NativeType< R > > float[] meanStd( final RandomAccessibleInterval< R > rai )
{
// Mean.
double sum = 0.;
long n = 0;
for ( final FloatType p : Views.iterable( rai ) )
for ( final R p : Views.iterable( rai ) )
{
sum += p.getRealDouble();
n++;
Expand All @@ -304,7 +304,7 @@ public static float[] meanStd( final RandomAccessibleInterval< FloatType > rai )

// Variance.
double sumdx2 = 0.;
for ( final FloatType p : Views.iterable( rai ) )
for ( final R p : Views.iterable( rai ) )
{
final double dx = p.getRealDouble() - mean;
sumdx2 += dx * dx;
Expand Down Expand Up @@ -383,4 +383,48 @@ public static void test3() {
preprocessing.applyInPlace(tt);
System.out.print(true);
}

@SuppressWarnings("unchecked")
public < R extends RealType< R > & NativeType< R > >
void zeroMeanUnitVariance(RandomAccessibleInterval<R> rai, double mean, double std) {
if (rai.getAt(0) instanceof ByteType) {
LoopBuilder.setImages( (RandomAccessibleInterval<ByteType>) rai )
.multiThreaded()
.forEachPixel( i -> i.set((byte) ((i.get() - mean) / (std + eps)) ) );
} else if (rai.getAt(0) instanceof UnsignedByteType) {
LoopBuilder.setImages( (RandomAccessibleInterval<UnsignedByteType>) rai )
.multiThreaded()
.forEachPixel( i -> i.set((int) ((i.get() - mean) / (std + eps)) ) );
} else if (rai.getAt(0) instanceof ShortType) {
LoopBuilder.setImages((RandomAccessibleInterval<ShortType>) rai )
.multiThreaded()
.forEachPixel( i -> i.set((short) ((i.get() - mean) / (std + eps)) ) );
} else if (rai.getAt(0) instanceof UnsignedShortType) {
LoopBuilder.setImages((RandomAccessibleInterval<UnsignedShortType>) rai )
.multiThreaded()
.forEachPixel( i -> i.set((int) ((i.get() - mean) / (std + eps)) ) );
} else if (rai.getAt(0) instanceof IntType) {
LoopBuilder.setImages((RandomAccessibleInterval<IntType>) rai )
.multiThreaded()
.forEachPixel( i -> i.set((int) ((i.get() - mean) / (std + eps)) ) );
} else if (rai.getAt(0) instanceof UnsignedIntType) {
LoopBuilder.setImages((RandomAccessibleInterval<UnsignedIntType>) rai )
.multiThreaded()
.forEachPixel( i -> i.set((long) ((i.get() - mean) / (std + eps)) ) );
} else if (rai.getAt(0) instanceof LongType) {
LoopBuilder.setImages((RandomAccessibleInterval<LongType>) rai )
.multiThreaded()
.forEachPixel( i -> i.set((long) ((i.get() - mean) / (std + eps)) ) );
} else if (rai.getAt(0) instanceof FloatType) {
LoopBuilder.setImages((RandomAccessibleInterval<FloatType>) rai )
.multiThreaded()
.forEachPixel( i -> i.set((float) ((i.get() - mean) / (std + eps))) );
} else if (rai.getAt(0) instanceof DoubleType) {
LoopBuilder.setImages((RandomAccessibleInterval<DoubleType>) rai )
.multiThreaded()
.forEachPixel( i -> i.set((double) ((i.get() - mean) / (std + eps)) ) );
} else {
throw new IllegalArgumentException("Unsupported data type: " + Util.getTypeFromInterval(rai));
}
}
}

0 comments on commit a91c77e

Please sign in to comment.