Skip to content

Commit

Permalink
refactor method
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosuc3m committed Sep 24, 2024
1 parent f0c566a commit 568aa53
Show file tree
Hide file tree
Showing 10 changed files with 42 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -71,13 +71,13 @@ public List<TileInfo> getOptimalTileSize(List<ImageInfo> inputInfo) {
List<TileInfo> firstIterationInputs = new ArrayList<TileInfo>();
for (TensorSpec tt : this.descriptor.getInputTensors()) {
ImageInfo im = inputInfo.stream()
.filter(ii -> ii.getTensorName().equals(tt.getTensorID())).findFirst().orElse(null);
.filter(ii -> ii.getTensorName().equals(tt.getName())).findFirst().orElse(null);
if (im == null)
throw new IllegalArgumentException("No data was provided for input tensor: " + tt.getTensorID());
throw new IllegalArgumentException("No data was provided for input tensor: " + tt.getName());

long[] tileSize = getOptimalTileSize(tt, im.getAxesOrder(), im.getDimensions());

firstIterationInputs.add(TileInfo.build(tt.getTensorID(), im.getDimensions(), im.getAxesOrder(), tileSize, im.getAxesOrder()));
firstIterationInputs.add(TileInfo.build(tt.getName(), im.getDimensions(), im.getAxesOrder(), tileSize, im.getAxesOrder()));
}

if (!tiling)
Expand Down Expand Up @@ -174,7 +174,7 @@ private List<TileInfo> checkOutputSize(List<TileInfo> inputs, List<TensorSpec> a
if (ax.getReferenceTensor() == null)
continue;
TensorSpec inputT = this.descriptor.findInputTensor(ax.getReferenceTensor());
TileInfo im = inputs.stream().filter(in -> in.getName().equals(inputT.getTensorID())).findFirst().orElse(null);
TileInfo im = inputs.stream().filter(in -> in.getName().equals(inputT.getName())).findFirst().orElse(null);
String refAxis = ax.getReferenceAxis();
int index = im.getTileAxesOrder().indexOf(refAxis);
Axis inAx = inputT.getAxesInfo().getAxis(refAxis);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ public static TensorSpec findTensorInList(String name, List<TensorSpec> tts)
}

return tts.stream()
.filter(t -> t.getTensorID().equals(name))
.filter(t -> t.getName().equals(name))
.findAny().orElse(null);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -611,7 +611,7 @@ public TensorSpec findInputTensor(String name)
}

return input_tensors.stream()
.filter(t -> t.getTensorID().equals(name))
.filter(t -> t.getName().equals(name))
.findAny().orElse(null);
}

Expand All @@ -630,7 +630,7 @@ public TensorSpec findOutputTensor(String name)
}

return output_tensors.stream()
.filter(t -> t.getTensorID().equals(name))
.filter(t -> t.getName().equals(name))
.findAny().orElse(null);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -476,7 +476,7 @@ public TensorSpec findInputTensor(String name)
}

return input_tensors.stream()
.filter(t -> t.getTensorID().equals(name))
.filter(t -> t.getName().equals(name))
.findAny().orElse(null);
}

Expand All @@ -495,7 +495,7 @@ public TensorSpec findOutputTensor(String name)
}

