Skip to content

Commit

Permalink
add memoization and test
Browse files Browse the repository at this point in the history
  • Loading branch information
winitzki committed Jul 11, 2024
1 parent b19a4d6 commit 471c411
Show file tree
Hide file tree
Showing 4 changed files with 219 additions and 21 deletions.
50 changes: 30 additions & 20 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -13,30 +13,30 @@ val supportedScalaVersions = Seq(scala2V, scala3V)

def munitFramework = new TestFramework("munit.Framework")

val munitTest = "org.scalameta" %% "munit" % "0.7.29" % Test
val munitTest = "org.scalameta" %% "munit" % "1.0.0" % Test
val assertVerboseTest = "com.eed3si9n.expecty" %% "expecty" % "0.16.0" % Test

val fastparse = "com.lihaoyi" %% "fastparse" % "3.0.2"
val fastparse = "com.lihaoyi" %% "fastparse" % "3.1.1"
val antlr4 = "org.antlr" % "antlr4-runtime" % "4.13.1"
val anltr4_formatter = "com.khubla.antlr4formatter" % "antlr4-formatter-standalone" % "1.2.1" % Provided

val os_lib = "com.lihaoyi" %% "os-lib" % "0.9.2"
val httpRequest = "com.lihaoyi" %% "requests" % "0.8.0"
val enumeratum = "com.beachape" %% "enumeratum" % "1.7.3"
val izumi_reflect = "dev.zio" %% "izumi-reflect" % "2.3.8"
val zio_schema = "dev.zio" %% "zio-schema" % "1.2.1"
val zio_schema_deriving = "dev.zio" %% "zio-schema-derivation" % "1.1.1"
val kindProjector = "org.typelevel" % "kind-projector" % "0.13.3" cross CrossVersion.full
val jnr_posix = "com.github.jnr" % "jnr-posix" % "3.1.19"
val cbor1 = "co.nstant.in" % "cbor" % "0.9"
val cbor2 = "com.upokecenter" % "cbor" % "4.5.3"
val reflections = "org.reflections" % "reflections" % "0.10.2"
val mainargs = "com.lihaoyi" %% "mainargs" % "0.7.0"

val os_lib = "com.lihaoyi" %% "os-lib" % "0.10.2"
val httpRequest = "com.lihaoyi" %% "requests" % "0.8.0"
val enumeratum = "com.beachape" %% "enumeratum" % "1.7.3"
val izumi_reflect = "dev.zio" %% "izumi-reflect" % "2.3.8"
val zio_schema = "dev.zio" %% "zio-schema" % "1.2.1"
val zio_schema_deriving = "dev.zio" %% "zio-schema-derivation" % "1.2.2"
val kindProjector = "org.typelevel" % "kind-projector" % "0.13.3" cross CrossVersion.full
val jnr_posix = "com.github.jnr" % "jnr-posix" % "3.1.19"
val cbor1 = "co.nstant.in" % "cbor" % "0.9"
val cbor2 = "com.upokecenter" % "cbor" % "4.5.3"
val reflections = "org.reflections" % "reflections" % "0.10.2"
val mainargs = "com.lihaoyi" %% "mainargs" % "0.7.0"
val sourcecode = "com.lihaoyi" %% "sourcecode" % "0.4.2"
// Not used now:
val flatlaf = "com.formdev" % "flatlaf" % "3.2.2"
val cbor3 = "io.bullet" %% "borer-core" % "1.8.0"
val scalahashing = "com.desmondyeung.hashing" %% "scala-hashing" % "0.1.0"
val flatlaf = "com.formdev" % "flatlaf" % "3.4.1"
val cbor3 = "io.bullet" %% "borer-core" % "1.8.0"
val scalahashing = "com.desmondyeung.hashing" %% "scala-hashing" % "0.1.0"

val kindProjectorPlugin = compilerPlugin(kindProjector)

Expand Down Expand Up @@ -113,7 +113,17 @@ lazy val nano_dhall = (project in file("nano-dhall"))
httpRequest,
os_lib % Test,
),
).dependsOn(scall_testutils % "test->compile", scall_typeclasses)
).dependsOn(scall_testutils % "test->compile", scall_typeclasses, fastparse_memoize)

lazy val fastparse_memoize = (project in file("fastparse-memoize"))
.settings(publishingOptions)
.settings(
scalaVersion := scalaV,
crossScalaVersions := supportedScalaVersions,
testFrameworks += munitFramework,
Test / javaOptions ++= jdkModuleOptions,
libraryDependencies ++= Seq(fastparse, sourcecode, munitTest, assertVerboseTest),
).dependsOn(scall_testutils % "test->compile")

