Skip to content

WIP: add certificate path to CompiledCode #7103

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 4 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions plutus-benchmark/plutus-benchmark.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -447,10 +447,11 @@ library script-contexts-internal
PlutusBenchmark.V3.ScriptContexts

build-depends:
, base >=4.9 && <5
, plutus-ledger-api ^>=1.46
, plutus-tx ^>=1.46
, plutus-tx-plugin ^>=1.46
, base >=4.9 && <5
, plutus-ledger-api ^>=1.46
, plutus-tx ^>=1.46
, plutus-tx-plugin ^>=1.46
, plutus-tx-test-util

test-suite plutus-benchmark-script-contexts-tests
import: lang, ghc-version-support, os-support
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeApplications #-}
{-# OPTIONS_GHC -fplugin-opt PlutusTx.Plugin:certify=ScriptContextCert #-}

module PlutusBenchmark.V2.Data.ScriptContexts where

Expand All @@ -20,6 +21,8 @@ import PlutusTx.Data.List qualified as DataList
import PlutusTx.Plugin ()
import PlutusTx.Prelude qualified as PlutusTx

import PlutusTx.Test.Util.Compiled (compiledCodeToCertPath)

-- | A very crude deterministic generator for 'ScriptContext's with size
-- approximately proportional to the input integer.
mkScriptContext :: Integer -> ScriptContext
Expand Down Expand Up @@ -276,10 +279,14 @@ forwardWithStakeTrickManual r_stake_cred r_ctx =
mkForwardWithStakeTrickManualCode
:: StakingCredential
-> ScriptContext
-> PlutusTx.CompiledCode ()
-> (Maybe FilePath, PlutusTx.CompiledCode ())
mkForwardWithStakeTrickManualCode cred ctx =
let c = PlutusTx.toBuiltinData cred
sc = PlutusTx.toBuiltinData ctx
in $$(PlutusTx.compile [|| forwardWithStakeTrickManual ||])
code = $$(PlutusTx.compile [|| forwardWithStakeTrickManual ||])
in
( compiledCodeToCertPath code
, code
`PlutusTx.unsafeApplyCode` PlutusTx.liftCodeDef c
`PlutusTx.unsafeApplyCode` PlutusTx.liftCodeDef sc
)
4 changes: 3 additions & 1 deletion plutus-benchmark/script-contexts/test/V2/Spec.hs
Original file line number Diff line number Diff line change
Expand Up @@ -193,20 +193,22 @@ testDataFwdStakeTrick =

testDataFwdStakeTrickManual :: TestTree
testDataFwdStakeTrickManual =
-- testGroup "testing" $
runTestGhcSOP
[ Tx.goldenPirReadable "dataFwdStakeTrickManual" testAbsCode
, Tx.goldenUPlcReadable "dataFwdStakeTrickManual" testAbsCode
, Tx.goldenBudget "dataFwdStakeTrickManual" testCode
, Tx.goldenEvalCekCatch "dataFwdStakeTrickManual" [testCode]
]
-- : [testCase "testCert" $ ]
where
testCredential =
Data.SC.mkStakingCredential "someCredential"
testScriptContext =
Data.SC.mkScriptContextWithStake 20 20 (Just (testCredential, 1))
testAbsCode =
$$(PlutusTx.compile [|| Data.SC.forwardWithStakeTrickManual ||])
testCode =
(mcert, testCode) =
Data.SC.mkForwardWithStakeTrickManualCode testCredential testScriptContext

allTests :: TestTree
Expand Down
1 change: 1 addition & 0 deletions plutus-metatheory/src/Certifier.hs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ module Certifier (
, prettyCertifierError
, prettyCertifierSuccess
, CertifierError (..)
, CertifierSuccess (..)
) where

import Control.Monad ((>=>))
Expand Down
63 changes: 62 additions & 1 deletion plutus-tx-plugin/src/PlutusTx/Compiler/Types.hs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ import Control.Monad.Reader
import Control.Monad.State
import Control.Monad.Writer

import Flat (Flat (..))

import Data.Map qualified as Map
import Data.Set (Set)
import Data.Set qualified as Set
Expand Down Expand Up @@ -191,14 +193,73 @@ stableModuleCmp m1 m2 =
-- See Note [Stable name comparisons]
(GHC.moduleUnit m1 `GHC.stableUnitCmp` GHC.moduleUnit m2)

newtype CertificatePath = CertificatePath
{ getCertPath :: Maybe FilePath
}

instance Flat CertificatePath where
encode (CertificatePath mp) = encode mp
decode = CertificatePath <$> decode
size (CertificatePath mp) = size mp

instance Semigroup CertificatePath where
CertificatePath p1 <> CertificatePath p2 =
case (p1, p2) of
(Nothing, Nothing) -> CertificatePath Nothing
(Nothing, Just p) -> CertificatePath (Just p)
(Just p, Nothing) -> CertificatePath (Just p)
-- Overwrite the old path with the new path
(Just _, Just p) -> CertificatePath (Just p)

instance Monoid CertificatePath where
mempty = CertificatePath Nothing

data CompileOutput = CompileOutput
{ coCoverageIndex :: CoverageIndex
, coCertPath :: CertificatePath
}

instance Semigroup CompileOutput where
CompileOutput i1 c1 <> CompileOutput i2 c2 =
CompileOutput (i1 <> i2) (c1 <> c2)

instance Monoid CompileOutput where
mempty = CompileOutput mempty mempty

instance Flat CompileOutput where
encode (CompileOutput i c) = encode i <> encode c
decode = CompileOutput <$> decode <*> decode
size (CompileOutput i c) x = size i x + size c x

-- | Include a location coverage annotation in the index
addLocationToCoverageIndex :: MonadWriter CompileOutput m => CovLoc -> m CoverageAnnotation
addLocationToCoverageIndex src = do
let ann = CoverLocation src
tell $ CompileOutput (CoverageIndex $ Map.singleton ann mempty) mempty
pure ann

-- | Include a boolean coverage annotation in the index
addBoolCaseToCoverageIndex :: MonadWriter CompileOutput m
=> CovLoc -> Bool -> CoverageMetadata -> m CoverageAnnotation
addBoolCaseToCoverageIndex src b meta = do
let ann = CoverBool src b
tell $ CompileOutput (CoverageIndex (Map.singleton ann meta)) mempty
pure ann

addCertificatePath :: MonadWriter CompileOutput m => FilePath -> m ()
addCertificatePath path = do
let certPath = CertificatePath (Just path)
tell $ CompileOutput mempty certPath

-- See Note [Scopes]
type Compiling uni fun m ann =
( MonadError (CompileError uni fun ann) m
, MonadQuote m
, MonadReader (CompileContext uni fun) m
, MonadState CompileState m
, MonadDefs LexName uni fun Ann m
, MonadWriter CoverageIndex m
-- TODO: fix
, MonadWriter CompileOutput m
)

-- Packing up equality constraints gives us a nice way of writing type signatures as this way
Expand Down
36 changes: 20 additions & 16 deletions plutus-tx-plugin/src/PlutusTx/Plugin.hs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ import PlutusTx.Compiler.Error
import PlutusTx.Compiler.Expr
import PlutusTx.Compiler.Trace
import PlutusTx.Compiler.Types
import PlutusTx.Coverage
import PlutusTx.Function qualified
import PlutusTx.Optimize.Inline qualified
import PlutusTx.PIRTypes
Expand Down Expand Up @@ -419,14 +418,15 @@ compileMarkedExpr locStr codeTy origE = do
-- See Note [Occurrence analysis]
let origE' = GHC.occurAnalyseExpr origE

((pirP,uplcP), covIdx) <- runWriterT . runQuoteT . flip runReaderT ctx . flip evalStateT st $
((pirP,uplcP), compOut) <- runWriterT . runQuoteT . flip runReaderT ctx . flip evalStateT st $
traceCompilation 1 ("Compiling expr at" GHC.<+> GHC.text locStr) $
runCompiler moduleNameStr opts origE'

-- serialize the PIR, PLC, and coverageindex outputs into a bytestring.
bsPir <- makeByteStringLiteral $ flat pirP
bsPlc <- makeByteStringLiteral $ flat (UPLC.UnrestrictedProgram uplcP)
covIdxFlat <- makeByteStringLiteral $ flat covIdx
compOutFlat <- makeByteStringLiteral $ flat compOut


builder <- lift . lift . GHC.lookupId =<< thNameToGhcNameOrFail 'mkCompiledCode

Expand All @@ -436,7 +436,7 @@ compileMarkedExpr locStr codeTy origE = do
`GHC.App` GHC.Type codeTy
`GHC.App` bsPlc
`GHC.App` bsPir
`GHC.App` covIdxFlat
`GHC.App` compOutFlat

-- | The GHC.Core to PIR to PLC compiler pipeline. Returns both the PIR and PLC output.
-- It invokes the whole compiler chain: Core expr -> PIR expr -> PLC expr -> UPLC expr.
Expand All @@ -446,7 +446,7 @@ runCompiler ::
, fun ~ PLC.DefaultFun
, MonadReader (CompileContext uni fun) m
, MonadState CompileState m
, MonadWriter CoverageIndex m
, MonadWriter CompileOutput m
, MonadQuote m
, MonadError (CompileError uni fun Ann) m
, MonadIO m
Expand Down Expand Up @@ -560,15 +560,19 @@ runCompiler moduleName opts expr = do

let optCertify = opts ^. posCertify
(uplcP, simplTrace) <- flip runReaderT plcOpts $ PLC.compileProgramWithTrace plcP
liftIO $ case optCertify of
Just certName -> do
result <- runCertifier $ mkCertifier simplTrace certName
case result of
Right certSuccess ->
hPutStrLn stderr $ prettyCertifierSuccess certSuccess
Left err ->
hPutStrLn stderr $ prettyCertifierError err
Nothing -> pure ()
certP <-
liftIO $ case optCertify of
Just certName -> do
result <- runCertifier $ mkCertifier simplTrace certName
case result of
Right certSuccess -> do
hPutStrLn stderr $ prettyCertifierSuccess certSuccess
pure $ Just (certDir certSuccess)
Left err -> do
hPutStrLn stderr $ prettyCertifierError err
pure Nothing
Nothing -> pure Nothing
maybe (pure ()) addCertificatePath certP
dbP <- liftExcept $ traverseOf UPLC.progTerm UPLC.deBruijnTerm uplcP
when (opts ^. posDumpUPlc) . liftIO $
dumpFlat
Expand Down Expand Up @@ -643,8 +647,8 @@ stripTicks = \case
e -> e

-- | Helper to avoid doing too much construction of Core ourselves
mkCompiledCode :: forall a . BS.ByteString -> BS.ByteString -> BS.ByteString -> CompiledCode a
mkCompiledCode plcBS pirBS ci = SerializedCode plcBS (Just pirBS) (fold . unflat $ ci)
mkCompiledCode :: forall a . BS.ByteString -> BS.ByteString -> BS.ByteString -> Maybe CertPath -> CompiledCode a
mkCompiledCode plcBS pirBS ci mcp = SerializedCode plcBS (Just pirBS) (fold . unflat $ ci) mcp

-- | Make a 'NameInfo' mapping the given set of TH names to their
-- 'GHC.TyThing's for later reference.
Expand Down
6 changes: 6 additions & 0 deletions plutus-tx-test-util/testlib/PlutusTx/Test/Util/Compiled.hs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ module PlutusTx.Test.Util.Compiled
, toAnonDeBruijnProg
, toNamedDeBruijnTerm
, compiledCodeToTerm
, compiledCodeToCertPath
, haskellValueToTerm
, unsafeRunTermCek
, runTermCek
Expand Down Expand Up @@ -56,6 +57,11 @@ compiledCodeToTerm
:: Tx.CompiledCodeIn DefaultUni DefaultFun a -> Term
compiledCodeToTerm (Tx.getPlcNoAnn -> UPLC.Program _ _ body) = body

{- | Extract the path to the generated certificate, if one exists. -}
compiledCodeToCertPath
:: Tx.CompiledCodeIn DefaultUni DefaultFun a -> Maybe FilePath
compiledCodeToCertPath (Tx.getCertPath -> mpath) = mpath

{- | Lift a Haskell value to a PLC term. The constraints get a bit out of control
if we try to do this over an arbitrary universe.-}
haskellValueToTerm
Expand Down
9 changes: 6 additions & 3 deletions plutus-tx/src/PlutusTx.hs
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,16 @@ module PlutusTx (
makeLift,
safeLiftCode,
liftCode,
liftCodeDef) where
liftCodeDef,
getCovIdx,
getCertPath,
) where

