Skip to content

Commit

Permalink
Split normalisation and EOP conversion
Browse files Browse the repository at this point in the history
  • Loading branch information
rowanG077 committed Jul 10, 2022
1 parent fe80881 commit bbc0d62
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 50 deletions.
8 changes: 4 additions & 4 deletions src/GHC/TypeLits/Extra/Solver.hs
Original file line number Diff line number Diff line change
Expand Up @@ -242,8 +242,8 @@ toSolverConstraint defs ct = case classifyPredType $ ctEvPred $ ctEvidence ct of
EqPred NomEq t1 t2
| isNatKind (typeKind t1) || isNatKind (typeKind t2)
-> do
(t1', n1) <- normaliseNat defs t1
(t2', n2) <- normaliseNat defs t2
(t1', n1) <- MaybeT (pure (toExtraOp defs t1 >>= normaliseNat defs))
(t2', n2) <- MaybeT (pure (toExtraOp defs t2 >>= normaliseNat defs))
pure (NatEquality ct t1' t2' (mergeNormalised n1 n2))
#if MIN_VERSION_ghc(9,2,0)
| TyConApp tc [_,cmpNat,TyConApp tt1 [],TyConApp tt2 [],TyConApp ff1 []] <- t1
Expand All @@ -259,8 +259,8 @@ toSolverConstraint defs ct = case classifyPredType $ ctEvPred $ ctEvidence ct of
#endif
, TyConApp tc' [] <- t2
-> do
(x', n1) <- normaliseNat defs x
(y', n2) <- normaliseNat defs y
(x', n1) <- MaybeT (pure (toExtraOp defs x >>= normaliseNat defs))
(y', n2) <- MaybeT (pure (toExtraOp defs y >>= normaliseNat defs))
let res | tc' == promotedTrueDataCon = pure (NatInequality ct x' y' True (mergeNormalised n1 n2))
| tc' == promotedFalseDataCon = pure (NatInequality ct x' y' False (mergeNormalised n1 n2))
| otherwise = fail "Nothing"
Expand Down
104 changes: 58 additions & 46 deletions src/GHC/TypeLits/Extra/Solver/Unify.hs
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,12 @@ module GHC.TypeLits.Extra.Solver.Unify
, UnifyResult (..)
, NormaliseResult
, normaliseNat
, toExtraOp
, unifyExtra
)
where

-- external
import Control.Monad.Trans.Class (lift)
import Control.Monad.Trans.Maybe (MaybeT (..))
import Data.Maybe (catMaybes)
import Data.Function (on)
import GHC.TypeLits.Normalise.Unify (CType (..))
Expand Down Expand Up @@ -50,64 +49,77 @@ import TcRnMonad (Ct)
import GHC.TypeLits.Extra.Solver.Operations

mergeNormResWith
:: (ExtraOp -> ExtraOp -> MaybeT TcPluginM NormaliseResult)
-> MaybeT TcPluginM NormaliseResult
-> MaybeT TcPluginM NormaliseResult
-> MaybeT TcPluginM NormaliseResult
:: (ExtraOp -> ExtraOp -> Maybe NormaliseResult)
-> Maybe NormaliseResult
-> Maybe NormaliseResult
-> Maybe NormaliseResult
mergeNormResWith f x y = do
(x', n1) <- x
(y', n2) <- y
(res, n3) <- f x' y'
pure (res, n1 `mergeNormalised` n2 `mergeNormalised` n3)


normaliseNat :: ExtraDefs -> Type -> MaybeT TcPluginM NormaliseResult
normaliseNat defs ty | Just ty1 <- coreView ty = normaliseNat defs ty1
normaliseNat _ (TyVarTy v) = pure (V v, Untouched)
normaliseNat _ (LitTy (NumTyLit i)) = pure (I i, Untouched)
normaliseNat defs (TyConApp tc [x,y])
| tc == maxTyCon defs = mergeNormResWith (\x' y' -> return (mergeMax defs x' y'))
(normaliseNat defs x)
(normaliseNat defs y)
| tc == minTyCon defs = mergeNormResWith (\x' y' -> return (mergeMin defs x' y'))
(normaliseNat defs x)
(normaliseNat defs y)
| tc == divTyCon defs = mergeNormResWith (\x' y' -> MaybeT (return (mergeDiv x' y')))
(normaliseNat defs x)
(normaliseNat defs y)
| tc == modTyCon defs = mergeNormResWith (\x' y' -> MaybeT (return (mergeMod x' y')))
(normaliseNat defs x)
(normaliseNat defs y)
| tc == flogTyCon defs = mergeNormResWith (\x' y' -> MaybeT (return (mergeFLog x' y')))
(normaliseNat defs x)
(normaliseNat defs y)
| tc == clogTyCon defs = mergeNormResWith (\x' y' -> MaybeT (return (mergeCLog x' y')))
(normaliseNat defs x)
(normaliseNat defs y)
| tc == logTyCon defs = mergeNormResWith (\x' y' -> MaybeT (return (mergeLog x' y')))
(normaliseNat defs x)
(normaliseNat defs y)
| tc == gcdTyCon defs = mergeNormResWith (\x' y' -> return (mergeGCD x' y'))
(normaliseNat defs x)
(normaliseNat defs y)
| tc == lcmTyCon defs = mergeNormResWith (\x' y' -> return (mergeLCM x' y'))
(normaliseNat defs x)
(normaliseNat defs y)
| tc == typeNatExpTyCon = mergeNormResWith (\x' y' -> return (mergeExp x' y'))
(normaliseNat defs x)
(normaliseNat defs y)

normaliseNat defs (TyConApp tc tys) = do
toExtraOp :: ExtraDefs -> Type -> Maybe ExtraOp
toExtraOp defs ty | Just ty1 <- coreView ty = toExtraOp defs ty1
toExtraOp _ (TyVarTy v) = pure (V v)
toExtraOp _ (LitTy (NumTyLit i)) = pure (I i)
toExtraOp defs (TyConApp tc [x,y])
| tc == maxTyCon defs = Max <$> (toExtraOp defs x) <*> (toExtraOp defs y)
| tc == minTyCon defs = Min <$> (toExtraOp defs x) <*> (toExtraOp defs y)
| tc == divTyCon defs = Div <$> (toExtraOp defs x) <*> (toExtraOp defs y)
| tc == modTyCon defs = Mod <$> (toExtraOp defs x) <*> (toExtraOp defs y)
| tc == flogTyCon defs = FLog <$> (toExtraOp defs x) <*> (toExtraOp defs y)
| tc == clogTyCon defs = CLog <$> (toExtraOp defs x) <*> (toExtraOp defs y)
| tc == logTyCon defs = Log <$> (toExtraOp defs x) <*> (toExtraOp defs y)
| tc == gcdTyCon defs = GCD <$> (toExtraOp defs x) <*> (toExtraOp defs y)
| tc == lcmTyCon defs = LCM <$> (toExtraOp defs x) <*> (toExtraOp defs y)
| tc == typeNatExpTyCon = Exp <$> (toExtraOp defs x) <*> (toExtraOp defs y)
toExtraOp _ t = Just (C (CType t))

normaliseNat :: ExtraDefs -> ExtraOp -> Maybe NormaliseResult
normaliseNat _ (I i) = pure (I i, Untouched)
normaliseNat _ (V v) = pure (V v, Untouched)
normaliseNat defs (Max x y) = mergeNormResWith (\x' y' -> pure (mergeMax defs x' y'))
(normaliseNat defs x)
(normaliseNat defs y)
normaliseNat defs (Min x y) = mergeNormResWith (\x' y' -> pure (mergeMin defs x' y'))
(normaliseNat defs x)
(normaliseNat defs y)
normaliseNat defs (Div x y) = mergeNormResWith (\x' y' -> (mergeDiv x' y'))
(normaliseNat defs x)
(normaliseNat defs y)
normaliseNat defs (Mod x y) = mergeNormResWith (\x' y' -> (mergeMod x' y'))
(normaliseNat defs x)
(normaliseNat defs y)
normaliseNat defs (FLog x y) = mergeNormResWith (\x' y' -> (mergeFLog x' y'))
(normaliseNat defs x)
(normaliseNat defs y)
normaliseNat defs (CLog x y) = mergeNormResWith (\x' y' -> (mergeCLog x' y'))
(normaliseNat defs x)
(normaliseNat defs y)
normaliseNat defs (Log x y) = mergeNormResWith (\x' y' -> (mergeLog x' y'))
(normaliseNat defs x)
(normaliseNat defs y)
normaliseNat defs (GCD x y) = mergeNormResWith (\x' y' -> pure (mergeGCD x' y'))
(normaliseNat defs x)
(normaliseNat defs y)
normaliseNat defs (LCM x y) = mergeNormResWith (\x' y' -> pure (mergeLCM x' y'))
(normaliseNat defs x)
(normaliseNat defs y)
normaliseNat defs (Exp x y) = mergeNormResWith (\x' y' -> pure (mergeExp x' y'))
(normaliseNat defs x)
(normaliseNat defs y)
normaliseNat defs (C (CType (TyConApp tc tys))) = do
let mergeExtraOp [] = []
mergeExtraOp ((Just (op, Normalised), _):xs) = reifyEOP defs op:mergeExtraOp xs
mergeExtraOp ((_, ty):xs) = ty:mergeExtraOp xs

normResults <- lift (sequence (runMaybeT . normaliseNat defs <$> tys))
let normResults = map (\t -> toExtraOp defs t >>= normaliseNat defs) tys
let anyNormalised = foldr mergeNormalised Untouched (snd <$> catMaybes normResults)
let tys' = mergeExtraOp (zip normResults tys)
pure (C (CType (TyConApp tc tys')), anyNormalised)

normaliseNat _ t = return (C (CType t), Untouched)
normaliseNat _ eop = Just (eop, Untouched)

-- | Result of comparing two 'SOP' terms, returning a potential substitution
-- list under which the two terms are equal.
Expand Down

0 comments on commit bbc0d62

Please sign in to comment.