diff --git a/library/Booster/Pattern/ApplyEquations.hs b/library/Booster/Pattern/ApplyEquations.hs index 0e4a9490f..06f6bf1ba 100644 --- a/library/Booster/Pattern/ApplyEquations.hs +++ b/library/Booster/Pattern/ApplyEquations.hs @@ -52,7 +52,7 @@ import Data.Foldable (toList, traverse_) import Data.List (partition) import Data.Map (Map) import Data.Map qualified as Map -import Data.Maybe (catMaybes, fromJust, fromMaybe, isJust) +import Data.Maybe (catMaybes, fromJust, fromMaybe) import Data.Sequence (Seq (..), pattern (:<|)) import Data.Sequence qualified as Seq import Data.Set (Set) @@ -155,7 +155,18 @@ data EquationState = EquationState , cache :: SimplifierCache } -type SimplifierCache = Map Term Term +data SimplifierCache = SimplifierCache {llvm, equations :: Map Term Term} + deriving stock (Show) + +instance Semigroup SimplifierCache where + cache1 <> cache2 = + SimplifierCache (cache1.llvm <> cache2.llvm) (cache1.equations <> cache2.equations) + +instance Monoid SimplifierCache where + mempty = SimplifierCache mempty mempty + +data CacheTag = LLVM | Equations + deriving stock (Show) data EquationMetadata = EquationMetadata { location :: Maybe Location @@ -246,7 +257,7 @@ isMatchFailure _ = False isSuccess EquationApplied{} = True isSuccess _ = False -startState :: Map Term Term -> EquationState +startState :: SimplifierCache -> EquationState startState cache = EquationState { termStack = mempty @@ -295,11 +306,19 @@ popRecursion = do then throw $ InternalError "Trying to pop an empty recursion stack" else eqState $ put s{recursionStack = tail s.recursionStack} -toCache :: MonadLoggerIO io => Term -> Term -> EquationT io () -toCache orig result = eqState . modify $ \s -> s{cache = Map.insert orig result s.cache} +toCache :: MonadLoggerIO io => CacheTag -> Term -> Term -> EquationT io () +toCache tag orig result = eqState . modify $ \s -> s{cache = updateCache tag s.cache} + where + insertInto = Map.insert orig result + updateCache LLVM cache = cache{llvm = insertInto cache.llvm} + updateCache Equations cache = cache{equations = insertInto cache.equations} -fromCache :: MonadLoggerIO io => Term -> EquationT io (Maybe Term) -fromCache t = eqState $ Map.lookup t <$> gets (.cache) +fromCache :: MonadLoggerIO io => CacheTag -> Term -> EquationT io (Maybe Term) +fromCache tag t = eqState $ Map.lookup t <$> gets (select tag . (.cache)) + where + select :: CacheTag -> SimplifierCache -> Map Term Term + select LLVM = (.llvm) + select Equations = (.equations) checkForLoop :: MonadLoggerIO io => Term -> EquationT io () checkForLoop t = do @@ -366,12 +385,86 @@ iterateEquations direction preference startTerm = do throw $ TooManyIterations currentCount startTerm currentTerm pushTerm currentTerm + -- simplify the term using the LLVM backend first + llvmResult <- llvmSimplify currentTerm + -- NB llvmSimplify is idempotent. No need to iterate if + -- the equation evaluation does not change the term any more. + resetChanged -- evaluate functions and simplify (recursively at each level) - newTerm <- applyTerm direction preference currentTerm + newTerm <- applyTerm direction preference llvmResult changeFlag <- getChanged if changeFlag then checkForLoop newTerm >> resetChanged >> go newTerm - else pure currentTerm + else pure llvmResult + +llvmSimplify :: forall io. MonadLoggerIO io => Term -> EquationT io Term +llvmSimplify t = do + config <- getConfig + case config.llvmApi of + Nothing -> pure t + Just api -> do + logOtherNS "booster" (LevelOther "Simplify") "Calling LLVM simplification" + llvmEval config.definition api t + +llvmEval :: MonadLoggerIO io => KoreDefinition -> LLVM.API -> Term -> EquationT io Term +llvmEval definition api = eval + where + eval t@(Term attributes _) + | attributes.isEvaluated = pure t + | otherwise = + fromCache LLVM t >>= \case + Nothing -> + eval' t + Just cached -> + when (cached /= t) setChanged >> pure cached + + eval' t@(Term attributes _) + | attributes.isEvaluated = pure t + | isConcrete t && attributes.canBeEvaluated = do + LLVM.simplifyTerm api definition t (sortOfTerm t) + >>= \case + Left e -> throw $ UndefinedTerm t e + Right result -> do + when (result /= t) $ do + setChanged + emitEquationTrace t Nothing (Just "LLVM") Nothing $ Success result + toCache LLVM t result + pure result + | otherwise = do + result <- descend t + toCache LLVM t result + pure result + descend = \case + dv@DomainValue{} -> pure dv + v@Var{} -> pure v + Injection src trg t -> + Injection src trg <$> eval t -- no injection simplification + AndTerm arg1 arg2 -> + AndTerm -- no \and simplification + <$> eval arg1 + <*> eval arg2 + SymbolApplication sym sorts args -> + SymbolApplication sym sorts <$> mapM eval args + KMap def keyVals rest -> + KMap def + <$> mapM (\(k, v) -> (,) <$> eval k <*> eval v) keyVals + <*> maybe (pure Nothing) ((Just <$>) . eval) rest + KList def heads rest -> + KList def + <$> mapM eval heads + <*> maybe + (pure Nothing) + ( (Just <$>) + . \(mid, tails) -> + (,) + <$> eval mid + <*> mapM eval tails + ) + rest + KSet def keyVals rest -> + KSet def + <$> mapM eval keyVals + <*> maybe (pure Nothing) ((Just <$>) . eval) rest ---------------------------------------- -- Interface functions @@ -451,6 +544,7 @@ applyTerm :: Term -> EquationT io Term applyTerm direction pref trm = do + logOtherNS "booster" (LevelOther "Simplify") "Calling equation-based simplifier" config <- getConfig -- avoid re-reading config at every node descend config trm where @@ -458,23 +552,10 @@ applyTerm direction pref trm = do descend config t@(Term attributes _) | attributes.isEvaluated = pure t | otherwise = - fromCache t >>= \case + fromCache Equations t >>= \case Nothing -> do - simplified <- - if isConcrete t && isJust config.llvmApi && attributes.canBeEvaluated - then -- LLVM simplification proceeds top-down and cuts the descent - - LLVM.simplifyTerm (fromJust config.llvmApi) config.definition t (sortOfTerm t) - >>= \case - Left e -> throw $ UndefinedTerm t e - Right result -> do - when (result /= t) $ do - setChanged - emitEquationTrace t Nothing (Just "LLVM") Nothing $ Success result - pure result - else -- use equations - apply config t - toCache t simplified + simplified <- apply config t + toCache Equations t simplified pure simplified Just cached -> do when (t /= cached) $ do