Skip to content

Commit

Permalink
try to generalize more the transformations
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosuc3m committed Sep 26, 2024
1 parent 3a3c056 commit 8ad2c83
Show file tree
Hide file tree
Showing 2 changed files with 119 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,18 @@
package io.bioimage.modelrunner.transformations;

import io.bioimage.modelrunner.tensor.Tensor;

import net.imglib2.RandomAccessibleInterval;
import net.imglib2.loops.LoopBuilder;
import net.imglib2.type.NativeType;
import net.imglib2.type.numeric.RealType;
import net.imglib2.type.numeric.integer.ByteType;
import net.imglib2.type.numeric.integer.IntType;
import net.imglib2.type.numeric.integer.LongType;
import net.imglib2.type.numeric.integer.ShortType;
import net.imglib2.type.numeric.integer.UnsignedByteType;
import net.imglib2.type.numeric.integer.UnsignedIntType;
import net.imglib2.type.numeric.integer.UnsignedShortType;
import net.imglib2.type.numeric.real.DoubleType;
import net.imglib2.type.numeric.real.FloatType;

/**
Expand All @@ -38,6 +46,14 @@ public class AbstractTensorPixelTransformation extends AbstractTensorTransformat
{

private FloatUnaryOperator fun;
private DoubleUnaryOperator dun;
private ByteUnaryOperator bun;
private UByteUnaryOperator ubun;
private ShortUnaryOperator sun;
private UShortUnaryOperator usun;
private IntUnaryOperator iun;
private UIntUnaryOperator uiun;
private LongUnaryOperator lun;

protected AbstractTensorPixelTransformation( final String name)
{
Expand All @@ -52,25 +68,116 @@ protected void setFloatUnitaryOperator(final FloatUnaryOperator fun) {
public < R extends RealType< R > & NativeType< R > > Tensor< FloatType > apply( final Tensor< R > input )
{
final Tensor< FloatType > output = makeOutput( input );
LoopBuilder
.setImages( input.getData(), output.getData() )
.multiThreaded()
.forEachPixel( ( i, o ) -> o.set( fun.applyAsFloat( i.getRealFloat() ) ) );
applyInPlace(output);
return output;
}

@SuppressWarnings("unchecked")
@Override
public void applyInPlace( final Tensor< FloatType > input )
public < R extends RealType< R > & NativeType< R > >
void applyInPlace( final Tensor< R > input )
{
LoopBuilder
.setImages( input.getData() )
.multiThreaded()
.forEachPixel( i -> i.set( fun.applyAsFloat( i.get() ) ) );
if (input.getData().getAt(0) instanceof FloatType) {
LoopBuilder
.setImages( (RandomAccessibleInterval<FloatType>) input.getData() )
.multiThreaded()
.forEachPixel( i -> i.set( fun.applyAs( i.get() ) ) );
} else if (input.getData().getAt(0) instanceof DoubleType) {
LoopBuilder
.setImages( (RandomAccessibleInterval<DoubleType>) input.getData() )
.multiThreaded()
.forEachPixel( i -> i.set( dun.applyAs( i.get() ) ) );
} else if (input.getData().getAt(0) instanceof ByteType) {
LoopBuilder
.setImages( (RandomAccessibleInterval<ByteType>) input.getData() )
.multiThreaded()
.forEachPixel( i -> i.set( bun.applyAs( i.get() ) ) );
} else if (input.getData().getAt(0) instanceof UnsignedByteType) {
LoopBuilder
.setImages( (RandomAccessibleInterval<UnsignedByteType>) input.getData() )
.multiThreaded()
.forEachPixel( i -> i.set( ubun.applyAs( i.get() ) ) );
} else if (input.getData().getAt(0) instanceof ShortType) {
LoopBuilder
.setImages( (RandomAccessibleInterval<ShortType>) input.getData() )
.multiThreaded()
.forEachPixel( i -> i.set( sun.applyAs( i.get() ) ) );
} else if (input.getData().getAt(0) instanceof UnsignedShortType) {
LoopBuilder
.setImages( (RandomAccessibleInterval<UnsignedShortType>) input.getData() )
.multiThreaded()
.forEachPixel( i -> i.set( usun.applyAs( i.get() ) ) );
} else if (input.getData().getAt(0) instanceof IntType) {
LoopBuilder
.setImages( (RandomAccessibleInterval<IntType>) input.getData() )
.multiThreaded()
.forEachPixel( i -> i.set( iun.applyAs( i.get() ) ) );
} else if (input.getData().getAt(0) instanceof UnsignedIntType) {
LoopBuilder
.setImages( (RandomAccessibleInterval<UnsignedIntType>) input.getData() )
.multiThreaded()
.forEachPixel( i -> i.set( uiun.applyAs( i.get() ) ) );
} else if (input.getData().getAt(0) instanceof LongType) {
LoopBuilder
.setImages( (RandomAccessibleInterval<LongType>) input.getData() )
.multiThreaded()
.forEachPixel( i -> i.set( lun.applyAs( i.get() ) ) );
} else {
throw new IllegalArgumentException("Unsupported data type.");
}
}

@FunctionalInterface
public interface FloatUnaryOperator
{
float applyAsFloat( float in );
float applyAs( float in );
}

@FunctionalInterface
public interface DoubleUnaryOperator
{
float applyAs( double in );
}

@FunctionalInterface
public interface ByteUnaryOperator
{
byte applyAs( byte in );
}

@FunctionalInterface
public interface UByteUnaryOperator
{
int applyAs( int i );
}

@FunctionalInterface
public interface ShortUnaryOperator
{
short applyAs( short in );
}

@FunctionalInterface
public interface UShortUnaryOperator
{
int applyAs( int i );
}

@FunctionalInterface
public interface IntUnaryOperator
{
int applyAs( int in );
}

@FunctionalInterface
public interface UIntUnaryOperator
{
long applyAs( long in );
}

@FunctionalInterface
public interface LongUnaryOperator
{
long applyAs( long in );
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ public abstract class AbstractTensorTransformation implements TensorTransformati

private Mode mode = Mode.FIXED;

protected static String DEFAULT_MISSING_ARG_ERR = "Cannot execute Clip BioImage.io transformation because '%s' "
protected static String DEFAULT_MISSING_ARG_ERR = "Cannot execute %s BioImage.io transformation because '%s' "
+ "parameter was not set.";

protected AbstractTensorTransformation( final String name )
Expand Down

0 comments on commit 8ad2c83

Please sign in to comment.