Skip to content

Commit

Permalink
Improved Log inference
Browse files Browse the repository at this point in the history
Normalise `Log base (a * b * ... * z)` to `Log base a + Log base b + ... + Log base z`
Normalise `Log b (x^y)` to `y * Log b`
  • Loading branch information
rowanG077 committed Feb 9, 2022
1 parent fe80881 commit e6640a1
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 10 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
.stack-work
dist
dist-newstyle
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# Changelog for the [`ghc-typelits-extra`](http://hackage.haskell.org/package/ghc-typelits-extra) package

# Unreleased
* Normalise `Log base (a * b * ... * z)` to `Log base a + Log base b + ... + Log base z`
* Normalise `Log b (x^y)` to `y * Log b x`

# 0.4.3 *June 18th 2021*
* Add support for GHC 9.2.0.20210422

Expand Down
40 changes: 31 additions & 9 deletions src/GHC/TypeLits/Extra/Solver/Operations.hs
Original file line number Diff line number Diff line change
Expand Up @@ -31,23 +31,28 @@ where
-- external
import Control.Monad.Trans.Writer.Strict
#if MIN_VERSION_ghc_typelits_natnormalise(0,7,0)
import Data.Set as Set
import qualified Data.Set as Set
#endif


import GHC.Base (isTrue#,(==#),(+#))
import GHC.Integer (smallInteger)
import GHC.Integer.Logarithms (integerLogBase#)
import GHC.TypeLits.Normalise.Unify (CType (..), normaliseNat, isNatural)
import GHC.TypeLits.Normalise.SOP (SOP (..), Product (..))
import GHC.TypeLits.Normalise.Unify ( CType (..), normaliseNat, isNatural
, reifySOP)

-- GHC API
#if MIN_VERSION_ghc(9,0,0)
import GHC.Builtin.Types.Literals (typeNatExpTyCon, typeNatSubTyCon)
import GHC.Builtin.Types.Literals ( typeNatExpTyCon, typeNatMulTyCon
, typeNstAddTyCon, typeNatSubTyCon)
import GHC.Core.TyCon (TyCon)
import GHC.Core.Type (Type, TyVar, mkNumLitTy, mkTyConApp, mkTyVarTy)
import GHC.Utils.Outputable (Outputable (..), (<+>), integer, text)
#else
import Outputable (Outputable (..), (<+>), integer, text)
import TcTypeNats (typeNatExpTyCon, typeNatSubTyCon)
import TcTypeNats ( typeNatExpTyCon, typeNatMulTyCon, typeNatAddTyCon
, typeNatSubTyCon)
import TyCon (TyCon)
import Type (Type, TyVar, mkNumLitTy, mkTyConApp, mkTyVarTy)
#endif
Expand Down Expand Up @@ -194,11 +199,28 @@ mergeCLog i (Exp j k) | i == j = Just (k, Normalised)
mergeCLog (I i) (I j) = fmap (\r -> (I r, Normalised)) (clogBase i j)
mergeCLog x y = Just (CLog x y, Untouched)

mergeLog :: ExtraOp -> ExtraOp -> Maybe NormaliseResult
mergeLog (I i) _ | i < 2 = Nothing
mergeLog b (Exp b' y) | b == b' = Just (y, Normalised)
mergeLog (I i) (I j) = fmap (\r -> (I r, Normalised)) (exactLogBase i j)
mergeLog x y = Just (Log x y, Untouched)
mergeLog :: ExtraDefs -> ExtraOp -> ExtraOp -> Maybe NormaliseResult
mergeLog _ (I i) _ | i < 2 = Nothing
mergeLog _ b (Exp b' y) | b == b' = Just (y, Normalised)
mergeLog defs b (Exp x y) =
let t = mkTyConApp
typeNatMulTyCon
[ reifyEOP defs y
, reifyEOP defs (Log b x)
]
in Just (C (CType t), Normalised)
mergeLog _ (I i) (I j) = fmap (\r -> (I r, Normalised)) (exactLogBase i j)
mergeLog defs b x = case (fst (runWriter (normaliseNat (reifyEOP defs x)))) of
S [P [_]] -> Just (Log b x, Untouched)
S [P p] ->
let si a = [a]
toLog body = reifyEOP defs (Log b (C (CType body)))
logs = map (toLog . reifySOP . S . si . P . si) p
merged = foldr (\t1 t2 -> mkTyConApp typeNatAddTyCon [t1, t2])
(head logs)
(tail logs)
in Just (C (CType merged), Normalised)
_ -> Just (Log b x, Untouched)

mergeGCD :: ExtraOp -> ExtraOp -> NormaliseResult
mergeGCD (I i) (I j) = (I (gcd i j), Normalised)
Expand Down
2 changes: 1 addition & 1 deletion src/GHC/TypeLits/Extra/Solver/Unify.hs
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ normaliseNat defs (TyConApp tc [x,y])
| tc == clogTyCon defs = mergeNormResWith (\x' y' -> MaybeT (return (mergeCLog x' y')))
(normaliseNat defs x)
(normaliseNat defs y)
| tc == logTyCon defs = mergeNormResWith (\x' y' -> MaybeT (return (mergeLog x' y')))
| tc == logTyCon defs = mergeNormResWith (\x' y' -> MaybeT (return (mergeLog defs x' y')))
(normaliseNat defs x)
(normaliseNat defs y)
| tc == gcdTyCon defs = mergeNormResWith (\x' y' -> return (mergeGCD x' y'))
Expand Down
24 changes: 24 additions & 0 deletions tests/Main.hs
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,24 @@ test57
-> Proxy True
test57 _ _ = id

test58
:: Proxy b
-> Proxy x
-> Proxy y
-> Proxy (Log b (x^y))
-> Proxy (y * (Log b x))
test58 _ _ _ = id

test59
:: Proxy b
-> Proxy n
-> Proxy p
-> Proxy q
-> Proxy (Log b (n * p * q))
-> Proxy (Log b n + Log b p + Log b q)
test59 _ _ _ _ = id


main :: IO ()
main = defaultMain tests

Expand Down Expand Up @@ -400,6 +418,12 @@ tests = testGroup "ghc-typelits-natnormalise"
, testCase "forall n p . n + 1 <= Max (n + p + 1) p" $
show (test57 Proxy Proxy Proxy) @?=
"Proxy"
, testCase "forall b x y . Log b (x^y) ~ y * (Log b x)" $
show (test58 Proxy Proxy Proxy Proxy) @?=
"Proxy"
, testCase "forall b n p q . Log b (n * p * q) ~ Log b n + Log b p + Log b q" $
show (test59 Proxy Proxy Proxy Proxy Proxy) @?=
"Proxy"
]
, testGroup "errors"
[ testCase "GCD 6 8 /~ 4" $ testFail1 `throws` testFail1Errors
Expand Down

0 comments on commit e6640a1

Please sign in to comment.