Skip to content

Commit

Permalink
Implement comment injections
Browse files Browse the repository at this point in the history
  • Loading branch information
Ferinko authored and langfield committed Mar 28, 2023
1 parent db57ebe commit f41b3c6
Show file tree
Hide file tree
Showing 15 changed files with 256 additions and 91 deletions.
1 change: 1 addition & 0 deletions horus-check.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ library
Horus.Preprocessor
Horus.Preprocessor.Runner
Horus.Preprocessor.Solvers
Horus.SMTHygiene
Horus.SW.Builtin
Horus.SW.CairoType
Horus.SW.CairoType.JSON
Expand Down
8 changes: 6 additions & 2 deletions src/Horus/CFGBuild.hs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ module Horus.CFGBuild
, mkInv
, Vertex (..)
, getVerts
, isOptimising
, isPreCheckingVertex
)
where
Expand Down Expand Up @@ -81,6 +82,9 @@ data Vertex = Vertex
}
deriving (Show)

isOptimising :: Vertex -> Bool
isOptimising = isJust . v_preCheckedF

instance Eq Vertex where
(==) lhs rhs = v_name lhs == v_name rhs

Expand Down Expand Up @@ -273,7 +277,7 @@ addArcsFrom inlinables prog rows seg@(Segment s) vFrom
-- in which Pre(G) is assumed to hold at the PC of the G call site, as it will have
-- been checked by the module induced by the ghost vertex.
ghostV <- addOptimisingVertex (nextSegmentLabel seg) callee
pre <- maybe (mkPre Expr.True) mkPre . fs'_pre <$> getFuncSpec callee
pre <- maybe (mkPre Expr.True) (mkPre . fst) . fs'_pre <$> getFuncSpec callee

-- Important note on the way we deal with logical variables. These are @declare-d and
-- their values can be bound in preconditions. They generate existentials which only occur
Expand Down Expand Up @@ -333,7 +337,7 @@ addAssertions inlinables identifiers = do
(Nothing, Nothing) ->
when (fu_pc f `Set.notMember` Set.map sf_pc inlinables) $
for_ retVs (`addAssertion` mkPost Expr.True)
_ -> for_ retVs (`addAssertion` maybe (mkPost Expr.True) mkPost post)
_ -> for_ retVs (`addAssertion` maybe (mkPost Expr.True) (mkPost . fst) post)
ILabel pc ->
whenJustM (getInvariant idName) $ \inv ->
getSalientVertex pc >>= (`addAssertion` mkInv inv)
Expand Down
84 changes: 63 additions & 21 deletions src/Horus/CairoSemantics.hs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ import Horus.Expr.Vars
, pattern StorageVar
)
import Horus.Expr.Vars qualified as Vars
import Horus.FunctionAnalysis (ScopedFunction (sf_pc))
import Horus.FunctionAnalysis (ScopedFunction (sf_pc, sf_scopedName))
import Horus.Instruction
( ApUpdate (..)
, Instruction (..)
Expand All @@ -53,18 +53,20 @@ import Horus.Instruction
, getNextPcInlinedWithFallback
, isCall
, isRet
, toSemiAsmUnsafe
, uncheckedCallDestination
)
import Horus.Label (Label (..), tShowLabel)
import Horus.Module (Module (..), apEqualsFp, isPreChecking)
import Horus.Program (ApTracking (..))
import Horus.SMTHygiene (AssertionMisc, commentAbove, commentBelow, commentRight, emptyMisc, magicHygieneConstant)
import Horus.SW.Builtin (Builtin, BuiltinOffsets (..))
import Horus.SW.Builtin qualified as Builtin (name)
import Horus.SW.FuncSpec (FuncSpec (..), FuncSpec', toFuncSpec)
import Horus.SW.ScopedName (ScopedName)
import Horus.SW.ScopedName qualified as ScopedName (fromText)
import Horus.SW.Storage (Storage)
import Horus.SW.Storage qualified as Storage (equivalenceExpr)
import Horus.SW.Storage qualified as Storage
import Horus.Util (enumerate, safeLast, tShow, whenJust, whenJustM)

data MemoryVariable = MemoryVariable
Expand All @@ -78,17 +80,19 @@ data AssertionType
= PreAssertion
| PostAssertion
| InstructionSemanticsAssertion
| ApConstraintAssertion
deriving (Eq, Show)

data CairoSemanticsF a
= Assert' (Expr TBool) AssertionType a
= Assert' (Expr TBool) AssertionType AssertionMisc a
| Expect' (Expr TBool) AssertionType a
| DeclareMem (Expr TFelt) (Expr TFelt -> a)
| DeclareLocalMem (Expr TFelt) (MemoryVariable -> a)
| GetApTracking Label (ApTracking -> a)
| GetBuiltinOffsets Label Builtin (Maybe BuiltinOffsets -> a)
| GetCallee LabeledInst (ScopedFunction -> a)
| GetFuncSpec ScopedFunction (FuncSpec' -> a)
| GetCurrentF (ScopedFunction -> a)
| GetFunPc Label (Label -> a)
| GetInlinable (Set.Set ScopedFunction -> a)
| GetMemVars ([MemoryVariable] -> a)
Expand All @@ -109,8 +113,8 @@ deriving instance Functor CairoSemanticsF

type CairoSemanticsL = F CairoSemanticsF

assert' :: AssertionType -> Expr TBool -> CairoSemanticsL ()
assert' assType a = liftF (Assert' a assType ())
assert' :: AssertionType -> AssertionMisc -> Expr TBool -> CairoSemanticsL ()
assert' assType misc a = liftF (Assert' a assType misc ())

