Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow streaming batches of data from a DataFrame #75

Merged
merged 10 commits into from
Sep 4, 2023
1 change: 1 addition & 0 deletions datafusion-java/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ dependencies {
implementation 'org.slf4j:slf4j-api:1.7.36'
implementation 'org.apache.arrow:arrow-format:13.0.0'
implementation 'org.apache.arrow:arrow-vector:13.0.0'
implementation 'org.apache.arrow:arrow-c-data:13.0.0'
runtimeOnly 'org.apache.arrow:arrow-memory-unsafe:13.0.0'
testImplementation 'org.junit.jupiter:junit-jupiter:5.8.1'
testImplementation 'org.apache.hadoop:hadoop-client:3.3.5'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,14 @@ public interface DataFrame extends NativeProxy {
*/
CompletableFuture<ArrowReader> collect(BufferAllocator allocator);

/**
* Execute this DataFrame and return a stream of the result data
*
* @param allocator {@link BufferAllocator buffer allocator} to allocate vectors for the stream
* @return Stream of results
*/
CompletableFuture<RecordBatchStream> executeStream(BufferAllocator allocator);

/**
* Print results.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ private DataFrames() {}
static native void collectDataframe(
long runtime, long dataframe, BiConsumer<String, byte[]> callback);

static native void executeStream(long runtime, long dataframe, ObjectResultCallback callback);

static native void writeParquet(
long runtime, long dataframe, String path, Consumer<String> callback);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,25 @@ public CompletableFuture<ArrowReader> collect(BufferAllocator allocator) {
return result;
}

@Override
public CompletableFuture<RecordBatchStream> executeStream(BufferAllocator allocator) {
CompletableFuture<RecordBatchStream> result = new CompletableFuture<>();
Runtime runtime = context.getRuntime();
long runtimePointer = runtime.getPointer();
long dataframe = getPointer();
DataFrames.executeStream(
runtimePointer,
dataframe,
(errString, streamId) -> {
if (containsError(errString)) {
result.completeExceptionally(new RuntimeException(errString));
} else {
result.complete(new DefaultRecordBatchStream(context, streamId, allocator));
}
});
return result;
}

private boolean containsError(String errString) {
return errString != null && !errString.isEmpty();
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
package org.apache.arrow.datafusion;

import java.util.Set;
import java.util.concurrent.CompletableFuture;
import org.apache.arrow.c.ArrowArray;
import org.apache.arrow.c.ArrowSchema;
import org.apache.arrow.c.CDataDictionaryProvider;
import org.apache.arrow.c.Data;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.dictionary.Dictionary;
import org.apache.arrow.vector.types.pojo.Schema;

class DefaultRecordBatchStream extends AbstractProxy implements RecordBatchStream {
private final SessionContext context;
private final BufferAllocator allocator;
private final CDataDictionaryProvider dictionaryProvider;
private VectorSchemaRoot vectorSchemaRoot = null;
private boolean initialized = false;

DefaultRecordBatchStream(SessionContext context, long pointer, BufferAllocator allocator) {
super(pointer);
this.context = context;
this.allocator = allocator;
this.dictionaryProvider = new CDataDictionaryProvider();
}

@Override
void doClose(long pointer) {
destroy(pointer);
dictionaryProvider.close();
if (initialized) {
vectorSchemaRoot.close();
}
}

@Override
public VectorSchemaRoot getVectorSchemaRoot() {
ensureInitialized();
return vectorSchemaRoot;
}

@Override
public CompletableFuture<Boolean> loadNextBatch() {
ensureInitialized();
Runtime runtime = context.getRuntime();
long runtimePointer = runtime.getPointer();
long recordBatchStream = getPointer();
CompletableFuture<Boolean> result = new CompletableFuture<>();
next(
runtimePointer,
recordBatchStream,
(errString, arrowArrayAddress) -> {
if (containsError(errString)) {
result.completeExceptionally(new RuntimeException(errString));
} else if (arrowArrayAddress == 0) {
// Reached end of stream
result.complete(false);
} else {
try {
ArrowArray arrowArray = ArrowArray.wrap(arrowArrayAddress);
Data.importIntoVectorSchemaRoot(
allocator, arrowArray, vectorSchemaRoot, dictionaryProvider);
result.complete(true);
} catch (Exception e) {
result.completeExceptionally(e);
}
}
});
return result;
}

@Override
public Dictionary lookup(long id) {
return dictionaryProvider.lookup(id);
}

@Override
public Set<Long> getDictionaryIds() {
return dictionaryProvider.getDictionaryIds();
}

private void ensureInitialized() {
if (!initialized) {
Schema schema = getSchema();
this.vectorSchemaRoot = VectorSchemaRoot.create(schema, allocator);
}
initialized = true;
}

private Schema getSchema() {
long recordBatchStream = getPointer();
// Native method is not async, but use a future to store the result for convenience
CompletableFuture<Schema> result = new CompletableFuture<>();
getSchema(
recordBatchStream,
(errString, arrowSchemaAddress) -> {
if (containsError(errString)) {
result.completeExceptionally(new RuntimeException(errString));
} else {
try {
ArrowSchema arrowSchema = ArrowSchema.wrap(arrowSchemaAddress);
Schema schema = Data.importSchema(allocator, arrowSchema, dictionaryProvider);
result.complete(schema);
// The FFI schema will be released from rust when it is dropped
} catch (Exception e) {
result.completeExceptionally(e);
}
}
});
return result.join();
}

private static boolean containsError(String errString) {
return errString != null && !"".equals(errString);
}

private static native void getSchema(long pointer, ObjectResultCallback callback);

private static native void next(long runtime, long pointer, ObjectResultCallback callback);

private static native void destroy(long pointer);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package org.apache.arrow.datafusion;

import java.util.concurrent.CompletableFuture;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.dictionary.DictionaryProvider;

/**
* A record batch stream is a stream of tabular Arrow data that can be iterated over asynchronously
*/
public interface RecordBatchStream extends AutoCloseable, NativeProxy, DictionaryProvider {
/**
* Get the VectorSchemaRoot that will be populated with data as the stream is iterated over
*
* @return the stream's VectorSchemaRoot
*/
VectorSchemaRoot getVectorSchemaRoot();

/**
* Load the next record batch in the stream into the VectorSchemaRoot
*
* @return Future that will complete with true if a batch was loaded or false if the end of the
* stream has been reached
*/
CompletableFuture<Boolean> loadNextBatch();
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
package org.apache.arrow.datafusion;

import static org.junit.jupiter.api.Assertions.*;

import java.net.URL;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.Arrays;
import java.util.List;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.vector.BigIntVector;
import org.apache.arrow.vector.Float8Vector;
import org.apache.arrow.vector.VarCharVector;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.dictionary.DictionaryEncoder;
import org.apache.arrow.vector.types.pojo.Schema;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;

public class TestExecuteStream {
@Test
public void executeStream(@TempDir Path tempDir) throws Exception {
try (SessionContext context = SessionContexts.create();
BufferAllocator allocator = new RootAllocator()) {
Path csvFilePath = tempDir.resolve("data.csv");

List<String> lines = Arrays.asList("x,y,z", "1,2,3.5", "4,5,6.5", "7,8,9.5");
Files.write(csvFilePath, lines);

context.registerCsv("test", csvFilePath).join();

try (RecordBatchStream stream =
context
.sql("SELECT y,z FROM test WHERE x > 3")
.thenComposeAsync(df -> df.executeStream(allocator))
.join()) {
VectorSchemaRoot root = stream.getVectorSchemaRoot();
Schema schema = root.getSchema();
assertEquals(2, schema.getFields().size());
assertEquals("y", schema.getFields().get(0).getName());
assertEquals("z", schema.getFields().get(1).getName());

assertTrue(stream.loadNextBatch().join());
assertEquals(2, root.getRowCount());
BigIntVector yValues = (BigIntVector) root.getVector(0);
assertEquals(5, yValues.get(0));
assertEquals(8, yValues.get(1));
Float8Vector zValues = (Float8Vector) root.getVector(1);
assertEquals(6.5, zValues.get(0));
assertEquals(9.5, zValues.get(1));

assertFalse(stream.loadNextBatch().join());
}
}
}

@Test
public void readDictionaryData() throws Exception {
try (SessionContext context = SessionContexts.create();
BufferAllocator allocator = new RootAllocator()) {

URL fileUrl = this.getClass().getResource("/dictionary_data.parquet");
Path parquetFilePath = Paths.get(fileUrl.getPath());

context.registerParquet("test", parquetFilePath).join();

try (RecordBatchStream stream =
context
.sql("SELECT x,y FROM test")
.thenComposeAsync(df -> df.executeStream(allocator))
.join()) {
VectorSchemaRoot root = stream.getVectorSchemaRoot();
Schema schema = root.getSchema();
assertEquals(2, schema.getFields().size());
assertEquals("x", schema.getFields().get(0).getName());
assertEquals("y", schema.getFields().get(1).getName());

int rowsRead = 0;
while (stream.loadNextBatch().join()) {
int batchNumRows = root.getRowCount();
BigIntVector xValuesEncoded = (BigIntVector) root.getVector(0);
long xDictionaryId = xValuesEncoded.getField().getDictionary().getId();
try (VarCharVector xValues =
(VarCharVector)
DictionaryEncoder.decode(xValuesEncoded, stream.lookup(xDictionaryId))) {
String[] expected = {"one", "two", "three"};
for (int i = 0; i < batchNumRows; ++i) {
assertEquals(
new String(xValues.get(i), StandardCharsets.UTF_8), expected[(rowsRead + i) % 3]);
}
}

BigIntVector yValuesEncoded = (BigIntVector) root.getVector(1);
long yDictionaryId = yValuesEncoded.getField().getDictionary().getId();
try (VarCharVector yValues =
(VarCharVector)
DictionaryEncoder.decode(yValuesEncoded, stream.lookup(yDictionaryId))) {
String[] expected = {"four", "five", "six"};
for (int i = 0; i < batchNumRows; ++i) {
assertEquals(
new String(yValues.get(i), StandardCharsets.UTF_8), expected[(rowsRead + i) % 3]);
}
}
rowsRead += batchNumRows;
}

assertEquals(100, rowsRead);
}
}
}
}
Binary file not shown.
18 changes: 18 additions & 0 deletions datafusion-java/write_test_files.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import pyarrow as pa
import pyarrow.parquet as pq


num_rows = 100

dict_array_x = pa.DictionaryArray.from_arrays(
pa.array([i % 3 for i in range(num_rows)]),
pa.array(["one", "two", "three"])
)

dict_array_y = pa.DictionaryArray.from_arrays(
pa.array([i % 3 for i in range(num_rows)]),
pa.array(["four", "five", "six"])
)

table = pa.Table.from_arrays([dict_array_x, dict_array_y], ["x", "y"])
pq.write_table(table, "src/test/resources/dictionary_data.parquet")
3 changes: 2 additions & 1 deletion datafusion-jni/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@ edition = "2021"
[dependencies]
jni = "^0.21.0"
tokio = "^1.32.0"
arrow = "^22.0"
arrow = { version = "^22.0", features = ["ffi"] }
datafusion = "^12.0"
futures = "0.3.28"

[lib]
crate_type = ["cdylib"]
Expand Down
39 changes: 39 additions & 0 deletions datafusion-jni/src/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,45 @@ pub extern "system" fn Java_org_apache_arrow_datafusion_DataFrames_collectDatafr
});
}