import PlutusCore.Data (Data (..))
import PlutusTx.Blueprint.TH (makeIsDataSchemaIndexed)
import PlutusTx.Builtins (BuiltinData, builtinDataToData, dataToBuiltinData)
import PlutusTx.Code (CompiledCode, CompiledCodeIn, applyCode, getPir, getPirNoAnn, getPlc,
getPlcNoAnn, unsafeApplyCode)
import PlutusTx.Code (CompiledCode, CompiledCodeIn, applyCode, getCertPath, getCovIdx, getPir,
getPirNoAnn, getPlc, getPlcNoAnn, unsafeApplyCode)
import PlutusTx.IsData (FromData (..), ToData (..), UnsafeFromData (..), fromData,
makeIsDataIndexed, toData, unstableMakeIsData)
import PlutusTx.Lift (liftCode, liftCodeDef, makeLift, safeLiftCode)
Expand Down
27 changes: 19 additions & 8 deletions plutus-tx/src/PlutusTx/Code.hs
Original file line number Diff line number Diff line change
Expand Up @@ -45,16 +45,20 @@ type role CompiledCodeIn representational representational nominal
-- if you want to put it on the chain you must normalize the types first.
data CompiledCodeIn uni fun a =
-- | Serialized UPLC code and possibly serialized PIR code with metadata used for program coverage.
SerializedCode BS.ByteString (Maybe BS.ByteString) CoverageIndex
SerializedCode BS.ByteString (Maybe BS.ByteString) CoverageIndex (Maybe CertPath)
-- | Deserialized UPLC program, and possibly deserialized PIR program with metadata used for program coverage.
| DeserializedCode
(UPLC.Program UPLC.NamedDeBruijn uni fun SrcSpans)
(Maybe (PIR.Program PLC.TyName PLC.Name uni fun SrcSpans))
CoverageIndex
(Maybe CertPath)

