From 13d03cb1a06163aa18702944881f4accd6970d8d Mon Sep 17 00:00:00 2001 From: liurenjie1024 Date: Thu, 24 Oct 2024 10:48:43 +0800 Subject: [PATCH 1/3] Introduce kudo serialization format. --- .../nvidia/spark/rapids/jni/SlicedTable.java | 49 +++ .../nvidia/spark/rapids/jni/TableUtils.java | 131 +++++++ .../rapids/jni/kudo/ColumnOffsetInfo.java | 47 +++ .../spark/rapids/jni/kudo/ColumnViewInfo.java | 56 +++ .../jni/kudo/DataOutputStreamWriter.java | 67 ++++ .../spark/rapids/jni/kudo/DataWriter.java | 36 ++ .../rapids/jni/kudo/HostBufferMerger.java | 339 ++++++++++++++++++ .../rapids/jni/kudo/HostMergeResult.java | 46 +++ .../spark/rapids/jni/kudo/KudoSerializer.java | 209 +++++++++++ .../spark/rapids/jni/kudo/MergeMetrics.java | 66 ++++ .../spark/rapids/jni/kudo/MergedInfoCalc.java | 153 ++++++++ .../rapids/jni/kudo/MultiTableVisitor.java | 243 +++++++++++++ .../spark/rapids/jni/kudo/RefUtils.java | 73 ++++ .../rapids/jni/kudo/SerializedTable.java | 36 ++ .../jni/kudo/SerializedTableHeader.java | 155 ++++++++ .../jni/kudo/SerializedTableHeaderCalc.java | 155 ++++++++ .../spark/rapids/jni/kudo/SliceInfo.java | 34 ++ .../jni/kudo/SlicedBufferSerializer.java | 177 +++++++++ .../jni/kudo/SlicedValidityBufferInfo.java | 54 +++ .../spark/rapids/jni/kudo/TableBuilder.java | 78 ++++ .../rapids/jni/schema/HostColumnsVisitor.java | 16 + .../rapids/jni/schema/SchemaVisitor.java | 22 ++ .../spark/rapids/jni/schema/Visitors.java | 74 ++++ 23 files changed, 2316 insertions(+) create mode 100644 src/main/java/com/nvidia/spark/rapids/jni/SlicedTable.java create mode 100644 src/main/java/com/nvidia/spark/rapids/jni/TableUtils.java create mode 100644 src/main/java/com/nvidia/spark/rapids/jni/kudo/ColumnOffsetInfo.java create mode 100644 src/main/java/com/nvidia/spark/rapids/jni/kudo/ColumnViewInfo.java create mode 100644 src/main/java/com/nvidia/spark/rapids/jni/kudo/DataOutputStreamWriter.java create mode 100644 src/main/java/com/nvidia/spark/rapids/jni/kudo/DataWriter.java create mode 100644 src/main/java/com/nvidia/spark/rapids/jni/kudo/HostBufferMerger.java create mode 100644 src/main/java/com/nvidia/spark/rapids/jni/kudo/HostMergeResult.java create mode 100644 src/main/java/com/nvidia/spark/rapids/jni/kudo/KudoSerializer.java create mode 100644 src/main/java/com/nvidia/spark/rapids/jni/kudo/MergeMetrics.java create mode 100644 src/main/java/com/nvidia/spark/rapids/jni/kudo/MergedInfoCalc.java create mode 100644 src/main/java/com/nvidia/spark/rapids/jni/kudo/MultiTableVisitor.java create mode 100644 src/main/java/com/nvidia/spark/rapids/jni/kudo/RefUtils.java create mode 100644 src/main/java/com/nvidia/spark/rapids/jni/kudo/SerializedTable.java create mode 100644 src/main/java/com/nvidia/spark/rapids/jni/kudo/SerializedTableHeader.java create mode 100644 src/main/java/com/nvidia/spark/rapids/jni/kudo/SerializedTableHeaderCalc.java create mode 100644 src/main/java/com/nvidia/spark/rapids/jni/kudo/SliceInfo.java create mode 100644 src/main/java/com/nvidia/spark/rapids/jni/kudo/SlicedBufferSerializer.java create mode 100644 src/main/java/com/nvidia/spark/rapids/jni/kudo/SlicedValidityBufferInfo.java create mode 100644 src/main/java/com/nvidia/spark/rapids/jni/kudo/TableBuilder.java create mode 100644 src/main/java/com/nvidia/spark/rapids/jni/schema/HostColumnsVisitor.java create mode 100644 src/main/java/com/nvidia/spark/rapids/jni/schema/SchemaVisitor.java create mode 100644 src/main/java/com/nvidia/spark/rapids/jni/schema/Visitors.java diff --git a/src/main/java/com/nvidia/spark/rapids/jni/SlicedTable.java b/src/main/java/com/nvidia/spark/rapids/jni/SlicedTable.java new file mode 100644 index 000000000..b0d19748f --- /dev/null +++ b/src/main/java/com/nvidia/spark/rapids/jni/SlicedTable.java @@ -0,0 +1,49 @@ +package com.nvidia.spark.rapids.jni; + +import ai.rapids.cudf.Table; + +import java.util.Objects; + +import static com.nvidia.spark.rapids.jni.TableUtils.ensure; + + +/** + * A sliced view to table. + * This table doesn't change ownership of the underlying data. + */ +public class SlicedTable { + private final long startRow; + private final long numRows; + private final Table table; + + public SlicedTable(long startRow, long numRows, Table table) { + Objects.requireNonNull(table, "table must not be null"); + ensure(startRow >= 0, "startRow must be >= 0"); + ensure(startRow < table.getRowCount(), + () -> "startRow " + startRow + " is larger than table row count " + table.getRowCount()); + ensure(numRows >= 0, () -> "numRows " + numRows + " is negative"); + ensure(startRow + numRows <= table.getRowCount(), () -> "startRow + numRows is " + (startRow + numRows) + + ", must be less than table row count " + table.getRowCount()); + + this.startRow = startRow; + this.numRows = numRows; + this.table = table; + } + + public long getStartRow() { + return startRow; + } + + public long getNumRows() { + return numRows; + } + + public Table getTable() { + return table; + } + + public static SlicedTable from(Table table, long startRow, long numRows) { + return new SlicedTable(startRow, numRows, table); + } +} + diff --git a/src/main/java/com/nvidia/spark/rapids/jni/TableUtils.java b/src/main/java/com/nvidia/spark/rapids/jni/TableUtils.java new file mode 100644 index 000000000..01a15bc18 --- /dev/null +++ b/src/main/java/com/nvidia/spark/rapids/jni/TableUtils.java @@ -0,0 +1,131 @@ +package com.nvidia.spark.rapids.jni; + +import ai.rapids.cudf.*; + +import java.util.Arrays; +import java.util.Iterator; +import java.util.function.Function; +import java.util.function.LongConsumer; +import java.util.function.Supplier; +import java.util.stream.Collectors; + +public class TableUtils { + public static Schema schemaOf(Table t) { + Schema.Builder builder = Schema.builder(); + + for (int i = 0; i < t.getNumberOfColumns(); i++) { + ColumnVector cv = t.getColumn(i); + addToSchema(cv, "col_" + i + "_", builder); + } + + return builder.build(); + } + + public static void addToSchema(ColumnView cv, String namePrefix, Schema.Builder builder) { + toSchemaInner(cv, 0, namePrefix, builder); + } + + private static int toSchemaInner(ColumnView cv, int idx, String namePrefix, + Schema.Builder builder) { + String name = namePrefix + idx; + + Schema.Builder thisBuilder = builder.addColumn(cv.getType(), name); + int lastIdx = idx; + for (int i = 0; i < cv.getNumChildren(); i++) { + lastIdx = toSchemaInner(cv.getChildColumnView(i), lastIdx + 1, namePrefix, + thisBuilder); + } + + return lastIdx; + } + + public static void addToSchema(HostColumnVectorCore cv, String namePrefix, Schema.Builder builder) { + toSchemaInner(cv, 0, namePrefix, builder); + } + + private static int toSchemaInner(HostColumnVectorCore cv, int idx, String namePrefix, + Schema.Builder builder) { + String name = namePrefix + idx; + + Schema.Builder thisBuilder = builder.addColumn(cv.getType(), name); + int lastIdx = idx; + for (int i=0; i < cv.getNumChildren(); i++) { + lastIdx = toSchemaInner(cv.getChildColumnView(i), lastIdx + 1, namePrefix, thisBuilder); + } + + return lastIdx; + } + + public static void ensure(boolean condition, String message) { + if (!condition) { + throw new IllegalArgumentException(message); + } + } + + public static void ensure(boolean condition, Supplier messageSupplier) { + if (!condition) { + throw new IllegalArgumentException(messageSupplier.get()); + } + } + + /** + * This method returns the length in bytes needed to represent X number of rows + * e.g. getValidityLengthInBytes(5) => 1 byte + * getValidityLengthInBytes(7) => 1 byte + * getValidityLengthInBytes(14) => 2 bytes + */ + public static long getValidityLengthInBytes(long rows) { + return (rows + 7) / 8; + } + + /** + * This method returns the allocation size of the validity vector which is 64-byte aligned + * e.g. getValidityAllocationSizeInBytes(5) => 64 bytes + * getValidityAllocationSizeInBytes(14) => 64 bytes + * getValidityAllocationSizeInBytes(65) => 128 bytes + */ + static long getValidityAllocationSizeInBytes(long rows) { + long numBytes = getValidityLengthInBytes(rows); + return ((numBytes + 63) / 64) * 64; + } + + public static T closeIfException(R resource, Function function) { + try { + return function.apply(resource); + } catch (Exception e) { + if (resource != null) { + try { + resource.close(); + } catch (Exception inner) { + // ignore + } + } + throw e; + } + } + + public static void closeQuietly(Iterator resources) { + while (resources.hasNext()) { + try { + resources.next().close(); + } catch (Exception e) { + // ignore + } + } + } + + public static void closeQuietly(R... resources) { + closeQuietly(Arrays.stream(resources).collect(Collectors.toList())); + } + + public static void closeQuietly(Iterable resources) { + closeQuietly(resources.iterator()); + } + + public static T withTime(Supplier task, LongConsumer timeConsumer) { + long now = System.nanoTime(); + T ret = task.get(); + timeConsumer.accept(System.nanoTime() - now); + return ret; + } +} diff --git a/src/main/java/com/nvidia/spark/rapids/jni/kudo/ColumnOffsetInfo.java b/src/main/java/com/nvidia/spark/rapids/jni/kudo/ColumnOffsetInfo.java new file mode 100644 index 000000000..28c7454d7 --- /dev/null +++ b/src/main/java/com/nvidia/spark/rapids/jni/kudo/ColumnOffsetInfo.java @@ -0,0 +1,47 @@ +package com.nvidia.spark.rapids.jni.kudo; + +import java.util.OptionalLong; + +/** + * This class is used to store the offsets of the buffer of a column in the serialized data. + */ +public class ColumnOffsetInfo { + private static final long INVALID_OFFSET = -1L; + private final long validity; + private final long offset; + private final long data; + private final long dataLen; + + public ColumnOffsetInfo(long validity, long offset, long data, long dataLen) { + this.validity = validity; + this.offset = offset; + this.data = data; + this.dataLen = dataLen; + } + + public OptionalLong getValidity() { + return (validity == INVALID_OFFSET) ? OptionalLong.empty() : OptionalLong.of(validity); + } + + public OptionalLong getOffset() { + return (offset == INVALID_OFFSET) ? OptionalLong.empty() : OptionalLong.of(offset); + } + + public OptionalLong getData() { + return (data == INVALID_OFFSET) ? OptionalLong.empty() : OptionalLong.of(data); + } + + public long getDataLen() { + return dataLen; + } + + @Override + public String toString() { + return "ColumnOffsets{" + + "validity=" + validity + + ", offset=" + offset + + ", data=" + data + + ", dataLen=" + dataLen + + '}'; + } +} diff --git a/src/main/java/com/nvidia/spark/rapids/jni/kudo/ColumnViewInfo.java b/src/main/java/com/nvidia/spark/rapids/jni/kudo/ColumnViewInfo.java new file mode 100644 index 000000000..f6b00e8f7 --- /dev/null +++ b/src/main/java/com/nvidia/spark/rapids/jni/kudo/ColumnViewInfo.java @@ -0,0 +1,56 @@ +package com.nvidia.spark.rapids.jni.kudo; + +import ai.rapids.cudf.DType; +import ai.rapids.cudf.DeviceMemoryBuffer; + + +public class ColumnViewInfo { + private final DType dtype; + private final ColumnOffsetInfo offsetInfo; + private final long nullCount; + private final long rowCount; + + public ColumnViewInfo(DType dtype, ColumnOffsetInfo offsetInfo, + long nullCount, long rowCount) { + this.dtype = dtype; + this.offsetInfo = offsetInfo; + this.nullCount = nullCount; + this.rowCount = rowCount; + } + + public long buildColumnView(DeviceMemoryBuffer buffer, long[] childrenView) { + long bufferAddress = buffer.getAddress(); + + long dataAddress = 0; + if (offsetInfo.getData().isPresent()) { + dataAddress = buffer.getAddress() + offsetInfo.getData().getAsLong(); + } + + long validityAddress = 0; + if (offsetInfo.getValidity().isPresent()) { + validityAddress = offsetInfo.getValidity().getAsLong() + bufferAddress; + } + + long offsetsAddress = 0; + if (offsetInfo.getOffset().isPresent()) { + offsetsAddress = offsetInfo.getOffset().getAsLong() + bufferAddress; + } + + return RefUtils.makeCudfColumnView( + dtype.getTypeId().getNativeId(), dtype.getScale(), + dataAddress, offsetInfo.getDataLen(), + offsetsAddress, validityAddress, + safeLongToInt(nullCount), safeLongToInt(rowCount), + childrenView); + } + + @Override + public String toString() { + return "ColumnViewInfo{" + + "dtype=" + dtype + + ", offsetInfo=" + offsetInfo + + ", nullCount=" + nullCount + + ", rowCount=" + rowCount + + '}'; + } +} diff --git a/src/main/java/com/nvidia/spark/rapids/jni/kudo/DataOutputStreamWriter.java b/src/main/java/com/nvidia/spark/rapids/jni/kudo/DataOutputStreamWriter.java new file mode 100644 index 000000000..99f16724b --- /dev/null +++ b/src/main/java/com/nvidia/spark/rapids/jni/kudo/DataOutputStreamWriter.java @@ -0,0 +1,67 @@ +package com.nvidia.spark.rapids.jni.kudo; + +import ai.rapids.cudf.HostMemoryBuffer; + +import java.io.DataOutputStream; +import java.io.IOException; + +/** + * Visible for testing + */ +class DataOutputStreamWriter extends DataWriter { + private final byte[] arrayBuffer = new byte[1024 * 128]; + private final DataOutputStream dout; + + public DataOutputStreamWriter(DataOutputStream dout) { + this.dout = dout; + } + + @Override + public void writeByte(byte b) throws IOException { + dout.writeByte(b); + } + + @Override + public void writeShort(short s) throws IOException { + dout.writeShort(s); + } + + @Override + public void writeInt(int i) throws IOException { + dout.writeInt(i); + } + + @Override + public void writeIntNativeOrder(int i) throws IOException { + // TODO this only works on Little Endian Architectures, x86. If we need + // to support others we need to detect the endianness and switch on the right implementation. + writeInt(Integer.reverseBytes(i)); + } + + @Override + public void writeLong(long val) throws IOException { + dout.writeLong(val); + } + + @Override + public void copyDataFrom(HostMemoryBuffer src, long srcOffset, long len) throws IOException { + long dataLeft = len; + while (dataLeft > 0) { + int amountToCopy = (int)Math.min(arrayBuffer.length, dataLeft); + src.getBytes(arrayBuffer, 0, srcOffset, amountToCopy); + dout.write(arrayBuffer, 0, amountToCopy); + srcOffset += amountToCopy; + dataLeft -= amountToCopy; + } + } + + @Override + public void flush() throws IOException { + dout.flush(); + } + + @Override + public void write(byte[] arr, int offset, int length) throws IOException { + dout.write(arr, offset, length); + } +} diff --git a/src/main/java/com/nvidia/spark/rapids/jni/kudo/DataWriter.java b/src/main/java/com/nvidia/spark/rapids/jni/kudo/DataWriter.java new file mode 100644 index 000000000..2a725d05b --- /dev/null +++ b/src/main/java/com/nvidia/spark/rapids/jni/kudo/DataWriter.java @@ -0,0 +1,36 @@ +package com.nvidia.spark.rapids.jni.kudo; + +import ai.rapids.cudf.HostMemoryBuffer; + +import java.io.IOException; + +/** + * Visible for testing + */ +abstract class DataWriter { + + public abstract void writeByte(byte b) throws IOException; + + public abstract void writeShort(short s) throws IOException; + + public abstract void writeInt(int i) throws IOException; + + public abstract void writeIntNativeOrder(int i) throws IOException; + + public abstract void writeLong(long val) throws IOException; + + /** + * Copy data from src starting at srcOffset and going for len bytes. + * + * @param src where to copy from. + * @param srcOffset offset to start at. + * @param len amount to copy. + */ + public abstract void copyDataFrom(HostMemoryBuffer src, long srcOffset, long len) throws IOException; + + public void flush() throws IOException { + // NOOP by default + } + + public abstract void write(byte[] arr, int offset, int length) throws IOException; +} diff --git a/src/main/java/com/nvidia/spark/rapids/jni/kudo/HostBufferMerger.java b/src/main/java/com/nvidia/spark/rapids/jni/kudo/HostBufferMerger.java new file mode 100644 index 000000000..fedc54ae5 --- /dev/null +++ b/src/main/java/com/nvidia/spark/rapids/jni/kudo/HostBufferMerger.java @@ -0,0 +1,339 @@ +package com.nvidia.spark.rapids.jni.kudo; + +import ai.rapids.cudf.HostMemoryBuffer; +import ai.rapids.cudf.Schema; +import com.nvidia.spark.rapids.jni.TableUtils; +import com.nvidia.spark.rapids.jni.schema.Visitors; + +import java.nio.ByteOrder; +import java.nio.IntBuffer; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.OptionalInt; + +import static com.nvidia.spark.rapids.jni.TableUtils.ensure; +import static com.nvidia.spark.rapids.jni.TableUtils.getValidityLengthInBytes; +import static com.nvidia.spark.rapids.jni.kudo.KudoSerializer.padFor64byteAlignment; +import static com.nvidia.spark.rapids.jni.kudo.KudoSerializer.safeLongToInt; +import static java.lang.Math.min; +import static java.util.Objects.requireNonNull; + +public class HostBufferMerger extends MultiTableVisitor { + // Number of 1s in a byte + private static final int[] NUMBER_OF_ONES = new int[256]; + private static final byte[] ZEROS = new byte[1024]; + + static { + for (int i = 0; i < NUMBER_OF_ONES.length; i += 1) { + int count = 0; + for (int j = 0; j < 8; j += 1) { + if ((i & (1 << j)) != 0) { + count += 1; + } + } + NUMBER_OF_ONES[i] = count; + } + + Arrays.fill(ZEROS, (byte) 0); + } + + private final List columnOffsets; + private final HostMemoryBuffer buffer; + private final List colViewInfoList; + + public HostBufferMerger(List tables, HostMemoryBuffer buffer, List columnOffsets) { + super(tables); + requireNonNull(buffer, "buffer can't be null!"); + ensure(columnOffsets != null, "column offsets cannot be null"); + ensure(!columnOffsets.isEmpty(), "column offsets cannot be empty"); + this.columnOffsets = columnOffsets; + this.buffer = buffer; + this.colViewInfoList = new ArrayList<>(columnOffsets.size()); + } + + @Override + protected HostMergeResult doVisitTopSchema(Schema schema, List children) { + return new HostMergeResult(buffer, colViewInfoList); + } + + @Override + protected Void doVisitStruct(Schema structType, List children) { + ColumnOffsetInfo offsetInfo = getCurColumnOffsets(); + long nullCount = deserializeValidityBuffer(); + long totalRowCount = getTotalRowCount(); + colViewInfoList.add(new ColumnViewInfo(structType.getType(), + offsetInfo, nullCount, totalRowCount)); + return null; + } + + @Override + protected Void doPreVisitList(Schema listType) { + ColumnOffsetInfo offsetInfo = getCurColumnOffsets(); + long nullCount = deserializeValidityBuffer(); + long totalRowCount = getTotalRowCount(); + deserializeOffsetBuffer(); + + colViewInfoList.add(new ColumnViewInfo(listType.getType(), + offsetInfo, nullCount, totalRowCount)); + return null; + } + + @Override + protected Void doVisitList(Schema listType, Void preVisitResult, Void childResult) { + return null; + } + + @Override + protected Void doVisit(Schema primitiveType) { + ColumnOffsetInfo offsetInfo = getCurColumnOffsets(); + long nullCount = deserializeValidityBuffer(); + long totalRowCount = getTotalRowCount(); + if (primitiveType.getType().hasOffsets()) { + deserializeOffsetBuffer(); + deserializeDataBuffer(OptionalInt.empty()); + } else { + deserializeDataBuffer(OptionalInt.of(primitiveType.getType().getSizeInBytes())); + } + + colViewInfoList.add(new ColumnViewInfo(primitiveType.getType(), + offsetInfo, nullCount, totalRowCount)); + + return null; + } + + private long deserializeValidityBuffer() { + ColumnOffsetInfo colOffset = getCurColumnOffsets(); + if (colOffset.getValidity().isPresent()) { + long offset = colOffset.getValidity().getAsLong(); + long validityBufferSize = padFor64byteAlignment( + getValidityLengthInBytes(getTotalRowCount())); + try (HostMemoryBuffer validityBuffer = buffer.slice(offset, validityBufferSize)) { + int nullCountTotal = 0; + int startRow = 0; + for (int tableIdx = 0; tableIdx < getTableSize(); tableIdx += 1) { + SliceInfo sliceInfo = sliceInfoOf(tableIdx); + long validityOffset = validifyBufferOffset(tableIdx); + if (validityOffset != -1) { + nullCountTotal += copyValidityBuffer(validityBuffer, startRow, + memoryBufferOf(tableIdx), safeLongToInt(validityOffset), + sliceInfo); + } else { + appendAllValid(validityBuffer, startRow, sliceInfo.getRowCount()); + } + + startRow += safeLongToInt(sliceInfo.getRowCount()); + } + return nullCountTotal; + } + } else { + return 0; + } + } + + /** + * Copy a sliced validity buffer to the destination buffer, starting at the given bit offset. + * + * @return Number of nulls in the validity buffer. + */ + private static int copyValidityBuffer(HostMemoryBuffer dest, int startBit, + HostMemoryBuffer src, int srcOffset, + SliceInfo sliceInfo) { + int nullCount = 0; + int totalRowCount = safeLongToInt(sliceInfo.getRowCount()); + int curIdx = 0; + int curSrcByteIdx = srcOffset; + int curSrcBitIdx = safeLongToInt(sliceInfo.getValidityBufferInfo().getBeginBit()); + int curDestByteIdx = startBit / 8; + int curDestBitIdx = startBit % 8; + + while (curIdx < totalRowCount) { + int leftRowCount = totalRowCount - curIdx; + int appendCount; + if (curDestBitIdx == 0) { + appendCount = min(8, leftRowCount); + } else { + appendCount = min(8 - curDestBitIdx, leftRowCount); + } + + int leftBitsInCurSrcByte = 8 - curSrcBitIdx; + byte srcByte = src.getByte(curSrcByteIdx); + if (leftBitsInCurSrcByte >= appendCount) { + // Extract appendCount bits from srcByte, starting from curSrcBitIdx + byte mask = (byte) (((1 << appendCount) - 1) & 0xFF); + srcByte = (byte) ((srcByte >>> curSrcBitIdx) & mask); + + nullCount += (appendCount - NUMBER_OF_ONES[srcByte & 0xFF]); + + // Sets the bits in destination buffer starting from curDestBitIdx to 0 + byte destByte = dest.getByte(curDestByteIdx); + destByte = (byte) (destByte & ((1 << curDestBitIdx) - 1) & 0xFF); + + // Update destination byte with the bits from source byte + destByte = (byte) ((destByte | (srcByte << curDestBitIdx)) & 0xFF); + dest.setByte(curDestByteIdx, destByte); + + curSrcBitIdx += appendCount; + if (curSrcBitIdx == 8) { + curSrcBitIdx = 0; + curSrcByteIdx += 1; + } + } else { + // Extract appendCount bits from srcByte, starting from curSrcBitIdx + byte mask = (byte) (((1 << leftBitsInCurSrcByte) - 1) & 0xFF); + srcByte = (byte) ((srcByte >>> curSrcBitIdx) & mask); + + byte nextSrcByte = src.getByte(curSrcByteIdx + 1); + byte nextSrcByteMask = (byte) ((1 << (appendCount - leftBitsInCurSrcByte)) - 1); + nextSrcByte = (byte) (nextSrcByte & nextSrcByteMask); + nextSrcByte = (byte) (nextSrcByte << leftBitsInCurSrcByte); + srcByte = (byte) (srcByte | nextSrcByte); + + nullCount += (appendCount - NUMBER_OF_ONES[srcByte & 0xFF]); + + // Sets the bits in destination buffer starting from curDestBitIdx to 0 + byte destByte = dest.getByte(curDestByteIdx); + destByte = (byte) (destByte & ((1 << curDestBitIdx) - 1)); + + // Update destination byte with the bits from source byte + destByte = (byte) (destByte | (srcByte << curDestBitIdx)); + dest.setByte(curDestByteIdx, destByte); + + // Update the source byte index and bit index + curSrcByteIdx += 1; + curSrcBitIdx = appendCount - leftBitsInCurSrcByte; + } + + curIdx += appendCount; + + // Update the destination byte index and bit index + curDestBitIdx += appendCount; + if (curDestBitIdx == 8) { + curDestBitIdx = 0; + curDestByteIdx += 1; + } + } + + return nullCount; + } + + private static void appendAllValid(HostMemoryBuffer dest, int startBit, long numRowsLong) { + int numRows = safeLongToInt(numRowsLong); + int curDestByteIdx = startBit / 8; + int curDestBitIdx = startBit % 8; + int curIdx = 0; + while (curIdx < numRows) { + int leftRowCount = numRows - curIdx; + int appendCount; + if (curDestBitIdx == 0) { + dest.setByte(curDestByteIdx, (byte) 0xFF); + appendCount = min(8, leftRowCount); + } else { + appendCount = min(8 - curDestBitIdx, leftRowCount); + byte mask = (byte) (((1 << appendCount) - 1) << curDestBitIdx); + byte destByte = dest.getByte(curDestByteIdx); + dest.setByte(curDestByteIdx, (byte) (destByte | mask)); + } + + curDestBitIdx += appendCount; + if (curDestBitIdx == 8) { + curDestBitIdx = 0; + curDestByteIdx += 1; + } + + curIdx += appendCount; + } + } + + private void deserializeOffsetBuffer() { + ColumnOffsetInfo colOffset = getCurColumnOffsets(); + if (colOffset.getOffset().isPresent()) { + long offset = colOffset.getOffset().getAsLong(); + long bufferSize = Integer.BYTES * (getTotalRowCount() + 1); + + IntBuffer buf = buffer + .asByteBuffer(offset, safeLongToInt(bufferSize)) + .order(ByteOrder.LITTLE_ENDIAN) + .asIntBuffer(); + + int accumulatedDataLen = 0; + + for (int tableIdx = 0; tableIdx < getTableSize(); tableIdx += 1) { + SliceInfo sliceInfo = sliceInfoOf(tableIdx); + + if (sliceInfo.getRowCount() > 0) { + int rowCnt = safeLongToInt(sliceInfo.getRowCount()); + + int firstOffset = offsetOf(tableIdx, 0); + int lastOffset = offsetOf(tableIdx, rowCnt); + + for (int i = 0; i < rowCnt; i += 1) { + buf.put(offsetOf(tableIdx, i) - firstOffset + accumulatedDataLen); + } + + accumulatedDataLen += (lastOffset - firstOffset); + } + } + + buf.put(accumulatedDataLen); + } + } + + private void deserializeDataBuffer(OptionalInt sizeInBytes) { + ColumnOffsetInfo colOffset = getCurColumnOffsets(); + + if (colOffset.getData().isPresent() && colOffset.getDataLen() > 0) { + long offset = colOffset.getData().getAsLong(); + long dataLen = colOffset.getDataLen(); + + try (HostMemoryBuffer buf = buffer.slice(offset, dataLen)) { + if (sizeInBytes.isPresent()) { + // Fixed size type + int elementSize = sizeInBytes.getAsInt(); + + long start = 0; + for (int tableIdx = 0; tableIdx < getTableSize(); tableIdx += 1) { + SliceInfo sliceInfo = sliceInfoOf(tableIdx); + if (sliceInfo.getRowCount() > 0) { + int thisDataLen = safeLongToInt(elementSize * sliceInfo.getRowCount()); + copyDataBuffer(buf, start, tableIdx, thisDataLen); + start += thisDataLen; + } + } + } else { + // String type + long start = 0; + for (int tableIdx = 0; tableIdx < getTableSize(); tableIdx += 1) { + int thisDataLen = getStrDataLenOf(tableIdx); + copyDataBuffer(buf, start, tableIdx, thisDataLen); + start += thisDataLen; + } + } + } + } + } + + + private ColumnOffsetInfo getCurColumnOffsets() { + return columnOffsets.get(getCurrentIdx()); + } + + public static HostMergeResult merge(Schema schema, MergedInfoCalc mergedInfo) { + List serializedTables = mergedInfo.getTables(); + return TableUtils.closeIfException(HostMemoryBuffer.allocate(mergedInfo.getTotalDataLen()), + buffer -> { + clearHostBuffer(buffer); + HostBufferMerger merger = new HostBufferMerger(serializedTables, buffer, mergedInfo.getColumnOffsets()); + return Visitors.visitSchema(schema, merger); + }); + } + + private static void clearHostBuffer(HostMemoryBuffer buffer) { + int left = safeLongToInt(buffer.getLength()); + while (left > 0) { + int toWrite = min(left, ZEROS.length); + buffer.setBytes(buffer.getLength() - left, ZEROS, 0, toWrite); + left -= toWrite; + } + } +} diff --git a/src/main/java/com/nvidia/spark/rapids/jni/kudo/HostMergeResult.java b/src/main/java/com/nvidia/spark/rapids/jni/kudo/HostMergeResult.java new file mode 100644 index 000000000..b32ada9c5 --- /dev/null +++ b/src/main/java/com/nvidia/spark/rapids/jni/kudo/HostMergeResult.java @@ -0,0 +1,46 @@ +package com.nvidia.spark.rapids.jni.kudo; + +import ai.rapids.cudf.*; +import com.nvidia.spark.rapids.jni.TableUtils; +import com.nvidia.spark.rapids.jni.schema.Visitors; + +import java.util.List; + +public class HostMergeResult implements AutoCloseable { + private final List columnOffsets; + private final HostMemoryBuffer hostBuf; + + public HostMergeResult(HostMemoryBuffer hostBuf, List columnOffsets) { + this.columnOffsets = columnOffsets; + this.hostBuf = hostBuf; + } + + @Override + public void close() throws Exception { + if (hostBuf != null) { + hostBuf.close(); + } + } + + public ContiguousTable toContiguousTable(Schema schema) { + return TableUtils.closeIfException(DeviceMemoryBuffer.allocate(hostBuf.getLength()), + deviceMemBuf -> { + if (hostBuf.getLength() > 0) { + deviceMemBuf.copyFromHostBuffer(hostBuf); + } + + TableBuilder builder = new TableBuilder(columnOffsets, deviceMemBuf); + Table t = Visitors.visitSchema(schema, builder); + + return RefUtils.makeContiguousTable(t, deviceMemBuf); + }); + } + + @Override + public String toString() { + return "HostMergeResult{" + + "columnOffsets=" + columnOffsets + + ", hostBuf length =" + hostBuf.getLength() + + '}'; + } +} diff --git a/src/main/java/com/nvidia/spark/rapids/jni/kudo/KudoSerializer.java b/src/main/java/com/nvidia/spark/rapids/jni/kudo/KudoSerializer.java new file mode 100644 index 000000000..50b87fd24 --- /dev/null +++ b/src/main/java/com/nvidia/spark/rapids/jni/kudo/KudoSerializer.java @@ -0,0 +1,209 @@ +package com.nvidia.spark.rapids.jni.kudo; + +import ai.rapids.cudf.*; +import com.nvidia.spark.rapids.jni.schema.Visitors; + +import java.io.*; +import java.util.Arrays; +import java.util.List; +import java.util.Objects; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +import static com.nvidia.spark.rapids.jni.TableUtils.withTime; + + +public class KudoSerializer { + + private static final byte[] PADDING = new byte[64]; + + static { + Arrays.fill(PADDING, (byte) 0); + } + + public String version() { + return "MultiTableSerializer-v7"; + } + + public long writeToStream(Table table, OutputStream out, long rowOffset, long numRows) { + + HostColumnVector[] columns = null; + try { + columns = IntStream.range(0, table.getNumberOfColumns()) + .mapToObj(table::getColumn) + .map(ColumnView::copyToHost) + .toArray(HostColumnVector[]::new); + return writeToStream(columns, out, rowOffset, numRows); + } finally { + if (columns != null) { + for (HostColumnVector column : columns) { + column.close(); + } + } + } + } + + public long writeToStream(HostColumnVector[] columns, OutputStream out, long rowOffset, long numRows) { + if (numRows < 0) { + throw new IllegalArgumentException("numRows must be >= 0"); + } + + if (numRows == 0 || columns.length == 0) { + return 0; + } + + try { + return writeSliced(columns, writerFrom(out), rowOffset, numRows); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + public long writeRowsToStream(OutputStream out, long numRows) { + if (numRows <= 0) { + throw new IllegalArgumentException("Number of rows must be > 0, but was " + numRows); + } + try { + DataWriter writer = writerFrom(out); + SerializedTableHeader header = new SerializedTableHeader(0, safeLongToInt(numRows), 0, 0, 0, new byte[0]); + header.writeTo(writer); + writer.flush(); + return header.getSerializedSize(); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + public SerializedTable readOneTableBuffer(InputStream in) { + Objects.requireNonNull(in, "Input stream must not be null"); + + try { + DataInputStream din = readerFrom(in); + SerializedTableHeader header = new SerializedTableHeader(din); + if (!header.wasInitialized()) { + return null; + } + + if (header.getNumRows() <= 0) { + throw new IllegalArgumentException("Number of rows must be > 0, but was " + header.getNumRows()); + } + + // Header only + if (header.getNumColumns() == 0) { + return new SerializedTable(header, null); + } + + HostMemoryBuffer buffer = HostMemoryBuffer.allocate(header.getTotalDataLen(), false); + RefUtils.copyFromStream(buffer, 0, din, header.getTotalDataLen()); + return new SerializedTable(header, buffer); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + public Pair mergeToHost(List serializedTables, + Schema schema) { + MergeMetrics.Builder metricsBuilder = MergeMetrics.builder(); + + MergedInfoCalc mergedInfoCalc = withTime(() -> MergedInfoCalc.calc(schema, serializedTables), + metricsBuilder::calcHeaderTime); +// System.err.println("MergedInfoCalc: " + mergedInfoCalc); + HostMergeResult result = withTime(() -> HostBufferMerger.merge(schema, mergedInfoCalc), + metricsBuilder::mergeIntoHostBufferTime); + return Pair.of(result, metricsBuilder.build()); + + } + + public Pair mergeTable(List buffers, + Schema schema) { + Pair result = mergeToHost(buffers, schema); + MergeMetrics.Builder builder = MergeMetrics.builder(result.getRight()); + try (HostMergeResult children = result.getLeft()) { +// System.err.println("HostMergeResult: " + children); + ContiguousTable table = withTime(() -> children.toContiguousTable(schema), + builder::convertIntoContiguousTableTime); + + return Pair.of(table, builder.build()); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + private static long writeSliced(HostColumnVector[] columns, DataWriter out, long rowOffset, long numRows) throws Exception { + List columnList = Arrays.stream(columns).collect(Collectors.toList()); + + SerializedTableHeaderCalc headerCalc = new SerializedTableHeaderCalc(rowOffset, numRows); + SerializedTableHeader header = Visitors.visitColumns(columnList, headerCalc); + header.writeTo(out); + + long bytesWritten = 0; + for (BufferType bufferType : Arrays.asList(BufferType.VALIDITY, BufferType.OFFSET, BufferType.DATA)) { + bytesWritten += Visitors.visitColumns(columnList, new SlicedBufferSerializer(rowOffset, numRows, bufferType, out)); + } + + if (bytesWritten != header.getTotalDataLen()) { + throw new IllegalStateException("Header total data length: " + header.getTotalDataLen() + + " does not match actual written data length: " + bytesWritten + + ", rowOffset: " + rowOffset + " numRows: " + numRows); + } + + out.flush(); + + return header.getSerializedSize() + bytesWritten; + } + + private static DataInputStream readerFrom(InputStream in) { + if (!(in instanceof DataInputStream)) { + in = new DataInputStream(in); + } + return new DataInputStream(in); + } + + private static DataWriter writerFrom(OutputStream out) { + if (!(out instanceof DataOutputStream)) { + out = new DataOutputStream(new BufferedOutputStream(out)); + } + return new DataOutputStreamWriter((DataOutputStream) out); + } + + + ///////////////////////////////////////////// + // METHODS + ///////////////////////////////////////////// + + + ///////////////////////////////////////////// +// PADDING FOR ALIGNMENT +///////////////////////////////////////////// + static long padForHostAlignment(long orig) { + return ((orig + 3) / 4) * 4; + } + + static long padForHostAlignment(DataWriter out, long bytes) throws IOException { + final long paddedBytes = padForHostAlignment(bytes); + if (paddedBytes > bytes) { + out.write(PADDING, 0, (int) (paddedBytes - bytes)); + } + return paddedBytes; + } + + static long padFor64byteAlignment(long orig) { + return ((orig + 63) / 64) * 64; + } + + static long padFor64byteAlignment(DataWriter out, long bytes) throws IOException { + final long paddedBytes = padFor64byteAlignment(bytes); + if (paddedBytes > bytes) { + out.write(PADDING, 0, (int) (paddedBytes - bytes)); + } + return paddedBytes; + } + + static int safeLongToInt(long value) { +// if (value < Integer.MIN_VALUE || value > Integer.MAX_VALUE) { +// throw new ArithmeticException("Overflow: long value is too large to fit in an int"); +// } + return (int) value; + } + +} \ No newline at end of file diff --git a/src/main/java/com/nvidia/spark/rapids/jni/kudo/MergeMetrics.java b/src/main/java/com/nvidia/spark/rapids/jni/kudo/MergeMetrics.java new file mode 100644 index 000000000..804a7677d --- /dev/null +++ b/src/main/java/com/nvidia/spark/rapids/jni/kudo/MergeMetrics.java @@ -0,0 +1,66 @@ +package com.nvidia.spark.rapids.jni.kudo; + +public class MergeMetrics { + // The time it took to calculate combined header in nanoseconds + private final long calcHeaderTime; + // The time it took to merge the buffers into the host buffer in nanoseconds + private final long mergeIntoHostBufferTime; + // The time it took to convert the host buffer into a contiguous table in nanoseconds + private final long convertIntoContiguousTableTime; + + MergeMetrics(long calcHeaderTime, long mergeIntoHostBufferTime, + long convertIntoContiguousTableTime) { + this.calcHeaderTime = calcHeaderTime; + this.mergeIntoHostBufferTime = mergeIntoHostBufferTime; + this.convertIntoContiguousTableTime = convertIntoContiguousTableTime; + } + + public long getCalcHeaderTime() { + return calcHeaderTime; + } + + public long getMergeIntoHostBufferTime() { + return mergeIntoHostBufferTime; + } + + public long getConvertIntoContiguousTableTime() { + return convertIntoContiguousTableTime; + } + + public static Builder builder() { + return new Builder(); + } + + public static Builder builder(MergeMetrics metrics) { + return new Builder() + .calcHeaderTime(metrics.calcHeaderTime) + .mergeIntoHostBufferTime(metrics.mergeIntoHostBufferTime) + .convertIntoContiguousTableTime(metrics.convertIntoContiguousTableTime); + } + + + public static class Builder { + private long calcHeaderTime; + private long mergeIntoHostBufferTime; + private long convertIntoContiguousTableTime; + + public Builder calcHeaderTime(long calcHeaderTime) { + this.calcHeaderTime = calcHeaderTime; + return this; + } + + public Builder mergeIntoHostBufferTime(long mergeIntoHostBufferTime) { + this.mergeIntoHostBufferTime = mergeIntoHostBufferTime; + return this; + } + + public Builder convertIntoContiguousTableTime(long convertIntoContiguousTableTime) { + this.convertIntoContiguousTableTime = convertIntoContiguousTableTime; + return this; + } + + public MergeMetrics build() { + return new MergeMetrics(calcHeaderTime, mergeIntoHostBufferTime, convertIntoContiguousTableTime); + } + } +} diff --git a/src/main/java/com/nvidia/spark/rapids/jni/kudo/MergedInfoCalc.java b/src/main/java/com/nvidia/spark/rapids/jni/kudo/MergedInfoCalc.java new file mode 100644 index 000000000..e16d472c8 --- /dev/null +++ b/src/main/java/com/nvidia/spark/rapids/jni/kudo/MergedInfoCalc.java @@ -0,0 +1,153 @@ +package com.nvidia.spark.rapids.jni.kudo; + +import ai.rapids.cudf.Schema; +import com.nvidia.spark.rapids.jni.schema.Visitors; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +import static com.nvidia.spark.rapids.jni.TableUtils.getValidityLengthInBytes; +import static com.nvidia.spark.rapids.jni.kudo.KudoSerializer.padFor64byteAlignment; + + +/** + * This class is used to calculate column offsets of merged buffer, + */ +public class MergedInfoCalc extends MultiTableVisitor { + // Total data len in gpu, which accounts for 64 byte alignment + private long totalDataLen; + // Column offset in gpu device buffer, it has one field for each flattened column + private final List columnOffsets; + + public MergedInfoCalc(List tables) { + super(tables); + this.totalDataLen = 0; + this.columnOffsets = new ArrayList<>(tables.get(0).getHeader().getNumColumns()) ; + } + + @Override + protected Void doVisitTopSchema(Schema schema, List children) { + return null; + } + + @Override + protected Void doVisitStruct(Schema structType, List children) { + long validityBufferLen = 0; + long validityOffset = -1; + if (hasNull()) { + validityBufferLen = padFor64byteAlignment(getValidityLengthInBytes(getTotalRowCount())); + validityOffset = totalDataLen; + totalDataLen += validityBufferLen; + } + + columnOffsets.add(new ColumnOffsetInfo(validityOffset, -1, -1, 0)); + return null; + } + + @Override + protected Void doPreVisitList(Schema listType) { + long validityBufferLen = 0; + long validityOffset = -1; + if (hasNull()) { + validityBufferLen = padFor64byteAlignment(getValidityLengthInBytes(getTotalRowCount())); + validityOffset = totalDataLen; + totalDataLen += validityBufferLen; + } + + long offsetBufferLen = 0; + long offsetBufferOffset = -1; + if (getTotalRowCount() > 0) { + offsetBufferLen = padFor64byteAlignment((getTotalRowCount() + 1) * Integer.BYTES); + offsetBufferOffset = totalDataLen; + totalDataLen += offsetBufferLen; + } + + + columnOffsets.add(new ColumnOffsetInfo(validityOffset, offsetBufferOffset, -1, 0)); + return null; + } + + @Override + protected Void doVisitList(Schema listType, Void preVisitResult, Void childResult) { + return null; + } + + @Override + protected Void doVisit(Schema primitiveType) { + // String type + if (primitiveType.getType().hasOffsets()) { + long validityBufferLen = 0; + long validityOffset = -1; + if (hasNull()) { + validityBufferLen = padFor64byteAlignment(getValidityLengthInBytes(getTotalRowCount())); + validityOffset = totalDataLen; + totalDataLen += validityBufferLen; + } + + long offsetBufferLen = 0; + long offsetBufferOffset = -1; + if (getTotalRowCount() > 0) { + offsetBufferLen = padFor64byteAlignment((getTotalRowCount() + 1) * Integer.BYTES); + offsetBufferOffset = totalDataLen; + totalDataLen += offsetBufferLen; + } + + long dataBufferLen = 0; + long dataBufferOffset = -1; + if (getTotalStrDataLen() > 0) { + dataBufferLen = padFor64byteAlignment(getTotalStrDataLen()); + dataBufferOffset = totalDataLen; + totalDataLen += dataBufferLen; + } + + columnOffsets.add(new ColumnOffsetInfo(validityOffset, offsetBufferOffset, dataBufferOffset, dataBufferLen)); + } else { + long totalRowCount = getTotalRowCount(); + long validityBufferLen = 0; + long validityOffset = -1; + if (hasNull()) { + validityBufferLen = padFor64byteAlignment(getValidityLengthInBytes(totalRowCount)); + validityOffset = totalDataLen; + totalDataLen += validityBufferLen; + } + + long offsetBufferOffset = -1; + + long dataBufferLen = 0; + long dataBufferOffset = -1; + if (totalRowCount > 0) { + dataBufferLen = padFor64byteAlignment(totalRowCount * primitiveType.getType().getSizeInBytes()); + dataBufferOffset = totalDataLen; + totalDataLen += dataBufferLen; + } + + columnOffsets.add(new ColumnOffsetInfo(validityOffset, offsetBufferOffset, dataBufferOffset, dataBufferLen)); + } + + return null; + } + + + public long getTotalDataLen() { + return totalDataLen; + } + + public List getColumnOffsets() { + return Collections.unmodifiableList(columnOffsets); + } + + @Override + public String toString() { + return "MergedInfoCalc{" + + "totalDataLen=" + totalDataLen + + ", columnOffsets=" + columnOffsets + + '}'; + } + + public static MergedInfoCalc calc(Schema schema, List table) { + MergedInfoCalc calc = new MergedInfoCalc(table); + Visitors.visitSchema(schema, calc); + return calc; + } +} diff --git a/src/main/java/com/nvidia/spark/rapids/jni/kudo/MultiTableVisitor.java b/src/main/java/com/nvidia/spark/rapids/jni/kudo/MultiTableVisitor.java new file mode 100644 index 000000000..8879fc193 --- /dev/null +++ b/src/main/java/com/nvidia/spark/rapids/jni/kudo/MultiTableVisitor.java @@ -0,0 +1,243 @@ +package com.nvidia.spark.rapids.jni.kudo; + +import ai.rapids.cudf.HostMemoryBuffer; +import ai.rapids.cudf.Schema; +import com.nvidia.spark.rapids.jni.schema.SchemaVisitor; + +import java.util.*; + +import static com.nvidia.spark.rapids.jni.TableUtils.ensure; +import static com.nvidia.spark.rapids.jni.kudo.KudoSerializer.padForHostAlignment; +import static com.nvidia.spark.rapids.jni.kudo.KudoSerializer.safeLongToInt; + + +public abstract class MultiTableVisitor implements SchemaVisitor { + private final List tables; + private final long[] currentValidityOffsets; + private final long[] currentOffsetOffsets; + private final long[] currentDataOffset; + private final Deque[] sliceInfoStack; + private final Deque totalRowCountStack; + // A temporary variable to keep if current column has null + private boolean hasNull; + private int currentIdx; + // Temporary buffer to store data length of string column to avoid repeated allocation + private final int[] strDataLen; + // Temporary variable to calcluate total data length of string column + private long totalStrDataLen; + + protected MultiTableVisitor(List inputTables) { + Objects.requireNonNull(inputTables, "tables cannot be null"); + ensure(!inputTables.isEmpty(), "tables cannot be empty"); + this.tables = inputTables instanceof ArrayList ? inputTables : new ArrayList<>(inputTables); + this.currentValidityOffsets = new long[tables.size()]; + this.currentOffsetOffsets = new long[tables.size()]; + this.currentDataOffset = new long[tables.size()]; + this.sliceInfoStack = new Deque[tables.size()]; + for (int i = 0; i < tables.size(); i++) { + this.currentValidityOffsets[i] = 0; + SerializedTableHeader header = tables.get(i).getHeader(); + this.currentOffsetOffsets[i] = header.getValidityBufferLen(); + this.currentDataOffset[i] = header.getValidityBufferLen() + header.getOffsetBufferLen(); + this.sliceInfoStack[i] = new ArrayDeque<>(16); + this.sliceInfoStack[i].add(new SliceInfo(header.getOffset(), header.getNumRows())); + } + long totalRowCount = tables.stream().mapToLong(t -> t.getHeader().getNumRows()).sum(); + this.totalRowCountStack = new ArrayDeque<>(16); + totalRowCountStack.addLast(totalRowCount); + this.hasNull = true; + this.currentIdx = 0; + this.strDataLen = new int[tables.size()]; + this.totalStrDataLen = 0; + } + + List getTables() { + return tables; + } + + @Override + public R visitTopSchema(Schema schema, List children) { + return doVisitTopSchema(schema, children); + } + + protected abstract R doVisitTopSchema(Schema schema, List children); + + @Override + public T visitStruct(Schema structType, List children) { + updateHasNull(); + T t = doVisitStruct(structType, children); + updateOffsets(false, false, false, -1); + currentIdx += 1; + return t; + } + + protected abstract T doVisitStruct(Schema structType, List children); + + @Override + public T preVisitList(Schema listType) { + updateHasNull(); + T t = doPreVisitList(listType); + updateOffsets(true, false, true, Integer.BYTES); + currentIdx += 1; + return t; + } + + protected abstract T doPreVisitList(Schema listType); + + @Override + public T visitList(Schema listType, T preVisitResult, T childResult) { + T t = doVisitList(listType, preVisitResult, childResult); + for (int tableIdx = 0; tableIdx < tables.size(); tableIdx++) { + sliceInfoStack[tableIdx].removeLast(); + } + totalRowCountStack.removeLast(); + return t; + } + + protected abstract T doVisitList(Schema listType, T preVisitResult, T childResult); + + @Override + public T visit(Schema primitiveType) { + updateHasNull(); + if (primitiveType.getType().hasOffsets()) { + // string type + updateDataLen(); + } + + T t = doVisit(primitiveType); + if (primitiveType.getType().hasOffsets()) { + updateOffsets(true, true, false, -1); + } else { + updateOffsets(false, true, false, primitiveType.getType().getSizeInBytes()); + } + currentIdx += 1; + return t; + } + + protected abstract T doVisit(Schema primitiveType); + + private void updateHasNull() { + hasNull = false; + for (SerializedTable table : tables) { + if (table.getHeader().hasValidityBuffer(currentIdx)) { + hasNull = true; + return; + } + } + } + + // For string column only + private void updateDataLen() { + totalStrDataLen = 0; + // String's data len needs to be calculated from offset buffer + for (int tableIdx = 0; tableIdx < getTableSize(); tableIdx += 1) { + SliceInfo sliceInfo = sliceInfoOf(tableIdx); + if (sliceInfo.getRowCount() > 0) { + int offset = offsetOf(tableIdx, 0); + int endOffset = offsetOf(tableIdx, safeLongToInt(sliceInfo.getRowCount())); + + strDataLen[tableIdx] = endOffset - offset; + totalStrDataLen += strDataLen[tableIdx]; + } else { + strDataLen[tableIdx] = 0; + } + } + } + + private void updateOffsets(boolean updateOffset, boolean updateData, boolean updateSliceInfo, int sizeInBytes) { + long totalRowCount = 0; + for (int tableIdx = 0; tableIdx < tables.size(); tableIdx++) { + SliceInfo sliceInfo = sliceInfoOf(tableIdx); + if (sliceInfo.getRowCount() > 0) { + if (updateSliceInfo) { + int startOffset = offsetOf(tableIdx, 0); + int endOffset = offsetOf(tableIdx, safeLongToInt(sliceInfo.getRowCount())); + int rowCount = endOffset - startOffset; + totalRowCount += rowCount; + + sliceInfoStack[tableIdx].addLast(new SliceInfo(startOffset, rowCount)); + } + + if (tables.get(tableIdx).getHeader().hasValidityBuffer(currentIdx)) { + currentValidityOffsets[tableIdx] += padForHostAlignment(sliceInfo.getValidityBufferInfo().getBufferLength()); + } + + if (updateOffset) { + currentOffsetOffsets[tableIdx] += padForHostAlignment((sliceInfo.getRowCount() + 1) * Integer.BYTES); + if (updateData) { + // string type + currentDataOffset[tableIdx] += padForHostAlignment(strDataLen[tableIdx]); + } + // otherwise list type + } else { + if (updateData) { + // primitive type + currentDataOffset[tableIdx] += padForHostAlignment(sliceInfo.getRowCount() * sizeInBytes); + } + } + + } else { + if (updateSliceInfo) { + sliceInfoStack[tableIdx].addLast(new SliceInfo(0, 0)); + } + } + } + + if (updateSliceInfo) { + totalRowCountStack.addLast(totalRowCount); + } + } + + // Below parts are information about current column + + protected long getTotalRowCount() { + return totalRowCountStack.getLast(); + } + + + protected boolean hasNull() { + return hasNull; + } + + protected SliceInfo sliceInfoOf(int tableIdx) { + return sliceInfoStack[tableIdx].getLast(); + } + + protected HostMemoryBuffer memoryBufferOf(int tableIdx) { + return tables.get(tableIdx).getBuffer(); + } + + protected int offsetOf(int tableIdx, long rowIdx) { + long startOffset = currentOffsetOffsets[tableIdx]; + return tables.get(tableIdx).getBuffer().getInt(startOffset + rowIdx * Integer.BYTES); + } + + protected long validifyBufferOffset(int tableIdx) { + if (tables.get(tableIdx).getHeader().hasValidityBuffer(currentIdx)) { + return currentValidityOffsets[tableIdx]; + } else { + return -1; + } + } + + protected void copyDataBuffer(HostMemoryBuffer dst, long dstOffset, int tableIdx, int dataLen) { + long startOffset = currentDataOffset[tableIdx]; + dst.copyFromHostBuffer(dstOffset, tables.get(tableIdx).getBuffer(), startOffset, dataLen); + } + + protected long getTotalStrDataLen() { + return totalStrDataLen; + } + + protected int getStrDataLenOf(int tableIdx) { + return strDataLen[tableIdx]; + } + + protected int getCurrentIdx() { + return currentIdx; + } + + public int getTableSize() { + return this.tables.size(); + } +} diff --git a/src/main/java/com/nvidia/spark/rapids/jni/kudo/RefUtils.java b/src/main/java/com/nvidia/spark/rapids/jni/kudo/RefUtils.java new file mode 100644 index 000000000..7eb781976 --- /dev/null +++ b/src/main/java/com/nvidia/spark/rapids/jni/kudo/RefUtils.java @@ -0,0 +1,73 @@ +package com.nvidia.spark.rapids.jni.kudo; + +import ai.rapids.cudf.*; + +import java.io.InputStream; +import java.lang.reflect.Constructor; +import java.lang.reflect.Method; + +public class RefUtils { + private static Method MAKE_CUDF_COLUMN_VIEW; + private static Method FROM_VIEW_WITH_CONTIGUOUS_ALLOCATION; + private static Constructor CONTIGUOUS_TABLE_CONSTRUCTOR; + private static Method COPY_FROM_STREAM; + + static { + try { + MAKE_CUDF_COLUMN_VIEW = ColumnView.class.getDeclaredMethod("makeCudfColumnView", + int.class, int.class, long.class, long.class, long.class, long.class, int.class, + int.class, long[].class); + MAKE_CUDF_COLUMN_VIEW.setAccessible(true); + + FROM_VIEW_WITH_CONTIGUOUS_ALLOCATION = ColumnVector.class.getDeclaredMethod( + "fromViewWithContiguousAllocation", + long.class, DeviceMemoryBuffer.class); + FROM_VIEW_WITH_CONTIGUOUS_ALLOCATION.setAccessible(true); + + CONTIGUOUS_TABLE_CONSTRUCTOR = ContiguousTable.class.getDeclaredConstructor(Table.class, + DeviceMemoryBuffer.class); + CONTIGUOUS_TABLE_CONSTRUCTOR.setAccessible(true); + + COPY_FROM_STREAM = HostMemoryBuffer.class.getDeclaredMethod("copyFromStream", + long.class, InputStream.class, long.class); + COPY_FROM_STREAM.setAccessible(true); + } catch (NoSuchMethodException e) { + throw new RuntimeException(e); + } + } + + public static long makeCudfColumnView(int typeId, int scale, long dataAddress, long dataLen, + long offsetsAddress, long validityAddress, int nullCount, int rowCount, long[] childrenView) { + try { + return (long) MAKE_CUDF_COLUMN_VIEW.invoke(null, typeId, scale, dataAddress, dataLen, + offsetsAddress, validityAddress, nullCount, rowCount, childrenView); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + public static ColumnVector fromViewWithContiguousAllocation(long colView, DeviceMemoryBuffer buffer) { + try { + return (ColumnVector) FROM_VIEW_WITH_CONTIGUOUS_ALLOCATION.invoke(null, colView, buffer); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + public static ContiguousTable makeContiguousTable(Table table, DeviceMemoryBuffer buffer) { + try { + return CONTIGUOUS_TABLE_CONSTRUCTOR.newInstance(table, buffer); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + public static void copyFromStream(HostMemoryBuffer buffer, long offset, InputStream in, + long len) { + try { + COPY_FROM_STREAM.invoke(buffer, offset, in, len); + } catch (Exception e) { + throw new RuntimeException(e); + } + } +} diff --git a/src/main/java/com/nvidia/spark/rapids/jni/kudo/SerializedTable.java b/src/main/java/com/nvidia/spark/rapids/jni/kudo/SerializedTable.java new file mode 100644 index 000000000..ec5f26280 --- /dev/null +++ b/src/main/java/com/nvidia/spark/rapids/jni/kudo/SerializedTable.java @@ -0,0 +1,36 @@ +package com.nvidia.spark.rapids.jni.kudo; + +import ai.rapids.cudf.HostMemoryBuffer; + +public class SerializedTable implements AutoCloseable { + private final SerializedTableHeader header; + private final HostMemoryBuffer buffer; + + SerializedTable(SerializedTableHeader header, HostMemoryBuffer buffer) { + this.header = header; + this.buffer = buffer; + } + + public SerializedTableHeader getHeader() { + return header; + } + + public HostMemoryBuffer getBuffer() { + return buffer; + } + + @Override + public String toString() { + return "SerializedTable{" + + "header=" + header + + ", buffer=" + buffer + + '}'; + } + + @Override + public void close() throws Exception { + if (buffer != null) { + buffer.close(); + } + } +} diff --git a/src/main/java/com/nvidia/spark/rapids/jni/kudo/SerializedTableHeader.java b/src/main/java/com/nvidia/spark/rapids/jni/kudo/SerializedTableHeader.java new file mode 100644 index 000000000..02455815f --- /dev/null +++ b/src/main/java/com/nvidia/spark/rapids/jni/kudo/SerializedTableHeader.java @@ -0,0 +1,155 @@ +package com.nvidia.spark.rapids.jni.kudo; + +import java.io.DataInputStream; +import java.io.EOFException; +import java.io.IOException; +import java.util.Arrays; +import java.util.Optional; + +/** + * Holds the metadata about a serialized table. If this is being read from a stream + * isInitialized will return true if the metadata was read correctly from the stream. + * It will return false if an EOF was encountered at the beginning indicating that + * there was no data to be read. + */ +public final class SerializedTableHeader { + /** + * Magic number "KUDO" in ASCII. + */ + private static final int SER_FORMAT_MAGIC_NUMBER = 0x4B55444F; + private static final short VERSION_NUMBER = 0x0001; + + // Useful for reducing calculations in writing. + private long offset; + private long numRows; + private long validityBufferLen; + private long offsetBufferLen; + private long totalDataLen; + // This is used to indicate the validity buffer for the columns. + // 1 means that this column has validity data, 0 means it does not. + private byte[] hasValidityBuffer; + + private boolean initialized = false; + + + public SerializedTableHeader(DataInputStream din) throws IOException { + readFrom(din); + } + + SerializedTableHeader(long offset, long numRows, long validityBufferLen, long offsetBufferLen, long totalDataLen, byte[] hasValidityBuffer) { + this.offset = offset; + this.numRows = numRows; + this.validityBufferLen = validityBufferLen; + this.offsetBufferLen = offsetBufferLen; + this.totalDataLen = totalDataLen; + this.hasValidityBuffer = hasValidityBuffer; + + this.initialized = true; + } + + /** + * Returns the size of a buffer needed to read data into the stream. + */ + public long getTotalDataLen() { + return totalDataLen; + } + + /** + * Returns the number of rows stored in this table. + */ + public long getNumRows() { + return numRows; + } + + public long getOffset() { + return offset; + } + + /** + * Returns true if the metadata for this table was read, else false indicating an EOF was + * encountered. + */ + public boolean wasInitialized() { + return initialized; + } + + public boolean hasValidityBuffer(int columnIndex) { + return hasValidityBuffer[columnIndex] != 0; + } + + public long getSerializedSize() { + return 4 + 2 + 8 + 8 + 8 + 8 + 8 + 4 + hasValidityBuffer.length; + } + + public int getNumColumns() { + return Optional.ofNullable(hasValidityBuffer).map(arr -> arr.length).orElse(0); + } + + public long getValidityBufferLen() { + return validityBufferLen; + } + + public long getOffsetBufferLen() { + return offsetBufferLen; + } + + public boolean isInitialized() { + return initialized; + } + + private void readFrom(DataInputStream din) throws IOException { + try { + int num = din.readInt(); + if (num != SER_FORMAT_MAGIC_NUMBER) { + throw new IllegalStateException("THIS DOES NOT LOOK LIKE CUDF SERIALIZED DATA. " + "Expected magic number " + SER_FORMAT_MAGIC_NUMBER + " Found " + num); + } + } catch (EOFException e) { + // If we get an EOF at the very beginning don't treat it as an error because we may + // have finished reading everything... + return; + } + short version = din.readShort(); + if (version != VERSION_NUMBER) { + throw new IllegalStateException("READING THE WRONG SERIALIZATION FORMAT VERSION FOUND " + version + " EXPECTED " + VERSION_NUMBER); + } + + offset = din.readLong(); + numRows = din.readLong(); + + validityBufferLen = din.readLong(); + offsetBufferLen = din.readLong(); + totalDataLen = din.readLong(); + int validityBufferLength = din.readInt(); + hasValidityBuffer = new byte[validityBufferLength]; + din.readFully(hasValidityBuffer); + + initialized = true; + } + + public void writeTo(DataWriter dout) throws IOException { + // Now write out the data + dout.writeInt(SER_FORMAT_MAGIC_NUMBER); + dout.writeShort(VERSION_NUMBER); + + dout.writeLong(offset); + dout.writeLong(numRows); + dout.writeLong(validityBufferLen); + dout.writeLong(offsetBufferLen); + dout.writeLong(totalDataLen); + dout.writeInt(hasValidityBuffer.length); + dout.write(hasValidityBuffer, 0, hasValidityBuffer.length); + } + + @Override + public String toString() { + return "SerializedTableHeader{" + + "offset=" + offset + + ", numRows=" + numRows + + ", validityBufferLen=" + validityBufferLen + + ", offsetBufferLen=" + offsetBufferLen + + ", totalDataLen=" + totalDataLen + + ", hasValidityBuffer=" + Arrays.toString(hasValidityBuffer) + + ", initialized=" + initialized + + '}'; + } +} diff --git a/src/main/java/com/nvidia/spark/rapids/jni/kudo/SerializedTableHeaderCalc.java b/src/main/java/com/nvidia/spark/rapids/jni/kudo/SerializedTableHeaderCalc.java new file mode 100644 index 000000000..db4d02ce3 --- /dev/null +++ b/src/main/java/com/nvidia/spark/rapids/jni/kudo/SerializedTableHeaderCalc.java @@ -0,0 +1,155 @@ +package com.nvidia.spark.rapids.jni.kudo; + +import ai.rapids.cudf.BufferType; +import ai.rapids.cudf.DType; +import ai.rapids.cudf.HostColumnVectorCore; +import com.nvidia.spark.rapids.jni.schema.HostColumnsVisitor; + +import java.util.ArrayDeque; +import java.util.ArrayList; +import java.util.Deque; +import java.util.List; + +import static com.nvidia.spark.rapids.jni.kudo.KudoSerializer.padForHostAlignment; + + +class SerializedTableHeaderCalc implements HostColumnsVisitor { + private final SliceInfo root; + private final List hasValidityBuffer = new ArrayList<>(1024); + private long validityBufferLen; + private long offsetBufferLen; + private long totalDataLen; + + private Deque sliceInfos = new ArrayDeque<>(); + + SerializedTableHeaderCalc(long rowOffset, long numRows) { + this.root = new SliceInfo(rowOffset, numRows); + this.totalDataLen = 0; + sliceInfos.addLast(this.root); + } + + @Override + public SerializedTableHeader visitTopSchema(List children) { + byte[] hasValidityBuffer = new byte[this.hasValidityBuffer.size()]; + for (int i = 0; i < this.hasValidityBuffer.size(); i++) { + hasValidityBuffer[i] = (byte) (this.hasValidityBuffer.get(i) ? 1 : 0); + } + return new SerializedTableHeader(root.offset, root.rowCount, + validityBufferLen, offsetBufferLen, + totalDataLen, hasValidityBuffer); + } + + @Override + public Void visitStruct(HostColumnVectorCore col, List children) { + SliceInfo parent = sliceInfos.getLast(); + + long validityBufferLength = 0; + if (col.hasValidityVector()) { + validityBufferLength = padForHostAlignment(parent.getValidityBufferInfo().getBufferLength()); + } + + this.validityBufferLen += validityBufferLength; + + totalDataLen += validityBufferLength; + hasValidityBuffer.add(col.getValidity() != null); + return null; + } + + @Override + public Void preVisitList(HostColumnVectorCore col) { + SliceInfo parent = sliceInfos.getLast(); + + + long validityBufferLength = 0; + if (col.hasValidityVector() && parent.rowCount > 0) { + validityBufferLength = padForHostAlignment(parent.getValidityBufferInfo().getBufferLength()); + } + + long offsetBufferLength = 0; + if (col.getOffsets() != null && parent.rowCount > 0) { + offsetBufferLength = padForHostAlignment((parent.rowCount + 1) * Integer.BYTES); + } + + this.validityBufferLen += validityBufferLength; + this.offsetBufferLen += offsetBufferLength; + this.totalDataLen += validityBufferLength + offsetBufferLength; + + hasValidityBuffer.add(col.getValidity() != null); + + SliceInfo current; + + if (col.getOffsets() != null) { + long start = col.getOffsets().getInt(parent.offset * Integer.BYTES); + long end = col.getOffsets().getInt((parent.offset + parent.rowCount) * Integer.BYTES); + long rowCount = end - start; + current = new SliceInfo(start, rowCount); + } else { + current = new SliceInfo(0, 0); + } + + sliceInfos.addLast(current); + return null; + } + + @Override + public Void visitList(HostColumnVectorCore col, Void preVisitResult, Void childResult) { + sliceInfos.removeLast(); + + return null; + } + + + @Override + public Void visit(HostColumnVectorCore col) { + SliceInfo parent = sliceInfos.peekLast(); + long validityBufferLen = calcPrimitiveDataLen(col, BufferType.VALIDITY, parent); + long offsetBufferLen = calcPrimitiveDataLen(col, BufferType.OFFSET, parent); + long dataBufferLen = calcPrimitiveDataLen(col, BufferType.DATA, parent); + + this.validityBufferLen += validityBufferLen; + this.offsetBufferLen += offsetBufferLen; + this.totalDataLen += validityBufferLen + offsetBufferLen + dataBufferLen; + + hasValidityBuffer.add(col.getValidity() != null); + + return null; + } + + private long calcPrimitiveDataLen(HostColumnVectorCore col, + BufferType bufferType, + SliceInfo info) { + switch (bufferType) { + case VALIDITY: + if (col.hasValidityVector() && info.getRowCount() > 0) { + return padForHostAlignment(info.getValidityBufferInfo().getBufferLength()); + } else { + return 0; + } + case OFFSET: + if (DType.STRING.equals(col.getType()) && info.getRowCount() > 0) { + return padForHostAlignment((info.rowCount + 1) * Integer.BYTES); + } else { + return 0; + } + case DATA: + if (DType.STRING.equals(col.getType())) { + if (col.getOffsets() != null) { + long startByteOffset = col.getOffsets().getInt(info.offset * Integer.BYTES); + long endByteOffset = col.getOffsets().getInt((info.offset + info.rowCount) * Integer.BYTES); + return padForHostAlignment(endByteOffset - startByteOffset); + } else { + return 0; + } + } else { + if (col.getType().getSizeInBytes() > 0) { + return padForHostAlignment(col.getType().getSizeInBytes() * info.rowCount); + } else { + return 0; + } + } + default: + throw new IllegalArgumentException("Unexpected buffer type: " + bufferType); + + } + } +} diff --git a/src/main/java/com/nvidia/spark/rapids/jni/kudo/SliceInfo.java b/src/main/java/com/nvidia/spark/rapids/jni/kudo/SliceInfo.java new file mode 100644 index 000000000..e7d6dd331 --- /dev/null +++ b/src/main/java/com/nvidia/spark/rapids/jni/kudo/SliceInfo.java @@ -0,0 +1,34 @@ +package com.nvidia.spark.rapids.jni.kudo; + +public class SliceInfo { + final long offset; + final long rowCount; + private final SlicedValidityBufferInfo validityBufferInfo; + + SliceInfo(long offset, long rowCount) { + this.offset = offset; + this.rowCount = rowCount; + this.validityBufferInfo = SlicedValidityBufferInfo.calc(offset, rowCount); + } + + public SlicedValidityBufferInfo getValidityBufferInfo() { + return validityBufferInfo; + } + + public long getOffset() { + return offset; + } + + public long getRowCount() { + return rowCount; + } + + @Override + public String toString() { + return "SliceInfo{" + + "offset=" + offset + + ", rowCount=" + rowCount + + ", validityBufferInfo=" + validityBufferInfo + + '}'; + } +} diff --git a/src/main/java/com/nvidia/spark/rapids/jni/kudo/SlicedBufferSerializer.java b/src/main/java/com/nvidia/spark/rapids/jni/kudo/SlicedBufferSerializer.java new file mode 100644 index 000000000..95ecc4e25 --- /dev/null +++ b/src/main/java/com/nvidia/spark/rapids/jni/kudo/SlicedBufferSerializer.java @@ -0,0 +1,177 @@ +package com.nvidia.spark.rapids.jni.kudo; + +import ai.rapids.cudf.BufferType; +import ai.rapids.cudf.DType; +import ai.rapids.cudf.HostColumnVectorCore; +import ai.rapids.cudf.HostMemoryBuffer; +import com.nvidia.spark.rapids.jni.schema.HostColumnsVisitor; + +import java.io.IOException; +import java.util.ArrayDeque; +import java.util.Deque; +import java.util.List; + +import static com.nvidia.spark.rapids.jni.kudo.KudoSerializer.padForHostAlignment; + + +class SlicedBufferSerializer implements HostColumnsVisitor { + private final SliceInfo root; + private final BufferType bufferType; + private final DataWriter writer; + + private final Deque sliceInfos = new ArrayDeque<>(); + + SlicedBufferSerializer(long rowOffset, long numRows, BufferType bufferType, DataWriter writer) { + this.root = new SliceInfo(rowOffset, numRows); + this.bufferType = bufferType; + this.writer = writer; + this.sliceInfos.addLast(root); + } + + @Override + public Long visitTopSchema(List children) { + return children.stream().mapToLong(Long::longValue).sum(); + } + + @Override + public Long visitStruct(HostColumnVectorCore col, List children) { + SliceInfo parent = sliceInfos.peekLast(); + + long bytesCopied = children.stream().mapToLong(Long::longValue).sum(); + try { + switch (bufferType) { + case VALIDITY: + bytesCopied += this.copySlicedValidity(col, parent); + return bytesCopied; + case OFFSET: + case DATA: + return bytesCopied; + default: + throw new IllegalArgumentException("Unexpected buffer type: " + bufferType); + } + + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + @Override + public Long preVisitList(HostColumnVectorCore col) { + SliceInfo parent = sliceInfos.getLast(); + + + long bytesCopied = 0; + try { + switch (bufferType) { + case VALIDITY: + bytesCopied = this.copySlicedValidity(col, parent); + break; + case OFFSET: + bytesCopied = this.copySlicedOffset(col, parent); + break; + case DATA: + break; + default: + throw new IllegalArgumentException("Unexpected buffer type: " + bufferType); + } + + } catch (IOException e) { + throw new RuntimeException(e); + } + + SliceInfo current; + if (col.getOffsets() != null) { + long start = col.getOffsets() + .getInt(parent.offset * Integer.BYTES); + long end = col.getOffsets().getInt((parent.offset + parent.rowCount) * Integer.BYTES); + long rowCount = end - start; + + current = new SliceInfo(start, rowCount); + } else { + current = new SliceInfo(0, 0); + } + + sliceInfos.addLast(current); + return bytesCopied; + } + + @Override + public Long visitList(HostColumnVectorCore col, Long preVisitResult, Long childResult) { + sliceInfos.removeLast(); + return preVisitResult + childResult; + } + + @Override + public Long visit(HostColumnVectorCore col) { + SliceInfo parent = sliceInfos.getLast(); + try { + switch (bufferType) { + case VALIDITY: + return this.copySlicedValidity(col, parent); + case OFFSET: + return this.copySlicedOffset(col, parent); + case DATA: + return this.copySlicedData(col, parent); + default: + throw new IllegalArgumentException("Unexpected buffer type: " + bufferType); + } + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + private long copySlicedValidity(HostColumnVectorCore column, SliceInfo sliceInfo) throws IOException { + if (column.getValidity() != null && sliceInfo.getRowCount() > 0) { + HostMemoryBuffer buff = column.getValidity(); + long len = sliceInfo.getValidityBufferInfo().getBufferLength(); + writer.copyDataFrom(buff, sliceInfo.getValidityBufferInfo().getBufferOffset(), + len); + return padForHostAlignment(writer, len); + } else { + return 0; + } + } + + private long copySlicedOffset(HostColumnVectorCore column, SliceInfo sliceInfo) throws IOException { + if (sliceInfo.rowCount <= 0 || column.getOffsets() == null) { + // Don't copy anything, there are no rows + return 0; + } + long bytesToCopy = (sliceInfo.rowCount + 1) * Integer.BYTES; + long srcOffset = sliceInfo.offset * Integer.BYTES; + HostMemoryBuffer buff = column.getOffsets(); + writer.copyDataFrom(buff, srcOffset, bytesToCopy); + return padForHostAlignment(writer, bytesToCopy); + } + + private long copySlicedData(HostColumnVectorCore column, SliceInfo sliceInfo) throws IOException { + if (sliceInfo.rowCount > 0) { + DType type = column.getType(); + if (type.equals(DType.STRING)) { + long startByteOffset = column.getOffsets().getInt(sliceInfo.offset * Integer.BYTES); + long endByteOffset = column.getOffsets().getInt((sliceInfo.offset + sliceInfo.rowCount) * Integer.BYTES); + long bytesToCopy = endByteOffset - startByteOffset; + if (column.getData() == null) { + if (bytesToCopy != 0) { + throw new IllegalStateException("String column has no data buffer, " + + "but bytes to copy is not zero: " + bytesToCopy); + } + + return 0; + } else { + writer.copyDataFrom(column.getData(), startByteOffset, bytesToCopy); + return padForHostAlignment(writer, bytesToCopy); + } + } else if (type.getSizeInBytes() > 0) { + long bytesToCopy = sliceInfo.rowCount * type.getSizeInBytes(); + long srcOffset = sliceInfo.offset * type.getSizeInBytes(); + writer.copyDataFrom(column.getData(), srcOffset, bytesToCopy); + return padForHostAlignment(writer, bytesToCopy); + } else { + return 0; + } + } else { + return 0; + } + } +} diff --git a/src/main/java/com/nvidia/spark/rapids/jni/kudo/SlicedValidityBufferInfo.java b/src/main/java/com/nvidia/spark/rapids/jni/kudo/SlicedValidityBufferInfo.java new file mode 100644 index 000000000..749f4b2f1 --- /dev/null +++ b/src/main/java/com/nvidia/spark/rapids/jni/kudo/SlicedValidityBufferInfo.java @@ -0,0 +1,54 @@ +package com.nvidia.spark.rapids.jni.kudo; + +class SlicedValidityBufferInfo { + private final long bufferOffset; + private final long bufferLength; + /// The bit offset within the buffer where the slice starts + private final long beginBit; + private final long endBit; // Exclusive + + SlicedValidityBufferInfo(long bufferOffset, long bufferLength, long beginBit, long endBit) { + this.bufferOffset = bufferOffset; + this.bufferLength = bufferLength; + this.beginBit = beginBit; + this.endBit = endBit; + } + + @Override + public String toString() { + return "SlicedValidityBufferInfo{" + "bufferOffset=" + bufferOffset + ", bufferLength=" + bufferLength + ", beginBit=" + beginBit + ", endBit=" + endBit + '}'; + } + + public long getBufferOffset() { + return bufferOffset; + } + + public long getBufferLength() { + return bufferLength; + } + + public long getBeginBit() { + return beginBit; + } + + public long getEndBit() { + return endBit; + } + + static SlicedValidityBufferInfo calc(long rowOffset, long numRows) { + if (rowOffset < 0) { + throw new IllegalArgumentException("rowOffset must be >= 0, but was " + rowOffset); + } + if (numRows < 0) { + throw new IllegalArgumentException("numRows must be >= 0, but was " + numRows); + } + long bufferOffset = rowOffset / 8; + long beginBit = rowOffset % 8; + long bufferLength = 0; + if (numRows > 0) { + bufferLength = (rowOffset + numRows - 1) / 8 - bufferOffset + 1; + } + long endBit = beginBit + numRows; + return new SlicedValidityBufferInfo(bufferOffset, bufferLength, beginBit, endBit); + } +} diff --git a/src/main/java/com/nvidia/spark/rapids/jni/kudo/TableBuilder.java b/src/main/java/com/nvidia/spark/rapids/jni/kudo/TableBuilder.java new file mode 100644 index 000000000..f7ef384b2 --- /dev/null +++ b/src/main/java/com/nvidia/spark/rapids/jni/kudo/TableBuilder.java @@ -0,0 +1,78 @@ +package com.nvidia.spark.rapids.jni.kudo; + +import ai.rapids.cudf.*; +import com.nvidia.spark.rapids.jni.schema.SchemaVisitor; + +import java.util.List; + +import static com.nvidia.spark.rapids.jni.TableUtils.ensure; +import static java.util.Objects.requireNonNull; + +public class TableBuilder implements SchemaVisitor { + // Current column index + private int curIdx; + private final DeviceMemoryBuffer buffer; + private final List colViewInfoList; + + public TableBuilder(List colViewInfoList, DeviceMemoryBuffer buffer) { + requireNonNull(colViewInfoList, "colViewInfoList cannot be null"); + ensure(!colViewInfoList.isEmpty(), "colViewInfoList cannot be empty"); + requireNonNull(buffer, "Device buffer can't be null!"); + + this.curIdx = 0; + this.buffer = buffer; + this.colViewInfoList = colViewInfoList; + } + + @Override + public Table visitTopSchema(Schema schema, List children) { + try (CloseableArray arr = CloseableArray.wrap(new ColumnVector[children.size()])) { + for (int i = 0; i < children.size(); i++) { + long colView = (long) children.get(i); + arr.set(i, RefUtils.fromViewWithContiguousAllocation(colView, buffer)); + } + + return new Table(arr.getArray()); + } + } + + @Override + public Long visitStruct(Schema structType, List children) { + ColumnViewInfo colViewInfo = getCurrentColumnViewInfo(); + + long[] childrenView = children.stream().mapToLong(o -> (long) o).toArray(); + long columnView = colViewInfo.buildColumnView(buffer, childrenView); + curIdx += 1; + return columnView; + } + + @Override + public ColumnViewInfo preVisitList(Schema listType) { + ColumnViewInfo colViewInfo = getCurrentColumnViewInfo(); + + curIdx += 1; + return colViewInfo; + } + + @Override + public Long visitList(Schema listType, Object preVisitResult, Object childResult) { + ColumnViewInfo colViewInfo = (ColumnViewInfo) preVisitResult; + + long[] children = new long[] { (long) childResult }; + + return colViewInfo.buildColumnView(buffer, children); + } + + @Override + public Long visit(Schema primitiveType) { + ColumnViewInfo colViewInfo = getCurrentColumnViewInfo(); + + long columnView = colViewInfo.buildColumnView(buffer, null); + curIdx += 1; + return columnView; + } + + private ColumnViewInfo getCurrentColumnViewInfo() { + return colViewInfoList.get(curIdx); + } +} diff --git a/src/main/java/com/nvidia/spark/rapids/jni/schema/HostColumnsVisitor.java b/src/main/java/com/nvidia/spark/rapids/jni/schema/HostColumnsVisitor.java new file mode 100644 index 000000000..9c62f00ee --- /dev/null +++ b/src/main/java/com/nvidia/spark/rapids/jni/schema/HostColumnsVisitor.java @@ -0,0 +1,16 @@ +package com.nvidia.spark.rapids.jni.schema; + +import ai.rapids.cudf.HostColumnVectorCore; + +import java.util.List; + +public interface HostColumnsVisitor { + R visitTopSchema(List children); + + T visitStruct(HostColumnVectorCore col, List children); + + T preVisitList(HostColumnVectorCore col); + T visitList(HostColumnVectorCore col, T preVisitResult, T childResult); + + T visit(HostColumnVectorCore col); +} diff --git a/src/main/java/com/nvidia/spark/rapids/jni/schema/SchemaVisitor.java b/src/main/java/com/nvidia/spark/rapids/jni/schema/SchemaVisitor.java new file mode 100644 index 000000000..4d980ac00 --- /dev/null +++ b/src/main/java/com/nvidia/spark/rapids/jni/schema/SchemaVisitor.java @@ -0,0 +1,22 @@ +package com.nvidia.spark.rapids.jni.schema; + +import ai.rapids.cudf.Schema; + +import java.util.List; + +/** + * Interface for visiting a schema in post order. + */ +public interface SchemaVisitor { + R visitTopSchema(Schema schema, List children); + + T visitStruct(Schema structType, List children); + + T preVisitList(Schema listType); + + T visitList(Schema listType, T preVisitResult, T childResult); + + T visit(Schema primitiveType); + + +} diff --git a/src/main/java/com/nvidia/spark/rapids/jni/schema/Visitors.java b/src/main/java/com/nvidia/spark/rapids/jni/schema/Visitors.java new file mode 100644 index 000000000..1da0dfb05 --- /dev/null +++ b/src/main/java/com/nvidia/spark/rapids/jni/schema/Visitors.java @@ -0,0 +1,74 @@ +package com.nvidia.spark.rapids.jni.schema; + +import ai.rapids.cudf.HostColumnVector; +import ai.rapids.cudf.HostColumnVectorCore; +import ai.rapids.cudf.Schema; + +import java.util.ArrayList; +import java.util.List; +import java.util.Objects; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +public class Visitors { + public static R visitSchema(Schema schema, SchemaVisitor visitor) { + Objects.requireNonNull(schema, "schema cannot be null"); + Objects.requireNonNull(visitor, "visitor cannot be null"); + + List childrenResult = IntStream.range(0, schema.getNumChildren()) + .mapToObj(i -> visitSchemaInner(schema.getChild(i), visitor)) + .collect(Collectors.toList()); + + return visitor.visitTopSchema(schema, childrenResult); + } + + private static T visitSchemaInner(Schema schema, SchemaVisitor visitor) { + switch (schema.getType().getTypeId()) { + case STRUCT: + List children = IntStream.range(0, schema.getNumChildren()) + .mapToObj(childIdx -> visitSchemaInner(schema.getChild(childIdx), visitor)) + .collect(Collectors.toList()); + return visitor.visitStruct(schema, children); + case LIST: + T preVisitResult = visitor.preVisitList(schema); + T childResult = visitSchemaInner(schema.getChild(0), visitor); + return visitor.visitList(schema, preVisitResult, childResult); + default: + return visitor.visit(schema); + } + } + + + /** + * Entry point for visiting a schema with columns. + */ + public static R visitColumns(List cols, + HostColumnsVisitor visitor) { + Objects.requireNonNull(cols, "cols cannot be null"); + Objects.requireNonNull(visitor, "visitor cannot be null"); + + List childrenResult = new ArrayList<>(cols.size()); + + for (HostColumnVector col : cols) { + childrenResult.add(visitSchema(col, visitor)); + } + + return visitor.visitTopSchema(childrenResult); + } + + private static T visitSchema(HostColumnVectorCore col, HostColumnsVisitor visitor) { + switch (col.getType().getTypeId()) { + case STRUCT: + List children = IntStream.range(0, col.getNumChildren()) + .mapToObj(childIdx -> visitSchema(col.getChildColumnView(childIdx), visitor)) + .collect(Collectors.toList()); + return visitor.visitStruct(col, children); + case LIST: + T preVisitResult = visitor.preVisitList(col); + T childResult = visitSchema(col.getChildColumnView(0), visitor); + return visitor.visitList(col, preVisitResult, childResult); + default: + return visitor.visit(col); + } + } +} From d7d756d6e824b0aa2a366096dc4deef38e97f5a4 Mon Sep 17 00:00:00 2001 From: liurenjie1024 Date: Thu, 24 Oct 2024 10:51:57 +0800 Subject: [PATCH 2/3] Fix pair Signed-off-by: liurenjie1024 --- .../com/nvidia/spark/rapids/jni/Pair.java | 23 +++++++++++++++++++ .../spark/rapids/jni/kudo/ColumnViewInfo.java | 2 ++ .../spark/rapids/jni/kudo/KudoSerializer.java | 3 ++- 3 files changed, 27 insertions(+), 1 deletion(-) create mode 100644 src/main/java/com/nvidia/spark/rapids/jni/Pair.java diff --git a/src/main/java/com/nvidia/spark/rapids/jni/Pair.java b/src/main/java/com/nvidia/spark/rapids/jni/Pair.java new file mode 100644 index 000000000..f0a1c0955 --- /dev/null +++ b/src/main/java/com/nvidia/spark/rapids/jni/Pair.java @@ -0,0 +1,23 @@ +package com.nvidia.spark.rapids.jni; + +public class Pair { + private final K left; + private final V right; + + public Pair(K left, V right) { + this.left = left; + this.right = right; + } + + public K getLeft() { + return left; + } + + public V getRight() { + return right; + } + + public static Pair of(K left, V right) { + return new Pair<>(left, right); + } +} diff --git a/src/main/java/com/nvidia/spark/rapids/jni/kudo/ColumnViewInfo.java b/src/main/java/com/nvidia/spark/rapids/jni/kudo/ColumnViewInfo.java index f6b00e8f7..d97b469b3 100644 --- a/src/main/java/com/nvidia/spark/rapids/jni/kudo/ColumnViewInfo.java +++ b/src/main/java/com/nvidia/spark/rapids/jni/kudo/ColumnViewInfo.java @@ -3,6 +3,8 @@ import ai.rapids.cudf.DType; import ai.rapids.cudf.DeviceMemoryBuffer; +import static com.nvidia.spark.rapids.jni.kudo.KudoSerializer.safeLongToInt; + public class ColumnViewInfo { private final DType dtype; diff --git a/src/main/java/com/nvidia/spark/rapids/jni/kudo/KudoSerializer.java b/src/main/java/com/nvidia/spark/rapids/jni/kudo/KudoSerializer.java index 50b87fd24..cc52aa1fc 100644 --- a/src/main/java/com/nvidia/spark/rapids/jni/kudo/KudoSerializer.java +++ b/src/main/java/com/nvidia/spark/rapids/jni/kudo/KudoSerializer.java @@ -1,6 +1,7 @@ package com.nvidia.spark.rapids.jni.kudo; import ai.rapids.cudf.*; +import com.nvidia.spark.rapids.jni.Pair; import com.nvidia.spark.rapids.jni.schema.Visitors; import java.io.*; @@ -102,7 +103,7 @@ public SerializedTable readOneTableBuffer(InputStream in) { } public Pair mergeToHost(List serializedTables, - Schema schema) { + Schema schema) { MergeMetrics.Builder metricsBuilder = MergeMetrics.builder(); MergedInfoCalc mergedInfoCalc = withTime(() -> MergedInfoCalc.calc(schema, serializedTables), From 2db4365999a5e24152da59c18fb651f5cf33f67b Mon Sep 17 00:00:00 2001 From: liurenjie1024 Date: Tue, 29 Oct 2024 09:50:00 +0800 Subject: [PATCH 3/3] Remove unused --- .../nvidia/spark/rapids/jni/TableUtils.java | 58 ------------------- 1 file changed, 58 deletions(-) diff --git a/src/main/java/com/nvidia/spark/rapids/jni/TableUtils.java b/src/main/java/com/nvidia/spark/rapids/jni/TableUtils.java index 01a15bc18..9db0be9a4 100644 --- a/src/main/java/com/nvidia/spark/rapids/jni/TableUtils.java +++ b/src/main/java/com/nvidia/spark/rapids/jni/TableUtils.java @@ -1,7 +1,5 @@ package com.nvidia.spark.rapids.jni; -import ai.rapids.cudf.*; - import java.util.Arrays; import java.util.Iterator; import java.util.function.Function; @@ -10,51 +8,6 @@ import java.util.stream.Collectors; public class TableUtils { - public static Schema schemaOf(Table t) { - Schema.Builder builder = Schema.builder(); - - for (int i = 0; i < t.getNumberOfColumns(); i++) { - ColumnVector cv = t.getColumn(i); - addToSchema(cv, "col_" + i + "_", builder); - } - - return builder.build(); - } - - public static void addToSchema(ColumnView cv, String namePrefix, Schema.Builder builder) { - toSchemaInner(cv, 0, namePrefix, builder); - } - - private static int toSchemaInner(ColumnView cv, int idx, String namePrefix, - Schema.Builder builder) { - String name = namePrefix + idx; - - Schema.Builder thisBuilder = builder.addColumn(cv.getType(), name); - int lastIdx = idx; - for (int i = 0; i < cv.getNumChildren(); i++) { - lastIdx = toSchemaInner(cv.getChildColumnView(i), lastIdx + 1, namePrefix, - thisBuilder); - } - - return lastIdx; - } - - public static void addToSchema(HostColumnVectorCore cv, String namePrefix, Schema.Builder builder) { - toSchemaInner(cv, 0, namePrefix, builder); - } - - private static int toSchemaInner(HostColumnVectorCore cv, int idx, String namePrefix, - Schema.Builder builder) { - String name = namePrefix + idx; - - Schema.Builder thisBuilder = builder.addColumn(cv.getType(), name); - int lastIdx = idx; - for (int i=0; i < cv.getNumChildren(); i++) { - lastIdx = toSchemaInner(cv.getChildColumnView(i), lastIdx + 1, namePrefix, thisBuilder); - } - - return lastIdx; - } public static void ensure(boolean condition, String message) { if (!condition) { @@ -78,17 +31,6 @@ public static long getValidityLengthInBytes(long rows) { return (rows + 7) / 8; } - /** - * This method returns the allocation size of the validity vector which is 64-byte aligned - * e.g. getValidityAllocationSizeInBytes(5) => 64 bytes - * getValidityAllocationSizeInBytes(14) => 64 bytes - * getValidityAllocationSizeInBytes(65) => 128 bytes - */ - static long getValidityAllocationSizeInBytes(long rows) { - long numBytes = getValidityLengthInBytes(rows); - return ((numBytes + 63) / 64) * 64; - } - public static T closeIfException(R resource, Function function) { try { return function.apply(resource);