Skip to content

Commit

Permalink
start improving the tiling strategy and usage
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosuc3m committed Aug 6, 2024
1 parent d6aed6f commit f51cf42
Show file tree
Hide file tree
Showing 7 changed files with 104 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -57,4 +57,6 @@ public interface TensorSpec {
public Axes getAxesInfo();

public String getDataType();

public boolean isImage();
}
Original file line number Diff line number Diff line change
Expand Up @@ -177,4 +177,12 @@ public String getDataType() {
public int[] getHaloArr() {
return this.axes.getHaloArr();
}

@Override
public boolean isImage() {
if (axes.getAxesOrder().contains("x") && axes.getAxesOrder().contains("y"))
return true;
else
return false;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -180,4 +180,12 @@ public String getDataType() {
public int[] getHaloArr() {
return this.axes.getHaloArr();
}

@Override
public boolean isImage() {
if (axes.getAxesOrder().contains("x") && axes.getAxesOrder().contains("y"))
return true;
else
return false;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -186,11 +186,11 @@ public void checkPatchSpecs(LinkedHashMap<String, PatchSpec> patchSpecs) throws
for (Entry<String, PatchSpec> spec : patchSpecs.entrySet()) {
int[] nGrid = spec.getValue().getTileGrid();
TensorSpec tt = this.descriptor.findInputTensor(spec.getKey());
if (grid == null && tt.getTiling()) {
if (grid == null && this.descriptor.isTilingAllowed()) {
grid = nGrid;
firstName = spec.getKey();
}
if (tt.getTiling() && !compareTwoArrays(nGrid, grid)){
if (this.descriptor.isTilingAllowed() && !compareTwoArrays(nGrid, grid)){
throw new IllegalArgumentException("All the input images must be processed with the same number of patches.\n"
+ "The relationship between the patch size and image size should be the same for every input that allows patching/tiling.\n"
+ "Tensors '" + firstName + "' and '" + spec.getKey() + "' need different number of patches to "
Expand Down Expand Up @@ -247,7 +247,7 @@ private List<TensorSpec> findOutputImageTensorSpec()
private LinkedHashMap<String, PatchSpec> computePatchSpecsForEveryTensor(List<TensorSpec> tensors, List<Tensor<T>> images){
LinkedHashMap<String, PatchSpec> patchInfoList = new LinkedHashMap<String, PatchSpec>();
for (int i = 0; i < tensors.size(); i ++)
patchInfoList.put(tensors.get(i).getName(), computePatchSpecs(tensors.get(i), images.get(i)));
patchInfoList.put(tensors.get(i).getTensorID(), computePatchSpecs(tensors.get(i), images.get(i)));
return patchInfoList;
}

Expand All @@ -272,7 +272,7 @@ public LinkedHashMap<String, PatchSpec> getOutputTensorsTileSpecs() throws Illeg
for (int i = 0; i < outTensors.size(); i ++) {
String refTensor = outTensors.get(i).getShape().getReferenceInput();
PatchSpec refSpec = refTensor == null ? inputTilesSpecs.values().stream().findFirst().get() : inputTilesSpecs.get(refTensor);
patchInfoList.put(outTensors.get(i).getName(), computePatchSpecsForOutputTensor(outTensors.get(i), refSpec));
patchInfoList.put(outTensors.get(i).getTensorID(), computePatchSpecsForOutputTensor(outTensors.get(i), refSpec));
}
outputTilesSpecs = patchInfoList;
return outputTilesSpecs;
Expand Down Expand Up @@ -331,8 +331,8 @@ private PatchSpec computePatchSpecs(TensorSpec spec, RandomAccessibleInterval<T>
int[][] paddingSize = new int[2][tileSize.length];
// REgard that the input halo represents the output halo + offset
// and must be divisible by 0.5.
float[] halo = spec.getHalo();
if (!descriptor.isPyramidal() && spec.getTiling()) {
int[] halo = spec.getHaloArr();
if (!descriptor.isPyramidal() && this.descriptor.isTilingAllowed()) {
// In the case that padding is asymmetrical, the left upper padding has the extra pixel
for (int i = 0; i < halo.length; i ++) {paddingSize[0][i] = (int) Math.ceil(halo[i]);}
// In the case that padding is asymmetrical, the right bottom padding has one pixel less
Expand Down Expand Up @@ -373,7 +373,7 @@ private PatchSpec computePatchSpecsForOutputTensor(TensorSpec tensorSpec, PatchS
paddingSize[1] = arrayToWantedAxesOrderAddZeros(paddingSize[1], ogAxes, tensorSpec.getAxesOrder());
long[] tileSize;
long[] shapeLong;
if (tensorSpec.getShape().getReferenceInput() == null && !tensorSpec.getTiling()) {
if (tensorSpec.getShape().getReferenceInput() == null && !this.descriptor.isTilingAllowed()) {
shapeLong = Arrays.stream(tensorSpec.getTileSize()).mapToLong(i -> i).toArray();
tileSize = shapeLong;
} else if (tensorSpec.getShape().getReferenceInput() == null) {
Expand Down
18 changes: 18 additions & 0 deletions src/main/java/io/bioimage/modelrunner/tiling/TileCalculator.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package io.bioimage.modelrunner.tiling;

import java.util.List;

import io.bioimage.modelrunner.bioimageio.description.ModelDescriptor;

public class TileCalculator {

private final List<TileInfo> tileInfoList;

private final ModelDescriptor descriptor;

private TileCalculator(ModelDescriptor descriptor, List<TileInfo> tileInfoList) {
this.descriptor = descriptor;
this.tileInfoList = tileInfoList;
}

}
18 changes: 0 additions & 18 deletions src/main/java/io/bioimage/modelrunner/tiling/TileGrid.java
Original file line number Diff line number Diff line change
Expand Up @@ -160,22 +160,4 @@ public List<long[]> getRoiPostionsInImage() {
return this.roiPositionsInImage;
}

@Override
public String toString()
{
return "";
/*
String[] paddingStrArr = new String[patchPaddingSize[0].length];
for (int i = 0; i < paddingStrArr.length; i ++)
paddingStrArr[i] = patchPaddingSize[0][i] + "," + patchPaddingSize[1][i];
StringBuilder builder = new StringBuilder();
builder.append("PatchSpec of '" + tensorName + "'"
+ "[patchInputSize=").append(Arrays.toString(this.tileSize))
.append(", patchGridSize=").append(Arrays.toString(patchGridSize))
.append(", patchPaddingSize=").append(Arrays.toString(paddingStrArr))
.append("]");
return builder.toString();
*/
}

}
61 changes: 61 additions & 0 deletions src/main/java/io/bioimage/modelrunner/tiling/TileInfo.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
package io.bioimage.modelrunner.tiling;

public class TileInfo {


private final String name;

private final long[] imDims;

private final long[] proposedTileDims;

private final String imAxesOrder;

private final String tileAxesOrder;

private TileInfo(String tensorName, long[] imDims, String imAxesOrder, long[] proposedTileDims, String tileAxesOrder) {
this.name = tensorName;
this.imAxesOrder = imAxesOrder;
this.imDims = imDims;
this.proposedTileDims = proposedTileDims;
this.tileAxesOrder = tileAxesOrder;
}

/**
* @return the name
*/
public String getName() {
return name;
}

/**
* @return the imDims
*/
public long[] getImageDimensions() {
return imDims;
}

/**
* @return the proposedTileDims
*/
public long[] getProposedTileDimensions() {
return proposedTileDims;
}

/**
* @return the imAxesOrder
*/
public String getImageAxesOrder() {
return imAxesOrder;
}

/**
* @return the tileAxesOrder
*/
public String getTileAxesOrder() {
return tileAxesOrder;
}



}

0 comments on commit f51cf42

Please sign in to comment.