Skip to content

Commit

Permalink
Make MPS default device for macOS M1
Browse files Browse the repository at this point in the history
  • Loading branch information
frankfliu committed Sep 28, 2022
1 parent 1d0bc77 commit e3fdeb0
Show file tree
Hide file tree
Showing 6 changed files with 30 additions and 3 deletions.
2 changes: 1 addition & 1 deletion api/src/main/java/ai/djl/engine/Engine.java
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ public abstract class Engine {

private static final String DEFAULT_ENGINE = initEngine();

private Device defaultDevice;
protected Device defaultDevice;

// use object to check if it's set
private Integer seed;
Expand Down
2 changes: 1 addition & 1 deletion api/src/test/java/ai/djl/DeviceTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ public void testDevice() {
if (engine.getGpuCount() > 0) {
Assert.assertEquals(Device.gpu(), engine.defaultDevice());
} else {
Assert.assertEquals(Device.cpu(), engine.defaultDevice());
Assert.assertEquals(engine.defaultDevice().getDeviceId(), -1);
}
Assert.assertEquals(Device.gpu(), Device.of("gpu", 0));
Assert.assertEquals(Device.gpu(3), Device.of("gpu", 3));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,20 @@ public boolean hasCapability(String capability) {
return JniUtils.getFeatures().contains(capability);
}

/** {@inheritDoc} */
@Override
public Device defaultDevice() {
if (defaultDevice != null) {
return defaultDevice;
}
if ("aarch64".equals(System.getProperty("os.arch"))
&& System.getProperty("os.name").startsWith("Mac")) {
defaultDevice = Device.of("mps", -1);
return defaultDevice;
}
return super.defaultDevice();
}

/** {@inheritDoc} */
@Override
public SymbolBlock newSymbolBlock(NDManager manager) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1544,6 +1544,10 @@ public static PtSymbolBlock loadModule(
String[] extraFileKeys,
String[] extraFileValues) {
Device device = manager.getDevice();
// MPS doesn't support mapLocation
if ("mps".equals(device.getDeviceType())) {
mapLocation = false;
}
logger.debug("mapLocation: {}", mapLocation);
logger.debug("extraFileKeys: {}", Arrays.toString(extraFileKeys));
long handle =
Expand Down Expand Up @@ -1571,6 +1575,11 @@ public static long loadModuleHandle(
if (hasSize) {
size = new DataInputStream(is).readLong();
}
// MPS doesn't support mapLocation
if ("mps".equals(device.getDeviceType())) {
mapLocation = false;
}
logger.debug("mapLocation: {}", mapLocation);
return PyTorchLibrary.LIB.moduleLoad(
is,
new int[] {PtDeviceType.toDeviceType(device), device.getDeviceId()},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
package ai.djl.pytorch.integration;

import ai.djl.Application;
import ai.djl.Device;
import ai.djl.MalformedModelException;
import ai.djl.inference.Predictor;
import ai.djl.modality.Classifications;
Expand Down Expand Up @@ -53,6 +54,7 @@ public void testProfiler()
Criteria.builder()
.setTypes(Image.class, Classifications.class)
.optApplication(Application.CV.IMAGE_CLASSIFICATION)
.optDevice(Device.cpu()) // MPS doesn't support float64 (by profiler)
.optFilter("layers", "18")
.optTranslator(translator)
.optProgress(new ProgressBar())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

package ai.djl.pytorch.integration;

import ai.djl.Device;
import ai.djl.ModelException;
import ai.djl.inference.Predictor;
import ai.djl.ndarray.NDList;
Expand Down Expand Up @@ -88,9 +89,10 @@ public void testInputOutput() throws IOException, ModelException {
try (InputStream is = Files.newInputStream(modelFile)) {
PtSymbolBlock block = JniUtils.loadModule(manager, is, true, false);
ByteArrayOutputStream os = new ByteArrayOutputStream();
// writeModule with MPS cannot be loaded back on MPS
JniUtils.writeModule(block, os, true);
ByteArrayInputStream bis = new ByteArrayInputStream(os.toByteArray());
JniUtils.loadModule(manager, bis, true, true);
JniUtils.loadModuleHandle(bis, Device.cpu(), true, true);
bis.close();
os.close();
}
Expand Down

0 comments on commit e3fdeb0

Please sign in to comment.