Skip to content

Commit

Permalink
try out frequency-based binarization
Browse files Browse the repository at this point in the history
  • Loading branch information
breandan committed Jul 2, 2024
1 parent 5dd0c4a commit 282b1fc
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ fun CFG.transformIntoCNF(): CFG =
addEpsilonProduction()
.refactorEpsilonProds()
.elimVarUnitProds()
.refactorRHS()
.binarizeRHSByFrequency()
.terminalsToUnitProds()
.removeUselessSymbols()

Expand Down Expand Up @@ -272,14 +272,39 @@ private tailrec fun CFG.elimVarUnitProds(
.elimVarUnitProds(toVisit.drop(1).toSet(), vars)
}

// Counts the number of times a pair of adjacent symbols appears in the RHS of a production
private fun CFG.countPairFreqs() =
flatMap { it.RHS.windowed(2, 1) }.groupingBy { it }.eachCount()

// TODO: try different heuristics from https://pages.cs.wisc.edu/~sding/paper/EMNLP2008.pdf
// Refactors long productions, e.g., (A -> BCD) -> (A -> BE, E -> CD)
private tailrec fun CFG.refactorRHS(): CFG {
private tailrec fun CFG.binarizeRHSByFrequency(): CFG {
val histogram: Map<List<Σᐩ>, Int> = countPairFreqs()
// Greedily chooses the production with the RHS pair that appears most frequently
val eligibleProds = filter { it.RHS.size > 2 }.maxByOrNull { longProd ->
longProd.RHS.windowed(2, 1).maxOfOrNull { histogram[it]!! } ?: 0
} ?: return this.elimVarUnitProds()
val mostFreqPair = eligibleProds.RHS.windowed(2, 1).mapIndexed { i, it -> i to it }.toSet()
.maxByOrNull { histogram[it.second]!! }!!
val freshName = mostFreqPair.second.joinToString(".")
val newProd = freshName to mostFreqPair.second
// Replace frequent pair of adjacent symbols in RHS with freshName
val allProdsWithPair = filter { mostFreqPair.second in it.RHS.windowed(2) }
val spProds = allProdsWithPair.map {
val idx = it.RHS.windowed(2).indexOfFirst { it == mostFreqPair.second }
it.LHS to (it.RHS.subList(0, idx) + freshName + it.RHS.subList(idx + 2, it.RHS.size))
}
val newGrammar = (this - allProdsWithPair) + spProds + newProd
return if (this == newGrammar) this.elimVarUnitProds() else newGrammar.binarizeRHSByFrequency()
}

private tailrec fun CFG.binarizeRHSByRightmost(): CFG {
val longProd = firstOrNull { it.RHS.size > 2 } ?: return this
val freshName = longProd.RHS.takeLast(2).joinToString(".")
val newProd = freshName to longProd.RHS.takeLast(2)
val shortProd = longProd.LHS to (longProd.RHS.dropLast(2) + freshName)
val newGrammar = this - longProd + shortProd + newProd
return if (this == newGrammar) this else newGrammar.refactorRHS()
return if (this == newGrammar) this else newGrammar.binarizeRHSByRightmost()
}

// Replaces terminals in non-unit productions, e.g., (A -> bC) -> (A -> BC, B -> b)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ object Grammars {
""".trimIndent().parseCFG().noNonterminalStubs

val shortS2PParikhMap by lazy { ParikhMap(seq2parsePythonCFG, 20) }
val seq2parsePythonCFG: CFG = """
val seq2parsePythonCFGStr = """
START -> Stmts_Or_Newlines
Stmts_Or_Newlines -> Stmt_Or_Newline | Stmt_Or_Newline Stmts_Or_Newlines
Stmt_Or_Newline -> Stmt | Newline
Expand Down Expand Up @@ -308,8 +308,10 @@ object Grammars {
Yield_Expr -> Yield_Keyword | Yield_Keyword Yield_Arg
Yield_Arg -> From_Keyword Test | Testlist_Endcomma
""".parseCFG().noNonterminalStubs
"""

val seq2parsePythonCFG: CFG = seq2parsePythonCFGStr.parseCFG().noNonterminalStubs
val seq2parsePythonVanillaCFG: CFG = seq2parsePythonCFGStr.parseCFG().noEpsilonOrNonterminalStubs

val checkedArithCFG = """
START -> S
Expand Down

0 comments on commit 282b1fc

Please sign in to comment.