diff --git a/src/main/java/io/bioimage/modelrunner/bioimageio/description/TensorSpec.java b/src/main/java/io/bioimage/modelrunner/bioimageio/description/TensorSpec.java index 1bcd0efb..ba96d685 100644 --- a/src/main/java/io/bioimage/modelrunner/bioimageio/description/TensorSpec.java +++ b/src/main/java/io/bioimage/modelrunner/bioimageio/description/TensorSpec.java @@ -57,4 +57,6 @@ public interface TensorSpec { public Axes getAxesInfo(); public String getDataType(); + + public boolean isImage(); } diff --git a/src/main/java/io/bioimage/modelrunner/bioimageio/description/TensorSpecV04.java b/src/main/java/io/bioimage/modelrunner/bioimageio/description/TensorSpecV04.java index 03c1ebfd..1f04f20c 100644 --- a/src/main/java/io/bioimage/modelrunner/bioimageio/description/TensorSpecV04.java +++ b/src/main/java/io/bioimage/modelrunner/bioimageio/description/TensorSpecV04.java @@ -177,4 +177,12 @@ public String getDataType() { public int[] getHaloArr() { return this.axes.getHaloArr(); } + + @Override + public boolean isImage() { + if (axes.getAxesOrder().contains("x") && axes.getAxesOrder().contains("y")) + return true; + else + return false; + } } diff --git a/src/main/java/io/bioimage/modelrunner/bioimageio/description/TensorSpecV05.java b/src/main/java/io/bioimage/modelrunner/bioimageio/description/TensorSpecV05.java index b7e2058e..9df56279 100644 --- a/src/main/java/io/bioimage/modelrunner/bioimageio/description/TensorSpecV05.java +++ b/src/main/java/io/bioimage/modelrunner/bioimageio/description/TensorSpecV05.java @@ -180,4 +180,12 @@ public String getDataType() { public int[] getHaloArr() { return this.axes.getHaloArr(); } + + @Override + public boolean isImage() { + if (axes.getAxesOrder().contains("x") && axes.getAxesOrder().contains("y")) + return true; + else + return false; + } } diff --git a/src/main/java/io/bioimage/modelrunner/tiling/PatchGridCalculator.java b/src/main/java/io/bioimage/modelrunner/tiling/PatchGridCalculator.java index 9666acd6..c9303f43 100644 --- a/src/main/java/io/bioimage/modelrunner/tiling/PatchGridCalculator.java +++ b/src/main/java/io/bioimage/modelrunner/tiling/PatchGridCalculator.java @@ -186,11 +186,11 @@ public void checkPatchSpecs(LinkedHashMap patchSpecs) throws for (Entry spec : patchSpecs.entrySet()) { int[] nGrid = spec.getValue().getTileGrid(); TensorSpec tt = this.descriptor.findInputTensor(spec.getKey()); - if (grid == null && tt.getTiling()) { + if (grid == null && this.descriptor.isTilingAllowed()) { grid = nGrid; firstName = spec.getKey(); } - if (tt.getTiling() && !compareTwoArrays(nGrid, grid)){ + if (this.descriptor.isTilingAllowed() && !compareTwoArrays(nGrid, grid)){ throw new IllegalArgumentException("All the input images must be processed with the same number of patches.\n" + "The relationship between the patch size and image size should be the same for every input that allows patching/tiling.\n" + "Tensors '" + firstName + "' and '" + spec.getKey() + "' need different number of patches to " @@ -247,7 +247,7 @@ private List findOutputImageTensorSpec() private LinkedHashMap computePatchSpecsForEveryTensor(List tensors, List> images){ LinkedHashMap patchInfoList = new LinkedHashMap(); for (int i = 0; i < tensors.size(); i ++) - patchInfoList.put(tensors.get(i).getName(), computePatchSpecs(tensors.get(i), images.get(i))); + patchInfoList.put(tensors.get(i).getTensorID(), computePatchSpecs(tensors.get(i), images.get(i))); return patchInfoList; } @@ -272,7 +272,7 @@ public LinkedHashMap getOutputTensorsTileSpecs() throws Illeg for (int i = 0; i < outTensors.size(); i ++) { String refTensor = outTensors.get(i).getShape().getReferenceInput(); PatchSpec refSpec = refTensor == null ? inputTilesSpecs.values().stream().findFirst().get() : inputTilesSpecs.get(refTensor); - patchInfoList.put(outTensors.get(i).getName(), computePatchSpecsForOutputTensor(outTensors.get(i), refSpec)); + patchInfoList.put(outTensors.get(i).getTensorID(), computePatchSpecsForOutputTensor(outTensors.get(i), refSpec)); } outputTilesSpecs = patchInfoList; return outputTilesSpecs; @@ -331,8 +331,8 @@ private PatchSpec computePatchSpecs(TensorSpec spec, RandomAccessibleInterval int[][] paddingSize = new int[2][tileSize.length]; // REgard that the input halo represents the output halo + offset // and must be divisible by 0.5. - float[] halo = spec.getHalo(); - if (!descriptor.isPyramidal() && spec.getTiling()) { + int[] halo = spec.getHaloArr(); + if (!descriptor.isPyramidal() && this.descriptor.isTilingAllowed()) { // In the case that padding is asymmetrical, the left upper padding has the extra pixel for (int i = 0; i < halo.length; i ++) {paddingSize[0][i] = (int) Math.ceil(halo[i]);} // In the case that padding is asymmetrical, the right bottom padding has one pixel less @@ -373,7 +373,7 @@ private PatchSpec computePatchSpecsForOutputTensor(TensorSpec tensorSpec, PatchS paddingSize[1] = arrayToWantedAxesOrderAddZeros(paddingSize[1], ogAxes, tensorSpec.getAxesOrder()); long[] tileSize; long[] shapeLong; - if (tensorSpec.getShape().getReferenceInput() == null && !tensorSpec.getTiling()) { + if (tensorSpec.getShape().getReferenceInput() == null && !this.descriptor.isTilingAllowed()) { shapeLong = Arrays.stream(tensorSpec.getTileSize()).mapToLong(i -> i).toArray(); tileSize = shapeLong; } else if (tensorSpec.getShape().getReferenceInput() == null) { diff --git a/src/main/java/io/bioimage/modelrunner/tiling/TileCalculator.java b/src/main/java/io/bioimage/modelrunner/tiling/TileCalculator.java new file mode 100644 index 00000000..093cb270 --- /dev/null +++ b/src/main/java/io/bioimage/modelrunner/tiling/TileCalculator.java @@ -0,0 +1,18 @@ +package io.bioimage.modelrunner.tiling; + +import java.util.List; + +import io.bioimage.modelrunner.bioimageio.description.ModelDescriptor; + +public class TileCalculator { + + private final List tileInfoList; + + private final ModelDescriptor descriptor; + + private TileCalculator(ModelDescriptor descriptor, List tileInfoList) { + this.descriptor = descriptor; + this.tileInfoList = tileInfoList; + } + +} diff --git a/src/main/java/io/bioimage/modelrunner/tiling/TileGrid.java b/src/main/java/io/bioimage/modelrunner/tiling/TileGrid.java index 524184fe..929dd128 100644 --- a/src/main/java/io/bioimage/modelrunner/tiling/TileGrid.java +++ b/src/main/java/io/bioimage/modelrunner/tiling/TileGrid.java @@ -160,22 +160,4 @@ public List getRoiPostionsInImage() { return this.roiPositionsInImage; } - @Override - public String toString() - { - return ""; - /* - String[] paddingStrArr = new String[patchPaddingSize[0].length]; - for (int i = 0; i < paddingStrArr.length; i ++) - paddingStrArr[i] = patchPaddingSize[0][i] + "," + patchPaddingSize[1][i]; - StringBuilder builder = new StringBuilder(); - builder.append("PatchSpec of '" + tensorName + "'" - + "[patchInputSize=").append(Arrays.toString(this.tileSize)) - .append(", patchGridSize=").append(Arrays.toString(patchGridSize)) - .append(", patchPaddingSize=").append(Arrays.toString(paddingStrArr)) - .append("]"); - return builder.toString(); - */ - } - } diff --git a/src/main/java/io/bioimage/modelrunner/tiling/TileInfo.java b/src/main/java/io/bioimage/modelrunner/tiling/TileInfo.java new file mode 100644 index 00000000..137009d8 --- /dev/null +++ b/src/main/java/io/bioimage/modelrunner/tiling/TileInfo.java @@ -0,0 +1,61 @@ +package io.bioimage.modelrunner.tiling; + +public class TileInfo { + + + private final String name; + + private final long[] imDims; + + private final long[] proposedTileDims; + + private final String imAxesOrder; + + private final String tileAxesOrder; + + private TileInfo(String tensorName, long[] imDims, String imAxesOrder, long[] proposedTileDims, String tileAxesOrder) { + this.name = tensorName; + this.imAxesOrder = imAxesOrder; + this.imDims = imDims; + this.proposedTileDims = proposedTileDims; + this.tileAxesOrder = tileAxesOrder; + } + + /** + * @return the name + */ + public String getName() { + return name; + } + + /** + * @return the imDims + */ + public long[] getImageDimensions() { + return imDims; + } + + /** + * @return the proposedTileDims + */ + public long[] getProposedTileDimensions() { + return proposedTileDims; + } + + /** + * @return the imAxesOrder + */ + public String getImageAxesOrder() { + return imAxesOrder; + } + + /** + * @return the tileAxesOrder + */ + public String getTileAxesOrder() { + return tileAxesOrder; + } + + + +}