Skip to content

Commit

Permalink
generalize
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosuc3m committed Sep 30, 2024
1 parent 9d2cdba commit ff8a30b
Showing 1 changed file with 135 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,16 @@
import net.imglib2.loops.LoopBuilder;
import net.imglib2.type.NativeType;
import net.imglib2.type.numeric.RealType;
import net.imglib2.type.numeric.integer.ByteType;
import net.imglib2.type.numeric.integer.IntType;
import net.imglib2.type.numeric.integer.LongType;
import net.imglib2.type.numeric.integer.ShortType;
import net.imglib2.type.numeric.integer.UnsignedByteType;
import net.imglib2.type.numeric.integer.UnsignedIntType;
import net.imglib2.type.numeric.integer.UnsignedShortType;
import net.imglib2.type.numeric.real.DoubleType;
import net.imglib2.type.numeric.real.FloatType;
import net.imglib2.util.Util;
import net.imglib2.view.IntervalView;
import net.imglib2.view.Views;

Expand Down Expand Up @@ -115,7 +124,7 @@ public < R extends RealType< R > & NativeType< R > > Tensor< FloatType > apply(
}

@Override
public void applyInPlace(Tensor<FloatType> input) {
public < R extends RealType< R > & NativeType< R > > void applyInPlace(Tensor<R> input) {
String selectedAxes = "";
for (String ax : input.getAxesOrderString().split("")) {
if (axes != null && !axes.toLowerCase().contains(ax.toLowerCase())
Expand All @@ -131,34 +140,95 @@ public void applyInPlace(Tensor<FloatType> input) {

}

private void globalScale( final Tensor< FloatType > output ) {
float minPercentileVal = findPercentileValue(output.getData(), minPercentile);
float maxPercentileVal = findPercentileValue(output.getData(), maxPercentile);
LoopBuilder.setImages( output.getData() )
.multiThreaded()
.forEachPixel( i -> i.set( ( i.get() - minPercentileVal ) / ( maxPercentileVal - minPercentileVal + eps ) ) );
private < R extends RealType< R > & NativeType< R > > void globalScale( final Tensor< R > output ) {
double minPercentileVal = findPercentileValue(output.getData(), minPercentile);
double maxPercentileVal = findPercentileValue(output.getData(), maxPercentile);
scaleRange(output.getData(), maxPercentileVal, minPercentileVal);
}

private float findPercentileValue(RandomAccessibleInterval<FloatType> rai, double percentile) {
final IterableInterval<FloatType> flatImage = Views.iterable(rai);
final Cursor<FloatType> cursor = flatImage.cursor();
@SuppressWarnings("unchecked")
private < R extends RealType< R > & NativeType< R > >
double findPercentileValue(RandomAccessibleInterval<R> rai, double percentile) {
final IterableInterval<R> flatImage = Views.iterable(rai);
long flatSize = Arrays.stream(flatImage.dimensionsAsLongArray()).reduce(1, (a, b) -> a * b);
double[] flatArr = new double[(int) flatSize];

int count = 0;
while ( cursor.hasNext() )
{
cursor.next();
flatArr[count ++] = cursor.get().get();
if (rai.getAt(0) instanceof ByteType) {
final Cursor<ByteType> cursor = (Cursor<ByteType>) flatImage.cursor();
while ( cursor.hasNext() )
{
cursor.next();
flatArr[count ++] = (double) cursor.get().get();
}
} else if (rai.getAt(0) instanceof UnsignedByteType) {
final Cursor<ByteType> cursor = (Cursor<ByteType>) flatImage.cursor();
while ( cursor.hasNext() )
{
cursor.next();
flatArr[count ++] = (double) cursor.get().get();
}
} else if (rai.getAt(0) instanceof ShortType) {
final Cursor<ByteType> cursor = (Cursor<ByteType>) flatImage.cursor();
while ( cursor.hasNext() )
{
cursor.next();
flatArr[count ++] = (double) cursor.get().get();
}
} else if (rai.getAt(0) instanceof UnsignedShortType) {
final Cursor<ByteType> cursor = (Cursor<ByteType>) flatImage.cursor();
while ( cursor.hasNext() )
{
cursor.next();
flatArr[count ++] = (double) cursor.get().get();
}
} else if (rai.getAt(0) instanceof IntType) {
final Cursor<ByteType> cursor = (Cursor<ByteType>) flatImage.cursor();
while ( cursor.hasNext() )
{
cursor.next();
flatArr[count ++] = (double) cursor.get().get();
}
} else if (rai.getAt(0) instanceof UnsignedIntType) {
final Cursor<ByteType> cursor = (Cursor<ByteType>) flatImage.cursor();
while ( cursor.hasNext() )
{
cursor.next();
flatArr[count ++] = (double) cursor.get().get();
}
} else if (rai.getAt(0) instanceof LongType) {
final Cursor<ByteType> cursor = (Cursor<ByteType>) flatImage.cursor();
while ( cursor.hasNext() )
{
cursor.next();
flatArr[count ++] = (double) cursor.get().get();
}
} else if (rai.getAt(0) instanceof FloatType) {
final Cursor<ByteType> cursor = (Cursor<ByteType>) flatImage.cursor();
while ( cursor.hasNext() )
{
cursor.next();
flatArr[count ++] = (double) cursor.get().get();
}
} else if (rai.getAt(0) instanceof DoubleType) {
final Cursor<ByteType> cursor = (Cursor<ByteType>) flatImage.cursor();
while ( cursor.hasNext() )
{
cursor.next();
flatArr[count ++] = (double) cursor.get().get();
}
} else {
throw new IllegalArgumentException("Unsupported data type: " + Util.getTypeFromInterval(rai));
}
Arrays.sort(flatArr);

int percentilePos = (int) (flatSize * percentile);
percentilePos = percentilePos >= flatArr.length ? flatArr.length - 1 : percentilePos;
return (float) flatArr[percentilePos];
return flatArr[percentilePos];
}

private void axesScale( final Tensor< FloatType > output, String axesOfInterest) {
private < R extends RealType< R > & NativeType< R > >
void axesScale( final Tensor< R > output, String axesOfInterest) {
long[] start = new long[output.getData().numDimensions()];
long[] dims = output.getData().dimensionsAsLongArray();
long[] indOfDims = new long[axesOfInterest.length()];
Expand All @@ -171,7 +241,6 @@ private void axesScale( final Tensor< FloatType > output, String axesOfInterest)
}

long[][] points = getAllCombinations(sizeOfDims);
int c = 0;
for (long[] pp : points) {
for (int i = 0; i < pp.length; i ++) {
start[(int) indOfDims[i]] = pp[i];
Expand All @@ -180,12 +249,10 @@ private void axesScale( final Tensor< FloatType > output, String axesOfInterest)
// Define the view by defining the length per axis
long[] end = new long[dims.length];
for (int i = 0; i < dims.length; i ++) end[i] = dims[i] - start[i];
IntervalView<FloatType> plane = Views.offsetInterval( output.getData(), start, end );
float minPercentileVal = findPercentileValue(plane, minPercentile);
float maxPercentileVal = findPercentileValue(plane, maxPercentile);
LoopBuilder.setImages( plane )
.multiThreaded()
.forEachPixel( i -> i.set( ( i.get() - minPercentileVal ) / ( maxPercentileVal - minPercentileVal + eps) ) );
IntervalView<R> plane = Views.offsetInterval( output.getData(), start, end );
double minPercentileVal = findPercentileValue(plane, minPercentile);
double maxPercentileVal = findPercentileValue(plane, maxPercentile);
scaleRange(plane, maxPercentileVal, minPercentileVal);
}
}

Expand Down Expand Up @@ -238,4 +305,49 @@ public static void test2() {
preprocessing.applyInPlace(tt);
System.out.print(true);
}

@SuppressWarnings("unchecked")
public < R extends RealType< R > & NativeType< R > >
void scaleRange(RandomAccessibleInterval<R> rai, double maxPercentileVal, double minPercentileVal) {
double diff = maxPercentileVal - minPercentileVal;
if (rai.getAt(0) instanceof ByteType) {
LoopBuilder.setImages( (RandomAccessibleInterval<ByteType>) rai )
.multiThreaded()
.forEachPixel( i -> i.set((byte) ((i.get() - minPercentileVal) / (diff + eps)) ) );
} else if (rai.getAt(0) instanceof UnsignedByteType) {
LoopBuilder.setImages( (RandomAccessibleInterval<UnsignedByteType>) rai )
.multiThreaded()
.forEachPixel( i -> i.set((int) ((i.get() - minPercentileVal) / (diff + eps)) ) );
} else if (rai.getAt(0) instanceof ShortType) {
LoopBuilder.setImages((RandomAccessibleInterval<ShortType>) rai )
.multiThreaded()
.forEachPixel( i -> i.set((short) ((i.get() - minPercentileVal) / (diff + eps)) ) );
} else if (rai.getAt(0) instanceof UnsignedShortType) {
LoopBuilder.setImages((RandomAccessibleInterval<UnsignedShortType>) rai )
.multiThreaded()
.forEachPixel( i -> i.set((int) ((i.get() - minPercentileVal) / (diff + eps)) ) );
} else if (rai.getAt(0) instanceof IntType) {
LoopBuilder.setImages((RandomAccessibleInterval<IntType>) rai )
.multiThreaded()
.forEachPixel( i -> i.set((int) ((i.get() - minPercentileVal) / (diff + eps)) ) );
} else if (rai.getAt(0) instanceof UnsignedIntType) {
LoopBuilder.setImages((RandomAccessibleInterval<UnsignedIntType>) rai )
.multiThreaded()
.forEachPixel( i -> i.set((long) ((i.get() - minPercentileVal) / (diff + eps)) ) );
} else if (rai.getAt(0) instanceof LongType) {
LoopBuilder.setImages((RandomAccessibleInterval<LongType>) rai )
.multiThreaded()
.forEachPixel( i -> i.set((long) ((i.get() - minPercentileVal) / (diff + eps)) ) );
} else if (rai.getAt(0) instanceof FloatType) {
LoopBuilder.setImages((RandomAccessibleInterval<FloatType>) rai )
.multiThreaded()
.forEachPixel( i -> i.set((float) ((i.get() - minPercentileVal) / (diff + eps))) );
} else if (rai.getAt(0) instanceof DoubleType) {
LoopBuilder.setImages((RandomAccessibleInterval<DoubleType>) rai )
.multiThreaded()
.forEachPixel( i -> i.set((double) ((i.get() - minPercentileVal) / (diff + eps)) ) );
} else {
throw new IllegalArgumentException("Unsupported data type: " + Util.getTypeFromInterval(rai));
}
}
}

0 comments on commit ff8a30b

Please sign in to comment.