From ef22c8886a2164ae5d416520f860e219e3a26bf1 Mon Sep 17 00:00:00 2001 From: Jost Berthold Date: Sat, 9 Mar 2024 15:27:06 +1100 Subject: [PATCH] Separate llvm pass from equation evaluator (#543) This PR changes the `iterateEquations` procedure to first evaluate all suitable (concrete, unevaluated, evaluatable) subterms using the LLVM library (if one is provided), before simplifying and evaluating the term using equations. Two separate caches are necessary to store the prior results. Built-in hook evaluation and equation-based evaluation are still combined (trying hooks first) because both should be done bottom-up while the LLVM-based evaluation should traverse top-down (to cut the traversal). Related to #539 (and #517 ) --- library/Booster/Pattern/ApplyEquations.hs | 131 +++++++++++++++++----- 1 file changed, 106 insertions(+), 25 deletions(-) diff --git a/library/Booster/Pattern/ApplyEquations.hs b/library/Booster/Pattern/ApplyEquations.hs index 0e4a9490f..06f6bf1ba 100644 --- a/library/Booster/Pattern/ApplyEquations.hs +++ b/library/Booster/Pattern/ApplyEquations.hs @@ -52,7 +52,7 @@ import Data.Foldable (toList, traverse_) import Data.List (partition) import Data.Map (Map) import Data.Map qualified as Map -import Data.Maybe (catMaybes, fromJust, fromMaybe, isJust) +import Data.Maybe (catMaybes, fromJust, fromMaybe) import Data.Sequence (Seq (..), pattern (:<|)) import Data.Sequence qualified as Seq import Data.Set (Set) @@ -155,7 +155,18 @@ data EquationState = EquationState , cache :: SimplifierCache } -type SimplifierCache = Map Term Term +data SimplifierCache = SimplifierCache {llvm, equations :: Map Term Term} + deriving stock (Show) + +instance Semigroup SimplifierCache where + cache1 <> cache2 = + SimplifierCache (cache1.llvm <> cache2.llvm) (cache1.equations <> cache2.equations) + +instance Monoid SimplifierCache where + mempty = SimplifierCache mempty mempty + +data CacheTag = LLVM | Equations + deriving stock (Show) data EquationMetadata = EquationMetadata { location :: Maybe Location @@ -246,7 +257,7 @@ isMatchFailure _ = False isSuccess EquationApplied{} = True isSuccess _ = False -startState :: Map Term Term -> EquationState +startState :: SimplifierCache -> EquationState startState cache = EquationState { termStack = mempty @@ -295,11 +306,19 @@ popRecursion = do then throw $ InternalError "Trying to pop an empty recursion stack" else eqState $ put s{recursionStack = tail s.recursionStack} -toCache :: MonadLoggerIO io => Term -> Term -> EquationT io () -toCache orig result = eqState . modify $ \s -> s{cache = Map.insert orig result s.cache} +toCache :: MonadLoggerIO io => CacheTag -> Term -> Term -> EquationT io () +toCache tag orig result = eqState . modify $ \s -> s{cache = updateCache tag s.cache} + where + insertInto = Map.insert orig result + updateCache LLVM cache = cache{llvm = insertInto cache.llvm} + updateCache Equations cache = cache{equations = insertInto cache.equations} -fromCache :: MonadLoggerIO io => Term -> EquationT io (Maybe Term) -fromCache t = eqState $ Map.lookup t <$> gets (.cache) +fromCache :: MonadLoggerIO io => CacheTag -> Term -> EquationT io (Maybe Term) +fromCache tag t = eqState $ Map.lookup t <$> gets (select tag . (.cache)) + where + select :: CacheTag -> SimplifierCache -> Map Term Term + select LLVM = (.llvm) + select Equations = (.equations) checkForLoop :: MonadLoggerIO io => Term -> EquationT io () checkForLoop t = do @@ -366,12 +385,86 @@ iterateEquations direction preference startTerm = do throw $ TooManyIterations currentCount startTerm currentTerm pushTerm currentTerm + -- simplify the term using the LLVM backend first + llvmResult <- llvmSimplify currentTerm + -- NB llvmSimplify is idempotent. No need to iterate if + -- the equation evaluation does not change the term any more. + resetChanged -- evaluate functions and simplify (recursively at each level) - newTerm <- applyTerm direction preference currentTerm + newTerm <- applyTerm direction preference llvmResult changeFlag <- getChanged if changeFlag then checkForLoop newTerm >> resetChanged >> go newTerm - else pure currentTerm + else pure llvmResult + +llvmSimplify :: forall io. MonadLoggerIO io => Term -> EquationT io Term +llvmSimplify t = do + config <- getConfig + case config.llvmApi of + Nothing -> pure t + Just api -> do + logOtherNS "booster" (LevelOther "Simplify") "Calling LLVM simplification" + llvmEval config.definition api t + +llvmEval :: MonadLoggerIO io => KoreDefinition -> LLVM.API -> Term -> EquationT io Term +llvmEval definition api = eval + where + eval t@(Term attributes _) + | attributes.isEvaluated = pure t + | otherwise = + fromCache LLVM t >>= \case + Nothing -> + eval' t + Just cached -> + when (cached /= t) setChanged >> pure cached + + eval' t@(Term attributes _) + | attributes.isEvaluated = pure t + | isConcrete t && attributes.canBeEvaluated = do + LLVM.simplifyTerm api definition t (sortOfTerm t) + >>= \case + Left e -> throw $ UndefinedTerm t e + Right result -> do + when (result /= t) $ do + setChanged + emitEquationTrace t Nothing (Just "LLVM") Nothing $ Success result + toCache LLVM t result + pure result + | otherwise = do + result <- descend t + toCache LLVM t result + pure result + descend = \case + dv@DomainValue{} -> pure dv + v@Var{} -> pure v + Injection src trg t -> + Injection src trg <$> eval t -- no injection simplification + AndTerm arg1 arg2 -> + AndTerm -- no \and simplification + <$> eval arg1 + <*> eval arg2 + SymbolApplication sym sorts args -> + SymbolApplication sym sorts <$> mapM eval args + KMap def keyVals rest -> + KMap def + <$> mapM (\(k, v) -> (,) <$> eval k <*> eval v) keyVals + <*> maybe (pure Nothing) ((Just <$>) . eval) rest + KList def heads rest -> + KList def + <$> mapM eval heads + <*> maybe + (pure Nothing) + ( (Just <$>) + . \(mid, tails) -> + (,) + <$> eval mid + <*> mapM eval tails + ) + rest + KSet def keyVals rest -> + KSet def + <$> mapM eval keyVals + <*> maybe (pure Nothing) ((Just <$>) . eval) rest ---------------------------------------- -- Interface functions @@ -451,6 +544,7 @@ applyTerm :: Term -> EquationT io Term applyTerm direction pref trm = do + logOtherNS "booster" (LevelOther "Simplify") "Calling equation-based simplifier" config <- getConfig -- avoid re-reading config at every node descend config trm where @@ -458,23 +552,10 @@ applyTerm direction pref trm = do descend config t@(Term attributes _) | attributes.isEvaluated = pure t | otherwise = - fromCache t >>= \case + fromCache Equations t >>= \case Nothing -> do - simplified <- - if isConcrete t && isJust config.llvmApi && attributes.canBeEvaluated - then -- LLVM simplification proceeds top-down and cuts the descent - - LLVM.simplifyTerm (fromJust config.llvmApi) config.definition t (sortOfTerm t) - >>= \case - Left e -> throw $ UndefinedTerm t e - Right result -> do - when (result /= t) $ do - setChanged - emitEquationTrace t Nothing (Just "LLVM") Nothing $ Success result - pure result - else -- use equations - apply config t - toCache t simplified + simplified <- apply config t + toCache Equations t simplified pure simplified Just cached -> do when (t /= cached) $ do