Skip to content

Commit

Permalink
Add CLog with well-defined zero case
Browse files Browse the repository at this point in the history
  • Loading branch information
kleinreact committed Sep 1, 2024
1 parent 4dadc82 commit d2ecb7a
Show file tree
Hide file tree
Showing 6 changed files with 125 additions and 1 deletion.
1 change: 1 addition & 0 deletions src-ghc-9.4/GHC/TypeLits/Extra/Solver.hs
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,7 @@ lookupExtraDefs = do
<*> look ''GHC.TypeLits.Extra.LCM
<*> look ''Data.Type.Ord.OrdCond
<*> look ''GHC.TypeError.Assert
<*> look ''GHC.TypeLits.Extra.CLogWZ
where
look nm = tcLookupTyCon =<< lookupTHName nm

Expand Down
1 change: 1 addition & 0 deletions src-pre-ghc-9.4/GHC/TypeLits/Extra/Solver.hs
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,7 @@ lookupExtraDefs = do
<*> pure typeNatLeqTyCon
<*> pure typeNatLeqTyCon
#endif
<*> look md "CLogWZ"
where
look md s = tcLookupTyCon =<< lookupName md (mkTcOcc s)
myModule = mkModuleName "GHC.TypeLits.Extra"
Expand Down
29 changes: 28 additions & 1 deletion src/GHC/TypeLits/Extra.hs
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ module GHC.TypeLits.Extra
-- ** Logarithm
, FLog
, CLog
, CLogWZ
-- *** Exact logarithm
, Log
-- Numeric
Expand Down Expand Up @@ -101,7 +102,8 @@ import GHC.TypeLits as N
#if MIN_VERSION_ghc(8,4,0)
import GHC.TypeLits (Div, Mod)
#endif
import GHC.TypeLits.KnownNat (KnownNat2 (..), SNatKn (..), nameToSymbol)
import GHC.TypeLits.KnownNat (KnownNat2 (..), KnownNat3 (..)
,SNatKn (..), nameToSymbol)

