Skip to content

Commit

Permalink
keep unifying tile calculator
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosuc3m committed Sep 18, 2024
1 parent 1e0330f commit e80e898
Showing 1 changed file with 152 additions and 11 deletions.
163 changes: 152 additions & 11 deletions src/main/java/io/bioimage/modelrunner/tiling/TileCalculator.java
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
package io.bioimage.modelrunner.tiling;

import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

import io.bioimage.modelrunner.bioimageio.TileFactory;
import io.bioimage.modelrunner.bioimageio.description.ModelDescriptor;
Expand All @@ -18,29 +20,64 @@ public class TileCalculator {
private TileCalculator(ModelDescriptor descriptor, List<TileInfo> tileInfoList) {
this.descriptor = descriptor;
this.tileInfoList = tileInfoList;
validateTileSize();
this.factory = TileFactory.init(descriptor);
}

public static TileCalculator build(ModelDescriptor descriptor, List<TileInfo> tileInfoList) {
return new TileCalculator(descriptor, tileInfoList);
}

public boolean validateTileSize() {
private boolean validateTileSize() {
checkAllTensorsDefined();
for (TileInfo tile : tileInfoList) {
String id = tile.getName();
TensorSpec tensor = descriptor.getInputTensors().stream()
.filter(tt -> tt.getTensorID().equals(id)).findFirst().orElse(null);
if (tensor == null)
throw new IllegalArgumentException("Invalid tiling information: The input tensor named '"
+ id + "' does not exist in the model. Please check the model's input tensors "
+ "and provide tiling information for an existing tensor.");
checkTileDims(tensor, tile);
}
checkTileDims(tensor, tile);
validateTileVsImageSize();
checkTilesCombine();
return false;
}

private void validate

private void validateStepMin() {

}

private void validateTileVsImageSize() throws IllegalArgumentException {
for (TileInfo tile : this.tileInfoList) {
String axesTile = tile.getTileAxesOrder();
String axesImage = tile.getImageAxesOrder();
long[] tileDims = tile.getProposedTileDimensions();
checkAxisSize(tile);
long[] imDims = arrayToWantedAxesOrderAddOnes(tile.getImageDimensions(), axesImage, axesTile);
for (int i = 0; i < axesTile.length(); i ++) {
int indIm = axesImage.indexOf(axesTile.split("")[i]);
if (imDims[indIm] * 3 < tileDims[i])
throw new IllegalArgumentException("Error in the axes size selected. "
+ "The axes size introduced in any of the dimensions cannot "
+ "be bigger than 3 times the image size of that same axes. "
+ "The image selected has " + axesTile.split("")[i] + "-dimension of size "
+ imDims[indIm] + "and the tile is of size " + tileDims[i] + "."
+ " Maxmum tile size for " + axesTile.split("")[i] + "-axis in this image is "
+ imDims[indIm] * 3);
}
}
}

private static void checkAxisSize(TileInfo tile) {
String axesTile = tile.getTileAxesOrder();
long[] tileDims = tile.getProposedTileDimensions();
if (axesTile.length() != tileDims.length)
throw new IllegalArgumentException("The tile dimensions and tile axes should be of the same length:"
+ " " + axesTile + " (" + axesTile.length() + ") vs " + Arrays.toString(tileDims)
+ " (" + tileDims.length + ")");
String axesImage = tile.getImageAxesOrder();
long[] imDims = tile.getImageDimensions();
if (axesImage.length() != imDims.length)
throw new IllegalArgumentException("The image dimensions and image axes should be of the same length:"
+ " " + axesImage + " (" + axesImage.length() + ") vs " + Arrays.toString(imDims)
+ " (" + imDims.length + ")");
}

private void checkTilesCombine() {

}
Expand Down Expand Up @@ -125,6 +162,110 @@ public List<long[]> getTilePostionsOutputImage(String tensorId) {
private static void checkTileDims(TensorSpec tensor, TileInfo tile) {

}

/**
* Convert the array following given axes order into
* another int[] which follows the target axes order
* The newly added components will be ones.
* @param size
* original array following the original axes order
* @param orginalAxes
* axes order of the original array
* @param targetAxes
* axes order of the target array
* @return a size array in the order of the tensor of interest
*/
public static long[] arrayToWantedAxesOrderAddOnes(long[] size, String orginalAxes, String targetAxes) {
orginalAxes = orginalAxes.toLowerCase();
String[] axesArr = targetAxes.toLowerCase().split("");
long[] finalSize = new long[targetAxes.length()];
for (int i = 0; i < finalSize.length; i ++) {
int ind = orginalAxes.indexOf(axesArr[i]);
if (ind == -1) {
finalSize[i] = 1;
} else {
finalSize[i] = size[ind];
}
}
return finalSize;
}

/**
* Convert the array following given axes order into
* another float[] which follows the target axes order
* The newly added components will be ones.
* @param size
* original array following the original axes order
* @param orginalAxes
* axes order of the original array
* @param targetAxes
* axes order of the target array
* @return a size array in the order of the tensor of interest
*/
public static float[] arrayToWantedAxesOrderAddOnes(float[] size, String orginalAxes, String targetAxes) {
orginalAxes = orginalAxes.toLowerCase();
String[] axesArr = targetAxes.toLowerCase().split("");
float[] finalSize = new float[targetAxes.length()];
for (int i = 0; i < finalSize.length; i ++) {
int ind = orginalAxes.indexOf(axesArr[i]);
if (ind == -1) {
finalSize[i] = 1;
} else {
finalSize[i] = size[ind];
}
}
return finalSize;
}

/**
* Convert the array following given axes order into
* another float[] which follows the target axes order.
* The newly added components will be zeros.
* @param size
* original array following the original axes order
* @param orginalAxes
* axes order of the original array
* @param targetAxes
* axes order of the target array
* @return a size array in the order of the tensor of interest
*/
public static float[] arrayToWantedAxesOrderAddZeros(float[] size, String orginalAxes, String targetAxes) {
orginalAxes = orginalAxes.toLowerCase();
String[] axesArr = targetAxes.toLowerCase().split("");
float[] finalSize = new float[targetAxes.length()];
for (int i = 0; i < finalSize.length; i ++) {
int ind = orginalAxes.indexOf(axesArr[i]);
if (ind == -1)
continue;
finalSize[i] = size[ind];
}
return finalSize;
}

/**
* Convert the array following given axes order into
* another int[] which follows the target axes order.
* The newly added components will be zeros.
* @param size
* original array following the original axes order
* @param orginalAxes
* axes order of the original array
* @param targetAxes
* axes order of the target array
* @return a size array in the order of the tensor of interest
*/
public static int[] arrayToWantedAxesOrderAddZeros(int[] size, String orginalAxes, String targetAxes) {
orginalAxes = orginalAxes.toLowerCase();
String[] axesArr = targetAxes.toLowerCase().split("");
int[] finalSize = new int[targetAxes.length()];
for (int i = 0; i < finalSize.length; i ++) {
int ind = orginalAxes.indexOf(axesArr[i]);
if (ind == -1)
continue;
finalSize[i] = size[ind];
}
return finalSize;
}



Expand Down

0 comments on commit e80e898

Please sign in to comment.