From 6bfd0ea25d95e84572fcee0cfb449e9bd4a63828 Mon Sep 17 00:00:00 2001 From: Rowan Goemans Date: Sat, 13 Apr 2024 00:30:28 +0200 Subject: [PATCH] Reduce `Mod n p <= q` to `p <= q + 1` and `1 <= p` --- CHANGELOG.md | 1 + src-ghc-9.4/GHC/TypeLits/Extra/Solver.hs | 12 ++++++++--- src-pre-ghc-9.4/GHC/TypeLits/Extra/Solver.hs | 16 ++++++++++----- src/GHC/TypeLits/Extra/Solver/Unify.hs | 7 +++++-- tests-ghc-9.4/ErrorTests.hs | 9 +++++++++ tests-pre-ghc-9.4/ErrorTests.hs | 13 ++++++++++++ tests/Main.hs | 21 ++++++++++++++++++++ 7 files changed, 69 insertions(+), 10 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2954a5e..462149b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,7 @@ # Unreleased * Fix faulty lookup for `Mod` and `Div` in GHC >= 9.2 +* Reduce `Mod n p <= q` to `p <= q + 1` and `1 <= p` # 0.4.7 * Fix Plugin silently fails when normalizing <= in GHC 9.4+ [#50](https://github.com/clash-lang/ghc-typelits-extra/issues/50) diff --git a/src-ghc-9.4/GHC/TypeLits/Extra/Solver.hs b/src-ghc-9.4/GHC/TypeLits/Extra/Solver.hs index 8fe3db5..07988f7 100644 --- a/src-ghc-9.4/GHC/TypeLits/Extra/Solver.hs +++ b/src-ghc-9.4/GHC/TypeLits/Extra/Solver.hs @@ -32,14 +32,14 @@ import GHC.TcPluginM.Extra import GHC.Builtin.Names (eqPrimTyConKey, hasKey, getUnique) import GHC.Builtin.Types (promotedTrueDataCon, promotedFalseDataCon) import GHC.Builtin.Types (boolTy, naturalTy, cTupleDataCon, cTupleTyCon) -import GHC.Builtin.Types.Literals (typeNatDivTyCon, typeNatModTyCon, typeNatCmpTyCon) +import GHC.Builtin.Types.Literals (typeNatAddTyCon, typeNatDivTyCon, typeNatModTyCon, typeNatCmpTyCon) import GHC.Core.Coercion (mkUnivCo) import GHC.Core.DataCon (dataConWrapId) import GHC.Core.Predicate (EqRel (NomEq), Pred (EqPred, IrredPred), classifyPredType) import GHC.Core.Reduction (Reduction(..)) import GHC.Core.TyCon (TyCon) import GHC.Core.TyCo.Rep (Type (..), TyLit (..), UnivCoProvenance (PluginProv)) -import GHC.Core.Type (Kind, mkTyConApp, splitTyConApp_maybe, typeKind) +import GHC.Core.Type (Kind, mkTyConApp, mkNumLitTy, splitTyConApp_maybe, typeKind) #if MIN_VERSION_ghc(9,6,0) import GHC.Core.TyCo.Compare (eqType) #else @@ -181,7 +181,13 @@ simplifyExtra defs eqs = tcPluginTrace "simplifyExtra" (ppr eqs) >> simples [] [ | otherwise -> return (Impossible eq) (p, Max x y) | b && (p == x || p == y) -> simples (((,) <$> evMagic ct <*> pure ct):evs) news eqs' - + -- transform: Mod n p <= q + -- to: p <= q + 1, 1 <= p + (Mod _ p, q) | isWantedCt ct -> do + let succQ = toCType $ TyConApp typeNatAddTyCon [reifyEOP defs q, mkNumLitTy 1] + modCt <- createWantedFromNormalised defs (NatInequality ct p succQ b norm) + gteOneCt <- createWantedFromNormalised defs (NatInequality ct (I 1) p b norm) + simples (((,) <$> evMagic ct <*> pure ct):evs) (modCt:gteOneCt:news) eqs' -- transform: q ~ Max x y => (p <=? q ~ True) -- to: (p <=? Max x y) ~ True -- and try to solve that along with the rest of the eqs' diff --git a/src-pre-ghc-9.4/GHC/TypeLits/Extra/Solver.hs b/src-pre-ghc-9.4/GHC/TypeLits/Extra/Solver.hs index 3d9a9d3..e0c1f09 100644 --- a/src-pre-ghc-9.4/GHC/TypeLits/Extra/Solver.hs +++ b/src-pre-ghc-9.4/GHC/TypeLits/Extra/Solver.hs @@ -42,7 +42,7 @@ import GHC.Builtin.Types (boolTy, naturalTy) #else import GHC.Builtin.Types (typeNatKind) #endif -import GHC.Builtin.Types.Literals (typeNatDivTyCon, typeNatModTyCon) +import GHC.Builtin.Types.Literals (typeNatAddTyCon, typeNatDivTyCon, typeNatModTyCon) #if MIN_VERSION_ghc(9,2,0) import GHC.Builtin.Types.Literals (typeNatCmpTyCon) #else @@ -50,7 +50,7 @@ import GHC.Builtin.Types.Literals (typeNatLeqTyCon) #endif import GHC.Core.Predicate (EqRel (NomEq), Pred (EqPred), classifyPredType) import GHC.Core.TyCo.Rep (Type (..)) -import GHC.Core.Type (Kind, eqType, mkTyConApp, splitTyConApp_maybe, typeKind) +import GHC.Core.Type (Kind, eqType, mkNumLitTy, mkTyConApp, splitTyConApp_maybe, typeKind) import GHC.Data.FastString (fsLit) import GHC.Driver.Plugins (Plugin (..), defaultPlugin, purePlugin) import GHC.Tc.Plugin (TcPluginM, tcLookupTyCon, tcPluginTrace) @@ -77,10 +77,10 @@ import PrelNames (eqPrimTyConKey, hasKey) import TcEvidence (EvTerm) import TcPluginM (TcPluginM, tcLookupTyCon, tcPluginTrace) import TcRnTypes (TcPlugin(..), TcPluginResult (..)) -import Type (Kind, eqType, mkTyConApp, splitTyConApp_maybe) +import Type (Kind, eqType, mkNumLitTy, mkTyConApp, splitTyConApp_maybe) import TyCoRep (Type (..)) import TysWiredIn (typeNatKind, promotedTrueDataCon, promotedFalseDataCon) -import TcTypeNats (typeNatLeqTyCon) +import TcTypeNats (typeNatAddTyCon, typeNatLeqTyCon) #if MIN_VERSION_ghc(8,4,0) import TcTypeNats (typeNatDivTyCon, typeNatModTyCon) #else @@ -209,7 +209,13 @@ simplifyExtra defs eqs = tcPluginTrace "simplifyExtra" (ppr eqs) >> simples [] [ | otherwise -> return (Impossible eq) (p, Max x y) | b && (p == x || p == y) -> simples (((,) <$> evMagic ct <*> pure ct):evs) news eqs' - + -- transform: Mod n p <= q + -- to: p <= q + 1, 1 <= p + (Mod _ p, q) | isWantedCt ct -> do + let succQ = toCType $ TyConApp typeNatAddTyCon [reifyEOP defs q, mkNumLitTy 1] + modCt <- createWantedFromNormalised defs (NatInequality ct p succQ b norm) + gteOneCt <- createWantedFromNormalised defs (NatInequality ct (I 1) p b norm) + simples (((,) <$> evMagic ct <*> pure ct):evs) (modCt:gteOneCt:news) eqs' -- transform: q ~ Max x y => (p <=? q ~ True) -- to: (p <=? Max x y) ~ True -- and try to solve that along with the rest of the eqs' diff --git a/src/GHC/TypeLits/Extra/Solver/Unify.hs b/src/GHC/TypeLits/Extra/Solver/Unify.hs index b365044..801eb0e 100644 --- a/src/GHC/TypeLits/Extra/Solver/Unify.hs +++ b/src/GHC/TypeLits/Extra/Solver/Unify.hs @@ -11,6 +11,7 @@ module GHC.TypeLits.Extra.Solver.Unify ( ExtraDefs (..) , UnifyResult (..) , NormaliseResult + , toCType , normaliseNat , unifyExtra ) @@ -60,6 +61,8 @@ mergeNormResWith f x y = do (res, n3) <- f x' y' pure (res, n1 `mergeNormalised` n2 `mergeNormalised` n3) +toCType :: Type -> ExtraOp +toCType ty = C $ CType ty normaliseNat :: ExtraDefs -> Type -> MaybeT TcPluginM NormaliseResult normaliseNat defs ty | Just ty1 <- coreView ty = normaliseNat defs ty1 @@ -105,9 +108,9 @@ normaliseNat defs (TyConApp tc tys) = do normResults <- lift (sequence (runMaybeT . normaliseNat defs <$> tys)) let anyNormalised = foldr mergeNormalised Untouched (snd <$> catMaybes normResults) let tys' = mergeExtraOp (zip normResults tys) - pure (C (CType (TyConApp tc tys')), anyNormalised) + pure (toCType $ TyConApp tc tys', anyNormalised) -normaliseNat _ t = return (C (CType t), Untouched) +normaliseNat _ t = return (toCType t, Untouched) -- | Result of comparing two 'SOP' terms, returning a potential substitution -- list under which the two terms are equal. diff --git a/tests-ghc-9.4/ErrorTests.hs b/tests-ghc-9.4/ErrorTests.hs index 4cc5fbe..19dcd6d 100644 --- a/tests-ghc-9.4/ErrorTests.hs +++ b/tests-ghc-9.4/ErrorTests.hs @@ -100,6 +100,9 @@ testFail26 = testFail26' (Proxy @4) (Proxy @6) (Proxy @6) testFail27 :: Proxy n -> Proxy (n + 2 <=? Max (n + 1) 1) -> Proxy True testFail27 _ = id +testFail28 :: Proxy n -> Proxy (Mod n p <=? p) -> Proxy True +testFail28 _ = id + testFail1Errors = ["Expected: Proxy (GCD 6 8) -> Proxy 4" ," Actual: Proxy 4 -> Proxy 4" @@ -231,3 +234,9 @@ testFail26Errors = ,"from the context: (x <=? n) ~ 'True" ] #endif + +testFail28Errors = + ["Couldn't match type ‘Data.Type.Ord.OrdCond" + ,"(CmpNat 1 p) True True False" + , "with ‘True’" + ] diff --git a/tests-pre-ghc-9.4/ErrorTests.hs b/tests-pre-ghc-9.4/ErrorTests.hs index 26e1cd5..5353418 100644 --- a/tests-pre-ghc-9.4/ErrorTests.hs +++ b/tests-pre-ghc-9.4/ErrorTests.hs @@ -100,6 +100,9 @@ testFail26 = testFail26' (Proxy @4) (Proxy @6) (Proxy @6) testFail27 :: Proxy n -> Proxy (n + 2 <=? Max (n + 1) 1) -> Proxy True testFail27 _ = id +testFail28 :: Proxy n -> Proxy (Mod n p <=? p) -> Proxy True +testFail28 _ = id + #if __GLASGOW_HASKELL__ >= 900 testFail1Errors = ["Expected: Proxy (GCD 6 8) -> Proxy 4" @@ -345,3 +348,13 @@ testFail26Errors = ["Could not deduce: Max x y ~ n" ,"from the context: (x <=? n) ~ 'True" ] + +testFail28Errors = +#if __GLASGOW_HASKELL__ >= 902 + ["Couldn't match type ‘Data.Type.Ord.OrdCond" + ,"(CmpNat 1 p) True True False" + , "with ‘True’" + ] +#else + ["Couldn't match type ‘1 <=? p’ with ‘'True’"] +#endif diff --git a/tests/Main.hs b/tests/Main.hs index d260275..89ec3ec 100644 --- a/tests/Main.hs +++ b/tests/Main.hs @@ -234,6 +234,20 @@ test58b -> Proxy (Max (n+2) 1) test58b = test58a +test59 + :: Proxy n + -> Proxy p + -> Proxy (Mod n (p + 1) <=? p) + -> Proxy True +test59 _ _ x = x + +test60 + :: Proxy n + -> Proxy p + -> Proxy (Mod n (3 * p + 5) <=? (4 + p * 3)) + -> Proxy True +test60 _ _ x = x + main :: IO () main = defaultMain tests @@ -411,6 +425,12 @@ tests = testGroup "ghc-typelits-natnormalise" , testCase "forall n p . n + 1 <= Max (n + p + 1) p" $ show (test57 Proxy Proxy Proxy) @?= "Proxy" + , testCase "forall n p . Mod n (p + 1) <= p" $ + show (test59 Proxy Proxy Proxy) @?= + "Proxy" + , testCase "forall n p . Mod n (3 * p + 5) <= (4 + p * 3)" $ + show (test60 Proxy Proxy Proxy) @?= + "Proxy" ] , testGroup "errors" [ testCase "GCD 6 8 /~ 4" $ testFail1 `throws` testFail1Errors @@ -440,6 +460,7 @@ tests = testGroup "ghc-typelits-natnormalise" , testCase "(x+1 <=? Max x y) /~ True" $ testFail25 `throws` testFail25Errors , testCase "(x <= n) /=> (Max x y) ~ n" $ testFail26 `throws` testFail26Errors , testCase "n + 2 <=? Max (n + 1) 1 /~ True" $ testFail27 `throws` testFail27Errors + , testCase "Mod n p <=? p" $ testFail28 `throws` testFail28Errors ] ]