Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into feat/compile_prim
Browse files Browse the repository at this point in the history
  • Loading branch information
acl-cqc committed Oct 15, 2024
2 parents 2b2335e + 7e922da commit 9f9f54d
Show file tree
Hide file tree
Showing 52 changed files with 308 additions and 78 deletions.
80 changes: 74 additions & 6 deletions brat/Brat/Checker.hs
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ checkInputs tm@(WC fc _) (o:overs) (u:unders) = localFC fc $ do
addRowContext _ as bs (Err fc (TypeMismatch tm _ _))
= Err fc $ TypeMismatch tm (showRow as) (showRow bs)
addRowContext _ _ _ e = e
checkInputs tm [] unders = typeErr $ "No overs but unders: " ++ show unders ++ " for " ++ show tm
checkInputs tm [] unders = typeErr $ "No overs but unders: " ++ showRow unders ++ " for " ++ show tm

checkOutputs :: (CheckConstraints m k, ?my :: Modey m)
=> WC (Term Syn k)
Expand All @@ -160,7 +160,7 @@ checkOutputs tm@(WC fc _) (u:unders) (o:overs) = localFC fc $ do
addRowContext _ as bs (Err fc (TypeMismatch tm _ _))
= Err fc $ TypeMismatch tm (showRow as) (showRow bs)
addRowContext _ _ _ e = e
checkOutputs tm [] overs = typeErr $ "No unders but overs: " ++ show overs ++ " for " ++ show tm
checkOutputs tm [] overs = typeErr $ "No unders but overs: " ++ showRow overs ++ " for " ++ show tm

checkThunk :: (CheckConstraints m UVerb, EvMode m)
=> Modey m
Expand All @@ -171,9 +171,13 @@ checkThunk :: (CheckConstraints m UVerb, EvMode m)
checkThunk m name cty tm = do
((dangling, _), ()) <- let ?my = m in makeBox name cty $
\(thOvers, thUnders) -> do
(((), ()), (emptyOvers, emptyUnders)) <- check tm (thOvers, thUnders)
ensureEmpty "thunk leftovers" emptyOvers
ensureEmpty "thunk leftunders" emptyUnders
(((), ()), leftovers) <- check tm (thOvers, thUnders)
case leftovers of
([], []) -> pure ()
([], unders) -> err (ThunkLeftUnders (showRow unders))
-- If there are leftovers and leftunders, complain about the leftovers
-- Until we can report multiple errors!
(overs, _) -> err (ThunkLeftOvers (showRow overs))
pure dangling

