diff --git a/src/main/java/io/bioimage/modelrunner/transformations/ScaleRangeTransformation.java b/src/main/java/io/bioimage/modelrunner/transformations/ScaleRangeTransformation.java index e24acb4a..5e33247d 100644 --- a/src/main/java/io/bioimage/modelrunner/transformations/ScaleRangeTransformation.java +++ b/src/main/java/io/bioimage/modelrunner/transformations/ScaleRangeTransformation.java @@ -32,7 +32,16 @@ import net.imglib2.loops.LoopBuilder; import net.imglib2.type.NativeType; import net.imglib2.type.numeric.RealType; +import net.imglib2.type.numeric.integer.ByteType; +import net.imglib2.type.numeric.integer.IntType; +import net.imglib2.type.numeric.integer.LongType; +import net.imglib2.type.numeric.integer.ShortType; +import net.imglib2.type.numeric.integer.UnsignedByteType; +import net.imglib2.type.numeric.integer.UnsignedIntType; +import net.imglib2.type.numeric.integer.UnsignedShortType; +import net.imglib2.type.numeric.real.DoubleType; import net.imglib2.type.numeric.real.FloatType; +import net.imglib2.util.Util; import net.imglib2.view.IntervalView; import net.imglib2.view.Views; @@ -115,7 +124,7 @@ public < R extends RealType< R > & NativeType< R > > Tensor< FloatType > apply( } @Override - public void applyInPlace(Tensor input) { + public < R extends RealType< R > & NativeType< R > > void applyInPlace(Tensor input) { String selectedAxes = ""; for (String ax : input.getAxesOrderString().split("")) { if (axes != null && !axes.toLowerCase().contains(ax.toLowerCase()) @@ -131,34 +140,95 @@ public void applyInPlace(Tensor input) { } - private void globalScale( final Tensor< FloatType > output ) { - float minPercentileVal = findPercentileValue(output.getData(), minPercentile); - float maxPercentileVal = findPercentileValue(output.getData(), maxPercentile); - LoopBuilder.setImages( output.getData() ) - .multiThreaded() - .forEachPixel( i -> i.set( ( i.get() - minPercentileVal ) / ( maxPercentileVal - minPercentileVal + eps ) ) ); + private < R extends RealType< R > & NativeType< R > > void globalScale( final Tensor< R > output ) { + double minPercentileVal = findPercentileValue(output.getData(), minPercentile); + double maxPercentileVal = findPercentileValue(output.getData(), maxPercentile); + scaleRange(output.getData(), maxPercentileVal, minPercentileVal); } - private float findPercentileValue(RandomAccessibleInterval rai, double percentile) { - final IterableInterval flatImage = Views.iterable(rai); - final Cursor cursor = flatImage.cursor(); + @SuppressWarnings("unchecked") + private < R extends RealType< R > & NativeType< R > > + double findPercentileValue(RandomAccessibleInterval rai, double percentile) { + final IterableInterval flatImage = Views.iterable(rai); long flatSize = Arrays.stream(flatImage.dimensionsAsLongArray()).reduce(1, (a, b) -> a * b); double[] flatArr = new double[(int) flatSize]; int count = 0; - while ( cursor.hasNext() ) - { - cursor.next(); - flatArr[count ++] = cursor.get().get(); + if (rai.getAt(0) instanceof ByteType) { + final Cursor cursor = (Cursor) flatImage.cursor(); + while ( cursor.hasNext() ) + { + cursor.next(); + flatArr[count ++] = (double) cursor.get().get(); + } + } else if (rai.getAt(0) instanceof UnsignedByteType) { + final Cursor cursor = (Cursor) flatImage.cursor(); + while ( cursor.hasNext() ) + { + cursor.next(); + flatArr[count ++] = (double) cursor.get().get(); + } + } else if (rai.getAt(0) instanceof ShortType) { + final Cursor cursor = (Cursor) flatImage.cursor(); + while ( cursor.hasNext() ) + { + cursor.next(); + flatArr[count ++] = (double) cursor.get().get(); + } + } else if (rai.getAt(0) instanceof UnsignedShortType) { + final Cursor cursor = (Cursor) flatImage.cursor(); + while ( cursor.hasNext() ) + { + cursor.next(); + flatArr[count ++] = (double) cursor.get().get(); + } + } else if (rai.getAt(0) instanceof IntType) { + final Cursor cursor = (Cursor) flatImage.cursor(); + while ( cursor.hasNext() ) + { + cursor.next(); + flatArr[count ++] = (double) cursor.get().get(); + } + } else if (rai.getAt(0) instanceof UnsignedIntType) { + final Cursor cursor = (Cursor) flatImage.cursor(); + while ( cursor.hasNext() ) + { + cursor.next(); + flatArr[count ++] = (double) cursor.get().get(); + } + } else if (rai.getAt(0) instanceof LongType) { + final Cursor cursor = (Cursor) flatImage.cursor(); + while ( cursor.hasNext() ) + { + cursor.next(); + flatArr[count ++] = (double) cursor.get().get(); + } + } else if (rai.getAt(0) instanceof FloatType) { + final Cursor cursor = (Cursor) flatImage.cursor(); + while ( cursor.hasNext() ) + { + cursor.next(); + flatArr[count ++] = (double) cursor.get().get(); + } + } else if (rai.getAt(0) instanceof DoubleType) { + final Cursor cursor = (Cursor) flatImage.cursor(); + while ( cursor.hasNext() ) + { + cursor.next(); + flatArr[count ++] = (double) cursor.get().get(); + } + } else { + throw new IllegalArgumentException("Unsupported data type: " + Util.getTypeFromInterval(rai)); } Arrays.sort(flatArr); int percentilePos = (int) (flatSize * percentile); percentilePos = percentilePos >= flatArr.length ? flatArr.length - 1 : percentilePos; - return (float) flatArr[percentilePos]; + return flatArr[percentilePos]; } - private void axesScale( final Tensor< FloatType > output, String axesOfInterest) { + private < R extends RealType< R > & NativeType< R > > + void axesScale( final Tensor< R > output, String axesOfInterest) { long[] start = new long[output.getData().numDimensions()]; long[] dims = output.getData().dimensionsAsLongArray(); long[] indOfDims = new long[axesOfInterest.length()]; @@ -171,7 +241,6 @@ private void axesScale( final Tensor< FloatType > output, String axesOfInterest) } long[][] points = getAllCombinations(sizeOfDims); - int c = 0; for (long[] pp : points) { for (int i = 0; i < pp.length; i ++) { start[(int) indOfDims[i]] = pp[i]; @@ -180,12 +249,10 @@ private void axesScale( final Tensor< FloatType > output, String axesOfInterest) // Define the view by defining the length per axis long[] end = new long[dims.length]; for (int i = 0; i < dims.length; i ++) end[i] = dims[i] - start[i]; - IntervalView plane = Views.offsetInterval( output.getData(), start, end ); - float minPercentileVal = findPercentileValue(plane, minPercentile); - float maxPercentileVal = findPercentileValue(plane, maxPercentile); - LoopBuilder.setImages( plane ) - .multiThreaded() - .forEachPixel( i -> i.set( ( i.get() - minPercentileVal ) / ( maxPercentileVal - minPercentileVal + eps) ) ); + IntervalView plane = Views.offsetInterval( output.getData(), start, end ); + double minPercentileVal = findPercentileValue(plane, minPercentile); + double maxPercentileVal = findPercentileValue(plane, maxPercentile); + scaleRange(plane, maxPercentileVal, minPercentileVal); } } @@ -238,4 +305,49 @@ public static void test2() { preprocessing.applyInPlace(tt); System.out.print(true); } + + @SuppressWarnings("unchecked") + public < R extends RealType< R > & NativeType< R > > + void scaleRange(RandomAccessibleInterval rai, double maxPercentileVal, double minPercentileVal) { + double diff = maxPercentileVal - minPercentileVal; + if (rai.getAt(0) instanceof ByteType) { + LoopBuilder.setImages( (RandomAccessibleInterval) rai ) + .multiThreaded() + .forEachPixel( i -> i.set((byte) ((i.get() - minPercentileVal) / (diff + eps)) ) ); + } else if (rai.getAt(0) instanceof UnsignedByteType) { + LoopBuilder.setImages( (RandomAccessibleInterval) rai ) + .multiThreaded() + .forEachPixel( i -> i.set((int) ((i.get() - minPercentileVal) / (diff + eps)) ) ); + } else if (rai.getAt(0) instanceof ShortType) { + LoopBuilder.setImages((RandomAccessibleInterval) rai ) + .multiThreaded() + .forEachPixel( i -> i.set((short) ((i.get() - minPercentileVal) / (diff + eps)) ) ); + } else if (rai.getAt(0) instanceof UnsignedShortType) { + LoopBuilder.setImages((RandomAccessibleInterval) rai ) + .multiThreaded() + .forEachPixel( i -> i.set((int) ((i.get() - minPercentileVal) / (diff + eps)) ) ); + } else if (rai.getAt(0) instanceof IntType) { + LoopBuilder.setImages((RandomAccessibleInterval) rai ) + .multiThreaded() + .forEachPixel( i -> i.set((int) ((i.get() - minPercentileVal) / (diff + eps)) ) ); + } else if (rai.getAt(0) instanceof UnsignedIntType) { + LoopBuilder.setImages((RandomAccessibleInterval) rai ) + .multiThreaded() + .forEachPixel( i -> i.set((long) ((i.get() - minPercentileVal) / (diff + eps)) ) ); + } else if (rai.getAt(0) instanceof LongType) { + LoopBuilder.setImages((RandomAccessibleInterval) rai ) + .multiThreaded() + .forEachPixel( i -> i.set((long) ((i.get() - minPercentileVal) / (diff + eps)) ) ); + } else if (rai.getAt(0) instanceof FloatType) { + LoopBuilder.setImages((RandomAccessibleInterval) rai ) + .multiThreaded() + .forEachPixel( i -> i.set((float) ((i.get() - minPercentileVal) / (diff + eps))) ); + } else if (rai.getAt(0) instanceof DoubleType) { + LoopBuilder.setImages((RandomAccessibleInterval) rai ) + .multiThreaded() + .forEachPixel( i -> i.set((double) ((i.get() - minPercentileVal) / (diff + eps)) ) ); + } else { + throw new IllegalArgumentException("Unsupported data type: " + Util.getTypeFromInterval(rai)); + } + } }