diff --git a/api/src/main/java/ai/djl/engine/Engine.java b/api/src/main/java/ai/djl/engine/Engine.java index 8a1fc8871ac..8e330e3c40d 100644 --- a/api/src/main/java/ai/djl/engine/Engine.java +++ b/api/src/main/java/ai/djl/engine/Engine.java @@ -63,7 +63,7 @@ public abstract class Engine { private static final Pattern PATTERN = Pattern.compile("KEY|TOKEN|PASSWORD", Pattern.CASE_INSENSITIVE); - private Device defaultDevice; + protected Device defaultDevice; // use object to check if it's set private Integer seed; diff --git a/api/src/test/java/ai/djl/DeviceTest.java b/api/src/test/java/ai/djl/DeviceTest.java index a69a502739b..c2f0315616a 100644 --- a/api/src/test/java/ai/djl/DeviceTest.java +++ b/api/src/test/java/ai/djl/DeviceTest.java @@ -28,7 +28,7 @@ public void testDevice() { if (engine.getGpuCount() > 0) { Assert.assertEquals(Device.gpu(), engine.defaultDevice()); } else { - Assert.assertEquals(Device.cpu(), engine.defaultDevice()); + Assert.assertEquals(engine.defaultDevice().getDeviceId(), -1); } Assert.assertEquals(Device.gpu(), Device.of("gpu", 0)); Assert.assertEquals(Device.gpu(3), Device.of("gpu", 3)); diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtEngine.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtEngine.java index f909a57885c..26f51a3c87c 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtEngine.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtEngine.java @@ -121,6 +121,20 @@ public boolean hasCapability(String capability) { return JniUtils.getFeatures().contains(capability); } + /** {@inheritDoc} */ + @Override + public Device defaultDevice() { + if (defaultDevice != null) { + return defaultDevice; + } + if ("aarch64".equals(System.getProperty("os.arch")) + && System.getProperty("os.name").startsWith("Mac")) { + defaultDevice = Device.of("mps", -1); + return defaultDevice; + } + return super.defaultDevice(); + } + /** {@inheritDoc} */ @Override public SymbolBlock newSymbolBlock(NDManager manager) { diff --git a/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/ProfilerTest.java b/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/ProfilerTest.java index 80d67c88cc2..9721ec2c433 100644 --- a/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/ProfilerTest.java +++ b/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/ProfilerTest.java @@ -14,6 +14,7 @@ package ai.djl.pytorch.integration; import ai.djl.Application; +import ai.djl.Device; import ai.djl.MalformedModelException; import ai.djl.inference.Predictor; import ai.djl.modality.Classifications; @@ -58,6 +59,7 @@ public void testProfiler() Criteria.builder() .setTypes(Image.class, Classifications.class) .optApplication(Application.CV.IMAGE_CLASSIFICATION) + .optDevice(Device.cpu()) // MPS doesn't support float64 (by profiler) .optFilter("layers", "18") .optTranslator(translator) .optProgress(new ProgressBar()) diff --git a/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/TorchScriptTest.java b/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/TorchScriptTest.java index 10409eb0030..435ec8ec98e 100644 --- a/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/TorchScriptTest.java +++ b/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/TorchScriptTest.java @@ -13,6 +13,7 @@ package ai.djl.pytorch.integration; +import ai.djl.Device; import ai.djl.ModelException; import ai.djl.inference.Predictor; import ai.djl.ndarray.NDList; @@ -93,9 +94,10 @@ public void testInputOutput() throws IOException, ModelException { try (InputStream is = Files.newInputStream(modelFile)) { PtSymbolBlock block = JniUtils.loadModule(manager, is, true, false); ByteArrayOutputStream os = new ByteArrayOutputStream(); + // writeModule with MPS cannot be loaded back on MPS JniUtils.writeModule(block, os, true); ByteArrayInputStream bis = new ByteArrayInputStream(os.toByteArray()); - JniUtils.loadModule(manager, bis, true, true); + JniUtils.loadModuleHandle(bis, Device.cpu(), true, true); bis.close(); os.close(); }