diff --git a/src/main/java/io/bioimage/modelrunner/example/ExampleLoadAndRunModel.java b/src/main/java/io/bioimage/modelrunner/example/ExampleLoadAndRunModel.java index 6fadf5b6..4a1b2a23 100644 --- a/src/main/java/io/bioimage/modelrunner/example/ExampleLoadAndRunModel.java +++ b/src/main/java/io/bioimage/modelrunner/example/ExampleLoadAndRunModel.java @@ -19,12 +19,18 @@ */ package io.bioimage.modelrunner.example; +import io.bioimage.modelrunner.bioimageio.BioimageioRepo; import io.bioimage.modelrunner.engine.EngineInfo; +import io.bioimage.modelrunner.engine.installation.EngineInstall; import io.bioimage.modelrunner.exceptions.LoadEngineException; import io.bioimage.modelrunner.model.Model; import io.bioimage.modelrunner.tensor.Tensor; +import io.bioimage.modelrunner.versionmanagement.AvailableEngines; +import io.bioimage.modelrunner.versionmanagement.DeepLearningVersion; +import io.bioimage.modelrunner.versionmanagement.InstalledEngines; import java.io.File; +import java.io.IOException; import java.util.ArrayList; import java.util.List; @@ -62,26 +68,33 @@ public class ExampleLoadAndRunModel { * @throws Exception */ public static < T extends RealType< T > & NativeType< T > > void main(String[] args) throws LoadEngineException, Exception { + // Tag for the DL framework (engine) that wants to be used - String engine = "torchscript"; + String framework = "torchscript"; // Version of the engine String engineVersion = "1.13.1"; // Directory where all the engines are stored String enginesDir = ENGINES_DIR; - // Path to the model folder - String modelFolder = new File(MODELS_DIR, "EnhancerMitochondriaEM2D_13012023_130426").getAbsolutePath(); + downloadTorchscriptEngine(framework, engineVersion, enginesDir); + + String bmzModelName = "EnhancerMitochondriaEM2D"; + String modelFolder = downloadBMZModel(bmzModelName, MODELS_DIR); + // Path to the model source. The model source locally is the path to the source file defined in the // yaml inside the model folder String modelSource = new File(modelFolder, "weights-torchscript.pt").getAbsolutePath(); // Whether the engine is supported by CPu or not boolean cpu = true; + List installedList = + InstalledEngines.checkEngineWithArgsInstalledForOS(framework, engineVersion, + cpu, null, enginesDir); // Whether the engine is supported by GPU or not - boolean gpu = true; + boolean gpu = installedList.get(0).getGPU(); // Create the EngineInfo object. It is needed to load the wanted DL framework // among all the installed ones. The EngineInfo loads the corresponding engine by looking // at the enginesDir at searching for the folder that is named satisfying the characteristics specified. // REGARD THAT the engine folders need to follow a naming convention - EngineInfo engineInfo = createEngineInfo(engine, engineVersion, enginesDir, cpu, gpu); + EngineInfo engineInfo = createEngineInfo(framework, engineVersion, enginesDir, cpu, gpu); // Load the corresponding model Model model = loadModel(modelFolder, modelSource, engineInfo); // Create an image that will be the backend of the Input Tensor @@ -119,6 +132,24 @@ public static < T extends RealType< T > & NativeType< T > > void main(String[] a System.out.print("Success!!"); } + public static void downloadTorchscriptEngine(String framework, String engineVersion, + String enginesDir) throws IOException, InterruptedException { + List possibleEngines = + AvailableEngines.getEnginesForOsByParams(framework, engineVersion, true, null); + boolean success = EngineInstall.installEngineInDir(possibleEngines.get(0), enginesDir); + + if (!success) + throw new IOException("The wanted DL engine was not downloaed correctly: " + + possibleEngines.get(0).folderName()); + } + + + public static String downloadBMZModel(String bmzModelName, String modelsDir) throws IOException, InterruptedException { + // Create an instance of the BioimageRepo object + BioimageioRepo br = BioimageioRepo.connect(); + return br.downloadByName(bmzModelName, modelsDir); + } + /** * Method that creates the {@link EngineInfo} object. * @param engine