Skip to content

Commit

Permalink
Fixes java format
Browse files Browse the repository at this point in the history
  • Loading branch information
frankfliu committed Nov 26, 2023
1 parent d4d6e48 commit a5c9623
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 22 deletions.
5 changes: 3 additions & 2 deletions api/src/main/java/ai/djl/modality/Classifications.java
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,12 @@ public Classifications(List<String> classNames, NDArray probabilities) {
*/
public Classifications(List<String> classNames, NDArray probabilities, int topK) {
this.classNames = classNames;
if (probabilities.getDataType().equals(DataType.FLOAT32)) {
if (probabilities.getDataType() == DataType.FLOAT32) {
// Avoid converting float32 to float64 as this is not supported on MPS device
this.probabilities = new ArrayList<>();
for (float prob : probabilities.toFloatArray())
for (float prob : probabilities.toFloatArray()) {
this.probabilities.add((double) prob);
}
} else {
NDArray array = probabilities.toType(DataType.FLOAT64, false);
this.probabilities =
Expand Down
5 changes: 3 additions & 2 deletions api/src/main/java/ai/djl/ndarray/internal/NDArrayEx.java
Original file line number Diff line number Diff line change
Expand Up @@ -435,10 +435,11 @@ default NDArray toTensor() {
result = result.expandDims(0);
}
// For Apple Silicon MPS it is important not to switch to 64-bit float here
if (result.getDataType().equals(DataType.FLOAT32))
if (result.getDataType() == DataType.FLOAT32) {
result = result.div(255.0f).transpose(0, 3, 1, 2);
else
} else {
result = result.div(255.0).transpose(0, 3, 1, 2);
}
if (dim == 3) {
result = result.squeeze(0);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,11 @@ private NDArray readData(Artifact.Item item, long length) throws IOException {
}

byte[] buf = Utils.toByteArray(is);
try (NDArray array = manager.create(ByteBuffer.wrap(buf),
new Shape(length, IMAGE_WIDTH, IMAGE_HEIGHT, 1), DataType.UINT8)) {
try (NDArray array =
manager.create(
ByteBuffer.wrap(buf),
new Shape(length, IMAGE_WIDTH, IMAGE_HEIGHT, 1),
DataType.UINT8)) {
return array.toType(DataType.FLOAT32, false);
}
}
Expand All @@ -131,7 +134,8 @@ private NDArray readLabel(Artifact.Item item) throws IOException {
}

byte[] buf = Utils.toByteArray(is);
try (NDArray array = manager.create(ByteBuffer.wrap(buf), new Shape(buf.length), DataType.UINT8)) {
try (NDArray array =
manager.create(ByteBuffer.wrap(buf), new Shape(buf.length), DataType.UINT8)) {
return array.toType(DataType.FLOAT32, false);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,9 @@ private NDArray readData(Artifact.Item item, long length) throws IOException {
}

byte[] buf = Utils.toByteArray(is);
try (NDArray array = manager.create(ByteBuffer.wrap(buf), new Shape(length, 28, 28, 1), DataType.UINT8)) {
try (NDArray array =
manager.create(
ByteBuffer.wrap(buf), new Shape(length, 28, 28, 1), DataType.UINT8)) {
return array.toType(DataType.FLOAT32, false);
}
}
Expand All @@ -124,7 +126,8 @@ private NDArray readLabel(Artifact.Item item) throws IOException {
throw new AssertionError("Failed skip data.");
}
byte[] buf = Utils.toByteArray(is);
try (NDArray array = manager.create(ByteBuffer.wrap(buf), new Shape(buf.length), DataType.UINT8)) {
try (NDArray array =
manager.create(ByteBuffer.wrap(buf), new Shape(buf.length), DataType.UINT8)) {
return array.toType(DataType.FLOAT32, false);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.testng.annotations.Test;

import java.util.Arrays;
import java.util.Collections;
import java.util.List;

public class MpsTest {
Expand All @@ -42,8 +43,8 @@ public void testMps() {
}

private static boolean checkMpsCompatible() {
return "aarch64".equals(System.getProperty("os.arch")) &&
System.getProperty("os.name").startsWith("Mac");
return "aarch64".equals(System.getProperty("os.arch"))
&& System.getProperty("os.name").startsWith("Mac");
}

@Test
Expand All @@ -54,9 +55,9 @@ public void testToTensorMPS() {

// Test that toTensor does not fail on MPS (e.g. due to use of float64 for division)
try (NDManager manager = NDManager.newBaseManager(Device.fromName("mps"))) {
NDArray array = manager.create(127f).reshape(1, 1, 1, 1);;
NDArray array = manager.create(127f).reshape(1, 1, 1, 1);
NDArray tensor = array.getNDArrayInternal().toTensor();
Assert.assertEquals(tensor.toFloatArray(), new float[]{127f/255f});
Assert.assertEquals(tensor.toFloatArray(), new float[] {127f / 255f});
}
}

Expand All @@ -66,16 +67,13 @@ public void testClassificationsMPS() {
throw new SkipException("MPS classification test requires Apple Silicon macOS.");
}

// Test that classifications do not fail on MPS (e.g. due to conversion of probabilities to float64)
try (NDManager manager = NDManager.newBaseManager(Device.fromName("mps"))) {
// Test that classifications do not fail on MPS (e.g. due to conversion of probabilities to
// float64)
try (NDManager manager = NDManager.newBaseManager(Device.fromName("mps"))) {
List<String> names = Arrays.asList("First", "Second", "Third", "Fourth", "Fifth");
NDArray tensor = manager.create(new float[]{0f, 0.125f, 1f, 0.5f, 0.25f});
Classifications classifications = new Classifications(
names,
tensor
);
Assert.assertNotNull(classifications.topK(1).equals(Arrays.asList("Third")));
NDArray tensor = manager.create(new float[] {0f, 0.125f, 1f, 0.5f, 0.25f});
Classifications classifications = new Classifications(names, tensor);
Assert.assertEquals(classifications.topK(1), Collections.singletonList("Third"));
}
}

}

0 comments on commit a5c9623

Please sign in to comment.