Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improved Log inference #40

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
.stack-work
dist
dist-newstyle
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# Changelog for the [`ghc-typelits-extra`](http://hackage.haskell.org/package/ghc-typelits-extra) package

# Unreleased
* Normalise `Log base a + Log base b + ... + Log base z` to `Log base (a * b * ... * z)`

# 0.4.3 *June 18th 2021*
* Add support for GHC 9.2.0.20210422

Expand Down
38 changes: 35 additions & 3 deletions src/GHC/TypeLits/Extra/Solver/Operations.hs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ module GHC.TypeLits.Extra.Solver.Operations
, NormaliseResult
, mergeNormalised
, reifyEOP
, mergeAdd
, mergeMul
, mergeMax
, mergeMin
, mergeDiv
Expand All @@ -31,8 +33,12 @@ where
-- external
import Control.Monad.Trans.Writer.Strict
#if MIN_VERSION_ghc_typelits_natnormalise(0,7,0)
import Data.Set as Set
import qualified Data.Set as Set
#endif
#if !MIN_VERSION_base(4,11,0)
import Data.Semigroup (Semigroup (..))
#endif


import GHC.Base (isTrue#,(==#),(+#))
import GHC.Integer (smallInteger)
Expand All @@ -41,13 +47,15 @@ import GHC.TypeLits.Normalise.Unify (CType (..), normaliseNat, isNatural)

-- GHC API
#if MIN_VERSION_ghc(9,0,0)
import GHC.Builtin.Types.Literals (typeNatExpTyCon, typeNatSubTyCon)
import GHC.Builtin.Types.Literals ( typeNatExpTyCon, typeNatSubTyCon
, typeNatAddTyCon, typeNatMulTyCon)
import GHC.Core.TyCon (TyCon)
import GHC.Core.Type (Type, TyVar, mkNumLitTy, mkTyConApp, mkTyVarTy)
import GHC.Utils.Outputable (Outputable (..), (<+>), integer, text)
#else
import Outputable (Outputable (..), (<+>), integer, text)
import TcTypeNats (typeNatExpTyCon, typeNatSubTyCon)
import TcTypeNats ( typeNatExpTyCon, typeNatSubTyCon, typeNatAddTyCon
, typeNatMulTyCon)
import TyCon (TyCon)
import Type (Type, TyVar, mkNumLitTy, mkTyConApp, mkTyVarTy)
#endif
Expand All @@ -65,6 +73,15 @@ mergeNormalised Normalised _ = Normalised
mergeNormalised _ Normalised = Normalised
mergeNormalised _ _ = Untouched

instance Semigroup Normalised where
(<>) = mergeNormalised

instance Monoid Normalised where
mempty = Untouched
#if !MIN_VERSION_base(4,11,0)
mappend = mergeNormalised
#endif

-- | A normalise result contains the ExtraOp and a flag that indicates whether any expression
-- | was normalised within the ExtraOp.
type NormaliseResult = (ExtraOp, Normalised)
Expand All @@ -73,6 +90,8 @@ data ExtraOp
= I Integer
| V TyVar
| C CType
| Add ExtraOp ExtraOp
| Mul ExtraOp ExtraOp
| Max ExtraOp ExtraOp
| Min ExtraOp ExtraOp
| Div ExtraOp ExtraOp
Expand All @@ -89,6 +108,8 @@ instance Outputable ExtraOp where
ppr (I i) = integer i
ppr (V v) = ppr v
ppr (C c) = ppr c
ppr (Add x y) = text "Add (" <+> ppr x <+> text "," <+> ppr y <+> text ")"
ppr (Mul x y) = text "Mul (" <+> ppr x <+> text "," <+> ppr y <+> text ")"
ppr (Max x y) = text "Max (" <+> ppr x <+> text "," <+> ppr y <+> text ")"
ppr (Min x y) = text "Min (" <+> ppr x <+> text "," <+> ppr y <+> text ")"
ppr (Div x y) = text "Div (" <+> ppr x <+> text "," <+> ppr y <+> text ")"
Expand Down Expand Up @@ -117,6 +138,10 @@ reifyEOP :: ExtraDefs -> ExtraOp -> Type
reifyEOP _ (I i) = mkNumLitTy i
reifyEOP _ (V v) = mkTyVarTy v
reifyEOP _ (C (CType c)) = c
reifyEOP defs (Add x y) = mkTyConApp typeNatAddTyCon [reifyEOP defs x
,reifyEOP defs y]
reifyEOP defs (Mul x y) = mkTyConApp typeNatMulTyCon [reifyEOP defs x
,reifyEOP defs y]
reifyEOP defs (Max x y) = mkTyConApp (maxTyCon defs) [reifyEOP defs x
,reifyEOP defs y]
reifyEOP defs (Min x y) = mkTyConApp (minTyCon defs) [reifyEOP defs x
Expand All @@ -138,6 +163,13 @@ reifyEOP defs (LCM x y) = mkTyConApp (lcmTyCon defs) [reifyEOP defs x
reifyEOP defs (Exp x y) = mkTyConApp typeNatExpTyCon [reifyEOP defs x
,reifyEOP defs y]

mergeAdd :: ExtraOp -> ExtraOp -> NormaliseResult
mergeAdd (Log b1 x1) (Log b2 x2) | b1 == b2 = (Log b1 (Mul x1 x2), Normalised)
mergeAdd x y = (Add x y, Untouched)

mergeMul :: ExtraOp -> ExtraOp -> NormaliseResult
mergeMul x y = (Mul x y, Untouched)

mergeMax :: ExtraDefs -> ExtraOp -> ExtraOp -> NormaliseResult
mergeMax _ (I 0) y = (y, Normalised)
mergeMax _ x (I 0) = (x, Normalised)
Expand Down
58 changes: 56 additions & 2 deletions src/GHC/TypeLits/Extra/Solver/Unify.hs
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,18 @@ where
-- external
import Control.Monad.Trans.Class (lift)
import Control.Monad.Trans.Maybe (MaybeT (..))
import Control.Monad.Trans.Writer.Strict (runWriter)
import Data.Maybe (catMaybes)
import Data.Function (on)
import GHC.TypeLits.Normalise.Unify (CType (..))
import qualified GHC.TypeLits.Normalise.Unify as Normalise

-- GHC API
#if MIN_VERSION_ghc(9,0,0)
import GHC.Builtin.Types.Literals (typeNatExpTyCon)
import GHC.Builtin.Types.Literals
(typeNatExpTyCon, typeNatMulTyCon, typeNatAddTyCon, typeNatSubTyCon)
import GHC.Core.TyCo.Rep (Type (..), TyLit (..))
import GHC.Core.TyCon (TyCon)
import GHC.Core.Type (TyVar, coreView)
import GHC.Tc.Plugin (TcPluginM, tcPluginTrace)
import GHC.Tc.Types.Constraint (Ct)
Expand All @@ -35,8 +39,10 @@ import GHC.Utils.Outputable (Outputable (..), ($$), text)
#else
import Outputable (Outputable (..), ($$), text)
import TcPluginM (TcPluginM, tcPluginTrace)
import TcTypeNats (typeNatExpTyCon)
import TcTypeNats
(typeNatExpTyCon, typeNatAddTyCon, typeNatMulTyCon, typeNatSubTyCon)
import Type (TyVar, coreView)
import TyCon (TyCon)
import TyCoRep (Type (..), TyLit (..))
import UniqSet (UniqSet, emptyUniqSet, unionUniqSets, unitUniqSet)
#if MIN_VERSION_ghc(8,10,0)
Expand All @@ -60,12 +66,56 @@ mergeNormResWith f x y = do
(res, n3) <- f x' y'
pure (res, n1 `mergeNormalised` n2 `mergeNormalised` n3)

-- | First normalise the expression using the solver from ghc-typelits-natnormalise
-- before calling the mergeWith operation from this ghc-typelits-extra solver.
withSOPNormalize ::
ExtraDefs ->
-- The mergeWith operation
(ExtraOp -> ExtraOp -> (ExtraOp, Normalised)) ->
-- The (+), (-), or (*) type constructor
TyCon ->
Type ->
Type ->
MaybeT TcPluginM NormaliseResult
withSOPNormalize defs mergeWith tc x y = do
(x1, n1) <- normaliseNat defs x
(y1, n2) <- normaliseNat defs y
let x2 = reifyEOP defs x1
y2 = reifyEOP defs y1
(q,subtractions) = runWriter (Normalise.normaliseNat (TyConApp tc [x2, y2]))
-- Currently we don't use the result of ghc-typelits-natnormalise when there
-- are "inner" subtractions, as we would have to emit inequality constraints
-- for the arguments of those subtractions to guarantee that the result of
-- those subtractions would result in a natural number.
--
-- TODO: emit inequality constraints for the inner subtractions
noSubtractions
| tc == typeNatSubTyCon
, [(s1, s2)] <- subtractions
= CType s1 == CType x2 && CType s2 == CType y2
| otherwise
= null subtractions
if noSubtractions then case Normalise.reifySOP q of
TyConApp tcQ [xQ,yQ]
| tcQ == tc -> do
(x3, n3) <- normaliseNat defs xQ
(y3, n4) <- normaliseNat defs yQ
let (res,n5) = mergeWith x3 y3
return (res, mconcat [n1,n2,n3,n4,n5])
resSOP -> do
(res, n3) <- normaliseNat defs resSOP
return (res, mconcat [n1,n2,n3])
else do
let (res, n3) = mergeWith x1 y1
return (res, mconcat [n1,n2,n3])

normaliseNat :: ExtraDefs -> Type -> MaybeT TcPluginM NormaliseResult
normaliseNat defs ty | Just ty1 <- coreView ty = normaliseNat defs ty1
normaliseNat _ (TyVarTy v) = pure (V v, Untouched)
normaliseNat _ (LitTy (NumTyLit i)) = pure (I i, Untouched)
normaliseNat defs (TyConApp tc [x,y])
| tc == typeNatAddTyCon = withSOPNormalize defs mergeAdd tc x y
| tc == typeNatMulTyCon = withSOPNormalize defs mergeMul tc x y
| tc == maxTyCon defs = mergeNormResWith (\x' y' -> return (mergeMax defs x' y'))
(normaliseNat defs x)
(normaliseNat defs y)
Expand Down Expand Up @@ -152,6 +202,8 @@ fvOP :: ExtraOp -> UniqSet TyVar
fvOP (I _) = emptyUniqSet
fvOP (V v) = unitUniqSet v
fvOP (C _) = emptyUniqSet
fvOP (Add x y) = fvOP x `unionUniqSets` fvOP y
fvOP (Mul x y) = fvOP x `unionUniqSets` fvOP y
fvOP (Max x y) = fvOP x `unionUniqSets` fvOP y
fvOP (Min x y) = fvOP x `unionUniqSets` fvOP y
fvOP (Div x y) = fvOP x `unionUniqSets` fvOP y
Expand All @@ -170,6 +222,8 @@ containsConstants :: ExtraOp -> Bool
containsConstants (I _) = False
containsConstants (V _) = False
containsConstants (C _) = True
containsConstants (Add _ _) = True
containsConstants (Mul _ _) = True
containsConstants (Max x y) = containsConstants x || containsConstants y
containsConstants (Min x y) = containsConstants x || containsConstants y
containsConstants (Div x y) = containsConstants x || containsConstants y
Expand Down
6 changes: 6 additions & 0 deletions tests/ErrorTests.hs
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,9 @@ testFail26 = testFail26' (Proxy @4) (Proxy @6) (Proxy @6)
testFail27 :: Proxy n -> Proxy (n + 2 <=? Max (n + 1) 1) -> Proxy True
testFail27 _ = id

testFail28 :: Proxy (Log 10 50 + Log 10 2) -> Proxy 2
testFail28 = id

#if __GLASGOW_HASKELL__ >= 900
testFail1Errors =
["Expected: Proxy (GCD 6 8) -> Proxy 4"
Expand Down Expand Up @@ -345,3 +348,6 @@ testFail26Errors =
["Could not deduce: Max x y ~ n"
,"from the context: (x <=? n) ~ 'True"
]

testFail28Errors =
["Couldn't match type ‘Log 10 50 + Log 10 2’ with ‘2’"]
39 changes: 39 additions & 0 deletions tests/Main.hs
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,35 @@ test57
-> Proxy True
test57 _ _ = id

test58
:: Proxy b
-> Proxy n
-> Proxy p
-> Proxy q
-> Proxy (Log b (n * p * q))
-> Proxy (Log b n + Log b p + Log b q)
test58 _ _ _ _ = id

test59
:: Proxy b
-> Proxy n
-> Proxy p
-> Proxy q
-> Proxy r
-> Proxy (Log b (n * p * q) + r)
-> Proxy (Log b n + r + Log b p + Log b q)
test59 _ _ _ _ _ = id

test60
:: Proxy b
-> Proxy n
-> Proxy p
-> Proxy q
-> Proxy r
-> Proxy (Log b (Max p q * Min r q * q))
-> Proxy (Log b (Max p q) + Log b (Min r q) + Log b q)
test60 _ _ _ _ _ = id

main :: IO ()
main = defaultMain tests

Expand Down Expand Up @@ -400,6 +429,15 @@ tests = testGroup "ghc-typelits-natnormalise"
, testCase "forall n p . n + 1 <= Max (n + p + 1) p" $
show (test57 Proxy Proxy Proxy) @?=
"Proxy"
, testCase "forall b n p q . Log b (n * p * q) ~ Log b n + Log b p + Log b q" $
show (test58 Proxy Proxy Proxy Proxy Proxy) @?=
"Proxy"
, testCase "forall b n p q r . r + Log b (n * p * q) ~ Log b n + r + Log b p + Log b q" $
show (test59 Proxy Proxy Proxy Proxy Proxy Proxy) @?=
"Proxy"
, testCase "forall b n p q r . r + Log b (Max p q * Min r q * q) ~ Log b (Max p q) + Log b (Min r q) + Log b q" $
show (test60 Proxy Proxy Proxy Proxy Proxy Proxy) @?=
"Proxy"
]
, testGroup "errors"
[ testCase "GCD 6 8 /~ 4" $ testFail1 `throws` testFail1Errors
Expand Down Expand Up @@ -429,6 +467,7 @@ tests = testGroup "ghc-typelits-natnormalise"
, testCase "(x+1 <=? Max x y) /~ True" $ testFail25 `throws` testFail25Errors
, testCase "(x <= n) /=> (Max x y) ~ n" $ testFail26 `throws` testFail26Errors
, testCase "n + 2 <=? Max (n + 1) 1 /~ True" $ testFail27 `throws` testFail27Errors
, testCase "Log 10 50 + Log 10 2 /= 2" $ testFail28 `throws` testFail28Errors
]
]

Expand Down