Skip to content

Commit

Permalink
fallback if the tf names are not perfect
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosuc3m committed Dec 15, 2023
1 parent 31a5eda commit c73789f
Showing 1 changed file with 30 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import io.bioimage.modelrunner.utils.ZipUtils;
import net.imglib2.type.NativeType;
import net.imglib2.type.numeric.RealType;
import net.imglib2.type.numeric.real.FloatType;

import java.io.BufferedReader;
import java.io.File;
Expand Down Expand Up @@ -312,15 +313,17 @@ public void run(List<Tensor<?>> inputTensors, List<Tensor<?>> outputTensors)
List<String> inputListNames = new ArrayList<String>();
List<org.tensorflow.Tensor<?>> inTensors =
new ArrayList<org.tensorflow.Tensor<?>>();
int c = 0;
for (Tensor tt : inputTensors) {
inputListNames.add(tt.getName());
org.tensorflow.Tensor<?> inT = TensorBuilder.build(tt);
inTensors.add(inT);
runner.feed(getModelInputName(tt.getName()), inT);
String inputName = getModelInputName(tt.getName(), c ++);
runner.feed(inputName, inT);
}

c = 0;
for (Tensor tt : outputTensors)
runner = runner.fetch(getModelOutputName(tt.getName()));
runner = runner.fetch(getModelOutputName(tt.getName(), c ++));
// Run runner
List<org.tensorflow.Tensor<?>> resultPatchTensors = runner.run();

Expand Down Expand Up @@ -426,10 +429,14 @@ public void closeModel() {
* the signature input name.
*
* @param inputName Signature input name.
* @param i position of the input of interest in the list of inputs
* @return The readable input name.
*/
public static String getModelInputName(String inputName) {
public static String getModelInputName(String inputName, int i) {
TensorInfo inputInfo = sig.getInputsMap().getOrDefault(inputName, null);
if (inputInfo == null) {
inputInfo = sig.getInputsMap().values().stream().collect(Collectors.toList()).get(i);
}
if (inputInfo != null) {
String modelInputName = inputInfo.getName();
if (modelInputName != null) {
Expand All @@ -452,10 +459,14 @@ public static String getModelInputName(String inputName) {
* given the signature output name.
*
* @param outputName Signature output name.
* @param i position of the input of interest in the list of inputs
* @return The readable output name.
*/
public static String getModelOutputName(String outputName) {
public static String getModelOutputName(String outputName, int i) {
TensorInfo outputInfo = sig.getOutputsMap().getOrDefault(outputName, null);
if (outputInfo == null) {
outputInfo = sig.getOutputsMap().values().stream().collect(Collectors.toList()).get(i);
}
if (outputInfo != null) {
String modelOutputName = outputInfo.getName();
if (modelOutputName.endsWith(":0")) {
Expand Down Expand Up @@ -495,6 +506,20 @@ public static String getModelOutputName(String outputName) {
* @throws RunModelException if there is any error running the model
*/
public static void main(String[] args) throws LoadModelException, IOException, RunModelException {
if (args.length == 0) {
String modelFolder = "/home/carlos/git/deep-icy/models/stardist_1channel";
Tensorflow1Interface ti = new Tensorflow1Interface(false);
ti.loadModel(modelFolder, modelFolder);
Tensor<FloatType> inp = Tensor.buildBlankTensor("in", "byxc", new long[] {1, 208, 208, 1}, new FloatType());
Tensor<FloatType> out = Tensor.buildEmptyTensor("out", "byxc");
List<Tensor<?>> inps = new ArrayList<Tensor<?>>();
inps.add(inp);
List<Tensor<?>> outs = new ArrayList<Tensor<?>>();
outs.add(out);
ti.run(inps, outs);
System.out.println(false);
return;
}
// Unpack the args needed
if (args.length < 4)
throw new IllegalArgumentException("Error exectuting Tensorflow 1, "
Expand Down

0 comments on commit c73789f

Please sign in to comment.