Skip to content

Commit

Permalink
[refactor] vectorise: take expected mode, make return type explicit (#58
Browse files Browse the repository at this point in the history
)

It's confusing where responsibility for error handling (expected
function, got wrong mode/type) belongs, hence "-- These shouldn't
happen" type comments - so, make that clear, removing those comments :)

This shortens quite a lot of stuff and makes `getThunks` much simpler,
 albeit at the cost of a nasty case/match in `mkMapFun` because
`next` loses the type of what we've passed in. (A problem in common with
other clients of `next`.)
---------

Co-authored-by: Craig Roy <[email protected]>
  • Loading branch information
acl-cqc and croyzor authored Nov 27, 2024
1 parent 7458577 commit cca5a1d
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 39 deletions.
72 changes: 36 additions & 36 deletions brat/Brat/Checker/Helpers.hs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ import Util (log2)

import Control.Monad.Freer (req)
import Data.Bifunctor
import Data.Foldable (foldrM)
import Data.Type.Equality (TestEquality(..), (:~:)(..))
import qualified Data.Map as M
import Prelude hiding (last)
Expand Down Expand Up @@ -247,22 +248,19 @@ getThunks :: Modey m
,Overs m UVerb
)
getThunks _ [] = pure ([], [], [])
getThunks Braty row@((src, Right ty):rest) = (eval S0 ty >>= vectorise . (src,)) >>= \case
(src, VFun Braty (ss :->> ts)) -> do
(node, unders, overs, _) <- let ?my = Braty in
anext "" (Eval (end src)) (S0, Some (Zy :* S0)) ss ts
(nodes, unders', overs') <- getThunks Braty rest
pure (node:nodes, unders <> unders', overs <> overs')
-- These shouldn't happen
(_, VFun _ _) -> err $ ExpectedThunk (showMode Braty) (showRow row)
v -> typeErr $ "Force called on non-thunk: " ++ show v
getThunks Kerny row@((src, Right ty):rest) = (eval S0 ty >>= vectorise . (src,)) >>= \case
(src, VFun Kerny (ss :->> ts)) -> do
(node, unders, overs, _) <- let ?my = Kerny in anext "" (Splice (end src)) (S0, Some (Zy :* S0)) ss ts
(nodes, unders', overs') <- getThunks Kerny rest
pure (node:nodes, unders <> unders', overs <> overs')
(_, VFun _ _) -> err $ ExpectedThunk (showMode Kerny) (showRow row)
v -> typeErr $ "Force called on non-(kernel)-thunk: " ++ show v
getThunks Braty ((src, Right ty):rest) = do
ty <- eval S0 ty
(src, (ss :->> ts)) <- vectorise Braty (src, ty)
(node, unders, overs, _) <- let ?my = Braty in
anext "" (Eval (end src)) (S0, Some (Zy :* S0)) ss ts
(nodes, unders', overs') <- getThunks Braty rest
pure (node:nodes, unders <> unders', overs <> overs')
getThunks Kerny ((src, Right ty):rest) = do
ty <- eval S0 ty
(src, (ss :->> ts)) <- vectorise Kerny (src,ty)
(node, unders, overs, _) <- let ?my = Kerny in anext "" (Splice (end src)) (S0, Some (Zy :* S0)) ss ts
(nodes, unders', overs') <- getThunks Kerny rest
pure (node:nodes, unders <> unders', overs <> overs')
getThunks Braty ((src, Left (Star args)):rest) = do
(node, unders, overs) <- case bwdStack (B0 <>< args) of
Some (_ :* stk) -> do
Expand All @@ -274,15 +272,15 @@ getThunks Braty ((src, Left (Star args)):rest) = do
getThunks m ro = err $ ExpectedThunk (showMode m) (showRow ro)

-- The type given here should be normalised
vecLayers :: Val Z -> Checking ([(Src, NumVal (VVar Z))] -- The sizes of the vector layers
,Some (Modey :* Flip CTy Z) -- The function type at the end
)
vecLayers (TVec ty (VNum n)) = do
vecLayers :: Modey m -> Val Z -> Checking ([(Src, NumVal (VVar Z))] -- The sizes of the vector layers
,CTy m Z -- The function type at the end
)
vecLayers my (TVec ty (VNum n)) = do
src <- mkStaticNum n
(layers, fun) <- vecLayers ty
pure ((src, n):layers, fun)
vecLayers (VFun my cty) = pure ([], Some (my :* Flip cty))
vecLayers ty = typeErr $ "Expected a function or vector of functions, got " ++ show ty
first ((src, n):) <$> vecLayers my ty
vecLayers Braty (VFun Braty cty) = pure ([], cty)
vecLayers Kerny (VFun Kerny cty) = pure ([], cty)
vecLayers my ty = typeErr $ "Expected a " ++ showMode my ++ "function or vector of functions, got " ++ show ty

