From e6640a13a56d60fb5e58f3b8aae6aae46f2e75e9 Mon Sep 17 00:00:00 2001 From: rowanG077 Date: Wed, 9 Feb 2022 18:07:41 +0100 Subject: [PATCH] Improved Log inference Normalise `Log base (a * b * ... * z)` to `Log base a + Log base b + ... + Log base z` Normalise `Log b (x^y)` to `y * Log b` --- .gitignore | 2 ++ CHANGELOG.md | 4 +++ src/GHC/TypeLits/Extra/Solver/Operations.hs | 40 ++++++++++++++++----- src/GHC/TypeLits/Extra/Solver/Unify.hs | 2 +- tests/Main.hs | 24 +++++++++++++ 5 files changed, 62 insertions(+), 10 deletions(-) diff --git a/.gitignore b/.gitignore index 8ee1bf9..1ee91d1 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,3 @@ .stack-work +dist +dist-newstyle diff --git a/CHANGELOG.md b/CHANGELOG.md index e3c128f..ec1cf1a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,9 @@ # Changelog for the [`ghc-typelits-extra`](http://hackage.haskell.org/package/ghc-typelits-extra) package +# Unreleased +* Normalise `Log base (a * b * ... * z)` to `Log base a + Log base b + ... + Log base z` +* Normalise `Log b (x^y)` to `y * Log b x` + # 0.4.3 *June 18th 2021* * Add support for GHC 9.2.0.20210422 diff --git a/src/GHC/TypeLits/Extra/Solver/Operations.hs b/src/GHC/TypeLits/Extra/Solver/Operations.hs index c63e940..57c083e 100644 --- a/src/GHC/TypeLits/Extra/Solver/Operations.hs +++ b/src/GHC/TypeLits/Extra/Solver/Operations.hs @@ -31,23 +31,28 @@ where -- external import Control.Monad.Trans.Writer.Strict #if MIN_VERSION_ghc_typelits_natnormalise(0,7,0) -import Data.Set as Set +import qualified Data.Set as Set #endif + import GHC.Base (isTrue#,(==#),(+#)) import GHC.Integer (smallInteger) import GHC.Integer.Logarithms (integerLogBase#) -import GHC.TypeLits.Normalise.Unify (CType (..), normaliseNat, isNatural) +import GHC.TypeLits.Normalise.SOP (SOP (..), Product (..)) +import GHC.TypeLits.Normalise.Unify ( CType (..), normaliseNat, isNatural + , reifySOP) -- GHC API #if MIN_VERSION_ghc(9,0,0) -import GHC.Builtin.Types.Literals (typeNatExpTyCon, typeNatSubTyCon) +import GHC.Builtin.Types.Literals ( typeNatExpTyCon, typeNatMulTyCon + , typeNstAddTyCon, typeNatSubTyCon) import GHC.Core.TyCon (TyCon) import GHC.Core.Type (Type, TyVar, mkNumLitTy, mkTyConApp, mkTyVarTy) import GHC.Utils.Outputable (Outputable (..), (<+>), integer, text) #else import Outputable (Outputable (..), (<+>), integer, text) -import TcTypeNats (typeNatExpTyCon, typeNatSubTyCon) +import TcTypeNats ( typeNatExpTyCon, typeNatMulTyCon, typeNatAddTyCon + , typeNatSubTyCon) import TyCon (TyCon) import Type (Type, TyVar, mkNumLitTy, mkTyConApp, mkTyVarTy) #endif @@ -194,11 +199,28 @@ mergeCLog i (Exp j k) | i == j = Just (k, Normalised) mergeCLog (I i) (I j) = fmap (\r -> (I r, Normalised)) (clogBase i j) mergeCLog x y = Just (CLog x y, Untouched) -mergeLog :: ExtraOp -> ExtraOp -> Maybe NormaliseResult -mergeLog (I i) _ | i < 2 = Nothing -mergeLog b (Exp b' y) | b == b' = Just (y, Normalised) -mergeLog (I i) (I j) = fmap (\r -> (I r, Normalised)) (exactLogBase i j) -mergeLog x y = Just (Log x y, Untouched) +mergeLog :: ExtraDefs -> ExtraOp -> ExtraOp -> Maybe NormaliseResult +mergeLog _ (I i) _ | i < 2 = Nothing +mergeLog _ b (Exp b' y) | b == b' = Just (y, Normalised) +mergeLog defs b (Exp x y) = + let t = mkTyConApp + typeNatMulTyCon + [ reifyEOP defs y + , reifyEOP defs (Log b x) + ] + in Just (C (CType t), Normalised) +mergeLog _ (I i) (I j) = fmap (\r -> (I r, Normalised)) (exactLogBase i j) +mergeLog defs b x = case (fst (runWriter (normaliseNat (reifyEOP defs x)))) of + S [P [_]] -> Just (Log b x, Untouched) + S [P p] -> + let si a = [a] + toLog body = reifyEOP defs (Log b (C (CType body))) + logs = map (toLog . reifySOP . S . si . P . si) p + merged = foldr (\t1 t2 -> mkTyConApp typeNatAddTyCon [t1, t2]) + (head logs) + (tail logs) + in Just (C (CType merged), Normalised) + _ -> Just (Log b x, Untouched) mergeGCD :: ExtraOp -> ExtraOp -> NormaliseResult mergeGCD (I i) (I j) = (I (gcd i j), Normalised) diff --git a/src/GHC/TypeLits/Extra/Solver/Unify.hs b/src/GHC/TypeLits/Extra/Solver/Unify.hs index b365044..a3c6fbd 100644 --- a/src/GHC/TypeLits/Extra/Solver/Unify.hs +++ b/src/GHC/TypeLits/Extra/Solver/Unify.hs @@ -84,7 +84,7 @@ normaliseNat defs (TyConApp tc [x,y]) | tc == clogTyCon defs = mergeNormResWith (\x' y' -> MaybeT (return (mergeCLog x' y'))) (normaliseNat defs x) (normaliseNat defs y) - | tc == logTyCon defs = mergeNormResWith (\x' y' -> MaybeT (return (mergeLog x' y'))) + | tc == logTyCon defs = mergeNormResWith (\x' y' -> MaybeT (return (mergeLog defs x' y'))) (normaliseNat defs x) (normaliseNat defs y) | tc == gcdTyCon defs = mergeNormResWith (\x' y' -> return (mergeGCD x' y')) diff --git a/tests/Main.hs b/tests/Main.hs index 4fa8382..041743e 100644 --- a/tests/Main.hs +++ b/tests/Main.hs @@ -223,6 +223,24 @@ test57 -> Proxy True test57 _ _ = id +test58 + :: Proxy b + -> Proxy x + -> Proxy y + -> Proxy (Log b (x^y)) + -> Proxy (y * (Log b x)) +test58 _ _ _ = id + +test59 + :: Proxy b + -> Proxy n + -> Proxy p + -> Proxy q + -> Proxy (Log b (n * p * q)) + -> Proxy (Log b n + Log b p + Log b q) +test59 _ _ _ _ = id + + main :: IO () main = defaultMain tests @@ -400,6 +418,12 @@ tests = testGroup "ghc-typelits-natnormalise" , testCase "forall n p . n + 1 <= Max (n + p + 1) p" $ show (test57 Proxy Proxy Proxy) @?= "Proxy" + , testCase "forall b x y . Log b (x^y) ~ y * (Log b x)" $ + show (test58 Proxy Proxy Proxy Proxy) @?= + "Proxy" + , testCase "forall b n p q . Log b (n * p * q) ~ Log b n + Log b p + Log b q" $ + show (test59 Proxy Proxy Proxy Proxy Proxy) @?= + "Proxy" ] , testGroup "errors" [ testCase "GCD 6 8 /~ 4" $ testFail1 `throws` testFail1Errors