Skip to content

Commit

Permalink
add method to validate min step constraints
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosuc3m committed Sep 19, 2024
1 parent e80e898 commit d8f7e2d
Showing 1 changed file with 121 additions and 5 deletions.
126 changes: 121 additions & 5 deletions src/main/java/io/bioimage/modelrunner/tiling/TileCalculator.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,25 +20,59 @@ public class TileCalculator {
private TileCalculator(ModelDescriptor descriptor, List<TileInfo> tileInfoList) {
this.descriptor = descriptor;
this.tileInfoList = tileInfoList;
validateTileSize();
validate();
this.factory = TileFactory.init(descriptor);
}

public static TileCalculator build(ModelDescriptor descriptor, List<TileInfo> tileInfoList) {
return new TileCalculator(descriptor, tileInfoList);
}

private boolean validateTileSize() {
private boolean validate() {
checkAllTensorsDefined();
checkTileDims(tensor, tile);
validateTileVsImageSize();
validateStepMin();
validateTileVsHalo();
validateNoTiling();
validateTileVsImageChannel();
checkTilesCombine();
return false;
}

private void validate
private void validateTileVsHalo() {

}

private void validateNoTiling() {

}

private void validateStepMin() {
for (TileInfo tile : this.tileInfoList) {
TensorSpec tt = this.descriptor.findInputTensor(tile.getName());
if (tt == null) continue;
String axesTile = tile.getTileAxesOrder();
long[] tileDims = tile.getProposedTileDimensions();
String axesTensor = tt.getAxesOrder();
axesTile = addMissingAxes(axesTensor, axesTile);
axesTensor = addMissingAxes(axesTile, axesTensor);
tileDims = arrayToWantedAxesOrderAddOnes(tileDims, tile.getTileAxesOrder(), axesTile);
int[] min = arrayToWantedAxesOrderAddOnes(tt.getMinTileSizeArr(), tt.getAxesOrder(), axesTile);
int[] step = arrayToWantedAxesOrderAddZeros(tt.getTileStepArr(), tt.getAxesOrder(), axesTile);

for (int i = 0; i < tileDims.length; i ++) {
if (tileDims[i] != min[i] && step[i] == 0)
throw new IllegalArgumentException("Invalid tile size for axis '" + axesTile.split("")[i].toUpperCase()
+ "'. Only allowed tile size for this axis is: " + min[i]);
else if ((tileDims[i] - min[i]) % step[i] != 0)
throw new IllegalArgumentException("Invalid tile size for axis '" + axesTile.split("")[i].toUpperCase()
+ "'. Tile size for this axis should satisfy: " + min[i] + " + n x " + step[i]
+ " where n can be any positive integer.");
}
}
}

private void validateTileVsImageChannel() {

}

Expand Down Expand Up @@ -94,7 +128,6 @@ private void checkAllTensorsDefined() {
}

private void calculate() {
this.validateTileSize();
}

public void getTileList() {
Expand Down Expand Up @@ -190,6 +223,33 @@ public static long[] arrayToWantedAxesOrderAddOnes(long[] size, String orginalAx
return finalSize;
}

/**
* Convert the array following given axes order into
* another int[] which follows the target axes order
* The newly added components will be ones.
* @param size
* original array following the original axes order
* @param orginalAxes
* axes order of the original array
* @param targetAxes
* axes order of the target array
* @return a size array in the order of the tensor of interest
*/
public static int[] arrayToWantedAxesOrderAddOnes(int[] size, String orginalAxes, String targetAxes) {
orginalAxes = orginalAxes.toLowerCase();
String[] axesArr = targetAxes.toLowerCase().split("");
int[] finalSize = new int[targetAxes.length()];
for (int i = 0; i < finalSize.length; i ++) {
int ind = orginalAxes.indexOf(axesArr[i]);
if (ind == -1) {
finalSize[i] = 1;
} else {
finalSize[i] = size[ind];
}
}
return finalSize;
}

/**
* Convert the array following given axes order into
* another float[] which follows the target axes order
Expand Down Expand Up @@ -242,6 +302,31 @@ public static float[] arrayToWantedAxesOrderAddZeros(float[] size, String orgina
return finalSize;
}

/**
* Convert the array following given axes order into
* another int[] which follows the target axes order.
* The newly added components will be zeros.
* @param size
* original array following the original axes order
* @param orginalAxes
* axes order of the original array
* @param targetAxes
* axes order of the target array
* @return a size array in the order of the tensor of interest
*/
public static long[] arrayToWantedAxesOrderAddZeros(long[] size, String orginalAxes, String targetAxes) {
orginalAxes = orginalAxes.toLowerCase();
String[] axesArr = targetAxes.toLowerCase().split("");
long[] finalSize = new long[targetAxes.length()];
for (int i = 0; i < finalSize.length; i ++) {
int ind = orginalAxes.indexOf(axesArr[i]);
if (ind == -1)
continue;
finalSize[i] = size[ind];
}
return finalSize;
}

/**
* Convert the array following given axes order into
* another int[] which follows the target axes order.
Expand All @@ -266,6 +351,37 @@ public static int[] arrayToWantedAxesOrderAddZeros(int[] size, String orginalAxe
}
return finalSize;
}

/**
* Compare two axes order strings and adds the dimensions of axes1 that are not present in
* axes2 to axes2. Special rules are applied for axes 'b' and 't':
* - If 'b' is present in axes1 but 't' is in axes2, 'b' is skipped.
* - If 't' is present in axes1 but 'b' is in axes2, 't' is added to axes2.
*
* For example:
* <pre>
* String result1 = addMissingAxes("xyz", "xz");
* // result1 will be "xyz" since 'y' is added to "xz"
*
* String result2 = addMissingAxes("xyz", "xyc");
* // result2 will be "xycz"
* </pre>
*
* @param axes1 The source axes order string from which missing axes are added.
* @param axes2 The target axes order string where missing axes are added.
* @return The modified axes2 string including missing axes from axes1.
*/
public static String addMissingAxes(String axes1, String axes2) {
for (String ax : axes1.split("")) {
if (ax.equals("b") && axes2.indexOf(ax) == -1 && axes2.indexOf("t") != -1)
continue;
else if (ax.equals("t") && axes2.indexOf(ax) == -1 && axes2.indexOf("b") != -1)
axes2 += ax;
else if (axes2.indexOf(ax) == -1)
axes2 += ax;
}
return axes2;
}



Expand Down

0 comments on commit d8f7e2d

Please sign in to comment.