Skip to content

Commit

Permalink
implement steerable random walk
Browse files Browse the repository at this point in the history
  • Loading branch information
breandan committed Jun 23, 2024
1 parent 7a57c90 commit 7e284ad
Show file tree
Hide file tree
Showing 4 changed files with 18,732 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,9 @@ open class MarkovChain<T>(
if (variables.size == 1) counter.rawCounts.getEstimate(variables[0]) / counter.total.toDouble()
else get(*variables.mapIndexed { i, t -> i to t }.toTypedArray())

fun scoreChunk(seq: List<T?>): Double =
-ln((getAtLeastOne(seq) + 1) / (getAtLeastOne(seq.dropLast(1) + null) + dictionary.size))

private fun getAtLeastOne(variables: List<T?>): Double =
// variables.allMasks().sumOf { mask ->
(counter.nrmCounts.getEstimate(variables) + 1).toDouble() / counter.total.toDouble()
Expand All @@ -143,8 +146,7 @@ open class MarkovChain<T>(
}

// https://www.cs.utah.edu/~jeffp/papers/merge-summ.pdf
operator fun plus(mc: MarkovChain<T>) =
MarkovChain<T>(memory = memory, counter = counter + mc.counter)
operator fun plus(mc: MarkovChain<T>) = MarkovChain<T>(memory = memory, counter = counter + mc.counter)

/**
* TODO: construct [Dist] using precomputed normalization constants [Counter.nrmCounts]
Expand Down
91 changes: 78 additions & 13 deletions src/jvmTest/kotlin/ai/hypergraph/kaliningraph/automata/WFSATest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,42 @@ import Grammars
import Grammars.shortS2PParikhMap
import ai.hypergraph.kaliningraph.graphs.LabeledGraph
import ai.hypergraph.kaliningraph.parsing.*
import ai.hypergraph.kaliningraph.visualization.alsoCopy
import ai.hypergraph.markovian.mcmc.MarkovChain
import dk.brics.automaton.Transition
import net.jhoogland.jautomata.*
import net.jhoogland.jautomata.Automaton
import net.jhoogland.jautomata.operations.*
import net.jhoogland.jautomata.operations.Concatenation
import net.jhoogland.jautomata.semirings.RealSemiring
import java.io.File
import java.util.PriorityQueue
import kotlin.random.Random
import kotlin.test.*
import kotlin.time.*

typealias BState = dk.brics.automaton.State
typealias BAutomaton = dk.brics.automaton.Automaton
typealias JAutomaton<S, K> = Automaton<S, K>
typealias JAutomaton<S, K> = net.jhoogland.jautomata.Automaton<S, K>

class WFSATest {

val MARKOV_MEMORY = 4
// Python3 snippets
// https://github.com/michiyasunaga/BIFI?tab=readme-ov-file#about-the-github-python-dataset
val P_BIFI: MarkovChain<Σᐩ> by lazy {
// readBIFIContents()
val csv = File(File("").absolutePath + "/src/jvmTest/resources/ngrams_BIFI_$MARKOV_MEMORY.csv")
MarkovChain.deserialize(csv.readText())
.also { println("Loaded ${it.counter.total} BIFI $MARKOV_MEMORY-grams from ${csv.absolutePath}") }
}

// Python2 snippets, about ~20x longer on average than BIFI
// https://www.sri.inf.ethz.ch/py150
val P_PY150: MarkovChain<Σᐩ> by lazy {
val csv = File(File("").absolutePath + "/src/jvmTest/resources/ngrams_PY150_$MARKOV_MEMORY.csv")
MarkovChain.deserialize(csv.readText())
.also { println("Loaded ${it.counter.total} PY150 $MARKOV_MEMORY-grams from ${csv.absolutePath}") }
}

val P_BIFI_PY150: MarkovChain<Σᐩ> by lazy { P_BIFI + P_PY150 }

/*
./gradlew jvmTest --tests "ai.hypergraph.kaliningraph.automata.WFSATest.testWFSA"
*/
Expand Down Expand Up @@ -90,19 +111,60 @@ class WFSATest {
* previous n-1 transitions, i.e., q' ~ argmax_{q'} P(q' | q_{t-1}, ..., q_{t-n+1})
*/

fun <S, K> JAutomaton<S, K>.randomWalk(mc: MarkovChain<S>, topK: Int = 1000): Sequence<S> {
val init = initialStates().first()
val padding = List(mc.memory - 1) { null }
val ts = transitionsOut(init).map { (it as BasicTransition<S, K>).label() }.map { it to mc.score(padding + it) }
return TODO()
data class FSATrajectory(val toks: List<Σᐩ?>, val lastState: BState, val score: Double) {
val isComplete: Boolean = lastState.isAccept
override fun toString() = toks.reversed().filterNotNull().joinToString(" ")
}

// Steers a random walk using the last n-1 transitions from the Markov Chain
fun BAutomaton.steerableRandomWalk(
mc: MarkovChain<Σᐩ>,
// BAutomata uses a Unicode alphabet, and the Markov Chain recognizes a
// string-based alphabet, so we need a way to translate between the two
dec: Map<Char, String>, // Maps unicode characters back to strings
topK: Int = 10_000_000 // Total number of top-K results to return
): List<Σᐩ> {
val startTime = TimeSource.Monotonic.markNow()
val fullTrajectories = PriorityQueue<FSATrajectory>(compareBy { it.score / it.toks.size })
val partTrajectories = PriorityQueue<FSATrajectory>(compareBy { it.score / it.toks.size })
.apply { add(FSATrajectory(List(mc.memory) { null }, initialState, 0.0)) }
while (fullTrajectories.size < topK && partTrajectories.isNotEmpty()) {
val partTraj = partTrajectories.remove()
val lastToks = partTraj.toks.take(mc.memory - 1).reversed()
partTraj.lastState.transitions.forEach { next: Transition ->
(next.min..next.max).forEach { tok ->
val decTok = dec[tok]
val nextToks = lastToks + decTok
val nextScore = partTraj.score + mc.scoreChunk(nextToks)
val traj = FSATrajectory(listOf(decTok) + partTraj.toks, next.dest, nextScore)
if (!traj.isComplete) partTrajectories.add(traj)
else {
fullTrajectories.add(traj)
if (traj.lastState.transitions.isNotEmpty())
partTrajectories.add(traj)
}
}
}
}

println("Top 10 trajectories:")
fullTrajectories.take(10).forEach { println(it.score.toString().take(5) + ": $it") }
println("Took ${startTime.elapsedNow()} to decode ${fullTrajectories.size} trajectories")

return fullTrajectories.map { it.toString() }
}

/*
./gradlew jvmTest --tests "ai.hypergraph.kaliningraph.automata.WFSATest.testPTreeVsWFSA"
*/
@Test
fun testPTreeVsWFSA() {
println("${P_BIFI_PY150.memory}-gram Markov chain is ready.")
// val toRepair = "from NAME import NAME NEWLINE NAME = NAME ( STRING , STRING ) NEWLINE NAME STRING . NAME ( NAME ) NEWLINE"
// val groundTr = "NEWLINE from NAME import NAME NEWLINE NAME = NAME ( STRING , STRING ) NEWLINE NAME ( STRING . NAME ( NAME ) ) NEWLINE"
// val radius = 3
val toRepair = "NAME : NEWLINE NAME = STRING NEWLINE NAME = NAME . NAME ( STRING ) NEWLINE"
val groundTr = "+ NAME : True NEWLINE NAME = STRING NEWLINE NAME = NAME . NAME ( STRING ) NEWLINE"
val radius = 2
val pt = Grammars.seq2parsePythonCFG.makeLevPTree(toRepair, radius, shortS2PParikhMap)
fun Char.toUnicodeEscaped() = "\\u${code.toString(16).padStart(4, '0')}"
Expand Down Expand Up @@ -146,8 +208,11 @@ class WFSATest {
}")
}
}
.let { it?.getFiniteStrings(-1) ?: emptySet() }
// ?.getFiniteStrings(-1)?.map { it.map { ctbl[it] }.joinToString(" ") } ?: emptySet()
?.steerableRandomWalk(P_BIFI_PY150, ctbl) ?: emptyList()
}.also {
assertTrue(groundTr in it.value, "Ground truth not found in ${it.value.size} repairs")
println("Index: ${it.value.indexOf(groundTr)}")
// // Print side by side comparison of repairs
// repairs.sorted().forEach {
// val a = it
Expand All @@ -156,9 +221,9 @@ class WFSATest {
// val colorB = if (b.isEmpty()) "" else levenshteinAlign(toRepair, b).paintANSIColors()
// println("$colorA\n$colorB\n")
// }
assertEquals(it.value.size, ptreeRepairs.value.size)
assertEquals(ptreeRepairs.value.size, it.value.size)

it.value.map { it.map { ctbl[it] }.joinToString(" ") }.forEach {
it.value.forEach {
// println(levenshteinAlign(toRepair, it).paintANSIColors())
assertTrue(levenshtein(toRepair, it) <= radius)
assertTrue(it in Grammars.seq2parsePythonCFG.language)
Expand Down
Loading

0 comments on commit 7e284ad

Please sign in to comment.