diff --git a/src/main/java/io/bioimage/modelrunner/tiling/TileCalculator.java b/src/main/java/io/bioimage/modelrunner/tiling/TileCalculator.java index fefbde61..3bcac9e0 100644 --- a/src/main/java/io/bioimage/modelrunner/tiling/TileCalculator.java +++ b/src/main/java/io/bioimage/modelrunner/tiling/TileCalculator.java @@ -20,7 +20,7 @@ public class TileCalculator { private TileCalculator(ModelDescriptor descriptor, List tileInfoList) { this.descriptor = descriptor; this.tileInfoList = tileInfoList; - validateTileSize(); + validate(); this.factory = TileFactory.init(descriptor); } @@ -28,17 +28,51 @@ public static TileCalculator build(ModelDescriptor descriptor, List ti return new TileCalculator(descriptor, tileInfoList); } - private boolean validateTileSize() { + private boolean validate() { checkAllTensorsDefined(); - checkTileDims(tensor, tile); validateTileVsImageSize(); + validateStepMin(); + validateTileVsHalo(); + validateNoTiling(); + validateTileVsImageChannel(); checkTilesCombine(); return false; } - private void validate + private void validateTileVsHalo() { + + } + + private void validateNoTiling() { + + } private void validateStepMin() { + for (TileInfo tile : this.tileInfoList) { + TensorSpec tt = this.descriptor.findInputTensor(tile.getName()); + if (tt == null) continue; + String axesTile = tile.getTileAxesOrder(); + long[] tileDims = tile.getProposedTileDimensions(); + String axesTensor = tt.getAxesOrder(); + axesTile = addMissingAxes(axesTensor, axesTile); + axesTensor = addMissingAxes(axesTile, axesTensor); + tileDims = arrayToWantedAxesOrderAddOnes(tileDims, tile.getTileAxesOrder(), axesTile); + int[] min = arrayToWantedAxesOrderAddOnes(tt.getMinTileSizeArr(), tt.getAxesOrder(), axesTile); + int[] step = arrayToWantedAxesOrderAddZeros(tt.getTileStepArr(), tt.getAxesOrder(), axesTile); + + for (int i = 0; i < tileDims.length; i ++) { + if (tileDims[i] != min[i] && step[i] == 0) + throw new IllegalArgumentException("Invalid tile size for axis '" + axesTile.split("")[i].toUpperCase() + + "'. Only allowed tile size for this axis is: " + min[i]); + else if ((tileDims[i] - min[i]) % step[i] != 0) + throw new IllegalArgumentException("Invalid tile size for axis '" + axesTile.split("")[i].toUpperCase() + + "'. Tile size for this axis should satisfy: " + min[i] + " + n x " + step[i] + + " where n can be any positive integer."); + } + } + } + + private void validateTileVsImageChannel() { } @@ -94,7 +128,6 @@ private void checkAllTensorsDefined() { } private void calculate() { - this.validateTileSize(); } public void getTileList() { @@ -190,6 +223,33 @@ public static long[] arrayToWantedAxesOrderAddOnes(long[] size, String orginalAx return finalSize; } + /** + * Convert the array following given axes order into + * another int[] which follows the target axes order + * The newly added components will be ones. + * @param size + * original array following the original axes order + * @param orginalAxes + * axes order of the original array + * @param targetAxes + * axes order of the target array + * @return a size array in the order of the tensor of interest + */ + public static int[] arrayToWantedAxesOrderAddOnes(int[] size, String orginalAxes, String targetAxes) { + orginalAxes = orginalAxes.toLowerCase(); + String[] axesArr = targetAxes.toLowerCase().split(""); + int[] finalSize = new int[targetAxes.length()]; + for (int i = 0; i < finalSize.length; i ++) { + int ind = orginalAxes.indexOf(axesArr[i]); + if (ind == -1) { + finalSize[i] = 1; + } else { + finalSize[i] = size[ind]; + } + } + return finalSize; + } + /** * Convert the array following given axes order into * another float[] which follows the target axes order @@ -242,6 +302,31 @@ public static float[] arrayToWantedAxesOrderAddZeros(float[] size, String orgina return finalSize; } + /** + * Convert the array following given axes order into + * another int[] which follows the target axes order. + * The newly added components will be zeros. + * @param size + * original array following the original axes order + * @param orginalAxes + * axes order of the original array + * @param targetAxes + * axes order of the target array + * @return a size array in the order of the tensor of interest + */ + public static long[] arrayToWantedAxesOrderAddZeros(long[] size, String orginalAxes, String targetAxes) { + orginalAxes = orginalAxes.toLowerCase(); + String[] axesArr = targetAxes.toLowerCase().split(""); + long[] finalSize = new long[targetAxes.length()]; + for (int i = 0; i < finalSize.length; i ++) { + int ind = orginalAxes.indexOf(axesArr[i]); + if (ind == -1) + continue; + finalSize[i] = size[ind]; + } + return finalSize; + } + /** * Convert the array following given axes order into * another int[] which follows the target axes order. @@ -266,6 +351,37 @@ public static int[] arrayToWantedAxesOrderAddZeros(int[] size, String orginalAxe } return finalSize; } + + /** + * Compare two axes order strings and adds the dimensions of axes1 that are not present in + * axes2 to axes2. Special rules are applied for axes 'b' and 't': + * - If 'b' is present in axes1 but 't' is in axes2, 'b' is skipped. + * - If 't' is present in axes1 but 'b' is in axes2, 't' is added to axes2. + * + * For example: + *
+	 *     String result1 = addMissingAxes("xyz", "xz");
+	 *     // result1 will be "xyz" since 'y' is added to "xz"
+	 *     
+	 *     String result2 = addMissingAxes("xyz", "xyc");
+	 *     // result2 will be "xycz" 
+	 * 
+ * + * @param axes1 The source axes order string from which missing axes are added. + * @param axes2 The target axes order string where missing axes are added. + * @return The modified axes2 string including missing axes from axes1. + */ + public static String addMissingAxes(String axes1, String axes2) { + for (String ax : axes1.split("")) { + if (ax.equals("b") && axes2.indexOf(ax) == -1 && axes2.indexOf("t") != -1) + continue; + else if (ax.equals("t") && axes2.indexOf(ax) == -1 && axes2.indexOf("b") != -1) + axes2 += ax; + else if (axes2.indexOf(ax) == -1) + axes2 += ax; + } + return axes2; + }