Skip to content

Commit

Permalink
Improved Log inference
Browse files Browse the repository at this point in the history
Normalise `Log base a + Log base b + ... + Log base z` + `Log base (a * b * ... * z)`
  • Loading branch information
rowanG077 committed Feb 10, 2022
1 parent fe80881 commit e9abf39
Show file tree
Hide file tree
Showing 6 changed files with 144 additions and 7 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
.stack-work
dist
dist-newstyle
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# Changelog for the [`ghc-typelits-extra`](http://hackage.haskell.org/package/ghc-typelits-extra) package

# Unreleased
* Normalise `Log base a + Log base b + ... + Log base z` to `Log base (a * b * ... * z)`

# 0.4.3 *June 18th 2021*
* Add support for GHC 9.2.0.20210422

Expand Down
40 changes: 36 additions & 4 deletions src/GHC/TypeLits/Extra/Solver/Operations.hs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ module GHC.TypeLits.Extra.Solver.Operations
, NormaliseResult
, mergeNormalised
, reifyEOP
, mergeAdd
, mergeMul
, mergeMax
, mergeMin
, mergeDiv
Expand All @@ -31,8 +33,12 @@ where
-- external
import Control.Monad.Trans.Writer.Strict
#if MIN_VERSION_ghc_typelits_natnormalise(0,7,0)
import Data.Set as Set
import qualified Data.Set as Set
#endif
#if !MIN_VERSION_base(4,11,0)
import Data.Semigroup (Semigroup (..))
#endif


