Skip to content

Commit

Permalink
Reduce Mod n p <= q to p <= q + 1
Browse files Browse the repository at this point in the history
  • Loading branch information
rowanG077 committed Apr 14, 2024
1 parent 0b2507a commit beb409c
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 11 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

# Unreleased
* Fix faulty lookup for `Mod` and `Div` in GHC >= 9.2
* Reduce `Mod n p <= q` to `p <= q + 1`

# 0.4.7
* Fix Plugin silently fails when normalizing <= in GHC 9.4+ [#50](https://github.com/clash-lang/ghc-typelits-extra/issues/50)
Expand Down
11 changes: 8 additions & 3 deletions src-ghc-9.4/GHC/TypeLits/Extra/Solver.hs
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,14 @@ import GHC.TcPluginM.Extra
import GHC.Builtin.Names (eqPrimTyConKey, hasKey, getUnique)
import GHC.Builtin.Types (promotedTrueDataCon, promotedFalseDataCon)
import GHC.Builtin.Types (boolTy, naturalTy, cTupleDataCon, cTupleTyCon)
import GHC.Builtin.Types.Literals (typeNatDivTyCon, typeNatModTyCon, typeNatCmpTyCon)
import GHC.Builtin.Types.Literals (typeNatAddTyCon, typeNatDivTyCon, typeNatModTyCon, typeNatCmpTyCon)
import GHC.Core.Coercion (mkUnivCo)
import GHC.Core.DataCon (dataConWrapId)
import GHC.Core.Predicate (EqRel (NomEq), Pred (EqPred, IrredPred), classifyPredType)
import GHC.Core.Reduction (Reduction(..))
import GHC.Core.TyCon (TyCon)
import GHC.Core.TyCo.Rep (Type (..), TyLit (..), UnivCoProvenance (PluginProv))
import GHC.Core.Type (Kind, mkTyConApp, splitTyConApp_maybe, typeKind)
import GHC.Core.Type (Kind, mkTyConApp, mkNumLitTy, splitTyConApp_maybe, typeKind)
#if MIN_VERSION_ghc(9,6,0)
import GHC.Core.TyCo.Compare (eqType)
#else
Expand Down Expand Up @@ -181,7 +181,12 @@ simplifyExtra defs eqs = tcPluginTrace "simplifyExtra" (ppr eqs) >> simples [] [
| otherwise -> return (Impossible eq)
(p, Max x y)
| b && (p == x || p == y) -> simples (((,) <$> evMagic ct <*> pure ct):evs) news eqs'

-- transform: Mod n p <= q
-- to: p <= q + 1
(Mod _ p, q) | isWantedCt ct -> do
let succQ = toCType $ TyConApp typeNatAddTyCon [reifyEOP defs q, mkNumLitTy 1]
newCt <- createWantedFromNormalised defs (NatInequality ct p succQ b norm)
simples (((,) <$> evMagic ct <*> pure ct):evs) (newCt:news) eqs'
-- transform: q ~ Max x y => (p <=? q ~ True)
-- to: (p <=? Max x y) ~ True
-- and try to solve that along with the rest of the eqs'
Expand Down
17 changes: 11 additions & 6 deletions src-pre-ghc-9.4/GHC/TypeLits/Extra/Solver.hs
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,15 @@ import GHC.Builtin.Types (boolTy, naturalTy)
#else
import GHC.Builtin.Types (typeNatKind)
#endif
import GHC.Builtin.Types.Literals (typeNatDivTyCon, typeNatModTyCon)
import GHC.Builtin.Types.Literals (typeNatAddTyCon, typeNatDivTyCon, typeNatModTyCon)
#if MIN_VERSION_ghc(9,2,0)
import GHC.Builtin.Types.Literals (typeNatCmpTyCon)
#else
import GHC.Builtin.Types.Literals (typeNatLeqTyCon)
#endif
import GHC.Core.Predicate (EqRel (NomEq), Pred (EqPred), classifyPredType)
import GHC.Core.TyCo.Rep (Type (..))
import GHC.Core.Type (Kind, eqType, mkTyConApp, splitTyConApp_maybe, typeKind)
import GHC.Core.Type (Kind, eqType, mkNumLitTy, mkTyConApp, splitTyConApp_maybe, typeKind)
import GHC.Data.FastString (fsLit)
import GHC.Driver.Plugins (Plugin (..), defaultPlugin, purePlugin)
import GHC.Tc.Plugin (TcPluginM, tcLookupTyCon, tcPluginTrace)
Expand All @@ -77,12 +77,12 @@ import PrelNames (eqPrimTyConKey, hasKey)
import TcEvidence (EvTerm)
import TcPluginM (TcPluginM, tcLookupTyCon, tcPluginTrace)
import TcRnTypes (TcPlugin(..), TcPluginResult (..))
import Type (Kind, eqType, mkTyConApp, splitTyConApp_maybe)
import Type (Kind, eqType, mkNumLitTy, mkTyConApp, splitTyConApp_maybe)
import TyCoRep (Type (..))
import TysWiredIn (typeNatKind, promotedTrueDataCon, promotedFalseDataCon)
import TcTypeNats (typeNatLeqTyCon)
import TcTypeNats (typeNatAddTyCon, typeNatLeqTyCon)
#if MIN_VERSION_ghc(8,4,0)
import TcTypeNats (typeNatDivTyCon, typeNatModTyCon)
import TcTypeNats (typeNatAddTyCon, typeNatDivTyCon, typeNatModTyCon)
#else
import TcPluginM (zonkCt)
#endif
Expand Down Expand Up @@ -209,7 +209,12 @@ simplifyExtra defs eqs = tcPluginTrace "simplifyExtra" (ppr eqs) >> simples [] [
| otherwise -> return (Impossible eq)
(p, Max x y)
| b && (p == x || p == y) -> simples (((,) <$> evMagic ct <*> pure ct):evs) news eqs'

-- transform: Mod n p <= q
-- to: p <= q + 1
(Mod _ p, q) | isWantedCt ct -> do
let succQ = toCType $ TyConApp typeNatAddTyCon [reifyEOP defs q, mkNumLitTy 1]
newCt <- createWantedFromNormalised defs (NatInequality ct p succQ b norm)
simples (((,) <$> evMagic ct <*> pure ct):evs) (newCt:news) eqs'
-- transform: q ~ Max x y => (p <=? q ~ True)
-- to: (p <=? Max x y) ~ True
-- and try to solve that along with the rest of the eqs'
Expand Down
7 changes: 5 additions & 2 deletions src/GHC/TypeLits/Extra/Solver/Unify.hs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ module GHC.TypeLits.Extra.Solver.Unify
( ExtraDefs (..)
, UnifyResult (..)
, NormaliseResult
, toCType
, normaliseNat
, unifyExtra
)
Expand Down Expand Up @@ -60,6 +61,8 @@ mergeNormResWith f x y = do
(res, n3) <- f x' y'
pure (res, n1 `mergeNormalised` n2 `mergeNormalised` n3)

toCType :: Type -> ExtraOp
toCType ty = C $ CType ty

normaliseNat :: ExtraDefs -> Type -> MaybeT TcPluginM NormaliseResult
normaliseNat defs ty | Just ty1 <- coreView ty = normaliseNat defs ty1
Expand Down Expand Up @@ -105,9 +108,9 @@ normaliseNat defs (TyConApp tc tys) = do
normResults <- lift (sequence (runMaybeT . normaliseNat defs <$> tys))
let anyNormalised = foldr mergeNormalised Untouched (snd <$> catMaybes normResults)
let tys' = mergeExtraOp (zip normResults tys)
pure (C (CType (TyConApp tc tys')), anyNormalised)
pure (toCType $ TyConApp tc tys', anyNormalised)

normaliseNat _ t = return (C (CType t), Untouched)
normaliseNat _ t = return (toCType t, Untouched)

-- | Result of comparing two 'SOP' terms, returning a potential substitution
-- list under which the two terms are equal.
Expand Down
20 changes: 20 additions & 0 deletions tests/Main.hs
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,20 @@ test58b
-> Proxy (Max (n+2) 1)
test58b = test58a

test59
:: Proxy n
-> Proxy p
-> Proxy (Mod n (p + 1) <=? p)
-> Proxy True
test59 _ _ x = x

test60
:: Proxy n
-> Proxy p
-> Proxy (Mod n (3 * p + 5) <=? (4 + p * 3))
-> Proxy True
test60 _ _ x = x

main :: IO ()
main = defaultMain tests

Expand Down Expand Up @@ -411,6 +425,12 @@ tests = testGroup "ghc-typelits-natnormalise"
, testCase "forall n p . n + 1 <= Max (n + p + 1) p" $
show (test57 Proxy Proxy Proxy) @?=
"Proxy"
, testCase "forall n p . Mod n (p + 1) <= p" $
show (test59 Proxy Proxy Proxy) @?=
"Proxy"
, testCase "forall n p . Mod n (3 * p + 5) <= (4 + p * 3)" $
show (test60 Proxy Proxy Proxy) @?=
"Proxy"
]
, testGroup "errors"
[ testCase "GCD 6 8 /~ 4" $ testFail1 `throws` testFail1Errors
Expand Down

0 comments on commit beb409c

Please sign in to comment.