Skip to content

Commit

Permalink
Expose numpy encode and decode
Browse files Browse the repository at this point in the history
This exposes the numpy encoding on the NDArray class and the decoding on the
NDManager class. It also refactors the NDSerializer to use a
ByteBufferBackedInputStream (added as a utility) to combine the encoding of
input streams and byte buffers.
  • Loading branch information
zachgk committed Oct 11, 2023
1 parent d432a65 commit 4d3e107
Show file tree
Hide file tree
Showing 7 changed files with 142 additions and 44 deletions.
9 changes: 9 additions & 0 deletions api/src/main/java/ai/djl/ndarray/NDArray.java
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,15 @@ default byte[] encode() {
return NDSerializer.encode(this);
}

/**
* Encodes {@code NDArray} to a numpy .npy byte array.
*
* @return a numpy .npy byte array
*/
default byte[] encodeAsNumpy() {
return NDSerializer.encodeAsNumpy(this);
}

/**
* Moves this {@code NDArray} to a different {@link Device}.
*
Expand Down
2 changes: 1 addition & 1 deletion api/src/main/java/ai/djl/ndarray/NDList.java
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ private static NDList decodeNumpy(NDManager manager, InputStream is) throws IOEx
ZipEntry entry;
while ((entry = zis.getNextEntry()) != null) {
String name = entry.getName();
NDArray array = NDSerializer.decodeNumpy(manager, zis);
NDArray array = manager.decodeNumpy(zis);
if (!name.startsWith("arr_") && name.endsWith(".npy")) {
array.setName(name.substring(0, name.length() - 4));
}
Expand Down
21 changes: 21 additions & 0 deletions api/src/main/java/ai/djl/ndarray/NDManager.java
Original file line number Diff line number Diff line change
Expand Up @@ -719,6 +719,27 @@ default NDArray decode(InputStream is) throws IOException {
return NDSerializer.decode(this, is);
}

/**
* Decodes a numpy .npy {@link NDArray} through byte array.
*
* @param bytes byte array to load from
* @return {@link NDArray}
*/
default NDArray decodeNumpy(byte[] bytes) {
return NDSerializer.decodeNumpy(this, ByteBuffer.wrap(bytes));
}

/**
* Decodes a numpy .npy {@link NDArray} through {@link DataInputStream}.
*
* @param is input stream data to load from
* @return {@link NDArray}
* @throws IOException data is not readable
*/
default NDArray decodeNumpy(InputStream is) throws IOException {
return NDSerializer.decodeNumpy(this, is);
}