-- | 'CompiledCodeIn' instantiated with default built-in types and functions.
type CompiledCode = CompiledCodeIn PLC.DefaultUni PLC.DefaultFun

-- | Type alias for the path to the certified compilation certificate, if one exists.
type CertPath = FilePath

-- | Apply a compiled function to a compiled argument. Will fail if the versions don't match.
applyCode
:: (PLC.Closed uni
Expand Down Expand Up @@ -87,7 +91,9 @@ applyCode fun arg = do
<> display argPir
(Nothing, Nothing) -> Left "Missing PIR for both the function program and the argument."

pure $ DeserializedCode uplc pir (getCovIdx fun <> getCovIdx arg)
-- I don't think it makes sense to compose certificates, so we just
-- return Nothing here.
pure $ DeserializedCode uplc pir (getCovIdx fun <> getCovIdx arg) Nothing

-- | Apply a compiled function to a compiled argument. Will throw if the versions don't match,
-- should only be used in non-production code.
Expand Down Expand Up @@ -122,10 +128,10 @@ getPlc
:: (PLC.Closed uni, uni `PLC.Everywhere` Flat, Flat fun)
=> CompiledCodeIn uni fun a -> UPLC.Program UPLC.NamedDeBruijn uni fun SrcSpans
getPlc wrapper = case wrapper of
SerializedCode plc _ _ -> case unflat (BSL.fromStrict plc) of
SerializedCode plc _ _ _ -> case unflat (BSL.fromStrict plc) of
Left e -> throw $ ImpossibleDeserialisationFailure e
Right (UPLC.UnrestrictedProgram p) -> p
DeserializedCode plc _ _ -> plc
DeserializedCode plc _ _ _ -> plc

getPlcNoAnn
:: (PLC.Closed uni, uni `PLC.Everywhere` Flat, Flat fun)
Expand All @@ -137,12 +143,12 @@ getPir
:: (PLC.Closed uni, uni `PLC.Everywhere` Flat, Flat fun)
=> CompiledCodeIn uni fun a -> Maybe (PIR.Program PIR.TyName PIR.Name uni fun SrcSpans)
getPir wrapper = case wrapper of
SerializedCode _ pir _ -> case pir of
SerializedCode _ pir _ _ -> case pir of
Just bs -> case unflat (BSL.fromStrict bs) of
Left e -> throw $ ImpossibleDeserialisationFailure e
Right p -> Just p
Nothing -> Nothing
DeserializedCode _ pir _ -> pir
DeserializedCode _ pir _ _ -> pir

getPirNoAnn
:: (PLC.Closed uni, uni `PLC.Everywhere` Flat, Flat fun)
Expand All @@ -151,5 +157,10 @@ getPirNoAnn = fmap void . getPir

getCovIdx :: CompiledCodeIn uni fun a -> CoverageIndex
getCovIdx wrapper = case wrapper of
SerializedCode _ _ idx -> idx
DeserializedCode _ _ idx -> idx
SerializedCode _ _ idx _ -> idx
DeserializedCode _ _ idx _ -> idx

getCertPath :: CompiledCodeIn uni fun a -> Maybe CertPath
getCertPath wrapper = case wrapper of
SerializedCode _ _ _ certPath -> certPath
DeserializedCode _ _ _ certPath -> certPath
Loading