Skip to content

Commit

Permalink
keep iterating
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosuc3m committed Aug 6, 2024
1 parent c418a1d commit d6aed6f
Show file tree
Hide file tree
Showing 8 changed files with 48 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@
import java.util.stream.IntStream;
import java.util.stream.LongStream;

import io.bioimage.modelrunner.bioimageio.description.Axis;
import io.bioimage.modelrunner.bioimageio.description.ModelDescriptor;
import io.bioimage.modelrunner.bioimageio.description.TensorSpec;
import io.bioimage.modelrunner.bioimageio.description.axes.axis.Axis;
import io.bioimage.modelrunner.tiling.PatchGridCalculator;

public class TileFactory {
Expand All @@ -32,10 +32,9 @@ public static TileFactory init(ModelDescriptor descriptor) {

private long[] getOptimalTileSize(TensorSpec tensor, String inputAxesOrder, long[] dims) {
boolean tiling = this.descriptor.isTilingAllowed();
int[] halo = descriptor.getTotalHalo();
int[] halo = tensor.getAxesInfo().getHaloArr();
int[] min = tensor.getMinTileSizeArr();
int[] step = tensor.getTileStepArr();
double[] scale = tensor.getTileScaleArr();

long[] patch = new long[inputAxesOrder.length()];
String seqSizeAxesUpper = inputAxesOrder.toUpperCase();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ public interface Axes {

public double[] getTileScaleArr();

public int[] getHaloArr();

public Axis getAxis(String abreviation);

}
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ public class AxesV04 implements Axes {

private final int[] stepArr;

private final int[] haloArr;
private int[] haloArr;

private final double[] offsetArr;

Expand Down Expand Up @@ -118,6 +118,9 @@ public double[] getOffsetArr() {
}

public int[] getHaloArr() {
haloArr = new int[this.axesList.size()];
for (int i = 0; i < this.axesList.size(); i ++)
haloArr[i] = this.axesList.get(i).getHalo();
return this.haloArr;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ public class AxesV05 implements Axes {

private final int[] stepArr;

private int[] haloArr;

protected AxesV05(List<Object> axesList) {
List<Axis> axesListInit = new ArrayList<Axis>();
String order = "";
Expand Down Expand Up @@ -61,6 +63,13 @@ public int[] getTileStepArr() {
return this.stepArr;
}

public int[] getHaloArr() {
haloArr = new int[this.axesList.size()];
for (int i = 0; i < this.axesList.size(); i ++)
haloArr[i] = this.axesList.get(i).getHalo();
return this.haloArr;
}

public double[] getTileScaleArr() {
return this.scaleArr;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ public interface TensorSpec {

public int[] getTileStepArr();

public int[] getHaloArr();

public double[] getTileScaleArr();

public Axes getAxesInfo();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,4 +172,9 @@ public String getDataType() {
// TODO
return this.data.toString();
}

@Override
public int[] getHaloArr() {
return this.axes.getHaloArr();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -175,4 +175,9 @@ public String getDataType() {
// TODO
return this.data.toString();
}

@Override
public int[] getHaloArr() {
return this.axes.getHaloArr();
}
}
21 changes: 19 additions & 2 deletions src/main/java/io/bioimage/modelrunner/model/Model.java
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import java.io.FileNotFoundException;
import java.io.IOException;
import java.net.MalformedURLException;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
Expand Down Expand Up @@ -171,7 +172,15 @@ public static Model createDeepLearningModel( String modelFolder, String modelSou
&& !engineInfo.getFramework().equals(EngineInfo.getTensorflowKey())
&& !engineInfo.getFramework().equals(EngineInfo.getBioimageioTfKey()) )
Objects.requireNonNull(modelSource);
return new Model( engineInfo, modelFolder, modelSource, null );
Model model = new Model( engineInfo, modelFolder, modelSource, null );

if (Paths.get(modelFolder, Constants.RDF_FNAME).toFile().isFile()) {
try {
model.descriptor = ModelDescriptorFactory.readFromLocalFile(Paths.get(modelFolder, Constants.RDF_FNAME).toAbsolutePath().toString());
} catch (ModelSpecsException | IOException e) {
}
}
return model;
}

/**
Expand Down Expand Up @@ -412,7 +421,15 @@ public static Model createDeepLearningModel( String modelFolder, String modelSou
&& !engineInfo.getFramework().equals(EngineInfo.getTensorflowKey())
&& !engineInfo.getFramework().equals(EngineInfo.getBioimageioTfKey()))
Objects.requireNonNull(modelSource);
return new Model( engineInfo, modelFolder, modelSource, classLoader );
Model model = new Model( engineInfo, modelFolder, modelSource, classLoader );

if (Paths.get(modelFolder, Constants.RDF_FNAME).toFile().isFile()) {
try {
model.descriptor = ModelDescriptorFactory.readFromLocalFile(Paths.get(modelFolder, Constants.RDF_FNAME).toAbsolutePath().toString());
} catch (ModelSpecsException | IOException e) {
}
}
return model;
}

/**
Expand Down

0 comments on commit d6aed6f

Please sign in to comment.