Skip to content

Commit

Permalink
Remove more lens
Browse files Browse the repository at this point in the history
  • Loading branch information
vladimirlogachev committed Sep 27, 2024
1 parent 768b120 commit 6690b38
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 51 deletions.
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
package com.wavesplatform.lang.v1.estimator.v2


import com.wavesplatform.lang.v1.FunctionHeader
import com.wavesplatform.lang.v1.compiler.Terms.FUNC
import com.wavesplatform.lang.v1.estimator.EstimationError
import com.wavesplatform.lang.v1.estimator.v2.EstimatorContext.EvalM
import com.wavesplatform.lang.v1.task.TaskM

private[v2] case class EstimatorContext(
letDefs: Map[String, (Boolean, EvalM[Long])],
predefFuncs: Map[FunctionHeader, Long],
userFuncs: Map[FunctionHeader, FUNC] = Map.empty,
overlappedRefs: Map[String, (Boolean, EvalM[Long])] = Map.empty
letDefs: Map[String, (Boolean, EvalM[Long])],
predefFuncs: Map[FunctionHeader, Long],
userFuncs: Map[FunctionHeader, FUNC] = Map.empty,
overlappedRefs: Map[String, (Boolean, EvalM[Long])] = Map.empty
)

private[v2] object EstimatorContext {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ object ScriptEstimatorV2 extends ScriptEstimator {
local {
for {
_ <- checkFuncCtx(func)
_ <- update(ec => ec.copy(userFuncs = ec.userFuncs + (FunctionHeader.User(func.name) -> func)) )
_ <- update(ec => ec.copy(userFuncs = ec.userFuncs + (FunctionHeader.User(func.name) -> func)))
r <- evalExpr(inner)
} yield r + 5
}
Expand Down Expand Up @@ -107,18 +107,22 @@ object ScriptEstimatorV2 extends ScriptEstimator {
_ <- update(ec => ec.copy(letDefs = ec.letDefs ++ ctx.overlappedRefs))
overlapped = func.args.flatMap(arg => ctx.letDefs.get(arg).map((arg, _))).toMap
ctxArgs = func.args.map((_, (false, const(1)))).toMap
_ <- update(ec => ec.copy(
letDefs = ec.letDefs ++ ctxArgs,
overlappedRefs = ec.overlappedRefs ++ overlapped
))
_ <- update(ec =>
ec.copy(
letDefs = ec.letDefs ++ ctxArgs,
overlappedRefs = ec.overlappedRefs ++ overlapped
)
)

bodyComplexity <- evalExpr(func.body).map(_ + func.args.size * 5)
evaluatedCtx <- get[Id, EstimatorContext, EstimationError]
overlappedChanges = overlapped.map { case ref @ (name, _) => evaluatedCtx.letDefs.get(name).map((name, _)).getOrElse(ref) }
_ <- update(ec => ec.copy(
letDefs = ec.letDefs -- ctxArgs.keys ++ overlapped,
overlappedRefs = ec.overlappedRefs ++ overlappedChanges
))
_ <- update(ec =>
ec.copy(
letDefs = ec.letDefs -- ctxArgs.keys ++ overlapped,
overlappedRefs = ec.overlappedRefs ++ overlappedChanges
)
)
} yield bodyComplexity + argsComplexity

private def evalFuncArgs(args: List[EXPR]): EvalM[Long] =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import com.wavesplatform.lang.v1.estimator.EstimationError
import com.wavesplatform.lang.v1.estimator.v3.EstimatorContext.EvalM
import com.wavesplatform.lang.v1.task.TaskM
import monix.eval.Coeval
import shapeless.{Lens, lens}

private[v3] case class EstimatorContext(
funcs: Map[FunctionHeader, (Coeval[Long], Set[String])],
Expand All @@ -18,9 +17,4 @@ private[v3] case class EstimatorContext(

private[v3] object EstimatorContext {
type EvalM[A] = TaskM[EstimatorContext, EstimationError, A]

object Lenses {
val funcs: Lens[EstimatorContext, Map[FunctionHeader, (Coeval[Long], Set[String])]] = lens[EstimatorContext] >> Symbol("funcs")
val usedRefs: Lens[EstimatorContext, Set[String]] = lens[EstimatorContext] >> Symbol("usedRefs")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import com.wavesplatform.lang.v1.FunctionHeader
import com.wavesplatform.lang.v1.FunctionHeader.User
import com.wavesplatform.lang.v1.compiler.Terms.*
import com.wavesplatform.lang.v1.estimator.v3.EstimatorContext.EvalM
import com.wavesplatform.lang.v1.estimator.v3.EstimatorContext.Lenses.*
import com.wavesplatform.lang.v1.estimator.{EstimationError, ScriptEstimator}
import com.wavesplatform.lang.v1.task.imports.*
import monix.eval.Coeval
Expand Down Expand Up @@ -82,7 +81,7 @@ case class ScriptEstimatorV3(fixOverflow: Boolean, overhead: Boolean, letFixes:
letCosts <- usedRefs.toSeq.traverse { ref =>
local {
for {
_ <- update(funcs.set(_)(startCtx.funcs))
_ <- update(ec => ec.copy(funcs = startCtx.funcs))
cost <- ctx.globalLetEvals.getOrElse(ref, zero)
} yield cost
}
Expand All @@ -100,22 +99,18 @@ case class ScriptEstimatorV3(fixOverflow: Boolean, overhead: Boolean, letFixes:
}

private def beforeNextExprEval(let: LET, eval: EvalM[Long]): EvalM[Unit] =
update(ctx =>
usedRefs
.modify(ctx)(_ - let.name)
.copy(refsCosts = ctx.refsCosts + (let.name -> local(eval)))
)
update(ctx => ctx.copy(usedRefs = ctx.usedRefs - let.name, refsCosts = ctx.refsCosts + (let.name -> local(eval))))

private def afterNextExprEval(let: LET, startCtx: EstimatorContext): EvalM[Unit] =
update(ctx =>
usedRefs
.modify(ctx)(r => if (startCtx.usedRefs.contains(let.name)) r + let.name else r - let.name)
.copy(refsCosts =
ctx.copy(
usedRefs = if (startCtx.usedRefs.contains(let.name)) ctx.usedRefs + let.name else ctx.usedRefs - let.name,
refsCosts =
if (startCtx.refsCosts.contains(let.name))
ctx.refsCosts + (let.name -> startCtx.refsCosts(let.name))
else
ctx.refsCosts - let.name
)
)
)

private def evalFuncBlock(func: FUNC, nextExpr: EXPR, activeFuncArgs: Set[String], globalDeclarationsMode: Boolean): EvalM[Long] =
Expand All @@ -142,14 +137,12 @@ case class ScriptEstimatorV3(fixOverflow: Boolean, overhead: Boolean, letFixes:
_ <- set[Id, EstimatorContext, EstimationError](ctx.copy(globalFunctionsCosts = ctx.globalFunctionsCosts + (name -> totalCost)))
} yield ()

private def handleUsedRefs(name: String, cost: Long, ctx: EstimatorContext, refsUsedInBody: Set[String]): EvalM[Unit] =
update(
(funcs ~ usedRefs).modify(_) { case (funcs, _) =>
(
funcs + (User(name) -> (Coeval.now(cost), refsUsedInBody)),
ctx.usedRefs
)
}
private def handleUsedRefs(name: String, cost: Long, startCtx: EstimatorContext, refsUsedInBody: Set[String]): EvalM[Unit] =
update(ec =>
ec.copy(
funcs = ec.funcs + (User(name) -> (Coeval.now(cost), refsUsedInBody)),
usedRefs = startCtx.usedRefs
)
)

private def evalIF(cond: EXPR, ifTrue: EXPR, ifFalse: EXPR, activeFuncArgs: Set[String]): EvalM[Long] =
Expand All @@ -165,7 +158,7 @@ case class ScriptEstimatorV3(fixOverflow: Boolean, overhead: Boolean, letFixes:
if (activeFuncArgs.contains(key) && letFixes)
const(overheadCost)
else
update(usedRefs.modify(_)(_ + key)).map(_ => overheadCost)
update(ec => ec.copy(usedRefs = ec.usedRefs + key)).map(_ => overheadCost)

private def evalGetter(expr: EXPR, activeFuncArgs: Set[String]): EvalM[Long] =
evalExpr(expr, activeFuncArgs).flatMap(sum(_, overheadCost))
Expand All @@ -187,18 +180,15 @@ case class ScriptEstimatorV3(fixOverflow: Boolean, overhead: Boolean, letFixes:
} yield result

private def setFuncToCtx(header: FunctionHeader, bodyCost: Coeval[Long], bodyUsedRefs: Set[EstimationError]): EvalM[Unit] =
update(
(funcs ~ usedRefs).modify(_) { case (funcs, usedRefs) =>
(
funcs + (header -> (bodyCost, Set())),
usedRefs ++ bodyUsedRefs
)
}
update(ec =>
ec.copy(
funcs = ec.funcs + (header -> (bodyCost, Set())),
usedRefs = ec.usedRefs ++ bodyUsedRefs
)
)

private def getFuncCost(header: FunctionHeader, ctx: EstimatorContext): EvalM[(Coeval[Long], Set[EstimationError])] =
funcs
.get(ctx)
ctx.funcs
.get(header)
.map(const)
.getOrElse(
Expand All @@ -217,9 +207,9 @@ case class ScriptEstimatorV3(fixOverflow: Boolean, overhead: Boolean, letFixes:
): EvalM[Long] =
for {
startCtx <- get[Id, EstimatorContext, EstimationError]
_ <- ctxFuncsOpt.fold(doNothing.void)(ctxFuncs => update(funcs.set(_)(ctxFuncs)))
_ <- ctxFuncsOpt.fold(doNothing.void)(ctxFuncs => update(ec => ec.copy(funcs = ctxFuncs)))
cost <- evalExpr(expr, activeFuncArgs)
_ <- update(funcs.set(_)(startCtx.funcs))
_ <- update(ec => ec.copy(funcs = startCtx.funcs))
} yield cost

private def withUsedRefs[A](eval: EvalM[A]): EvalM[(A, Set[String])] =
Expand Down

0 comments on commit 6690b38

Please sign in to comment.