From e9abf394d209d852dac71f6e4bba0ceda0dfd0e2 Mon Sep 17 00:00:00 2001 From: rowanG077 Date: Wed, 9 Feb 2022 18:07:41 +0100 Subject: [PATCH] Improved Log inference Normalise `Log base a + Log base b + ... + Log base z` + `Log base (a * b * ... * z)` --- .gitignore | 2 + CHANGELOG.md | 3 ++ src/GHC/TypeLits/Extra/Solver/Operations.hs | 40 ++++++++++++-- src/GHC/TypeLits/Extra/Solver/Unify.hs | 59 ++++++++++++++++++++- tests/ErrorTests.hs | 6 +++ tests/Main.hs | 41 +++++++++++++- 6 files changed, 144 insertions(+), 7 deletions(-) diff --git a/.gitignore b/.gitignore index 8ee1bf9..1ee91d1 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,3 @@ .stack-work +dist +dist-newstyle diff --git a/CHANGELOG.md b/CHANGELOG.md index e3c128f..aa141cf 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,8 @@ # Changelog for the [`ghc-typelits-extra`](http://hackage.haskell.org/package/ghc-typelits-extra) package +# Unreleased +* Normalise `Log base a + Log base b + ... + Log base z` to `Log base (a * b * ... * z)` + # 0.4.3 *June 18th 2021* * Add support for GHC 9.2.0.20210422 diff --git a/src/GHC/TypeLits/Extra/Solver/Operations.hs b/src/GHC/TypeLits/Extra/Solver/Operations.hs index c63e940..28feb90 100644 --- a/src/GHC/TypeLits/Extra/Solver/Operations.hs +++ b/src/GHC/TypeLits/Extra/Solver/Operations.hs @@ -15,6 +15,8 @@ module GHC.TypeLits.Extra.Solver.Operations , NormaliseResult , mergeNormalised , reifyEOP + , mergeAdd + , mergeMul , mergeMax , mergeMin , mergeDiv @@ -31,8 +33,12 @@ where -- external import Control.Monad.Trans.Writer.Strict #if MIN_VERSION_ghc_typelits_natnormalise(0,7,0) -import Data.Set as Set +import qualified Data.Set as Set #endif +#if !MIN_VERSION_base(4,11,0) +import Data.Semigroup (Semigroup (..)) +#endif + import GHC.Base (isTrue#,(==#),(+#)) import GHC.Integer (smallInteger) @@ -41,13 +47,15 @@ import GHC.TypeLits.Normalise.Unify (CType (..), normaliseNat, isNatural) -- GHC API #if MIN_VERSION_ghc(9,0,0) -import GHC.Builtin.Types.Literals (typeNatExpTyCon, typeNatSubTyCon) +import GHC.Builtin.Types.Literals ( typeNatExpTyCon, typeNatSubTyCon + , typeNatAddTyCon, typeNatMulTyCon) import GHC.Core.TyCon (TyCon) import GHC.Core.Type (Type, TyVar, mkNumLitTy, mkTyConApp, mkTyVarTy) import GHC.Utils.Outputable (Outputable (..), (<+>), integer, text) #else import Outputable (Outputable (..), (<+>), integer, text) -import TcTypeNats (typeNatExpTyCon, typeNatSubTyCon) +import TcTypeNats ( typeNatExpTyCon, typeNatSubTyCon, typeNatAddTyCon + , typeNatMulTyCon) import TyCon (TyCon) import Type (Type, TyVar, mkNumLitTy, mkTyConApp, mkTyVarTy) #endif @@ -65,6 +73,15 @@ mergeNormalised Normalised _ = Normalised mergeNormalised _ Normalised = Normalised mergeNormalised _ _ = Untouched +instance Semigroup Normalised where + (<>) = mergeNormalised + +instance Monoid Normalised where + mempty = Untouched +#if !MIN_VERSION_base(4,11,0) + mappend = mergeNormalised +#endif + -- | A normalise result contains the ExtraOp and a flag that indicates whether any expression -- | was normalised within the ExtraOp. type NormaliseResult = (ExtraOp, Normalised) @@ -73,6 +90,8 @@ data ExtraOp = I Integer | V TyVar | C CType + | Add ExtraOp ExtraOp + | Mul ExtraOp ExtraOp | Max ExtraOp ExtraOp | Min ExtraOp ExtraOp | Div ExtraOp ExtraOp @@ -83,12 +102,14 @@ data ExtraOp | GCD ExtraOp ExtraOp | LCM ExtraOp ExtraOp | Exp ExtraOp ExtraOp - deriving Eq + deriving (Eq, Ord) instance Outputable ExtraOp where ppr (I i) = integer i ppr (V v) = ppr v ppr (C c) = ppr c + ppr (Add x y) = text "Add (" <+> ppr x <+> text "," <+> ppr y <+> text ")" + ppr (Mul x y) = text "Mul (" <+> ppr x <+> text "," <+> ppr y <+> text ")" ppr (Max x y) = text "Max (" <+> ppr x <+> text "," <+> ppr y <+> text ")" ppr (Min x y) = text "Min (" <+> ppr x <+> text "," <+> ppr y <+> text ")" ppr (Div x y) = text "Div (" <+> ppr x <+> text "," <+> ppr y <+> text ")" @@ -117,6 +138,10 @@ reifyEOP :: ExtraDefs -> ExtraOp -> Type reifyEOP _ (I i) = mkNumLitTy i reifyEOP _ (V v) = mkTyVarTy v reifyEOP _ (C (CType c)) = c +reifyEOP defs (Add x y) = mkTyConApp typeNatAddTyCon [reifyEOP defs x + ,reifyEOP defs y] +reifyEOP defs (Mul x y) = mkTyConApp typeNatMulTyCon [reifyEOP defs x + ,reifyEOP defs y] reifyEOP defs (Max x y) = mkTyConApp (maxTyCon defs) [reifyEOP defs x ,reifyEOP defs y] reifyEOP defs (Min x y) = mkTyConApp (minTyCon defs) [reifyEOP defs x @@ -138,6 +163,13 @@ reifyEOP defs (LCM x y) = mkTyConApp (lcmTyCon defs) [reifyEOP defs x reifyEOP defs (Exp x y) = mkTyConApp typeNatExpTyCon [reifyEOP defs x ,reifyEOP defs y] +mergeAdd :: ExtraOp -> ExtraOp -> NormaliseResult +mergeAdd (Log b1 x1) (Log b2 x2) | b1 == b2 = (Log b1 (Mul x1 x2), Normalised) +mergeAdd x y = (Add x y, Untouched) + +mergeMul :: ExtraOp -> ExtraOp -> NormaliseResult +mergeMul x y = (Mul x y, Untouched) + mergeMax :: ExtraDefs -> ExtraOp -> ExtraOp -> NormaliseResult mergeMax _ (I 0) y = (y, Normalised) mergeMax _ x (I 0) = (x, Normalised) diff --git a/src/GHC/TypeLits/Extra/Solver/Unify.hs b/src/GHC/TypeLits/Extra/Solver/Unify.hs index b365044..69110c0 100644 --- a/src/GHC/TypeLits/Extra/Solver/Unify.hs +++ b/src/GHC/TypeLits/Extra/Solver/Unify.hs @@ -19,14 +19,18 @@ where -- external import Control.Monad.Trans.Class (lift) import Control.Monad.Trans.Maybe (MaybeT (..)) +import Control.Monad.Trans.Writer.Strict (runWriter) import Data.Maybe (catMaybes) import Data.Function (on) import GHC.TypeLits.Normalise.Unify (CType (..)) +import qualified GHC.TypeLits.Normalise.Unify as Normalise -- GHC API #if MIN_VERSION_ghc(9,0,0) -import GHC.Builtin.Types.Literals (typeNatExpTyCon) +import GHC.Builtin.Types.Literals + (typeNatExpTyCon, typeNatMulTyCon, typeNatAddTyCon, typeNatSubTyCon) import GHC.Core.TyCo.Rep (Type (..), TyLit (..)) +import GHC.Core.TyCon (TyCon) import GHC.Core.Type (TyVar, coreView) import GHC.Tc.Plugin (TcPluginM, tcPluginTrace) import GHC.Tc.Types.Constraint (Ct) @@ -35,8 +39,10 @@ import GHC.Utils.Outputable (Outputable (..), ($$), text) #else import Outputable (Outputable (..), ($$), text) import TcPluginM (TcPluginM, tcPluginTrace) -import TcTypeNats (typeNatExpTyCon) +import TcTypeNats + (typeNatExpTyCon, typeNatAddTyCon, typeNatMulTyCon, typeNatSubTyCon) import Type (TyVar, coreView) +import TyCon (TyCon) import TyCoRep (Type (..), TyLit (..)) import UniqSet (UniqSet, emptyUniqSet, unionUniqSets, unitUniqSet) #if MIN_VERSION_ghc(8,10,0) @@ -60,12 +66,57 @@ mergeNormResWith f x y = do (res, n3) <- f x' y' pure (res, n1 `mergeNormalised` n2 `mergeNormalised` n3) +-- | First normalise the expression using the solver from ghc-typelits-natnormalise +-- before calling the mergeWith operation from this ghc-typelits-extra solver. +withSOPNormalize :: + ExtraDefs -> + -- The mergeWith operation + (ExtraOp -> ExtraOp -> (ExtraOp, Normalised)) -> + -- The (+), (-), or (*) type constructor + TyCon -> + Type -> + Type -> + MaybeT TcPluginM NormaliseResult +withSOPNormalize defs mergeWith tc x y = do + (x1, n1) <- normaliseNat defs x + (y1, n2) <- normaliseNat defs y + let x2 = reifyEOP defs x1 + y2 = reifyEOP defs y1 + (q,subtractions) = runWriter (Normalise.normaliseNat (TyConApp tc [x2, y2])) + -- Currently we don't use the result of ghc-typelits-natnormalise when there + -- are "inner" subtractions, as we would have to emit inequality constraints + -- for the arguments of those subtractions to guarantee that the result of + -- those subtractions would result in a natural number. + -- + -- TODO: emit inequality constraints for the inner subtractions + noSubtractions + | tc == typeNatSubTyCon + , [(s1, s2)] <- subtractions + = CType s1 == CType x2 && CType s2 == CType y2 + | otherwise + = null subtractions + if noSubtractions then case Normalise.reifySOP q of + TyConApp tcQ [xQ,yQ] + | tcQ == tc -> do + (x3, n3) <- normaliseNat defs xQ + (y3, n4) <- normaliseNat defs yQ + let (res,n5) = mergeWith x3 y3 + lift (tcPluginTrace "called mergeWith on: " (ppr (x3, y3, res))) + return (res, mconcat [n1,n2,n3,n4,n5]) + resSOP -> do + (res, n3) <- normaliseNat defs resSOP + return (res, mconcat [n1,n2,n3]) + else do + let (res, n3) = mergeWith x1 y1 + return (res, mconcat [n1,n2,n3]) normaliseNat :: ExtraDefs -> Type -> MaybeT TcPluginM NormaliseResult normaliseNat defs ty | Just ty1 <- coreView ty = normaliseNat defs ty1 normaliseNat _ (TyVarTy v) = pure (V v, Untouched) normaliseNat _ (LitTy (NumTyLit i)) = pure (I i, Untouched) normaliseNat defs (TyConApp tc [x,y]) + | tc == typeNatAddTyCon = withSOPNormalize defs mergeAdd tc x y + | tc == typeNatMulTyCon = withSOPNormalize defs mergeMul tc x y | tc == maxTyCon defs = mergeNormResWith (\x' y' -> return (mergeMax defs x' y')) (normaliseNat defs x) (normaliseNat defs y) @@ -152,6 +203,8 @@ fvOP :: ExtraOp -> UniqSet TyVar fvOP (I _) = emptyUniqSet fvOP (V v) = unitUniqSet v fvOP (C _) = emptyUniqSet +fvOP (Add x y) = fvOP x `unionUniqSets` fvOP y +fvOP (Mul x y) = fvOP x `unionUniqSets` fvOP y fvOP (Max x y) = fvOP x `unionUniqSets` fvOP y fvOP (Min x y) = fvOP x `unionUniqSets` fvOP y fvOP (Div x y) = fvOP x `unionUniqSets` fvOP y @@ -170,6 +223,8 @@ containsConstants :: ExtraOp -> Bool containsConstants (I _) = False containsConstants (V _) = False containsConstants (C _) = True +containsConstants (Add _ _) = True +containsConstants (Mul _ _) = True containsConstants (Max x y) = containsConstants x || containsConstants y containsConstants (Min x y) = containsConstants x || containsConstants y containsConstants (Div x y) = containsConstants x || containsConstants y diff --git a/tests/ErrorTests.hs b/tests/ErrorTests.hs index 26e1cd5..c38501c 100644 --- a/tests/ErrorTests.hs +++ b/tests/ErrorTests.hs @@ -100,6 +100,9 @@ testFail26 = testFail26' (Proxy @4) (Proxy @6) (Proxy @6) testFail27 :: Proxy n -> Proxy (n + 2 <=? Max (n + 1) 1) -> Proxy True testFail27 _ = id +testFail28 :: Proxy (Log 10 50 + Log 10 2) -> Proxy 2 +testFail28 = id + #if __GLASGOW_HASKELL__ >= 900 testFail1Errors = ["Expected: Proxy (GCD 6 8) -> Proxy 4" @@ -345,3 +348,6 @@ testFail26Errors = ["Could not deduce: Max x y ~ n" ,"from the context: (x <=? n) ~ 'True" ] + +testFail28Errors = + ["Couldn't match type ‘Log 10 50 + Log 10 2’ with ‘2’"] \ No newline at end of file diff --git a/tests/Main.hs b/tests/Main.hs index 4fa8382..c05f54c 100644 --- a/tests/Main.hs +++ b/tests/Main.hs @@ -223,13 +223,42 @@ test57 -> Proxy True test57 _ _ = id +test58 + :: Proxy b + -> Proxy n + -> Proxy p + -> Proxy q + -> Proxy (Log b (n * p * q)) + -> Proxy (Log b n + Log b p + Log b q) +test58 _ _ _ _ = id + +test59 + :: Proxy b + -> Proxy n + -> Proxy p + -> Proxy q + -> Proxy r + -> Proxy (Log b (n * p * q) + r) + -> Proxy (Log b n + r + Log b p + Log b q) +test59 _ _ _ _ _ = id + +test60 + :: Proxy b + -> Proxy n + -> Proxy p + -> Proxy q + -> Proxy r + -> Proxy (Log b (Max p q * Min r q * q)) + -> Proxy (Log b (Max p q) + Log b (Min r q) + Log b q) +test60 _ _ _ _ _ = id + main :: IO () main = defaultMain tests tests :: TestTree tests = testGroup "ghc-typelits-natnormalise" [ testGroup "Basic functionality" - [ testCase "GCD 6 8 ~ 2" $ + [testCase "GCD 6 8 ~ 2" $ show (test1 Proxy) @?= "Proxy" , testCase "forall x . GCD 6 8 + x ~ x + GCD 10 8" $ @@ -400,6 +429,15 @@ tests = testGroup "ghc-typelits-natnormalise" , testCase "forall n p . n + 1 <= Max (n + p + 1) p" $ show (test57 Proxy Proxy Proxy) @?= "Proxy" + , testCase "forall b n p q . Log b (n * p * q) ~ Log b n + Log b p + Log b q" $ + show (test58 Proxy Proxy Proxy Proxy Proxy) @?= + "Proxy" + , testCase "forall b n p q r . r + Log b (n * p * q) ~ Log b n + r + Log b p + Log b q" $ + show (test59 Proxy Proxy Proxy Proxy Proxy Proxy) @?= + "Proxy" + , testCase "forall b n p q r . r + Log b (Max p q * Min r q * q) ~ Log b (Max p q) + Log b (Min r q) + Log b q" $ + show (test60 Proxy Proxy Proxy Proxy Proxy Proxy) @?= + "Proxy" ] , testGroup "errors" [ testCase "GCD 6 8 /~ 4" $ testFail1 `throws` testFail1Errors @@ -429,6 +467,7 @@ tests = testGroup "ghc-typelits-natnormalise" , testCase "(x+1 <=? Max x y) /~ True" $ testFail25 `throws` testFail25Errors , testCase "(x <= n) /=> (Max x y) ~ n" $ testFail26 `throws` testFail26Errors , testCase "n + 2 <=? Max (n + 1) 1 /~ True" $ testFail27 `throws` testFail27Errors + , testCase "Log 10 50 + Log 10 2 /= 2" $ testFail28 `throws` testFail28Errors ] ]