return output_tensors.stream()
.filter(t -> t.getTensorID().equals(name))
.filter(t -> t.getName().equals(name))
.findAny().orElse(null);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ public interface TensorSpec {



public String getTensorID();
public String getName();

public String getDescription();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ protected void setTestTensor(String testTensorName) {
this.testTensorName = testTensorName;
}

public String getTensorID() {
public String getName() {
return this.id;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ protected TensorSpecV05(Map<String, Object> tensorSpecMap, boolean input)
}
}

public String getTensorID() {
public String getName() {
return this.id;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@

import net.imglib2.img.ImgFactory;
import net.imglib2.img.array.ArrayImgFactory;
import net.imglib2.type.NativeType;
import net.imglib2.type.numeric.RealType;
import net.imglib2.type.numeric.real.FloatType;

/**
Expand Down Expand Up @@ -121,11 +123,12 @@ public static void main(String[] args) {
* descriptor containing the rdf.yaml information
* @throws Exception if any error occurs
*/
public static void loadAndRunModel(String modelFolder, ModelDescriptor descriptor) throws Exception {
public static <T extends RealType<T> & NativeType<T>, R extends RealType<R> & NativeType<R>>
void loadAndRunModel(String modelFolder, ModelDescriptor descriptor) throws Exception {
Model model = Model.createBioimageioModel(modelFolder, ENGINES_DIR);
model.loadModel();
List<Tensor<?>> inputs = createInputs(descriptor);
List<Tensor<?>> outputs = createOutputs(descriptor);
List<Tensor<T>> inputs = createInputs(descriptor);
List<Tensor<R>> outputs = createOutputs(descriptor);
model.runModel(inputs, outputs);
for (Tensor<?> tt : outputs) {
if (tt.isEmpty())
Expand All @@ -144,19 +147,19 @@ public static void loadAndRunModel(String modelFolder, ModelDescriptor descripto
* file containing the information
* @return the input Tensor list
*/
private static List<Tensor<?>> createInputs(ModelDescriptor descriptor) {
List<Tensor<?>> inputs = new ArrayList<Tensor<?>>();
private static <T extends RealType<T> & NativeType<T>> List<Tensor<T>> createInputs(ModelDescriptor descriptor) {
List<Tensor<T>> inputs = new ArrayList<Tensor<T>>();
final ImgFactory< FloatType > imgFactory = new ArrayImgFactory<>( new FloatType() );

for ( TensorSpec it : descriptor.getInputTensors()) {
String axesStr = it.getAxesOrder();
String name = it.getName();
int[] min = it.getShape().getTileMinimumSize();
int[] step = it.getShape().getTileStep();
int[] min = it.getMinTileSizeArr();
int[] step = it.getTileStepArr();
long[] imSize = LongStream.range(0, step.length)
.map(i -> min[(int) i] + step[(int) i]).toArray();
Tensor<FloatType> tt = Tensor.build(name, axesStr, imgFactory.create(imSize));
inputs.add(tt);
inputs.add((Tensor<T>) tt);
}
return inputs;
}
Expand All @@ -168,13 +171,13 @@ private static List<Tensor<?>> createInputs(ModelDescriptor descriptor) {
* file containing the information
* @return the output Tensor list
*/
private static List<Tensor<?>> createOutputs(ModelDescriptor descriptor) {
List<Tensor<?>> outputs = new ArrayList<Tensor<?>>();
private static <T extends RealType<T> & NativeType<T>> List<Tensor<T>> createOutputs(ModelDescriptor descriptor) {
List<Tensor<T>> outputs = new ArrayList<Tensor<T>>();

for ( TensorSpec ot : descriptor.getOutputTensors()) {
String axesStr = ot.getAxesOrder();
String name = ot.getName();
Tensor<?> tt = Tensor.buildEmptyTensor(name, axesStr);
Tensor<T> tt = Tensor.buildEmptyTensor(name, axesStr);
outputs.add(tt);
}
return outputs;
Expand Down
4 changes: 2 additions & 2 deletions src/main/java/io/bioimage/modelrunner/model/Model.java
Original file line number Diff line number Diff line change
Expand Up @@ -678,8 +678,8 @@ else if (descriptor == null)
List<Tensor<T>> runTiling(List<Tensor<R>> inputTensors, TileMaker tiles, TilingConsumer tileCounter) throws RunModelException {
List<Tensor<T>> outputTensors = new ArrayList<Tensor<T>>();
for (TensorSpec tt : descriptor.getOutputTensors()) {
long[] dims = tiles.getOutputImageSize(tt.getTensorID());
outputTensors.add((Tensor<T>) Tensor.buildBlankTensor(tt.getTensorID(),
long[] dims = tiles.getOutputImageSize(tt.getName());
outputTensors.add((Tensor<T>) Tensor.buildBlankTensor(tt.getName(),
tt.getAxesOrder(),
dims,
(T) new FloatType()));
Expand Down
28 changes: 14 additions & 14 deletions src/main/java/io/bioimage/modelrunner/tiling/TileMaker.java
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ private void getOutputTiles() {
TensorSpec intt = descriptor.getInputTensors().stream()
.filter(t -> t.isImage()).findFirst().orElse(null);
TileInfo inTile = inputTileInfo.stream()
.filter(t -> t.getName().equals(intt.getTensorID())).findFirst().orElse(null);
.filter(t -> t.getName().equals(intt.getName())).findFirst().orElse(null);
int indTile = inTile.getTileAxesOrder().indexOf(ax.getAxis());
int indIm = inTile.getImageAxesOrder().indexOf(ax.getAxis());
if (indTile == -1 || indIm == -1)
Expand Down Expand Up @@ -112,7 +112,7 @@ private void getOutputTiles() {
+ " so we can troubleshoot at: " + Constants.ISSUES_LINK);
}
}
outputTileInfo.add(TileInfo.build(tt.getTensorID(), imagSize, outAxesOrder, tileSize, outAxesOrder));
outputTileInfo.add(TileInfo.build(tt.getName(), imagSize, outAxesOrder, tileSize, outAxesOrder));
}
}

Expand Down Expand Up @@ -247,9 +247,9 @@ private void checkTilesCombine() {
private void checkAllTensorsDefined() {
for (TensorSpec tensor : this.descriptor.getInputTensors()) {
TileInfo info = inputTileInfo.stream()
.filter(tt -> tt.getName().equals(tensor.getTensorID())).findFirst().orElse(null);
.filter(tt -> tt.getName().equals(tensor.getName())).findFirst().orElse(null);
if (info == null) {
throw new IllegalArgumentException("Tiling info for input tensor '" + tensor.getTensorID()
throw new IllegalArgumentException("Tiling info for input tensor '" + tensor.getName()
+ "' not defined.");
}
}
Expand All @@ -258,17 +258,17 @@ private void checkAllTensorsDefined() {
private void calculate() {
for (TensorSpec tt : this.descriptor.getInputTensors()) {
TileInfo tile = inputTileInfo.stream()
.filter(til -> til.getName().equals(tt.getTensorID())).findFirst().orElse(null);
.filter(til -> til.getName().equals(tt.getName())).findFirst().orElse(null);
PatchSpec patch = computePatchSpecs(tt, tile);
input.put(tt.getTensorID(), patch);
inputGrid.put(tt.getTensorID(), TileGrid.create(patch));
input.put(tt.getName(), patch);
inputGrid.put(tt.getName(), TileGrid.create(patch));
}
for (TensorSpec tt : this.descriptor.getOutputTensors()) {
TileInfo tile = outputTileInfo.stream()
.filter(til -> til.getName().equals(tt.getTensorID())).findFirst().orElse(null);
.filter(til -> til.getName().equals(tt.getName())).findFirst().orElse(null);
PatchSpec patch = computePatchSpecs(tt, tile);
output.put(tt.getTensorID(), patch);
outputGrid.put(tt.getTensorID(), TileGrid.create(patch));
output.put(tt.getName(), patch);
outputGrid.put(tt.getName(), TileGrid.create(patch));
}
}

Expand Down Expand Up @@ -321,11 +321,11 @@ private PatchSpec computePatchSpecs(TensorSpec spec, TileInfo tile)
.map(i -> (int) Math.max( paddingSize[1][i],
tileSize[i] - imSize[i] - paddingSize[0][i])).toArray();

return PatchSpec.create(spec.getTensorID(), tileSize, patchGridSize, paddingSize, imSize);
return PatchSpec.create(spec.getName(), tileSize, patchGridSize, paddingSize, imSize);
}

public int getNumberOfTiles() {
return inputGrid.get(this.descriptor.getInputTensors().get(0).getTensorID()).getRoiPostionsInImage().size();
return inputGrid.get(this.descriptor.getInputTensors().get(0).getName()).getRoiPostionsInImage().size();
}

public Map<String, Integer> getTilesPerAxis() {
Expand All @@ -352,7 +352,7 @@ public void getOutputInsertionPoints(String tensorID, int nTile, String axes) {
*/
public List<String> getInputTensorNames() {
return descriptor.getInputTensors().stream()
.map(tt -> tt.getTensorID()).collect(Collectors.toList());
.map(tt -> tt.getName()).collect(Collectors.toList());
}

/**
Expand All @@ -361,7 +361,7 @@ public List<String> getInputTensorNames() {
*/
public List<String> getOutputTensorNames() {
return descriptor.getOutputTensors().stream()
.map(tt -> tt.getTensorID()).collect(Collectors.toList());
.map(tt -> tt.getName()).collect(Collectors.toList());
}

/**
Expand Down

0 comments on commit 568aa53

Please sign in to comment.