#if MIN_VERSION_ghc(8,2,0)
intToNumber :: Int# -> Natural
Expand Down Expand Up @@ -195,6 +197,31 @@ instance (KnownNat x, KnownNat y, 2 <= x, 1 <= y) => KnownNat2 $(nameToSymbol ''
_ | isTrue# (z1 ==# z2) -> SNatKn (intToNumber (z1 +# 1#))
| otherwise -> SNatKn (intToNumber z1)

-- | Extended version of 'CLog', which is also well-defined in case the non-base argument is zero. The additional third argument argument is returned in this particular case. dThe particular value is chosen the user.
--
-- Note that additional equations are provided by the type-checker plugin solver
-- "GHC.TypeLits.Extra.Solver".
type family CLogWZ (base :: Nat) (value :: Nat) (ifzero :: Nat) :: Nat where
CLogWZ 2 0 z = z
CLogWZ 2 1 _ = 0 -- Additional equations are provided by the custom solver

#if MIN_VERSION_ghc(9,4,0)
instance (KnownNat x, KnownNat y, KnownNat z, (2 <= x) ~ (() :: Constraint)) => KnownNat3 $(nameToSymbol ''CLogWZ) x y z where
#else
instance (KnownNat x, KnownNat y, KnownNat z, 2 <= x) => KnownNat3 $(nameToSymbol ''CLogWZ) x y z where
#endif
natSing3 = let x = natVal (Proxy @x)
y = natVal (Proxy @y)
z = natVal (Proxy @z)
z1 = integerLogBase# x y
z2 = integerLogBase# x (y-1)
in case y of
0 -> SNatKn $ fromInteger z
1 -> SNatKn 0
_ | isTrue# (z1 ==# z2) -> SNatKn (intToNumber (z1 +# 1#))
| otherwise -> SNatKn (intToNumber z1)


-- | Type-level equivalent of <https://hackage.haskell.org/package/base-4.17.0.0/docs/GHC-Integer-Logarithms.html#v:integerLogBase-35- integerLogBase#>
-- where the operation only reduces when:
--
Expand Down
18 changes: 18 additions & 0 deletions src/GHC/TypeLits/Extra/Solver/Operations.hs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ module GHC.TypeLits.Extra.Solver.Operations
, mergeMod
, mergeFLog
, mergeCLog
, mergeCLogWZ
, mergeLog
, mergeGCD
, mergeLCM
Expand Down Expand Up @@ -83,6 +84,7 @@ data ExtraOp
| GCD ExtraOp ExtraOp
| LCM ExtraOp ExtraOp
| Exp ExtraOp ExtraOp
| CLogWZ ExtraOp ExtraOp ExtraOp
deriving Eq

instance Outputable ExtraOp where
Expand All @@ -99,6 +101,12 @@ instance Outputable ExtraOp where
ppr (GCD x y) = text "GCD (" <+> ppr x <+> text "," <+> ppr y <+> text ")"
ppr (LCM x y) = text "GCD (" <+> ppr x <+> text "," <+> ppr y <+> text ")"
ppr (Exp x y) = text "Exp (" <+> ppr x <+> text "," <+> ppr y <+> text ")"
ppr (CLogWZ x y z) =
text "CLogWZ "
<+> text "(" <+> ppr x
<+> text "," <+> ppr y
<+> text "," <+> ppr z
<+> text ")"

data ExtraDefs = ExtraDefs
{ maxTyCon :: TyCon
Expand All @@ -112,6 +120,7 @@ data ExtraDefs = ExtraDefs
, lcmTyCon :: TyCon
, ordTyCon :: TyCon
, assertTC :: TyCon
, clogWZTyCon :: TyCon
}

reifyEOP :: ExtraDefs -> ExtraOp -> Type
Expand All @@ -128,6 +137,8 @@ reifyEOP defs (Mod x y) = mkTyConApp (modTyCon defs) [reifyEOP defs x
,reifyEOP defs y]
reifyEOP defs (CLog x y) = mkTyConApp (clogTyCon defs) [reifyEOP defs x
,reifyEOP defs y]
reifyEOP defs (CLogWZ x y z) = mkTyConApp (clogTyCon defs)
$ reifyEOP defs <$> [x, y, z]
reifyEOP defs (FLog x y) = mkTyConApp (flogTyCon defs) [reifyEOP defs x
,reifyEOP defs y]
reifyEOP defs (Log x y) = mkTyConApp (logTyCon defs) [reifyEOP defs x
Expand Down Expand Up @@ -195,6 +206,13 @@ mergeCLog i (Exp j k) | i == j = Just (k, Normalised)
mergeCLog (I i) (I j) = fmap (\r -> (I r, Normalised)) (clogBase i j)
mergeCLog x y = Just (CLog x y, Untouched)

mergeCLogWZ :: ExtraOp -> ExtraOp -> ExtraOp -> Maybe NormaliseResult
mergeCLogWZ (I i) _ _ | i < 2 = Nothing
mergeCLogWZ _ (I 0) z = Just (z, Normalised)
mergeCLogWZ i (Exp j k) _ | i == j = Just (k, Normalised)
mergeCLogWZ (I i) (I j) _ = fmap (\r -> (I r, Normalised)) (clogBase i j)
mergeCLogWZ x y _ = Just (CLog x y, Untouched)

mergeLog :: ExtraOp -> ExtraOp -> Maybe NormaliseResult
mergeLog (I i) _ | i < 2 = Nothing
mergeLog b (Exp b' y) | b == b' = Just (y, Normalised)
Expand Down
17 changes: 17 additions & 0 deletions src/GHC/TypeLits/Extra/Solver/Unify.hs
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,18 @@ normaliseNat defs (TyConApp tc [x,y])
(normaliseNat defs x)
(normaliseNat defs y)

normaliseNat defs (TyConApp tc [x,y,z])
| tc == clogWZTyCon defs = do
(x', n1) <- normaliseNat defs x
(y', n2) <- normaliseNat defs y
(z', n3) <- normaliseNat defs z
(res, n4) <- MaybeT $ return $ mergeCLogWZ x' y' z'
pure (res, n1 `mergeNormalised`
n2 `mergeNormalised`
n3 `mergeNormalised`
n4
)

normaliseNat defs (TyConApp tc tys) = do
let mergeExtraOp [] = []
mergeExtraOp ((Just (op, Normalised), _):xs) = reifyEOP defs op:mergeExtraOp xs
Expand Down Expand Up @@ -162,6 +174,10 @@ fvOP (Log x y) = fvOP x `unionUniqSets` fvOP y
fvOP (GCD x y) = fvOP x `unionUniqSets` fvOP y
fvOP (LCM x y) = fvOP x `unionUniqSets` fvOP y
fvOP (Exp x y) = fvOP x `unionUniqSets` fvOP y
fvOP (CLogWZ x y z) =
fvOP x `unionUniqSets`
fvOP y `unionUniqSets`
fvOP z

eqFV :: ExtraOp -> ExtraOp -> Bool
eqFV = (==) `on` fvOP
Expand All @@ -180,3 +196,4 @@ containsConstants (Log x y) = containsConstants x || containsConstants y
containsConstants (GCD x y) = containsConstants x || containsConstants y
containsConstants (LCM x y) = containsConstants x || containsConstants y
containsConstants (Exp x y) = containsConstants x || containsConstants y
containsConstants (CLogWZ x y z) = or $ map containsConstants [x, y, z]
60 changes: 60 additions & 0 deletions tests/Main.hs
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,36 @@ test58b
-> Proxy (Max (n+2) 1)
test58b = test58a

test59 :: Proxy (CLogWZ 3 10 9) -> Proxy 3
test59 = id

test60 :: Proxy ((CLogWZ 3 10 3) + x) -> Proxy (x + (CLogWZ 2 7 8))
test60 = id

test61 :: Proxy (CLogWZ x (x^y) 8) -> Proxy y
test61 = id

test62 :: Integer
test62 = natVal (Proxy :: Proxy (CLogWZ 6 8 3))

test63 :: Integer
test63 = natVal (Proxy :: Proxy (CLogWZ 3 10 9))

test64 :: Integer
test64 = natVal (Proxy :: Proxy ((CLogWZ 2 4 11) * (3 ^ (CLogWZ 2 4 8))))

test65 :: Integer
test65 = natVal (Proxy :: Proxy (Max (CLogWZ 2 4 8) (CLogWZ 4 20 5)))

test66 :: Proxy (CLogWZ 3 0 8) -> Proxy 8
test66 = id

test67 :: Proxy (CLogWZ 2 0 x) -> Proxy x
test67 = id

test68 :: Proxy (CLogWZ 5 0 0) -> Proxy 0
test68 = id

main :: IO ()
main = defaultMain tests

Expand Down Expand Up @@ -411,6 +441,36 @@ tests = testGroup "ghc-typelits-natnormalise"
, testCase "forall n p . n + 1 <= Max (n + p + 1) p" $
show (test57 Proxy Proxy Proxy) @?=
"Proxy"
, testCase "CLogWZ 3 10 9 ~ 3" $
show (test59 Proxy) @?=
"Proxy"
, testCase "forall x . CLogWZ 3 10 3 + x ~ x + CLogWZ 2 7 8" $
show (test60 Proxy) @?=
"Proxy"
, testCase "forall x>1 . CLogWZ x (x^y) 8 ~ y" $
show (test61 Proxy) @?=
"Proxy"
, testCase "KnownNat (CLogWZ 6 8 3) ~ 2" $
show test62 @?=
"2"
, testCase "KnownNat (CLogWZ 3 10 9) ~ 3" $
show test63 @?=
"3"
, testCase "KnownNat ((CLogWZ 2 4 11) * (3 ^ (CLogWZ 2 4 8)))) ~ 18" $
show test64 @?=
"18"
, testCase "KnownNat (Max (CLogWZ 2 4 8) (CLogWZ 4 20 5)) ~ 3" $
show test65 @?=
"3"
, testCase "CLogWZ 3 0 8 ~ 8" $
show (test66 Proxy) @?=
"Proxy"
, testCase "forall x. CLogWZ 2 0 x ~ x" $
show (test67 Proxy) @?=
"Proxy"
, testCase "CLogWZ 5 0 0 ~ 0" $
show (test68 Proxy) @?=
"Proxy"
]
, testGroup "errors"
[ testCase "GCD 6 8 /~ 4" $ testFail1 `throws` testFail1Errors
Expand Down

0 comments on commit d2ecb7a

Please sign in to comment.