Skip to content

Commit

Permalink
update example to run end to end
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosuc3m committed Sep 5, 2023
1 parent 6aca72e commit 016390a
Showing 1 changed file with 36 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,18 @@
*/
package io.bioimage.modelrunner.example;

import io.bioimage.modelrunner.bioimageio.BioimageioRepo;
import io.bioimage.modelrunner.engine.EngineInfo;
import io.bioimage.modelrunner.engine.installation.EngineInstall;
import io.bioimage.modelrunner.exceptions.LoadEngineException;
import io.bioimage.modelrunner.model.Model;
import io.bioimage.modelrunner.tensor.Tensor;
import io.bioimage.modelrunner.versionmanagement.AvailableEngines;
import io.bioimage.modelrunner.versionmanagement.DeepLearningVersion;
import io.bioimage.modelrunner.versionmanagement.InstalledEngines;

import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;

Expand Down Expand Up @@ -62,26 +68,33 @@ public class ExampleLoadAndRunModel {
* @throws Exception
*/
public static < T extends RealType< T > & NativeType< T > > void main(String[] args) throws LoadEngineException, Exception {

// Tag for the DL framework (engine) that wants to be used
String engine = "torchscript";
String framework = "torchscript";
// Version of the engine
String engineVersion = "1.13.1";
// Directory where all the engines are stored
String enginesDir = ENGINES_DIR;
// Path to the model folder
String modelFolder = new File(MODELS_DIR, "EnhancerMitochondriaEM2D_13012023_130426").getAbsolutePath();
downloadTorchscriptEngine(framework, engineVersion, enginesDir);

String bmzModelName = "EnhancerMitochondriaEM2D";
String modelFolder = downloadBMZModel(bmzModelName, MODELS_DIR);

// Path to the model source. The model source locally is the path to the source file defined in the
// yaml inside the model folder
String modelSource = new File(modelFolder, "weights-torchscript.pt").getAbsolutePath();
// Whether the engine is supported by CPu or not
boolean cpu = true;
List<DeepLearningVersion> installedList =
InstalledEngines.checkEngineWithArgsInstalledForOS(framework, engineVersion,
cpu, null, enginesDir);
// Whether the engine is supported by GPU or not
boolean gpu = true;
boolean gpu = installedList.get(0).getGPU();
// Create the EngineInfo object. It is needed to load the wanted DL framework
// among all the installed ones. The EngineInfo loads the corresponding engine by looking
// at the enginesDir at searching for the folder that is named satisfying the characteristics specified.
// REGARD THAT the engine folders need to follow a naming convention
EngineInfo engineInfo = createEngineInfo(engine, engineVersion, enginesDir, cpu, gpu);
EngineInfo engineInfo = createEngineInfo(framework, engineVersion, enginesDir, cpu, gpu);
// Load the corresponding model
Model model = loadModel(modelFolder, modelSource, engineInfo);
// Create an image that will be the backend of the Input Tensor
Expand Down Expand Up @@ -119,6 +132,24 @@ public static < T extends RealType< T > & NativeType< T > > void main(String[] a
System.out.print("Success!!");
}

public static void downloadTorchscriptEngine(String framework, String engineVersion,
String enginesDir) throws IOException, InterruptedException {
List<DeepLearningVersion> possibleEngines =
AvailableEngines.getEnginesForOsByParams(framework, engineVersion, true, null);
boolean success = EngineInstall.installEngineInDir(possibleEngines.get(0), enginesDir);

if (!success)
throw new IOException("The wanted DL engine was not downloaed correctly: "
+ possibleEngines.get(0).folderName());
}


public static String downloadBMZModel(String bmzModelName, String modelsDir) throws IOException, InterruptedException {
// Create an instance of the BioimageRepo object
BioimageioRepo br = BioimageioRepo.connect();
return br.downloadByName(bmzModelName, modelsDir);
}

/**
* Method that creates the {@link EngineInfo} object.
* @param engine
Expand Down

0 comments on commit 016390a

Please sign in to comment.