Skip to content

Commit

Permalink
implement comment injections
Browse files Browse the repository at this point in the history
  • Loading branch information
Ferinko committed Mar 27, 2023
1 parent 1f8d465 commit 8e8119e
Show file tree
Hide file tree
Showing 17 changed files with 273 additions and 166 deletions.
1 change: 1 addition & 0 deletions horus-check.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ library
Horus.Preprocessor
Horus.Preprocessor.Runner
Horus.Preprocessor.Solvers
Horus.SMTHygiene
Horus.SW.Builtin
Horus.SW.CairoType
Horus.SW.CairoType.JSON
Expand Down
62 changes: 12 additions & 50 deletions src/Horus/CFGBuild.hs
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
{-# LANGUAGE InstanceSigs #-}

module Horus.CFGBuild
( CFGBuildL (..)
, ArcCondition (..)
Expand All @@ -16,26 +14,19 @@ module Horus.CFGBuild
, Vertex (..)
, getVerts
, isOptimising
, AnnotationType (..)
, mkPre
, mkPost
, mkInv
, Vertex (..)
, getVerts
, isPreCheckingVertex
)
where

import Control.Arrow (Arrow (second))
import Control.Arrow (Arrow (second))
import Control.Arrow (Arrow (second))
import Control.Arrow
( Arrow (second)
)
import Control.Monad (when)
import Control.Monad.Except (MonadError (..))
import Control.Monad.Free.Church (F, liftF)
import Data.Coerce (coerce)
import Data.Foldable (forM_, for_, toList)
import Data.Functor ((<&>))
import Data.Functor ((<&>))
import Data.List.NonEmpty (NonEmpty (..))
import Data.List.NonEmpty qualified as NonEmpty (last, reverse, (<|))
import Data.Map qualified as Map (lookup, toList)
Expand All @@ -44,24 +35,22 @@ import Data.Set (Set)
import Data.Set qualified as Set
import Data.Text (Text)
import Data.Traversable (for)
import Data.Maybe (isJust)
import Data.Traversable (for)
import Lens.Micro.GHC ()

import Horus.Expr (Expr (), Ty (..))
import Horus.Expr qualified as Expr
import Horus.Expr.Util (gatherLogicalVariables)
import Horus.Expr.Util (gatherLogicalVariables)
import Horus.Expr.Util
( gatherLogicalVariables
)
import Horus.FunctionAnalysis
( FInfo
, FuncOp (ArcCall, ArcRet)
, ScopedFunction (ScopedFunction, sf_pc, sf_scopedName, sf_scopedName)
, ScopedFunction (ScopedFunction, sf_pc, sf_scopedName)
, callersOf
, pcToFunOfProg
, programLabels
, sizeOfCall
, uncheckedScopedFOfPc
, uncheckedScopedFOfPc
)
import Horus.Instruction
( Instruction (..)
Expand All @@ -76,8 +65,7 @@ import Horus.Program (Identifiers, Program (..))
import Horus.SW.FuncSpec (FuncSpec' (fs'_post, fs'_pre))
import Horus.SW.Identifier (Function (fu_pc), Identifier (IFunction, ILabel))
import Horus.SW.ScopedName (ScopedName)
import Horus.Util (appendList, tShow, whenJustM, tShow)
import Horus.Expr.Util (gatherLogicalVariables)
import Horus.Util (appendList, tShow, whenJustM)

data AnnotationType = APre | APost | AInv
deriving stock (Show)
Expand All @@ -94,38 +82,12 @@ mkInv = (AInv,)
data Vertex = Vertex
{ v_name :: Text
, v_label :: Label
, v_optimisesF :: Maybe ScopedFunction
, v_preCheckedF :: Maybe ScopedFunction
}
deriving (Show)

instance Eq Vertex where
(==) lhs rhs = v_name lhs == v_name rhs

instance Ord Vertex where
compare :: Vertex -> Vertex -> Ordering
compare lhs rhs = v_name lhs `compare` v_name rhs

isOptimising :: Vertex -> Bool
isOptimising = isJust . v_optimisesF

data AnnotationType = APre | APost | AInv
deriving stock (Show)

mkPre :: Expr TBool -> (AnnotationType, Expr TBool)
mkPre = (APre,)

mkPost :: Expr TBool -> (AnnotationType, Expr TBool)
mkPost = (APost,)

mkInv :: Expr TBool -> (AnnotationType, Expr TBool)
mkInv = (AInv,)

data Vertex = Vertex
{ v_name :: Text
, v_label :: Label
, v_preCheckedF :: Maybe ScopedFunction
}
deriving (Show)
isOptimising = isJust . v_preCheckedF

instance Eq Vertex where
(==) lhs rhs = v_name lhs == v_name rhs
Expand Down Expand Up @@ -319,7 +281,7 @@ addArcsFrom inlinables prog rows seg@(Segment s) vFrom
-- in which Pre(G) is assumed to hold at the PC of the G call site, as it will have
-- been checked by the module induced by the ghost vertex.
ghostV <- addOptimisingVertex (nextSegmentLabel seg) callee
pre <- maybe (mkPre Expr.True) mkPre . fs'_pre <$> getFuncSpec callee
pre <- maybe (mkPre Expr.True) (mkPre . fst) . fs'_pre <$> getFuncSpec callee

-- Important note on the way we deal with logical variables. These are @declare-d and
-- their values can be bound in preconditions. They generate existentials which only occur
Expand Down Expand Up @@ -379,7 +341,7 @@ addAssertions inlinables identifiers = do
(Nothing, Nothing) ->
when (fu_pc f `Set.notMember` Set.map sf_pc inlinables) $
for_ retVs (`addAssertion` mkPost Expr.True)
_ -> for_ retVs (`addAssertion` maybe (mkPost Expr.True) mkPost post)
_ -> for_ retVs (`addAssertion` maybe (mkPost Expr.True) (mkPost . fst) post)
ILabel pc ->
whenJustM (getInvariant idName) $ \inv ->
getSalientVertex pc >>= (`addAssertion` mkInv inv)
Expand Down
2 changes: 1 addition & 1 deletion src/Horus/CFGBuild/Runner.hs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import Lens.Micro (Lens', at, (&), (^.), _Just)
import Lens.Micro.GHC ()
import Lens.Micro.Mtl ((%=), (<%=))

import Horus.CFGBuild (AnnotationType, AnnotationType, ArcCondition (..), CFGBuildF (..), CFGBuildL (..), Label, LabeledInst, Vertex (..), isPreCheckingVertex, Vertex (..), isOptimising)
import Horus.CFGBuild (AnnotationType, ArcCondition (..), CFGBuildF (..), CFGBuildL (..), Label, LabeledInst, Vertex (..), isPreCheckingVertex)
import Horus.ContractInfo (ContractInfo (..))
import Horus.Expr (Expr, Ty (..))
import Horus.FunctionAnalysis (FInfo)
Expand Down
97 changes: 64 additions & 33 deletions src/Horus/CairoSemantics.hs
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@ import Data.Maybe (fromMaybe, isJust)
import Data.Set qualified as Set (Set, member)
import Data.Text (Text)
import Data.Traversable (for)
import Lens.Micro ((^.), _3)
import Data.Traversable (for)
import Lens.Micro ((^.), _1)

import Horus.CallStack as CS (CallEntry, CallStack)
Expand All @@ -43,7 +41,7 @@ import Horus.Expr.Vars
, pattern StorageVar
)
import Horus.Expr.Vars qualified as Vars
import Horus.FunctionAnalysis (ScopedFunction (sf_pc))
import Horus.FunctionAnalysis (ScopedFunction (sf_pc, sf_scopedName))
import Horus.Instruction
( ApUpdate (..)
, Instruction (..)
Expand All @@ -55,18 +53,20 @@ import Horus.Instruction
, getNextPcInlinedWithFallback
, isCall
, isRet
, toSemiAsmUnsafe
, uncheckedCallDestination
)
import Horus.Label (Label (..), tShowLabel)
import Horus.Module (Module (..), apEqualsFp, isPreChecking, isOptimising)
import Horus.Module (Module (..), apEqualsFp, isPreChecking)
import Horus.Program (ApTracking (..))
import Horus.SMTHygiene (AssertionMisc, commentAbove, commentBelow, commentRight, emptyMisc, magicHygieneConstant)
import Horus.SW.Builtin (Builtin, BuiltinOffsets (..))
import Horus.SW.Builtin qualified as Builtin (name)
import Horus.SW.FuncSpec (FuncSpec (..), FuncSpec', toFuncSpec)
import Horus.SW.ScopedName (ScopedName)
import Horus.SW.ScopedName qualified as ScopedName (fromText)
import Horus.SW.Storage (Storage)
import Horus.SW.Storage qualified as Storage (equivalenceExpr)
import Horus.SW.Storage qualified as Storage
import Horus.Util (enumerate, safeLast, tShow, whenJust, whenJustM)

data MemoryVariable = MemoryVariable
Expand All @@ -80,17 +80,19 @@ data AssertionType
= PreAssertion
| PostAssertion
| InstructionSemanticsAssertion
| ApConstraintAssertion
deriving (Eq, Show)

data CairoSemanticsF a
= Assert' (Expr TBool) AssertionType a
= Assert' (Expr TBool) AssertionType AssertionMisc a
| Expect' (Expr TBool) AssertionType a
| DeclareMem (Expr TFelt) (Expr TFelt -> a)
| DeclareLocalMem (Expr TFelt) (MemoryVariable -> a)
| GetApTracking Label (ApTracking -> a)
| GetBuiltinOffsets Label Builtin (Maybe BuiltinOffsets -> a)
| GetCallee LabeledInst (ScopedFunction -> a)
| GetFuncSpec ScopedFunction (FuncSpec' -> a)
| GetCurrentF (ScopedFunction -> a)
| GetFunPc Label (Label -> a)
| GetInlinable (Set.Set ScopedFunction -> a)
| GetMemVars ([MemoryVariable] -> a)
Expand All @@ -111,8 +113,8 @@ deriving instance Functor CairoSemanticsF

type CairoSemanticsL = F CairoSemanticsF

assert' :: AssertionType -> Expr TBool -> CairoSemanticsL ()
assert' assType a = liftF (Assert' a assType ())
assert' :: AssertionType -> AssertionMisc -> Expr TBool -> CairoSemanticsL ()
assert' assType misc a = liftF (Assert' a assType misc ())

expect' :: Expr TBool -> CairoSemanticsL ()
expect' a = liftF (Expect' a PostAssertion ())
Expand All @@ -129,6 +131,9 @@ getBuiltinOffsets l b = liftF (GetBuiltinOffsets l b id)
getCallee :: LabeledInst -> CairoSemanticsL ScopedFunction
getCallee call = liftF (GetCallee call id)

getCurrentF :: CairoSemanticsL ScopedFunction
getCurrentF = liftF (GetCurrentF id)

getFuncSpec :: ScopedFunction -> CairoSemanticsL FuncSpec'
getFuncSpec name = liftF (GetFuncSpec name id)

Expand All @@ -138,9 +143,6 @@ getFunPc l = liftF (GetFunPc l id)
getMemVars :: CairoSemanticsL [MemoryVariable]
getMemVars = liftF (GetMemVars id)

getMemVars :: CairoSemanticsL [MemoryVariable]
getMemVars = liftF (GetMemVars id)

declareLocalMem :: Expr TFelt -> CairoSemanticsL MemoryVariable
declareLocalMem address = liftF (DeclareLocalMem address id)

Expand All @@ -167,10 +169,27 @@ getStorage :: CairoSemanticsL Storage
getStorage = liftF (GetStorage id)

assert :: Expr TBool -> CairoSemanticsL ()
assert a = assert' InstructionSemanticsAssertion =<< memoryRemoval a
assert a = assert' InstructionSemanticsAssertion emptyMisc =<< memoryRemoval a

_assertWithComment :: Expr TBool -> Text -> CairoSemanticsL ()
_assertWithComment a text =
assert' InstructionSemanticsAssertion (commentRight text emptyMisc) =<< memoryRemoval a

assertWithComment' :: Expr TBool -> AssertionMisc -> CairoSemanticsL ()
assertWithComment' a misc =
assert' InstructionSemanticsAssertion misc =<< memoryRemoval a

_assertPre :: Expr TBool -> CairoSemanticsL ()
_assertPre = assertPreWithComment' emptyMisc

assertPre :: Expr TBool -> CairoSemanticsL ()
assertPre a = assert' PreAssertion =<< memoryRemoval a
commentHere' :: Text -> CairoSemanticsL ()
commentHere' = assertWithComment' magicHygieneConstant . flip commentRight emptyMisc

assertApConstraint :: Expr TBool -> CairoSemanticsL ()
assertApConstraint a = assert' ApConstraintAssertion emptyMisc =<< memoryRemoval a

assertPreWithComment' :: AssertionMisc -> Expr TBool -> CairoSemanticsL ()
assertPreWithComment' misc a = assert' PreAssertion misc =<< memoryRemoval a

expect :: Expr TBool -> CairoSemanticsL ()
expect a = expect' =<< memoryRemoval a
Expand Down Expand Up @@ -295,21 +314,17 @@ moduleEndAp mdl =
contain a storage update.
-}
encodeModule :: Module -> CairoSemanticsL ()
encodeModule mdl@(Module (FuncSpec pre post storage) instrs oracle _ _ mbPreCheckedFuncWithCallStack) = do
encodeModule mdl@(Module (FuncSpec (pre, preDesc) (post, _) storage) instrs oracle _ _ mbPreCheckedFuncWithCallStack) = do
enableStorage
fp <- getFp
apEnd <- moduleEndAp mdl
preparedStorage <- traverseStorage (prepare' apEnd fp) storage
encodePlainSpec mdl plainSpec
accumulatedStorage <- getStorage
unless (isOptimising mdl) $
expect (Storage.equivalenceExpr accumulatedStorage preparedStorage)
where
plainSpec = richToPlainSpec funcSpec

apStart <- moduleStartAp mdl
assert (fp .<= apStart)
assertPre =<< prepare' apStart fp (pre .&& apEqualsFp)
assertPreWithComment' (commentAbove ("Module pre: " <> preDesc) emptyMisc)
=<< prepare' apStart fp (pre .&& apEqualsFp)

-- The last FP might be influenced in the optimizing case, we need to grab it as propagated
-- by the encoding of the semantics of the call.
lastFp <-
Expand Down Expand Up @@ -418,19 +433,23 @@ mkInstructionConstraints ::
Map (NonEmpty Label, Label) Bool ->
(Int, LabeledInst) ->
CairoSemanticsL (Maybe (Expr TFelt))
mkInstructionConstraints instrs mbPreCheckedFuncWithCallStack jnzOracle (idx, lInst@(pc, Instruction{..})) = do
mkInstructionConstraints instrs mbPreCheckedFuncWithCallStack jnzOracle (idx, lInst@(pc, inst@Instruction{..})) = do
fp <- getFp
dst <- prepare pc fp (memory (regToVar i_dstRegister + fromInteger i_dstOffset))
case i_opCode of
Call -> mkCallConstraints pc nextPc fp mbPreCheckedFuncWithCallStack =<< getCallee lInst
AssertEqual -> getRes fp lInst >>= \res -> assert (res .== dst) $> Nothing
AssertEqual ->
getRes fp lInst >>= \res ->
assertWithComment' (res .== dst) (commentRight (toSemiAsmUnsafe inst) emptyMisc) $> Nothing
Nop -> do
stackTrace <- getOracle
case jnzOracle Map.!? (stackTrace, pc) of
Just False -> assert (dst .== 0) $> Nothing
Just True -> assert (dst ./= 0) $> Nothing
Nothing -> pure Nothing
Ret -> pop $> Nothing
Ret -> do
currentF <- getCurrentF
commentHere' ("Inlining ended: " <> tShow (sf_scopedName currentF)) >> pop $> Nothing
where
nextPc = getNextPcInlinedWithFallback instrs idx

Expand All @@ -447,7 +466,10 @@ mkCallConstraints pc nextPc fp mbPreCheckedFuncWithCallStack f = do
nextAp <- prepare pc calleeFp (Vars.fp .== Vars.ap + 2)
saveOldFp <- prepare pc fp (memory Vars.ap .== Vars.fp)
setNextPc <- prepare pc fp (memory (Vars.ap + 1) .== fromIntegral (unLabel nextPc))
assert (Expr.and [nextAp, saveOldFp, setNextPc])
canInline <- isInlinable f
assertWithComment'
(Expr.and [nextAp, saveOldFp, setNextPc])
(commentAbove (describeCall canInline <> tShow (sf_scopedName f)) emptyMisc)
push stackFrame
-- Considering we have already pushed the stackFrame by now, we need to make sure that either
-- the function is inlinable and we'll encounter a 'ret', or we need to pop right away
Expand All @@ -460,7 +482,7 @@ mkCallConstraints pc nextPc fp mbPreCheckedFuncWithCallStack f = do
-- of the function being invoked must not be considered.
guardWith (isInlinable f) (pure Nothing) $ do
-- An inlined function will have a 'ret' at some point, do not pop here.
(FuncSpec pre post storage) <- getFuncSpec f <&> toFuncSpec
(FuncSpec (pre, preDesc) (post, postDesc) storage) <- getFuncSpec f <&> toFuncSpec
let pre' = suffixLogicalVariables lvarSuffix pre
post' = suffixLogicalVariables lvarSuffix post
preparedPre <- prepare nextPc calleeFp =<< storageRemoval pre'
Expand All @@ -474,13 +496,22 @@ mkCallConstraints pc nextPc fp mbPreCheckedFuncWithCallStack f = do
-- However, pre will be checked in a separate 'optimising' module and we can therefore simply
-- assert it holds here. If it does not, the corresponding pre-checking module will fail,
-- thus failing the judgement for the entire function.
assert preparedPre
assert preparedPost
assertWithComment'
preparedPre
(commentAbove ("Pre of: " <> functionName <> " | " <> preDesc) emptyMisc)
assertWithComment'
preparedPost
( commentBelow
("Finished abstracting function: " <> functionName)
(commentAbove ("Post of: " <> functionName <> " | " <> postDesc) emptyMisc)
)
pure Nothing
where
lvarSuffix = "+" <> tShowLabel pc
calleePc = sf_pc f
stackFrame = (pc, calleePc)
functionName = tShow (sf_scopedName f)
describeCall inl = if inl then "Inlining function: " else "Abstracting function: "
-- Determine whether the current function matches the function being optimised exactly -
-- this necessitates comparing execution traces.
isModuleCheckingPre = do
Expand Down Expand Up @@ -510,15 +541,15 @@ mkApConstraints apEnd insts = do
ap2 <- getAp pcNext
fp <- getFp
getApIncrement fp lInst >>= \case
Just apIncrement -> assert (ap1 + apIncrement .== ap2)
Nothing | not isNewStackframe -> assert (ap1 .< ap2)
Just apIncrement -> assertApConstraint (ap1 + apIncrement .== ap2)
Nothing | not isNewStackframe -> assertApConstraint (ap1 .< ap2)
Nothing -> pure ()
lastAp <- getAp lastPc
when (isRet lastInst) pop
fp <- getFp
getApIncrement fp lastLInst >>= \case
Just lastApIncrement -> assert (lastAp + lastApIncrement .== apEnd)
Nothing -> assert (lastAp .< apEnd)
Just lastApIncrement -> assertApConstraint (lastAp + lastApIncrement .== apEnd)
Nothing -> assertApConstraint (lastAp .< apEnd)
where
lastLInst@(lastPc, lastInst) = NonEmpty.last insts

Expand Down
Loading

0 comments on commit 8e8119e

Please sign in to comment.