Skip to content

Commit

Permalink
correct many bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosuc3m committed Sep 23, 2024
1 parent 3e4852f commit d58b82e
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import io.bioimage.modelrunner.bioimageio.description.ModelDescriptor;
import io.bioimage.modelrunner.bioimageio.description.TensorSpec;
import io.bioimage.modelrunner.tiling.TileInfo;
import io.bioimage.modelrunner.utils.Constants;

public class TileCalculator {

Expand All @@ -28,6 +29,7 @@ public static TileCalculator init(ModelDescriptor descriptor) {
return new TileCalculator(descriptor);
}

// TODO what to do when the axes order do not coincide
private long[] getOptimalTileSize(TensorSpec tensor, String inputAxesOrder, long[] dims) {
boolean tiling = this.descriptor.isTilingAllowed();
int[] halo = tensor.getAxesInfo().getHaloArr();
Expand Down Expand Up @@ -75,7 +77,7 @@ public List<TileInfo> getOptimalTileSize(List<ImageInfo> inputInfo) {

long[] tileSize = getOptimalTileSize(tt, im.getAxesOrder(), im.getDimensions());

firstIterationInputs.add(TileInfo.build(im.getTensorName(), im.getDimensions(), im.getAxesOrder(), tileSize, im.getAxesOrder()));
firstIterationInputs.add(TileInfo.build(tt.getTensorID(), im.getDimensions(), im.getAxesOrder(), tileSize, im.getAxesOrder()));
}

if (!tiling)
Expand All @@ -95,11 +97,11 @@ public List<TileInfo> getOptimalTileSize(List<ImageInfo> inputInfo) {
secondIterationInputs.add(firstIterationInputs.get(i));
}
}
if (firstIterationInputs.size() == firstIterationInputs.size())
if (firstIterationInputs.size() == secondIterationInputs.size())
return secondIterationInputs;


List<Long> outputTotByteSizes = calculateByteSizeOfAffectedOutput(affectedTensors, null, null);
List<Long> outputTotByteSizes = calculateByteSizeOfAffectedOutput(affectedTensors, firstIterationInputs);
return checkOutputSize(firstIterationInputs, affectedTensors, outputTotByteSizes);
}

Expand Down Expand Up @@ -154,7 +156,7 @@ private List<TileInfo> checkOutputSize(List<TileInfo> inputs, List<TensorSpec> a
}
}

outByteSizes = calculateByteSizeOfAffectedOutput(affected, null, null);
outByteSizes = calculateByteSizeOfAffectedOutput(affected, null);
outRatio = outByteSizes.stream().map(ss -> (double) Integer.MAX_VALUE / (double) ss).collect(Collectors.toList());

if (Collections.min(outRatio) < 1 && Collections.min(inRatio) < 1 )
Expand Down Expand Up @@ -207,55 +209,75 @@ public void getForNTiles(int nTiles, String tensorName, long[] dims, String inpu

}