lazy val scall_core = (project in file("scall-core"))
.settings(publishingOptions)
Expand Down Expand Up @@ -161,7 +171,7 @@ lazy val scall_core = (project in file("scall-core"))
httpRequest,
os_lib % Test,
),
).dependsOn(scall_testutils % "test->compile", scall_typeclasses)
).dependsOn(scall_testutils % "test->compile", scall_typeclasses, fastparse_memoize)

lazy val scall_testutils = (project in file("scall-testutils"))
.settings(publishingOptions)
Expand Down
128 changes: 128 additions & 0 deletions fastparse-memoize/src/main/scala/io/chymyst/fastparse/Memoize.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
package io.chymyst.fastparse

import fastparse.{P, Parsed, ParserInput, ParserInputSource, ParsingRun}
import fastparse.internal.{Instrument, Msgs}

import scala.collection.mutable

final case class PRunData( // Copy all the mutable data from ParsingRun.
terminalMsgs: Msgs,
aggregateMsgs: Msgs,
shortMsg: Msgs,
lastFailureMsg: Msgs,
failureStack: List[(String, Int)],
isSuccess: Boolean,
logDepth: Int,
index: Int,
cut: Boolean,
successValue: Any,
verboseFailures: Boolean,
noDropBuffer: Boolean,
misc: collection.mutable.Map[Any, Any],
) {
override def toString: String = {
s"ParsingRun(index=$index, isSuccess = $isSuccess, successValue = $successValue)"
}


}

object PRunData { // Copy all the mutable data from a parsing run into a PRunData value.
def ofParsingRun[T](pr: ParsingRun[T]): PRunData = PRunData(
pr.terminalMsgs,
pr.aggregateMsgs,
pr.shortMsg,
pr.lastFailureMsg,
pr.failureStack,
pr.isSuccess,
pr.logDepth,
pr.index,
pr.cut,
pr.successValue,
pr.verboseFailures,
pr.noDropBuffer,
mutable.Map.from(pr.misc),
)
}

object Memoize {
def assignToParsingRun[T](data: PRunData, pr: ParsingRun[T]): ParsingRun[T] = { // Assign the mutable data to a given ParsingRun value.
pr.terminalMsgs = data.terminalMsgs
pr.aggregateMsgs = data.aggregateMsgs
pr.shortMsg = data.shortMsg
pr.lastFailureMsg = data.lastFailureMsg
pr.failureStack = data.failureStack
pr.isSuccess = data.isSuccess
pr.logDepth = data.logDepth
pr.index = data.index
pr.cut = data.cut
pr.successValue = data.successValue
pr.verboseFailures = data.verboseFailures
pr.noDropBuffer = data.noDropBuffer
data.misc.foreach { case (k, v) => pr.misc.put(k, v) }
pr
}
private def cacheGrammar[R](cache: mutable.Map[Int, PRunData], parser: => P[_])(implicit p: P[_]): P[R] = {
// The `parser` has not yet been run! And it is mutable. Do not run it twice!
val cachedData: PRunData = cache.getOrElseUpdate(p.index, PRunData.ofParsingRun(parser))
// After the `parser` has been run on `p`, the value of `p` changes and becomes equal to the result of running the parser.
// If the result was cached, we need to assign it to the current value of `p`. This will imitate the side effect of running the parser again.
assignToParsingRun(cachedData, p).asInstanceOf[P[R]]
}

private val cache = new mutable.HashMap[(sourcecode.File, sourcecode.Line), mutable.Map[Int, PRunData]]

private def getOrCreateCache(file : sourcecode.File, line: sourcecode.Line): mutable.Map[Int, PRunData] = {
cache.getOrElseUpdate((file, line), new mutable.HashMap[Int, PRunData])
}

implicit class MemoizeParser[A](parser: => P[A]) {
def memoize(implicit file : sourcecode.File, line: sourcecode.Line, p: P[_]): P[A] = {
val cache: mutable.Map[Int, PRunData] = getOrCreateCache(file, line)
cacheGrammar(cache, parser)
}
}

def clearAll(): Unit = cache.values.foreach(_.clear())

def statistics: String = cache.map {case ((file, line), c) => s"$file#$line: ${c.size} entries"}.mkString("\n")

def parse[T](input: ParserInputSource,
parser: P[_] => P[T],
verboseFailures: Boolean = false,
startIndex: Int = 0,
instrument: Instrument = null): Parsed[T] = {
clearAll()
fastparse.parse(input, parser, verboseFailures, startIndex, instrument)
}

def parseInputRaw[T](input: ParserInput,
parser: P[_] => P[T],
verboseFailures: Boolean = false,
startIndex: Int = 0,
traceIndex: Int = -1,
instrument: Instrument = null,
enableLogging: Boolean = true): ParsingRun[T] = {
clearAll()
fastparse.parseInputRaw(input, parser, verboseFailures, startIndex, traceIndex, instrument, enableLogging)
}

}