expect' :: Expr TBool -> CairoSemanticsL ()
expect' a = liftF (Expect' a PostAssertion ())
Expand All @@ -127,6 +131,9 @@ getBuiltinOffsets l b = liftF (GetBuiltinOffsets l b id)
getCallee :: LabeledInst -> CairoSemanticsL ScopedFunction
getCallee call = liftF (GetCallee call id)

getCurrentF :: CairoSemanticsL ScopedFunction
getCurrentF = liftF (GetCurrentF id)

getFuncSpec :: ScopedFunction -> CairoSemanticsL FuncSpec'
getFuncSpec name = liftF (GetFuncSpec name id)

Expand Down Expand Up @@ -162,10 +169,27 @@ getStorage :: CairoSemanticsL Storage
getStorage = liftF (GetStorage id)

assert :: Expr TBool -> CairoSemanticsL ()
assert a = assert' InstructionSemanticsAssertion =<< memoryRemoval a
assert a = assert' InstructionSemanticsAssertion emptyMisc =<< memoryRemoval a

_assertWithComment :: Expr TBool -> Text -> CairoSemanticsL ()
_assertWithComment a text =
assert' InstructionSemanticsAssertion (commentRight text emptyMisc) =<< memoryRemoval a

assertWithComment' :: Expr TBool -> AssertionMisc -> CairoSemanticsL ()
assertWithComment' a misc =
assert' InstructionSemanticsAssertion misc =<< memoryRemoval a

_assertPre :: Expr TBool -> CairoSemanticsL ()
_assertPre = assertPreWithComment' emptyMisc

assertPre :: Expr TBool -> CairoSemanticsL ()
assertPre a = assert' PreAssertion =<< memoryRemoval a
commentHere' :: Text -> CairoSemanticsL ()
commentHere' = assertWithComment' magicHygieneConstant . flip commentRight emptyMisc

assertApConstraint :: Expr TBool -> CairoSemanticsL ()
assertApConstraint a = assert' ApConstraintAssertion emptyMisc =<< memoryRemoval a

assertPreWithComment' :: AssertionMisc -> Expr TBool -> CairoSemanticsL ()
assertPreWithComment' misc a = assert' PreAssertion misc =<< memoryRemoval a