private List<Long> calculateByteSizeOfAffectedOutput(List<TensorSpec> inputTensors, List<long[]> inputSize, List<String> affectedOutputs) {
if (affectedOutputs == null || affectedOutputs.size() == 0)
return LongStream.range(0, affectedOutputs.size()).map(i -> 0L).boxed().collect(Collectors.toList());
List<String> names = inputTensors.stream().map(t -> t.getTensorID()).collect(Collectors.toList());
List<String> axesOrders = inputTensors.stream().map(t -> t.getAxesOrder()).collect(Collectors.toList());
List<TensorSpec> outputTensors = this.descriptor.getOutputTensors();
outputTensors = outputTensors.stream()
.filter(t -> {
return t.getAxesInfo().getAxesList().stream()
.filter(tt -> names.contains(tt.getReferenceTensor())).findFirst().orElse(null) != null;
}).collect(Collectors.toList());
private List<Long> calculateByteSizeOfAffectedOutput(List<TensorSpec> outputTensors, List<TileInfo> inputSize) {
if (outputTensors == null || outputTensors.size() == 0)
return new ArrayList<Long>();

List<long[]> outTiles = outputTensors.stream()
.map(t -> new long[t.getAxesInfo().getAxesList().size()]).collect(Collectors.toList());

for (int i = 0; i < outputTensors.size(); i ++) {
TensorSpec tt = outputTensors.get(i);
ArrayList<String> referencesList = new ArrayList<String>();
for (int j = 0; j < outputTensors.get(i).getAxesInfo().getAxesList().size(); j ++) {
Axis ax = tt.getAxesInfo().getAxesList().get(j);
if (ax.getStep() == 0) {
String refName = ax.getReferenceTensor();
if (refName == null && ax.getMin()!= 0) {
outTiles.get(i)[j] = ax.getMin();
continue;
} else if (refName == null) {
outTiles.get(i)[j] = -1;
continue;
}
String refName = ax.getReferenceTensor();
referencesList.add(refName);
String refAxisStr = ax.getReferenceAxis();
TensorSpec refTensor = inputTensors.get(names.indexOf(refName));
long[] refTileSize = inputSize.get(names.indexOf(refName));
String axesOrder = axesOrders.get(names.indexOf(refName));
Axis refAxis = refTensor.getAxesInfo().getAxis(refAxisStr);
TensorSpec refTensor = descriptor.findInputTensor(refName);
long[] refTileSize = inputSize.stream()
.filter(tile -> tile.getName().equals(refName)).findFirst().orElse(null).getTileDims();
String axesOrder = refTensor.getAxesOrder();
outTiles.get(i)[j] =
(long) (refTileSize[axesOrder.indexOf(refAxisStr)] * refAxis.getScale() + refAxis.getOffset());
(long) (refTileSize[axesOrder.indexOf(refAxisStr)] * ax.getScale() + ax.getOffset());
}
if (referencesList.stream().distinct().count() != 1)
throw new IllegalArgumentException(""
+ "Model specs too complex for JDLL. "
+ "Please contact the team and create and issue attaching the rdf.yaml file"
+ " so we can troubleshoot at: " + Constants.ISSUES_LINK);
else {
for (int j = 0; j < outputTensors.get(i).getAxesInfo().getAxesList().size(); j ++) {
if (outTiles.get(i)[j] != -1)
continue;
TensorSpec refInput = this.descriptor.findInputTensor(referencesList.get(0));
int ind = refInput.getAxesOrder().indexOf(outputTensors.get(i).getAxesInfo().getAxesList().get(j).getAxis());
if (ind == -1)
throw new IllegalArgumentException(""
+ "Model specs too complex for JDLL. "
+ "Please contact the team and create and issue attaching the rdf.yaml file"
+ " so we can troubleshoot at: " + Constants.ISSUES_LINK);
long[] refTileSize = inputSize.stream()
.filter(tile -> tile.getName().equals(referencesList.get(0))).findFirst().orElse(null).getTileDims();
outTiles.get(i)[j] = refTileSize[ind];
}
}
}

List<Long> flatSizes = LongStream.range(0, outTiles.size()).map(i -> 1L).boxed().collect(Collectors.toList());
List<Long> flatSizes = outTiles.stream().map(arr -> {
long a = 1L;
for (long l : arr) a *= l;
return a;
}).collect(Collectors.toList());

for (int i = 0; i < flatSizes.size(); i ++) {
if (outputTensors.get(i).getDataType().toLowerCase().equals("float32")
|| outputTensors.get(i).getDataType().toLowerCase().equals("int32")
|| outputTensors.get(i).getDataType().toLowerCase().equals("uint32"))
flatSizes.set(i, flatSizes.get(i) * 8);
flatSizes.set(i, flatSizes.get(i) * 4);
else if (outputTensors.get(i).getDataType().toLowerCase().equals("int16")
|| outputTensors.get(i).getDataType().toLowerCase().equals("uint16"))
flatSizes.set(i, flatSizes.get(i) * 2);
else if (outputTensors.get(i).getDataType().toLowerCase().equals("int64")
|| outputTensors.get(i).getDataType().toLowerCase().equals("float64"))
flatSizes.set(i, flatSizes.get(i) * 4);
for (long j : outTiles.get(i))
flatSizes.set(i, flatSizes.get(i) * j);
flatSizes.set(i, flatSizes.get(i) * 8);
}
return flatSizes;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,8 @@ public Axes getAxesInfo() {

public String getDataType() {
// TODO
return this.data.toString();
//return this.data.toString();
return "float32";
}

@Override
Expand Down
27 changes: 19 additions & 8 deletions src/main/java/io/bioimage/modelrunner/model/Model.java
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,14 @@
import io.bioimage.modelrunner.tensor.Tensor;
import io.bioimage.modelrunner.tiling.PatchSpec;
import io.bioimage.modelrunner.tiling.TileGrid;
import io.bioimage.modelrunner.tiling.TileInfo;
import io.bioimage.modelrunner.tiling.TileMaker;
import io.bioimage.modelrunner.utils.Constants;
import io.bioimage.modelrunner.versionmanagement.InstalledEngines;
import net.imglib2.FinalInterval;
import net.imglib2.RandomAccessibleInterval;
import net.imglib2.img.Img;
import net.imglib2.img.array.ArrayImgs;
import net.imglib2.type.NativeType;
import net.imglib2.type.numeric.RealType;
import net.imglib2.type.numeric.real.FloatType;
Expand Down Expand Up @@ -660,7 +663,12 @@ List<Tensor<T>> runBioimageioModelOnImgLib2WithTiling(List<Tensor<R>> inputTenso
if (descriptor == null && !(new File(modelFolder, Constants.RDF_FNAME).isFile()))
throw new IllegalArgumentException("Automatic tiling can only be done if the model contains a Bioiamge.io rdf.yaml specs file.");
else if (descriptor == null)
descriptor = ModelDescriptorFactory.readFromLocalFile(modelFolder + File.separator + Constants.RDF_FNAME);
try {
descriptor = ModelDescriptorFactory.readFromLocalFile(modelFolder + File.separator + Constants.RDF_FNAME);
} catch (ModelSpecsException | IOException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}

return runTiling(inputTensors, tiles, tileCounter);
}
Expand All @@ -678,15 +686,17 @@ List<Tensor<T>> runTiling(List<Tensor<R>> inputTensors, TileMaker tiles, TilingC
}

for (int i = 0; i < tiles.getNumberOfTiles(); i ++) {
int nTile = 0 + i;
List<Tensor<R>> inputTiles = inputTensors.stream()
.map(tt -> tiles.getNthTileInput(tt.getName(), i, tt)).collect(Collectors.toList());
.map(tt -> tiles.getNthTileInput(tt, nTile)).collect(Collectors.toList());
List<Tensor<T>> outputTiles = outputTensors.stream()
.map(tt -> tiles.getNthTileOutput(tt.getName(), i, tt)).collect(Collectors.toList());
.map(tt -> tiles.getNthTileOutput(tt, nTile)).collect(Collectors.toList());
runModel(inputTiles, outputTiles);
}
return outputTensors;
}

/** TODO remove
private <T extends RealType<T> & NativeType<T>, R extends RealType<R> & NativeType<R>>
void doTiling(List<Tensor<R>> inputTensors, List<Tensor<T>> outputTensors,
TileMaker tiles, TilingConsumer tileCounter) throws RunModelException {
Expand Down Expand Up @@ -720,20 +730,21 @@ void doTiling(List<Tensor<R>> inputTensors, List<Tensor<T>> outputTensors,
this.runModel(inputTileList, outputTileList);
}
}
*/

public static <T extends NativeType<T> & RealType<T>> void main(String[] args) throws IOException, ModelSpecsException, LoadEngineException, RunModelException, LoadModelException {
/*
String mm = "C:\\Users\\angel\\OneDrive\\Documentos\\pasteur\\git\\model-runner-java\\models\\\\EnhancerMitochondriaEM2D_22092023_133921\\";

String mm = "/home/carlos/git/JDLL/models/NucleiSegmentationBoundaryModel_17122023_143125";
Img<T> im = (Img<T>) ArrayImgs.floats(new long[] {1, 1, 512, 512});
List<Tensor<T>> l = new ArrayList<Tensor<T>>();
l.add((Tensor<T>) Tensor.build("input0", "bcyx", im));
Model model = createBioimageioModel(mm);
model.loadModel();
Map<String, int[]> tilingList = new LinkedHashMap<String, int[]>();
tilingList.put("input0", new int[] {1, 1, 256, 256});
List<Tensor<T>> out = model.runBioimageioModelOnImgLib2WithTiling(l, tilingList);
tilingList.put("input0", new int[] {1, 2, 256, 256});
List<Tensor<T>> out = model.runBioimageioModelOnImgLib2WithTiling(l);
System.out.println(false);
*/

}

/**
Expand Down

0 comments on commit d58b82e

Please sign in to comment.