From bbc0d62d587da8a4d48d87614d501bf6a29fd18b Mon Sep 17 00:00:00 2001 From: rowanG077 Date: Sun, 10 Jul 2022 12:13:17 +0200 Subject: [PATCH] Split normalisation and EOP conversion --- src/GHC/TypeLits/Extra/Solver.hs | 8 +- src/GHC/TypeLits/Extra/Solver/Unify.hs | 104 ++++++++++++++----------- 2 files changed, 62 insertions(+), 50 deletions(-) diff --git a/src/GHC/TypeLits/Extra/Solver.hs b/src/GHC/TypeLits/Extra/Solver.hs index 9b5e414..6ed92fb 100644 --- a/src/GHC/TypeLits/Extra/Solver.hs +++ b/src/GHC/TypeLits/Extra/Solver.hs @@ -242,8 +242,8 @@ toSolverConstraint defs ct = case classifyPredType $ ctEvPred $ ctEvidence ct of EqPred NomEq t1 t2 | isNatKind (typeKind t1) || isNatKind (typeKind t2) -> do - (t1', n1) <- normaliseNat defs t1 - (t2', n2) <- normaliseNat defs t2 + (t1', n1) <- MaybeT (pure (toExtraOp defs t1 >>= normaliseNat defs)) + (t2', n2) <- MaybeT (pure (toExtraOp defs t2 >>= normaliseNat defs)) pure (NatEquality ct t1' t2' (mergeNormalised n1 n2)) #if MIN_VERSION_ghc(9,2,0) | TyConApp tc [_,cmpNat,TyConApp tt1 [],TyConApp tt2 [],TyConApp ff1 []] <- t1 @@ -259,8 +259,8 @@ toSolverConstraint defs ct = case classifyPredType $ ctEvPred $ ctEvidence ct of #endif , TyConApp tc' [] <- t2 -> do - (x', n1) <- normaliseNat defs x - (y', n2) <- normaliseNat defs y + (x', n1) <- MaybeT (pure (toExtraOp defs x >>= normaliseNat defs)) + (y', n2) <- MaybeT (pure (toExtraOp defs y >>= normaliseNat defs)) let res | tc' == promotedTrueDataCon = pure (NatInequality ct x' y' True (mergeNormalised n1 n2)) | tc' == promotedFalseDataCon = pure (NatInequality ct x' y' False (mergeNormalised n1 n2)) | otherwise = fail "Nothing" diff --git a/src/GHC/TypeLits/Extra/Solver/Unify.hs b/src/GHC/TypeLits/Extra/Solver/Unify.hs index b365044..6776d86 100644 --- a/src/GHC/TypeLits/Extra/Solver/Unify.hs +++ b/src/GHC/TypeLits/Extra/Solver/Unify.hs @@ -12,13 +12,12 @@ module GHC.TypeLits.Extra.Solver.Unify , UnifyResult (..) , NormaliseResult , normaliseNat + , toExtraOp , unifyExtra ) where -- external -import Control.Monad.Trans.Class (lift) -import Control.Monad.Trans.Maybe (MaybeT (..)) import Data.Maybe (catMaybes) import Data.Function (on) import GHC.TypeLits.Normalise.Unify (CType (..)) @@ -50,64 +49,77 @@ import TcRnMonad (Ct) import GHC.TypeLits.Extra.Solver.Operations mergeNormResWith - :: (ExtraOp -> ExtraOp -> MaybeT TcPluginM NormaliseResult) - -> MaybeT TcPluginM NormaliseResult - -> MaybeT TcPluginM NormaliseResult - -> MaybeT TcPluginM NormaliseResult + :: (ExtraOp -> ExtraOp -> Maybe NormaliseResult) + -> Maybe NormaliseResult + -> Maybe NormaliseResult + -> Maybe NormaliseResult mergeNormResWith f x y = do (x', n1) <- x (y', n2) <- y (res, n3) <- f x' y' pure (res, n1 `mergeNormalised` n2 `mergeNormalised` n3) - -normaliseNat :: ExtraDefs -> Type -> MaybeT TcPluginM NormaliseResult -normaliseNat defs ty | Just ty1 <- coreView ty = normaliseNat defs ty1 -normaliseNat _ (TyVarTy v) = pure (V v, Untouched) -normaliseNat _ (LitTy (NumTyLit i)) = pure (I i, Untouched) -normaliseNat defs (TyConApp tc [x,y]) - | tc == maxTyCon defs = mergeNormResWith (\x' y' -> return (mergeMax defs x' y')) - (normaliseNat defs x) - (normaliseNat defs y) - | tc == minTyCon defs = mergeNormResWith (\x' y' -> return (mergeMin defs x' y')) - (normaliseNat defs x) - (normaliseNat defs y) - | tc == divTyCon defs = mergeNormResWith (\x' y' -> MaybeT (return (mergeDiv x' y'))) - (normaliseNat defs x) - (normaliseNat defs y) - | tc == modTyCon defs = mergeNormResWith (\x' y' -> MaybeT (return (mergeMod x' y'))) - (normaliseNat defs x) - (normaliseNat defs y) - | tc == flogTyCon defs = mergeNormResWith (\x' y' -> MaybeT (return (mergeFLog x' y'))) - (normaliseNat defs x) - (normaliseNat defs y) - | tc == clogTyCon defs = mergeNormResWith (\x' y' -> MaybeT (return (mergeCLog x' y'))) - (normaliseNat defs x) - (normaliseNat defs y) - | tc == logTyCon defs = mergeNormResWith (\x' y' -> MaybeT (return (mergeLog x' y'))) - (normaliseNat defs x) - (normaliseNat defs y) - | tc == gcdTyCon defs = mergeNormResWith (\x' y' -> return (mergeGCD x' y')) - (normaliseNat defs x) - (normaliseNat defs y) - | tc == lcmTyCon defs = mergeNormResWith (\x' y' -> return (mergeLCM x' y')) - (normaliseNat defs x) - (normaliseNat defs y) - | tc == typeNatExpTyCon = mergeNormResWith (\x' y' -> return (mergeExp x' y')) - (normaliseNat defs x) - (normaliseNat defs y) - -normaliseNat defs (TyConApp tc tys) = do +toExtraOp :: ExtraDefs -> Type -> Maybe ExtraOp +toExtraOp defs ty | Just ty1 <- coreView ty = toExtraOp defs ty1 +toExtraOp _ (TyVarTy v) = pure (V v) +toExtraOp _ (LitTy (NumTyLit i)) = pure (I i) +toExtraOp defs (TyConApp tc [x,y]) + | tc == maxTyCon defs = Max <$> (toExtraOp defs x) <*> (toExtraOp defs y) + | tc == minTyCon defs = Min <$> (toExtraOp defs x) <*> (toExtraOp defs y) + | tc == divTyCon defs = Div <$> (toExtraOp defs x) <*> (toExtraOp defs y) + | tc == modTyCon defs = Mod <$> (toExtraOp defs x) <*> (toExtraOp defs y) + | tc == flogTyCon defs = FLog <$> (toExtraOp defs x) <*> (toExtraOp defs y) + | tc == clogTyCon defs = CLog <$> (toExtraOp defs x) <*> (toExtraOp defs y) + | tc == logTyCon defs = Log <$> (toExtraOp defs x) <*> (toExtraOp defs y) + | tc == gcdTyCon defs = GCD <$> (toExtraOp defs x) <*> (toExtraOp defs y) + | tc == lcmTyCon defs = LCM <$> (toExtraOp defs x) <*> (toExtraOp defs y) + | tc == typeNatExpTyCon = Exp <$> (toExtraOp defs x) <*> (toExtraOp defs y) +toExtraOp _ t = Just (C (CType t)) + +normaliseNat :: ExtraDefs -> ExtraOp -> Maybe NormaliseResult +normaliseNat _ (I i) = pure (I i, Untouched) +normaliseNat _ (V v) = pure (V v, Untouched) +normaliseNat defs (Max x y) = mergeNormResWith (\x' y' -> pure (mergeMax defs x' y')) + (normaliseNat defs x) + (normaliseNat defs y) +normaliseNat defs (Min x y) = mergeNormResWith (\x' y' -> pure (mergeMin defs x' y')) + (normaliseNat defs x) + (normaliseNat defs y) +normaliseNat defs (Div x y) = mergeNormResWith (\x' y' -> (mergeDiv x' y')) + (normaliseNat defs x) + (normaliseNat defs y) +normaliseNat defs (Mod x y) = mergeNormResWith (\x' y' -> (mergeMod x' y')) + (normaliseNat defs x) + (normaliseNat defs y) +normaliseNat defs (FLog x y) = mergeNormResWith (\x' y' -> (mergeFLog x' y')) + (normaliseNat defs x) + (normaliseNat defs y) +normaliseNat defs (CLog x y) = mergeNormResWith (\x' y' -> (mergeCLog x' y')) + (normaliseNat defs x) + (normaliseNat defs y) +normaliseNat defs (Log x y) = mergeNormResWith (\x' y' -> (mergeLog x' y')) + (normaliseNat defs x) + (normaliseNat defs y) +normaliseNat defs (GCD x y) = mergeNormResWith (\x' y' -> pure (mergeGCD x' y')) + (normaliseNat defs x) + (normaliseNat defs y) +normaliseNat defs (LCM x y) = mergeNormResWith (\x' y' -> pure (mergeLCM x' y')) + (normaliseNat defs x) + (normaliseNat defs y) +normaliseNat defs (Exp x y) = mergeNormResWith (\x' y' -> pure (mergeExp x' y')) + (normaliseNat defs x) + (normaliseNat defs y) +normaliseNat defs (C (CType (TyConApp tc tys))) = do let mergeExtraOp [] = [] mergeExtraOp ((Just (op, Normalised), _):xs) = reifyEOP defs op:mergeExtraOp xs mergeExtraOp ((_, ty):xs) = ty:mergeExtraOp xs - normResults <- lift (sequence (runMaybeT . normaliseNat defs <$> tys)) + let normResults = map (\t -> toExtraOp defs t >>= normaliseNat defs) tys let anyNormalised = foldr mergeNormalised Untouched (snd <$> catMaybes normResults) let tys' = mergeExtraOp (zip normResults tys) pure (C (CType (TyConApp tc tys')), anyNormalised) -normaliseNat _ t = return (C (CType t), Untouched) +normaliseNat _ eop = Just (eop, Untouched) -- | Result of comparing two 'SOP' terms, returning a potential substitution -- list under which the two terms are equal.