Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
dougalm committed Sep 8, 2023
1 parent 1487b76 commit f955da4
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 86 deletions.
63 changes: 22 additions & 41 deletions src/lib/AbstractSyntax.hs
Original file line number Diff line number Diff line change
Expand Up @@ -172,32 +172,29 @@ aDef (CDef name params optRhs optGivens body) = do
resultTy' <- expr resultTy
return (expl, Just effs, Just resultTy')
implicitParams <- aOptGivens optGivens
let allParams = fmapNest asOptAnn implicitParams >>> explicitParams
let allParams = implicitParams >>> explicitParams
body' <- block body
return (name, ULamExpr allParams expl effs resultTy body')

asOptAnn :: UAnnBinder AnnRequired n l -> UAnnBinder AnnOptional n l
asOptAnn (UAnnBinder expl b (UAnn ann)) = UAnnBinder expl b (UAnn ann)

stripParens :: Group -> Group
stripParens (WithSrc _ (CParens [g])) = stripParens g
stripParens g = g

-- === combinators for different sorts of binder lists ===

aOptGivens :: Maybe GivenClause -> SyntaxM (Nest UReqAnnBinder VoidS VoidS)
aOptGivens :: Maybe GivenClause -> SyntaxM (Nest UAnnBinder VoidS VoidS)
aOptGivens optGivens = fromMaybeM optGivens Empty aGivens

binderList
:: [Group] -> (Group -> SyntaxM (Nest (UAnnBinder req) VoidS VoidS))
-> SyntaxM (Nest (UAnnBinder req) VoidS VoidS)
:: [Group] -> (Group -> SyntaxM (Nest UAnnBinder VoidS VoidS))
-> SyntaxM (Nest UAnnBinder VoidS VoidS)
binderList gs cont = concatNests <$> forM gs \case
WithSrc _ (CGivens gs') -> aGivens gs'
g -> cont g

withTrailingConstraints
:: Group -> (Group -> SyntaxM ((UAnnBinder req) VoidS VoidS))
-> SyntaxM (Nest (UAnnBinder req) VoidS VoidS)
:: Group -> (Group -> SyntaxM (UAnnBinder VoidS VoidS))
-> SyntaxM (Nest UAnnBinder VoidS VoidS)
withTrailingConstraints g cont = case g of
Binary Pipe lhs c -> do
bs@(Nest (UAnnBinder _ b _) _) <- withTrailingConstraints lhs cont
Expand All @@ -207,57 +204,44 @@ withTrailingConstraints g cont = case g of
UBind _ _ _ -> error "Shouldn't have internal names until renaming pass"
c' <- expr c
let v = WithSrcE ctx $ UVar (SourceName ctx s)
return $ bs >>> UnaryNest (fromReqAnnBinder $ asConstraintBinder v c')
return $ bs >>> UnaryNest (asConstraintBinder v c')
_ -> UnaryNest <$> cont g
where
asConstraintBinder :: UExpr VoidS -> UConstraint VoidS -> UReqAnnBinder VoidS VoidS
asConstraintBinder :: UExpr VoidS -> UConstraint VoidS -> UAnnBinder VoidS VoidS
asConstraintBinder v c = do
let t = ns $ UApp c [v] []
UAnnBinder (Inferred Nothing (Synth Full)) UIgnore (UAnn t)

defaultAnn :: UExpr VoidS -> UOptAnnBinder VoidS VoidS -> UReqAnnBinder VoidS VoidS
defaultAnn dAnn (UAnnBinder expl b ann) = do
let ann' = case ann of
UNoAnn -> dAnn
UAnn t -> t
UAnnBinder expl b (UAnn ann')

aGivens :: GivenClause -> SyntaxM (Nest (UAnnBinder req) VoidS VoidS)
aGivens :: GivenClause -> SyntaxM (Nest UAnnBinder VoidS VoidS)
aGivens (implicits, optConstraints) = do
implicits' <- concatNests <$> forM implicits \b -> withTrailingConstraints b implicitArgBinder
constraints <- fromMaybeM optConstraints Empty (\gs -> toNest <$> mapM synthBinder gs)
return $ fmapNest fromReqAnnBinder $ implicits' >>> constraints
return $ implicits' >>> constraints

synthBinder :: Group -> SyntaxM (UReqAnnBinder VoidS VoidS)
synthBinder :: Group -> SyntaxM (UAnnBinder VoidS VoidS)
synthBinder g = tyOptBinder (Inferred Nothing (Synth Full)) g

fromReqAnnBinder :: UReqAnnBinder n l -> UAnnBinder req n l
fromReqAnnBinder (UAnnBinder expl b ann) = do
let ann' = case ann of
UAnn t -> UAnn t
UAnnBinder expl b ann'

concatNests :: [Nest b VoidS VoidS] -> Nest b VoidS VoidS
concatNests [] = Empty
concatNests (b:bs) = b >>> concatNests bs

implicitArgBinder :: Group -> SyntaxM (UReqAnnBinder VoidS VoidS)
implicitArgBinder :: Group -> SyntaxM (UAnnBinder VoidS VoidS)
implicitArgBinder g = do
UAnnBinder _ b ann <- defaultAnn (ns tyKind) <$> binderOptTy (Inferred Nothing Unify) g
UAnnBinder _ b ann <- binderOptTy (Inferred Nothing Unify) g
s <- case b of
UBindSource _ s -> return $ Just s
_ -> return Nothing
return $ UAnnBinder (Inferred s Unify) b ann

aExplicitParams :: ExplicitParams -> SyntaxM (Nest UReqAnnBinder VoidS VoidS)
aExplicitParams :: ExplicitParams -> SyntaxM (Nest UAnnBinder VoidS VoidS)
aExplicitParams bs = binderList bs \b -> withTrailingConstraints b \b' ->
defaultAnn (ns tyKind) <$> binderOptTy Explicit b'
binderOptTy Explicit b'

aPiBinders :: ExplicitParams -> SyntaxM (Nest UReqAnnBinder VoidS VoidS)
aPiBinders :: ExplicitParams -> SyntaxM (Nest UAnnBinder VoidS VoidS)
aPiBinders bs = binderList bs \b ->
UnaryNest <$> tyOptBinder Explicit b

explicitBindersOptAnn :: ExplicitParams -> SyntaxM (Nest UOptAnnBinder VoidS VoidS)
explicitBindersOptAnn :: ExplicitParams -> SyntaxM (Nest UAnnBinder VoidS VoidS)
explicitBindersOptAnn bs = binderList bs \b -> withTrailingConstraints b \b' ->
binderOptTy Explicit b'

Expand All @@ -276,7 +260,7 @@ uBinder (WithSrc src b) = addSrcContext src $ case b of
_ -> throw SyntaxErr "Binder must be an identifier or `_`"

-- Type annotation with an optional binder pattern
tyOptPat :: Group -> SyntaxM (UReqAnnBinder VoidS VoidS)
tyOptPat :: Group -> SyntaxM (UAnnBinder VoidS VoidS)
tyOptPat = \case
-- Named type
Binary Colon lhs typeAnn -> UAnnBinder Explicit <$> uBinder lhs <*> (UAnn <$> expr typeAnn)
Expand Down Expand Up @@ -320,7 +304,7 @@ pat = propagateSrcB pat' where
_ -> error "unexpected postfix group (should be ruled out at grouping stage)"
pat' _ = throw SyntaxErr "Illegal pattern"

tyOptBinder :: Explicitness -> Group -> SyntaxM (UAnnBinder req VoidS VoidS)
tyOptBinder :: Explicitness -> Group -> SyntaxM (UAnnBinder VoidS VoidS)
tyOptBinder expl = \case
Binary Pipe _ _ -> throw SyntaxErr "Unexpected constraint"
Binary Colon name ty -> do
Expand All @@ -331,7 +315,7 @@ tyOptBinder expl = \case
ty <- expr g
return $ UAnnBinder expl UIgnore (UAnn ty)

binderOptTy :: Explicitness -> Group -> SyntaxM (UOptAnnBinder VoidS VoidS)
binderOptTy :: Explicitness -> Group -> SyntaxM (UAnnBinder VoidS VoidS)
binderOptTy expl = \case
Binary Colon name ty -> do
b <- uBinder name
Expand All @@ -341,7 +325,7 @@ binderOptTy expl = \case
b <- uBinder g
return $ UAnnBinder expl b UNoAnn

binderReqTy :: Explicitness -> Group -> SyntaxM (UReqAnnBinder VoidS VoidS)
binderReqTy :: Explicitness -> Group -> SyntaxM (UAnnBinder VoidS VoidS)
binderReqTy expl (Binary Colon name ty) = do
b <- uBinder name
ann <- UAnn <$> expr ty
Expand Down Expand Up @@ -562,13 +546,10 @@ charExpr c = ULit $ Word8Lit $ fromIntegral $ fromEnum c
unitExpr :: UExpr' VoidS
unitExpr = UPrim (UCon $ P.ProdCon) []

tyKind :: UExpr' VoidS
tyKind = UPrim (UPrimTC P.TypeKind) []

-- === Builders ===

-- TODO Does this generalize? Swap list for Nest?
buildFor :: SrcPos -> Direction -> [UOptAnnBinder VoidS VoidS] -> UBlock VoidS -> UExpr VoidS
buildFor :: SrcPos -> Direction -> [UAnnBinder VoidS VoidS] -> UBlock VoidS -> UExpr VoidS
buildFor pos dir binders body = case binders of
[] -> error "should have nonempty list of binder"
[b] -> WithSrcE (fromPos pos) $ UFor dir $ UForExpr b body
Expand Down
51 changes: 38 additions & 13 deletions src/lib/Inference.hs
Original file line number Diff line number Diff line change
Expand Up @@ -1195,7 +1195,7 @@ inferStructDef (UStructDef tyConName paramBs fields _) = do

inferDotMethod
:: TyConName o
-> Abs (Nest UReqAnnBinder) (Abs UAtomBinder ULamExpr) i
-> Abs (Nest UAnnBinder) (Abs UAtomBinder ULamExpr) i
-> InfererM i o (CoreLamExpr o)
inferDotMethod tc (Abs uparamBs (Abs selfB lam)) = do
TyConDef sn roleExpls paramBs _ <- lookupTyCon tc
Expand Down Expand Up @@ -1251,7 +1251,7 @@ dataConRepTy (Abs topBs UnitE) = case topBs of
depTy = DepPairTy $ DepPairType ExplicitDepPair b tailTy

inferClassDef
:: SourceName -> [SourceName] -> Nest UReqAnnBinder i i' -> [UType i']
:: SourceName -> [SourceName] -> Nest UAnnBinder i i' -> [UType i']
-> InfererM i o (ClassDef o)
inferClassDef className methodNames paramBs methodTys = do
withRoleUBinders paramBs \(ZipB roleExpls paramBs') -> do
Expand All @@ -1270,18 +1270,43 @@ inferClassDef className methodNames paramBs methodTys = do
let (roleExpls', paramBs''') = unzipAttrs paramBs''
return $ ClassDef className methodNames paramNames roleExpls' paramBs''' superclassBs methodTys'

withUBinder :: UReqAnnBinder i i' -> InfererCPSB2 (WithExpl CBinder) i i' o a
withUBinder :: UAnnBinder i i' -> InfererCPSB2 (WithExpl CBinder) i i' o a
withUBinder (UAnnBinder expl b (UAnn ty)) cont = do
ty' <- checkUType ty
withFreshBinderInf (getNameHint b) expl ty' \b' ->
extendSubst (b@>binderName b') $ cont (WithAttrB expl b')

withUBinders :: Nest UReqAnnBinder i i' -> InfererCPSB2 (Nest (WithExpl CBinder)) i i' o a
withUBinders Empty cont = withDistinct $ cont Empty
withUBinders (Nest b bs) cont = withUBinder b \b' ->
withUBinders bs \bs' -> cont $ Nest b' bs'
withUBinders :: Nest UAnnBinder i i' -> InfererCPSB2 (Nest (WithExpl CBinder)) i i' o a
withUBinders bs cont = do
Abs bs' UnitE <- inferUBinders bs \_ -> return UnitE
let (expls, bs'') = unzipAttrs bs'
withFreshBindersInf expls (EmptyAbs bs'') \bs'' -> do
extendSubst (bs@@> (atomVarName <$> bindersVars bs'')) $
cont $ zipAttrs expls bs''

inferUBinders
:: Zonkable e => Nest UAnnBinder i i'
-> (forall o'. DExt o o' => [CAtomName o'] -> InfererM i' o' (e o'))
-> InfererM i o (Abs (Nest (WithExpl CBinder)) e o)
inferUBinders Empty cont = withDistinct $ Abs Empty <$> cont []
inferUBinders (Nest (UAnnBinder expl b ann) bs) cont = do
-- TODO: factor out the common part of each case (requires an annotated
-- `where` clause because of the rank-2 type)
case ann of
UAnn ty -> do
ty' <- checkUType ty
withFreshBinderInf (getNameHint b) expl ty' \b' -> do
extendSubst (b@>binderName b') do
Abs bs' e <- inferUBinders bs \vs -> cont (sink (binderName b') : vs)
return $ Abs (Nest (WithAttrB expl b') bs') e
UNoAnn -> withFreshInferenceName (AnnotationInfVar (getSourceName b)) TyKind \v -> do
let ty = TyVar v
withFreshBinderInf (getNameHint b) expl ty \b' -> do
extendSubst (b@>binderName b') do
Abs bs' e <- inferUBinders bs \vs -> cont (sink (binderName b') : vs)
return $ Abs (Nest (WithAttrB expl b') bs') e

withRoleUBinders :: Nest UReqAnnBinder i i' -> InfererCPSB2 (Nest (WithRoleExpl CBinder)) i i' o a
withRoleUBinders :: Nest UAnnBinder i i' -> InfererCPSB2 (Nest (WithRoleExpl CBinder)) i i' o a
withRoleUBinders bs cont = do
withUBinders bs \(ZipB expls bs') -> do
let tys = getType <$> bindersVars bs'
Expand All @@ -1305,13 +1330,13 @@ withRoleUBinders bs cont = do
False -> return DataParam
{-# INLINE inferRole #-}

requireAnn :: UOptAnnBinder l l' -> InfererM i o (UReqAnnBinder l l')
requireAnn :: UAnnBinder l l' -> InfererM i o (UAnnBinder l l')
requireAnn (UAnnBinder expl b (UAnn ann)) = return $ UAnnBinder expl b (UAnn ann)
requireAnn (UAnnBinder _ b UNoAnn) = addSrcContext (srcPos b) $
throw TypeErr "Binder requires annotation"

checkLamBinders
:: [Explicitness] -> Nest CBinder o any -> Nest UOptAnnBinder i i'
:: [Explicitness] -> Nest CBinder o any -> Nest UAnnBinder i i'
-> InfererCPSB2 (Nest CBinder) i i' o a
checkLamBinders [] Empty Empty cont = withDistinct $ cont Empty
checkLamBinders (piExpl:piExpls) (Nest (piB:>piAnn) piBs) lamBs cont = do
Expand Down Expand Up @@ -1342,8 +1367,7 @@ inferUForExpr (UForExpr b body) = do

inferULam :: ULamExpr i -> InfererM i o (CoreLamExpr o)
inferULam (ULamExpr bs appExpl effs resultTy body) = do
bs' <- forEachNestItemM bs requireAnn
withUBinders bs' \(ZipB expls bs'') -> do
Abs (ZipB expls bs') (PairE effTy body') <- inferUBinders bs \_ -> do
effs' <- fromMaybe Pure <$> mapM checkUEffRow effs
resultTy' <- mapM checkUType resultTy
body' <- buildDeclsInf $ withAllowedEffects (sink effs') do
Expand All @@ -1353,7 +1377,8 @@ inferULam (ULamExpr bs appExpl effs resultTy body) = do
Just resultTy'' -> checkSigma noHint result (sink resultTy'')
resultTy'' <- blockTy body'
let effTy = EffTy effs' resultTy''
return $ CoreLamExpr (CorePiType appExpl expls bs'' effTy) (LamExpr bs'' body')
return $ PairE effTy body'
return $ CoreLamExpr (CorePiType appExpl expls bs' effTy) (LamExpr bs' body')

checkUForExpr :: Emits o => UForExpr i -> TabPiType CoreIR o -> InfererM i o (LamExpr CoreIR o)
checkUForExpr (UForExpr bFor body) tabPi@(TabPiType _ bPi _) = do
Expand Down
4 changes: 2 additions & 2 deletions src/lib/PPrint.hs
Original file line number Diff line number Diff line change
Expand Up @@ -655,10 +655,10 @@ prettyBinderNest bs = nest 6 $ line' <> (sep $ map p $ fromNest bs)
instance Pretty (UDataDefTrail n) where
pretty (UDataDefTrail bs) = p $ fromNest bs

instance Pretty (UAnnBinder req n l) where
instance Pretty (UAnnBinder n l) where
pretty (UAnnBinder _ b ty) = p b <> ":" <> p ty

instance Pretty (UAnn req n) where
instance Pretty (UAnn n) where
pretty (UAnn ty) = ":" <> p ty
pretty UNoAnn = mempty

Expand Down
4 changes: 2 additions & 2 deletions src/lib/SourceRename.hs
Original file line number Diff line number Diff line change
Expand Up @@ -157,11 +157,11 @@ instance (SourceRenamableE e, SourceRenamableB b) => SourceRenamableE (Abs b e)
instance SourceRenamableB (UBinder (AtomNameC CoreIR)) where
sourceRenameB b cont = sourceRenameUBinder UAtomVar b cont

instance SourceRenamableE (UAnn req) where
instance SourceRenamableE UAnn where
sourceRenameE UNoAnn = return UNoAnn
sourceRenameE (UAnn ann) = UAnn <$> sourceRenameE ann

instance SourceRenamableB (UAnnBinder req) where
instance SourceRenamableB UAnnBinder where
sourceRenameB (UAnnBinder expl b ann) cont = do
ann' <- sourceRenameE ann
sourceRenameB b \b' -> cont $ UAnnBinder expl b' ann'
Expand Down
Loading

0 comments on commit f955da4

Please sign in to comment.