diff --git a/CHANGELOG.md b/CHANGELOG.md index e3c128f..efae382 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,8 @@ # Changelog for the [`ghc-typelits-extra`](http://hackage.haskell.org/package/ghc-typelits-extra) package +# 0.4.4 Unreleased +* Reduce `*Log b (n * b^f)` to `*Log b n + f` + # 0.4.3 *June 18th 2021* * Add support for GHC 9.2.0.20210422 diff --git a/ghc-typelits-extra.cabal b/ghc-typelits-extra.cabal index 6f789d9..5fed383 100644 --- a/ghc-typelits-extra.cabal +++ b/ghc-typelits-extra.cabal @@ -1,5 +1,5 @@ name: ghc-typelits-extra -version: 0.4.3 +version: 0.4.4 synopsis: Additional type-level operations on GHC.TypeLits.Nat description: Additional type-level operations on @GHC.TypeLits.Nat@: diff --git a/src/GHC/TypeLits/Extra/Solver/Operations.hs b/src/GHC/TypeLits/Extra/Solver/Operations.hs index c63e940..edef2d4 100644 --- a/src/GHC/TypeLits/Extra/Solver/Operations.hs +++ b/src/GHC/TypeLits/Extra/Solver/Operations.hs @@ -14,6 +14,7 @@ module GHC.TypeLits.Extra.Solver.Operations , Normalised (..) , NormaliseResult , mergeNormalised + , toExtraOp , reifyEOP , mergeMax , mergeMin @@ -37,19 +38,26 @@ import Data.Set as Set import GHC.Base (isTrue#,(==#),(+#)) import GHC.Integer (smallInteger) import GHC.Integer.Logarithms (integerLogBase#) -import GHC.TypeLits.Normalise.Unify (CType (..), normaliseNat, isNatural) + +import qualified GHC.TypeLits.Normalise.SOP as N + (SOP (..), Product(..), Symbol(..)) +import qualified GHC.TypeLits.Normalise.Unify as N + (CType (..), normaliseNat, isNatural, reifySOP) -- GHC API #if MIN_VERSION_ghc(9,0,0) -import GHC.Builtin.Types.Literals (typeNatExpTyCon, typeNatSubTyCon) +import GHC.Builtin.Types.Literals (typeNatExpTyCon, typeNatAddTyCon + , typeNatSubTyCon) import GHC.Core.TyCon (TyCon) -import GHC.Core.Type (Type, TyVar, mkNumLitTy, mkTyConApp, mkTyVarTy) +import GHC.Core.TyCo.Rep (Type (..), TyLit (..)) +import GHC.Core.Type (TyVar, mkNumLitTy, mkTyConApp, mkTyVarTy, coreView) import GHC.Utils.Outputable (Outputable (..), (<+>), integer, text) #else import Outputable (Outputable (..), (<+>), integer, text) -import TcTypeNats (typeNatExpTyCon, typeNatSubTyCon) +import TcTypeNats (typeNatExpTyCon, typeNatAddTyCon, typeNatSubTyCon) import TyCon (TyCon) -import Type (Type, TyVar, mkNumLitTy, mkTyConApp, mkTyVarTy) +import TyCoRep (Type (..), TyLit (..)) +import Type (Type, TyVar, mkNumLitTy, mkTyConApp, mkTyVarTy, coreView) #endif -- | Indicates whether normalisation has occured @@ -65,6 +73,23 @@ mergeNormalised Normalised _ = Normalised mergeNormalised _ Normalised = Normalised mergeNormalised _ _ = Untouched +toExtraOp :: ExtraDefs -> Type -> Maybe ExtraOp +toExtraOp defs ty | Just ty1 <- coreView ty = toExtraOp defs ty1 +toExtraOp _ (TyVarTy v) = pure (V v) +toExtraOp _ (LitTy (NumTyLit i)) = pure (I i) +toExtraOp defs (TyConApp tc [x,y]) + | tc == maxTyCon defs = Max <$> (toExtraOp defs x) <*> (toExtraOp defs y) + | tc == minTyCon defs = Min <$> (toExtraOp defs x) <*> (toExtraOp defs y) + | tc == divTyCon defs = Div <$> (toExtraOp defs x) <*> (toExtraOp defs y) + | tc == modTyCon defs = Mod <$> (toExtraOp defs x) <*> (toExtraOp defs y) + | tc == flogTyCon defs = FLog <$> (toExtraOp defs x) <*> (toExtraOp defs y) + | tc == clogTyCon defs = CLog <$> (toExtraOp defs x) <*> (toExtraOp defs y) + | tc == logTyCon defs = Log <$> (toExtraOp defs x) <*> (toExtraOp defs y) + | tc == gcdTyCon defs = GCD <$> (toExtraOp defs x) <*> (toExtraOp defs y) + | tc == lcmTyCon defs = LCM <$> (toExtraOp defs x) <*> (toExtraOp defs y) + | tc == typeNatExpTyCon = Exp <$> (toExtraOp defs x) <*> (toExtraOp defs y) +toExtraOp _ t = Just (C (N.CType t)) + -- | A normalise result contains the ExtraOp and a flag that indicates whether any expression -- | was normalised within the ExtraOp. type NormaliseResult = (ExtraOp, Normalised) @@ -72,7 +97,7 @@ type NormaliseResult = (ExtraOp, Normalised) data ExtraOp = I Integer | V TyVar - | C CType + | C N.CType | Max ExtraOp ExtraOp | Min ExtraOp ExtraOp | Div ExtraOp ExtraOp @@ -116,7 +141,7 @@ data ExtraDefs = ExtraDefs reifyEOP :: ExtraDefs -> ExtraOp -> Type reifyEOP _ (I i) = mkNumLitTy i reifyEOP _ (V v) = mkTyVarTy v -reifyEOP _ (C (CType c)) = c +reifyEOP _ (C (N.CType c)) = c reifyEOP defs (Max x y) = mkTyConApp (maxTyCon defs) [reifyEOP defs x ,reifyEOP defs y] reifyEOP defs (Min x y) = mkTyConApp (minTyCon defs) [reifyEOP defs x @@ -144,13 +169,13 @@ mergeMax _ x (I 0) = (x, Normalised) mergeMax defs x y = let x' = reifyEOP defs x y' = reifyEOP defs y - z = fst (runWriter (normaliseNat (mkTyConApp typeNatSubTyCon [y',x']))) + z = fst (runWriter (N.normaliseNat (mkTyConApp typeNatSubTyCon [y',x']))) #if MIN_VERSION_ghc_typelits_natnormalise(0,7,0) - in case runWriterT (isNatural z) of + in case runWriterT (N.isNatural z) of Just (True , cs) | Set.null cs -> (y, Normalised) Just (False, cs) | Set.null cs -> (x, Normalised) #else - in case isNatural z of + in case N.isNatural z of Just True -> (y, Normalised) Just False -> (x, Normalised) #endif @@ -160,13 +185,13 @@ mergeMin :: ExtraDefs -> ExtraOp -> ExtraOp -> NormaliseResult mergeMin defs x y = let x' = reifyEOP defs x y' = reifyEOP defs y - z = fst (runWriter (normaliseNat (mkTyConApp typeNatSubTyCon [y',x']))) + z = fst (runWriter (N.normaliseNat (mkTyConApp typeNatSubTyCon [y',x']))) #if MIN_VERSION_ghc_typelits_natnormalise(0,7,0) - in case runWriterT (isNatural z) of + in case runWriterT (N.isNatural z) of Just (True, cs) | Set.null cs -> (x, Normalised) Just (False,cs) | Set.null cs -> (y, Normalised) #else - in case isNatural z of + in case N.isNatural z of Just True -> (x, Normalised) Just False -> (y, Normalised) #endif @@ -182,23 +207,70 @@ mergeMod _ (I 0) = Nothing mergeMod (I i) (I j) = Just (I (mod i j), Normalised) mergeMod x y = Just (Mod x y, Untouched) -mergeFLog :: ExtraOp -> ExtraOp -> Maybe NormaliseResult -mergeFLog (I i) _ | i < 2 = Nothing -mergeFLog i (Exp j k) | i == j = Just (k, Normalised) -mergeFLog (I i) (I j) = fmap (\r -> (I r, Normalised)) (flogBase i j) -mergeFLog x y = Just (FLog x y, Untouched) - -mergeCLog :: ExtraOp -> ExtraOp -> Maybe NormaliseResult -mergeCLog (I i) _ | i < 2 = Nothing -mergeCLog i (Exp j k) | i == j = Just (k, Normalised) -mergeCLog (I i) (I j) = fmap (\r -> (I r, Normalised)) (clogBase i j) -mergeCLog x y = Just (CLog x y, Untouched) - -mergeLog :: ExtraOp -> ExtraOp -> Maybe NormaliseResult -mergeLog (I i) _ | i < 2 = Nothing -mergeLog b (Exp b' y) | b == b' = Just (y, Normalised) -mergeLog (I i) (I j) = fmap (\r -> (I r, Normalised)) (exactLogBase i j) -mergeLog x y = Just (Log x y, Untouched) +-- | Try to factor out terms in the logarithm. In essence +-- it does the following transformation `Log b (n * b^f)` -> `(Log b n) + f`. +tryFactorLog + :: (ExtraOp -> ExtraOp -> ExtraOp) + -> ExtraDefs + -> ExtraOp + -> ExtraOp + -> Maybe NormaliseResult +tryFactorLog logConstr defs x y = result + where + -- Get SOP from Natnormalise plugins and check if the sum of products + -- contains forall v. base^v in all products. If this is the case + -- we can extract the v outside of the log and eliminate the base*v from + -- the log. I.e., reduce `Log b (n * b^f)` to `(Log b n) * f`. + -- + -- TODO: We could go even further and find the smallest common factor + -- and extract it out. Probably only worth it if it is a literal + mkProduct [] = N.P [N.I 1] + mkProduct xs = N.P xs + extractFactor _ _ [] = Nothing + extractFactor ac b (c:cs) + | b == (N.CType (N.reifySOP (N.S [N.P [c]]))) + = Just (mkProduct ((reverse ac) ++ cs), mkNumLitTy 1) + extractFactor ac b (((N.E c v)):cs) + | b == N.CType (N.reifySOP c) + = Just (mkProduct ((reverse ac) ++ cs), N.reifySOP (N.S [v])) + extractFactor ac b (c:cs) = extractFactor (c:ac) b cs + allSame [] = True + allSame (v:vs) = all (v ==) vs + + x1 = N.CType (reifyEOP defs x) + (ySOP, ltCts) = runWriter (N.normaliseNat (reifyEOP defs y)) + resultM = do + newProductsAndFactors <- sequence (fmap (extractFactor [] x1 . N.unP) (N.unS ySOP)) + let (newProducts, factors) = unzip newProductsAndFactors + let factor = head factors + newLogOf <- toExtraOp defs (N.reifySOP (N.S newProducts)) + let newLog = reifyEOP defs (logConstr x newLogOf) + let normalisedTy = mkTyConApp typeNatAddTyCon [newLog, factor] + if allSame (fmap N.CType factors) && Prelude.null ltCts + then toExtraOp defs normalisedTy + else Nothing + + result = case resultM of + Just norm -> Just (norm, Normalised) + Nothing -> Just (logConstr x y, Untouched) + +mergeFLog :: ExtraDefs -> ExtraOp -> ExtraOp -> Maybe NormaliseResult +mergeFLog _ (I i) _ | i < 2 = Nothing +mergeFLog _ i (Exp j k) | i == j = Just (k, Normalised) +mergeFLog _ (I i) (I j) = fmap (\r -> (I r, Normalised)) (flogBase i j) +mergeFLog defs x y = tryFactorLog FLog defs x y + +mergeCLog :: ExtraDefs -> ExtraOp -> ExtraOp -> Maybe NormaliseResult +mergeCLog _ (I i) _ | i < 2 = Nothing +mergeCLog _ i (Exp j k) | i == j = Just (k, Normalised) +mergeCLog _ (I i) (I j) = fmap (\r -> (I r, Normalised)) (clogBase i j) +mergeCLog defs x y = tryFactorLog CLog defs x y + +mergeLog :: ExtraDefs -> ExtraOp -> ExtraOp -> Maybe NormaliseResult +mergeLog _ (I i) _ | i < 2 = Nothing +mergeLog _ b (Exp b' y) | b == b' = Just (y, Normalised) +mergeLog _ (I i) (I j) = fmap (\r -> (I r, Normalised)) (exactLogBase i j) +mergeLog defs x y = tryFactorLog Log defs x y mergeGCD :: ExtraOp -> ExtraOp -> NormaliseResult mergeGCD (I i) (I j) = (I (gcd i j), Normalised) diff --git a/src/GHC/TypeLits/Extra/Solver/Unify.hs b/src/GHC/TypeLits/Extra/Solver/Unify.hs index 6776d86..ed15780 100644 --- a/src/GHC/TypeLits/Extra/Solver/Unify.hs +++ b/src/GHC/TypeLits/Extra/Solver/Unify.hs @@ -12,7 +12,6 @@ module GHC.TypeLits.Extra.Solver.Unify , UnifyResult (..) , NormaliseResult , normaliseNat - , toExtraOp , unifyExtra ) where @@ -24,9 +23,8 @@ import GHC.TypeLits.Normalise.Unify (CType (..)) -- GHC API #if MIN_VERSION_ghc(9,0,0) -import GHC.Builtin.Types.Literals (typeNatExpTyCon) -import GHC.Core.TyCo.Rep (Type (..), TyLit (..)) -import GHC.Core.Type (TyVar, coreView) +import GHC.Core.TyCo.Rep (Type (..)) +import GHC.Core.Type (TyVar) import GHC.Tc.Plugin (TcPluginM, tcPluginTrace) import GHC.Tc.Types.Constraint (Ct) import GHC.Types.Unique.Set (UniqSet, emptyUniqSet, unionUniqSets, unitUniqSet) @@ -35,7 +33,7 @@ import GHC.Utils.Outputable (Outputable (..), ($$), text) import Outputable (Outputable (..), ($$), text) import TcPluginM (TcPluginM, tcPluginTrace) import TcTypeNats (typeNatExpTyCon) -import Type (TyVar, coreView) +import Type (TyVar) import TyCoRep (Type (..), TyLit (..)) import UniqSet (UniqSet, emptyUniqSet, unionUniqSets, unitUniqSet) #if MIN_VERSION_ghc(8,10,0) @@ -59,23 +57,6 @@ mergeNormResWith f x y = do (res, n3) <- f x' y' pure (res, n1 `mergeNormalised` n2 `mergeNormalised` n3) -toExtraOp :: ExtraDefs -> Type -> Maybe ExtraOp -toExtraOp defs ty | Just ty1 <- coreView ty = toExtraOp defs ty1 -toExtraOp _ (TyVarTy v) = pure (V v) -toExtraOp _ (LitTy (NumTyLit i)) = pure (I i) -toExtraOp defs (TyConApp tc [x,y]) - | tc == maxTyCon defs = Max <$> (toExtraOp defs x) <*> (toExtraOp defs y) - | tc == minTyCon defs = Min <$> (toExtraOp defs x) <*> (toExtraOp defs y) - | tc == divTyCon defs = Div <$> (toExtraOp defs x) <*> (toExtraOp defs y) - | tc == modTyCon defs = Mod <$> (toExtraOp defs x) <*> (toExtraOp defs y) - | tc == flogTyCon defs = FLog <$> (toExtraOp defs x) <*> (toExtraOp defs y) - | tc == clogTyCon defs = CLog <$> (toExtraOp defs x) <*> (toExtraOp defs y) - | tc == logTyCon defs = Log <$> (toExtraOp defs x) <*> (toExtraOp defs y) - | tc == gcdTyCon defs = GCD <$> (toExtraOp defs x) <*> (toExtraOp defs y) - | tc == lcmTyCon defs = LCM <$> (toExtraOp defs x) <*> (toExtraOp defs y) - | tc == typeNatExpTyCon = Exp <$> (toExtraOp defs x) <*> (toExtraOp defs y) -toExtraOp _ t = Just (C (CType t)) - normaliseNat :: ExtraDefs -> ExtraOp -> Maybe NormaliseResult normaliseNat _ (I i) = pure (I i, Untouched) normaliseNat _ (V v) = pure (V v, Untouched) @@ -85,19 +66,19 @@ normaliseNat defs (Max x y) = mergeNormResWith (\x' y' -> pure (mergeMax defs x' normaliseNat defs (Min x y) = mergeNormResWith (\x' y' -> pure (mergeMin defs x' y')) (normaliseNat defs x) (normaliseNat defs y) -normaliseNat defs (Div x y) = mergeNormResWith (\x' y' -> (mergeDiv x' y')) +normaliseNat defs (Div x y) = mergeNormResWith mergeDiv (normaliseNat defs x) (normaliseNat defs y) -normaliseNat defs (Mod x y) = mergeNormResWith (\x' y' -> (mergeMod x' y')) +normaliseNat defs (Mod x y) = mergeNormResWith mergeMod (normaliseNat defs x) (normaliseNat defs y) -normaliseNat defs (FLog x y) = mergeNormResWith (\x' y' -> (mergeFLog x' y')) +normaliseNat defs (FLog x y) = mergeNormResWith (mergeFLog defs) (normaliseNat defs x) (normaliseNat defs y) -normaliseNat defs (CLog x y) = mergeNormResWith (\x' y' -> (mergeCLog x' y')) +normaliseNat defs (CLog x y) = mergeNormResWith (mergeCLog defs) (normaliseNat defs x) (normaliseNat defs y) -normaliseNat defs (Log x y) = mergeNormResWith (\x' y' -> (mergeLog x' y')) +normaliseNat defs (Log x y) = mergeNormResWith (mergeLog defs) (normaliseNat defs x) (normaliseNat defs y) normaliseNat defs (GCD x y) = mergeNormResWith (\x' y' -> pure (mergeGCD x' y')) diff --git a/tests/Main.hs b/tests/Main.hs index 4fa8382..3e85122 100644 --- a/tests/Main.hs +++ b/tests/Main.hs @@ -223,6 +223,33 @@ test57 -> Proxy True test57 _ _ = id +test58 + :: Proxy (CLog 2 (n * 2)) + -> Proxy (CLog 2 n + 1) +test58 = id + +test59 + :: Proxy n + -> Proxy b + -> Proxy (CLog b (n * b)) + -> Proxy (CLog b n + 1) +test59 _ _ = id + +test60 + :: Proxy n + -> Proxy b + -> Proxy (CLog b (n * b * b)) + -> Proxy (CLog b n + 2) +test60 _ _ = id + +test61 + :: Proxy n + -> Proxy b + -> Proxy f + -> Proxy (CLog (b ^ n) ((f * (b ^ n)) + (f * (b ^ n)))) + -> Proxy (CLog (b ^ n) (2 * f) + 1) +test61 _ _ _ = id + main :: IO () main = defaultMain tests @@ -400,6 +427,18 @@ tests = testGroup "ghc-typelits-natnormalise" , testCase "forall n p . n + 1 <= Max (n + p + 1) p" $ show (test57 Proxy Proxy Proxy) @?= "Proxy" + , testCase "forall n . CLog 2 (n * 2) ~ CLog 2 n + 1" $ + show (test58 Proxy) @?= + "Proxy" + , testCase "forall n b. CLog b (n * b) ~ CLog b n + 1" $ + show (test59 Proxy Proxy Proxy) @?= + "Proxy" + , testCase "forall n b. CLog b (n * b * b) ~ CLog b n + 2" $ + show (test60 Proxy Proxy Proxy) @?= + "Proxy" + , testCase "forall n b f. CLog (b ^ n) ((f * (b ^ n)) + (f * (b ^ n)))) ~ CLog (b ^ n) (2 * f) + 1" $ + show (test61 Proxy Proxy Proxy Proxy) @?= + "Proxy" ] , testGroup "errors" [ testCase "GCD 6 8 /~ 4" $ testFail1 `throws` testFail1Errors