From d0f9f3e6136bffaa94b25d4a1c95576d4747773d Mon Sep 17 00:00:00 2001 From: hellishfire Date: Mon, 9 Sep 2024 15:26:15 +0800 Subject: [PATCH] GH-43966: [Java] Check for nullabilities when comparing StructVector (#43968) ### Rationale for this change See #43966 ### What changes are included in this PR? Check for nullabilities when comparing StructVector with RangeEqualsVisitor. ### Are these changes tested? Yes ### Are there any user-facing changes? No * GitHub Issue: #43966 Authored-by: youming.whl Signed-off-by: David Li --- .../vector/compare/RangeEqualsVisitor.java | 61 +++++++++++++++++-- .../compare/TestRangeEqualsVisitor.java | 20 +++++- 2 files changed, 72 insertions(+), 9 deletions(-) diff --git a/java/vector/src/main/java/org/apache/arrow/vector/compare/RangeEqualsVisitor.java b/java/vector/src/main/java/org/apache/arrow/vector/compare/RangeEqualsVisitor.java index 9aa1bffb8463e..ed51f748af577 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/compare/RangeEqualsVisitor.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/compare/RangeEqualsVisitor.java @@ -41,6 +41,7 @@ import org.apache.arrow.vector.complex.ListVector; import org.apache.arrow.vector.complex.ListViewVector; import org.apache.arrow.vector.complex.NonNullableStructVector; +import org.apache.arrow.vector.complex.StructVector; import org.apache.arrow.vector.complex.UnionVector; /** Visitor to compare a range of values for vectors. */ @@ -345,6 +346,20 @@ protected boolean compareDenseUnionVectors(Range range) { return true; } + private boolean compareStructVectorsInternal( + NonNullableStructVector leftVector, NonNullableStructVector rightVector, Range range) { + List leftChildNames = leftVector.getChildFieldNames(); + for (String name : leftChildNames) { + RangeEqualsVisitor visitor = + createInnerVisitor( + leftVector.getChild(name), rightVector.getChild(name), /*type comparator*/ null); + if (!visitor.rangeEquals(range)) { + return false; + } + } + return true; + } + protected boolean compareStructVectors(Range range) { NonNullableStructVector leftVector = (NonNullableStructVector) left; NonNullableStructVector rightVector = (NonNullableStructVector) right; @@ -354,15 +369,49 @@ protected boolean compareStructVectors(Range range) { return false; } - for (String name : leftChildNames) { - RangeEqualsVisitor visitor = - createInnerVisitor( - leftVector.getChild(name), rightVector.getChild(name), /*type comparator*/ null); - if (!visitor.rangeEquals(range)) { + if (!(leftVector instanceof StructVector || rightVector instanceof StructVector)) { + // neither struct vector is nullable + return compareStructVectorsInternal(leftVector, rightVector, range); + } + + Range subRange = new Range(0, 0, 0); + boolean lastIsNull = true; + int lastNullIndex = -1; + for (int i = 0; i < range.getLength(); i++) { + int leftIndex = range.getLeftStart() + i; + int rightIndex = range.getRightStart() + i; + boolean isLeftNull = leftVector.isNull(leftIndex); + boolean isRightNull = rightVector.isNull(rightIndex); + + if (isLeftNull != isRightNull) { + // exactly one slot is null, unequal return false; } + if (isLeftNull) { + // slots are null + if (!lastIsNull) { + subRange + .setLeftStart(range.getLeftStart() + lastNullIndex + 1) + .setRightStart(range.getRightStart() + lastNullIndex + 1) + .setLength(i - (lastNullIndex + 1)); + if (!compareStructVectorsInternal(leftVector, rightVector, subRange)) { + return false; + } + } + lastIsNull = true; + lastNullIndex = i; + } else { + // slots are not null + lastIsNull = false; + } + } + if (!lastIsNull) { + subRange + .setLeftStart(range.getLeftStart() + lastNullIndex + 1) + .setRightStart(range.getRightStart() + lastNullIndex + 1) + .setLength(range.getLength() - (lastNullIndex + 1)); + return compareStructVectorsInternal(leftVector, rightVector, subRange); } - return true; } diff --git a/java/vector/src/test/java/org/apache/arrow/vector/compare/TestRangeEqualsVisitor.java b/java/vector/src/test/java/org/apache/arrow/vector/compare/TestRangeEqualsVisitor.java index eca5c2d9b2a83..08da786eb272c 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/compare/TestRangeEqualsVisitor.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/compare/TestRangeEqualsVisitor.java @@ -434,17 +434,18 @@ public void testStructVectorRangeEquals() { NullableStructWriter writer1 = vector1.getWriter(); writer1.allocate(); + writeStructVector(writer1, 0, 0L); writeStructVector(writer1, 1, 10L); writeStructVector(writer1, 2, 20L); writeStructVector(writer1, 3, 30L); writeStructVector(writer1, 4, 40L); writeStructVector(writer1, 5, 50L); - writer1.setValueCount(5); + writer1.setValueCount(6); NullableStructWriter writer2 = vector2.getWriter(); writer2.allocate(); - writeStructVector(writer2, 0, 00L); + writeStructVector(writer2, 0, 0L); writeStructVector(writer2, 2, 20L); writeStructVector(writer2, 3, 30L); writeStructVector(writer2, 4, 40L); @@ -452,7 +453,20 @@ public void testStructVectorRangeEquals() { writer2.setValueCount(5); RangeEqualsVisitor visitor = new RangeEqualsVisitor(vector1, vector2); - assertTrue(visitor.rangeEquals(new Range(1, 1, 3))); + assertTrue(visitor.rangeEquals(new Range(2, 1, 3))); + + // different nullability but same values + vector1.setNull(3); + assertFalse(visitor.rangeEquals(new Range(2, 1, 3))); + // both null and same values + vector2.setNull(2); + assertTrue(visitor.rangeEquals(new Range(2, 1, 3))); + // both not null but different values + assertFalse(visitor.rangeEquals(new Range(2, 1, 4))); + // both null but different values + vector1.setNull(5); + vector2.setNull(4); + assertTrue(visitor.rangeEquals(new Range(2, 1, 4))); } }