Skip to content

Commit

Permalink
refactor: Move Nat building code to Helpers
Browse files Browse the repository at this point in the history
  • Loading branch information
croyzor committed Dec 3, 2024
1 parent 991e0dc commit f82f25c
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 99 deletions.
93 changes: 93 additions & 0 deletions brat/Brat/Checker/Helpers.hs
Original file line number Diff line number Diff line change
Expand Up @@ -506,3 +506,96 @@ buildConst :: SimpleTerm -> Val Z -> Checking Src
buildConst tm ty = do
(_, _, [(out,_)], _) <- next "const" (Const tm) (S0, Some (Zy :* S0)) R0 (RPr ("value", ty) R0)
pure out

buildNum :: Integer -> Checking Src
buildNum n = buildConst (Num (fromIntegral n)) TNat

-- Generate wiring to produce a dynamic instance of the numval argument
-- N.B. In these functions, we wire using Req, rather than the `wire` function
-- because we don't want it to do any extra evaluation.
buildNatVal :: NumVal (VVar Z) -> Checking Src
buildNatVal nv@(NumValue n gro) = case n of
0 -> buildGro gro
n -> do
nDangling <- buildNum n
((lhs,rhs),out) <- buildArithOp Add
src <- buildGro gro
req $ Wire (end nDangling, TNat, end lhs)
req $ Wire (end src, TNat, end rhs)
defineSrc out (VNum (nPlus n (nVar (VPar (toEnd src)))))
pure out
where
buildGro :: Fun00 (VVar Z) -> Checking Src
buildGro Constant0 = buildNum 0
buildGro (StrictMonoFun sm) = buildSM sm

buildSM :: StrictMono (VVar Z) -> Checking Src
buildSM (StrictMono k mono) = do
-- Calculate 2^k as `factor`
two <- buildNum 2
kDangling <- buildNum k
((lhs,rhs),factor) <- buildArithOp Pow
req $ Wire (end two, TNat, end lhs)
req $ Wire (end kDangling, TNat, end rhs)
-- Multiply mono by 2^k
((lhs,rhs),out) <- buildArithOp Mul
monoDangling <- buildMono mono
req $ Wire (end factor, TNat, end lhs)
req $ Wire (end monoDangling, TNat, end rhs)
defineSrc out (VNum (n2PowTimes k (nVar (VPar (toEnd monoDangling)))))
pure out

buildMono :: Monotone (VVar Z) -> Checking Src
buildMono (Linear (VPar (ExEnd e))) = pure $ NamedPort e "numval"
buildMono (Full sm) = do
-- Calculate 2^n as `outPlus1`
two <- buildNum 2
dangling <- buildSM sm
((lhs,rhs),outPlus1) <- buildArithOp Pow
req $ Wire (end two, TNat, end lhs)
req $ Wire (end dangling, TNat, end rhs)
-- Then subtract 1
one <- buildNum 1
((lhs,rhs),out) <- buildArithOp Sub
req $ Wire (end outPlus1, TNat, end lhs)
req $ Wire (end one, TNat, end rhs)
defineSrc out (VNum (nFull (nVar (VPar (toEnd dangling)))))
pure out
buildMono _ = err . InternalError $ "Trying to build a non-closed nat value: " ++ show nv

invertNatVal :: NumVal (VVar Z) -> Checking Tgt
invertNatVal (NumValue up gro) = case up of
0 -> invertGro gro
_ -> do
((lhs,rhs),out) <- buildArithOp Sub
upSrc <- buildNum up
req $ Wire (end upSrc, TNat, end rhs)
tgt <- invertGro gro
req $ Wire (end out, TNat, end tgt)
defineTgt tgt (VNum (nVar (VPar (toEnd out))))
defineTgt lhs (VNum (nPlus up (nVar (VPar (toEnd tgt)))))
pure lhs
where
invertGro Constant0 = error "Invariant violated: the numval arg to invertNatVal should contain a variable"
invertGro (StrictMonoFun sm) = invertSM sm

invertSM (StrictMono k mono) = case k of
0 -> invertMono mono
_ -> do
divisor <- buildNum (2 ^ k)
((lhs,rhs),out) <- buildArithOp Div
tgt <- invertMono mono
req $ Wire (end out, TNat, end tgt)
req $ Wire (end divisor, TNat, end rhs)
defineTgt tgt (VNum (nVar (VPar (toEnd out))))
defineTgt lhs (VNum (n2PowTimes k (nVar (VPar (toEnd tgt)))))
pure lhs

invertMono (Linear (VPar (InEnd e))) = pure (NamedPort e "numval")
invertMono (Full sm) = do
(_, [(llufTgt,_)], [(llufSrc,_)], _) <- next "luff" (Prim ("BRAT","lluf")) (S0, Some (Zy :* S0)) (REx ("n", Nat) R0) (REx ("n", Nat) R0)
tgt <- invertSM sm
req $ Wire (end llufSrc, TNat, end tgt)
defineTgt tgt (VNum (nVar (VPar (toEnd llufSrc))))
defineTgt llufTgt (VNum (nFull (nVar (VPar (toEnd tgt)))))
pure llufTgt
99 changes: 1 addition & 98 deletions brat/Brat/Checker/SolveHoles.hs
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,10 @@ module Brat.Checker.SolveHoles (typeEq, buildNatVal, buildNum, invertNatVal) whe

