Skip to content

Commit

Permalink
steerable random walk through dfa
Browse files Browse the repository at this point in the history
  • Loading branch information
breandan committed Jun 21, 2024
1 parent b8b2f64 commit 20a2afe
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 8 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@ jobs:

steps:
- uses: actions/checkout@v2
- name: Set up JDK 17
- name: Set up JDK 21
uses: actions/setup-java@v1
with:
java-version: 17
java-version: 21
- name: Build with Gradle
run: ./gradlew -PleaseExcludeBenchmarks allTests --stacktrace
17 changes: 12 additions & 5 deletions src/jvmMain/kotlin/ai/hypergraph/markovian/mcmc/MarkovChain.kt
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ open class MarkovChain<T>(
val dists: LRUCache<List<Int>, Dist> = LRUCache()

// Computes perplexity of a sequence normalized by sequence length (lower is better)
fun score(seq: List<T>): Double =
fun score(seq: List<T?>): Double =
if (memory < seq.size) -seq.windowed(memory)
.map { (getAtLeastOne(it) + 1) / (getAtLeastOne(it.dropLast(1) + null) + dictionary.size) }
.sumOf { ln(it) } / seq.size
Expand Down Expand Up @@ -257,14 +257,21 @@ open class MarkovChain<T>(
var total = 0L
lines.map { it.substringBefore(CSVSEP).split(" ") to it.substringAfter(CSVSEP).toLong() }
.forEach { (ngram, count) ->
total += count
nrmCounts.update(ngram, count)
val padding = List(memory - 1) { null }
val windows = (padding + ngram + padding).windowed(memory, 1)
total += count * windows.size
windows.forEach { nrmCounts.update(it, count) }
ngram.forEach { rawCounts.update(it, count) }
}
return MarkovChain(
train = sequenceOf(),
train = sequenceOf(), // Empty since we already know the counts, no need to retrain
memory = memory,
Counter(total = AtomicInteger(total.toInt()), memory = memory, rawCounts = rawCounts, nrmCounts = nrmCounts)
Counter(
total = AtomicInteger(total.toInt()),
memory = memory,
rawCounts = rawCounts,
nrmCounts = nrmCounts
)
)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import Grammars
import Grammars.shortS2PParikhMap
import ai.hypergraph.kaliningraph.graphs.LabeledGraph
import ai.hypergraph.kaliningraph.parsing.*
import ai.hypergraph.markovian.mcmc.MarkovChain
import net.jhoogland.jautomata.*
import net.jhoogland.jautomata.Automaton
import net.jhoogland.jautomata.operations.*
Expand Down Expand Up @@ -62,6 +63,20 @@ class WFSATest {
.replace("Mrecord", "circle") // FSA states should be circular
.replace("null", "ε") // null label = ε-transition

/*
* Returns a sequence trajectories through a DFA sampled using the Markov chain.
* The DFA is expected to be deterministic. We use the Markov chain to steer the
* random walk through the DFA by sampling the best transitions conditioned on the
* previous n-1 transitions, i.e., q' ~ argmax_{q'} P(q' | q_{t-1}, ..., q_{t-n+1})
*/

fun <S, K> Automaton<S, K>.randomWalk(mc: MarkovChain<S>, topK: Int = 1000): Sequence<S> {
val init = initialStates().first()
val padding = List(mc.memory - 1) { null }
val ts = transitionsOut(init).map { (it as BasicTransition<S, K>).label() }.map { it to mc.score(padding + it) }
return TODO()
}

/*
./gradlew jvmTest --tests "ai.hypergraph.kaliningraph.automata.WFSATest.testPTreeVsWFSA"
*/
Expand All @@ -70,7 +85,7 @@ class WFSATest {
val toRepair = "NAME : NEWLINE NAME = STRING NEWLINE NAME = NAME . NAME ( STRING ) NEWLINE"
val radius = 2
val pt = Grammars.seq2parsePythonCFG.makeLevPTree(toRepair, radius, shortS2PParikhMap)
println(pt.totalTrees.toString())
println("Total trees: " + pt.totalTrees.toString())
val maxResults = 10_000
val ptreeRepairs = measureTimedValue {
pt.sampleStrWithoutReplacement().distinct().take(maxResults).toSet()
Expand Down

0 comments on commit 20a2afe

Please sign in to comment.