Skip to content

Commit

Permalink
fix null equality
Browse files Browse the repository at this point in the history
  • Loading branch information
edgao committed Dec 17, 2024
1 parent 893f209 commit 8f3afd0
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,17 @@ object MockDestinationBackend {
// Assume that in dedup mode, we don't have duplicates - so we can just find the first
// record with the same PK as the incoming record
val existingRecord =
file.firstOrNull { RecordDiffer.comparePks(incomingPk, getPk(it)) == 0 }
file.firstOrNull {
RecordDiffer.comparePks(incomingPk, getPk(it), nullEqualsUnset = false) == 0
}
if (existingRecord == null) {
file.add(incomingRecord)
} else {
val incomingCursor = getCursor(incomingRecord)
val existingCursor = getCursor(existingRecord)
val compare = RecordDiffer.valueComparator.compare(incomingCursor, existingCursor)
val compare =
RecordDiffer.getValueComparator(nullEqualsUnset = false)
.compare(incomingCursor, existingCursor)
// If the incoming record has a later cursor,
// or the same cursor but a later extractedAt,
// then upsert. (otherwise discard the incoming record.)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package io.airbyte.cdk.load.test.util

import java.time.OffsetDateTime
import kotlin.test.assertEquals
import kotlin.test.assertNull
import org.junit.jupiter.api.Assertions
import org.junit.jupiter.api.Test

Expand Down Expand Up @@ -155,7 +156,7 @@ class RecordDifferTest {
),
)
)
assertEquals(null, diff)
assertNull(diff)
}

/** Verify that the differ can sort records which are identical other than the cursor */
Expand Down Expand Up @@ -193,7 +194,7 @@ class RecordDifferTest {
),
),
)
assertEquals(null, diff)
assertNull(diff)
}

/** Verify that the differ can sort records which are identical other than extractedAt */
Expand Down Expand Up @@ -231,6 +232,49 @@ class RecordDifferTest {
),
)
)
assertEquals(null, diff)
assertNull(diff)
}

@Test
fun testNullEqualsUnset() {
val diff =
RecordDiffer(primaryKey = listOf(listOf("id")), cursor = null, nullEqualsUnset = true)
.diffRecords(
listOf(
OutputRecord(
extractedAt = 1,
generationId = 0,
data =
mapOf(
"id" to 1,
"sub_object" to
mapOf(
"foo" to "bar",
"sub_list" to listOf(mapOf<String, Any?>()),
)
),
airbyteMeta = null,
),
),
listOf(
OutputRecord(
extractedAt = 1,
generationId = 0,
data =
mapOf(
"id" to 1,
"name" to null,
"sub_object" to
mapOf(
"foo" to "bar",
"bar" to null,
"sub_list" to listOf(mapOf("foo" to null)),
)
),
airbyteMeta = null,
),
),
)
assertNull(diff)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
package io.airbyte.cdk.load.test.util

import io.airbyte.cdk.load.data.AirbyteValue
import io.airbyte.cdk.load.data.ArrayValue
import io.airbyte.cdk.load.data.DateValue
import io.airbyte.cdk.load.data.IntegerValue
import io.airbyte.cdk.load.data.NullValue
Expand Down Expand Up @@ -48,6 +49,8 @@ class RecordDiffer(
*/
val allowUnexpectedRecord: Boolean = false,
) {
private val valueComparator = getValueComparator(nullEqualsUnset)

private fun extract(data: Map<String, AirbyteValue>, path: List<String>): AirbyteValue {
return when (path.size) {
0 -> throw IllegalArgumentException("Empty path")
Expand Down Expand Up @@ -87,7 +90,7 @@ class RecordDiffer(
)
}

comparePks(pk1, pk2)
comparePks(pk1, pk2, nullEqualsUnset)
}

/**
Expand Down Expand Up @@ -276,30 +279,39 @@ class RecordDiffer(
}

companion object {
val valueComparator: Comparator<AirbyteValue> =
Comparator.nullsFirst { v1, v2 -> compare(v1!!, v2!!) }
fun getValueComparator(nullEqualsUnset: Boolean): Comparator<AirbyteValue> =
Comparator.nullsFirst { v1, v2 -> compare(v1!!, v2!!, nullEqualsUnset) }

/**
* Compare each PK field in order, until we find a field that the two records differ in. If
* all the fields are equal, then these two records have the same PK.
*/
fun comparePks(pk1: List<AirbyteValue?>, pk2: List<AirbyteValue?>) =
(pk1.zip(pk2)
.map { (pk1Field, pk2Field) -> valueComparator.compare(pk1Field, pk2Field) }
fun comparePks(
pk1: List<AirbyteValue?>,
pk2: List<AirbyteValue?>,
nullEqualsUnset: Boolean,
): Int {
return (pk1.zip(pk2)
.map { (pk1Field, pk2Field) ->
getValueComparator(nullEqualsUnset).compare(pk1Field, pk2Field)
}
.firstOrNull { it != 0 }
?: 0)
}

private fun compare(v1: AirbyteValue, v2: AirbyteValue): Int {
private fun compare(v1: AirbyteValue, v2: AirbyteValue, nullEqualsUnset: Boolean): Int {
if (v1 is UnknownValue) {
return compare(
JsonToAirbyteValue().fromJson(v1.value),
v2,
nullEqualsUnset,
)
}
if (v2 is UnknownValue) {
return compare(
v1,
JsonToAirbyteValue().fromJson(v2.value),
nullEqualsUnset,
)
}

Expand Down Expand Up @@ -348,6 +360,37 @@ class RecordDiffer(
}
}
}
is ObjectValue -> {
fun objComp(a: ObjectValue, b: ObjectValue): Int {
// objects aren't really comparable, so just do an equality check
return if (a == b) 0 else 1
}
if (nullEqualsUnset) {
// Walk through the airbyte value, removing any NullValue entries
// from ObjectValues.
fun removeObjectNullValues(value: AirbyteValue): AirbyteValue =
when (value) {
is ObjectValue ->
ObjectValue(
value.values
.filterTo(linkedMapOf()) { (_, v) ->
v !is NullValue
}
.mapValuesTo(linkedMapOf()) { (_, v) ->
removeObjectNullValues(v)
}
)
is ArrayValue ->
ArrayValue(value.values.map { removeObjectNullValues(it) })
else -> value
}
val filteredV1 = removeObjectNullValues(v1) as ObjectValue
val filteredV2 = removeObjectNullValues(v2) as ObjectValue
objComp(filteredV1, filteredV2)
} else {
objComp(v1, v2 as ObjectValue)
}
}
// otherwise, just be a terrible person.
// we know these are the same type, so this is safe to do.
is Comparable<*> ->
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ abstract class IcebergV2WriteTest(
}

@Test
// @Disabled
// @Disabled
override fun testContainerTypes() {
super.testContainerTypes()
}
Expand All @@ -82,7 +82,7 @@ abstract class IcebergV2WriteTest(
}

@Test
// @Disabled
// @Disabled
override fun testUnions() {
super.testUnions()
}
Expand Down

0 comments on commit 8f3afd0

Please sign in to comment.