Skip to content

Commit

Permalink
Check pattern consistency before starting rewriting in "execute"
Browse files Browse the repository at this point in the history
  • Loading branch information
geo2a committed Aug 7, 2024
1 parent 8eeb9ce commit 2025987
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 33 deletions.
62 changes: 42 additions & 20 deletions booster/library/Booster/JsonRpc.hs
Original file line number Diff line number Diff line change
Expand Up @@ -150,26 +150,48 @@ respond stateVar request =

solver <- maybe (SMT.noSolver) (SMT.initSolver def) mSMTOptions

logger <- getLogger
prettyModifiers <- getPrettyModifiers
let rewriteConfig =
RewriteConfig
{ definition = def
, llvmApi = mLlvmLibrary
, smtSolver = solver
, varsToAvoid = substVars
, doTracing
, logger
, prettyModifiers
, mbMaxDepth = mbDepth
, mbSimplify = rewriteOpts.interimSimplification
, cutLabels = cutPoints
, terminalLabels = terminals
}
result <-
performRewrite rewriteConfig substPat
SMT.finaliseSolver solver
pure $ execResponse req result substitution unsupported
-- check input pattern's consistency before starting rewriting
evaluatedInitialPattern <-
ApplyEquations.evaluatePattern
def
mLlvmLibrary
solver
mempty
substPat

case evaluatedInitialPattern of
(Left ApplyEquations.SideConditionFalse{}, _) -> do
-- input pattern's constraints are Bottom, return Vacuous
pure $
execResponse
req
(0, mempty, RewriteTrivial substPat)
substitution
unsupported
(Left other, _) ->
pure . Left . RpcError.backendError $ RpcError.Aborted (Text.pack . constructorName $ other)
(Right newPattern, simplifierCache) -> do
logger <- getLogger
prettyModifiers <- getPrettyModifiers
let rewriteConfig =
RewriteConfig
{ definition = def
, llvmApi = mLlvmLibrary
, smtSolver = solver
, varsToAvoid = substVars
, doTracing
, logger
, prettyModifiers
, mbMaxDepth = mbDepth
, mbSimplify = rewriteOpts.interimSimplification
, cutLabels = cutPoints
, terminalLabels = terminals
}

result <-
performRewrite rewriteConfig simplifierCache newPattern
SMT.finaliseSolver solver
pure $ execResponse req result substitution unsupported
RpcTypes.AddModule RpcTypes.AddModuleRequest{_module, nameAsId = nameAsId'} -> Booster.Log.withContext CtxAddModule $ runExceptT $ do
-- block other request executions while modifying the server state
state <- liftIO $ takeMVar stateVar
Expand Down
19 changes: 10 additions & 9 deletions booster/library/Booster/Pattern/Rewrite.hs
Original file line number Diff line number Diff line change
Expand Up @@ -692,9 +692,10 @@ performRewrite ::
forall io.
LoggerMIO io =>
RewriteConfig ->
SimplifierCache ->
Pattern ->
io (Natural, Seq (RewriteTrace ()), RewriteResult Pattern)
performRewrite rewriteConfig pat = do
performRewrite rewriteConfig initialCache pat = do
(rr, RewriteStepsState{counter, traces}) <-
flip runStateT rewriteStart $ doSteps False pat
pure (counter, traces, rr)
Expand All @@ -710,6 +711,14 @@ performRewrite rewriteConfig pat = do
, terminalLabels
} = rewriteConfig

rewriteStart :: RewriteStepsState
rewriteStart =
RewriteStepsState
{ counter = 0
, traces = mempty
, simplifierCache = initialCache
}

logDepth = withContext CtxDepth . logMessage

depthReached n = maybe False (n >=) mbMaxDepth
Expand Down Expand Up @@ -907,11 +916,3 @@ data RewriteStepsState = RewriteStepsState
, traces :: !(Seq (RewriteTrace ()))
, simplifierCache :: SimplifierCache
}

rewriteStart :: RewriteStepsState
rewriteStart =
RewriteStepsState
{ counter = 0
, traces = mempty
, simplifierCache = mempty
}
8 changes: 4 additions & 4 deletions booster/unit-tests/Test/Booster/Pattern/Rewrite.hs
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ runRewrite t = do
conf <- testConf
(counter, _, res) <-
runNoLoggingT $
performRewrite conf $
performRewrite conf mempty $
Pattern_ t
pure (counter, fmap (.term) res)

Expand Down Expand Up @@ -438,7 +438,7 @@ supportsDepthControl =
rewritesToDepth (MaxDepth depth) (Steps n) t t' f = do
conf <- testConf
(counter, _, res) <-
runNoLoggingT $ performRewrite conf{mbMaxDepth = Just depth} $ Pattern_ t
runNoLoggingT $ performRewrite conf{mbMaxDepth = Just depth} mempty $ Pattern_ t
(counter, fmap (.term) res) @?= (n, f t')

supportsCutPoints :: TestTree
Expand Down Expand Up @@ -492,7 +492,7 @@ supportsCutPoints =
conf <- testConf
(counter, _, res) <-
runNoLoggingT $
performRewrite conf{cutLabels = [lbl]} $
performRewrite conf{cutLabels = [lbl]} mempty $
Pattern_ t
(counter, fmap (.term) res) @?= (n, f t')

Expand Down Expand Up @@ -524,5 +524,5 @@ supportsTerminalRules =
rewritesToTerminal lbl (Steps n) t t' f = do
conf <- testConf
(counter, _, res) <-
runNoLoggingT $ performRewrite conf{terminalLabels = [lbl]} $ Pattern_ t
runNoLoggingT $ performRewrite conf{terminalLabels = [lbl]} mempty $ Pattern_ t
(counter, fmap (.term) res) @?= (n, f t')

0 comments on commit 2025987

Please sign in to comment.