Skip to content

Commit

Permalink
generalize the method
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosuc3m committed Sep 30, 2024
1 parent a91c77e commit 9d2cdba
Showing 1 changed file with 62 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,20 @@
import java.util.ArrayList;

import io.bioimage.modelrunner.tensor.Tensor;

import net.imglib2.RandomAccessibleInterval;
import net.imglib2.loops.LoopBuilder;
import net.imglib2.type.NativeType;
import net.imglib2.type.numeric.RealType;
import net.imglib2.type.numeric.integer.ByteType;
import net.imglib2.type.numeric.integer.IntType;
import net.imglib2.type.numeric.integer.LongType;
import net.imglib2.type.numeric.integer.ShortType;
import net.imglib2.type.numeric.integer.UnsignedByteType;
import net.imglib2.type.numeric.integer.UnsignedIntType;
import net.imglib2.type.numeric.integer.UnsignedShortType;
import net.imglib2.type.numeric.real.DoubleType;
import net.imglib2.type.numeric.real.FloatType;
import net.imglib2.util.Util;
import net.imglib2.view.IntervalView;
import net.imglib2.view.Views;

Expand Down Expand Up @@ -119,9 +128,9 @@ public void setAxes(Object axes) {

public void checkRequiredArgs() {
if (offsetDouble == null && offsetArr == null) {
throw new IllegalArgumentException(String.format(DEFAULT_MISSING_ARG_ERR, "offset"));
throw new IllegalArgumentException(String.format(DEFAULT_MISSING_ARG_ERR, name, "offset"));
} else if (gainDouble == null && gainArr == null) {
throw new IllegalArgumentException(String.format(DEFAULT_MISSING_ARG_ERR, "gain"));
throw new IllegalArgumentException(String.format(DEFAULT_MISSING_ARG_ERR, name, "gain"));
} else if ((offsetDouble == null && gainDouble != null)
|| (offsetDouble != null && gainDouble == null)) {
throw new IllegalArgumentException("Both arguments 'gain' and "
Expand All @@ -142,7 +151,7 @@ public < R extends RealType< R > & NativeType< R > > Tensor< FloatType > apply(
}

@Override
public void applyInPlace(Tensor<FloatType> input) {
public < R extends RealType< R > & NativeType< R > > void applyInPlace(Tensor<R> input) {
checkRequiredArgs();
String selectedAxes = "";
for (String ax : input.getAxesOrderString().split("")) {
Expand All @@ -167,13 +176,11 @@ public void applyInPlace(Tensor<FloatType> input) {

}

private void globalScale( final Tensor< FloatType > output ) {
LoopBuilder.setImages( output.getData() )
.multiThreaded()
.forEachPixel( i -> i.set( gainDouble.floatValue() * i.get() + offsetDouble.floatValue() ) );
private < R extends RealType< R > & NativeType< R > > void globalScale( final Tensor< R > output ) {
scaleLinear(output.getData(), gainDouble.doubleValue(), offsetDouble.doubleValue());
}

private void axesScale( final Tensor< FloatType > output, String axesOfInterest) {
private < R extends RealType< R > & NativeType< R > > void axesScale( final Tensor< R > output, String axesOfInterest) {
long[] start = new long[output.getData().numDimensions()];
long[] dims = output.getData().dimensionsAsLongArray();
long[] indOfDims = new long[axesOfInterest.length()];
Expand All @@ -195,12 +202,10 @@ private void axesScale( final Tensor< FloatType > output, String axesOfInterest)
// Define the view by defining the length per axis
long[] end = new long[dims.length];
for (int i = 0; i < dims.length; i ++) end[i] = dims[i] - start[i];
IntervalView<FloatType> plane = Views.offsetInterval( output.getData(), start, end );
IntervalView<R> plane = Views.offsetInterval( output.getData(), start, end );
final float gain = (float) this.gainArr[c];
final float offset = (float) this.offsetArr[c ++ ];
LoopBuilder.setImages( plane )
.multiThreaded()
.forEachPixel( i -> i.set( i.get() * gain + offset ) );
scaleLinear(plane, gain, offset);
}
}

Expand All @@ -221,4 +226,48 @@ private static long[][] getAllCombinations(long[] arr){
}
return allPoints;
}

@SuppressWarnings("unchecked")
public < R extends RealType< R > & NativeType< R > >
void scaleLinear(RandomAccessibleInterval<R> rai, double gain, double offset) {
if (rai.getAt(0) instanceof ByteType) {
LoopBuilder.setImages( (RandomAccessibleInterval<ByteType>) rai )
.multiThreaded()
.forEachPixel( i -> i.set((byte) (i.get() * gain + offset) ) );
} else if (rai.getAt(0) instanceof UnsignedByteType) {
LoopBuilder.setImages( (RandomAccessibleInterval<UnsignedByteType>) rai )
.multiThreaded()
.forEachPixel( i -> i.set((int) (i.get() * gain + offset) ) );
} else if (rai.getAt(0) instanceof ShortType) {
LoopBuilder.setImages((RandomAccessibleInterval<ShortType>) rai )
.multiThreaded()
.forEachPixel( i -> i.set((short) (i.get() * gain + offset) ) );
} else if (rai.getAt(0) instanceof UnsignedShortType) {
LoopBuilder.setImages((RandomAccessibleInterval<UnsignedShortType>) rai )
.multiThreaded()
.forEachPixel( i -> i.set((int) (i.get() * gain + offset) ) );
} else if (rai.getAt(0) instanceof IntType) {
LoopBuilder.setImages((RandomAccessibleInterval<IntType>) rai )
.multiThreaded()
.forEachPixel( i -> i.set((int) (i.get() * gain + offset) ) );
} else if (rai.getAt(0) instanceof UnsignedIntType) {
LoopBuilder.setImages((RandomAccessibleInterval<UnsignedIntType>) rai )
.multiThreaded()
.forEachPixel( i -> i.set((long) (i.get() * gain + offset) ) );
} else if (rai.getAt(0) instanceof LongType) {
LoopBuilder.setImages((RandomAccessibleInterval<LongType>) rai )
.multiThreaded()
.forEachPixel( i -> i.set((long) (i.get() * gain + offset) ) );
} else if (rai.getAt(0) instanceof FloatType) {
LoopBuilder.setImages((RandomAccessibleInterval<FloatType>) rai )
.multiThreaded()
.forEachPixel( i -> i.set((float) (i.get() * gain + offset) ) );
} else if (rai.getAt(0) instanceof DoubleType) {
LoopBuilder.setImages((RandomAccessibleInterval<DoubleType>) rai )
.multiThreaded()
.forEachPixel( i -> i.set((double) (i.get() * gain + offset) ) );
} else {
throw new IllegalArgumentException("Unsupported data type: " + Util.getTypeFromInterval(rai));
}
}
}

0 comments on commit 9d2cdba

Please sign in to comment.