diff --git a/src/Juvix/Analysis/TypeChecking/Inference.hs b/src/Juvix/Analysis/TypeChecking/Inference.hs index 4c83170180..2198fecb60 100644 --- a/src/Juvix/Analysis/TypeChecking/Inference.hs +++ b/src/Juvix/Analysis/TypeChecking/Inference.hs @@ -326,20 +326,14 @@ addIdens idens = modify (HashMap.union idens) -- | Assumes the given function has been type checked -- | NOTE: Registers the function *only* if the result type is Type functionDefEval :: forall r'. Member (Reader FunctionsTable) r' => FunctionDef -> Sem r' (Maybe Expression) -functionDefEval funDef = do - x <- runError (goTop funDef) - case x of - Left () -> return Nothing - Right r -> return (Just r) +functionDefEval = runFail . goTop where - goTop :: forall r. (Members '[Error (), Reader FunctionsTable] r) => FunctionDef -> Sem r Expression + goTop :: forall r. (Members '[Fail, Reader FunctionsTable] r) => FunctionDef -> Sem r Expression goTop f = case f ^. funDefClauses of c :| [] -> goClause c - _ -> nothing + _ -> fail where - nothing :: Sem r a - nothing = throw () goClause :: FunctionClause -> Sem r Expression goClause c = do let pats = c ^. clausePatterns @@ -350,8 +344,8 @@ functionDefEval funDef = do splitNExplicitParams :: Int -> Expression -> Sem r ([Expression], Expression) splitNExplicitParams n fun = do let (params, r) = unfoldFunType fun - unlessM (isUniverse r) nothing - (nfirst, rest) <- note () (splitAtExactMay n params) + unlessM (isUniverse r) fail + (nfirst, rest) <- failMaybe (splitAtExactMay n params) sparams <- mapM simpleExplicitParam nfirst let r' = foldFunType rest r return (sparams, r') @@ -364,16 +358,16 @@ functionDefEval funDef = do simpleExplicitParam :: FunctionParameter -> Sem r Expression simpleExplicitParam = \case FunctionParameter Nothing Explicit ty -> return ty - _ -> nothing + _ -> fail goPattern :: (Pattern, Expression) -> Expression -> Sem r Expression goPattern (p, ty) = case p of PatternVariable v -> return . ExpressionLambda . Lambda v ty - _ -> const nothing + _ -> const fail go :: [(PatternArg, Expression)] -> Sem r Expression go = \case [] -> return (c ^. clauseBody) ((p, ty) : ps) - | Implicit <- p ^. patternArgIsImplicit -> nothing + | Implicit <- p ^. patternArgIsImplicit -> fail | otherwise -> go ps >>= goPattern (p ^. patternArgPattern, ty) registerFunctionDef :: Member (State FunctionsTable) r => FunctionDef -> Sem r () diff --git a/src/Juvix/Prelude.hs b/src/Juvix/Prelude.hs index 827bba4e56..1adcbefb6f 100644 --- a/src/Juvix/Prelude.hs +++ b/src/Juvix/Prelude.hs @@ -5,10 +5,12 @@ module Juvix.Prelude module Juvix.Prelude.Lens, module Juvix.Prelude.Loc, module Juvix.Prelude.Trace, + module Juvix.Prelude.Effect.Fail, ) where import Juvix.Prelude.Base +import Juvix.Prelude.Effect.Fail import Juvix.Prelude.Error import Juvix.Prelude.Files import Juvix.Prelude.Lens diff --git a/src/Juvix/Prelude/Base.hs b/src/Juvix/Prelude/Base.hs index 2b6a58180e..cfbb4f46c7 100644 --- a/src/Juvix/Prelude/Base.hs +++ b/src/Juvix/Prelude/Base.hs @@ -67,7 +67,7 @@ module Juvix.Prelude.Base where import Control.Applicative -import Control.Monad.Extra +import Control.Monad.Extra hiding (fail) import Control.Monad.Fix import Data.Bifunctor hiding (first, second) import Data.Bitraversable diff --git a/src/Juvix/Prelude/Effect/Fail.hs b/src/Juvix/Prelude/Effect/Fail.hs new file mode 100644 index 0000000000..7343b950bf --- /dev/null +++ b/src/Juvix/Prelude/Effect/Fail.hs @@ -0,0 +1,20 @@ +-- | An effect similar to Polysemy Fail but wihout an error message +module Juvix.Prelude.Effect.Fail where + +import Juvix.Prelude.Base + +data Fail m a = Fail + +makeSem ''Fail + +-- | Run a 'Fail' effect purely. +runFail :: + Sem (Fail ': r) a -> + Sem r (Maybe a) +runFail = fmap (^? _Right) . runError @() . reinterpret (\Fail -> throw ()) +{-# INLINE runFail #-} + +failMaybe :: Member Fail r => Maybe a -> Sem r a +failMaybe = \case + Nothing -> fail + Just x -> return x diff --git a/test/Base.hs b/test/Base.hs index c70c10279c..639776fe91 100644 --- a/test/Base.hs +++ b/test/Base.hs @@ -6,6 +6,7 @@ module Base ) where +import Control.Monad.Extra as Monad import Data.Algorithm.Diff import Data.Algorithm.DiffOutput import Juvix.Prelude @@ -41,7 +42,7 @@ assertEqDiff msg a b | otherwise = do putStrLn (pack $ ppDiff (getGroupedDiff pa pb)) putStrLn "End diff" - fail msg + Monad.fail msg where pa = lines $ ppShow a pb = lines $ ppShow b