Skip to content

Commit

Permalink
update levenshtein automata indices
Browse files Browse the repository at this point in the history
  • Loading branch information
breandan committed Jun 30, 2024
1 parent bf6bf91 commit 4a4f5b0
Show file tree
Hide file tree
Showing 7 changed files with 76 additions and 40 deletions.
Binary file modified latex/popl2025/popl.pdf
Binary file not shown.
2 changes: 1 addition & 1 deletion latex/popl2025/popl.tex
Original file line number Diff line number Diff line change
Expand Up @@ -563,7 +563,7 @@
\begin{prooftree}
\AxiomC{$i \in [0, n] \phantom{\land} j \in [1, k]$}
\RightLabel{$\duparrow$}
\UnaryInfC{$(q_{i, j-1} \overset{{\color{orange}[\neq \sigma_i]}}{\rightarrow} q_{i,j}) \in \delta$}
\UnaryInfC{$(q_{i, j-1} \overset{{\color{orange}[\neq \sigma_{i+1}]}}{\rightarrow} q_{i,j}) \in \delta$}
\DisplayProof
\hskip 1.5em
\AxiomC{$i \in [1, n] \phantom{\land} j \in [1, k]$}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@ class NOM(override val Q: TSA, override val init: Set<Σᐩ>, override val final
}

fun Σᐩ.predicate(): (Σᐩ) -> Boolean =
if (startsWith("[!=]")) { s: Σᐩ -> s != drop(4) } else { s: Σᐩ -> s == this }
if (this == "[.*]") { s: Σᐩ -> true }
else if (startsWith("[!=]")) { s: Σᐩ -> s != drop(4) }
else { s: Σᐩ -> s == this }