check :: (CheckConstraints m k
Expand Down Expand Up @@ -252,7 +256,7 @@ check' (Lambda c@(WC abstFC abst, body) cs) (overs, unders) = do
solve ?my >>=
(solToEnv . snd)
(((), synthOuts), ((), ())) <- localEnv env $ check body ((), ())
pure synthOuts
pure synthOuts

sig <- mkSig usedOvers synthOuts
patOuts <- checkClauses sig usedOvers ((fst c, WC (fcOf body) (Emb body)) :| cs)
Expand Down Expand Up @@ -485,6 +489,70 @@ check' (Simple tm) ((), ((hungry, ty):unders)) = do
R0 (RPr ("value", vty) R0)
wire (dangling, vty, hungry)
pure (((), ()), ((), unders))
check' FanOut ((p, ty):overs, ()) = do
ty <- eval S0 (binderToValue ?my ty)
case ty of
TVec elTy n
| VNum n <- n
, Just n <- numValIsConstant n ->
if n < 0
then err (InternalError $ "Vector of negative length (" ++ show n ++ ")")
else do
wires <- fanoutNodes ?my n (p, valueToBinder ?my ty) elTy
pure (((), wires), (overs, ()))
| otherwise -> typeErr $ "Can't fanout a Vec with non-constant length: " ++ show n
_ -> typeErr "Fanout ([/\\]) only applies to Vec"
where
fanoutNodes :: Modey m -> Integer -> (Src, BinderType m) -> Val Z -> Checking [(Src, BinderType m)]
fanoutNodes _ 0 _ _ = pure []
fanoutNodes my n (dangling, ty) elTy = do
(_, [(hungry, _)], [danglingHead, danglingTail], _) <- anext "fanoutNodes" (Selector (plain "cons")) (S0, Some (Zy :* S0))
(RPr ("value", binderToValue my ty) R0)
((RPr ("head", elTy) (RPr ("tail", TVec elTy (VNum (nConstant (n - 1)))) R0)) :: Ro m Z Z)
-- Wire the input into the selector node
wire (dangling, binderToValue my ty, hungry)
(danglingHead:) <$> fanoutNodes my (n - 1) danglingTail elTy

check' FanIn (overs, ((tgt, ty):unders)) = do
ty <- eval S0 (binderToValue ?my ty)
case ty of
TVec elTy n
| VNum n <- n
, Just n <- numValIsConstant n ->
if n < 0
then err (InternalError $ "Vector of negative length (" ++ show n ++ ")")
else faninNodes ?my n (tgt, valueToBinder ?my ty) elTy overs >>= \case
Just overs -> pure (((), ()), (overs, unders))
Nothing -> typeErr ("Not enough inputs to make a vector of size " ++ show n)
| otherwise -> typeErr $ "Can't fanout a Vec with non-constant length: " ++ show n
_ -> typeErr "Fanin ([\\/]) only applies to Vec"
where
faninNodes :: Modey m
-> Integer -- The number of things left to pack up
-> (Tgt, BinderType m) -- The place to wire the resulting vector to
-> Val Z -- Element type
-> [(Src, BinderType m)] -- Overs
-> Checking (Maybe [(Src, BinderType m)]) -- Leftovers
faninNodes my 0 (tgt, ty) elTy overs = do
(_, _, [(dangling, _)], _) <- anext "nil" (Constructor (plain "nil")) (S0, Some (Zy :* S0))
(R0 :: Ro m Z Z)
(RPr ("value", TVec elTy (VNum nZero)) R0)
wire (dangling, binderToValue my ty, tgt)
pure (Just overs)
faninNodes _ _ _ _ [] = pure Nothing
faninNodes my n (hungry, ty) elTy ((over, overTy):overs) = do
let k = case my of
Kerny -> Dollar []
Braty -> Star []
typeEq (show FanIn) k elTy (binderToValue my overTy)
let tailTy = TVec elTy (VNum (nConstant (n - 1)))
(_, [(hungryHead, _), (hungryTail, tailTy)], [(danglingResult, _)], _) <- anext "faninNodes" (Constructor (plain "cons")) (S0, Some (Zy :* S0))
((RPr ("head", elTy) (RPr ("tail", tailTy) R0)) :: Ro m Z Z)
(RPr ("value", binderToValue my ty) R0)
wire (over, elTy, hungryHead)
wire (danglingResult, binderToValue ?my ty, hungry)
faninNodes my (n - 1) (hungryTail, tailTy) elTy overs
check' Identity ((this:leftovers), ()) = pure (((), [this]), (leftovers, ()))
check' tm _ = error $ "check' " ++ show tm


Expand Down
3 changes: 3 additions & 0 deletions brat/Brat/Elaborator.hs
Original file line number Diff line number Diff line change
Expand Up @@ -181,9 +181,12 @@ elaborate' (FAnnotation a ts) = do
elaborate' (FInto a b) = elaborate' (FApp b a)
elaborate' (FFn cty) = pure $ SomeRaw' (RFn cty)
elaborate' (FKernel sty) = pure $ SomeRaw' (RKernel sty)
elaborate' FIdentity = pure $ SomeRaw' RIdentity
-- We catch underscores in the top-level elaborate so this case
-- should never be triggered
elaborate' FUnderscore = Left (dumbErr (InternalError "Unexpected '_'"))
elaborate' FFanOut = pure $ SomeRaw' RFanOut
elaborate' FFanIn = pure $ SomeRaw' RFanIn

elabBody :: FBody -> FC -> Either Error (FunBody Raw Noun)
elabBody (FClauses cs) fc = ThunkOf . WC fc . Clauses <$> traverse elab1Clause cs
Expand Down
7 changes: 7 additions & 0 deletions brat/Brat/Error.hs
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,10 @@ data ErrorMsg
| WrongModeForType String
-- TODO: Add file context here
| CompilingHoles [String]
-- For thunks which don't address enough inputs, or produce enough outputs.
-- The argument is the row of unused connectors
| ThunkLeftOvers String
| ThunkLeftUnders String

instance Show ErrorMsg where
show (TypeErr x) = "Type error: " ++ x
Expand Down Expand Up @@ -164,6 +168,9 @@ instance Show ErrorMsg where
show (CompilingHoles hs) = unlines ("Can't compile file with remaining holes": indent hs)
where
indent = fmap (" " ++)

Check warning on line 170 in brat/Brat/Error.hs

View workflow job for this annotation

GitHub Actions / build

• The Monomorphism Restriction applies to the binding for ‘indent’

Check warning on line 170 in brat/Brat/Error.hs

View workflow job for this annotation

GitHub Actions / build

• The Monomorphism Restriction applies to the binding for ‘indent’
show (ThunkLeftOvers overs) = "Expected function to address all inputs, but " ++ overs ++ " wasn't used"
show (ThunkLeftUnders unders) = "Expected function to return additional values of type: " ++ unders


data Error = Err { fc :: Maybe FC
, msg :: ErrorMsg
Expand Down
5 changes: 4 additions & 1 deletion brat/Brat/Eval.hs
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,10 @@ eqWorker tm lvkz (TypeFor _ []) (SSum m0 stk0 rs0) (SSum m1 stk1 rs1)
Just rs -> traverse eqVariant rs <&> sequence_
where
eqVariant (Some r0, Some r1) = eqRowTest m0 tm lvkz (stk0,r0) (stk1,r1) <&> dropRight
eqWorker tm _ _ v0 v1 = pure . Left $ TypeMismatch tm (show v0) (show v1)
eqWorker tm _ _ s0 s1 = do
v0 <- quote Zy s0
v1 <- quote Zy s1
pure . Left $ TypeMismatch tm (show v0) (show v1)

-- Type rows have bot0,bot1 dangling de Bruijn indices, which we instantiate with
-- de Bruijn levels. As we go under binders in these rows, we add to the scope's
Expand Down
4 changes: 2 additions & 2 deletions brat/Brat/Lexer/Flat.hs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ space = (many $ (satisfy isSpace >> return ()) <|> comment) >> return ()
comment = string "--" *> ((printChar `manyTill` lookAhead (void newline <|> void eof)) >> return ())

tok :: Lexer Tok
tok = ( try (char '(' $> LParen)
tok = try (char '(' $> LParen)
<|> try (char ')' $> RParen)
<|> try (char '{' $> LBrace)
<|> try (char '}' $> RBrace)
Expand All @@ -62,6 +62,7 @@ tok = ( try (char '(' $> LParen)
<|> try (Number <$> number)
<|> try (string "+" $> Plus)
<|> try (string "/" $> Slash)
<|> try (string "\\" $> Backslash)
<|> try (string "^" $> Caret)
<|> try (string "->") $> Arrow
<|> try (string "=>") $> FatArrow
Expand Down Expand Up @@ -89,7 +90,6 @@ tok = ( try (char '(' $> LParen)
<|> try (K <$> try keyword)
<|> try qualified
<|> Ident <$> ident
)
where
float :: Lexer Double
float = label "float literal" $ do
Expand Down
2 changes: 2 additions & 0 deletions brat/Brat/Lexer/Token.hs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ data Tok
| Plus
| Minus
| Asterisk
| Backslash
| Slash
| Caret
| Hash
Expand Down Expand Up @@ -80,6 +81,7 @@ instance Show Tok where
show Plus = "+"
show Minus = "-"
show Asterisk = "*"
show Backslash = "\\"
show Slash = "/"
show Caret = "^"
show Hash = "#"
Expand Down
6 changes: 6 additions & 0 deletions brat/Brat/Parser.hs
Original file line number Diff line number Diff line change
Expand Up @@ -487,17 +487,23 @@ expr' p = choice $ (try . getParser <$> enumFrom p) ++ [atomExpr]
Nothing -> unWC expr
Just rest -> FJuxt expr rest

fanout = square (FFanOut <$ match Slash <* match Backslash)
fanin = square (FFanIn <$ match Backslash <* match Slash)

-- Expressions which don't contain juxtaposition or operators
atomExpr :: Parser Flat
atomExpr = simpleExpr <|> round expr
where
simpleExpr = FHole <$> hole
<|> try (FSimple <$> simpleTerm)
<|> try fanout
<|> try fanin
<|> vec
<|> cthunk
<|> try (match DotDot $> FPass)
<|> var
<|> match Underscore $> FUnderscore
<|> match Pipe $> FIdentity


cnoun :: Parser Flat -> Parser (WC (Raw 'Chk 'Noun))
Expand Down
3 changes: 3 additions & 0 deletions brat/Brat/Syntax/Concrete.hs
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,7 @@ data Flat
| FKernel RawKType
| FUnderscore
| FPass
| FFanOut
| FFanIn
| FIdentity
deriving Show
7 changes: 7 additions & 0 deletions brat/Brat/Syntax/Core.hs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ data Term :: Dir -> Kind -> Type where
Forget :: WC (Term d KVerb) -> Term d UVerb
Pull :: [PortName] -> WC (Term Chk k) -> Term Chk k
Var :: UserName -> Term Syn Noun -- Look up in noun (value) env
Identity :: Term Syn UVerb
Arith :: ArithOp -> WC (Term Chk Noun) -> WC (Term Chk Noun) -> Term Chk Noun
-- Type annotations (annotating a term with its outputs)
(:::) :: WC (Term Chk Noun) -> [Output] -> Term Syn Noun
Expand All @@ -66,6 +67,8 @@ data Term :: Dir -> Kind -> Type where
C :: CType' (PortName, KindOr (Term Chk Noun)) -> Term Chk Noun
-- Kernel types
K :: CType' (PortName, Term Chk Noun) -> Term Chk Noun
FanOut :: Term Syn UVerb
FanIn :: Term Chk UVerb

deriving instance Eq (Term d k)

Expand Down Expand Up @@ -106,6 +109,7 @@ instance Show (Term d k) where
showList ps = concatMap (++":") ps

show (Var x) = show x
show Identity = "|"
-- Nested applications should be bracketed too, hence 4 instead of 3
show (fun :$: arg) = bracket PApp fun ++ ('(' : show arg ++ ")")
show (tm ::: ty) = bracket PAnn tm ++ " :: " ++ show ty
Expand All @@ -126,6 +130,9 @@ instance Show (Term d k) where

show (C f) = "{" ++ show f ++ "}"
show (K (ss :-> ts)) = "{" ++ showSig ss ++ " -o " ++ showSig ts ++ "}"
show FanOut = "[/\\]"
show FanIn = "[\\/]"


-- Wrap a term in brackets if its `precedence` is looser than `n`
bracket :: Precedence -> WC (Term d k) -> String
Expand Down
9 changes: 9 additions & 0 deletions brat/Brat/Syntax/Raw.hs
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ data Raw :: Dir -> Kind -> Type where
RForget :: WC (Raw d KVerb) -> Raw d UVerb
RPull :: [PortName] -> WC (Raw Chk k) -> Raw Chk k
RVar :: UserName -> Raw Syn Noun
RIdentity :: Raw Syn UVerb
RArith :: ArithOp -> WC (Raw Chk Noun) -> WC (Raw Chk Noun) -> Raw Chk Noun
(:::::) :: WC (Raw Chk Noun) -> [RawIO] -> Raw Syn Noun
(::-::) :: WC (Raw Syn k) -> WC (Raw d UVerb) -> Raw d k -- vertical juxtaposition (diagrammatic composition)
Expand All @@ -80,6 +81,8 @@ data Raw :: Dir -> Kind -> Type where
RFn :: RawCType -> Raw Chk Noun
-- Kernel types
RKernel :: RawKType -> Raw Chk Noun
RFanOut :: Raw Syn UVerb
RFanIn :: Raw Chk UVerb

class Dirable d where
dir :: Raw d k -> Diry d
Expand Down Expand Up @@ -110,6 +113,7 @@ instance Show (Raw d k) where
show (RPull [] x) = "[]:" ++ show x
show (RPull ps x) = concat ((++":") <$> ps) ++ show x
show (RVar x) = show x
show RIdentity = "|"
show (RArith op a b) = "(" ++ show op ++ " " ++ show a ++ " " ++ show b ++ ")"
show (fun ::$:: arg) = show fun ++ ('(' : show arg ++ ")")
show (tm ::::: ty) = show tm ++ " :: " ++ show ty
Expand All @@ -121,6 +125,8 @@ instance Show (Raw d k) where
show (RCon c xs) = "Con(" ++ show c ++ "(" ++ show xs ++ "))"
show (RFn cty) = show cty
show (RKernel cty) = show cty
show RFanOut = "[/\\]"
show RFanIn = "[\\/]"

type Desugar = StateT Namespace (ReaderT (RawEnv, Bwd UserName) (Except Error))

Expand Down Expand Up @@ -223,6 +229,7 @@ instance (Kindable k) => Desugarable (Raw d k) where
desugar' (RForget kv) = Forget <$> desugar kv
desugar' (RPull ps raw) = Pull ps <$> desugar raw
desugar' (RVar name) = pure $ Var name
desugar' RIdentity = pure Identity
desugar' (RArith op a b) = Arith op <$> desugar a <*> desugar b
desugar' (fun ::$:: arg) = (:$:) <$> desugar fun <*> desugar arg
desugar' (tm ::::: outputs) = do
Expand All @@ -235,6 +242,8 @@ instance (Kindable k) => Desugarable (Raw d k) where
desugar' (RCon c arg) = Con c <$> desugar arg
desugar' (RFn cty) = C <$> desugar' cty
desugar' (RKernel cty) = K <$> desugar' cty
desugar' RFanOut = pure FanOut
desugar' RFanIn = pure FanIn

instance Desugarable ty => Desugarable (PortName, ty) where
type Desugared (PortName, ty) = (PortName, Desugared ty)
Expand Down
8 changes: 8 additions & 0 deletions brat/Brat/Syntax/Value.hs
Original file line number Diff line number Diff line change
Expand Up @@ -592,3 +592,11 @@ copyable (VApp _ _) = Just False
copyable (TVec elem _) = copyable elem
copyable TBit = Just True
copyable _ = Nothing

stkLen :: Stack Z t tot -> Ny tot
stkLen S0 = Zy
stkLen (zx :<< _) = Sy (stkLen zx)

numValIsConstant :: NumVal (VVar Z) -> Maybe Integer
numValIsConstant (NumValue up Constant0) = pure up
numValIsConstant _ = Nothing
6 changes: 6 additions & 0 deletions brat/Brat/Unelaborator.hs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ unelab dy _ (Lambda (abs,rhs) cs) = FLambda ((abs, unelab dy Nouny <$> rhs) :| (
unelab _ _ (Con c args) = FCon c (unelab Chky Nouny <$> args)
unelab _ _ (C (ss :-> ts)) = FFn (toRawRo ss :-> toRawRo ts)
unelab _ _ (K cty) = FKernel $ fmap (\(p, ty) -> Named p (toRaw ty)) cty
unelab _ _ Identity = FIdentity
unelab _ _ FanIn = FFanIn
unelab _ _ FanOut = FFanOut

-- This is needed for concrete terms which embed a type as a list of `Raw` things
toRaw :: Term d k -> Raw d k
Expand All @@ -61,6 +64,9 @@ toRaw (Lambda (abs,rhs) cs) = RLambda (abs, toRaw <$> rhs) (second (fmap toRaw)
toRaw (Con c args) = RCon c (toRaw <$> args)
toRaw (C (ss :-> ts)) = RFn (toRawRo ss :-> toRawRo ts)
toRaw (K cty) = RKernel $ (\(p, ty) -> Named p (toRaw ty)) <$> cty
toRaw Identity = RIdentity
toRaw FanIn = RFanIn
toRaw FanOut = RFanOut

toRawRo :: [(PortName, KindOr (Term Chk Noun))] -> [TypeRowElem (KindOr RawVType)]
toRawRo = fmap (\(p, bty) -> Named p (second toRaw bty))
9 changes: 3 additions & 6 deletions brat/examples/adder.brat
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,13 @@ and(Bool, Bool) -> Bool
and(true, b) = b
and(false, b) = false

id(Bool) -> Bool
id(b) = b

halfAdder(Bool, Bool) -> twos :: Bool, ones :: Bool
halfAdder(a, b) = and(a,b), xor(a,b)

fullAdder(Bool, Bool, Bool) -> twos :: Bool, ones :: Bool
fullAdder = halfAdder, id;
id, halfAdder;
xor, id
fullAdder = halfAdder, |;
|, halfAdder;
xor, |

adder(n :: #, Vec(Bool, n), Vec(Bool, n), carryIn :: Bool) -> carryOut :: Bool, Vec(Bool, n)
adder(0, [], [], b) = b, []
Expand Down
4 changes: 2 additions & 2 deletions brat/examples/bell.brat
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ idq = { q => q }

phasedBell :: { (th :: Float) -> { c :: Qubit, d :: Qubit -o c :: Bit, d :: Bit } }
phasedBell(th) = {
H,qid;
Rz(th),qid;
H,|;
Rz(th),|;
M,M
}
Loading

0 comments on commit 9f9f54d

Please sign in to comment.