Skip to content

Commit

Permalink
do some small changes to allow more flexibility when tracking a download
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosuc3m committed Oct 15, 2024
1 parent f4e00b5 commit d0696a5
Showing 1 changed file with 39 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,15 @@
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.stream.Collectors;

import io.bioimage.modelrunner.bioimageio.BioimageioRepo;
import io.bioimage.modelrunner.bioimageio.description.ModelDescriptor;
import io.bioimage.modelrunner.bioimageio.description.ModelDescriptorFactory;
import io.bioimage.modelrunner.bioimageio.description.exceptions.ModelSpecsException;
import io.bioimage.modelrunner.bioimageio.description.weights.WeightFormat;
import io.bioimage.modelrunner.bioimageio.download.DownloadModel;
import io.bioimage.modelrunner.bioimageio.download.DownloadTracker;
import io.bioimage.modelrunner.bioimageio.download.DownloadTracker.TwoParameterConsumer;
import io.bioimage.modelrunner.engine.EngineInfo;
Expand Down Expand Up @@ -69,6 +71,12 @@ public class EngineInstall {
* Map containing which version should always be installed per framework
*/
public static LinkedHashMap<String, String> ENGINES_VERSIONS = new LinkedHashMap<String, String>();
/**
* Suffix in the download progress tracker consumer that indicates that the value is total size of
* bytes to be downloaded. On the other hand, if the suffix is not present, it will represent the
* percentage of bytes that is being downloaded currently
*/
public static final String NBYTES_SUFFIX = "_SIZE";

static {
ENGINES_VERSIONS.put(EngineInfo.getTensorflowKey() + "_2", "2.7.0");
Expand Down Expand Up @@ -542,13 +550,11 @@ private void installMissingBasicEngines() {
checkEnginesInstalled();
if (missingEngineFolders.entrySet().size() == 0)
return;
getBasicDownloadTotalSize();
missingEngineFolders = missingEngineFolders.entrySet().stream()
.filter(v -> {
TwoParameterConsumer<String, Double> consumer = DownloadTracker.createConsumerProgress();
if (this.consumersMap != null && this.consumersMap.get(v.getValue()) != null)
consumer = this.consumersMap.get(v.getValue());
try {
return!installEngineByCompleteName(v.getValue(), consumer);
return!installEngineByCompleteName(v.getValue(), consumersMap.get(v.getValue()));
} catch (IOException | InterruptedException e) {
return true;
}
Expand All @@ -557,6 +563,34 @@ private void installMissingBasicEngines() {
(u, v) -> u, LinkedHashMap::new));
}

/**
*
* @return the total numebr of bytes to be downloaded in the basic installation. This also prepares the
* download tracker by setting the number of bytes per file
*/
public long getBasicDownloadTotalSize() {
if (this.consumersMap == null)
getBasicEnginesProgress();
long totalSize = 0;
for (Entry<String, String> ee : missingEngineFolders.entrySet()) {
try {
long engineSize = 0;
DeepLearningVersion dlVersion = DeepLearningVersion.fromFile(new File(ee.getValue()));
for (String link : dlVersion.getJars()) {
String key = DownloadModel.getFileNameFromURLString(link) + NBYTES_SUFFIX;
long val = DownloadModel.getFileSize(new URL(link));
this.consumersMap.get(ee.getValue()).accept(key, (double) val);
engineSize += val;
}
this.consumersMap.get(ee.getValue()).accept("total" + NBYTES_SUFFIX, (double) engineSize);
totalSize += engineSize;
} catch (IllegalStateException | IOException e) {
continue;
}
}
return totalSize;
}

/**
* Install the engine that should be located in the engine dir specified
* @param engineDir
Expand Down Expand Up @@ -1139,7 +1173,7 @@ public static boolean installEngineInDir(DeepLearningVersion engine, String engi
* @param engineDir
* directory where the files will be downloaded
*/
private static void downloadEngineFiles(DeepLearningVersion engine, String engineDir) {
public static void downloadEngineFiles(DeepLearningVersion engine, String engineDir) {
for (String jar : engine.getJars()) {
try {
URL website = new URL(jar);
Expand Down

0 comments on commit d0696a5

Please sign in to comment.