#[no_mangle]
pub extern "system" fn Java_org_apache_arrow_datafusion_DataFrames_executeStream(
env: mut JNIEnv,
jimexist marked this conversation as resolved.
Show resolved Hide resolved
_class: JClass,
runtime: jlong,
dataframe: jlong,
callback: JObject,
) {
let runtime = unsafe { &mut *(runtime as *mut Runtime) };
let dataframe = unsafe { &mut *(dataframe as *mut Arc<DataFrame>) };
runtime.block_on(async {
let stream_result = dataframe.execute_stream().await;
match stream_result {
Ok(stream) => {
let stream = Box::into_raw(Box::new(stream)) as jlong;
env.call_method(
callback,
"callback",
"(Ljava/lang/String;J)V",
&[JValue::Void, stream.into()],
)
}
Err(err) => {
let stream = -1 as jlong;
let err_message = env
.new_string(err.to_string())
.expect("Couldn't create java string!");
env.call_method(
callback,
"callback",
"(Ljava/lang/String;J)V",
&[err_message.into(), stream.into()],
)
}
}
.expect("failed to call method");
});
}

#[no_mangle]
pub extern "system" fn Java_org_apache_arrow_datafusion_DataFrames_showDataframe(
mut env: JNIEnv,
Expand Down
Loading
Loading