From 0ef43b61c011d3afa607efdc9ad6c5a153c31f8a Mon Sep 17 00:00:00 2001 From: Rowan Goemans Date: Sat, 13 Apr 2024 00:30:28 +0200 Subject: [PATCH] Reduce `Mod n p <= q` to `p <= q + 1` --- CHANGELOG.md | 1 + src-ghc-9.4/GHC/TypeLits/Extra/Solver.hs | 11 ++++++++--- src-pre-ghc-9.4/GHC/TypeLits/Extra/Solver.hs | 17 +++++++++++------ src/GHC/TypeLits/Extra/Solver/Unify.hs | 7 +++++-- tests/Main.hs | 20 ++++++++++++++++++++ 5 files changed, 45 insertions(+), 11 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2954a5e..0db36c9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,7 @@ # Unreleased * Fix faulty lookup for `Mod` and `Div` in GHC >= 9.2 +* Reduce `Mod n p <= q` to `p <= q + 1` # 0.4.7 * Fix Plugin silently fails when normalizing <= in GHC 9.4+ [#50](https://github.com/clash-lang/ghc-typelits-extra/issues/50) diff --git a/src-ghc-9.4/GHC/TypeLits/Extra/Solver.hs b/src-ghc-9.4/GHC/TypeLits/Extra/Solver.hs index 8fe3db5..ef6a2f1 100644 --- a/src-ghc-9.4/GHC/TypeLits/Extra/Solver.hs +++ b/src-ghc-9.4/GHC/TypeLits/Extra/Solver.hs @@ -32,14 +32,14 @@ import GHC.TcPluginM.Extra import GHC.Builtin.Names (eqPrimTyConKey, hasKey, getUnique) import GHC.Builtin.Types (promotedTrueDataCon, promotedFalseDataCon) import GHC.Builtin.Types (boolTy, naturalTy, cTupleDataCon, cTupleTyCon) -import GHC.Builtin.Types.Literals (typeNatDivTyCon, typeNatModTyCon, typeNatCmpTyCon) +import GHC.Builtin.Types.Literals (typeNatAddTyCon, typeNatDivTyCon, typeNatModTyCon, typeNatCmpTyCon) import GHC.Core.Coercion (mkUnivCo) import GHC.Core.DataCon (dataConWrapId) import GHC.Core.Predicate (EqRel (NomEq), Pred (EqPred, IrredPred), classifyPredType) import GHC.Core.Reduction (Reduction(..)) import GHC.Core.TyCon (TyCon) import GHC.Core.TyCo.Rep (Type (..), TyLit (..), UnivCoProvenance (PluginProv)) -import GHC.Core.Type (Kind, mkTyConApp, splitTyConApp_maybe, typeKind) +import GHC.Core.Type (Kind, mkTyConApp, mkNumLitTy, splitTyConApp_maybe, typeKind) #if MIN_VERSION_ghc(9,6,0) import GHC.Core.TyCo.Compare (eqType) #else @@ -181,7 +181,12 @@ simplifyExtra defs eqs = tcPluginTrace "simplifyExtra" (ppr eqs) >> simples [] [ | otherwise -> return (Impossible eq) (p, Max x y) | b && (p == x || p == y) -> simples (((,) <$> evMagic ct <*> pure ct):evs) news eqs' - + -- transform: Mod n p <= q + -- to: p <= q + 1 + (Mod _ p, q) | isWantedCt ct -> do + let succQ = toCType $ TyConApp typeNatAddTyCon [reifyEOP defs q, mkNumLitTy 1] + newCt <- createWantedFromNormalised defs (NatInequality ct p succQ b norm) + simples (((,) <$> evMagic ct <*> pure ct):evs) (newCt:news) eqs' -- transform: q ~ Max x y => (p <=? q ~ True) -- to: (p <=? Max x y) ~ True -- and try to solve that along with the rest of the eqs' diff --git a/src-pre-ghc-9.4/GHC/TypeLits/Extra/Solver.hs b/src-pre-ghc-9.4/GHC/TypeLits/Extra/Solver.hs index 3d9a9d3..d1ac638 100644 --- a/src-pre-ghc-9.4/GHC/TypeLits/Extra/Solver.hs +++ b/src-pre-ghc-9.4/GHC/TypeLits/Extra/Solver.hs @@ -42,7 +42,7 @@ import GHC.Builtin.Types (boolTy, naturalTy) #else import GHC.Builtin.Types (typeNatKind) #endif -import GHC.Builtin.Types.Literals (typeNatDivTyCon, typeNatModTyCon) +import GHC.Builtin.Types.Literals (typeNatAddTyCon, typeNatDivTyCon, typeNatModTyCon) #if MIN_VERSION_ghc(9,2,0) import GHC.Builtin.Types.Literals (typeNatCmpTyCon) #else @@ -50,7 +50,7 @@ import GHC.Builtin.Types.Literals (typeNatLeqTyCon) #endif import GHC.Core.Predicate (EqRel (NomEq), Pred (EqPred), classifyPredType) import GHC.Core.TyCo.Rep (Type (..)) -import GHC.Core.Type (Kind, eqType, mkTyConApp, splitTyConApp_maybe, typeKind) +import GHC.Core.Type (Kind, eqType, mkNumLitTy, mkTyConApp, splitTyConApp_maybe, typeKind) import GHC.Data.FastString (fsLit) import GHC.Driver.Plugins (Plugin (..), defaultPlugin, purePlugin) import GHC.Tc.Plugin (TcPluginM, tcLookupTyCon, tcPluginTrace) @@ -77,12 +77,12 @@ import PrelNames (eqPrimTyConKey, hasKey) import TcEvidence (EvTerm) import TcPluginM (TcPluginM, tcLookupTyCon, tcPluginTrace) import TcRnTypes (TcPlugin(..), TcPluginResult (..)) -import Type (Kind, eqType, mkTyConApp, splitTyConApp_maybe) +import Type (Kind, eqType, mkNumLitTy, mkTyConApp, splitTyConApp_maybe) import TyCoRep (Type (..)) import TysWiredIn (typeNatKind, promotedTrueDataCon, promotedFalseDataCon) -import TcTypeNats (typeNatLeqTyCon) +import TcTypeNats (typeNatAddTyCon, typeNatLeqTyCon) #if MIN_VERSION_ghc(8,4,0) -import TcTypeNats (typeNatDivTyCon, typeNatModTyCon) +import TcTypeNats (typeNatAddTyCon, typeNatDivTyCon, typeNatModTyCon) #else import TcPluginM (zonkCt) #endif @@ -209,7 +209,12 @@ simplifyExtra defs eqs = tcPluginTrace "simplifyExtra" (ppr eqs) >> simples [] [ | otherwise -> return (Impossible eq) (p, Max x y) | b && (p == x || p == y) -> simples (((,) <$> evMagic ct <*> pure ct):evs) news eqs' - + -- transform: Mod n p <= q + -- to: p <= q + 1 + (Mod _ p, q) | isWantedCt ct -> do + let succQ = toCType $ TyConApp typeNatAddTyCon [reifyEOP defs q, mkNumLitTy 1] + newCt <- createWantedFromNormalised defs (NatInequality ct p succQ b norm) + simples (((,) <$> evMagic ct <*> pure ct):evs) (newCt:news) eqs' -- transform: q ~ Max x y => (p <=? q ~ True) -- to: (p <=? Max x y) ~ True -- and try to solve that along with the rest of the eqs' diff --git a/src/GHC/TypeLits/Extra/Solver/Unify.hs b/src/GHC/TypeLits/Extra/Solver/Unify.hs index b365044..801eb0e 100644 --- a/src/GHC/TypeLits/Extra/Solver/Unify.hs +++ b/src/GHC/TypeLits/Extra/Solver/Unify.hs @@ -11,6 +11,7 @@ module GHC.TypeLits.Extra.Solver.Unify ( ExtraDefs (..) , UnifyResult (..) , NormaliseResult + , toCType , normaliseNat , unifyExtra ) @@ -60,6 +61,8 @@ mergeNormResWith f x y = do (res, n3) <- f x' y' pure (res, n1 `mergeNormalised` n2 `mergeNormalised` n3) +toCType :: Type -> ExtraOp +toCType ty = C $ CType ty normaliseNat :: ExtraDefs -> Type -> MaybeT TcPluginM NormaliseResult normaliseNat defs ty | Just ty1 <- coreView ty = normaliseNat defs ty1 @@ -105,9 +108,9 @@ normaliseNat defs (TyConApp tc tys) = do normResults <- lift (sequence (runMaybeT . normaliseNat defs <$> tys)) let anyNormalised = foldr mergeNormalised Untouched (snd <$> catMaybes normResults) let tys' = mergeExtraOp (zip normResults tys) - pure (C (CType (TyConApp tc tys')), anyNormalised) + pure (toCType $ TyConApp tc tys', anyNormalised) -normaliseNat _ t = return (C (CType t), Untouched) +normaliseNat _ t = return (toCType t, Untouched) -- | Result of comparing two 'SOP' terms, returning a potential substitution -- list under which the two terms are equal. diff --git a/tests/Main.hs b/tests/Main.hs index d260275..a2d414e 100644 --- a/tests/Main.hs +++ b/tests/Main.hs @@ -234,6 +234,20 @@ test58b -> Proxy (Max (n+2) 1) test58b = test58a +test59 + :: Proxy n + -> Proxy p + -> Proxy (Mod n (p + 1) <=? p) + -> Proxy True +test59 _ _ x = x + +test60 + :: Proxy n + -> Proxy p + -> Proxy (Mod n (3 * p + 5) <=? (4 + p * 3)) + -> Proxy True +test60 _ _ x = x + main :: IO () main = defaultMain tests @@ -411,6 +425,12 @@ tests = testGroup "ghc-typelits-natnormalise" , testCase "forall n p . n + 1 <= Max (n + p + 1) p" $ show (test57 Proxy Proxy Proxy) @?= "Proxy" + , testCase "forall n p . Mod n (p + 1) <= p" $ + show (test59 Proxy Proxy Proxy) @?= + "Proxy" + , testCase "forall n p . Mod n (3 * p + 5) <= (4 + p * 3)" $ + show (test60 Proxy Proxy Proxy) @?= + "Proxy" ] , testGroup "errors" [ testCase "GCD 6 8 /~ 4" $ testFail1 `throws` testFail1Errors