mkStaticNum :: NumVal (VVar Z) -> Checking Src
mkStaticNum n@(NumValue c gro) = do
Expand Down Expand Up @@ -330,27 +328,29 @@ mkStaticNum n@(NumValue c gro) = do
wire (oneSrc, TNat, rhs)
pure src

vectorise :: (Src, Val Z) -> Checking (Src, Val Z)
vectorise (src, ty) = do
(layers, Some (my :* Flip cty)) <- vecLayers ty
modily my $ mkMapFuns (src, VFun my cty) layers
vectorise :: forall m. Modey m -> (Src, Val Z) -> Checking (Src, CTy m Z)
vectorise my (src, ty) = do
(layers, cty) <- vecLayers my ty
modily my $ foldrM mkMapFun (src, cty) layers
where
mkMapFuns :: (Src, Val Z) -- The input to the mapfun
-> [(Src, NumVal (VVar Z))] -- Remaining layers
-> Checking (Src, Val Z)
mkMapFuns over [] = pure over
mkMapFuns (valSrc, ty) ((lenSrc, len):layers) = do
(valSrc, ty@(VFun my cty)) <- mkMapFuns (valSrc, ty) layers
mkMapFun :: (Src, NumVal (VVar Z)) -- Layer to apply
-> (Src, CTy m Z) -- The input to this level of mapfun
-> Checking (Src, CTy m Z)
mkMapFun (lenSrc, len) (valSrc, cty) = do
let weak1 = changeVar (Thinning (ThDrop ThNull))

Check warning on line 340 in brat/Brat/Checker/Helpers.hs

View workflow job for this annotation

GitHub Actions / build

• The Monomorphism Restriction applies to the binding for ‘weak1’

Check warning on line 340 in brat/Brat/Checker/Helpers.hs

View workflow job for this annotation

GitHub Actions / build

• The Monomorphism Restriction applies to the binding for ‘weak1’

Check warning on line 340 in brat/Brat/Checker/Helpers.hs

View workflow job for this annotation

GitHub Actions / build

• The Monomorphism Restriction applies to the binding for ‘weak1’
vecFun <- vectorisedFun len my cty
(_, [(lenTgt,_), (valTgt, _)], [(vectorSrc, Right vecTy)], _) <-
(_, [(lenTgt,_), (valTgt, _)], [(vectorSrc, Right (VFun my' cty))], _) <-
next "" MapFun (S0, Some (Zy :* S0))
(REx ("len", Nat) (RPr ("value", weak1 ty) R0))
(RPr ("vector", weak1 vecFun) R0)
defineTgt lenTgt (VNum len)
wire (lenSrc, kindType Nat, lenTgt)
wire (valSrc, ty, valTgt)
pure (vectorSrc, vecTy)
let vecCTy = case (my,my',cty) of
(Braty,Braty,cty) -> cty
(Kerny,Kerny,cty) -> cty
_ -> error "next returned wrong mode of computation type to that passed in"
pure (vectorSrc, vecCTy)

vectorisedFun :: NumVal (VVar Z) -> Modey m -> CTy m Z -> Checking (Val Z)
vectorisedFun nv my (ss :->> ts) = do
Expand Down
4 changes: 1 addition & 3 deletions brat/test/golden/kernel/kernel_application.brat.golden
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,5 @@ Error in test/golden/kernel/kernel_application.brat on line 16:
rotate = { q => maybeRotate(true) }
^^^^^^^^^^^

Expected function to be a (kernel) thunk, but found:
(thunk :: { (a1 :: Bool) -> (a1 :: { (a1 :: Qubit) -o (a1 :: Qubit) }) })

Type error: Expected a (kernel) function or vector of functions, got { (a1 :: Bool) -> (a1 :: { (a1 :: Qubit) -o (a1 :: Qubit) }) }

0 comments on commit cca5a1d

Please sign in to comment.