import GHC.Base (isTrue#,(==#),(+#))
import GHC.Integer (smallInteger)
Expand All @@ -41,13 +47,15 @@ import GHC.TypeLits.Normalise.Unify (CType (..), normaliseNat, isNatural)

-- GHC API
#if MIN_VERSION_ghc(9,0,0)
import GHC.Builtin.Types.Literals (typeNatExpTyCon, typeNatSubTyCon)
import GHC.Builtin.Types.Literals ( typeNatExpTyCon, typeNatSubTyCon
, typeNatAddTyCon, typeNatMulTyCon)
import GHC.Core.TyCon (TyCon)
import GHC.Core.Type (Type, TyVar, mkNumLitTy, mkTyConApp, mkTyVarTy)
import GHC.Utils.Outputable (Outputable (..), (<+>), integer, text)
#else
import Outputable (Outputable (..), (<+>), integer, text)
import TcTypeNats (typeNatExpTyCon, typeNatSubTyCon)
import TcTypeNats ( typeNatExpTyCon, typeNatSubTyCon, typeNatAddTyCon
, typeNatMulTyCon)
import TyCon (TyCon)
import Type (Type, TyVar, mkNumLitTy, mkTyConApp, mkTyVarTy)
#endif
Expand All @@ -65,6 +73,15 @@ mergeNormalised Normalised _ = Normalised
mergeNormalised _ Normalised = Normalised
mergeNormalised _ _ = Untouched

instance Semigroup Normalised where
(<>) = mergeNormalised

instance Monoid Normalised where
mempty = Untouched
#if !MIN_VERSION_base(4,11,0)
mappend = mergeNormalised
#endif

-- | A normalise result contains the ExtraOp and a flag that indicates whether any expression
-- | was normalised within the ExtraOp.
type NormaliseResult = (ExtraOp, Normalised)
Expand All @@ -73,6 +90,8 @@ data ExtraOp
= I Integer
| V TyVar
| C CType
| Add ExtraOp ExtraOp
| Mul ExtraOp ExtraOp
| Max ExtraOp ExtraOp
| Min ExtraOp ExtraOp
| Div ExtraOp ExtraOp
Expand All @@ -83,12 +102,14 @@ data ExtraOp
| GCD ExtraOp ExtraOp
| LCM ExtraOp ExtraOp
| Exp ExtraOp ExtraOp
deriving Eq
deriving (Eq, Ord)

instance Outputable ExtraOp where
ppr (I i) = integer i
ppr (V v) = ppr v
ppr (C c) = ppr c
ppr (Add x y) = text "Add (" <+> ppr x <+> text "," <+> ppr y <+> text ")"
ppr (Mul x y) = text "Mul (" <+> ppr x <+> text "," <+> ppr y <+> text ")"
ppr (Max x y) = text "Max (" <+> ppr x <+> text "," <+> ppr y <+> text ")"
ppr (Min x y) = text "Min (" <+> ppr x <+> text "," <+> ppr y <+> text ")"
ppr (Div x y) = text "Div (" <+> ppr x <+> text "," <+> ppr y <+> text ")"
Expand Down Expand Up @@ -117,6 +138,10 @@ reifyEOP :: ExtraDefs -> ExtraOp -> Type
reifyEOP _ (I i) = mkNumLitTy i
reifyEOP _ (V v) = mkTyVarTy v
reifyEOP _ (C (CType c)) = c
reifyEOP defs (Add x y) = mkTyConApp typeNatAddTyCon [reifyEOP defs x
,reifyEOP defs y]
reifyEOP defs (Mul x y) = mkTyConApp typeNatMulTyCon [reifyEOP defs x
,reifyEOP defs y]
reifyEOP defs (Max x y) = mkTyConApp (maxTyCon defs) [reifyEOP defs x
,reifyEOP defs y]
reifyEOP defs (Min x y) = mkTyConApp (minTyCon defs) [reifyEOP defs x
Expand All @@ -138,6 +163,13 @@ reifyEOP defs (LCM x y) = mkTyConApp (lcmTyCon defs) [reifyEOP defs x
reifyEOP defs (Exp x y) = mkTyConApp typeNatExpTyCon [reifyEOP defs x
,reifyEOP defs y]

mergeAdd :: ExtraOp -> ExtraOp -> NormaliseResult
mergeAdd (Log b1 x1) (Log b2 x2) | b1 == b2 = (Log b1 (Mul x1 x2), Normalised)
mergeAdd x y = (Add x y, Untouched)

mergeMul :: ExtraOp -> ExtraOp -> NormaliseResult
mergeMul x y = (Mul x y, Untouched)

mergeMax :: ExtraDefs -> ExtraOp -> ExtraOp -> NormaliseResult
mergeMax _ (I 0) y = (y, Normalised)
mergeMax _ x (I 0) = (x, Normalised)
Expand Down
59 changes: 57 additions & 2 deletions src/GHC/TypeLits/Extra/Solver/Unify.hs
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,18 @@ where
-- external
import Control.Monad.Trans.Class (lift)
import Control.Monad.Trans.Maybe (MaybeT (..))
import Control.Monad.Trans.Writer.Strict (runWriter)
import Data.Maybe (catMaybes)
import Data.Function (on)
import GHC.TypeLits.Normalise.Unify (CType (..))
import qualified GHC.TypeLits.Normalise.Unify as Normalise

-- GHC API
#if MIN_VERSION_ghc(9,0,0)
import GHC.Builtin.Types.Literals (typeNatExpTyCon)
import GHC.Builtin.Types.Literals
(typeNatExpTyCon, typeNatMulTyCon, typeNatAddTyCon, typeNatSubTyCon)
import GHC.Core.TyCo.Rep (Type (..), TyLit (..))
import GHC.Core.TyCon (TyCon)
import GHC.Core.Type (TyVar, coreView)
import GHC.Tc.Plugin (TcPluginM, tcPluginTrace)
import GHC.Tc.Types.Constraint (Ct)
Expand All @@ -35,8 +39,10 @@ import GHC.Utils.Outputable (Outputable (..), ($$), text)
#else
import Outputable (Outputable (..), ($$), text)
import TcPluginM (TcPluginM, tcPluginTrace)
import TcTypeNats (typeNatExpTyCon)
import TcTypeNats
(typeNatExpTyCon, typeNatAddTyCon, typeNatMulTyCon, typeNatSubTyCon)
import Type (TyVar, coreView)
import TyCon (TyCon)
import TyCoRep (Type (..), TyLit (..))
import UniqSet (UniqSet, emptyUniqSet, unionUniqSets, unitUniqSet)
#if MIN_VERSION_ghc(8,10,0)
Expand All @@ -60,12 +66,57 @@ mergeNormResWith f x y = do
(res, n3) <- f x' y'
pure (res, n1 `mergeNormalised` n2 `mergeNormalised` n3)

-- | First normalise the expression using the solver from ghc-typelits-natnormalise
-- before calling the mergeWith operation from this ghc-typelits-extra solver.
withSOPNormalize ::
ExtraDefs ->
-- The mergeWith operation
(ExtraOp -> ExtraOp -> (ExtraOp, Normalised)) ->
-- The (+), (-), or (*) type constructor
TyCon ->
Type ->
Type ->
MaybeT TcPluginM NormaliseResult
withSOPNormalize defs mergeWith tc x y = do
(x1, n1) <- normaliseNat defs x
(y1, n2) <- normaliseNat defs y
let x2 = reifyEOP defs x1
y2 = reifyEOP defs y1
(q,subtractions) = runWriter (Normalise.normaliseNat (TyConApp tc [x2, y2]))
-- Currently we don't use the result of ghc-typelits-natnormalise when there
-- are "inner" subtractions, as we would have to emit inequality constraints
-- for the arguments of those subtractions to guarantee that the result of
-- those subtractions would result in a natural number.
--
-- TODO: emit inequality constraints for the inner subtractions
noSubtractions
| tc == typeNatSubTyCon
, [(s1, s2)] <- subtractions
= CType s1 == CType x2 && CType s2 == CType y2
| otherwise
= null subtractions
if noSubtractions then case Normalise.reifySOP q of
TyConApp tcQ [xQ,yQ]
| tcQ == tc -> do
(x3, n3) <- normaliseNat defs xQ
(y3, n4) <- normaliseNat defs yQ
let (res,n5) = mergeWith x3 y3
lift (tcPluginTrace "called mergeWith on: " (ppr (x3, y3, res)))
return (res, mconcat [n1,n2,n3,n4,n5])
resSOP -> do
(res, n3) <- normaliseNat defs resSOP
return (res, mconcat [n1,n2,n3])
else do
let (res, n3) = mergeWith x1 y1
return (res, mconcat [n1,n2,n3])

normaliseNat :: ExtraDefs -> Type -> MaybeT TcPluginM NormaliseResult
normaliseNat defs ty | Just ty1 <- coreView ty = normaliseNat defs ty1
normaliseNat _ (TyVarTy v) = pure (V v, Untouched)
normaliseNat _ (LitTy (NumTyLit i)) = pure (I i, Untouched)
normaliseNat defs (TyConApp tc [x,y])
| tc == typeNatAddTyCon = withSOPNormalize defs mergeAdd tc x y
| tc == typeNatMulTyCon = withSOPNormalize defs mergeMul tc x y
| tc == maxTyCon defs = mergeNormResWith (\x' y' -> return (mergeMax defs x' y'))
(normaliseNat defs x)
(normaliseNat defs y)
Expand Down Expand Up @@ -152,6 +203,8 @@ fvOP :: ExtraOp -> UniqSet TyVar
fvOP (I _) = emptyUniqSet
fvOP (V v) = unitUniqSet v
fvOP (C _) = emptyUniqSet
fvOP (Add x y) = fvOP x `unionUniqSets` fvOP y
fvOP (Mul x y) = fvOP x `unionUniqSets` fvOP y
fvOP (Max x y) = fvOP x `unionUniqSets` fvOP y
fvOP (Min x y) = fvOP x `unionUniqSets` fvOP y
fvOP (Div x y) = fvOP x `unionUniqSets` fvOP y
Expand All @@ -170,6 +223,8 @@ containsConstants :: ExtraOp -> Bool
containsConstants (I _) = False
containsConstants (V _) = False
containsConstants (C _) = True
containsConstants (Add _ _) = True
containsConstants (Mul _ _) = True
containsConstants (Max x y) = containsConstants x || containsConstants y
containsConstants (Min x y) = containsConstants x || containsConstants y
containsConstants (Div x y) = containsConstants x || containsConstants y
Expand Down
6 changes: 6 additions & 0 deletions tests/ErrorTests.hs
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,9 @@ testFail26 = testFail26' (Proxy @4) (Proxy @6) (Proxy @6)
testFail27 :: Proxy n -> Proxy (n + 2 <=? Max (n + 1) 1) -> Proxy True
testFail27 _ = id

testFail28 :: Proxy (Log 10 50 + Log 10 2) -> Proxy 2
testFail28 = id

#if __GLASGOW_HASKELL__ >= 900
testFail1Errors =
["Expected: Proxy (GCD 6 8) -> Proxy 4"
Expand Down Expand Up @@ -345,3 +348,6 @@ testFail26Errors =
["Could not deduce: Max x y ~ n"
,"from the context: (x <=? n) ~ 'True"
]

testFail28Errors =
["Couldn't match type ‘Log 10 50 + Log 10 2’ with ‘2’"]
41 changes: 40 additions & 1 deletion tests/Main.hs
Original file line number Diff line number Diff line change
Expand Up @@ -223,13 +223,42 @@ test57
-> Proxy True
test57 _ _ = id

test58
:: Proxy b
-> Proxy n
-> Proxy p
-> Proxy q
-> Proxy (Log b (n * p * q))
-> Proxy (Log b n + Log b p + Log b q)
test58 _ _ _ _ = id

test59
:: Proxy b
-> Proxy n
-> Proxy p
-> Proxy q
-> Proxy r
-> Proxy (Log b (n * p * q) + r)
-> Proxy (Log b n + r + Log b p + Log b q)
test59 _ _ _ _ _ = id

test60
:: Proxy b
-> Proxy n
-> Proxy p
-> Proxy q
-> Proxy r
-> Proxy (Log b (Max p q * Min r q * q))
-> Proxy (Log b (Max p q) + Log b (Min r q) + Log b q)
test60 _ _ _ _ _ = id

main :: IO ()
main = defaultMain tests

tests :: TestTree
tests = testGroup "ghc-typelits-natnormalise"
[ testGroup "Basic functionality"
[ testCase "GCD 6 8 ~ 2" $
[testCase "GCD 6 8 ~ 2" $
show (test1 Proxy) @?=
"Proxy"
, testCase "forall x . GCD 6 8 + x ~ x + GCD 10 8" $
Expand Down Expand Up @@ -400,6 +429,15 @@ tests = testGroup "ghc-typelits-natnormalise"
, testCase "forall n p . n + 1 <= Max (n + p + 1) p" $
show (test57 Proxy Proxy Proxy) @?=
"Proxy"
, testCase "forall b n p q . Log b (n * p * q) ~ Log b n + Log b p + Log b q" $
show (test58 Proxy Proxy Proxy Proxy Proxy) @?=
"Proxy"
, testCase "forall b n p q r . r + Log b (n * p * q) ~ Log b n + r + Log b p + Log b q" $
show (test59 Proxy Proxy Proxy Proxy Proxy Proxy) @?=
"Proxy"
, testCase "forall b n p q r . r + Log b (Max p q * Min r q * q) ~ Log b (Max p q) + Log b (Min r q) + Log b q" $
show (test60 Proxy Proxy Proxy Proxy Proxy Proxy) @?=
"Proxy"
]
, testGroup "errors"
[ testCase "GCD 6 8 /~ 4" $ testFail1 `throws` testFail1Errors
Expand Down Expand Up @@ -429,6 +467,7 @@ tests = testGroup "ghc-typelits-natnormalise"
, testCase "(x+1 <=? Max x y) /~ True" $ testFail25 `throws` testFail25Errors
, testCase "(x <= n) /=> (Max x y) ~ n" $ testFail26 `throws` testFail26Errors
, testCase "n + 2 <=? Max (n + 1) 1 /~ True" $ testFail27 `throws` testFail27Errors
, testCase "Log 10 50 + Log 10 2 /= 2" $ testFail28 `throws` testFail28Errors
]
]

Expand Down

0 comments on commit e9abf39

Please sign in to comment.