Skip to content

Commit

Permalink
keep going
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosuc3m committed Sep 23, 2024
1 parent b6b4594 commit ca8e286
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 53 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,14 @@
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.LongStream;

import io.bioimage.modelrunner.bioimageio.description.Axis;
import io.bioimage.modelrunner.bioimageio.description.ModelDescriptor;
import io.bioimage.modelrunner.bioimageio.description.TensorSpec;
import io.bioimage.modelrunner.tiling.PatchGridCalculator;
import io.bioimage.modelrunner.tiling.TileInfo;

public class TileCalculator {

Expand Down Expand Up @@ -66,9 +64,9 @@ private long[] getOptimalTileSize(TensorSpec tensor, String inputAxesOrder, long
return patch;
}

public List<ImageInfo> getOptimalTileSize(List<ImageInfo> inputInfo) {
public List<TileInfo> getOptimalTileSize(List<ImageInfo> inputInfo) {
boolean tiling = this.descriptor.isTilingAllowed();
List<ImageInfo> firstIterationInputs = new ArrayList<ImageInfo>();
List<TileInfo> firstIterationInputs = new ArrayList<TileInfo>();
for (TensorSpec tt : this.descriptor.getInputTensors()) {
ImageInfo im = inputInfo.stream()
.filter(ii -> ii.getTensorName().equals(tt.getTensorID())).findFirst().orElse(null);
Expand All @@ -77,7 +75,7 @@ public List<ImageInfo> getOptimalTileSize(List<ImageInfo> inputInfo) {

long[] tileSize = getOptimalTileSize(tt, im.getAxesOrder(), im.getDimensions());

firstIterationInputs.add(new ImageInfo(im.getTensorName(), im.getAxesOrder(), tileSize));
firstIterationInputs.add(TileInfo.build(im.getTensorName(), im.getDimensions(), im.getAxesOrder(), tileSize, im.getAxesOrder()));
}

if (!tiling)
Expand All @@ -90,9 +88,9 @@ public List<ImageInfo> getOptimalTileSize(List<ImageInfo> inputInfo) {
.findFirst().orElse(null) != null;
}).collect(Collectors.toList());

List<ImageInfo> secondIterationInputs = new ArrayList<ImageInfo>();
List<TileInfo> secondIterationInputs = new ArrayList<TileInfo>();
for (int i = 0; i < firstIterationInputs.size(); i ++) {
TensorSpec tensor = descriptor.findInputTensor(firstIterationInputs.get(i).getTensorName());
TensorSpec tensor = descriptor.findInputTensor(firstIterationInputs.get(i).getName());
if (Arrays.stream(tensor.getTileStepArr()).allMatch(ii -> ii == 0)) {
secondIterationInputs.add(firstIterationInputs.get(i));
}
Expand All @@ -106,8 +104,8 @@ public List<ImageInfo> getOptimalTileSize(List<ImageInfo> inputInfo) {
}


private List<ImageInfo> checkOutputSize(List<ImageInfo> inputs, List<TensorSpec> affected, List<Long> outByteSizes) {
List<Long> totInPixels = inputs.stream().map(in -> Arrays.stream(in.getDimensions()).reduce(1, (x, y) -> x * y)).collect(Collectors.toList());
private List<TileInfo> checkOutputSize(List<TileInfo> inputs, List<TensorSpec> affected, List<Long> outByteSizes) {
List<Long> totInPixels = inputs.stream().map(in -> Arrays.stream(in.getProposedTileDimensions()).reduce(1, (x, y) -> x * y)).collect(Collectors.toList());


if (totInPixels.stream().filter(oo -> oo > OPTIMAL_MAX_NUMBER_PIXELS).findFirst().orElse(null) == null
Expand All @@ -125,7 +123,7 @@ private List<ImageInfo> checkOutputSize(List<ImageInfo> inputs, List<TensorSpec>
.sorted(Comparator.comparing(inRatio::get))
.collect(Collectors.toList());
for (Integer ind : sortedIndices) {
TensorSpec tt = this.descriptor.findInputTensor(inputs.get(ind).getTensorName());
TensorSpec tt = this.descriptor.findInputTensor(inputs.get(ind).getName());
if (Arrays.stream(tt.getTileStepArr()).allMatch(ii -> ii == 0))
continue;
argmin = ind;
Expand All @@ -134,22 +132,21 @@ private List<ImageInfo> checkOutputSize(List<ImageInfo> inputs, List<TensorSpec>
if (argmin == null) break;

Double startingRatio = inRatio.get(argmin);
ImageInfo in = inputs.get(argmin);
long[] dims = in.getDimensions();
TensorSpec tt = this.descriptor.findInputTensor(in.getTensorName());
TileInfo in = inputs.get(argmin);
TensorSpec tt = this.descriptor.findInputTensor(in.getName());
int c = 0;
for (String ax : in.getAxesOrder().split("")) {
for (String ax : in.getTileAxesOrder().split("")) {
Axis axis = tt.getAxesInfo().getAxis(ax);
if (axis.getStep() == 0) continue;
long nTot = totInPixels.get(argmin) / in.getDimensions()[c];
if ((in.getDimensions()[c] * inRatio.get(argmin) < axis.getMin()) && (axis.getMin() > 1)) {
in.getDimensions()[c] = (int)Math.ceil((double) 100 / (double) axis.getStep()) * axis.getStep();
} else if (in.getDimensions()[c] * inRatio.get(argmin) < axis.getMin()) {
in.getDimensions()[c] = axis.getMin();
long nTot = totInPixels.get(argmin) / in.getProposedTileDimensions()[c];
if ((in.getProposedTileDimensions()[c] * inRatio.get(argmin) < axis.getMin()) && (axis.getMin() > 1)) {
in.getProposedTileDimensions()[c] = (int)Math.ceil((double) 100 / (double) axis.getStep()) * axis.getStep();
} else if (in.getProposedTileDimensions()[c] * inRatio.get(argmin) < axis.getMin()) {
in.getProposedTileDimensions()[c] = axis.getMin();
} else {
in.getDimensions()[c] = (long) (Math.floor((in.getDimensions()[c] * inRatio.get(argmin) - axis.getMin()) / axis.getStep()) * axis.getStep() + axis.getMin());
in.getProposedTileDimensions()[c] = (long) (Math.floor((in.getProposedTileDimensions()[c] * inRatio.get(argmin) - axis.getMin()) / axis.getStep()) * axis.getStep() + axis.getMin());
}
totInPixels.set(argmin, nTot * in.getDimensions()[c]);
totInPixels.set(argmin, nTot * in.getProposedTileDimensions()[c]);
inRatio = totInPixels.stream().map(ss -> (double) OPTIMAL_MAX_NUMBER_PIXELS / (double) ss).collect(Collectors.toList());

if (startingRatio == inRatio.get(argmin))
Expand All @@ -175,20 +172,20 @@ private List<ImageInfo> checkOutputSize(List<ImageInfo> inputs, List<TensorSpec>
if (ax.getReferenceTensor() == null)
continue;
TensorSpec inputT = this.descriptor.findInputTensor(ax.getReferenceTensor());
ImageInfo im = inputs.stream().filter(in -> in.getTensorName().equals(inputT.getTensorID())).findFirst().orElse(null);
TileInfo im = inputs.stream().filter(in -> in.getName().equals(inputT.getTensorID())).findFirst().orElse(null);
String refAxis = ax.getReferenceAxis();
int index = im.getAxesOrder().indexOf(refAxis);
int index = im.getTileAxesOrder().indexOf(refAxis);
Axis inAx = inputT.getAxesInfo().getAxis(refAxis);
long size = im.getDimensions()[index];
long size = im.getProposedTileDimensions()[index];

if ((size * outRatio.get(argmin) < inAx.getMin()) && (inAx.getMin() > 1)) {
im.getDimensions()[index] = (int)Math.ceil((double) 100 / (double) inAx.getStep()) * inAx.getStep();
im.getProposedTileDimensions()[index] = (int)Math.ceil((double) 100 / (double) inAx.getStep()) * inAx.getStep();
} else if (size * outRatio.get(argmin) < inAx.getMin()) {
im.getDimensions()[index] = inAx.getMin();
im.getProposedTileDimensions()[index] = inAx.getMin();
} else {
im.getDimensions()[index] = (long) (Math.floor((size * outRatio.get(argmin) - inAx.getMin()) / inAx.getStep()) * inAx.getStep() + inAx.getMin());
im.getProposedTileDimensions()[index] = (long) (Math.floor((size * outRatio.get(argmin) - inAx.getMin()) / inAx.getStep()) * inAx.getStep() + inAx.getMin());
}
double change = (size * ax.getScale() + 2 * ax.getOffset()) / (im.getDimensions()[index] * ax.getScale() + 2 * ax.getOffset());
double change = (size * ax.getScale() + 2 * ax.getOffset()) / (im.getProposedTileDimensions()[index] * ax.getScale() + 2 * ax.getOffset());
outRatio.set(argmin, outRatio.get(argmin) * change);
if (outRatio.get(argmin) > 1)
break;
Expand Down
12 changes: 9 additions & 3 deletions src/main/java/io/bioimage/modelrunner/model/Model.java
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
import java.util.stream.IntStream;
import java.util.stream.LongStream;

import io.bioimage.modelrunner.bioimageio.ImageInfo;
import io.bioimage.modelrunner.bioimageio.TileCalculator;
import io.bioimage.modelrunner.bioimageio.bioengine.BioEngineAvailableModels;
import io.bioimage.modelrunner.bioimageio.bioengine.BioengineInterface;
import io.bioimage.modelrunner.bioimageio.description.ModelDescriptor;
Expand All @@ -51,7 +53,6 @@
import io.bioimage.modelrunner.exceptions.LoadModelException;
import io.bioimage.modelrunner.exceptions.RunModelException;
import io.bioimage.modelrunner.tensor.Tensor;
import io.bioimage.modelrunner.tiling.PatchGridCalculator;
import io.bioimage.modelrunner.tiling.PatchSpec;
import io.bioimage.modelrunner.tiling.TileGrid;
import io.bioimage.modelrunner.tiling.TileMaker;
Expand Down Expand Up @@ -593,8 +594,13 @@ List<Tensor<T>> runBioimageioModelOnImgLib2WithTiling(List<Tensor<R>> inputTenso
throw new IllegalArgumentException("Automatic tiling can only be done if the model contains a Bioiamge.io rdf.yaml specs file.");
else if (descriptor == null)
descriptor = ModelDescriptorFactory.readFromLocalFile(modelFolder + File.separator + Constants.RDF_FNAME);
PatchGridCalculator<R> tileGrid = PatchGridCalculator.build(descriptor, inputTensors);
return runTiling(inputTensors, tileGrid, tileCounter);
TileCalculator calc = TileCalculator.init(descriptor);
List<ImageInfo> imageInfos = inputTensors.stream()
.map(tt -> new ImageInfo(tt.getName(), tt.getAxesOrderString(), tt.getData().dimensionsAsLongArray()))
.collect(Collectors.toList());
List<TileInfo> inputTiles = calc.getOptimalTileSize(imageInfos);
TileMaker tiles = TileMaker.build(descriptor, inputTiles);
return runTiling(inputTensors, tiles, tileCounter);
}

/**
Expand Down
57 changes: 36 additions & 21 deletions src/main/java/io/bioimage/modelrunner/tiling/TileMaker.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,11 @@
import io.bioimage.modelrunner.bioimageio.description.TensorSpec;
import io.bioimage.modelrunner.tensor.Tensor;
import io.bioimage.modelrunner.utils.Constants;
import net.imglib2.FinalInterval;
import net.imglib2.RandomAccessibleInterval;
import net.imglib2.type.NativeType;
import net.imglib2.type.numeric.RealType;
import net.imglib2.view.Views;

public class TileMaker {

Expand Down Expand Up @@ -257,7 +259,7 @@ private static void checkAxisSize(TileInfo tile) {
}

private void checkTilesCombine() {

// TODO
}

private void checkAllTensorsDefined() {
Expand Down Expand Up @@ -341,7 +343,7 @@ private PatchSpec computePatchSpecs(TensorSpec spec, TileInfo tile)
}

public int getNumberOfTiles() {
return 0;
return inputGrid.get(this.descriptor.getInputTensors().get(0).getTensorID()).getRoiPostionsInImage().size();
}

public Map<String, Integer> getTilesPerAxis() {
Expand Down Expand Up @@ -450,40 +452,53 @@ public List<long[]> getTilePostionsOutputImage(String tensorID) {
return outputGrid.get(tensorID).getTilePostionsInImage();
}

public <T extends NativeType<T> & RealType<T>> RandomAccessibleInterval<T> getNthTileInput(String tensorID, int n, RandomAccessibleInterval<T> rai) {
public <T extends NativeType<T> & RealType<T>> RandomAccessibleInterval<T> getNthTileInput(String tensorID, RandomAccessibleInterval<T> rai, int n) {
List<long[]> tiles = this.getTilePostionsOutputImage(tensorID);
if (tiles.size() >= n) {
throw new IllegalArgumentException();
throw new IllegalArgumentException("There are only " + tiles.size() + " tiles. Tile " + n
+ " is out of bounds.");
}
return null;
long[] minLim = tiles.get(n);
long[] size = this.getInputTileSize(tensorID);
long[] maxLim = new long[size.length];
for (int i = 0; i < size.length; i ++) maxLim[i] = minLim[i] + maxLim[i] - 1;
RandomAccessibleInterval<T> tileRai = Views.interval(
Views.extendMirrorDouble(rai), new FinalInterval( minLim, maxLim ));
return tileRai;
}

public <T extends NativeType<T> & RealType<T>> RandomAccessibleInterval<T> getNthTileOutput(String tensorID, int n, RandomAccessibleInterval<T> rai) {
public <T extends NativeType<T> & RealType<T>> RandomAccessibleInterval<T> getNthTileOutput(String tensorID, RandomAccessibleInterval<T> rai, int n) {
List<long[]> tiles = this.getTilePostionsOutputImage(tensorID);
if (tiles.size() >= n) {
throw new IllegalArgumentException();
throw new IllegalArgumentException("There are only " + tiles.size() + " tiles. Tile " + n
+ " is out of bounds.");
}
return null;
long[] minLim = tiles.get(n);
long[] size = this.getOutputTileSize(tensorID);
long[] maxLim = new long[size.length];
for (int i = 0; i < size.length; i ++) maxLim[i] = minLim[i] + maxLim[i] - 1;
RandomAccessibleInterval<T> tileRai = Views.interval(
Views.extendMirrorDouble(rai), new FinalInterval( minLim, maxLim ));
return tileRai;
}

public <T extends NativeType<T> & RealType<T>> Tensor<T> getNthTileInput(String tensorID, int n, Tensor<T> tensor) {
List<long[]> tiles = this.getTilePostionsOutputImage(tensorID);
if (tiles.size() >= n) {
throw new IllegalArgumentException();
}
return null;
public <T extends NativeType<T> & RealType<T>> Tensor<T> getNthTileInput(Tensor<T> tensor, int n) {
RandomAccessibleInterval<T> rai = getNthTileInput(tensor.getName(), tensor.getData(), n);
return Tensor.build(tensor.getName(), tensor.getAxesOrderString(), rai);
}

public <T extends NativeType<T> & RealType<T>> Tensor<T> getNthTileOutput(String tensorID, int n, Tensor<T> tensor) {
List<long[]> tiles = this.getTilePostionsOutputImage(tensorID);
if (tiles.size() >= n) {
throw new IllegalArgumentException();
}
return null;
public <T extends NativeType<T> & RealType<T>> Tensor<T> getNthTileOutput(Tensor<T> tensor, int n) {
RandomAccessibleInterval<T> rai = getNthTileOutput(tensor.getName(), tensor.getData(), n);
return Tensor.build(tensor.getName(), tensor.getAxesOrderString(), rai);
}

public long[] getOutputImageSize(String tensorID) {
return null;
TileInfo tile = this.outputTileInfo.stream()
.filter(tt -> tt.getName().equals(tensorID)).findFirst().orElse(null);
if (tile == null)
throw new IllegalArgumentException("The tensor ID proposed does not correspond to an output tensor: "
+ "'" + tensorID + "'.");
return tile.getImageDimensions();
}

/**
Expand Down

0 comments on commit ca8e286

Please sign in to comment.