Skip to content

Commit

Permalink
keep working on the tiling
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosuc3m committed Sep 20, 2024
1 parent 34f7942 commit 2b5232a
Showing 1 changed file with 66 additions and 2 deletions.
68 changes: 66 additions & 2 deletions src/main/java/io/bioimage/modelrunner/tiling/TileCalculator.java
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
package io.bioimage.modelrunner.tiling;

import java.util.Arrays;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

import io.bioimage.modelrunner.bioimageio.TileFactory;
import io.bioimage.modelrunner.bioimageio.description.Axis;
import io.bioimage.modelrunner.bioimageio.description.ModelDescriptor;
import io.bioimage.modelrunner.bioimageio.description.TensorSpec;
import io.bioimage.modelrunner.tensor.Tensor;
import net.imglib2.RandomAccessibleInterval;

public class TileCalculator {

Expand All @@ -17,6 +21,10 @@ public class TileCalculator {

private final TileFactory factory;

private final LinkedHashMap<String, PatchSpec> input = new LinkedHashMap<String, PatchSpec>();

private final LinkedHashMap<String, PatchSpec> output = new LinkedHashMap<String, PatchSpec>();

private TileCalculator(ModelDescriptor descriptor, List<TileInfo> tileInfoList) {
this.descriptor = descriptor;
this.tileInfoList = tileInfoList;
Expand All @@ -28,14 +36,13 @@ public static TileCalculator build(ModelDescriptor descriptor, List<TileInfo> ti
return new TileCalculator(descriptor, tileInfoList);
}

private boolean validate() {
private void validate() {
checkAllTensorsDefined();
validateTileVsImageSize();
validateStepMin();
validateTileVsHalo();
validateTileVsImageChannel();
checkTilesCombine();
return false;
}

private void validateTileVsHalo() {
Expand Down Expand Up @@ -163,7 +170,64 @@ private void checkAllTensorsDefined() {
}

private void calculate() {
for (TensorSpec tt : this.descriptor.getInputTensors()) {
TileInfo tile = tileInfoList.stream()
.filter(til -> til.getName().equals(tt.getTensorID())).findFirst().orElse(null);
input.put(tt.getTensorID(), computePatchSpecs(tt, tile));
}
}

/**
* Compute the patch details needed to perform the tiling strategy. The calculations
* obtain the input patch, the padding needed at each side and the number of patches
* needed for every tensor.
*
* @param spec
* specs of the tensor
* @param rai
* ImgLib2 rai, backend of a tensor, that is going to be tiled
* @param tileSize
* the size of the tile selected to process the image
*
* @return an object containing the specs needed to perform patching for the particular tensor
*/
private PatchSpec computePatchSpecs(TensorSpec spec, TileInfo tile)
{
long[] imSize = arrayToWantedAxesOrderAddOnes(tile.getImageDimensions(),
tile.getImageAxesOrder(), spec.getAxesInfo().getAxesOrder());
long[] tileSize = arrayToWantedAxesOrderAddOnes(tile.getProposedTileDimensions(),
tile.getTileAxesOrder(), spec.getAxesInfo().getAxesOrder());
int[][] paddingSize = new int[2][tileSize.length];
// REgard that the input halo represents the output halo + offset
// and must be divisible by 0.5.
int[] halo = spec.getHaloArr();
if (!descriptor.isPyramidal() && this.descriptor.isTilingAllowed()) {
// In the case that padding is asymmetrical, the left upper padding has the extra pixel
for (int i = 0; i < halo.length; i ++) {paddingSize[0][i] = (int) Math.ceil(halo[i]);}
// In the case that padding is asymmetrical, the right bottom padding has one pixel less
for (int i = 0; i < halo.length; i ++) {paddingSize[1][i] = (int) Math.floor(halo[i]);}

}
int[] patchGridSize = new int[imSize.length];
for (int i = 0; i < patchGridSize.length; i ++) patchGridSize[i] = 1;
if (descriptor.isTilingAllowed()) {
patchGridSize = IntStream.range(0, tileSize.length)
.map(i -> (int) Math.ceil((double) imSize[i] / ((double) tileSize[i] - halo[i] * 2)))
.toArray();
}
// For the cases when the patch is bigger than the image size, share the
// padding between both sides of the image
paddingSize[0] = IntStream.range(0, tileSize.length)
.map(i ->
(int) Math.max(paddingSize[0][i],
Math.ceil( (double) (tileSize[i] - imSize[i]) / 2))
).toArray();
paddingSize[1] = IntStream.range(0, tileSize.length)
.map(i -> (int) Math.max( paddingSize[1][i],
tileSize[i] - imSize[i] - paddingSize[0][i])).toArray();

return PatchSpec.create(spec.getTensorID(), tileSize, patchGridSize, paddingSize, imSize);
}

public void getTileList() {

Expand Down

0 comments on commit 2b5232a

Please sign in to comment.