Skip to content

Commit

Permalink
keep improving the tiling
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosuc3m committed Sep 23, 2024
1 parent 60e8945 commit c32a330
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 50 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,18 @@
import io.bioimage.modelrunner.bioimageio.description.TensorSpec;
import io.bioimage.modelrunner.tiling.PatchGridCalculator;

public class TileFactory {
public class TileCalculator {

private final ModelDescriptor descriptor;

private static final long OPTIMAL_MAX_NUMBER_PIXELS = 4096 * 4096 * 3;

private TileFactory(ModelDescriptor descriptor) {
private TileCalculator(ModelDescriptor descriptor) {
this.descriptor = descriptor;
}

public static TileFactory init(ModelDescriptor descriptor) {
return new TileFactory(descriptor);
public static TileCalculator init(ModelDescriptor descriptor) {
return new TileCalculator(descriptor);
}

private long[] getOptimalTileSize(TensorSpec tensor, String inputAxesOrder, long[] dims) {
Expand Down Expand Up @@ -202,11 +202,11 @@ private List<ImageInfo> checkOutputSize(List<ImageInfo> inputs, List<TensorSpec>
return null;
}

public void validateTileSize(String tensorName, long[] dims, String inputAxesOrder) {
public void getTilesForNPixels(String tensorName, long[] dims, String inputAxesOrder) {

}

public void getTileSizeForNTiles(int nTiles, String tensorName, long[] dims, String inputAxesOrder) {
public void getForNTiles(int nTiles, String tensorName, long[] dims, String inputAxesOrder) {

}

Expand Down
41 changes: 8 additions & 33 deletions src/main/java/io/bioimage/modelrunner/model/Model.java
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
import io.bioimage.modelrunner.tiling.PatchGridCalculator;
import io.bioimage.modelrunner.tiling.PatchSpec;
import io.bioimage.modelrunner.tiling.TileGrid;
import io.bioimage.modelrunner.tiling.TileMaker;
import io.bioimage.modelrunner.utils.Constants;
import io.bioimage.modelrunner.versionmanagement.InstalledEngines;
import net.imglib2.FinalInterval;
Expand Down Expand Up @@ -616,8 +617,8 @@ else if (descriptor == null)
*/
public <T extends RealType<T> & NativeType<T>, R extends RealType<R> & NativeType<R>>
List<Tensor<T>> runBioimageioModelOnImgLib2WithTiling(List<Tensor<R>> inputTensors,
Map<String, int[]> tileMap) throws ModelSpecsException, RunModelException {
return runBioimageioModelOnImgLib2WithTiling(inputTensors, tileMap, null);
TileMaker tiles) throws ModelSpecsException, RunModelException {
return runBioimageioModelOnImgLib2WithTiling(inputTensors, tiles, null);
}

/**
Expand All @@ -644,39 +645,21 @@ List<Tensor<T>> runBioimageioModelOnImgLib2WithTiling(List<Tensor<R>> inputTenso
*/
public <T extends RealType<T> & NativeType<T>, R extends RealType<R> & NativeType<R>>
List<Tensor<T>> runBioimageioModelOnImgLib2WithTiling(List<Tensor<R>> inputTensors,
Map<String, int[]> tileMap, TilingConsumer tileCounter) throws ModelSpecsException, RunModelException {
TileMaker tiles, TilingConsumer tileCounter) throws ModelSpecsException, RunModelException {

if (!this.isLoaded())
throw new RunModelException("Please first load the model.");
if (descriptor == null && !(new File(modelFolder, Constants.RDF_FNAME).isFile()))
throw new IllegalArgumentException("Automatic tiling can only be done if the model contains a Bioiamge.io rdf.yaml specs file.");
else if (descriptor == null)
descriptor = ModelDescriptorFactory.readFromLocalFile(modelFolder + File.separator + Constants.RDF_FNAME);
for (TensorSpec t : descriptor.getInputTensors()) {
if (!t.isImage())
continue;
if (tileMap.get(t.getTensorID()) == null)
throw new RunModelException("Either provide the wanted tile size for every image tensor ("
+ "in this case tenso '" + t.getTensorID() + "' is missing) or let JDLL compute all "
+ "tile sizes automatically using runBioimageioModelOnImgLib2WithTiling(List<Tensor<R>> inputTensors).");
if (Tensor.getTensorByNameFromList(inputTensors, t.getTensorID()) == null)
throw new RunModelException("Required tensor named '" + t.getTensorID() + "' is missing from the list of input tensors");
try {
t.setTileSizeForTensorAndImageSize(tileMap.get(t.getTensorID()), Tensor.getTensorByNameFromList(inputTensors, t.getTensorID()).getShape());
} catch (Exception e) {
throw new RunModelException(e.getMessage());
}
}
PatchGridCalculator<R> tileGrid = PatchGridCalculator.build(descriptor, inputTensors);

return runTiling(inputTensors, tileGrid, tileCounter);
}

@SuppressWarnings("unchecked")
private <T extends RealType<T> & NativeType<T>, R extends RealType<R> & NativeType<R>>
List<Tensor<T>> runTiling(List<Tensor<R>> inputTensors,
PatchGridCalculator<R> tileGrid, TilingConsumer tileCounter) throws RunModelException {
LinkedHashMap<String, PatchSpec> inTileSpecs = tileGrid.getInputTensorsTileSpecs();
LinkedHashMap<String, PatchSpec> outTileSpecs = tileGrid.getOutputTensorsTileSpecs();
List<Tensor<T>> runTiling(List<Tensor<R>> inputTensors, TileMaker tiles, TilingConsumer tileCounter) throws RunModelException {
List<Tensor<T>> outputTensors = new ArrayList<Tensor<T>>();
for (TensorSpec tt : descriptor.getOutputTensors()) {
if (outTileSpecs.get(tt.getTensorID()) == null)
Expand All @@ -693,16 +676,8 @@ List<Tensor<T>> runTiling(List<Tensor<R>> inputTensors,

private <T extends RealType<T> & NativeType<T>, R extends RealType<R> & NativeType<R>>
void doTiling(List<Tensor<R>> inputTensors, List<Tensor<T>> outputTensors,
PatchGridCalculator<R> tileGrid, TilingConsumer tileCounter) throws RunModelException {
LinkedHashMap<String, PatchSpec> inTileSpecs = tileGrid.getInputTensorsTileSpecs();
LinkedHashMap<String, PatchSpec> outTileSpecs = tileGrid.getOutputTensorsTileSpecs();
Map<Object, TileGrid> inTileGrids = inTileSpecs.entrySet().stream()
.collect(Collectors.toMap(entry -> entry.getKey(), entry -> TileGrid.create(entry.getValue())));
Map<Object, TileGrid> outTileGrids = outTileSpecs.entrySet().stream()
.collect(Collectors.toMap(entry -> entry.getKey(), entry -> TileGrid.create(entry.getValue())));
int[] tilesPerAxis = inTileSpecs.values().stream().findFirst().get().getTileGrid();
int nTiles = 1;
for (int i : tilesPerAxis) nTiles *= i;
TileMaker tiles, TilingConsumer tileCounter) throws RunModelException {
int nTiles = tiles.getNumberOfTiles();
tileCounter.acceptTotal(Long.valueOf(nTiles));
for (int j = 0; j < nTiles; j ++) {
tileCounter.acceptProgress(Long.valueOf(j));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,16 @@
import java.util.Arrays;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

import io.bioimage.modelrunner.bioimageio.TileFactory;
import io.bioimage.modelrunner.bioimageio.description.Axis;
import io.bioimage.modelrunner.bioimageio.description.ModelDescriptor;
import io.bioimage.modelrunner.bioimageio.description.TensorSpec;
import io.bioimage.modelrunner.tensor.Tensor;
import io.bioimage.modelrunner.utils.Constants;
import net.imglib2.RandomAccessibleInterval;

public class TileCalculator {
public class TileMaker {

private final List<TileInfo> inputTileInfo;

Expand All @@ -31,15 +29,15 @@ public class TileCalculator {

private final LinkedHashMap<String, TileGrid> outputGrid = new LinkedHashMap<String, TileGrid>();

private TileCalculator(ModelDescriptor descriptor, List<TileInfo> tileInfoList) {
private TileMaker(ModelDescriptor descriptor, List<TileInfo> tileInfoList) {
this.descriptor = descriptor;
this.inputTileInfo = tileInfoList;
validate();
calculate();
}

public static TileCalculator build(ModelDescriptor descriptor, List<TileInfo> tileInfoList) {
return new TileCalculator(descriptor, tileInfoList);
public static TileMaker build(ModelDescriptor descriptor, List<TileInfo> tileInfoList) {
return new TileMaker(descriptor, tileInfoList);
}

private void validate() {
Expand Down Expand Up @@ -337,10 +335,14 @@ private PatchSpec computePatchSpecs(TensorSpec spec, TileInfo tile)

return PatchSpec.create(spec.getTensorID(), tileSize, patchGridSize, paddingSize, imSize);
}

public void getTileList() {

}

public int getNumberOfTiles() {
return 0;
}

public Map<String, Integer> getTilesPerAxis() {
return null;
}

public void getInputInsertionPoints(String tensorId, int nTile, String axes) {
TileInfo tile = this.inputTileInfo.stream().filter(t -> t.getName().equals(tensorId)).findFirst().orElse(null);
Expand Down Expand Up @@ -444,6 +446,22 @@ public List<long[]> getTilePostionsOutputImage(String tensorId) {
return outputGrid.get(tensorId).getTilePostionsInImage();
}

public long[] getNthTileInput(String tensorId, int n) {
List<long[]> tiles = this.getTilePostionsOutputImage(tensorId);
if (tiles.size() >= n) {
throw new IllegalArgumentException();
}
return tiles.get(n);
}

public long[] getNthTileOutput(String tensorId, int n) {
List<long[]> tiles = this.getTilePostionsOutputImage(tensorId);
if (tiles.size() >= n) {
throw new IllegalArgumentException();
}
return tiles.get(n);
}

/**
* Convert the array following given axes order into
* another int[] which follows the target axes order
Expand Down

0 comments on commit c32a330

Please sign in to comment.