Skip to content

Commit

Permalink
Reduce *Log b (n * b^f) to *Log b n + f
Browse files Browse the repository at this point in the history
  • Loading branch information
rowanG077 committed Jul 12, 2022
1 parent bbc0d62 commit 3a7a1ba
Show file tree
Hide file tree
Showing 5 changed files with 153 additions and 58 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# Changelog for the [`ghc-typelits-extra`](http://hackage.haskell.org/package/ghc-typelits-extra) package

# 0.4.4 Unreleased
* Reduce `*Log b (n * b^f)` to `*Log b n + f`

# 0.4.3 *June 18th 2021*
* Add support for GHC 9.2.0.20210422

Expand Down
2 changes: 1 addition & 1 deletion ghc-typelits-extra.cabal
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
name: ghc-typelits-extra
version: 0.4.3
version: 0.4.4
synopsis: Additional type-level operations on GHC.TypeLits.Nat
description:
Additional type-level operations on @GHC.TypeLits.Nat@:
Expand Down
132 changes: 102 additions & 30 deletions src/GHC/TypeLits/Extra/Solver/Operations.hs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ module GHC.TypeLits.Extra.Solver.Operations
, Normalised (..)
, NormaliseResult
, mergeNormalised
, toExtraOp
, reifyEOP
, mergeMax
, mergeMin
Expand All @@ -37,19 +38,26 @@ import Data.Set as Set
import GHC.Base (isTrue#,(==#),(+#))
import GHC.Integer (smallInteger)
import GHC.Integer.Logarithms (integerLogBase#)
import GHC.TypeLits.Normalise.Unify (CType (..), normaliseNat, isNatural)

import qualified GHC.TypeLits.Normalise.SOP as N
(SOP (..), Product(..), Symbol(..))
import qualified GHC.TypeLits.Normalise.Unify as N
(CType (..), normaliseNat, isNatural, reifySOP)

-- GHC API
#if MIN_VERSION_ghc(9,0,0)
import GHC.Builtin.Types.Literals (typeNatExpTyCon, typeNatSubTyCon)
import GHC.Builtin.Types.Literals (typeNatExpTyCon, typeNatAddTyCon
, typeNatSubTyCon)
import GHC.Core.TyCon (TyCon)
import GHC.Core.Type (Type, TyVar, mkNumLitTy, mkTyConApp, mkTyVarTy)
import GHC.Core.TyCo.Rep (Type (..), TyLit (..))
import GHC.Core.Type (TyVar, mkNumLitTy, mkTyConApp, mkTyVarTy, coreView)
import GHC.Utils.Outputable (Outputable (..), (<+>), integer, text)
#else
import Outputable (Outputable (..), (<+>), integer, text)
import TcTypeNats (typeNatExpTyCon, typeNatSubTyCon)
import TcTypeNats (typeNatExpTyCon, typeNatAddTyCon, typeNatSubTyCon)
import TyCon (TyCon)
import Type (Type, TyVar, mkNumLitTy, mkTyConApp, mkTyVarTy)
import TyCoRep (Type (..), TyLit (..))
import Type (Type, TyVar, mkNumLitTy, mkTyConApp, mkTyVarTy, coreView)
#endif

-- | Indicates whether normalisation has occured
Expand All @@ -65,14 +73,31 @@ mergeNormalised Normalised _ = Normalised
mergeNormalised _ Normalised = Normalised
mergeNormalised _ _ = Untouched

toExtraOp :: ExtraDefs -> Type -> Maybe ExtraOp
toExtraOp defs ty | Just ty1 <- coreView ty = toExtraOp defs ty1
toExtraOp _ (TyVarTy v) = pure (V v)
toExtraOp _ (LitTy (NumTyLit i)) = pure (I i)
toExtraOp defs (TyConApp tc [x,y])
| tc == maxTyCon defs = Max <$> (toExtraOp defs x) <*> (toExtraOp defs y)
| tc == minTyCon defs = Min <$> (toExtraOp defs x) <*> (toExtraOp defs y)
| tc == divTyCon defs = Div <$> (toExtraOp defs x) <*> (toExtraOp defs y)
| tc == modTyCon defs = Mod <$> (toExtraOp defs x) <*> (toExtraOp defs y)
| tc == flogTyCon defs = FLog <$> (toExtraOp defs x) <*> (toExtraOp defs y)
| tc == clogTyCon defs = CLog <$> (toExtraOp defs x) <*> (toExtraOp defs y)
| tc == logTyCon defs = Log <$> (toExtraOp defs x) <*> (toExtraOp defs y)
| tc == gcdTyCon defs = GCD <$> (toExtraOp defs x) <*> (toExtraOp defs y)
| tc == lcmTyCon defs = LCM <$> (toExtraOp defs x) <*> (toExtraOp defs y)
| tc == typeNatExpTyCon = Exp <$> (toExtraOp defs x) <*> (toExtraOp defs y)
toExtraOp _ t = Just (C (N.CType t))

-- | A normalise result contains the ExtraOp and a flag that indicates whether any expression
-- | was normalised within the ExtraOp.
type NormaliseResult = (ExtraOp, Normalised)

data ExtraOp
= I Integer
| V TyVar
| C CType
| C N.CType
| Max ExtraOp ExtraOp
| Min ExtraOp ExtraOp
| Div ExtraOp ExtraOp
Expand Down Expand Up @@ -116,7 +141,7 @@ data ExtraDefs = ExtraDefs
reifyEOP :: ExtraDefs -> ExtraOp -> Type
reifyEOP _ (I i) = mkNumLitTy i
reifyEOP _ (V v) = mkTyVarTy v
reifyEOP _ (C (CType c)) = c
reifyEOP _ (C (N.CType c)) = c
reifyEOP defs (Max x y) = mkTyConApp (maxTyCon defs) [reifyEOP defs x
,reifyEOP defs y]
reifyEOP defs (Min x y) = mkTyConApp (minTyCon defs) [reifyEOP defs x
Expand Down Expand Up @@ -144,13 +169,13 @@ mergeMax _ x (I 0) = (x, Normalised)
mergeMax defs x y =
let x' = reifyEOP defs x
y' = reifyEOP defs y
z = fst (runWriter (normaliseNat (mkTyConApp typeNatSubTyCon [y',x'])))
z = fst (runWriter (N.normaliseNat (mkTyConApp typeNatSubTyCon [y',x'])))
#if MIN_VERSION_ghc_typelits_natnormalise(0,7,0)
in case runWriterT (isNatural z) of
in case runWriterT (N.isNatural z) of
Just (True , cs) | Set.null cs -> (y, Normalised)
Just (False, cs) | Set.null cs -> (x, Normalised)
#else
in case isNatural z of
in case N.isNatural z of
Just True -> (y, Normalised)
Just False -> (x, Normalised)
#endif
Expand All @@ -160,13 +185,13 @@ mergeMin :: ExtraDefs -> ExtraOp -> ExtraOp -> NormaliseResult
mergeMin defs x y =
let x' = reifyEOP defs x
y' = reifyEOP defs y
z = fst (runWriter (normaliseNat (mkTyConApp typeNatSubTyCon [y',x'])))
z = fst (runWriter (N.normaliseNat (mkTyConApp typeNatSubTyCon [y',x'])))
#if MIN_VERSION_ghc_typelits_natnormalise(0,7,0)
in case runWriterT (isNatural z) of
in case runWriterT (N.isNatural z) of
Just (True, cs) | Set.null cs -> (x, Normalised)
Just (False,cs) | Set.null cs -> (y, Normalised)
#else
in case isNatural z of
in case N.isNatural z of
Just True -> (x, Normalised)
Just False -> (y, Normalised)
#endif
Expand All @@ -182,23 +207,70 @@ mergeMod _ (I 0) = Nothing
mergeMod (I i) (I j) = Just (I (mod i j), Normalised)
mergeMod x y = Just (Mod x y, Untouched)

mergeFLog :: ExtraOp -> ExtraOp -> Maybe NormaliseResult
mergeFLog (I i) _ | i < 2 = Nothing
mergeFLog i (Exp j k) | i == j = Just (k, Normalised)
mergeFLog (I i) (I j) = fmap (\r -> (I r, Normalised)) (flogBase i j)
mergeFLog x y = Just (FLog x y, Untouched)

mergeCLog :: ExtraOp -> ExtraOp -> Maybe NormaliseResult
mergeCLog (I i) _ | i < 2 = Nothing
mergeCLog i (Exp j k) | i == j = Just (k, Normalised)
mergeCLog (I i) (I j) = fmap (\r -> (I r, Normalised)) (clogBase i j)
mergeCLog x y = Just (CLog x y, Untouched)

mergeLog :: ExtraOp -> ExtraOp -> Maybe NormaliseResult
mergeLog (I i) _ | i < 2 = Nothing
mergeLog b (Exp b' y) | b == b' = Just (y, Normalised)
mergeLog (I i) (I j) = fmap (\r -> (I r, Normalised)) (exactLogBase i j)
mergeLog x y = Just (Log x y, Untouched)
-- | Try to factor out terms in the logarithm. In essence
-- it does the following transformation `Log b (n * b^f)` -> `(Log b n) + f`.
tryFactorLog
:: (ExtraOp -> ExtraOp -> ExtraOp)
-> ExtraDefs
-> ExtraOp
-> ExtraOp
-> Maybe NormaliseResult
tryFactorLog logConstr defs x y = result
where
-- Get SOP from Natnormalise plugins and check if the sum of products
-- contains forall v. base^v in all products. If this is the case
-- we can extract the v outside of the log and eliminate the base*v from
-- the log. I.e., reduce `Log b (n * b^f)` to `(Log b n) * f`.
--
-- TODO: We could go even further and find the smallest common factor
-- and extract it out. Probably only worth it if it is a literal
mkProduct [] = N.P [N.I 1]
mkProduct xs = N.P xs
extractFactor _ _ [] = Nothing
extractFactor ac b (c:cs)
| b == (N.CType (N.reifySOP (N.S [N.P [c]])))
= Just (mkProduct ((reverse ac) ++ cs), mkNumLitTy 1)
extractFactor ac b (((N.E c v)):cs)
| b == N.CType (N.reifySOP c)
= Just (mkProduct ((reverse ac) ++ cs), N.reifySOP (N.S [v]))
extractFactor ac b (c:cs) = extractFactor (c:ac) b cs
allSame [] = True
allSame (v:vs) = all (v ==) vs

x1 = N.CType (reifyEOP defs x)
(ySOP, ltCts) = runWriter (N.normaliseNat (reifyEOP defs y))
resultM = do
newProductsAndFactors <- sequence (fmap (extractFactor [] x1 . N.unP) (N.unS ySOP))
let (newProducts, factors) = unzip newProductsAndFactors
let factor = head factors
newLogOf <- toExtraOp defs (N.reifySOP (N.S newProducts))
let newLog = reifyEOP defs (logConstr x newLogOf)
let normalisedTy = mkTyConApp typeNatAddTyCon [newLog, factor]
if allSame (fmap N.CType factors) && Prelude.null ltCts
then toExtraOp defs normalisedTy
else Nothing

result = case resultM of
Just norm -> Just (norm, Normalised)
Nothing -> Just (logConstr x y, Untouched)

mergeFLog :: ExtraDefs -> ExtraOp -> ExtraOp -> Maybe NormaliseResult
mergeFLog _ (I i) _ | i < 2 = Nothing
mergeFLog _ i (Exp j k) | i == j = Just (k, Normalised)
mergeFLog _ (I i) (I j) = fmap (\r -> (I r, Normalised)) (flogBase i j)
mergeFLog defs x y = tryFactorLog FLog defs x y

mergeCLog :: ExtraDefs -> ExtraOp -> ExtraOp -> Maybe NormaliseResult
mergeCLog _ (I i) _ | i < 2 = Nothing
mergeCLog _ i (Exp j k) | i == j = Just (k, Normalised)
mergeCLog _ (I i) (I j) = fmap (\r -> (I r, Normalised)) (clogBase i j)
mergeCLog defs x y = tryFactorLog CLog defs x y

mergeLog :: ExtraDefs -> ExtraOp -> ExtraOp -> Maybe NormaliseResult
mergeLog _ (I i) _ | i < 2 = Nothing
mergeLog _ b (Exp b' y) | b == b' = Just (y, Normalised)
mergeLog _ (I i) (I j) = fmap (\r -> (I r, Normalised)) (exactLogBase i j)
mergeLog defs x y = tryFactorLog Log defs x y

mergeGCD :: ExtraOp -> ExtraOp -> NormaliseResult
mergeGCD (I i) (I j) = (I (gcd i j), Normalised)
Expand Down
35 changes: 8 additions & 27 deletions src/GHC/TypeLits/Extra/Solver/Unify.hs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ module GHC.TypeLits.Extra.Solver.Unify
, UnifyResult (..)
, NormaliseResult
, normaliseNat
, toExtraOp
, unifyExtra
)
where
Expand All @@ -24,9 +23,8 @@ import GHC.TypeLits.Normalise.Unify (CType (..))

-- GHC API
#if MIN_VERSION_ghc(9,0,0)
import GHC.Builtin.Types.Literals (typeNatExpTyCon)
import GHC.Core.TyCo.Rep (Type (..), TyLit (..))
import GHC.Core.Type (TyVar, coreView)
import GHC.Core.TyCo.Rep (Type (..))
import GHC.Core.Type (TyVar)
import GHC.Tc.Plugin (TcPluginM, tcPluginTrace)
import GHC.Tc.Types.Constraint (Ct)
import GHC.Types.Unique.Set (UniqSet, emptyUniqSet, unionUniqSets, unitUniqSet)
Expand All @@ -35,7 +33,7 @@ import GHC.Utils.Outputable (Outputable (..), ($$), text)
import Outputable (Outputable (..), ($$), text)
import TcPluginM (TcPluginM, tcPluginTrace)
import TcTypeNats (typeNatExpTyCon)
import Type (TyVar, coreView)
import Type (TyVar)
import TyCoRep (Type (..), TyLit (..))
import UniqSet (UniqSet, emptyUniqSet, unionUniqSets, unitUniqSet)
#if MIN_VERSION_ghc(8,10,0)
Expand All @@ -59,23 +57,6 @@ mergeNormResWith f x y = do
(res, n3) <- f x' y'
pure (res, n1 `mergeNormalised` n2 `mergeNormalised` n3)

toExtraOp :: ExtraDefs -> Type -> Maybe ExtraOp
toExtraOp defs ty | Just ty1 <- coreView ty = toExtraOp defs ty1
toExtraOp _ (TyVarTy v) = pure (V v)
toExtraOp _ (LitTy (NumTyLit i)) = pure (I i)
toExtraOp defs (TyConApp tc [x,y])
| tc == maxTyCon defs = Max <$> (toExtraOp defs x) <*> (toExtraOp defs y)
| tc == minTyCon defs = Min <$> (toExtraOp defs x) <*> (toExtraOp defs y)
| tc == divTyCon defs = Div <$> (toExtraOp defs x) <*> (toExtraOp defs y)
| tc == modTyCon defs = Mod <$> (toExtraOp defs x) <*> (toExtraOp defs y)
| tc == flogTyCon defs = FLog <$> (toExtraOp defs x) <*> (toExtraOp defs y)
| tc == clogTyCon defs = CLog <$> (toExtraOp defs x) <*> (toExtraOp defs y)
| tc == logTyCon defs = Log <$> (toExtraOp defs x) <*> (toExtraOp defs y)
| tc == gcdTyCon defs = GCD <$> (toExtraOp defs x) <*> (toExtraOp defs y)
| tc == lcmTyCon defs = LCM <$> (toExtraOp defs x) <*> (toExtraOp defs y)
| tc == typeNatExpTyCon = Exp <$> (toExtraOp defs x) <*> (toExtraOp defs y)
toExtraOp _ t = Just (C (CType t))

normaliseNat :: ExtraDefs -> ExtraOp -> Maybe NormaliseResult
normaliseNat _ (I i) = pure (I i, Untouched)
normaliseNat _ (V v) = pure (V v, Untouched)
Expand All @@ -85,19 +66,19 @@ normaliseNat defs (Max x y) = mergeNormResWith (\x' y' -> pure (mergeMax defs x'
normaliseNat defs (Min x y) = mergeNormResWith (\x' y' -> pure (mergeMin defs x' y'))
(normaliseNat defs x)
(normaliseNat defs y)
normaliseNat defs (Div x y) = mergeNormResWith (\x' y' -> (mergeDiv x' y'))
normaliseNat defs (Div x y) = mergeNormResWith mergeDiv
(normaliseNat defs x)
(normaliseNat defs y)
normaliseNat defs (Mod x y) = mergeNormResWith (\x' y' -> (mergeMod x' y'))
normaliseNat defs (Mod x y) = mergeNormResWith mergeMod
(normaliseNat defs x)
(normaliseNat defs y)
normaliseNat defs (FLog x y) = mergeNormResWith (\x' y' -> (mergeFLog x' y'))
normaliseNat defs (FLog x y) = mergeNormResWith (mergeFLog defs)
(normaliseNat defs x)
(normaliseNat defs y)
normaliseNat defs (CLog x y) = mergeNormResWith (\x' y' -> (mergeCLog x' y'))
normaliseNat defs (CLog x y) = mergeNormResWith (mergeCLog defs)
(normaliseNat defs x)
(normaliseNat defs y)
normaliseNat defs (Log x y) = mergeNormResWith (\x' y' -> (mergeLog x' y'))
normaliseNat defs (Log x y) = mergeNormResWith (mergeLog defs)
(normaliseNat defs x)
(normaliseNat defs y)
normaliseNat defs (GCD x y) = mergeNormResWith (\x' y' -> pure (mergeGCD x' y'))
Expand Down
39 changes: 39 additions & 0 deletions tests/Main.hs
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,33 @@ test57
-> Proxy True
test57 _ _ = id

test58
:: Proxy (CLog 2 (n * 2))
-> Proxy (CLog 2 n + 1)
test58 = id

test59
:: Proxy n
-> Proxy b
-> Proxy (CLog b (n * b))
-> Proxy (CLog b n + 1)
test59 _ _ = id

test60
:: Proxy n
-> Proxy b
-> Proxy (CLog b (n * b * b))
-> Proxy (CLog b n + 2)
test60 _ _ = id

test61
:: Proxy n
-> Proxy b
-> Proxy f
-> Proxy (CLog (b ^ n) ((f * (b ^ n)) + (f * (b ^ n))))
-> Proxy (CLog (b ^ n) (2 * f) + 1)
test61 _ _ _ = id

main :: IO ()
main = defaultMain tests

Expand Down Expand Up @@ -400,6 +427,18 @@ tests = testGroup "ghc-typelits-natnormalise"
, testCase "forall n p . n + 1 <= Max (n + p + 1) p" $
show (test57 Proxy Proxy Proxy) @?=
"Proxy"
, testCase "forall n . CLog 2 (n * 2) ~ CLog 2 n + 1" $
show (test58 Proxy) @?=
"Proxy"
, testCase "forall n b. CLog b (n * b) ~ CLog b n + 1" $
show (test59 Proxy Proxy Proxy) @?=
"Proxy"
, testCase "forall n b. CLog b (n * b * b) ~ CLog b n + 2" $
show (test60 Proxy Proxy Proxy) @?=
"Proxy"
, testCase "forall n b f. CLog (b ^ n) ((f * (b ^ n)) + (f * (b ^ n)))) ~ CLog (b ^ n) (2 * f) + 1" $
show (test61 Proxy Proxy Proxy Proxy) @?=
"Proxy"
]
, testGroup "errors"
[ testCase "GCD 6 8 /~ 4" $ testFail1 `throws` testFail1Errors
Expand Down

0 comments on commit 3a7a1ba

Please sign in to comment.