Skip to content

Commit

Permalink
add pre and post processing to model run
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosuc3m committed Sep 26, 2024
1 parent cef3a32 commit 3a3c056
Showing 1 changed file with 20 additions and 6 deletions.
26 changes: 20 additions & 6 deletions src/main/java/io/bioimage/modelrunner/model/Model.java
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
import io.bioimage.modelrunner.exceptions.LoadEngineException;
import io.bioimage.modelrunner.exceptions.LoadModelException;
import io.bioimage.modelrunner.exceptions.RunModelException;
import io.bioimage.modelrunner.model.processing.Processing;
import io.bioimage.modelrunner.tensor.Tensor;
import io.bioimage.modelrunner.utils.Constants;
import io.bioimage.modelrunner.versionmanagement.InstalledEngines;
Expand Down Expand Up @@ -530,8 +531,14 @@ void runModel( List< Tensor < T > > inTensors, List< Tensor < R > > outTensors )
{
DeepLearningEngineInterface engineInstance = engineClassLoader.getEngineInstance();
engineClassLoader.setEngineClassLoader();
inTensors.stream().forEach( tt -> tt = (Tensor<T>) Tensor.createCopyOfTensorInWantedDataType( tt, new FloatType() ) );
engineInstance.run( inTensors, outTensors );
ArrayList<Tensor<FloatType>> inTensorsFloat = new ArrayList<Tensor<FloatType>>();
for (Tensor<T> tt : inTensors) {
if (tt.getData().getAt(0) instanceof FloatType)
inTensorsFloat.add((Tensor<FloatType>) tt);
else
inTensorsFloat.add(Tensor.createCopyOfTensorInWantedDataType( tt, new FloatType() ));
}
engineInstance.run( inTensorsFloat, outTensors );
engineClassLoader.setBaseClassLoader();
}

Expand Down Expand Up @@ -595,8 +602,8 @@ else if (descriptor == null)
.map(tt -> new ImageInfo(tt.getName(), tt.getAxesOrderString(), tt.getData().dimensionsAsLongArray()))
.collect(Collectors.toList());
List<TileInfo> inputTiles = calc.getOptimalTileSize(imageInfos);
TileMaker tiles = TileMaker.build(descriptor, inputTiles);
return runTiling(inputTensors, tiles, tileCounter);
TileMaker maker = TileMaker.build(descriptor, inputTiles);
return runBMZ(inputTensors, maker, tileCounter);
}

/**
Expand Down Expand Up @@ -663,8 +670,15 @@ else if (descriptor == null) {
}
}
TileMaker maker = TileMaker.build(descriptor, tiles);

return runTiling(inputTensors, maker, tileCounter);
return runBMZ(inputTensors, maker, tileCounter);
}

private <T extends RealType<T> & NativeType<T>, R extends RealType<R> & NativeType<R>>
List<Tensor<T>> runBMZ(List<Tensor<R>> inputTensors, TileMaker tiles, TilingConsumer tileCounter) throws RunModelException {
Processing processing = Processing.init(descriptor);
inputTensors = processing.preprocess(inputTensors, false);
List<Tensor<R>> outputTensors = runTiling(inputTensors, tiles, tileCounter);
return processing.postprocess(outputTensors, true);
}

@SuppressWarnings("unchecked")
Expand Down

0 comments on commit 3a3c056

Please sign in to comment.