Skip to content

Optimize Utf8Validator with constant input Vector.slice API #68

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 28 additions & 55 deletions src/main/java/org/simdjson/Utf8Validator.java
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
import static jdk.incubator.vector.VectorOperators.LSHL;
import static jdk.incubator.vector.VectorOperators.LSHR;
import static jdk.incubator.vector.VectorOperators.NE;
import static jdk.incubator.vector.VectorOperators.UNSIGNED_GE;
import static jdk.incubator.vector.VectorOperators.UNSIGNED_GT;
import static jdk.incubator.vector.VectorOperators.UGE;
import static jdk.incubator.vector.VectorOperators.UGT;
import static jdk.incubator.vector.VectorShuffle.iota;
import static org.simdjson.VectorUtils.BYTE_SPECIES;
import static org.simdjson.VectorUtils.INT_SPECIES;
Expand Down Expand Up @@ -40,8 +40,6 @@ class Utf8Validator {
private static final byte TWO_CONTINUATIONS = (byte) (1 << 7);
private static final byte MAX_2_LEADING_BYTE = (byte) 0b110_11111;
private static final byte MAX_3_LEADING_BYTE = (byte) 0b1110_1111;
private static final int TWO_BYTES_SIZE = Byte.SIZE * 2;
private static final int THREE_BYTES_SIZE = Byte.SIZE * 3;
private static final ByteVector BYTE_1_HIGH_LOOKUP = createByte1HighLookup();
private static final ByteVector BYTE_1_LOW_LOOKUP = createByte1LowLookup();
private static final ByteVector BYTE_2_HIGH_LOOKUP = createByte2HighLookup();
Expand All @@ -52,84 +50,63 @@ class Utf8Validator {
private static final int STEP_SIZE = BYTE_SPECIES.vectorByteSize();

static void validate(byte[] buffer, int length) {
long previousIncomplete = 0;
int offset = 0;
long errors = 0;
int previousFourUtf8Bytes = 0;

long previousIncomplete = 0;
int loopBound = BYTE_SPECIES.loopBound(length);
int offset = 0;
ByteVector previousChunk = ByteVector.broadcast(BYTE_SPECIES, 0);

for (; offset < loopBound; offset += STEP_SIZE) {
ByteVector chunk = ByteVector.fromArray(BYTE_SPECIES, buffer, offset);
IntVector chunkAsInts = chunk.reinterpretAsInts();
// ASCII fast path can bypass the checks that are only required for multibyte code points.
if (chunk.and(ALL_ASCII_MASK).compare(EQ, 0).allTrue()) {
errors |= previousIncomplete;
} else {
previousIncomplete = chunk.compare(UNSIGNED_GE, INCOMPLETE_CHECK).toLong();
// Shift the input forward by four bytes to make space for the previous four bytes.
// The previous three bytes are required for validation, pulling in the last integer
// will give the previous four bytes. The switch to integer vectors is to allow for
// integer shifting instead of the more expensive shuffle / slice operations.
IntVector chunkWithPreviousFourBytes = chunkAsInts
.rearrange(FOUR_BYTES_FORWARD_SHIFT)
.withLane(0, previousFourUtf8Bytes);
// Shift the current input forward by one byte to include one byte from the previous chunk.
ByteVector previousOneByte = chunkAsInts
.lanewise(LSHL, Byte.SIZE)
.or(chunkWithPreviousFourBytes.lanewise(LSHR, THREE_BYTES_SIZE))
.reinterpretAsBytes();
previousIncomplete = chunk.compare(UGE, INCOMPLETE_CHECK).toLong();
// Pull in last byte from previous chunk.
ByteVector previousOneByte = previousChunk.slice(BYTE_SPECIES.length() - 1, chunk);

ByteVector byte2HighNibbles = chunkAsInts.lanewise(LSHR, 4)
.reinterpretAsBytes()
.and(LOW_NIBBLE_MASK);
ByteVector byte1HighNibbles = previousOneByte.reinterpretAsInts()
.lanewise(LSHR, 4)
.reinterpretAsBytes()
.and(LOW_NIBBLE_MASK);

ByteVector byte1LowNibbles = previousOneByte.and(LOW_NIBBLE_MASK);
ByteVector byte1HighState = byte1HighNibbles.selectFrom(BYTE_1_HIGH_LOOKUP);
ByteVector byte1LowState = byte1LowNibbles.selectFrom(BYTE_1_LOW_LOOKUP);
ByteVector byte2HighState = byte2HighNibbles.selectFrom(BYTE_2_HIGH_LOOKUP);
ByteVector firstCheck = byte1HighState.and(byte1LowState).and(byte2HighState);

// All remaining checks are for invalid 3 and 4-byte sequences, which either have too many
// continuation bytes or not enough.
ByteVector previousTwoBytes = chunkAsInts
.lanewise(LSHL, TWO_BYTES_SIZE)
.or(chunkWithPreviousFourBytes.lanewise(LSHR, TWO_BYTES_SIZE))
.reinterpretAsBytes();
ByteVector previousTwoBytes = previousChunk.slice(BYTE_SPECIES.length() - 2, chunk);

// The minimum leading byte of 3-byte sequences is always greater than the maximum leading byte of 2-byte sequences.
VectorMask<Byte> is3ByteLead = previousTwoBytes.compare(UNSIGNED_GT, MAX_2_LEADING_BYTE);
ByteVector previousThreeBytes = chunkAsInts
.lanewise(LSHL, THREE_BYTES_SIZE)
.or(chunkWithPreviousFourBytes.lanewise(LSHR, Byte.SIZE))
.reinterpretAsBytes();
VectorMask<Byte> is3ByteLead = previousTwoBytes.compare(UGT, MAX_2_LEADING_BYTE);
ByteVector previousThreeBytes = previousChunk.slice(BYTE_SPECIES.length() - 3, chunk);

// The minimum leading byte of 4-byte sequences is always greater than the maximum leading byte of 3-byte sequences.
VectorMask<Byte> is4ByteLead = previousThreeBytes.compare(UNSIGNED_GT, MAX_3_LEADING_BYTE);
VectorMask<Byte> is4ByteLead = previousThreeBytes.compare(UGT, MAX_3_LEADING_BYTE);
// The firstCheck vector contains 0x80 values on continuation byte indexes.
// The leading bytes of 3 and 4-byte sequences should match up with these indexes and zero them out.
ByteVector secondCheck = firstCheck.add((byte) 0x80, is3ByteLead.or(is4ByteLead));
errors |= secondCheck.compare(NE, 0).toLong();
}
previousFourUtf8Bytes = chunkAsInts.lane(INT_SPECIES.length() - 1);
previousChunk = chunk;
}

// If the input file doesn't align with the vector width, pad the missing bytes with zeros.
VectorMask<Byte> remainingBytes = BYTE_SPECIES.indexInRange(offset, length);
ByteVector chunk = ByteVector.fromArray(BYTE_SPECIES, buffer, offset, remainingBytes);
if (!chunk.and(ALL_ASCII_MASK).compare(EQ, 0).allTrue()) {
IntVector chunkAsInts = chunk.reinterpretAsInts();
previousIncomplete = chunk.compare(UNSIGNED_GE, INCOMPLETE_CHECK).toLong();
// Shift the input forward by four bytes to make space for the previous four bytes.
// The previous three bytes are required for validation, pulling in the last integer
// will give the previous four bytes. The switch to integer vectors is to allow for
// integer shifting instead of the more expensive shuffle / slice operations.
IntVector chunkWithPreviousFourBytes = chunkAsInts
.rearrange(FOUR_BYTES_FORWARD_SHIFT)
.withLane(0, previousFourUtf8Bytes);
// Shift the current input forward by one byte to include one byte from the previous chunk.
ByteVector previousOneByte = chunkAsInts
.lanewise(LSHL, Byte.SIZE)
.or(chunkWithPreviousFourBytes.lanewise(LSHR, THREE_BYTES_SIZE))
.reinterpretAsBytes();
previousIncomplete = chunk.compare(UGE, INCOMPLETE_CHECK).toLong();
// Pull in last byte from previous chunk.
ByteVector previousOneByte = previousChunk.slice(BYTE_SPECIES.length() - 1, chunk);
ByteVector byte2HighNibbles = chunkAsInts.lanewise(LSHR, 4)
.reinterpretAsBytes()
.and(LOW_NIBBLE_MASK);
Expand All @@ -144,18 +121,14 @@ static void validate(byte[] buffer, int length) {
ByteVector firstCheck = byte1HighState.and(byte1LowState).and(byte2HighState);
// All remaining checks are for invalid 3 and 4-byte sequences, which either have too many
// continuation bytes or not enough.
ByteVector previousTwoBytes = chunkAsInts
.lanewise(LSHL, TWO_BYTES_SIZE)
.or(chunkWithPreviousFourBytes.lanewise(LSHR, TWO_BYTES_SIZE))
.reinterpretAsBytes();
// Pull in last two bytes from previous chunk.
ByteVector previousTwoBytes = previousChunk.slice(BYTE_SPECIES.length() - 2, chunk);
// The minimum leading byte of 3-byte sequences is always greater than the maximum leading byte of 2-byte sequences.
VectorMask<Byte> is3ByteLead = previousTwoBytes.compare(UNSIGNED_GT, MAX_2_LEADING_BYTE);
ByteVector previousThreeBytes = chunkAsInts
.lanewise(LSHL, THREE_BYTES_SIZE)
.or(chunkWithPreviousFourBytes.lanewise(LSHR, Byte.SIZE))
.reinterpretAsBytes();
VectorMask<Byte> is3ByteLead = previousTwoBytes.compare(UGT, MAX_2_LEADING_BYTE);
ByteVector previousThreeBytes = previousChunk.slice(BYTE_SPECIES.length() - 3, chunk);
// The minimum leading byte of 4-byte sequences is always greater than the maximum leading byte of 3-byte sequences.
VectorMask<Byte> is4ByteLead = previousThreeBytes.compare(UNSIGNED_GT, MAX_3_LEADING_BYTE);
// Pull in last three bytes from previous chunk.
VectorMask<Byte> is4ByteLead = previousThreeBytes.compare(UGT, MAX_3_LEADING_BYTE);
// The firstCheck vector contains 0x80 values on continuation byte indexes.
// The leading bytes of 3 and 4-byte sequences should match up with these indexes and zero them out.
ByteVector secondCheck = firstCheck.add((byte) 0x80, is3ByteLead.or(is4ByteLead));
Expand Down