import Brat.Checker.Monad
import Brat.Checker.Types (kindForMode)
import Brat.Checker.Helpers (buildArithOp, buildConst, defineSrc, defineTgt, next)
import Brat.Checker.Helpers (buildConst, buildNatVal, buildNum, invertNatVal)
import Brat.Error (ErrorMsg(..))
import Brat.Eval
import Brat.Graph (NodeType(..))
import Brat.Syntax.Common
import Brat.Syntax.Port (ToEnd(..))
import Brat.Syntax.Simple (SimpleTerm(..))
import Brat.Syntax.Value
import Control.Monad.Freer
Expand Down Expand Up @@ -154,98 +152,3 @@ typeEqRigid tm lvkz (TypeFor _ []) (VSum m0 rs0) (VSum m1 rs1)
where
eqVariant (Some r0, Some r1) = throwLeft (snd <$> typeEqRow m0 tm lvkz r0 r1)
typeEqRigid tm _ _ v0 v1 = err $ TypeMismatch tm (show v0) (show v1)

wire :: (Src, Val Z, Tgt) -> Checking ()
wire (src, ty, tgt) = req $ Wire (end src, ty, end tgt)

buildNum :: Integer -> Checking Src
buildNum n = buildConst (Num (fromIntegral n)) TNat


-- Generate wiring to produce a dynamic instance of the numval argument
buildNatVal :: NumVal (VVar Z) -> Checking Src
buildNatVal nv@(NumValue n gro) = case n of
0 -> buildGro gro
n -> do
nDangling <- buildNum n
((lhs,rhs),out) <- buildArithOp Add
src <- buildGro gro
wire (nDangling, TNat, lhs)
wire (src, TNat, rhs)
defineSrc out (VNum (nPlus n (nVar (VPar (toEnd src)))))
pure out
where
buildGro :: Fun00 (VVar Z) -> Checking Src
buildGro Constant0 = buildNum 0
buildGro (StrictMonoFun sm) = buildSM sm

buildSM :: StrictMono (VVar Z) -> Checking Src
buildSM (StrictMono k mono) = do
-- Calculate 2^k as `factor`
two <- buildNum 2
kDangling <- buildNum k
((lhs,rhs),factor) <- buildArithOp Pow
wire (two, TNat, lhs)
wire (kDangling, TNat, rhs)
-- Multiply mono by 2^k
((lhs,rhs),out) <- buildArithOp Mul
monoDangling <- buildMono mono
wire (factor, TNat, lhs)
wire (monoDangling, TNat, rhs)
defineSrc out (VNum (n2PowTimes k (nVar (VPar (toEnd monoDangling)))))
pure out

buildMono :: Monotone (VVar Z) -> Checking Src
buildMono (Linear (VPar (ExEnd e))) = pure $ NamedPort e "numval"
buildMono (Full sm) = do
-- Calculate 2^n as `outPlus1`
two <- buildNum 2
dangling <- buildSM sm
((lhs,rhs),outPlus1) <- buildArithOp Pow
wire (two, TNat, lhs)
wire (dangling, TNat, rhs)
-- Then subtract 1
one <- buildNum 1
((lhs,rhs),out) <- buildArithOp Sub
wire (outPlus1, TNat, lhs)
wire (one, TNat, rhs)
defineSrc out (VNum (nFull (nVar (VPar (toEnd dangling)))))
pure out
buildMono _ = err . InternalError $ "Trying to build a non-closed nat value: " ++ show nv

invertNatVal :: NumVal (VVar Z) -> Checking Tgt
invertNatVal (NumValue up gro) = case up of
0 -> invertGro gro
_ -> do
((lhs,rhs),out) <- buildArithOp Sub
upSrc <- buildNum up
wire (upSrc, TNat, rhs)
tgt <- invertGro gro
wire (out, TNat, tgt)
defineTgt tgt (VNum (nVar (VPar (toEnd out))))
defineTgt lhs (VNum (nPlus up (nVar (VPar (toEnd tgt)))))
pure lhs
where
invertGro Constant0 = error "Invariant violated: the numval arg to invertNatVal should contain a variable"
invertGro (StrictMonoFun sm) = invertSM sm

invertSM (StrictMono k mono) = case k of
0 -> invertMono mono
_ -> do
divisor <- buildNum (2 ^ k)
((lhs,rhs),out) <- buildArithOp Div
tgt <- invertMono mono
wire (out, TNat, tgt)
wire (divisor, TNat, rhs)
defineTgt tgt (VNum (nVar (VPar (toEnd out))))
defineTgt lhs (VNum (n2PowTimes k (nVar (VPar (toEnd tgt)))))
pure lhs

invertMono (Linear (VPar (InEnd e))) = pure (NamedPort e "numval")
invertMono (Full sm) = do
(_, [(llufTgt,_)], [(llufSrc,_)], _) <- next "luff" (Prim ("BRAT","lluf")) (S0, Some (Zy :* S0)) (REx ("n", Nat) R0) (REx ("n", Nat) R0)
tgt <- invertSM sm
wire (llufSrc, TNat, tgt)
defineTgt tgt (VNum (nVar (VPar (toEnd llufSrc))))
defineTgt llufTgt (VNum (nFull (nVar (VPar (toEnd tgt)))))
pure llufTgt
1 change: 0 additions & 1 deletion brat/Brat/Checker/SolvePatterns.hs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ module Brat.Checker.SolvePatterns (argProblems, argProblemsWithLeftovers, solve)

import Brat.Checker.Monad
import Brat.Checker.Helpers
import Brat.Checker.SolveHoles (buildNatVal, invertNatVal)
import Brat.Checker.Types (EndType(..))
import Brat.Constructors
import Brat.Constructors.Patterns
Expand Down

0 comments on commit f82f25c

Please sign in to comment.