diff --git a/src/main/java/io/bioimage/modelrunner/tiling/TileCalculator.java b/src/main/java/io/bioimage/modelrunner/tiling/TileCalculator.java index 11350998..481a2c55 100644 --- a/src/main/java/io/bioimage/modelrunner/tiling/TileCalculator.java +++ b/src/main/java/io/bioimage/modelrunner/tiling/TileCalculator.java @@ -1,13 +1,17 @@ package io.bioimage.modelrunner.tiling; import java.util.Arrays; +import java.util.LinkedHashMap; import java.util.List; import java.util.stream.Collectors; +import java.util.stream.IntStream; import io.bioimage.modelrunner.bioimageio.TileFactory; import io.bioimage.modelrunner.bioimageio.description.Axis; import io.bioimage.modelrunner.bioimageio.description.ModelDescriptor; import io.bioimage.modelrunner.bioimageio.description.TensorSpec; +import io.bioimage.modelrunner.tensor.Tensor; +import net.imglib2.RandomAccessibleInterval; public class TileCalculator { @@ -17,6 +21,10 @@ public class TileCalculator { private final TileFactory factory; + private final LinkedHashMap input = new LinkedHashMap(); + + private final LinkedHashMap output = new LinkedHashMap(); + private TileCalculator(ModelDescriptor descriptor, List tileInfoList) { this.descriptor = descriptor; this.tileInfoList = tileInfoList; @@ -28,14 +36,13 @@ public static TileCalculator build(ModelDescriptor descriptor, List ti return new TileCalculator(descriptor, tileInfoList); } - private boolean validate() { + private void validate() { checkAllTensorsDefined(); validateTileVsImageSize(); validateStepMin(); validateTileVsHalo(); validateTileVsImageChannel(); checkTilesCombine(); - return false; } private void validateTileVsHalo() { @@ -163,7 +170,64 @@ private void checkAllTensorsDefined() { } private void calculate() { + for (TensorSpec tt : this.descriptor.getInputTensors()) { + TileInfo tile = tileInfoList.stream() + .filter(til -> til.getName().equals(tt.getTensorID())).findFirst().orElse(null); + input.put(tt.getTensorID(), computePatchSpecs(tt, tile)); + } } + + /** + * Compute the patch details needed to perform the tiling strategy. The calculations + * obtain the input patch, the padding needed at each side and the number of patches + * needed for every tensor. + * + * @param spec + * specs of the tensor + * @param rai + * ImgLib2 rai, backend of a tensor, that is going to be tiled + * @param tileSize + * the size of the tile selected to process the image + * + * @return an object containing the specs needed to perform patching for the particular tensor + */ + private PatchSpec computePatchSpecs(TensorSpec spec, TileInfo tile) + { + long[] imSize = arrayToWantedAxesOrderAddOnes(tile.getImageDimensions(), + tile.getImageAxesOrder(), spec.getAxesInfo().getAxesOrder()); + long[] tileSize = arrayToWantedAxesOrderAddOnes(tile.getProposedTileDimensions(), + tile.getTileAxesOrder(), spec.getAxesInfo().getAxesOrder()); + int[][] paddingSize = new int[2][tileSize.length]; + // REgard that the input halo represents the output halo + offset + // and must be divisible by 0.5. + int[] halo = spec.getHaloArr(); + if (!descriptor.isPyramidal() && this.descriptor.isTilingAllowed()) { + // In the case that padding is asymmetrical, the left upper padding has the extra pixel + for (int i = 0; i < halo.length; i ++) {paddingSize[0][i] = (int) Math.ceil(halo[i]);} + // In the case that padding is asymmetrical, the right bottom padding has one pixel less + for (int i = 0; i < halo.length; i ++) {paddingSize[1][i] = (int) Math.floor(halo[i]);} + + } + int[] patchGridSize = new int[imSize.length]; + for (int i = 0; i < patchGridSize.length; i ++) patchGridSize[i] = 1; + if (descriptor.isTilingAllowed()) { + patchGridSize = IntStream.range(0, tileSize.length) + .map(i -> (int) Math.ceil((double) imSize[i] / ((double) tileSize[i] - halo[i] * 2))) + .toArray(); + } + // For the cases when the patch is bigger than the image size, share the + // padding between both sides of the image + paddingSize[0] = IntStream.range(0, tileSize.length) + .map(i -> + (int) Math.max(paddingSize[0][i], + Math.ceil( (double) (tileSize[i] - imSize[i]) / 2)) + ).toArray(); + paddingSize[1] = IntStream.range(0, tileSize.length) + .map(i -> (int) Math.max( paddingSize[1][i], + tileSize[i] - imSize[i] - paddingSize[0][i])).toArray(); + + return PatchSpec.create(spec.getTensorID(), tileSize, patchGridSize, paddingSize, imSize); + } public void getTileList() {