forked from deepjavalibrary/djl
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
e436b57
commit 1558d98
Showing
5 changed files
with
355 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
26 changes: 26 additions & 0 deletions
26
paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/jni/JniUtils.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
/* | ||
* Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance | ||
* with the License. A copy of the License is located at | ||
* | ||
* http://aws.amazon.com/apache2.0/ | ||
* | ||
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES | ||
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions | ||
* and limitations under the License. | ||
*/ | ||
package ai.djl.paddlepaddle.jni; | ||
|
||
/** | ||
* A class containing utilities to interact with the Paddle Engine's Java Native Interface (JNI) | ||
* layer. | ||
*/ | ||
@SuppressWarnings("MissingJavadocMethod") | ||
public final class JniUtils { | ||
private JniUtils() {} | ||
|
||
public static int getTensorDType(long handle) { | ||
return PaddleLibrary.LIB.getTensorDType(handle); | ||
} | ||
} |
322 changes: 322 additions & 0 deletions
322
paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/jni/LibUtils.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,322 @@ | ||
/* | ||
* Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance | ||
* with the License. A copy of the License is located at | ||
* | ||
* http://aws.amazon.com/apache2.0/ | ||
* | ||
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES | ||
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions | ||
* and limitations under the License. | ||
*/ | ||
package ai.djl.paddlepaddle.jni; | ||
|
||
import ai.djl.util.Platform; | ||
import ai.djl.util.Utils; | ||
import java.io.File; | ||
import java.io.IOException; | ||
import java.io.InputStream; | ||
import java.net.URL; | ||
import java.nio.file.Files; | ||
import java.nio.file.Path; | ||
import java.nio.file.Paths; | ||
import java.nio.file.StandardCopyOption; | ||
import java.util.Enumeration; | ||
import java.util.List; | ||
import java.util.Properties; | ||
import java.util.regex.Matcher; | ||
import java.util.regex.Pattern; | ||
import java.util.zip.GZIPInputStream; | ||
import org.slf4j.Logger; | ||
import org.slf4j.LoggerFactory; | ||
|
||
/** | ||
* Utilities for finding the Paddle Engine binary on the System. | ||
* | ||
* <p>The Engine will be searched for in a variety of locations in the following order: | ||
* | ||
* <ol> | ||
* <li>In the path specified by the Paddle_LIBRARY_PATH environment variable | ||
* <li>In a jar file location in the classpath. These jars can be created with the paddle-native | ||
* module. | ||
* </ol> | ||
*/ | ||
@SuppressWarnings("MissingJavadocMethod") | ||
public final class LibUtils { | ||
|
||
private static final Logger logger = LoggerFactory.getLogger(LibUtils.class); | ||
|
||
private static final String NATIVE_LIB_NAME = "paddle_fluid"; | ||
private static final String LIB_NAME = "djl_paddle"; | ||
private static final Pattern VERSION_PATTERN = | ||
Pattern.compile("(\\d+\\.\\d+\\.\\d+)(-SNAPSHOT)?(-\\d+)?"); | ||
|
||
private LibUtils() {} | ||
|
||
public static void loadLibrary() { | ||
String libName = getLibName(); | ||
logger.debug("Loading paddle library from: {}", libName); | ||
|
||
System.load(libName); // NOPMD | ||
} | ||
|
||
public static String getLibName() { | ||
String libName = LibUtils.findOverrideLibrary(); | ||
if (libName == null) { | ||
libName = LibUtils.findLibraryInClasspath(); | ||
if (libName == null) { | ||
throw new IllegalStateException("Native library not found"); | ||
} | ||
Path nativeLibDir = Paths.get(libName).getParent(); | ||
if (nativeLibDir == null || !nativeLibDir.toFile().isDirectory()) { | ||
throw new IllegalStateException("Native folder cannot be found"); | ||
} | ||
libName = copyJniLibraryFromClasspath(nativeLibDir); | ||
} | ||
return libName; | ||
} | ||
|
||
private static String findOverrideLibrary() { | ||
String libPath = System.getenv("PADDLE_LIBRARY_PATH"); | ||
if (libPath != null) { | ||
String libName = findLibraryInPath(libPath); | ||
if (libName != null) { | ||
return libName; | ||
} | ||
} | ||
|
||
libPath = System.getProperty("java.library.path"); | ||
if (libPath != null) { | ||
return findLibraryInPath(libPath); | ||
} | ||
return null; | ||
} | ||
|
||
private static String copyJniLibraryFromClasspath(Path nativeDir) { | ||
String name = System.mapLibraryName(LIB_NAME); | ||
Platform platform = Platform.fromSystem(); | ||
String classifier = platform.getClassifier(); | ||
String flavor = platform.getFlavor(); | ||
if (flavor.isEmpty()) { | ||
flavor = "cpu"; | ||
} | ||
Properties prop = new Properties(); | ||
try (InputStream stream = | ||
LibUtils.class.getResourceAsStream("/jnilib/paddlepaddle.properties")) { | ||
prop.load(stream); | ||
} catch (IOException e) { | ||
throw new IllegalStateException("Cannot find paddle property file", e); | ||
} | ||
String version = prop.getProperty("version"); | ||
Path path = nativeDir.resolve(version + '-' + flavor + '-' + name); | ||
if (Files.exists(path)) { | ||
return path.toAbsolutePath().toString(); | ||
} | ||
|
||
Path tmp = null; | ||
try (InputStream stream = | ||
LibUtils.class.getResourceAsStream( | ||
"/jnilib/" + classifier + '/' + flavor + '/' + name)) { | ||
tmp = Files.createTempFile(nativeDir, "jni", "tmp"); | ||
Files.copy(stream, tmp, StandardCopyOption.REPLACE_EXISTING); | ||
Utils.moveQuietly(tmp, path); | ||
return path.toAbsolutePath().toString(); | ||
} catch (IOException e) { | ||
throw new IllegalStateException("Cannot copy jni files", e); | ||
} finally { | ||
if (tmp != null) { | ||
Utils.deleteQuietly(tmp); | ||
} | ||
} | ||
} | ||
|
||
private static synchronized String findLibraryInClasspath() { | ||
Enumeration<URL> urls; | ||
try { | ||
urls = | ||
Thread.currentThread() | ||
.getContextClassLoader() | ||
.getResources("native/lib/paddlepaddle.properties"); | ||
} catch (IOException e) { | ||
logger.warn("", e); | ||
return null; | ||
} | ||
|
||
// No native jars | ||
if (!urls.hasMoreElements()) { | ||
logger.debug("paddlepaddle.properties not found in class path."); | ||
return null; | ||
} | ||
|
||
Platform systemPlatform = Platform.fromSystem(); | ||
try { | ||
Platform matching = null; | ||
Platform placeholder = null; | ||
while (urls.hasMoreElements()) { | ||
URL url = urls.nextElement(); | ||
Platform platform = Platform.fromUrl(url); | ||
if (platform.isPlaceholder()) { | ||
placeholder = platform; | ||
} else if (platform.matches(systemPlatform)) { | ||
matching = platform; | ||
break; | ||
} | ||
} | ||
|
||
if (matching != null) { | ||
return loadLibraryFromClasspath(matching); | ||
} | ||
|
||
if (placeholder != null) { | ||
try { | ||
return downloadLibrary(placeholder); | ||
} catch (IOException e) { | ||
throw new IllegalStateException( | ||
"Failed to download PaddlePaddle native library", e); | ||
} | ||
} | ||
} catch (IOException e) { | ||
throw new IllegalStateException( | ||
"Failed to read PaddlePaddle native library jar properties", e); | ||
} | ||
|
||
throw new IllegalStateException( | ||
"Your PaddlePaddle native library jar does not match your operating system. Make sure that the Maven Dependency Classifier matches your system type."); | ||
} | ||
|
||
private static String loadLibraryFromClasspath(Platform platform) { | ||
Path tmp = null; | ||
try { | ||
String libName = System.mapLibraryName(NATIVE_LIB_NAME); | ||
Path cacheFolder = getCacheDir(); | ||
logger.debug("Using cache dir: {}", cacheFolder); | ||
|
||
Path dir = cacheFolder.resolve(platform.getVersion() + platform.getClassifier()); | ||
Path path = dir.resolve(libName); | ||
if (Files.exists(path)) { | ||
return path.toAbsolutePath().toString(); | ||
} | ||
Files.createDirectories(cacheFolder); | ||
tmp = Files.createTempDirectory(cacheFolder, "tmp"); | ||
for (String file : platform.getLibraries()) { | ||
String libPath = "/native/lib/" + file; | ||
try (InputStream is = LibUtils.class.getResourceAsStream(libPath)) { | ||
logger.info("Extracting {} to cache ...", file); | ||
Files.copy(is, tmp.resolve(file), StandardCopyOption.REPLACE_EXISTING); | ||
} | ||
} | ||
|
||
Utils.moveQuietly(tmp, dir); | ||
return path.toAbsolutePath().toString(); | ||
} catch (IOException e) { | ||
throw new IllegalStateException("Failed to extract PaddlePaddle native library", e); | ||
} finally { | ||
if (tmp != null) { | ||
Utils.deleteQuietly(tmp); | ||
} | ||
} | ||
} | ||
|
||
private static String findLibraryInPath(String libPath) { | ||
String[] paths = libPath.split(File.pathSeparator); | ||
String mappedLibNames = System.mapLibraryName(NATIVE_LIB_NAME); | ||
|
||
for (String path : paths) { | ||
File p = new File(path); | ||
if (!p.exists()) { | ||
continue; | ||
} | ||
if (p.isFile() && p.getName().endsWith(mappedLibNames)) { | ||
return p.getAbsolutePath(); | ||
} | ||
|
||
File file = new File(path, mappedLibNames); | ||
if (file.exists() && file.isFile()) { | ||
return file.getAbsolutePath(); | ||
} | ||
} | ||
return null; | ||
} | ||
|
||
private static String downloadLibrary(Platform platform) throws IOException { | ||
String version = platform.getVersion(); | ||
String flavor = platform.getFlavor(); | ||
if (flavor.isEmpty()) { | ||
flavor = "cpu"; | ||
} | ||
String classifier = platform.getClassifier(); | ||
String os = platform.getOsPrefix(); | ||
|
||
String libName = System.mapLibraryName(NATIVE_LIB_NAME); | ||
Path cacheDir = getCacheDir(); | ||
logger.debug("Using cache dir: {}", cacheDir); | ||
Path dir = cacheDir.resolve(version + '-' + flavor + '-' + classifier); | ||
Path path = dir.resolve(libName); | ||
if (Files.exists(path)) { | ||
return path.toAbsolutePath().toString(); | ||
} | ||
|
||
Files.createDirectories(cacheDir); | ||
Matcher matcher = VERSION_PATTERN.matcher(version); | ||
if (!matcher.matches()) { | ||
throw new IllegalArgumentException("Unexpected version: " + version); | ||
} | ||
|
||
Path tmp = null; | ||
String link = "https://publish.djl.ai/paddlepaddle-" + matcher.group(1); | ||
try (InputStream is = new URL(link + "/files.txt").openStream()) { | ||
List<String> lines = Utils.readLines(is); | ||
if (flavor.startsWith("cu") | ||
&& !lines.contains(flavor + '/' + os + "/native/lib/" + libName)) { | ||
logger.warn("No matching cuda flavor for {} found: {}.", os, flavor); | ||
// fallback to CPU | ||
flavor = "cpu"; | ||
|
||
// check again | ||
dir = cacheDir.resolve(version + '-' + flavor + '-' + classifier); | ||
path = dir.resolve(libName); | ||
if (Files.exists(path)) { | ||
return cacheDir.toAbsolutePath().toString(); | ||
} | ||
} | ||
|
||
tmp = Files.createTempDirectory(cacheDir, "tmp"); | ||
for (String line : lines) { | ||
if (line.startsWith(os + '/' + flavor + '/')) { | ||
URL url = new URL(link + '/' + line); | ||
String fileName = line.substring(line.lastIndexOf('/') + 1, line.length() - 3); | ||
logger.info("Downloading {} ...", url); | ||
try (InputStream fis = new GZIPInputStream(url.openStream())) { | ||
Files.copy(fis, tmp.resolve(fileName), StandardCopyOption.REPLACE_EXISTING); | ||
} | ||
} | ||
} | ||
|
||
Utils.moveQuietly(tmp, dir); | ||
return path.toAbsolutePath().toString(); | ||
} finally { | ||
if (tmp != null) { | ||
Utils.deleteQuietly(tmp); | ||
} | ||
} | ||
} | ||
|
||
private static Path getCacheDir() { | ||
String cacheDir = System.getProperty("ENGINE_CACHE_DIR"); | ||
if (cacheDir == null || cacheDir.isEmpty()) { | ||
cacheDir = System.getenv("ENGINE_CACHE_DIR"); | ||
if (cacheDir == null || cacheDir.isEmpty()) { | ||
cacheDir = System.getProperty("DJL_CACHE_DIR"); | ||
if (cacheDir == null || cacheDir.isEmpty()) { | ||
cacheDir = System.getenv("DJL_CACHE_DIR"); | ||
if (cacheDir == null || cacheDir.isEmpty()) { | ||
String userHome = System.getProperty("user.home"); | ||
return Paths.get(userHome, ".djl.ai").resolve("paddle"); | ||
} | ||
} | ||
} | ||
} | ||
return Paths.get(cacheDir, "paddle"); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters