Skip to content

Commit

Permalink
always use tmp dir to load model (#2535)
Browse files Browse the repository at this point in the history
* always use tmp dir to load model

* honor temp dir as model dir

* update env setting

* fmt

* clean dir if archive failed
  • Loading branch information
lxning committed Aug 29, 2023
1 parent d3eeb07 commit 65f6005
Show file tree
Hide file tree
Showing 5 changed files with 78 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ public static ModelArchive downloadModel(
}
}

File tempDir = ZipUtils.createTempDir(null, "models");
logger.info("createTempDir {}", tempDir.getAbsolutePath());
File directory = new File(url);
if (directory.isDirectory()) {
// handle the case that the input url is a directory.
Expand All @@ -88,9 +90,13 @@ public static ModelArchive downloadModel(
if (fileList.length == 1 && fileList[0].isDirectory()) {
// handle the case that a model tgz file
// has root dir after decompression on SageMaker
return load(url, fileList[0], false);
File targetLink = ZipUtils.createSymbolicDir(fileList[0], tempDir);
logger.info("createSymbolicDir {}", targetLink.getAbsolutePath());
return load(url, targetLink, false);
}
return load(url, directory, false);
File targetLink = ZipUtils.createSymbolicDir(directory, tempDir);
logger.info("createSymbolicDir {}", targetLink.getAbsolutePath());
return load(url, targetLink, false);
} else if (modelLocation.exists()) {
// handle the case that "/xxx/model_store/modelXXX" is directory.
// the input of url is modelXXX when torchserve is started
Expand All @@ -99,9 +105,13 @@ public static ModelArchive downloadModel(
if (fileList.length == 1 && fileList[0].isDirectory()) {
// handle the case that a model tgz file
// has root dir after decompression on SageMaker
return load(url, fileList[0], false);
File targetLink = ZipUtils.createSymbolicDir(fileList[0], tempDir);
logger.info("createSymbolicDir {}", targetLink.getAbsolutePath());
return load(url, targetLink, false);
}
return load(url, modelLocation, false);
File targetLink = ZipUtils.createSymbolicDir(modelLocation, tempDir);
logger.info("createSymbolicDir {}", targetLink.getAbsolutePath());
return load(url, targetLink, false);
}

throw new ModelNotFoundException("Model not found at: " + url);
Expand All @@ -122,8 +132,12 @@ private static ModelArchive load(String url, File dir, boolean extracted)
failed = false;
return new ModelArchive(manifest, url, dir, extracted);
} finally {
if (extracted && failed) {
FileUtils.deleteQuietly(dir);
if (failed) {
if (Files.isSymbolicLink(dir.toPath())) {
FileUtils.deleteQuietly(dir.getParentFile());
} else {
FileUtils.deleteQuietly(dir);
}
}
}
}
Expand Down Expand Up @@ -195,8 +209,12 @@ public String getModelVersion() {
}

public void clean() {
if (url != null && extracted) {
FileUtils.deleteQuietly(modelDir);
if (url != null) {
if (Files.isSymbolicLink(modelDir.toPath())) {
FileUtils.deleteQuietly(modelDir.getParentFile());
} else {
FileUtils.deleteQuietly(modelDir);
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,4 +120,29 @@ public static void decompressTarGzipFile(InputStream is, File dest) throws IOExc
}
}
}

public static File createTempDir(String eTag, String type) throws IOException {
File tmpDir = FileUtils.getTempDirectory();
File modelDir = new File(tmpDir, type);

if (eTag == null) {
eTag = UUID.randomUUID().toString().replaceAll("-", "");
}

File dir = new File(modelDir, eTag);
if (dir.exists()) {
FileUtils.forceDelete(dir);
}
FileUtils.forceMkdir(dir);

return dir;
}

public static File createSymbolicDir(File source, File dest) throws IOException {
String sourceDirName = source.getName();
File targetLink = new File(dest, sourceDirName);
Files.createSymbolicLink(targetLink.toPath(), source.toPath());

return targetLink;
}
}
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
package org.pytorch.serve.util.messages;

import java.io.File;
import java.io.IOException;
import java.nio.file.Files;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;
import java.util.regex.Pattern;
import org.pytorch.serve.archive.model.Manifest;
import org.pytorch.serve.util.ConfigManager;
import org.pytorch.serve.wlm.Model;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public final class EnvironmentUtils {

private static final Logger logger = LoggerFactory.getLogger(EnvironmentUtils.class);
private static ConfigManager configManager = ConfigManager.getInstance();

private EnvironmentUtils() {}
Expand Down Expand Up @@ -38,7 +42,19 @@ public static String[] getEnvString(String cwd, String modelPath, String handler
}

if (modelPath != null) {
pythonPath.append(modelPath).append(File.pathSeparatorChar);
File modelPathCanonical = new File(modelPath);
try {
modelPathCanonical = modelPathCanonical.getCanonicalFile();
} catch (IOException e) {
logger.error("Invalid model path {}", modelPath, e);
}
pythonPath.append(modelPathCanonical.getAbsolutePath()).append(File.pathSeparatorChar);
File dependencyPath = new File(modelPath);
if (Files.isSymbolicLink(dependencyPath.toPath())) {
pythonPath
.append(dependencyPath.getParentFile().getAbsolutePath())
.append(File.pathSeparatorChar);
}
}

if (!cwd.contains("site-packages") && !cwd.contains("dist-packages")) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import java.io.IOException;
import java.io.InputStreamReader;
import java.net.HttpURLConnection;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.HashSet;
Expand Down Expand Up @@ -212,10 +213,14 @@ private void setupModelDependencies(Model model)

String pythonRuntime = EnvironmentUtils.getPythonRunTime(model);

File dependencyPath = model.getModelDir();
if (Files.isSymbolicLink(dependencyPath.toPath())) {
dependencyPath = dependencyPath.getParentFile();
}
String packageInstallCommand =
pythonRuntime
+ " -m pip install -U -t "
+ model.getModelDir().getAbsolutePath()
+ dependencyPath.getAbsolutePath()
+ " -r "
+ requirementsFilePath; // NOPMD

Expand Down Expand Up @@ -251,7 +256,6 @@ private void setupModelDependencies(Model model)
errorString.append(line);
}

logger.info("Dependency installation stdout:\n" + outputString.toString());
logger.error("Dependency installation stderr:\n" + errorString.toString());

throw new ModelException(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,9 @@ public void startWorker(int port, String deviceIds)
File modelPath;
setPort(port);
try {
modelPath = model.getModelDir().getCanonicalFile();
modelPath = model.getModelDir();
// Test if modelPath is valid
modelPath.getCanonicalFile();
} catch (IOException e) {
throw new WorkerInitializationException("Failed get TS home directory", e);
}
Expand Down

0 comments on commit 65f6005

Please sign in to comment.