Skip to content

Commit

Permalink
Fix overeager query invalidation, optimize validation, bump gradle wr…
Browse files Browse the repository at this point in the history
…apper version
  • Loading branch information
ty1824 committed May 4, 2024
1 parent b4bfc65 commit 5a7fd58
Show file tree
Hide file tree
Showing 5 changed files with 146 additions and 64 deletions.
2 changes: 1 addition & 1 deletion gradle/wrapper/gradle-wrapper.properties
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
distributionBase=GRADLE_USER_HOME
distributionPath=wrapper/dists
distributionUrl=https\://services.gradle.org/distributions/gradle-7.6-bin.zip
distributionUrl=https\://services.gradle.org/distributions/gradle-8.7-bin.zip
zipStoreBase=GRADLE_USER_HOME
zipStorePath=wrapper/dists
1 change: 0 additions & 1 deletion inkt/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ publishing {
}
}
}
repositories.forEach { println((it as MavenArtifactRepository).url)}
publications {
register<MavenPublication>("default") {
from(components["java"])
Expand Down
92 changes: 41 additions & 51 deletions inkt/src/main/kotlin/dev/dialector/inkt/next/QueryDatabaseImpl.kt
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
package dev.dialector.inkt.next

internal data class QueryKey<K : Any, V>(val queryDef: QueryDefinition<K, V>, val key: K) {
fun presentation(): String = "(${queryDef.name}, $key"
override fun toString(): String =
"(${queryDef.name}, $key)"
}

internal sealed interface Value<V> {
Expand Down Expand Up @@ -78,11 +79,23 @@ public class QueryDatabaseImpl : QueryDatabase {
private fun <K : Any, V> set(inputDef: QueryDefinition<K, V>, key: K, value: V) {
val queryStorage = getQueryStorage(inputDef)
when (val currentValue = queryStorage[key]) {
null -> queryStorage[key] = InputValue(value, ++currentRevision)
is InputValue -> {
currentValue.value = value
currentValue.changedAt = ++currentRevision
// Only register a change if the value actually has changed
if (currentValue.value != value) {
currentValue.value = value
currentValue.changedAt = ++currentRevision
}
}
else -> {
// If the new input value is equivalent to the existing value, backdate it.
val revision = if (currentValue.value != value) {
++currentRevision
} else {
currentValue.changedAt
}
queryStorage[key] = InputValue(value, revision)
}
else -> queryStorage[key] = InputValue(value, ++currentRevision)
}
}

Expand Down Expand Up @@ -112,7 +125,7 @@ public class QueryDatabaseImpl : QueryDatabase {
}

is DerivedValue<V> -> {
if (deepVerify(context, existingValue)) {
if (verify(context, existingValue)) {
context.addDependency(queryKey, existingValue.changedAt)
existingValue.value
} else {
Expand Down Expand Up @@ -150,42 +163,29 @@ public class QueryDatabaseImpl : QueryDatabase {
}

/**
* Checks whether a value is guaranteed to be up-to-date as of this revision. Does not check dependencies.
*/
private fun shallowVerify(value: Value<*>): Boolean {
return when (value) {
is InputValue<*> -> value.changedAt <= currentRevision
is DerivedValue<*> -> value.verifiedAt == currentRevision
}
}

/**
* Checks whether a value is up-to-date based on its dependencies.
* Checks whether a derived value is up-to-date based on its dependencies.
*
* Returns true if the value is considered up-to-date, false if it must be recomputed.
*/
private fun deepVerify(context: QueryExecutionContext, value: Value<*>): Boolean {
return when (value) {
is InputValue<*> -> shallowVerify(value)
is DerivedValue<*> -> {
if (shallowVerify(value)) {
return true
}

val noDepsChanged = value.dependencies.none { dep ->
// If the dependency exists, check if it may have changed.
// If it does not exist, it has "changed" (likely removed) and thus must be recomputed.
get(dep)?.let {
maybeChangedAfter(context, dep, it, value.verifiedAt)
} ?: true
}
private fun verify(context: QueryExecutionContext, value: DerivedValue<*>): Boolean {
// Short-circuit if possible
if (value.verifiedAt == currentRevision) {
return true
}

if (noDepsChanged) {
value.verifiedAt = currentRevision
}
val noDepsChanged = value.dependencies.none { dep ->
// If the dependency exists, check if it may have changed.
// If it does not exist, it has "changed" (likely removed) and thus must be recomputed.
get(dep)?.let {
maybeChangedAfter(context, dep, it, value.verifiedAt)
} ?: true
}

return false
}
return if (noDepsChanged) {
value.verifiedAt = currentRevision
true
} else {
false
}
}

Expand All @@ -198,24 +198,14 @@ public class QueryDatabaseImpl : QueryDatabase {
value: Value<*>,
asOfRevision: Int,
): Boolean {
if (value is InputValue<*>) {
return shallowVerify(value)
}

if (shallowVerify(value)) {
return value.changedAt > asOfRevision
}

if (deepVerify(context, value)) {
// If the value is not derived (is an input) or is a verified derived value, return if it has changed
if (value !is DerivedValue<*> || verify(context, value)) {
return value.changedAt > asOfRevision
}

// If the value is not verified, re-run and check if it produces the same result.
val newValue = execute(context, key)
if (value == newValue) {
return false
}

return true
return value.value != newValue
}

public fun print() {
Expand All @@ -241,7 +231,7 @@ public class QueryDatabaseImpl : QueryDatabase {
checkCanceled()
if (queryStack.any { it.queryKey == key }) {
throw IllegalStateException(
"Cycle detected: ${key.presentation()} already in ${queryStack.joinToString { it.queryKey.presentation() }}",
"Cycle detected: $key already in ${queryStack.joinToString { it.queryKey.toString() }}",
)
}
queryStack.add(QueryFrame(key))
Expand Down
5 changes: 4 additions & 1 deletion inkt/src/main/kotlin/dev/dialector/inkt/next/QueryDslImpl.kt
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,12 @@ internal class QueryDefinitionDelegate<K : Any, V>(private val value: QueryDefin
override operator fun getValue(thisRef: Any?, property: KProperty<*>): QueryDefinition<K, V> = value
}

internal data class QueryDefinitionImpl<K : Any, V>(
internal class QueryDefinitionImpl<K : Any, V>(
override val name: String,
val logic: QueryFunction<K, V>,
) : QueryDefinition<K, V> {

override fun toString(): String = "QueryDefinition($name)"

override fun execute(context: QueryContext, key: K): V = context.logic(key)
}
110 changes: 100 additions & 10 deletions inkt/src/test/kotlin/dev/dialector/inkt/DatabaseTest.kt
Original file line number Diff line number Diff line change
@@ -1,35 +1,66 @@
package dev.dialector.inkt

import dev.dialector.inkt.next.QueryDatabase
import dev.dialector.inkt.next.QueryDatabaseImpl
import dev.dialector.inkt.next.QueryDefinition
import dev.dialector.inkt.next.defineQuery
import dev.dialector.inkt.next.query
import dev.dialector.inkt.next.remove
import dev.dialector.inkt.next.set
import org.junit.jupiter.api.Test
import java.io.ByteArrayOutputStream
import java.io.PrintStream
import kotlin.test.BeforeTest
import kotlin.test.assertEquals
import kotlin.test.assertFails

class InvocationCounter {
private var counter = 0
fun increment() {
counter++
}

fun checkAndReset(): Int {
val ret = counter
counter = 0
return ret
}

fun reset() {
counter = 0
}
}

class DatabaseTest {
private val someInput by defineQuery<Int>("daInput")
private val someInputTimesTwo by defineQuery<Int>("derived2") { query(someInput) * 2 }
private val otherInput by defineQuery<Int>()
private val timesTwoCounter = InvocationCounter()
private val someInputTimesTwo by defineQuery<Int>("derived2") {
timesTwoCounter.increment()
query(someInput) * 2
}
private val doubleArgument by defineQuery<Int, Int> { it + it }

private var transitiveInvokeCount = 0
private var transitiveInvokeCount = InvocationCounter()
private val transitive by defineQuery<String, Int> { arg ->
transitiveInvokeCount++
transitiveInvokeCount.increment()
val doubledSomeInput = query(doubleArgument, query(someInput))
doubledSomeInput + arg.length
}

private lateinit var database: QueryDatabase
private val otherTransitiveCounter = InvocationCounter()
private val otherTransitive by defineQuery<Int> {
otherTransitiveCounter.increment()
query(someInputTimesTwo) + query(otherInput)
}

private lateinit var database: QueryDatabaseImpl

@BeforeTest
fun init() {
database = QueryDatabaseImpl()
transitiveInvokeCount = 0
timesTwoCounter.reset()
transitiveInvokeCount.reset()
otherTransitiveCounter.reset()
}

@Test
Expand All @@ -45,8 +76,7 @@ class DatabaseTest {
assertEquals(13, database.query(transitive, "hi!"))

// Verify that the `transitive` query was only invoked twice, once for each unique argument
assertEquals(2, transitiveInvokeCount)
transitiveInvokeCount = 0
assertEquals(2, transitiveInvokeCount.checkAndReset())

// Change someInput and repeat
database.set(someInput, 100)
Expand All @@ -57,8 +87,7 @@ class DatabaseTest {
assertEquals(202, database.query(transitive, "hi"))
assertEquals(202, database.query(transitive, "hi"))
assertEquals(203, database.query(transitive, "hi!"))
assertEquals(2, transitiveInvokeCount)
transitiveInvokeCount = 0
assertEquals(2, transitiveInvokeCount.checkAndReset())

// All calls should fail after removing dependency
database.remove(someInput)
Expand All @@ -68,6 +97,31 @@ class DatabaseTest {
assertFails { database.query(transitive, "hi!") }
}

@Test
fun `caching and invalidation of queries upon change`() {
database.writeTransaction {
set(someInput, 5)
set(otherInput, 100)
query(otherTransitive) // Should run fully here
set(otherInput, 100)
query(otherTransitive) // Should not recompute, setting to the same value
set(someInput, 6)
query(otherTransitive) // Should recompute fully
set(otherInput, 5)
query(otherTransitive) // Should not re-run times two
set(someInputTimesTwo, 12)
query(otherTransitive) // Should not recompute, derived query was explicitly assigned the same value
remove(someInputTimesTwo)
query(otherTransitive) // Should recompute fully
set(someInput, 5)
set(someInput, 6)
query(otherTransitive) // Should only recompute intermediate value, result should be the same.
}

assertEquals(4, timesTwoCounter.checkAndReset())
assertEquals(4, otherTransitiveCounter.checkAndReset())
}

@Test
fun implementationWithExplicitValue() {
assertEquals(4, database.query(doubleArgument, 2))
Expand All @@ -79,6 +133,7 @@ class DatabaseTest {
// Result for 3 should be unchanged
assertEquals(6, database.query(doubleArgument, 3))

// 2 + 2 = 4, whew
database.remove(doubleArgument, 2)
assertEquals(4, database.query(doubleArgument, 2))
assertEquals(6, database.query(doubleArgument, 3))
Expand Down Expand Up @@ -106,4 +161,39 @@ class DatabaseTest {
assertFails { database.query(possiblyCyclic, 4) }
assertFails { database.query(possiblyCyclic, 8) }
}

@Test
fun `print database`() {
database.writeTransaction {
set(someInput, 5)
set(otherInput, 10)
query(otherTransitive)
set(someInput, 100)
}

val expected = """
|=========================
|Current revision = 3
|Query store: QueryDefinition(daInput)
| kotlin.Unit to InputValue(value=100, changedAt=3)
|Query store: QueryDefinition(otherInput)
| kotlin.Unit to InputValue(value=10, changedAt=2)
|Query store: QueryDefinition(otherTransitive)
| kotlin.Unit to DerivedValue(value=20, dependencies=[(derived2, kotlin.Unit), (otherInput, kotlin.Unit)], verifiedAt=2, changedAt=2)
|Query store: QueryDefinition(derived2)
| kotlin.Unit to DerivedValue(value=10, dependencies=[(daInput, kotlin.Unit)], verifiedAt=2, changedAt=1)
|=========================
|
""".trimMargin()

val os = ByteArrayOutputStream(1024)
val originalOut = System.out
try {
System.setOut(PrintStream(os))
database.print()
assertEquals(expected, os.toString())
} finally {
System.setOut(originalOut)
}
}
}

0 comments on commit 5a7fd58

Please sign in to comment.