Skip to content

Commit

Permalink
sketch O(n) postprocessor
Browse files Browse the repository at this point in the history
  • Loading branch information
breandan committed Jun 29, 2024
1 parent 71ee648 commit bf6bf91
Showing 1 changed file with 84 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,12 @@ import ai.hypergraph.kaliningraph.repair.minimizeFix
import ai.hypergraph.kaliningraph.types.*
import ai.hypergraph.kaliningraph.types.times
import java.util.concurrent.*
import java.util.concurrent.atomic.AtomicInteger
import java.util.stream.*
import kotlin.streams.*
import kotlin.time.Duration.Companion.minutes
import kotlin.time.TimeSource
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.atomic.LongAdder

fun CFG.parallelEnumSeqMinimalWOR(
prompt: List<String>,
Expand Down Expand Up @@ -147,13 +148,11 @@ val MINFREEMEM = 1000000000L
val MAX_NTS = 4_000_000 // Gives each nonterminal about ~35kb of memory on Xmx=150GB
val MAX_PRODS = 200_000_000

val maxNTsSeen = AtomicInteger(0)

private fun CFG.jvmIntersectLevFSAP(fsa: FSA, parikhMap: ParikhMap): CFG {
// if (fsa.Q.size < 650) throw Exception("FSA size was out of bounds")
var clock = TimeSource.Monotonic.markNow()

val nts = ConcurrentSkipListSet(setOf("START"))
val nts = ConcurrentHashMap.newKeySet<Σᐩ>().apply { add("START") }

val initFinal =
(fsa.init * fsa.final).map { (q, r) -> "START" to listOf("[$q~START~$r]") }
Expand Down Expand Up @@ -188,7 +187,7 @@ private fun CFG.jvmIntersectLevFSAP(fsa: FSA, parikhMap: ParikhMap): CFG {

val states = fsa.stateLst
val allsym = ntLst
val counter = AtomicInteger(0)
val counter = LongAdder()
val lpClock = TimeSource.Monotonic.markNow()
val binaryProds =
prods.parallelStream().flatMap {
Expand All @@ -203,7 +202,8 @@ private fun CFG.jvmIntersectLevFSAP(fsa: FSA, parikhMap: ParikhMap): CFG {
// .filter { it.obeysLevenshteinParikhBounds(A to B to C, fsa, parikhMap) }
.filter { it.checkCompatibility(trip, ct2) }
.map { (a, b, c) ->
if (MAX_PRODS < counter.incrementAndGet()) throw Exception("∩-grammar has too many productions! (>$MAX_PRODS)")
counter.increment()
if (MAX_PRODS < counter.sum()) throw Exception("∩-grammar has too many productions! (>$MAX_PRODS)")
val (p, q, r) = states[a] to states[b] to states[c]
"[$p~${allsym[A]}~$r]".also { nts.add(it) } to listOf("[$p~${allsym[B]}~$q]", "[$q~${allsym[C]}~$r]")
}
Expand All @@ -212,6 +212,7 @@ private fun CFG.jvmIntersectLevFSAP(fsa: FSA, parikhMap: ParikhMap): CFG {
val elimCounter = (validTriples.size * prods.size) - binaryProds.size
println("Levenshtein-Parikh constraints eliminated $elimCounter productions in ${lpClock.elapsedNow()}")

// !isSyntheticNT() === is START or a terminal
fun Σᐩ.isSyntheticNT() =
first() == '[' && length > 1 // && last() == ']' && count { it == '~' } == 2

Expand All @@ -224,6 +225,7 @@ private fun CFG.jvmIntersectLevFSAP(fsa: FSA, parikhMap: ParikhMap): CFG {
.collect(Collectors.toSet())
.also { println("Eliminated ${totalProds - it.size} extra productions before normalization") }
.jvmPostProcess(clock)
// .jdvpNew()
}

// Parallel streaming doesn't seem to be that much faster (yet)?
Expand Down Expand Up @@ -253,13 +255,87 @@ tailrec fun CFG.jvmElimVarUnitProds(
.jvmElimVarUnitProds(toVisit.drop(1).toSet(), vars)
}

// TODO: Incomplete / untested
// Based on: https://zerobone.net/blog/cs/non-productive-cfg-rules/
fun CFG.jdvpNew(): CFG {
println("Total productions: $size")
val timer = TimeSource.Monotonic.markNow()
val counter = ConcurrentHashMap<Set<Σᐩ>, LongAdder>()

val UDEPS = ConcurrentHashMap<Σᐩ, ConcurrentLinkedQueue<Set<Σᐩ>>>(size)
val NDEPS = ConcurrentHashMap<Set<Σᐩ>, ConcurrentLinkedQueue<Production>>(size).apply {
put(emptySet(), ConcurrentLinkedQueue())
this@jdvpNew.asSequence().asStream().parallel().forEach {
val v = it.second.toSet()
getOrPut(if(it.second.size == 1) emptySet() else v) { ConcurrentLinkedQueue() }.add(it)
v.forEach { s -> UDEPS.getOrPut(s) { ConcurrentLinkedQueue() }.add(v) }
if (v.size == 2) counter.putIfAbsent(v, LongAdder().apply { add(2L) })
}
}

println("Constructed dependency graph in ${timer.elapsedNow()}")

val visited = mutableSetOf<Production>()
val nextReachable: LinkedHashSet<Set<Σᐩ>> = LinkedHashSet<Set<Σᐩ>>().apply { add(emptySet()) }

val productive = mutableSetOf<Production>()
do {
println("Next reachable: ${nextReachable.size}, Visited: ${visited.size}, Productive: ${productive.size}")
val q = nextReachable.removeFirst()
if (q.size == 2) { // Conjunct
val dec = counter[q]!!.apply { decrement() }
if (dec.sum() == 0L) { // Seen both
NDEPS[q]?.forEach {
visited += it
UDEPS[it.LHS]?.forEach { st -> NDEPS[st]?.forEach { if (it !in visited) nextReachable += st } }
productive.add(it)
}
} else nextReachable += q // Always add back if sum not zero
} else {
NDEPS[q]?.forEach {
visited += it
UDEPS[it.LHS]?.forEach { st -> NDEPS[st]?.forEach { if (it !in visited) nextReachable += st } }
productive.add(it)
}
}
} while (nextReachable.isNotEmpty())

println("Eliminated ${size - productive.size} unproductive productions in ${timer.elapsedNow()}")
println("Resulting in ${productive.size} productions.")

val QDEPS =
ConcurrentHashMap<Σᐩ, ConcurrentLinkedQueue<Production>>(size).apply {
productive.asSequence().asStream().parallel().forEach {
getOrPut(it.LHS) { ConcurrentLinkedQueue() }.add(it)
}
}

val done = mutableSetOf(START_SYMBOL)
val nextProd: MutableList<Σᐩ> = mutableListOf(START_SYMBOL)
val productiveAndReachable = mutableSetOf<Production>()

do {
val q = nextProd.removeFirst().also { done += it }
QDEPS[q]?.forEach { it ->
productiveAndReachable.add(it)
it.RHS.forEach { if (it !in done) nextProd += it }
}
} while (nextProd.isNotEmpty())

println("Eliminated ${productive.size - productiveAndReachable.size} unreachable productions in ${timer.elapsedNow()}")
println("Resulting in ${productiveAndReachable.size} productions.")

return productiveAndReachable.freeze()
}

fun CFG.jvmDropVestigialProductions(clock: TimeSource.Monotonic.ValueTimeMark): CFG {
val start = clock.elapsedNow()
val counter = AtomicInteger(0)
val counter = LongAdder()
val nts: Set<Σᐩ> = asSequence().asStream().parallel().map { it.first }.collect(Collectors.toSet())
val rw = asSequence().asStream().parallel()
.filter { prod ->
if (counter.incrementAndGet() % 10 == 0 && BH_TIMEOUT < clock.elapsedNow()) throw Exception("Timeout! ${clock.elapsedNow()}")
counter.increment()
if (counter.toInt() % 10 == 0 && BH_TIMEOUT < clock.elapsedNow()) throw Exception("Timeout! ${clock.elapsedNow()}")
// Only keep productions whose RHS symbols are not synthetic or are in the set of NTs
prod.RHS.all { !(it.first() == '[' && 1 < it.length) || it in nts }
}
Expand All @@ -286,7 +362,6 @@ fun CFG.jvmDropVestigialProductions(clock: TimeSource.Monotonic.ValueTimeMark):
* A useful symbol is both generating and reachable.
*/

// TODO: https://zerobone.net/blog/cs/non-productive-cfg-rules/
fun CFG.jvmRemoveUselessSymbols(
generating: Set<Σᐩ> = jvmGenSym(),
reachable: Set<Σᐩ> = jvmReachSym()
Expand Down

0 comments on commit bf6bf91

Please sign in to comment.