/* See discussion in https://github.com/com-lihaoyi/fastparse/discussions/301
//... other rules of the grammar as above. The changes are only for `x_times` and `x_other`:
def x_times[$: P]: P[R] = P(x_other_cached ~ ("*" ~ x_other_cached).rep).map { case (i, is) => i * is.product }
def x_other[$: P]: P[R] = P(number | ("(" ~ expr ~ ")"))
def x_other_cached[$](implicit p: P[$]): P[R] = cachedParser(cache_other, x_other)
// Need a separate cache for every memoized parser.
val cache_other = mutable.Map[Int, PRunData]()
// Need to do cache_other.clear() between different calls to parse an expression.
val n = 500
cache_other.clear()
assert(parse("(" * (n - 1) + "1" + ")" * (n - 1), program(_)).get.value == 1)
cache_other.clear()
assert(parse("1+" + "(1+" * (n - 1) + "1" + ")" * (n - 1), program(_)).get.value == n + 1)
*/
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
package io.chymyst.fastparse.unit

import com.eed3si9n.expecty.Expecty.expect
import fastparse.NoWhitespace._
import fastparse._
import io.chymyst.fastparse.Memoize
import io.chymyst.test.TestTimings
import munit.FunSuite

class MemoizeTest extends FunSuite with TestTimings {

test("slow grammar becomes faster after memoization") {
// Integer calculator program: 1+2*3-(4-5)*6 and so on. No spaces, for simplicity.
def program1[$: P]: P[Int] = P(expr1 ~ End)
def expr1[$: P]: P[Int] = P(minus1 | plus1)
def minus1[$: P] = P(times1 ~ "-" ~ expr1)
.map { case (x, y) => x - y }
def plus1[$: P] = P(times1 ~ ("+" ~ expr1).rep)
.map { case (i, is) => i + is.sum }
def times1[$: P] = P(other1 ~ ("*" ~ other1).rep)
.map { case (i, is) => i * is.product }
def other1[$: P]: P[Int] = P(number | ("(" ~ expr1 ~ ")"))
def number[$: P] = P(CharIn("0-9").rep(1))
.!.map(_.toInt)
// Verify that this works as expected.
assert(fastparse.parse("123*(1+1)", program1(_)).get.value == 246)
assert(fastparse.parse("123*1+1", program1(_)).get.value == 124)
assert(fastparse.parse("123*1-1", program1(_)).get.value == 122)
assert(fastparse.parse("123*(1-1)", program1(_)).get.value == 0)

// Parse an expression of the form `(((((...(1)...)))))`.
val n = 23
val (result1, elapsed1) = elapsedNanos(fastparse.parse("(" * (n - 1) + "1" + ")" * (n - 1), program1(_)))
assert(result1.get.value == 1)

// The same parsing after memoization.
import io.chymyst.fastparse.Memoize.MemoizeParser
def program2[$: P]: P[Int] = P(expr2 ~ End)
def expr2[$: P]: P[Int] = P(minus2 | plus2)
def minus2[$: P] = P(times2 ~ "-" ~ expr2)
.map { case (x, y) => x - y }
def plus2[$: P] = P(times2 ~ ("+" ~ expr2).rep)
.map { case (i, is) => i + is.sum }
def times2[$: P] = P(other2 ~ ("*" ~ other2).rep)
.map { case (i, is) => i * is.product }
def other2[$: P]: P[Int] = P(number | ("(" ~ expr2 ~ ")")).memoize

val (result2, elapsed2) = elapsedNanos(Memoize.parse("(" * (n - 1) + "1" + ")" * (n - 1), program2(_)))
assert(result2.get.value == 1)
// Verify that the memoized parser works as expected.
assert(Memoize.parse("123*(1+1)", program2(_)).get.value == 246)
assert(Memoize.parse("123*1+1", program2(_)).get.value == 124)
assert(Memoize.parse("123*1-1", program2(_)).get.value == 122)
assert(Memoize.parse("123*(1-1)", program2(_)).get.value == 0)

println(s"before memoization: ${elapsed1/1e9}, after memoization: ${elapsed2/1e9}, statistics: ${Memoize.statistics}")
// Memoization should speed up at least 100 times in this example.
expect(elapsed1 > elapsed2 * 100)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import java.io.{ByteArrayOutputStream, FileInputStream}
import java.nio.file.{Files, Paths}

class PerfTest extends FunSuite with ResourceFiles with TestTimings {

import sourcecode._
test("create yaml from realistic example 1") {
val file = resourceAsFile("yaml-perftest/create_yaml.dhall").get
val options = YamlOptions()
Expand Down

0 comments on commit 471c411

Please sign in to comment.