diff --git a/horus-check.cabal b/horus-check.cabal index 86b33512..ed0bdaa7 100644 --- a/horus-check.cabal +++ b/horus-check.cabal @@ -85,6 +85,7 @@ library Horus.Preprocessor Horus.Preprocessor.Runner Horus.Preprocessor.Solvers + Horus.SMTHygiene Horus.SW.Builtin Horus.SW.CairoType Horus.SW.CairoType.JSON diff --git a/src/Horus/CFGBuild.hs b/src/Horus/CFGBuild.hs index 4899a61c..49317aa6 100644 --- a/src/Horus/CFGBuild.hs +++ b/src/Horus/CFGBuild.hs @@ -13,6 +13,7 @@ module Horus.CFGBuild , mkInv , Vertex (..) , getVerts + , isOptimising , isPreCheckingVertex ) where @@ -81,6 +82,9 @@ data Vertex = Vertex } deriving (Show) +isOptimising :: Vertex -> Bool +isOptimising = isJust . v_preCheckedF + instance Eq Vertex where (==) lhs rhs = v_name lhs == v_name rhs @@ -273,7 +277,7 @@ addArcsFrom inlinables prog rows seg@(Segment s) vFrom -- in which Pre(G) is assumed to hold at the PC of the G call site, as it will have -- been checked by the module induced by the ghost vertex. ghostV <- addOptimisingVertex (nextSegmentLabel seg) callee - pre <- maybe (mkPre Expr.True) mkPre . fs'_pre <$> getFuncSpec callee + pre <- maybe (mkPre Expr.True) (mkPre . fst) . fs'_pre <$> getFuncSpec callee -- Important note on the way we deal with logical variables. These are @declare-d and -- their values can be bound in preconditions. They generate existentials which only occur @@ -333,7 +337,7 @@ addAssertions inlinables identifiers = do (Nothing, Nothing) -> when (fu_pc f `Set.notMember` Set.map sf_pc inlinables) $ for_ retVs (`addAssertion` mkPost Expr.True) - _ -> for_ retVs (`addAssertion` maybe (mkPost Expr.True) mkPost post) + _ -> for_ retVs (`addAssertion` maybe (mkPost Expr.True) (mkPost . fst) post) ILabel pc -> whenJustM (getInvariant idName) $ \inv -> getSalientVertex pc >>= (`addAssertion` mkInv inv) diff --git a/src/Horus/CairoSemantics.hs b/src/Horus/CairoSemantics.hs index a3f58856..de8fcf37 100644 --- a/src/Horus/CairoSemantics.hs +++ b/src/Horus/CairoSemantics.hs @@ -41,7 +41,7 @@ import Horus.Expr.Vars , pattern StorageVar ) import Horus.Expr.Vars qualified as Vars -import Horus.FunctionAnalysis (ScopedFunction (sf_pc)) +import Horus.FunctionAnalysis (ScopedFunction (sf_pc, sf_scopedName)) import Horus.Instruction ( ApUpdate (..) , Instruction (..) @@ -53,18 +53,20 @@ import Horus.Instruction , getNextPcInlinedWithFallback , isCall , isRet + , toSemiAsmUnsafe , uncheckedCallDestination ) import Horus.Label (Label (..), tShowLabel) import Horus.Module (Module (..), apEqualsFp, isPreChecking) import Horus.Program (ApTracking (..)) +import Horus.SMTHygiene (AssertionMisc, commentAbove, commentBelow, commentRight, emptyMisc, magicHygieneConstant) import Horus.SW.Builtin (Builtin, BuiltinOffsets (..)) import Horus.SW.Builtin qualified as Builtin (name) import Horus.SW.FuncSpec (FuncSpec (..), FuncSpec', toFuncSpec) import Horus.SW.ScopedName (ScopedName) import Horus.SW.ScopedName qualified as ScopedName (fromText) import Horus.SW.Storage (Storage) -import Horus.SW.Storage qualified as Storage (equivalenceExpr) +import Horus.SW.Storage qualified as Storage import Horus.Util (enumerate, safeLast, tShow, whenJust, whenJustM) data MemoryVariable = MemoryVariable @@ -78,10 +80,11 @@ data AssertionType = PreAssertion | PostAssertion | InstructionSemanticsAssertion + | ApConstraintAssertion deriving (Eq, Show) data CairoSemanticsF a - = Assert' (Expr TBool) AssertionType a + = Assert' (Expr TBool) AssertionType AssertionMisc a | Expect' (Expr TBool) AssertionType a | DeclareMem (Expr TFelt) (Expr TFelt -> a) | DeclareLocalMem (Expr TFelt) (MemoryVariable -> a) @@ -89,6 +92,7 @@ data CairoSemanticsF a | GetBuiltinOffsets Label Builtin (Maybe BuiltinOffsets -> a) | GetCallee LabeledInst (ScopedFunction -> a) | GetFuncSpec ScopedFunction (FuncSpec' -> a) + | GetCurrentF (ScopedFunction -> a) | GetFunPc Label (Label -> a) | GetInlinable (Set.Set ScopedFunction -> a) | GetMemVars ([MemoryVariable] -> a) @@ -109,8 +113,8 @@ deriving instance Functor CairoSemanticsF type CairoSemanticsL = F CairoSemanticsF -assert' :: AssertionType -> Expr TBool -> CairoSemanticsL () -assert' assType a = liftF (Assert' a assType ()) +assert' :: AssertionType -> AssertionMisc -> Expr TBool -> CairoSemanticsL () +assert' assType misc a = liftF (Assert' a assType misc ()) expect' :: Expr TBool -> CairoSemanticsL () expect' a = liftF (Expect' a PostAssertion ()) @@ -127,6 +131,9 @@ getBuiltinOffsets l b = liftF (GetBuiltinOffsets l b id) getCallee :: LabeledInst -> CairoSemanticsL ScopedFunction getCallee call = liftF (GetCallee call id) +getCurrentF :: CairoSemanticsL ScopedFunction +getCurrentF = liftF (GetCurrentF id) + getFuncSpec :: ScopedFunction -> CairoSemanticsL FuncSpec' getFuncSpec name = liftF (GetFuncSpec name id) @@ -162,10 +169,27 @@ getStorage :: CairoSemanticsL Storage getStorage = liftF (GetStorage id) assert :: Expr TBool -> CairoSemanticsL () -assert a = assert' InstructionSemanticsAssertion =<< memoryRemoval a +assert a = assert' InstructionSemanticsAssertion emptyMisc =<< memoryRemoval a + +_assertWithComment :: Expr TBool -> Text -> CairoSemanticsL () +_assertWithComment a text = + assert' InstructionSemanticsAssertion (commentRight text emptyMisc) =<< memoryRemoval a + +assertWithComment' :: Expr TBool -> AssertionMisc -> CairoSemanticsL () +assertWithComment' a misc = + assert' InstructionSemanticsAssertion misc =<< memoryRemoval a + +_assertPre :: Expr TBool -> CairoSemanticsL () +_assertPre = assertPreWithComment' emptyMisc -assertPre :: Expr TBool -> CairoSemanticsL () -assertPre a = assert' PreAssertion =<< memoryRemoval a +commentHere' :: Text -> CairoSemanticsL () +commentHere' = assertWithComment' magicHygieneConstant . flip commentRight emptyMisc + +assertApConstraint :: Expr TBool -> CairoSemanticsL () +assertApConstraint a = assert' ApConstraintAssertion emptyMisc =<< memoryRemoval a + +assertPreWithComment' :: AssertionMisc -> Expr TBool -> CairoSemanticsL () +assertPreWithComment' misc a = assert' PreAssertion misc =<< memoryRemoval a expect :: Expr TBool -> CairoSemanticsL () expect a = expect' =<< memoryRemoval a @@ -289,7 +313,7 @@ moduleEndAp mdl = contain a storage update. -} encodeModule :: Module -> CairoSemanticsL () -encodeModule mdl@(Module (FuncSpec pre post storage) instrs oracle _ _ mbPreCheckedFuncWithCallStack) = do +encodeModule mdl@(Module (FuncSpec (pre, preDesc) (post, _) storage) instrs oracle _ _ mbPreCheckedFuncWithCallStack) = do enableStorage fp <- getFp apEnd <- moduleEndAp mdl @@ -297,7 +321,9 @@ encodeModule mdl@(Module (FuncSpec pre post storage) instrs oracle _ _ mbPreChec apStart <- moduleStartAp mdl assert (fp .<= apStart) - assertPre =<< prepare' apStart fp (pre .&& apEqualsFp) + assertPreWithComment' (commentAbove ("Module pre: " <> preDesc) emptyMisc) + =<< prepare' apStart fp (pre .&& apEqualsFp) + -- The last FP might be influenced in the optimizing case, we need to grab it as propagated -- by the encoding of the semantics of the call. lastFp <- @@ -406,19 +432,23 @@ mkInstructionConstraints :: Map (NonEmpty Label, Label) Bool -> (Int, LabeledInst) -> CairoSemanticsL (Maybe (Expr TFelt)) -mkInstructionConstraints instrs mbPreCheckedFuncWithCallStack jnzOracle (idx, lInst@(pc, Instruction{..})) = do +mkInstructionConstraints instrs mbPreCheckedFuncWithCallStack jnzOracle (idx, lInst@(pc, inst@Instruction{..})) = do fp <- getFp dst <- prepare pc fp (memory (regToVar i_dstRegister + fromInteger i_dstOffset)) case i_opCode of Call -> mkCallConstraints pc nextPc fp mbPreCheckedFuncWithCallStack =<< getCallee lInst - AssertEqual -> getRes fp lInst >>= \res -> assert (res .== dst) $> Nothing + AssertEqual -> + getRes fp lInst >>= \res -> + assertWithComment' (res .== dst) (commentRight (toSemiAsmUnsafe inst) emptyMisc) $> Nothing Nop -> do stackTrace <- getOracle case jnzOracle Map.!? (stackTrace, pc) of Just False -> assert (dst .== 0) $> Nothing Just True -> assert (dst ./= 0) $> Nothing Nothing -> pure Nothing - Ret -> pop $> Nothing + Ret -> do + currentF <- getCurrentF + commentHere' ("Inlining ended: " <> tShow (sf_scopedName currentF)) >> pop $> Nothing where nextPc = getNextPcInlinedWithFallback instrs idx @@ -435,7 +465,10 @@ mkCallConstraints pc nextPc fp mbPreCheckedFuncWithCallStack f = do nextAp <- prepare pc calleeFp (Vars.fp .== Vars.ap + 2) saveOldFp <- prepare pc fp (memory Vars.ap .== Vars.fp) setNextPc <- prepare pc fp (memory (Vars.ap + 1) .== fromIntegral (unLabel nextPc)) - assert (Expr.and [nextAp, saveOldFp, setNextPc]) + canInline <- isInlinable f + assertWithComment' + (Expr.and [nextAp, saveOldFp, setNextPc]) + (commentAbove (describeCall canInline <> tShow (sf_scopedName f)) emptyMisc) push stackFrame -- Considering we have already pushed the stackFrame by now, we need to make sure that either -- the function is inlinable and we'll encounter a 'ret', or we need to pop right away @@ -448,7 +481,7 @@ mkCallConstraints pc nextPc fp mbPreCheckedFuncWithCallStack f = do -- of the function being invoked must not be considered. guardWith (isInlinable f) (pure Nothing) $ do -- An inlined function will have a 'ret' at some point, do not pop here. - (FuncSpec pre post storage) <- getFuncSpec f <&> toFuncSpec + (FuncSpec (pre, preDesc) (post, postDesc) storage) <- getFuncSpec f <&> toFuncSpec let pre' = suffixLogicalVariables lvarSuffix pre post' = suffixLogicalVariables lvarSuffix post preparedPre <- prepare nextPc calleeFp =<< storageRemoval pre' @@ -462,13 +495,22 @@ mkCallConstraints pc nextPc fp mbPreCheckedFuncWithCallStack f = do -- However, pre will be checked in a separate 'optimising' module and we can therefore simply -- assert it holds here. If it does not, the corresponding pre-checking module will fail, -- thus failing the judgement for the entire function. - assert preparedPre - assert preparedPost + assertWithComment' + preparedPre + (commentAbove ("Pre of: " <> functionName <> " | " <> preDesc) emptyMisc) + assertWithComment' + preparedPost + ( commentBelow + ("Finished abstracting function: " <> functionName) + (commentAbove ("Post of: " <> functionName <> " | " <> postDesc) emptyMisc) + ) pure Nothing where lvarSuffix = "+" <> tShowLabel pc calleePc = sf_pc f stackFrame = (pc, calleePc) + functionName = tShow (sf_scopedName f) + describeCall inl = if inl then "Inlining function: " else "Abstracting function: " -- Determine whether the current function matches the function being optimised exactly - -- this necessitates comparing execution traces. isModuleCheckingPre = do @@ -498,15 +540,15 @@ mkApConstraints apEnd insts = do ap2 <- getAp pcNext fp <- getFp getApIncrement fp lInst >>= \case - Just apIncrement -> assert (ap1 + apIncrement .== ap2) - Nothing | not isNewStackframe -> assert (ap1 .< ap2) + Just apIncrement -> assertApConstraint (ap1 + apIncrement .== ap2) + Nothing | not isNewStackframe -> assertApConstraint (ap1 .< ap2) Nothing -> pure () lastAp <- getAp lastPc when (isRet lastInst) pop fp <- getFp getApIncrement fp lastLInst >>= \case - Just lastApIncrement -> assert (lastAp + lastApIncrement .== apEnd) - Nothing -> assert (lastAp .< apEnd) + Just lastApIncrement -> assertApConstraint (lastAp + lastApIncrement .== apEnd) + Nothing -> assertApConstraint (lastAp .< apEnd) where lastLInst@(lastPc, lastInst) = NonEmpty.last insts diff --git a/src/Horus/CairoSemantics/Runner.hs b/src/Horus/CairoSemantics/Runner.hs index 6ed768a8..c5578c8d 100644 --- a/src/Horus/CairoSemantics/Runner.hs +++ b/src/Horus/CairoSemantics/Runner.hs @@ -16,22 +16,23 @@ import Control.Monad.Reader , asks , runReaderT ) -import Control.Monad.State (MonadState (get), State, runState) +import Control.Monad.State (MonadState (get), State, gets, runState) import Data.Foldable (toList) import Data.Function ((&)) import Data.Functor (($>)) +import Data.List (partition) import Data.List qualified as List (find, tails) -import Data.Map qualified as Map (map, null, unionWith) +import Data.Map qualified as Map import Data.Maybe (mapMaybe) import Data.Singletons (sing) import Data.Some (foldSome) import Data.Text (Text) import Data.Text qualified as Text (intercalate) -import Lens.Micro (Lens', (%~), (<&>), (^.)) +import Lens.Micro (Lens', (%~), (<&>), (^.), _1, _2) import Lens.Micro.GHC () import Lens.Micro.Mtl (use, (%=), (.=), (<%=)) -import Horus.CairoSemantics (AssertionType (PreAssertion), CairoSemanticsF (..), CairoSemanticsL, MemoryVariable (..)) +import Horus.CairoSemantics (AssertionType (ApConstraintAssertion, PreAssertion), CairoSemanticsF (..), CairoSemanticsL, MemoryVariable (..)) import Horus.CallStack (CallStack, digestOfCallStack, pop, push, reset, stackTrace, top) import Horus.Command.SMT qualified as Command import Horus.ContractInfo (ContractInfo (..)) @@ -43,6 +44,7 @@ import Horus.Expr.Type (STy (..)) import Horus.Expr.Util (gatherNonStdFunctions) import Horus.Expr.Vars (prime, rcBound) import Horus.FunctionAnalysis (ScopedFunction (sf_scopedName)) +import Horus.SMTHygiene (AssertionMisc, commentBelow, encodeRestriction, withEmptyMisc) import Horus.SW.Builtin qualified as Builtin (rcBound) import Horus.SW.Storage (Storage) import Horus.SW.Storage qualified as Storage (read) @@ -61,7 +63,7 @@ builderToAss mv (ExistentialAss f) = f mv -- inlined. data ConstraintsState = ConstraintsState { cs_memoryVariables :: [MemoryVariable] - , cs_asserts :: [(AssertionBuilder, AssertionType)] + , cs_asserts :: [(AssertionBuilder, AssertionType, AssertionMisc)] , cs_expects :: [(Expr TBool, AssertionType)] , cs_nameCounter :: Int , cs_callStack :: CallStack @@ -70,7 +72,7 @@ data ConstraintsState = ConstraintsState csMemoryVariables :: Lens' ConstraintsState [MemoryVariable] csMemoryVariables lMod g = fmap (\x -> g{cs_memoryVariables = x}) (lMod (cs_memoryVariables g)) -csAsserts :: Lens' ConstraintsState [(AssertionBuilder, AssertionType)] +csAsserts :: Lens' ConstraintsState [(AssertionBuilder, AssertionType, AssertionMisc)] csAsserts lMod g = fmap (\x -> g{cs_asserts = x}) (lMod (cs_asserts g)) csExpects :: Lens' ConstraintsState [(Expr TBool, AssertionType)] @@ -121,7 +123,7 @@ interpret :: forall a. CairoSemanticsL a -> Impl a interpret = iterM exec where exec :: CairoSemanticsF (Impl a) -> Impl a - exec (Assert' a assType cont) = eConstraints . csAsserts %= ((QFAss a, assType) :) >> cont + exec (Assert' a assType misc cont) = eConstraints . csAsserts %= ((QFAss a, assType, misc) :) >> cont exec (Expect' a assType cont) = eConstraints . csExpects %= ((a, assType) :) >> cont exec (DeclareMem address cont) = do memVars <- use (eConstraints . csMemoryVariables) @@ -151,6 +153,10 @@ interpret = iterM exec exec (GetCallee inst cont) = do ci <- ask ci_getCallee ci inst >>= cont + exec (GetCurrentF cont) = do + (_, calledF) <- gets $ top . (^. csCallStack) . e_constraints + ci <- ask + cont (ci_functions ci Map.! calledF) exec (GetFuncSpec name cont) = do ci <- ask ci_getFuncSpec ci name & cont @@ -200,9 +206,9 @@ debugFriendlyModel ConstraintsState{..} = [ ["# Memory"] , memoryPairs , ["# Assert"] - , map (pprExpr . builderToAss cs_memoryVariables . fst) cs_asserts + , map (pprExpr . builderToAss cs_memoryVariables . (^. _1)) cs_asserts , ["# Expect"] - , map (pprExpr . fst) cs_expects + , map (pprExpr . (^. _1)) cs_expects ] where memoryPairs = @@ -222,10 +228,10 @@ restrictMemTail (mv0 : rest) = makeModel :: Bool -> ConstraintsState -> Integer -> Text makeModel checkPreOnly ConstraintsState{..} fPrime = - Text.intercalate "\n" (decls <> map (Command.assert fPrime) restrictions) + Text.intercalate "\n" (decls <> map (encodeRestriction fPrime) restrictions) where functions = - toList (foldMap gatherNonStdFunctions generalRestrictions <> gatherNonStdFunctions prime) + toList (foldMap (gatherNonStdFunctions . fst) generalRestrictions <> gatherNonStdFunctions prime) decls = map (foldSome Command.declare) functions rangeRestrictions = mapMaybe (foldSome restrictRange) functions memRestrictions = concatMap restrictMemTail (List.tails cs_memoryVariables) @@ -233,17 +239,32 @@ makeModel checkPreOnly ConstraintsState{..} fPrime = [Expr.const mv_addrName .== mv_addrExpr | MemoryVariable{..} <- cs_memoryVariables] -- If checking @pre only, only take `PreAssertion`s, no postconditions. - allowedAsserts = if checkPreOnly then filter ((== PreAssertion) . snd) cs_asserts else cs_asserts + allowedAsserts = if checkPreOnly then filter ((== PreAssertion) . (^. _2)) cs_asserts else cs_asserts allowedExpects = if checkPreOnly then [] else cs_expects + allowedAssertsWithApRegion = delineateApAssertions allowedAsserts + generalRestrictions = concat - [ memRestrictions - , addrRestrictions - , map (builderToAss cs_memoryVariables . fst) allowedAsserts - , [Expr.not (Expr.and (map fst allowedExpects)) | not (null allowedExpects)] + [ sansMiscInfo memRestrictions + , sansMiscInfo addrRestrictions + , [(builderToAss cs_memoryVariables builder, misc) | (builder, _, misc) <- allowedAssertsWithApRegion] + , sansMiscInfo [Expr.not (Expr.and (map (^. _1) allowedExpects)) | not (null allowedExpects)] ] - restrictions = rangeRestrictions <> generalRestrictions + restrictions = sansMiscInfo rangeRestrictions <> generalRestrictions + + sansMiscInfo = map withEmptyMisc + + delineateApAssertions asserts = + let (anteAp, postAp) = partition ((/= ApConstraintAssertion) . (^. _2)) asserts + in commentRegion anteAp "Begin AP constraints." ++ commentRegion postAp "End AP constraints." + where + commentRegion region msg = + if null region + then region + else + let (ass, assType, assMisc) = last region + in init region ++ [(ass, assType, commentBelow msg assMisc)] restrictRange :: forall ty. Function ty -> Maybe (Expr TBool) restrictRange (Function name) = case sing @ty of diff --git a/src/Horus/Command/SMT.hs b/src/Horus/Command/SMT.hs index 9da35af2..552be86c 100644 --- a/src/Horus/Command/SMT.hs +++ b/src/Horus/Command/SMT.hs @@ -1,7 +1,7 @@ -module Horus.Command.SMT (declare, assert) where +module Horus.Command.SMT (declare, assert, comment) where import Data.List.NonEmpty (NonEmpty (..)) -import Data.Text (Text, pack) +import Data.Text (Text, pack, unpack) import SimpleSMT qualified as SMT import Text.Printf (printf) @@ -10,6 +10,7 @@ import Horus.Expr.SMT qualified as Expr (toSMT) import Horus.Expr.Std (Function (..)) import Horus.Expr.Type (Ty (..)) import Horus.Expr.Type.SMT qualified as Ty (toSMT) +import SimpleSMT (SExpr (Atom)) declare :: forall ty. Function ty -> Text declare (Function name) = pack (printf "(declare-fun %s (%s) %s)" name args res) @@ -20,3 +21,6 @@ declare (Function name) = pack (printf "(declare-fun %s (%s) %s)" name args res) assert :: Integer -> Expr TBool -> Text assert fPrime e = pack (printf "(assert %s)" (SMT.showsSExpr (Expr.toSMT fPrime e) "")) + +comment :: Text -> Text +comment text = pack (printf "; %s" (SMT.showsSExpr (Atom (unpack text)) "")) diff --git a/src/Horus/Expr.hs b/src/Horus/Expr.hs index da4e9599..cc935cf9 100644 --- a/src/Horus/Expr.hs +++ b/src/Horus/Expr.hs @@ -343,9 +343,7 @@ a .>= b = function ">=" a b infix 4 .== (.==) :: Expr TFelt -> Expr TFelt -> Expr TBool -a .== b - | a == b = True - | otherwise = function "=" a b +a .== b = function "=" a b infix 4 ./= (./=) :: Expr TFelt -> Expr TFelt -> Expr TBool diff --git a/src/Horus/Global.hs b/src/Horus/Global.hs index dadecf6e..d65df13f 100644 --- a/src/Horus/Global.hs +++ b/src/Horus/Global.hs @@ -274,7 +274,7 @@ removeMathSAT m run = do instUsesLvars i = falseIfError $ do callee <- getCallee i spec <- getFuncSpec callee - let lvars = gatherLogicalVariables (fromMaybe Expr.True (fs'_pre spec)) + let lvars = gatherLogicalVariables . fst $ fromMaybe (Expr.True, "True") (fs'_pre spec) pure (not (null lvars)) falseIfError a = a `catchError` const (pure False) diff --git a/src/Horus/Instruction.hs b/src/Horus/Instruction.hs index 973f8fb8..1273c2f9 100644 --- a/src/Horus/Instruction.hs +++ b/src/Horus/Instruction.hs @@ -245,7 +245,7 @@ toSemiAsm Instruction{..} = do withRes f = fmap f getRes dst = mem (printReg i_dstRegister `add` i_dstOffset) mbApPP = case i_apUpdate of - Add1 -> "; ap++" + Add1 -> ", ap++" _ -> "" getRes = case i_resLogic of Op1 -> pure op1 diff --git a/src/Horus/JSON/Util.hs b/src/Horus/JSON/Util.hs index e21b1550..7e9f003b 100644 --- a/src/Horus/JSON/Util.hs +++ b/src/Horus/JSON/Util.hs @@ -31,5 +31,5 @@ instance Typeable a => FromJSON (HSourcedSExpr a) where HSourcedSExpr <$> do sourceLst <- v .: "source" - pure $ Text.intercalate "\n" sourceLst + pure $ Text.intercalate " | " sourceLst <*> v .: "sexpr" diff --git a/src/Horus/Module.hs b/src/Horus/Module.hs index 900463c9..fab39bba 100644 --- a/src/Horus/Module.hs +++ b/src/Horus/Module.hs @@ -189,7 +189,7 @@ descrOfOracle oracle = getModuleNameParts :: Identifiers -> Module -> (Text, Text, Text, Text) getModuleNameParts idents (Module spec prog oracle calledF _ mbPreCheckedFuncAndCallStack) = case beginOfModule prog of - Nothing -> ("", "empty: " <> pprExpr post, "", "") + Nothing -> ("", "empty: " <> pprExpr (fst post), "", "") Just label -> let scopedNames = labelNamesOfPc idents label isFloatingLabel = label /= calledF @@ -248,12 +248,12 @@ throw t = liftF' (Throw t) catch :: ModuleL a -> (Error -> ModuleL a) -> ModuleL a catch m h = liftF' (Catch m h id) -data SpecBuilder = SBRich | SBPlain (Expr TBool) +data SpecBuilder = SBRich | SBPlain (Expr TBool, Text) extractPlainBuilder :: FuncSpec -> ModuleL SpecBuilder -extractPlainBuilder (FuncSpec pre _ storage) +extractPlainBuilder (FuncSpec (pre, preDesc) _ storage) | not (null storage) = throwError EInvariantWithSVarUpdateSpec - | otherwise = pure (SBPlain (pre .&& (ap .== fp))) + | otherwise = pure (SBPlain (pre .&& (ap .== fp), preDesc)) gatherModules :: CFG -> [(Function, ScopedName, FuncSpec)] -> ModuleL () gatherModules cfg = traverse_ $ \(f, _, spec) -> gatherFromSource cfg f spec @@ -323,18 +323,18 @@ visit cfg fSpec function oracle callstack acc builder v@(Vertex _ label preCheck visitLoop SBRich = extractPlainBuilder fSpec >>= visitLoop visitLoop (SBPlain pre) | null assertions = throwError (ELoopNoInvariant label) - | otherwise = emit pre (Expr.and assertions) + | otherwise = emit pre (Expr.and assertions, "") visitLinear :: SpecBuilder -> ModuleL () visitLinear SBRich - | onFinalNode = emit (fs_pre fSpec) (Expr.and $ map snd (cfg_assertions cfg ^. ix v)) + | onFinalNode = emit (fs_pre fSpec) (Expr.and $ map snd (cfg_assertions cfg ^. ix v), "") | null assertions = visitArcs cfg fSpec function callstack' oracle' acc builder v | otherwise = extractPlainBuilder fSpec >>= visitLinear visitLinear (SBPlain pre) | null assertions = visitArcs cfg fSpec function callstack' oracle' acc builder v | otherwise = do - emit pre (Expr.and assertions) - visitArcs cfg fSpec function callstack' Map.empty [] (SBPlain (Expr.and assertions)) v + emit pre (Expr.and assertions, "") + visitArcs cfg fSpec function callstack' Map.empty [] (SBPlain (Expr.and assertions, "")) v callstack' = case f of Nothing -> callstack @@ -349,7 +349,7 @@ visit cfg fSpec function oracle callstack acc builder v@(Vertex _ label preCheck preCheckingStackFrame = (fCallerPc, uncheckedCallDestination labelledCall) preCheckingContext = (push preCheckingStackFrame callstack',) <$> preCheckedF - emit :: Expr TBool -> Expr TBool -> ModuleL () + emit :: (Expr TBool, Text) -> (Expr TBool, Text) -> ModuleL () emit pre post = emitModule (Module spec acc oracle' pc (callstack', label) preCheckingContext) where pc = fu_pc function diff --git a/src/Horus/SMTHygiene.hs b/src/Horus/SMTHygiene.hs new file mode 100644 index 00000000..e505817e --- /dev/null +++ b/src/Horus/SMTHygiene.hs @@ -0,0 +1,54 @@ +module Horus.SMTHygiene + ( AssertionMisc (..) + , emptyMisc + , commentAbove + , commentBelow + , commentRight + , encodeRestriction + , withEmptyMisc + , magicHygieneConstant + ) +where + +import Data.Text (Text, intercalate) +import Horus.Command.SMT qualified as Command +import Horus.Expr + +data AssertionMisc = AssertionMisc + { am_textAbove :: [Text] + , am_textBelow :: [Text] + , am_textRight :: [Text] + } + +emptyMisc :: AssertionMisc +emptyMisc = AssertionMisc [] [] [] + +withEmptyMisc :: Expr a -> (Expr a, AssertionMisc) +withEmptyMisc = (,emptyMisc) + +commentAbove :: Text -> AssertionMisc -> AssertionMisc +commentAbove comment am = am{am_textAbove = am_textAbove am ++ [comment]} + +commentBelow :: Text -> AssertionMisc -> AssertionMisc +commentBelow comment am = am{am_textBelow = am_textBelow am ++ [comment]} + +commentRight :: Text -> AssertionMisc -> AssertionMisc +commentRight comment am = am{am_textRight = am_textRight am ++ [comment]} + +-- We need a unique-enough Expr.True to identify assertions that only exists to carry comments +magicHygieneConstant :: Expr TBool +magicHygieneConstant = 24601 .== 24601 + +encodeRestriction :: Integer -> (Expr TBool, AssertionMisc) -> Text +encodeRestriction prime (expr, AssertionMisc{..}) = + indentGroup (map Command.comment am_textAbove) + <> indentNewlinePretty am_textAbove + <> assertNotMagic expr + <> indentNear (map Command.comment am_textRight) + <> indentNewlinePretty am_textBelow + <> indentGroup (map Command.comment am_textBelow) + where + indentGroup = intercalate "\n" + indentNear = intercalate " | " + indentNewlinePretty comments = if null comments then "" else "\n" + assertNotMagic e = if e == magicHygieneConstant then "" else Command.assert prime expr diff --git a/src/Horus/SW/FuncSpec.hs b/src/Horus/SW/FuncSpec.hs index 7a64c5ca..b1169fa1 100644 --- a/src/Horus/SW/FuncSpec.hs +++ b/src/Horus/SW/FuncSpec.hs @@ -3,6 +3,7 @@ module Horus.SW.FuncSpec (FuncSpec (..), emptyFuncSpec, emptyFuncSpec', FuncSpec import Data.Aeson (FromJSON (..), withObject, (.:)) import Data.Coerce (coerce) import Data.Maybe (fromMaybe) +import Data.Text (Text) import Horus.Expr (Expr, Ty (..)) import Horus.Expr qualified as Expr @@ -11,14 +12,19 @@ import Horus.SW.Storage (Storage) import Horus.SW.Storage qualified as Storage (parse) data FuncSpec = FuncSpec - { fs_pre :: Expr TBool - , fs_post :: Expr TBool + { fs_pre :: (Expr TBool, Text) + , fs_post :: (Expr TBool, Text) , fs_storage :: Storage } deriving stock (Show) emptyFuncSpec :: FuncSpec -emptyFuncSpec = FuncSpec{fs_pre = Expr.True, fs_post = Expr.True, fs_storage = mempty} +emptyFuncSpec = + FuncSpec + { fs_pre = (Expr.True, "True") + , fs_post = (Expr.True, "True") + , fs_storage = mempty + } {- | A version of `FuncSpec` that distinguishes omitted preconditions and postconditions from trivial ones. @@ -28,8 +34,8 @@ emptyFuncSpec = FuncSpec{fs_pre = Expr.True, fs_post = Expr.True, fs_storage = m `Nothing`. -} data FuncSpec' = FuncSpec' - { fs'_pre :: Maybe (Expr TBool) - , fs'_post :: Maybe (Expr TBool) + { fs'_pre :: Maybe (Expr TBool, Text) + , fs'_post :: Maybe (Expr TBool, Text) , fs'_storage :: Storage } @@ -39,17 +45,20 @@ emptyFuncSpec' = FuncSpec'{fs'_pre = Nothing, fs'_post = Nothing, fs'_storage = toFuncSpec :: FuncSpec' -> FuncSpec toFuncSpec FuncSpec'{..} = FuncSpec - { fs_pre = fromMaybe Expr.True fs'_pre - , fs_post = fromMaybe Expr.True fs'_post + { fs_pre = fromMaybe (Expr.True, "True") fs'_pre + , fs_post = fromMaybe (Expr.True, "True") fs'_post , fs_storage = fs'_storage } instance FromJSON FuncSpec where parseJSON = withObject "FuncSpec" $ \v -> FuncSpec - <$> fmap (elimHSExpr . hss_hsexpr) (v .: "pre") - <*> fmap (elimHSExpr . hss_hsexpr) (v .: "post") + <$> fmap (\x -> (elimHSExpr (hss_hsexpr x), normalize (hss_source x))) (v .: "pre") + <*> fmap (\x -> (elimHSExpr (hss_hsexpr x), normalize (hss_source x))) (v .: "post") <*> (Storage.parse =<< (v .: "storage_update")) + where + normalize "" = "True" + normalize x = x elimHSExpr :: HSExpr a -> Expr a elimHSExpr = coerce diff --git a/src/Horus/SW/Std.hs b/src/Horus/SW/Std.hs index 90bc356b..9b68d442 100644 --- a/src/Horus/SW/Std.hs +++ b/src/Horus/SW/Std.hs @@ -17,7 +17,7 @@ stdSpecs :: Map ScopedName FuncSpec stdSpecs = Map.fromList stdSpecsList mkReadSpec :: ScopedName -> Int -> FuncSpec -mkReadSpec name arity = emptyFuncSpec{fs_post = memory (ap - 1) .== var} +mkReadSpec name arity = emptyFuncSpec{fs_post = (memory (ap - 1) .== var, "Reading: " <> tShow name)} where offsets = [-3 - arity + 1 .. -3] args = [memory (fp + fromIntegral offset) | offset <- offsets] @@ -40,6 +40,8 @@ trustedStdFuncs = , "starkware.cairo.common.math.assert_le_felt" ] +-- TODO(chore): Fill in 'Text' representation of the specs. + {- | A lexicographically sorted by fs_name list of specifications of standard library functions. @@ -54,21 +56,21 @@ stdSpecsList = , emptyFuncSpec { fs_post = let diff = memory (fp - 3) - memory (fp - 4) - in 0 .<= diff .&& diff .< rcBound + in (0 .<= diff .&& diff .< rcBound, "") } ) , ( "starkware.cairo.common.math.assert_le_felt" - , emptyFuncSpec{fs_post = memory (fp - 4) .<= memory (fp - 3)} + , emptyFuncSpec{fs_post = (memory (fp - 4) .<= memory (fp - 3), "")} ) , ( "starkware.cairo.common.math.assert_nn" - , emptyFuncSpec{fs_post = 0 .<= memory (fp - 3) .&& memory (fp - 3) .< rcBound} + , emptyFuncSpec{fs_post = (0 .<= memory (fp - 3) .&& memory (fp - 3) .< rcBound, "")} ) , ( "starkware.cairo.common.math.assert_nn_le" , emptyFuncSpec - { fs_post = 0 .<= memory (fp - 4) .&& memory (fp - 4) .<= memory (fp - 3) + { fs_post = (0 .<= memory (fp - 4) .&& memory (fp - 4) .<= memory (fp - 3), "") } ) , @@ -78,7 +80,7 @@ stdSpecsList = let low = memory (ap - 1) high = memory (ap - 2) v = memory (fp - 3) - in low .== v `Expr.mod` rcBound .&& high .== v `Expr.div` rcBound + in (low .== v `Expr.mod` rcBound .&& high .== v `Expr.div` rcBound, "") } ) , @@ -86,13 +88,15 @@ stdSpecsList = , let (value, div') = (memory (fp - 4), memory (fp - 3)) (q, r) = (memory (ap - 2), memory (ap - 1)) in emptyFuncSpec - { fs_pre = ExitField (0 .< div' .&& div' * rcBound .<= prime) + { fs_pre = (ExitField (0 .< div' .&& div' * rcBound .<= prime), "") , fs_post = - Expr.and - [ 0 .<= q .&& q .< rcBound - , 0 .<= r .&& r .< div' - , value .== q * div' + r - ] + ( Expr.and + [ 0 .<= q .&& q .< rcBound + , 0 .<= r .&& r .< div' + , value .== q * div' + r + ] + , "" + ) } ) , @@ -101,48 +105,52 @@ stdSpecsList = { fs_post = let diff = memory (fp - 3) - memory (fp - 4) res = memory (ap - 1) - in Expr.ite - (0 .<= diff .&& diff .< rcBound) - (res .== 1) - (res .== 0) + in ( Expr.ite + (0 .<= diff .&& diff .< rcBound) + (res .== 1) + (res .== 0) + , "" + ) } ) , ( "starkware.cairo.common.math_cmp.is_nn" , emptyFuncSpec { fs_post = - Expr.ite - (0 .<= memory (fp - 3) .&& memory (fp - 3) .< rcBound) - (memory (ap - 1) .== 1) - (memory (ap - 1) .== 0) + ( Expr.ite + (0 .<= memory (fp - 3) .&& memory (fp - 3) .< rcBound) + (memory (ap - 1) .== 1) + (memory (ap - 1) .== 0) + , "" + ) } ) , ( "starkware.cairo.lang.compiler.lib.registers.get_ap" - , emptyFuncSpec{fs_post = memory (ap - 1) .== fp - 2} + , emptyFuncSpec{fs_post = (memory (ap - 1) .== fp - 2, "")} ) , ( "starkware.cairo.lang.compiler.lib.registers.get_fp_and_pc" , emptyFuncSpec - { fs_post = memory (ap - 2) .== memory (fp - 2) .&& memory (ap - 1) .== memory (fp - 1) + { fs_post = (memory (ap - 2) .== memory (fp - 2) .&& memory (ap - 1) .== memory (fp - 1), "") } ) , ( "starkware.starknet.common.syscalls.get_block_timestamp" , emptyFuncSpec - { fs_post = memory (ap - 1) .== blockTimestamp + { fs_post = (memory (ap - 1) .== blockTimestamp, "") } ) , ( "starkware.starknet.common.syscalls.get_caller_address" , emptyFuncSpec - { fs_post = memory (ap - 1) .== callerAddress + { fs_post = (memory (ap - 1) .== callerAddress, "") } ) , ( "starkware.starknet.common.syscalls.get_contract_address" , emptyFuncSpec - { fs_post = memory (ap - 1) .== contractAddress + { fs_post = (memory (ap - 1) .== contractAddress, "") } ) ] diff --git a/tests/resources/golden/pre_fancy.cairo b/tests/resources/golden/pre_fancy.cairo new file mode 100644 index 00000000..e1954d79 --- /dev/null +++ b/tests/resources/golden/pre_fancy.cairo @@ -0,0 +1,17 @@ +// @pre token == 0 or token == 1 +// @post (token == 0 and $Return.t == 1) or (token == 1 and $Return.t == 0) +func get_opposite_token(token: felt) -> (t: felt) { + if (token == 0) { + return (t=1); + } else { + return (t=0); + } +} + +// @pre token == 0 or token == 1 +// @post $Return.res == token +func bar(token) -> (res: felt) { +let (a) = get_opposite_token(token); +let (b) = get_opposite_token(a); +return (res=b); +} \ No newline at end of file diff --git a/tests/resources/golden/pre_fancy.gold b/tests/resources/golden/pre_fancy.gold new file mode 100644 index 00000000..f46c6c2e --- /dev/null +++ b/tests/resources/golden/pre_fancy.gold @@ -0,0 +1,7 @@ + +bar +Verified + +get_opposite_token +Verified +