diff --git a/src/main/java/io/bioimage/modelrunner/engine/installation/EngineInstall.java b/src/main/java/io/bioimage/modelrunner/engine/installation/EngineInstall.java index 66fa6ad6..90077a77 100644 --- a/src/main/java/io/bioimage/modelrunner/engine/installation/EngineInstall.java +++ b/src/main/java/io/bioimage/modelrunner/engine/installation/EngineInstall.java @@ -34,6 +34,7 @@ import java.util.LinkedHashMap; import java.util.List; import java.util.Map; +import java.util.Map.Entry; import java.util.stream.Collectors; import io.bioimage.modelrunner.bioimageio.BioimageioRepo; @@ -41,6 +42,7 @@ import io.bioimage.modelrunner.bioimageio.description.ModelDescriptorFactory; import io.bioimage.modelrunner.bioimageio.description.exceptions.ModelSpecsException; import io.bioimage.modelrunner.bioimageio.description.weights.WeightFormat; +import io.bioimage.modelrunner.bioimageio.download.DownloadModel; import io.bioimage.modelrunner.bioimageio.download.DownloadTracker; import io.bioimage.modelrunner.bioimageio.download.DownloadTracker.TwoParameterConsumer; import io.bioimage.modelrunner.engine.EngineInfo; @@ -69,6 +71,12 @@ public class EngineInstall { * Map containing which version should always be installed per framework */ public static LinkedHashMap ENGINES_VERSIONS = new LinkedHashMap(); + /** + * Suffix in the download progress tracker consumer that indicates that the value is total size of + * bytes to be downloaded. On the other hand, if the suffix is not present, it will represent the + * percentage of bytes that is being downloaded currently + */ + public static final String NBYTES_SUFFIX = "_SIZE"; static { ENGINES_VERSIONS.put(EngineInfo.getTensorflowKey() + "_2", "2.7.0"); @@ -542,13 +550,11 @@ private void installMissingBasicEngines() { checkEnginesInstalled(); if (missingEngineFolders.entrySet().size() == 0) return; + getBasicDownloadTotalSize(); missingEngineFolders = missingEngineFolders.entrySet().stream() .filter(v -> { - TwoParameterConsumer consumer = DownloadTracker.createConsumerProgress(); - if (this.consumersMap != null && this.consumersMap.get(v.getValue()) != null) - consumer = this.consumersMap.get(v.getValue()); try { - return!installEngineByCompleteName(v.getValue(), consumer); + return!installEngineByCompleteName(v.getValue(), consumersMap.get(v.getValue())); } catch (IOException | InterruptedException e) { return true; } @@ -557,6 +563,34 @@ private void installMissingBasicEngines() { (u, v) -> u, LinkedHashMap::new)); } + /** + * + * @return the total numebr of bytes to be downloaded in the basic installation. This also prepares the + * download tracker by setting the number of bytes per file + */ + public long getBasicDownloadTotalSize() { + if (this.consumersMap == null) + getBasicEnginesProgress(); + long totalSize = 0; + for (Entry ee : missingEngineFolders.entrySet()) { + try { + long engineSize = 0; + DeepLearningVersion dlVersion = DeepLearningVersion.fromFile(new File(ee.getValue())); + for (String link : dlVersion.getJars()) { + String key = DownloadModel.getFileNameFromURLString(link) + NBYTES_SUFFIX; + long val = DownloadModel.getFileSize(new URL(link)); + this.consumersMap.get(ee.getValue()).accept(key, (double) val); + engineSize += val; + } + this.consumersMap.get(ee.getValue()).accept("total" + NBYTES_SUFFIX, (double) engineSize); + totalSize += engineSize; + } catch (IllegalStateException | IOException e) { + continue; + } + } + return totalSize; + } + /** * Install the engine that should be located in the engine dir specified * @param engineDir @@ -1139,7 +1173,7 @@ public static boolean installEngineInDir(DeepLearningVersion engine, String engi * @param engineDir * directory where the files will be downloaded */ - private static void downloadEngineFiles(DeepLearningVersion engine, String engineDir) { + public static void downloadEngineFiles(DeepLearningVersion engine, String engineDir) { for (String jar : engine.getJars()) { try { URL website = new URL(jar);