Skip to content

Commit

Permalink
evaluate PCFG NLL reranker
Browse files Browse the repository at this point in the history
  • Loading branch information
breandan committed Jul 7, 2024
1 parent 4a754aa commit b9134ff
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ fun CFG.barHillelRepair(prompt: List<Σᐩ>, distance: Int) =

// http://www.cs.umd.edu/~gasarch/BLOGPAPERS/cfg.pdf#page=2
// https://browse.arxiv.org/pdf/2209.06809.pdf#page=5
private fun CFG.intersectLevFSAP(fsa: FSA, parikhMap: ParikhMap = this.parikhMap): CFG {
fun CFG.intersectLevFSAP(fsa: FSA, parikhMap: ParikhMap = this.parikhMap): CFG {
var clock = TimeSource.Monotonic.markNow()
val nts = mutableSetOf("START")
fun Σᐩ.isSyntheticNT() =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ fun CFG.transformIntoCNF(): CFG =
addEpsilonProduction()
.refactorEpsilonProds()
.elimVarUnitProds()
.binarizeRHSByFrequency()
// .binarizeRHSByFrequency()
.binarizeRHSByRightmost()
.terminalsToUnitProds()
.removeUselessSymbols()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@ import ai.hypergraph.kaliningraph.tensor.UTMatrix
import ai.hypergraph.kaliningraph.types.*
import com.ionspin.kotlin.bignum.integer.*
import kotlin.jvm.JvmName
import kotlin.math.ln
import kotlin.random.*
import kotlin.time.measureTimedValue


// Indexes a set of PTrees by their roots
typealias PForest = Map<String, PTree> // ℙ₃
// Algebraic data type / polynomial functor for parse forests (ℙ₂)
Expand Down Expand Up @@ -108,6 +108,20 @@ class PTree(val root: String = ".ε", val branches: List<Π2A<PTree>> = listOf()
return if (left.isEmpty()) right else if (right.isEmpty()) left else "$left $right"
}

private fun newDecoderWithProb(i: BigInteger, pcfgMap: Map<Π3A<Σᐩ>, Int>, pcfgNorm: Map<Σᐩ, Int>): Pair<String, Double> {
if (branches.isEmpty()) return epsStr to 0.0
val t = ranges.indexOfFirst { it.first <= i && i <= it.second }
val (l, r) = branches[t]
val q = i - ranges[t].first
val (iLeft, iRight) = q.divrem(r.totalTrees)
val (lroot, rroot) = l.rootName to r.rootName
val (left, leftScore) = l.newDecoderWithProb(iLeft, pcfgMap, pcfgNorm)
val (right, rightScore) = r.newDecoderWithProb(iRight, pcfgMap, pcfgNorm)
val myScore = ln((pcfgMap[root to lroot to rroot]?.toDouble() ?: 0.00001) / (pcfgNorm[root]?.toDouble() ?: 1.0)) +
leftScore + rightScore
return (if (left.isEmpty()) right else if (right.isEmpty()) left else "$left $right") to myScore
}

// Average time: 436.96ms, total time 43696.959ms (testRandomCFG)
private fun decodeString(i: BigInteger): Pair<String, BigInteger> {
if (branches.isEmpty()) return epsStr to i
Expand Down Expand Up @@ -154,6 +168,20 @@ class PTree(val root: String = ".ε", val branches: List<Π2A<PTree>> = listOf()
while (i < totalTrees) { yield(newDecoder(i)); i++}
}

// Returns trees WoR from the CFG and scores the strings with a PCFG-based log-likelihood
fun sampleStrWithoutReplacementAndScore(
stride: Int = 1, offset: Int = 0,
pcfgMap: Map<Π3A<Σᐩ>, Int>, pcfgNorm: Map<Σᐩ, Int>
): Sequence<Π2<String, Double>> =
if (6 < totalTrees.bitLength())
bigLFSRSequence(totalTrees).mapIndexedNotNull { index, i ->
if (index % stride == offset) newDecoderWithProb(i, pcfgMap, pcfgNorm) else null
}
else sequence {
var i = BigInteger.ZERO
while (i < totalTrees) { yield(newDecoderWithProb(i, pcfgMap, pcfgNorm)); i++}
}

fun sampleStrWithPCFG5(pcfgTable: Map<Int, Int>): Sequence<String> =
sequence { while (true) yield(samplePCFG5(pcfgTable)) }

Expand Down Expand Up @@ -186,6 +214,7 @@ class PTree(val root: String = ".ε", val branches: List<Π2A<PTree>> = listOf()
if (a.isEmpty()) b else if (b.isEmpty()) a else "$a $b"
}

/** See [intersectLevFSAP], extracts original NT name from a synthetic ∩-NT. */
fun Σᐩ.name() = if ('~' in this) split('~')[1] else this
val triples : List<Π2A<Int>> by lazy { branches.map { it.first.ntIdx to it.second.ntIdx } }
val rootName by lazy { root.name() }
Expand All @@ -199,7 +228,7 @@ class PTree(val root: String = ".ε", val branches: List<Π2A<PTree>> = listOf()
// .also { if(Random.nextInt(10000) == 3) if (it == 1) println("$hash Miss"); else println("$hash Hit") }
+ 1 }
val cdf = probs.runningReduce { acc, i -> acc + i }
val rnd = Random.nextInt(probs.sum())
val rnd = Random.nextInt(cdf.last())
val childIdx = cdf.binarySearch { it.compareTo(rnd) }.let { if (it < 0) -it - 1 else it }
val (l, r) = branches[childIdx]
val (lr, rr) = l.ntIdx to r.ntIdx
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import ai.hypergraph.kaliningraph.graphs.LGVertex
import ai.hypergraph.kaliningraph.graphs.LabeledGraph
import ai.hypergraph.kaliningraph.tensor.FreeMatrix
import ai.hypergraph.kaliningraph.types.*
import kotlin.math.ln

typealias TreeMatrix = FreeMatrix<Forest>
typealias Forest = Set<Tree>
Expand Down Expand Up @@ -39,10 +40,10 @@ class Tree constructor(
children[0].quintuples(root, children[0].root + "*", children[1].root) +
children[1].quintuples(root, children[0].root, children[1].root + "*")

fun logProb(pcfgMap: Map<Π3A<Σᐩ>, Int>): Double =
fun logProb(pcfgMap: Map<Π3A<Σᐩ>, Int>, pcfgNorm: Map<Σᐩ, Int>): Double =
if (children.isEmpty()) 0.0
else (pcfgMap[root to children[0].root to children[1].root]?.toDouble() ?: 0.0) +
children.sumOf { it.logProb(pcfgMap) }
else ln((pcfgMap[root to children[0].root to children[1].root]?.toDouble() ?: 0.00001) / (pcfgNorm[root]?.toDouble() ?: 1.0)) +
children.sumOf { it.logProb(pcfgMap, pcfgNorm) }

fun toGraph(j: Σᐩ = "0"): LabeledGraph =
LabeledGraph { LGVertex(root, "$root.$j").let { it - it } } +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,18 @@ fun PTree.sampleDirectlyWOR(
.asStream()
}

fun PTree.sampleDirectlyWORAndScore(
cores: Int = NUM_CORES,
stoppingCriterion: () -> Boolean = { true },
pcfgMap: Map<Π3A<Σᐩ>, Int>, pcfgNorm: Map<Σᐩ, Int>
): Stream<Π2<String, Double>> =
(0..<cores).toList().parallelStream().flatMap { i ->
sampleStrWithoutReplacementAndScore(cores, i, pcfgMap, pcfgNorm)
.takeWhile { stoppingCriterion() }
.distinctBy { it.first }
.asStream()
}

fun CFG.parallelEnumListWR(
prompt: List<String>,
cores: Int = NUM_CORES,
Expand Down

0 comments on commit b9134ff

Please sign in to comment.