/**
* Loads the NDArrays saved to a file.
*
Expand Down
68 changes: 26 additions & 42 deletions api/src/main/java/ai/djl/ndarray/NDSerializer.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.util.ByteBufferBackedInputStream;

import java.io.ByteArrayOutputStream;
import java.io.DataInputStream;
Expand Down Expand Up @@ -109,6 +110,16 @@ static void encode(NDArray array, OutputStream os) throws IOException {
dos.flush();
}

static byte[] encodeAsNumpy(NDArray array) {
int total = Math.toIntExact(array.size()) * array.getDataType().getNumOfBytes() + 100;
try (ByteArrayOutputStream baos = new ByteArrayOutputStream(total)) {
encodeAsNumpy(array, baos);
return baos.toByteArray();
} catch (IOException e) {
throw new AssertionError("This should never happen", e);
}
}

static void encodeAsNumpy(NDArray array, OutputStream os) throws IOException {
StringBuilder sb = new StringBuilder(80);
sb.append("{'descr': '")
Expand Down Expand Up @@ -141,48 +152,12 @@ static void encodeAsNumpy(NDArray array, OutputStream os) throws IOException {
}

static NDArray decode(NDManager manager, ByteBuffer bb) {
if (!"NDAR".equals(readUTF(bb))) {
throw new IllegalArgumentException("Malformed NDArray data");
}

// NDArray encode version
int version = bb.getInt();
if (version < 1 || version > VERSION) {
throw new IllegalArgumentException("Unexpected NDArray encode version " + version);
}

String name = null;
if (version > 1) {
byte flag = bb.get();
if (flag == 1) {
name = readUTF(bb);
}
}

readUTF(bb); // ignore SparseFormat

// DataType
DataType dataType = DataType.valueOf(readUTF(bb));

// Shape
Shape shape = Shape.decode(bb);

// Data
ByteOrder order;
if (version > 2) {
order = bb.get() == '>' ? ByteOrder.BIG_ENDIAN : ByteOrder.LITTLE_ENDIAN;
} else {
order = ByteOrder.nativeOrder();
try {
return decode(manager, new ByteBufferBackedInputStream(bb));
} catch (IOException e) {
throw new IllegalStateException(
"Unexpected IOException with ByteBufferBackedInputStream", e);
}
int length = bb.getInt();
ByteBuffer data = bb.slice();
data.limit(length);
data.order(order);

NDArray array = manager.create(data, shape, dataType);
array.setName(name);
bb.position(bb.position() + length);
return array;
}

/**
Expand All @@ -201,7 +176,7 @@ static NDArray decode(NDManager manager, InputStream is) throws IOException {
dis = new DataInputStream(is);
}

if (!"NDAR".equals(dis.readUTF())) {
if (!MAGIC_NUMBER.equals(dis.readUTF())) {
throw new IllegalArgumentException("Malformed NDArray data");
}

Expand Down Expand Up @@ -244,6 +219,15 @@ static NDArray decode(NDManager manager, InputStream is) throws IOException {
return array;
}

static NDArray decodeNumpy(NDManager manager, ByteBuffer bb) {
try {
return decodeNumpy(manager, new ByteBufferBackedInputStream(bb));
} catch (IOException e) {
throw new IllegalStateException(
"Unexpected IOException with ByteBufferBackedInputStream", e);
}
}

static NDArray decodeNumpy(NDManager manager, InputStream is) throws IOException {
DataInputStream dis;
if (is instanceof DataInputStream) {
Expand Down
74 changes: 74 additions & 0 deletions api/src/main/java/ai/djl/util/ByteBufferBackedInputStream.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
/*
* Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.util;

import java.io.IOException;
import java.io.InputStream;
import java.nio.ByteBuffer;

/**
* A utility for reading a {@link ByteBuffer} with an {@link InputStream}.
*
* @see <a href="https://stackoverflow.com/a/6603018">source</a>
*/
public class ByteBufferBackedInputStream extends InputStream {
ByteBuffer buf;

/**
* Constructs a new {@link ByteBufferBackedInputStream}.
*
* @param buf the backing buffer
*/
public ByteBufferBackedInputStream(ByteBuffer buf) {
this.buf = buf;
}

/** {@inheritDoc} */
@Override
public int read() {
if (!buf.hasRemaining()) {
return -1;
}
return buf.get() & 0xFF;
}

/** {@inheritDoc} */
@Override
public int read(byte[] bytes, int off, int len) {
if (!buf.hasRemaining()) {
return -1;
}

len = Math.min(len, buf.remaining());
buf.get(bytes, off, len);
return len;
}

/** {@inheritDoc} */
@Override
public synchronized void mark(int readlimit) {
buf.mark();
}

/** {@inheritDoc} */
@Override
public synchronized void reset() throws IOException {
buf.reset();
}

/** {@inheritDoc} */
@Override
public boolean markSupported() {
return true;
}
}
2 changes: 1 addition & 1 deletion api/src/test/java/ai/djl/ndarray/NDSerializerTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ private static byte[] encode(NDArray array) throws IOException {

private static NDArray decode(NDManager manager, byte[] data) throws IOException {
try (ByteArrayInputStream bis = new ByteArrayInputStream(data)) {
return NDSerializer.decodeNumpy(manager, bis);
return manager.decodeNumpy(bis);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -841,6 +841,16 @@ public void testEncodeDecode() {
}
}

@Test
public void testEncodeDecodeNumpy() {
try (NDManager manager = NDManager.newBaseManager(TestUtils.getEngine())) {
NDArray array = manager.create(new byte[] {0, 3}, new Shape(2));
byte[] bytes = array.encodeAsNumpy();
NDArray recovered = manager.decodeNumpy(bytes);
Assertions.assertAlmostEquals(recovered, array);
}
}

@Test
public void testErfinv() {
try (NDManager manager = NDManager.newBaseManager(TestUtils.getEngine())) {
Expand Down

0 comments on commit 4d3e107

Please sign in to comment.