expect :: Expr TBool -> CairoSemanticsL ()
expect a = expect' =<< memoryRemoval a
Expand Down Expand Up @@ -289,15 +313,17 @@ moduleEndAp mdl =
contain a storage update.
-}
encodeModule :: Module -> CairoSemanticsL ()
encodeModule mdl@(Module (FuncSpec pre post storage) instrs oracle _ _ mbPreCheckedFuncWithCallStack) = do
encodeModule mdl@(Module (FuncSpec (pre, preDesc) (post, _) storage) instrs oracle _ _ mbPreCheckedFuncWithCallStack) = do
enableStorage
fp <- getFp
apEnd <- moduleEndAp mdl
preparedStorage <- traverseStorage (prepare' apEnd fp) storage

apStart <- moduleStartAp mdl
assert (fp .<= apStart)
assertPre =<< prepare' apStart fp (pre .&& apEqualsFp)
assertPreWithComment' (commentAbove ("Module pre: " <> preDesc) emptyMisc)
=<< prepare' apStart fp (pre .&& apEqualsFp)

-- The last FP might be influenced in the optimizing case, we need to grab it as propagated
-- by the encoding of the semantics of the call.
lastFp <-
Expand Down Expand Up @@ -406,19 +432,23 @@ mkInstructionConstraints ::
Map (NonEmpty Label, Label) Bool ->
(Int, LabeledInst) ->
CairoSemanticsL (Maybe (Expr TFelt))
mkInstructionConstraints instrs mbPreCheckedFuncWithCallStack jnzOracle (idx, lInst@(pc, Instruction{..})) = do
mkInstructionConstraints instrs mbPreCheckedFuncWithCallStack jnzOracle (idx, lInst@(pc, inst@Instruction{..})) = do
fp <- getFp
dst <- prepare pc fp (memory (regToVar i_dstRegister + fromInteger i_dstOffset))
case i_opCode of
Call -> mkCallConstraints pc nextPc fp mbPreCheckedFuncWithCallStack =<< getCallee lInst
AssertEqual -> getRes fp lInst >>= \res -> assert (res .== dst) $> Nothing
AssertEqual ->
getRes fp lInst >>= \res ->
assertWithComment' (res .== dst) (commentRight (toSemiAsmUnsafe inst) emptyMisc) $> Nothing
Nop -> do
stackTrace <- getOracle
case jnzOracle Map.!? (stackTrace, pc) of
Just False -> assert (dst .== 0) $> Nothing
Just True -> assert (dst ./= 0) $> Nothing
Nothing -> pure Nothing
Ret -> pop $> Nothing
Ret -> do
currentF <- getCurrentF
commentHere' ("Inlining ended: " <> tShow (sf_scopedName currentF)) >> pop $> Nothing
where
nextPc = getNextPcInlinedWithFallback instrs idx

Expand All @@ -435,7 +465,10 @@ mkCallConstraints pc nextPc fp mbPreCheckedFuncWithCallStack f = do
nextAp <- prepare pc calleeFp (Vars.fp .== Vars.ap + 2)
saveOldFp <- prepare pc fp (memory Vars.ap .== Vars.fp)
setNextPc <- prepare pc fp (memory (Vars.ap + 1) .== fromIntegral (unLabel nextPc))
assert (Expr.and [nextAp, saveOldFp, setNextPc])
canInline <- isInlinable f
assertWithComment'
(Expr.and [nextAp, saveOldFp, setNextPc])
(commentAbove (describeCall canInline <> tShow (sf_scopedName f)) emptyMisc)
push stackFrame
-- Considering we have already pushed the stackFrame by now, we need to make sure that either
-- the function is inlinable and we'll encounter a 'ret', or we need to pop right away
Expand All @@ -448,7 +481,7 @@ mkCallConstraints pc nextPc fp mbPreCheckedFuncWithCallStack f = do
-- of the function being invoked must not be considered.
guardWith (isInlinable f) (pure Nothing) $ do
-- An inlined function will have a 'ret' at some point, do not pop here.
(FuncSpec pre post storage) <- getFuncSpec f <&> toFuncSpec
(FuncSpec (pre, preDesc) (post, postDesc) storage) <- getFuncSpec f <&> toFuncSpec
let pre' = suffixLogicalVariables lvarSuffix pre
post' = suffixLogicalVariables lvarSuffix post
preparedPre <- prepare nextPc calleeFp =<< storageRemoval pre'
Expand All @@ -462,13 +495,22 @@ mkCallConstraints pc nextPc fp mbPreCheckedFuncWithCallStack f = do
-- However, pre will be checked in a separate 'optimising' module and we can therefore simply
-- assert it holds here. If it does not, the corresponding pre-checking module will fail,
-- thus failing the judgement for the entire function.
assert preparedPre
assert preparedPost
assertWithComment'
preparedPre
(commentAbove ("Pre of: " <> functionName <> " | " <> preDesc) emptyMisc)
assertWithComment'
preparedPost
( commentBelow
("Finished abstracting function: " <> functionName)
(commentAbove ("Post of: " <> functionName <> " | " <> postDesc) emptyMisc)
)
pure Nothing
where
lvarSuffix = "+" <> tShowLabel pc
calleePc = sf_pc f
stackFrame = (pc, calleePc)
functionName = tShow (sf_scopedName f)
describeCall inl = if inl then "Inlining function: " else "Abstracting function: "
-- Determine whether the current function matches the function being optimised exactly -
-- this necessitates comparing execution traces.
isModuleCheckingPre = do
Expand Down Expand Up @@ -498,15 +540,15 @@ mkApConstraints apEnd insts = do
ap2 <- getAp pcNext
fp <- getFp
getApIncrement fp lInst >>= \case
Just apIncrement -> assert (ap1 + apIncrement .== ap2)
Nothing | not isNewStackframe -> assert (ap1 .< ap2)
Just apIncrement -> assertApConstraint (ap1 + apIncrement .== ap2)
Nothing | not isNewStackframe -> assertApConstraint (ap1 .< ap2)
Nothing -> pure ()
lastAp <- getAp lastPc
when (isRet lastInst) pop
fp <- getFp
getApIncrement fp lastLInst >>= \case
Just lastApIncrement -> assert (lastAp + lastApIncrement .== apEnd)
Nothing -> assert (lastAp .< apEnd)
Just lastApIncrement -> assertApConstraint (lastAp + lastApIncrement .== apEnd)
Nothing -> assertApConstraint (lastAp .< apEnd)
where
lastLInst@(lastPc, lastInst) = NonEmpty.last insts

Expand Down
55 changes: 38 additions & 17 deletions src/Horus/CairoSemantics/Runner.hs
Original file line number Diff line number Diff line change
Expand Up @@ -16,22 +16,23 @@ import Control.Monad.Reader
, asks
, runReaderT
)
import Control.Monad.State (MonadState (get), State, runState)
import Control.Monad.State (MonadState (get), State, gets, runState)
import Data.Foldable (toList)
import Data.Function ((&))
import Data.Functor (($>))
import Data.List (partition)
import Data.List qualified as List (find, tails)
import Data.Map qualified as Map (map, null, unionWith)
import Data.Map qualified as Map
import Data.Maybe (mapMaybe)
import Data.Singletons (sing)
import Data.Some (foldSome)
import Data.Text (Text)
import Data.Text qualified as Text (intercalate)
import Lens.Micro (Lens', (%~), (<&>), (^.))
import Lens.Micro (Lens', (%~), (<&>), (^.), _1, _2)
import Lens.Micro.GHC ()
import Lens.Micro.Mtl (use, (%=), (.=), (<%=))

import Horus.CairoSemantics (AssertionType (PreAssertion), CairoSemanticsF (..), CairoSemanticsL, MemoryVariable (..))
import Horus.CairoSemantics (AssertionType (ApConstraintAssertion, PreAssertion), CairoSemanticsF (..), CairoSemanticsL, MemoryVariable (..))
import Horus.CallStack (CallStack, digestOfCallStack, pop, push, reset, stackTrace, top)
import Horus.Command.SMT qualified as Command
import Horus.ContractInfo (ContractInfo (..))
Expand All @@ -43,6 +44,7 @@ import Horus.Expr.Type (STy (..))
import Horus.Expr.Util (gatherNonStdFunctions)
import Horus.Expr.Vars (prime, rcBound)
import Horus.FunctionAnalysis (ScopedFunction (sf_scopedName))
import Horus.SMTHygiene (AssertionMisc, commentBelow, encodeRestriction, withEmptyMisc)
import Horus.SW.Builtin qualified as Builtin (rcBound)
import Horus.SW.Storage (Storage)
import Horus.SW.Storage qualified as Storage (read)
Expand All @@ -61,7 +63,7 @@ builderToAss mv (ExistentialAss f) = f mv
-- inlined.
data ConstraintsState = ConstraintsState
{ cs_memoryVariables :: [MemoryVariable]
, cs_asserts :: [(AssertionBuilder, AssertionType)]
, cs_asserts :: [(AssertionBuilder, AssertionType, AssertionMisc)]
, cs_expects :: [(Expr TBool, AssertionType)]
, cs_nameCounter :: Int
, cs_callStack :: CallStack
Expand All @@ -70,7 +72,7 @@ data ConstraintsState = ConstraintsState
csMemoryVariables :: Lens' ConstraintsState [MemoryVariable]
csMemoryVariables lMod g = fmap (\x -> g{cs_memoryVariables = x}) (lMod (cs_memoryVariables g))

csAsserts :: Lens' ConstraintsState [(AssertionBuilder, AssertionType)]
csAsserts :: Lens' ConstraintsState [(AssertionBuilder, AssertionType, AssertionMisc)]
csAsserts lMod g = fmap (\x -> g{cs_asserts = x}) (lMod (cs_asserts g))

csExpects :: Lens' ConstraintsState [(Expr TBool, AssertionType)]
Expand Down Expand Up @@ -121,7 +123,7 @@ interpret :: forall a. CairoSemanticsL a -> Impl a
interpret = iterM exec
where
exec :: CairoSemanticsF (Impl a) -> Impl a
exec (Assert' a assType cont) = eConstraints . csAsserts %= ((QFAss a, assType) :) >> cont
exec (Assert' a assType misc cont) = eConstraints . csAsserts %= ((QFAss a, assType, misc) :) >> cont
exec (Expect' a assType cont) = eConstraints . csExpects %= ((a, assType) :) >> cont
exec (DeclareMem address cont) = do
memVars <- use (eConstraints . csMemoryVariables)
Expand Down Expand Up @@ -151,6 +153,10 @@ interpret = iterM exec
exec (GetCallee inst cont) = do
ci <- ask
ci_getCallee ci inst >>= cont
exec (GetCurrentF cont) = do
(_, calledF) <- gets $ top . (^. csCallStack) . e_constraints
ci <- ask
cont (ci_functions ci Map.! calledF)
exec (GetFuncSpec name cont) = do
ci <- ask
ci_getFuncSpec ci name & cont
Expand Down Expand Up @@ -200,9 +206,9 @@ debugFriendlyModel ConstraintsState{..} =
[ ["# Memory"]
, memoryPairs
, ["# Assert"]
, map (pprExpr . builderToAss cs_memoryVariables . fst) cs_asserts
, map (pprExpr . builderToAss cs_memoryVariables . (^. _1)) cs_asserts
, ["# Expect"]
, map (pprExpr . fst) cs_expects
, map (pprExpr . (^. _1)) cs_expects
]
where
memoryPairs =
Expand All @@ -222,28 +228,43 @@ restrictMemTail (mv0 : rest) =

makeModel :: Bool -> ConstraintsState -> Integer -> Text
makeModel checkPreOnly ConstraintsState{..} fPrime =
Text.intercalate "\n" (decls <> map (Command.assert fPrime) restrictions)
Text.intercalate "\n" (decls <> map (encodeRestriction fPrime) restrictions)
where
functions =
toList (foldMap gatherNonStdFunctions generalRestrictions <> gatherNonStdFunctions prime)
toList (foldMap (gatherNonStdFunctions . fst) generalRestrictions <> gatherNonStdFunctions prime)
decls = map (foldSome Command.declare) functions
rangeRestrictions = mapMaybe (foldSome restrictRange) functions
memRestrictions = concatMap restrictMemTail (List.tails cs_memoryVariables)
addrRestrictions =
[Expr.const mv_addrName .== mv_addrExpr | MemoryVariable{..} <- cs_memoryVariables]

-- If checking @pre only, only take `PreAssertion`s, no postconditions.
allowedAsserts = if checkPreOnly then filter ((== PreAssertion) . snd) cs_asserts else cs_asserts
allowedAsserts = if checkPreOnly then filter ((== PreAssertion) . (^. _2)) cs_asserts else cs_asserts
allowedExpects = if checkPreOnly then [] else cs_expects

allowedAssertsWithApRegion = delineateApAssertions allowedAsserts

generalRestrictions =
concat
[ memRestrictions
, addrRestrictions
, map (builderToAss cs_memoryVariables . fst) allowedAsserts
, [Expr.not (Expr.and (map fst allowedExpects)) | not (null allowedExpects)]
[ sansMiscInfo memRestrictions
, sansMiscInfo addrRestrictions
, [(builderToAss cs_memoryVariables builder, misc) | (builder, _, misc) <- allowedAssertsWithApRegion]
, sansMiscInfo [Expr.not (Expr.and (map (^. _1) allowedExpects)) | not (null allowedExpects)]
]
restrictions = rangeRestrictions <> generalRestrictions
restrictions = sansMiscInfo rangeRestrictions <> generalRestrictions

sansMiscInfo = map withEmptyMisc

delineateApAssertions asserts =
let (anteAp, postAp) = partition ((/= ApConstraintAssertion) . (^. _2)) asserts
in commentRegion anteAp "Begin AP constraints." ++ commentRegion postAp "End AP constraints."
where
commentRegion region msg =
if null region
then region
else
let (ass, assType, assMisc) = last region
in init region ++ [(ass, assType, commentBelow msg assMisc)]

restrictRange :: forall ty. Function ty -> Maybe (Expr TBool)
restrictRange (Function name) = case sing @ty of
Expand Down
Loading

0 comments on commit f41b3c6

Please sign in to comment.