val mapF: Map<Σᐩ, List<Π2<StrPred, Σᐩ>>> by lazy {
Q.map { q -> q.first to q.second.predicate() to q.third }.groupBy { it.first }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ private fun CFG.intersectLevFSAP(fsa: FSA, parikhMap: ParikhMap = this.parikhMap
// such that δ(p, σ) = q we have the production [p, A, q] → σ in P′.
fun CFG.unitProdRules(fsa: FSA): List<Pair<String, List<Σᐩ>>> =
(unitProductions * fsa.nominalize().flattenedTriples)
.filter { (_, σ, arc) -> (arc.π2)(σ) }
.filter { (_, σ: Σᐩ, arc) -> (arc.π2)(σ) }
.map { (A, σ, arc) -> "[${arc.π1}~$A~${arc.π3}]" to listOf(σ) }

fun CFG.postProcess() =
Expand Down Expand Up @@ -201,6 +201,8 @@ val CFG.lengthBounds: Map<Σᐩ, IntRange> by cache {
map
}

val CFG.lengthBoundsCache by cache { lengthBounds.let { lb -> nonterminals.map { lb[it] ?: 0..0 } } }

fun Π3A<STC>.isValidStateTriple(): Boolean {
fun Pair<Int, Int>.dominates(other: Pair<Int, Int>) =
first <= other.first && second <= other.second
Expand Down Expand Up @@ -262,6 +264,7 @@ fun FSA.obeys(a: STC, b: STC, nt: Int, parikhMap: ParikhMap): Bln {
val sl = levString.size <= max(a.second, b.second) // Part of the LA that handles extra

if (sl) return true
// y-difference between Levenshtein levels of a and b, i.e., relaxation in case we are outside Parikh bounds
val margin = (b.third - a.third).absoluteValue
val length = (b.second - a.second)
val range = (length - margin).coerceAtLeast(1)..(length + margin)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@ typealias IProduction = Π2<Int, List<Int>>
typealias CFG = Set<Production>

val Production.LHS: Σᐩ get() = first
val Production.RHS: List<Σᐩ> get() =
second.let { if (it.size == 1) it.map(Σᐩ::stripEscapeChars) else it }
val Production.RHS: List<Σᐩ> get() = second
// Not sure why this was added, but we don't have time for it in production
// second.let { if (it.size == 1) it.map(Σᐩ::stripEscapeChars) else it }

/**
* "Freezes" the enclosed CFG, making it immutable and caching its hashCode for
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ fun CFG.levenshteinRepair(maxDist: Int, unparseable: List<Σᐩ>, solver: CJL.(L

fun makeLevFSA(str: Σᐩ, dist: Int): FSA = makeLevFSA(str.tokenizeByWhitespace(), dist)

/** Uses nominal arc predicates. See [NOM] for denominalization. */
fun makeLevFSA(
str: List<Σᐩ>,
dist: Int,
Expand Down Expand Up @@ -82,12 +83,12 @@ fun makeLevFSA(
private fun pd(i: Int, digits: Int) = i.toString().padStart(digits, '0')

/**
TODO: upArcs and diagArcs are the most expensive operations taking ~O(2n|Σ|) to construct.
upArcs and diagArcs are the most expensive operations taking ~O(2n|Σ|) to construct.
Later, the Bar-Hillel construction creates a new production for every triple QxQxQ, so it
increases the size of generated grammar by (2n|Σ|)^3. To fix this, we must instead create
increases the size of generated grammar by (2n|Σ|)^3. To fix this, we instead create
a nominal or parametric CFG with arcs which denote infinite alphabets.
See: [ai.hypergraph.kaliningraph.repair.CEAProb]
See also: [ai.hypergraph.kaliningraph.repair.CEAProb]
*//*
References
- https://arxiv.org/pdf/1402.0897.pdf#section.7
Expand All @@ -100,10 +101,38 @@ private fun pd(i: Int, digits: Int) = i.toString().padStart(digits, '0')
(q_i,j−1 -s→ q_i,j)∈δ
*/

/*
Precision@All
=============
|σ|∈[0, 10): Top-1/total: 28 / 28 = 1.0
|σ|∈[10, 20): Top-1/total: 41 / 41 = 1.0
|σ|∈[20, 30): Top-1/total: 45 / 46 = 0.9782608695652174
|σ|∈[30, 40): Top-1/total: 41 / 41 = 1.0
|σ|∈[40, 50): Top-1/total: 9 / 11 = 0.8181818181818182
Δ(1)= Top-1/total: 57 / 58 = 0.9827586206896551
Δ(2)= Top-1/total: 57 / 58 = 0.9827586206896551
Δ(3)= Top-1/total: 50 / 51 = 0.9803921568627451
(|σ|∈[0, 10), Δ=1): Top-1/total: 11 / 11 = 1.0
(|σ|∈[0, 10), Δ=2): Top-1/total: 11 / 11 = 1.0
(|σ|∈[0, 10), Δ=3): Top-1/total: 6 / 6 = 1.0
(|σ|∈[10, 20), Δ=1): Top-1/total: 12 / 12 = 1.0
(|σ|∈[10, 20), Δ=2): Top-1/total: 11 / 11 = 1.0
(|σ|∈[10, 20), Δ=3): Top-1/total: 18 / 18 = 1.0
(|σ|∈[20, 30), Δ=1): Top-1/total: 18 / 18 = 1.0
(|σ|∈[20, 30), Δ=2): Top-1/total: 13 / 13 = 1.0
(|σ|∈[20, 30), Δ=3): Top-1/total: 14 / 15 = 0.9333333333333333
(|σ|∈[30, 40), Δ=1): Top-1/total: 11 / 11 = 1.0
(|σ|∈[30, 40), Δ=2): Top-1/total: 19 / 19 = 1.0
(|σ|∈[30, 40), Δ=3): Top-1/total: 11 / 11 = 1.0
(|σ|∈[40, 50), Δ=1): Top-1/total: 5 / 6 = 0.8333333333333334
(|σ|∈[40, 50), Δ=2): Top-1/total: 3 / 4 = 0.75
(|σ|∈[40, 50), Δ=3): Top-1/total: 1 / 1 = 1.0
*/

fun upArcs(str: List<Σᐩ>, dist: Int, digits: Int): TSA =
((0..<str.size + dist).toSet() * (1..dist).toSet())
((0..str.size).toSet() * (1..dist).toSet())
// .filter { (i, _, s) -> str.size <= i || str[i] != s }
.filter { (i, j) -> i <= str.size || i - str.size < j }
// .filter { (i, j) -> i <= str.size || i - str.size < j }
.map { (i, j) -> i to j to if (i < str.size) str[i] else "###" }
.map { (i, j, s) -> i to j - 1 to "[!=]$s" to i to j }
.postProc(digits)
Expand All @@ -115,10 +144,10 @@ fun upArcs(str: List<Σᐩ>, dist: Int, digits: Int): TSA =
*/

fun diagArcs(str: List<Σᐩ>, dist: Int, digits: Int): TSA =
((1..<str.size + dist).toSet() * (1..dist).toSet())
((1..str.size).toSet() * (1..dist).toSet())
// .filter { (i, _, s) -> str.size <= i - 1 || str[i - 1] != s }
.filter { (i, j) -> i <= str.size || i - str.size <= j }
.map { (i, j) -> i to j to if (str.size <= i - 1) "###" else str[i - 1] }
.map { (i, j) -> i to j to str[i - 1] }
.map { (i, j, s) -> i - 1 to j - 1 to "[!=]$s" to i to j }
.postProc(digits)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,9 +167,7 @@ private fun CFG.jvmIntersectLevFSAP(fsa: FSA, parikhMap: ParikhMap): CFG {

// For each production A → BC in P, for every p, q, r ∈ Q,
// we have the production [p,A,r] → [p,B,q] [q,C,r] in P′.
val prods = nonterminalProductions
.map { (a, b) -> ntMap[a]!! to b.map { ntMap[it]!! } }.toSet()
val lengthBoundsCache = lengthBounds.let { lb -> nonterminals.map { lb[it] ?: 0..0 } }
val prods = nonterminalProductions.map { (a, b) -> ntMap[a]!! to b.map { ntMap[it]!! } }.toSet()
val validTriples = fsa.validTriples.map { arrayOf(it.π11, it.π21, it.π31) }

val ctClock = TimeSource.Monotonic.markNow()
Expand All @@ -187,7 +185,7 @@ private fun CFG.jvmIntersectLevFSAP(fsa: FSA, parikhMap: ParikhMap): CFG {

val states = fsa.stateLst
val allsym = ntLst
val counter = LongAdder()
var counter = 0
val lpClock = TimeSource.Monotonic.markNow()
val binaryProds =
prods.parallelStream().flatMap {
Expand All @@ -202,8 +200,7 @@ private fun CFG.jvmIntersectLevFSAP(fsa: FSA, parikhMap: ParikhMap): CFG {
// .filter { it.obeysLevenshteinParikhBounds(A to B to C, fsa, parikhMap) }
.filter { it.checkCompatibility(trip, ct2) }
.map { (a, b, c) ->
counter.increment()
if (MAX_PRODS < counter.sum()) throw Exception("∩-grammar has too many productions! (>$MAX_PRODS)")
if (MAX_PRODS < counter++) throw Exception("∩-grammar has too many productions! (>$MAX_PRODS)")
val (p, q, r) = states[a] to states[b] to states[c]
"[$p~${allsym[A]}~$r]".also { nts.add(it) } to listOf("[$p~${allsym[B]}~$q]", "[$q~${allsym[C]}~$r]")
}
Expand Down Expand Up @@ -257,45 +254,48 @@ tailrec fun CFG.jvmElimVarUnitProds(

// TODO: Incomplete / untested
// Based on: https://zerobone.net/blog/cs/non-productive-cfg-rules/
// Precondition: The CFG must be binarized, i.e., almost CNF but may have useless productions
// Postcondition: The CFG is in Chomsky Normal Form (CNF)
fun CFG.jdvpNew(): CFG {
println("Total productions: $size")
val timer = TimeSource.Monotonic.markNow()
val counter = ConcurrentHashMap<Set<Σᐩ>, LongAdder>()

// Maps each nonterminal to the set of RHS sets that contain it
val UDEPS = ConcurrentHashMap<Σᐩ, ConcurrentLinkedQueue<Set<Σᐩ>>>(size)
// Maps the set of symbols on the RHS of a production to the production
val NDEPS = ConcurrentHashMap<Set<Σᐩ>, ConcurrentLinkedQueue<Production>>(size).apply {
put(emptySet(), ConcurrentLinkedQueue())
this@jdvpNew.asSequence().asStream().parallel().forEach {
val v = it.second.toSet()
val v = it.second.toSet() // RHS set, i.e., the set of NTs on the RHS of a production
// If |v| is 1, then the production must be a unit production, i.e, A -> a, b/c A -> B is not binarized
getOrPut(if(it.second.size == 1) emptySet() else v) { ConcurrentLinkedQueue() }.add(it)
v.forEach { s -> UDEPS.getOrPut(s) { ConcurrentLinkedQueue() }.add(v) }
if (v.size == 2) counter.putIfAbsent(v, LongAdder().apply { add(2L) })
}
}

println("Constructed dependency graph in ${timer.elapsedNow()}")
println("Built graph in ${timer.elapsedNow()}: ${counter.size} conjuncts, ${UDEPS.size + NDEPS.size} edges")

val visited = mutableSetOf<Production>()
val nextReachable: LinkedHashSet<Set<Σᐩ>> = LinkedHashSet<Set<Σᐩ>>().apply { add(emptySet()) }

val productive = mutableSetOf<Production>()
do {
println("Next reachable: ${nextReachable.size}, Visited: ${visited.size}, Productive: ${productive.size}")
// println("Next reachable: ${nextReachable.size}, Productive: ${productive.size}")
val q = nextReachable.removeFirst()
if (q.size == 2) { // Conjunct
if (counter[q]?.sum() == 0L || NDEPS[q]?.all { it in productive } == true) continue
else if (q.size == 2) { // Conjunct
val dec = counter[q]!!.apply { decrement() }
if (dec.sum() == 0L) { // Seen both
NDEPS[q]?.forEach {
visited += it
UDEPS[it.LHS]?.forEach { st -> NDEPS[st]?.forEach { if (it !in visited) nextReachable += st } }
productive.add(it)
UDEPS[it.LHS]?.forEach { st -> if (st !in productive) nextReachable.addLast(st) }
}
} else nextReachable += q // Always add back if sum not zero
} else nextReachable.addLast(q) // Always add back if sum not zero
} else {
NDEPS[q]?.forEach {
visited += it
UDEPS[it.LHS]?.forEach { st -> NDEPS[st]?.forEach { if (it !in visited) nextReachable += st } }
productive.add(it)
UDEPS[it.LHS]?.forEach { st -> if (st !in productive) nextReachable.addLast(st) }
}
}
} while (nextReachable.isNotEmpty())
Expand Down Expand Up @@ -330,18 +330,19 @@ fun CFG.jdvpNew(): CFG {

fun CFG.jvmDropVestigialProductions(clock: TimeSource.Monotonic.ValueTimeMark): CFG {
val start = clock.elapsedNow()
val counter = LongAdder()
var counter = 0
val nts: Set<Σᐩ> = asSequence().asStream().parallel().map { it.first }.collect(Collectors.toSet())
val rw = asSequence().asStream().parallel()
.filter { prod ->
counter.increment()
if (counter.toInt() % 10 == 0 && BH_TIMEOUT < clock.elapsedNow()) throw Exception("Timeout! ${clock.elapsedNow()}")
if (counter++ % 1000 == 0 && BH_TIMEOUT < clock.elapsedNow()) throw Exception("Timeout! ${clock.elapsedNow()}")
// Only keep productions whose RHS symbols are not synthetic or are in the set of NTs
prod.RHS.all { !(it.first() == '[' && 1 < it.length) || it in nts }
}
.collect(Collectors.toSet())
.also { println("Removed ${size - it.size} invalid productions in ${clock.elapsedNow() - start}") }
.freeze().jvmRemoveUselessSymbols()
.freeze()
.jvmRemoveUselessSymbols()
//.jdvpNew()

println("Removed ${size - rw.size} vestigial productions, resulting in ${rw.size} productions.")

Expand Down Expand Up @@ -375,22 +376,22 @@ private fun CFG.jvmReachSym(from: Σᐩ = START_SYMBOL): Set<Σᐩ> {
val allReachable: MutableSet<Σᐩ> = mutableSetOf(from)
val nextReachable: MutableSet<Σᐩ> = mutableSetOf(from)
val NDEPS =
ConcurrentHashMap<Σᐩ, ConcurrentSkipListSet<Σᐩ>>(size).apply {
ConcurrentHashMap<Σᐩ, MutableSet<Σᐩ>>(size).apply {
this@jvmReachSym.asSequence().asStream().parallel()
.forEach { (l, r) -> getOrPut(l) { ConcurrentSkipListSet() }.addAll(r) }
.forEach { (l, r) -> getOrPut(l) { ConcurrentHashMap.newKeySet() }.addAll(r) }
}
// [email protected]().asStream().parallel()
// .flatMap { (l, r) -> r.stream().map { l to it } }
// // List of second elements grouped by first element
// .collect(Collectors.groupingByConcurrent({ it.first }, Collectors.mapping({ it.second }, Collectors.toSet())))

do {
while (nextReachable.isNotEmpty()) {
val t = nextReachable.first()
nextReachable.remove(t)
allReachable += t
nextReachable += (NDEPS[t]?: emptyList())
.filter { it !in allReachable && it !in nextReachable }
} while (nextReachable.isNotEmpty())
}

// println("TERM: ${allReachable.any { it in terminals }} ${allReachable.size}")

Expand All @@ -406,22 +407,22 @@ private fun CFG.jvmGenSym(
val allGenerating: MutableSet<Σᐩ> = mutableSetOf()
val nextGenerating: MutableSet<Σᐩ> = from.toMutableSet()
val TDEPS =
ConcurrentHashMap<Σᐩ, ConcurrentSkipListSet<Σᐩ>>(size).apply {
ConcurrentHashMap<Σᐩ, MutableSet<Σᐩ>>(size).apply {
this@jvmGenSym.asSequence().asStream().parallel()
.forEach { (l, r) -> r.forEach { getOrPut(it) { ConcurrentSkipListSet() }.add(l) } }
.forEach { (l, r) -> r.forEach { getOrPut(it) { ConcurrentHashMap.newKeySet() }.add(l) } }
}
// [email protected]().asStream().parallel()
// .flatMap { (l, r) -> r.asSequence().asStream().map { it to l } }
// // List of second elements grouped by first element
// .collect(Collectors.groupingByConcurrent({ it.first }, Collectors.mapping({ it.second }, Collectors.toList())))

do {
while (nextGenerating.isNotEmpty()) {
val t = nextGenerating.first()
nextGenerating.remove(t)
allGenerating += t
nextGenerating += (TDEPS[t] ?: emptyList())
.filter { it !in allGenerating && it !in nextGenerating }
} while (nextGenerating.isNotEmpty())
}

// println("START: ${START_SYMBOL in allGenerating} ${allGenerating.size}")

Expand Down

0 comments on commit 4a4f5b0

Please sign in to comment.