diff --git a/Jikka.cabal b/Jikka.cabal index 57a9142d..a47ffbcc 100644 --- a/Jikka.cabal +++ b/Jikka.cabal @@ -76,6 +76,7 @@ library Jikka.Core.Convert.ConvexHullTrick Jikka.Core.Convert.CumulativeSum Jikka.Core.Convert.Eta + Jikka.Core.Convert.KubaruToMorau Jikka.Core.Convert.MakeScanl Jikka.Core.Convert.MatrixExponentiation Jikka.Core.Convert.PropagateMod @@ -94,6 +95,7 @@ library Jikka.Core.Language.BuiltinPatterns Jikka.Core.Language.Expr Jikka.Core.Language.FreeVars + Jikka.Core.Language.LambdaPatterns Jikka.Core.Language.Lint Jikka.Core.Language.RewriteRules Jikka.Core.Language.Runtime diff --git a/examples/README.md b/examples/README.md index 4913f03f..82722ad2 100644 --- a/examples/README.md +++ b/examples/README.md @@ -24,10 +24,13 @@ - Educational DP Contest [Q - Flowers](https://atcoder.jp/contests/dp/tasks/dp_q) - A lazy propagation segment tree / 遅延伝播セグメント木 - submission at v5.0.11.0: -- :heavy_check_mark: AC `dp_z.py` +- :heavy_check_mark: AC `dp_z-morau.py` - Educational DP Contest [Z - Frog 3](https://atcoder.jp/contests/dp/tasks/dp_z) - Convex hull trick - submission at v5.0.11.0: +- :heavy_check_mark: AC `dp_z-kubaru.py` + - The Kubaru DP version of `dp_z-morau.py`. + - submission at v5.1.0.0: - :hourglass: TLE `abc134_c.py` - AtCoder Beginner Contest 134 [C - Exception Handling](https://atcoder.jp/contests/abc134/tasks/abc134_c) - Cumulative sums from both sides / 両側からの累積和 diff --git a/examples/data/aliases.json b/examples/data/aliases.json new file mode 100644 index 00000000..b2648b9b --- /dev/null +++ b/examples/data/aliases.json @@ -0,0 +1 @@ +{ "dp_z-kubaru": "dp_z-morau" } diff --git a/examples/data/dp_z.large.generator.py b/examples/data/dp_z-morau.large.generator.py similarity index 100% rename from examples/data/dp_z.large.generator.py rename to examples/data/dp_z-morau.large.generator.py diff --git a/examples/data/dp_z.sample-1.in b/examples/data/dp_z-morau.sample-1.in similarity index 100% rename from examples/data/dp_z.sample-1.in rename to examples/data/dp_z-morau.sample-1.in diff --git a/examples/data/dp_z.sample-1.out b/examples/data/dp_z-morau.sample-1.out similarity index 100% rename from examples/data/dp_z.sample-1.out rename to examples/data/dp_z-morau.sample-1.out diff --git a/examples/data/dp_z.sample-2.in b/examples/data/dp_z-morau.sample-2.in similarity index 100% rename from examples/data/dp_z.sample-2.in rename to examples/data/dp_z-morau.sample-2.in diff --git a/examples/data/dp_z.sample-2.out b/examples/data/dp_z-morau.sample-2.out similarity index 100% rename from examples/data/dp_z.sample-2.out rename to examples/data/dp_z-morau.sample-2.out diff --git a/examples/data/dp_z.sample-3.in b/examples/data/dp_z-morau.sample-3.in similarity index 100% rename from examples/data/dp_z.sample-3.in rename to examples/data/dp_z-morau.sample-3.in diff --git a/examples/data/dp_z.sample-3.out b/examples/data/dp_z-morau.sample-3.out similarity index 100% rename from examples/data/dp_z.sample-3.out rename to examples/data/dp_z-morau.sample-3.out diff --git a/examples/data/dp_z.solver.cpp b/examples/data/dp_z-morau.solver.cpp similarity index 100% rename from examples/data/dp_z.solver.cpp rename to examples/data/dp_z-morau.solver.cpp diff --git a/examples/dp_z-kubaru.py b/examples/dp_z-kubaru.py new file mode 100644 index 00000000..6ca2256e --- /dev/null +++ b/examples/dp_z-kubaru.py @@ -0,0 +1,27 @@ +# https://atcoder.jp/contests/dp/tasks/dp_z +from typing import * + +INF = 10 ** 18 + +def solve(n: int, c: int, h: List[int]) -> int: + assert 2 <= n <= 10 ** 5 + assert 1 <= c <= 10 ** 12 + assert len(h) == n + assert all(1 <= h_i <= 10 ** 6 for h_i in h) + + dp = [INF for _ in range(n)] + dp[0] = 0 + for i in range(n): + for j in range(i + 1, n): + dp[j] = min(dp[j], dp[i] + (h[i] - h[j]) ** 2 + c) + return dp[n - 1] + +def main() -> None: + n, c = map(int, input().split()) + h = list(map(int, input().split())) + assert len(h) == n + ans = solve(n, c, h) + print(ans) + +if __name__ == '__main__': + main() diff --git a/examples/dp_z.py b/examples/dp_z-morau.py similarity index 100% rename from examples/dp_z.py rename to examples/dp_z-morau.py diff --git a/scripts/integration_tests.py b/scripts/integration_tests.py index 8f51e455..814833ed 100644 --- a/scripts/integration_tests.py +++ b/scripts/integration_tests.py @@ -1,7 +1,9 @@ #!/usr/bin/env python3 import argparse import concurrent.futures +import functools import glob +import json import os import pathlib import platform @@ -26,6 +28,12 @@ def compile_cxx(src_path: pathlib.Path, dst_path: pathlib.Path): subprocess.check_call(command, timeout=5 * TIMEOUT_FACTOR) +@functools.lru_cache(maxsize=None) +def get_alias_mapping() -> Dict[str, str]: + with open(pathlib.Path('examples', 'data', 'aliases.json')) as fh: + return json.load(fh) + + def collect_input_cases(script: pathlib.Path, *, tempdir: pathlib.Path) -> List[pathlib.Path]: inputcases: List[pathlib.Path] = [] @@ -79,6 +87,14 @@ def collect_input_cases(script: pathlib.Path, *, tempdir: pathlib.Path) -> List[ return [] inputcases.append(inputcase) + # resolve alias + aliases = get_alias_mapping() + if script.stem in aliases: + if inputcases: + logger.error("%s: there must not be test cases when it uses an alias: %s", str(script), list(map(str, inputcases))) + return [] + return collect_input_cases(script.parent / (aliases[script.stem] + script.suffix), tempdir=tempdir) + return inputcases @@ -170,6 +186,8 @@ def find_unused_test_cases() -> List[pathlib.Path]: errors = list(pathlib.Path('examples', 'errors').glob('*.py')) unused = [] for path in pathlib.Path('examples', 'data').glob('*'): + if path.name == 'aliases.json': + continue name = path.name[:-len(''.join(path.suffixes))] if name not in [script.stem for script in scripts + errors]: unused.append(path) diff --git a/src/Jikka/CPlusPlus/Convert/FromCore.hs b/src/Jikka/CPlusPlus/Convert/FromCore.hs index 0de04008..93d6d652 100644 --- a/src/Jikka/CPlusPlus/Convert/FromCore.hs +++ b/src/Jikka/CPlusPlus/Convert/FromCore.hs @@ -77,7 +77,9 @@ runLiteral env = \case case stmts of [] -> return e _ -> throwInternalError "now builtin values don't use statements" - X.LitInt n -> return $ Y.Lit (Y.LitInt64 n) + X.LitInt n + | - (2 ^ 63) <= n && n < 2 ^ 63 -> return $ Y.Lit (Y.LitInt64 n) + | otherwise -> throwInternalError $ "integer value is too large for int64_t: " ++ show n X.LitBool p -> return $ Y.Lit (Y.LitBool p) X.LitNil t -> do t <- runType t diff --git a/src/Jikka/CPlusPlus/Convert/MoveSemantics.hs b/src/Jikka/CPlusPlus/Convert/MoveSemantics.hs index ac4a3715..8aa2e44b 100644 --- a/src/Jikka/CPlusPlus/Convert/MoveSemantics.hs +++ b/src/Jikka/CPlusPlus/Convert/MoveSemantics.hs @@ -51,10 +51,21 @@ runAssignExpr = \case AssignIncr e -> AssignIncr <$> runLeftExpr e AssignDecr e -> AssignDecr <$> runLeftExpr e -isMovable :: VarName -> [[Statement]] -> Bool -isMovable x cont = - let ReadWriteList rs _ = analyzeStatements (concat cont) - in x `S.notMember` rs +isMovableTo :: VarName -> VarName -> [[Statement]] -> Bool +isMovableTo x y cont + | x `S.notMember` readList' (analyzeStatements (concat cont)) = True + | otherwise = + let go = \case + [] -> False + (Assign (AssignExpr SimpleAssign (LeftVar x') (Var y')) : cont') + | x' == x && y' == y -> + let ReadWriteList _ ws' = analyzeStatements cont' + ReadWriteList rs ws = analyzeStatements (concat (tail cont)) + in y `S.notMember` S.unions [ws', rs, ws] + (stmt : cont) -> + let ReadWriteList rs ws = analyzeStatement stmt + in x `S.notMember` S.unions [rs, ws] && go cont + in go (head cont) runStatement :: MonadState (M.Map VarName VarName) m => Statement -> [[Statement]] -> m [Statement] runStatement stmt cont = case stmt of @@ -88,16 +99,16 @@ runStatement stmt cont = case stmt of DeclareCopy e -> DeclareCopy <$> runExpr e DeclareInitialize es -> DeclareInitialize <$> mapM runExpr es case init of - DeclareCopy (Var x) | x `isMovable` cont -> do + DeclareCopy (Var x) | (x `isMovableTo` y) cont -> do modify' (M.insert y x) return [] DeclareCopy (Call ConvexHullTrickCtor []) -> return [Declare t y DeclareDefault] DeclareCopy (Call ConvexHullTrickCopyAddLine [Var x, a, b]) - | x `isMovable` cont -> do + | (x `isMovableTo` y) cont -> do modify' (M.insert y x) return [callMethod' (Var x) "add_line" [a, b]] DeclareCopy (Call (SegmentTreeCopySetPoint _) [Var x, i, a]) - | x `isMovable` cont -> do + | (x `isMovableTo` y) cont -> do modify' (M.insert y x) return [callMethod' (Var x) "set" [i, a]] _ -> do @@ -111,13 +122,13 @@ runStatement stmt cont = case stmt of AssignExpr SimpleAssign (LeftVar y) (Var x) | x == y -> return [] AssignExpr SimpleAssign (LeftVar y) (Call ConvexHullTrickCopyAddLine [Var x, a, b]) | x == y -> return [callMethod' (Var x) "add_line" [a, b]] - | x `isMovable` cont -> do + | (x `isMovableTo` y) cont -> do modify' (M.insert y x) return [callMethod' (Var x) "add_line" [a, b]] | otherwise -> return [Assign e] AssignExpr SimpleAssign (LeftVar y) (Call (SegmentTreeCopySetPoint _) [Var x, i, a]) | x == y -> return [callMethod' (Var x) "set" [i, a]] - | x `isMovable` cont -> do + | (x `isMovableTo` y) cont -> do modify' (M.insert y x) return [callMethod' (Var x) "set" [i, a]] | otherwise -> return [Assign e] diff --git a/src/Jikka/CPlusPlus/Format.hs b/src/Jikka/CPlusPlus/Format.hs index 1d6e549d..6ff4485b 100644 --- a/src/Jikka/CPlusPlus/Format.hs +++ b/src/Jikka/CPlusPlus/Format.hs @@ -194,7 +194,9 @@ formatType = \case formatLiteral :: Literal -> Code formatLiteral = \case LitInt32 n -> show n - LitInt64 n -> show n + LitInt64 n + | - (2 ^ 31) <= n && n < 2 ^ 31 -> show n + | otherwise -> show n ++ "ll" LitBool p -> if p then "true" else "false" LitChar c -> show c LitString s -> show s diff --git a/src/Jikka/CPlusPlus/Language/VariableAnalysis.hs b/src/Jikka/CPlusPlus/Language/VariableAnalysis.hs index 941b7f5b..09299cdc 100644 --- a/src/Jikka/CPlusPlus/Language/VariableAnalysis.hs +++ b/src/Jikka/CPlusPlus/Language/VariableAnalysis.hs @@ -6,7 +6,7 @@ import qualified Data.Set as S import Jikka.CPlusPlus.Language.Expr data ReadWriteList = ReadWriteList - { readList :: S.Set VarName, + { readList' :: S.Set VarName, writeList :: S.Set VarName } deriving (Eq, Ord, Show, Read) diff --git a/src/Jikka/Core/Convert.hs b/src/Jikka/Core/Convert.hs index a5d2af1e..33bddc7a 100644 --- a/src/Jikka/Core/Convert.hs +++ b/src/Jikka/Core/Convert.hs @@ -29,6 +29,7 @@ import qualified Jikka.Core.Convert.ConstantPropagation as ConstantPropagation import qualified Jikka.Core.Convert.ConvexHullTrick as ConvexHullTrick import qualified Jikka.Core.Convert.CumulativeSum as CumulativeSum import qualified Jikka.Core.Convert.Eta as Eta +import qualified Jikka.Core.Convert.KubaruToMorau as KubaruToMorau import qualified Jikka.Core.Convert.MakeScanl as MakeScanl import qualified Jikka.Core.Convert.MatrixExponentiation as MatrixExponentiation import qualified Jikka.Core.Convert.PropagateMod as PropagateMod @@ -56,6 +57,7 @@ run'' prog = do prog <- CloseSum.run prog prog <- CloseAll.run prog prog <- CloseMin.run prog + prog <- KubaruToMorau.run prog prog <- CumulativeSum.run prog prog <- SegmentTree.run prog prog <- BubbleLet.run prog diff --git a/src/Jikka/Core/Convert/CloseMin.hs b/src/Jikka/Core/Convert/CloseMin.hs index 821b4767..12cb9295 100644 --- a/src/Jikka/Core/Convert/CloseMin.hs +++ b/src/Jikka/Core/Convert/CloseMin.hs @@ -53,7 +53,7 @@ reduceMin = simpleRewriteRule $ \case Lam x _ (Min2' _ e1 e2) -> Just $ Min2' t (Min1' t (Cons' t e0 (Map' t1 t2 (Lam x t e1) xs))) (Min1' t (Cons' t e0 (Map' t1 t2 (Lam x t e2) xs))) Lam x _ (Negate' e) -> Just $ Negate' (Max1' t (Cons' t (Negate' e0) (Map' t1 t2 (Lam x IntTy e) xs))) Lam x _ (Plus' e1 e2) | x `isUnusedVar` e1 -> Just $ Plus' e1 (Min1' t (Cons' t (Minus' e0 e1) (Map' t1 t2 (Lam x IntTy e2) xs))) - Lam x _ (Plus' e1 e2) | x `isUnusedVar` e2 -> Just $ Plus' (Min1' t (Cons' t (Minus' e0 e1) (Map' t1 t2 (Lam x IntTy e1) xs))) e2 + Lam x _ (Plus' e1 e2) | x `isUnusedVar` e2 -> Just $ Plus' (Min1' t (Cons' t (Minus' e0 e2) (Map' t1 t2 (Lam x IntTy e1) xs))) e2 _ -> Nothing _ -> Nothing @@ -80,7 +80,7 @@ reduceMax = simpleRewriteRule $ \case Lam x _ (Max2' _ e1 e2) -> Just $ Max2' t (Max1' t (Cons' t e0 (Map' t1 t2 (Lam x t e1) xs))) (Max1' t (Cons' t e0 (Map' t1 t2 (Lam x t e2) xs))) Lam x _ (Negate' e) -> Just $ Negate' (Min1' t (Cons' t (Negate' e0) (Map' t1 t2 (Lam x IntTy e) xs))) Lam x _ (Plus' e1 e2) | x `isUnusedVar` e1 -> Just $ Plus' e1 (Max1' t (Cons' t (Minus' e0 e1) (Map' t1 t2 (Lam x IntTy e2) xs))) - Lam x _ (Plus' e1 e2) | x `isUnusedVar` e2 -> Just $ Plus' (Max1' t (Cons' t (Minus' e0 e1) (Map' t1 t2 (Lam x IntTy e1) xs))) e2 + Lam x _ (Plus' e1 e2) | x `isUnusedVar` e2 -> Just $ Plus' (Max1' t (Cons' t (Minus' e0 e2) (Map' t1 t2 (Lam x IntTy e1) xs))) e2 _ -> Nothing _ -> Nothing diff --git a/src/Jikka/Core/Convert/ConvexHullTrick.hs b/src/Jikka/Core/Convert/ConvexHullTrick.hs index e8353b48..8031eb39 100644 --- a/src/Jikka/Core/Convert/ConvexHullTrick.hs +++ b/src/Jikka/Core/Convert/ConvexHullTrick.hs @@ -37,9 +37,6 @@ import Jikka.Core.Language.Lint import Jikka.Core.Language.RewriteRules import Jikka.Core.Language.Util -hoistMaybe :: Applicative m => Maybe a -> MaybeT m a -hoistMaybe = MaybeT . pure - -- | This is something commutative because only one kind of @c@ is allowed. plusPair :: (ArithmeticalExpr, ArithmeticalExpr) -> (ArithmeticalExpr, ArithmeticalExpr) -> Maybe (ArithmeticalExpr, ArithmeticalExpr) plusPair (a1, c1) (a2, _) | isZeroArithmeticalExpr a2 = Just (a1, c1) @@ -112,47 +109,48 @@ parseLinearFunctionBody' f i j e = result <$> go e _ -> Nothing _ -> Nothing -parseLinearFunctionBody :: MonadAlpha m => VarName -> VarName -> Integer -> Expr -> m (Maybe (Expr, Expr, Expr, Expr, Expr)) +parseLinearFunctionBody :: MonadAlpha m => VarName -> VarName -> Integer -> Expr -> m (Maybe (Expr, Expr, Expr, Expr, Expr, Maybe Expr)) parseLinearFunctionBody f i k = runMaybeT . go where + goMin e j step size = case unNPlusKPattern (parseArithmeticalExpr size) of + Just (i', k') | i' == i && k' == k -> do + (a, b, c, d) <- hoistMaybe $ parseLinearFunctionBody' f i j step + -- raname @j@ to @i@ + a <- lift $ substitute j (Var i) a + c <- lift $ substitute j (Var i) c + return (LitInt' 1, a, b, c, d, (`Minus'` d) <$> e) + _ -> hoistMaybe Nothing + goMax e j step size = do + (sign, a, b, c, d, e) <- goMin e j step size + return (Negate' sign, a, Negate' b, Negate' c, d, Negate' <$> e) go = \case - Min1' _ (Map' _ _ (Lam j _ step) (Range1' size)) -> case unNPlusKPattern (parseArithmeticalExpr size) of - Just (i', k') | i' == i && k' == k -> do - (a, b, c, d) <- hoistMaybe $ parseLinearFunctionBody' f i j step - -- raname @j@ to @i@ - a <- lift $ substitute j (Var i) a - c <- lift $ substitute j (Var i) c - return (LitInt' 1, a, b, c, d) - _ -> hoistMaybe Nothing - Max1' _ (Map' _ _ (Lam j _ step) (Range1' size)) -> case unNPlusKPattern (parseArithmeticalExpr size) of - Just (i', k') | i' == i && k' == k -> do - (a, b, c, d) <- hoistMaybe $ parseLinearFunctionBody' f i j step - -- raname @j@ to @i@ - a <- lift $ substitute j (Var i) a - c <- lift $ substitute j (Var i) c - return (LitInt' (-1), a, Negate' b, Negate' c, d) - _ -> hoistMaybe Nothing + Min1' _ (Map' _ _ (Lam j _ step) (Range1' size)) -> goMin Nothing j step size + Max1' _ (Map' _ _ (Lam j _ step) (Range1' size)) -> goMax Nothing j step size + Min1' _ (Cons' _ e (Map' _ _ (Lam j _ step) (Range1' size))) -> goMin (Just e) j step size + Max1' _ (Cons' _ e (Map' _ _ (Lam j _ step) (Range1' size))) -> goMax (Just e) j step size + Min1' _ (Snoc' _ (Map' _ _ (Lam j _ step) (Range1' size)) e) -> goMin (Just e) j step size + Max1' _ (Snoc' _ (Map' _ _ (Lam j _ step) (Range1' size)) e) -> goMax (Just e) j step size Negate' e -> do - (sign, a, b, c, d) <- go e - return (Negate' sign, a, b, c, Negate' d) + (sign, a, b, c, d, e) <- go e + return (Negate' sign, a, b, c, Negate' d, e) Plus' e1 e2 | isConstantTimeExpr e2 -> do - (sign, a, b, c, d) <- go e1 - return (sign, a, b, c, Plus' d e2) + (sign, a, b, c, d, e) <- go e1 + return (sign, a, b, c, Plus' d e2, e) Plus' e1 e2 | isConstantTimeExpr e1 -> do - (sign, a, b, c, d) <- go e2 - return (sign, a, b, c, Plus' e1 d) + (sign, a, b, c, d, e) <- go e2 + return (sign, a, b, c, Plus' e1 d, e) Minus' e1 e2 | isConstantTimeExpr e2 -> do - (sign, a, b, c, d) <- go e1 - return (sign, a, b, c, Minus' d e2) + (sign, a, b, c, d, e) <- go e1 + return (sign, a, b, c, Minus' d e2, e) Minus' e1 e2 | isConstantTimeExpr e1 -> do - (sign, a, b, c, d) <- go e2 - return (Negate' sign, a, b, c, Minus' e1 d) + (sign, a, b, c, d, e) <- go e2 + return (Negate' sign, a, b, c, Minus' e1 d, e) Mult' e1 e2 | isConstantTimeExpr e2 -> do - (sign, a, b, c, d) <- go e1 - return (Mult' sign e2, a, b, c, Mult' d e2) + (sign, a, b, c, d, e) <- go e1 + return (Mult' sign e2, a, b, c, Mult' d e2, e) Mult' e1 e2 | isConstantTimeExpr e1 -> do - (sign, a, b, c, d) <- go e2 - return (Mult' e1 sign, a, b, c, Mult' e1 d) + (sign, a, b, c, d, e) <- go e2 + return (Mult' e1 sign, a, b, c, Mult' e1 d, e) _ -> hoistMaybe Nothing getLength :: Expr -> Maybe Integer @@ -166,27 +164,57 @@ rule :: (MonadAlpha m, MonadError Error m) => RewriteRule m rule = RewriteRule $ \_ -> \case -- build (fun f -> step(f)) base n Build' IntTy (Lam f _ step) base n -> runMaybeT $ do + let ts = [ConvexHullTrickTy, ListTy IntTy] i <- lift genVarName' k <- hoistMaybe $ getLength base step <- replaceLenF f i k step - -- step(f) = sign(f) * min (map (fun j -> a(f, j) c(f) + b(f, j)) (range (i + k))) + d(f) - (sign, a, c, b, d) <- MaybeT $ parseLinearFunctionBody f i k step - x <- lift genVarName' - y <- lift genVarName' - f' <- lift $ genVarName f - let ts = [ConvexHullTrickTy, ListTy IntTy] - -- base' = (empty, base) - let base' = uncurryApp (Tuple' ts) [ConvexHullTrickInit', base] + -- step(f) = sign() * min (cons e(f, i) (map (fun j -> a(f, j) c(f, i) + b(f, j)) (range (i + k)))) + d(f, i) + (sign, a, c, b, d, e) <- MaybeT $ parseLinearFunctionBody f i k step + -- Update base when k = 0. If user's program has no bugs, it uses min(cons(x, xs)) when k = 0. + (base, n, k, c, d, e) <- case (e, k) of + (Just e, 0) -> do + e0 <- lift $ substitute i (LitInt' 0) e + d0 <- lift $ substitute i (LitInt' 0) d + let base' = Let f (ListTy IntTy) base $ Snoc' IntTy base (Plus' (Mult' sign e0) d0) + c <- lift $ substitute i (Plus' (Var i) (LitInt' 1)) c + d <- lift $ substitute i (Plus' (Var i) (LitInt' 1)) d + e <- lift $ substitute i (Plus' (Var i) (LitInt' 1)) e + return (base', Minus' n (LitInt' 1), k + 1, c, d, Just e) + _ -> return (base, n, k, c, d, e) + -- base' = (cht, base) + base' <- do + x <- lift genVarName' + f' <- lift $ genVarName f + i' <- lift $ genVarName i + a <- lift $ substitute f (Var f') a + b <- lift $ substitute f (Var f') b + a <- lift $ substitute i (Var i') a + b <- lift $ substitute i (Var i') b + -- cht for base[0], ..., base[k - 1] + let cht = Foldl' IntTy ConvexHullTrickTy (Lam2 x ConvexHullTrickTy i' IntTy (ConvexHullTrickInsert' (Var x) a b)) ConvexHullTrickInit' (Range1' (LitInt' k)) + return $ + Let f' (ListTy IntTy) base $ + uncurryApp (Tuple' ts) [cht, Var f'] -- step' = fun (cht, f) i -> - -- let f' = setat f index(i) (min cht f[i + k] + c(i)) + -- let f' = setat f index(i) value(..) -- in let cht' = update cht a(i) b(i) -- in (cht', f') - let step' = - Lam2 x (TupleTy ts) i IntTy $ - Let f (ListTy IntTy) (Proj' ts 1 (Var x)) $ + step' <- do + x <- lift genVarName' + -- value(..) = (min e (min cht f[i + k] + c(i))) + let value = Plus' (Mult' sign (maybe id (\e -> Min2' IntTy e) e (ConvexHullTrickGetMin' (Proj' ts 0 (Var x)) c))) d + y <- lift genVarName' + f' <- lift $ genVarName f + a <- lift $ substitute f (Var f') a + b <- lift $ substitute f (Var f') b + a <- lift $ substitute i (Plus' (Var i) (LitInt' k)) a + b <- lift $ substitute i (Plus' (Var i) (LitInt' k)) b + return $ + Lam2 x (TupleTy ts) i IntTy $ + Let f (ListTy IntTy) (Proj' ts 1 (Var x)) $ + Let f' (ListTy IntTy) (Snoc' IntTy (Var f) value) $ Let y ConvexHullTrickTy (ConvexHullTrickInsert' (Proj' ts 0 (Var x)) a b) $ - Let f' (ListTy IntTy) (Snoc' IntTy (Var f) (Plus' (Mult' sign (ConvexHullTrickGetMin' (Var y) c)) d)) $ - uncurryApp (Tuple' ts) [Var y, Var f'] + uncurryApp (Tuple' ts) [Var y, Var f'] -- proj 1 (foldl step' base' (range (n - 1))) return $ Proj' ts 1 (Foldl' IntTy (TupleTy ts) step' base' (Range1' n)) _ -> return Nothing diff --git a/src/Jikka/Core/Convert/KubaruToMorau.hs b/src/Jikka/Core/Convert/KubaruToMorau.hs new file mode 100644 index 00000000..90d4fa03 --- /dev/null +++ b/src/Jikka/Core/Convert/KubaruToMorau.hs @@ -0,0 +1,113 @@ +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE LambdaCase #-} + +-- | +-- Module : Jikka.Core.Convert.KubaruToMorau +-- Description : converts Kubaru DP to Morau DP. / 配る DP を貰う DP に変換します。 +-- Copyright : (c) Kimiyuki Onaka, 2021 +-- License : Apache License 2.0 +-- Maintainer : kimiyuki95@gmail.com +-- Stability : experimental +-- Portability : portable +module Jikka.Core.Convert.KubaruToMorau + ( run, + + -- * internal rules + rule, + runFunctionBody, + ) +where + +import Control.Monad.Trans.Maybe +import Jikka.Common.Alpha +import Jikka.Common.Error +import Jikka.Core.Language.ArithmeticalExpr +import Jikka.Core.Language.Beta +import Jikka.Core.Language.BuiltinPatterns +import Jikka.Core.Language.Expr +import Jikka.Core.Language.FreeVars +import Jikka.Core.Language.Lint +import Jikka.Core.Language.RewriteRules +import Jikka.Core.Language.Util + +-- | @runFunctionBody c i j step y x k@ returns @step'(y, x, i, k)@ s.t. @step(c, i, j) = step'(c[i + j + 1], c[i], i, i + j + 1)@ +runFunctionBody :: (MonadAlpha m, MonadError Error m) => VarName -> VarName -> VarName -> Expr -> VarName -> VarName -> VarName -> MaybeT m Expr +runFunctionBody c i j step y x k = do + step <- lift $ substitute j (Minus' (Minus' (Var k) (Var i)) (LitInt' 1)) step + let go = \case + Var x + | x == c -> hoistMaybe Nothing + | otherwise -> return $ Var x + Lit lit -> return $ Lit lit + At' _ (Var c') index | c' == c -> case () of + () | parseArithmeticalExpr index == parseArithmeticalExpr (Var i) -> return $ Var x + () | parseArithmeticalExpr index == parseArithmeticalExpr (Var k) -> return $ Var y + () | otherwise -> hoistMaybe Nothing + App e1 e2 -> App <$> go e1 <*> go e2 + Let x t e1 e2 + | x == c || x == i || x == j -> throwRuntimeError "name confliction found" + | otherwise -> Let x t <$> go e1 <*> go e2 + Lam x t e + | x == c || x == i || x == j -> throwRuntimeError "name confliction found" + | otherwise -> Lam x t <$> go e + go step + +-- | TODO: remove the assumption that the length of @a@ is equals to @n@ +rule :: (MonadAlpha m, MonadError Error m) => RewriteRule m +rule = RewriteRule $ \_ -> \case + -- foldl (fun b i -> foldl (fun c j -> setAt c index(i, j) step(c, i, j)) b (range m(i))) a (range n) + Foldl' IntTy (ListTy t2) (Lam2 b _ i _ (Foldl' IntTy (ListTy t2') (Lam2 c _ j _ (SetAt' _ (Var c') index step)) (Var b') (Range1' m))) a (Range1' n) + | t2' == t2 && b' == b && c == c' && b `isUnusedVar` m && b `isUnusedVar` index && b `isUnusedVar` step && c `isUnusedVar` index -> runMaybeT $ do + -- m(i) = n - i - 1 + guard $ parseArithmeticalExpr m == parseArithmeticalExpr (Minus' (Minus' n (Var i)) (LitInt' 1)) + -- index(i, j) = i + j + 1 + guard $ parseArithmeticalExpr index == parseArithmeticalExpr (Plus' (Var i) (Plus' (Var j) (LitInt' 1))) + x <- lift genVarName' + y <- lift genVarName' + k <- lift genVarName' + -- get step'(y, x, i, k) s.t. step(c, i, j) = step'(c[i + j + 1], c[i], i, i + j + 1) + step <- runFunctionBody c i j step y x k + step <- lift $ substitute x (At' t2 (Var c) (Var i)) step + step <- lift $ substitute k (Len' t2 (Var c)) step + let base = At' t2 a (Len' t2 (Var c)) + return $ Build' t2 (Lam c (ListTy t2) (Foldl' IntTy t2 (Lam2 y t2 i IntTy step) base (Range1' (Len' t2 (Var c))))) (Nil' t2) n + _ -> return Nothing + +runProgram :: (MonadAlpha m, MonadError Error m) => Program -> m Program +runProgram = applyRewriteRuleProgram' rule + +-- | `run` converts Kubaru DP +-- (for each \(i\), updates \( +-- \mathrm{dp}(j) \gets f(\mathrm{dp}(j), \mathrm{dp}(i)) +-- \) for each \(j \gt i\)) +-- to Morau DP +-- (for each \(i\), computes \( +-- \mathrm{dp}(i) = F(\lbrace \mathrm{dp}(j) \mid j \lt i \rbrace) +-- \)). +-- +-- == Examples +-- +-- Before: +-- +-- > foldl (fun dp i -> +-- > foldl (fun dp j -> +-- > setAt dp j ( +-- > f dp[j] dp[i]) +-- > ) dp (range (i + 1) n) +-- > ) dp (range n) +-- +-- After: +-- +-- > build (fun dp' -> +-- > foldl (fun dp_i j -> +-- > f dp_i dp'[j] +-- > ) dp[i] (range i) +-- > ) [] n +run :: (MonadAlpha m, MonadError Error m) => Program -> m Program +run prog = wrapError' "Jikka.Core.Convert.KubaruToMorau" $ do + precondition $ do + ensureWellTyped prog + prog <- runProgram prog + postcondition $ do + ensureWellTyped prog + return prog diff --git a/src/Jikka/Core/Convert/MakeScanl.hs b/src/Jikka/Core/Convert/MakeScanl.hs index 1aa3e9e0..62273fc6 100644 --- a/src/Jikka/Core/Convert/MakeScanl.hs +++ b/src/Jikka/Core/Convert/MakeScanl.hs @@ -32,8 +32,6 @@ where import Control.Monad.Trans.Maybe import qualified Data.Map as M -import Data.Maybe -import qualified Data.Vector as V import Jikka.Common.Alpha import Jikka.Common.Error import Jikka.Core.Language.ArithmeticalExpr @@ -57,17 +55,6 @@ reduceScanlBuild = simpleRewriteRule $ \case Scanl' t1 t2 f init (Cons' _ x xs) -> Just $ Cons' t2 init (Scanl' t1 t2 f (App2 f init x) xs) _ -> Nothing --- | `getRecurrenceFormulaBase` makes a pair @((a_0, ..., a_{k - 1}), a)@ from @setat (... (setat a 0 a_0) ...) (k - 1) a_{k - 1})@. -getRecurrenceFormulaBase :: Expr -> ([Expr], Expr) -getRecurrenceFormulaBase = go (V.replicate recurrenceLimit Nothing) - where - recurrenceLimit :: Num a => a - recurrenceLimit = 20 - go :: V.Vector (Maybe Expr) -> Expr -> ([Expr], Expr) - go base = \case - SetAt' _ e (LitInt' i) e' | 0 <= i && i < recurrenceLimit -> go (base V.// [(fromInteger i, Just e')]) e - e -> (map fromJust (takeWhile isJust (V.toList base)), e) - -- | `getRecurrenceFormulaStep1` removes `At` in @body@. getRecurrenceFormulaStep1 :: MonadAlpha m => Int -> Type -> VarName -> VarName -> Expr -> m (Maybe Expr) getRecurrenceFormulaStep1 shift t a i body = do @@ -113,9 +100,6 @@ getRecurrenceFormulaStep shift size t a i body = do Just body -> Just $ Lam2 x (TupleTy ts) i IntTy (uncurryApp (Tuple' ts) (map (\i -> Proj' ts i (Var x)) [1 .. size - 1] ++ [body])) Nothing -> Nothing -hoistMaybe :: Applicative m => Maybe a -> MaybeT m a -hoistMaybe = MaybeT . pure - -- | -- * This assumes that `Range2` and `Range3` are already converted to `Range1` (`Jikka.Core.Convert.ShortCutFusion`). -- * This assumes that combinations `Foldl` and `Map` squashed (`Jikka.Core.Convert.ShortCutFusion`). diff --git a/src/Jikka/Core/Convert/ShortCutFusion.hs b/src/Jikka/Core/Convert/ShortCutFusion.hs index f78fe23a..4992bfe5 100644 --- a/src/Jikka/Core/Convert/ShortCutFusion.hs +++ b/src/Jikka/Core/Convert/ShortCutFusion.hs @@ -36,6 +36,7 @@ import Jikka.Core.Format (formatExpr) import Jikka.Core.Language.BuiltinPatterns import Jikka.Core.Language.Expr import Jikka.Core.Language.FreeVars +import Jikka.Core.Language.LambdaPatterns import Jikka.Core.Language.Lint import Jikka.Core.Language.RewriteRules import Jikka.Core.Language.Util @@ -87,7 +88,7 @@ reduceMap = let return' = return . Just in RewriteRule $ \_ -> \case -- reduce `Map` - Map' _ _ (LamId _ _) xs -> return' xs + Map' _ _ (LamId _) xs -> return' xs -- reduce `Filter` Filter' t (Lam _ _ LitFalse) _ -> return' (Nil' t) Filter' _ (Lam _ _ LitTrue) xs -> return' xs @@ -104,7 +105,7 @@ reduceMapMap = let return' = return . Just in RewriteRule $ \_ -> \case -- reduce `Map` - Map' _ _ (LamId _ _) xs -> return' xs + Map' _ _ (LamId _) xs -> return' xs Map' _ t3 g (Map' t1 _ f xs) -> do x <- genVarName' let h = Lam x t1 (App g (App f (Var x))) @@ -127,7 +128,6 @@ reduceMapMap = -- reduce `Sorted` Sorted' t (Reversed' _ xs) -> return' $ Sorted' t xs Sorted' t (Sorted' _ xs) -> return' $ Sorted' t xs - -- others _ -> return Nothing reduceFoldMap :: MonadAlpha m => RewriteRule m @@ -149,6 +149,9 @@ reduceFoldMap = x1 <- genVarName' return' $ Foldl' t1 t3 (Lam2 x3 t3 x1 t1 (App2 g (Var x3) (App f (Var x1)))) init xs -- others + Len' t (SetAt' _ xs _ _) -> return' $ Len' t xs + Len' t (Scanl' _ _ _ _ xs) -> return' $ Plus' (Len' t xs) (LitInt' 1) + At' t (SetAt' _ xs i' x) i -> return' $ If' t (Equal' IntTy i' i) x (At' t xs i) _ -> return Nothing reduceFold :: Monad m => RewriteRule m @@ -176,6 +179,7 @@ reduceFoldBuild = Elem' t y (Cons' _ x xs) -> return' $ And' (Equal' t x y) (Elem' t y xs) Elem' _ x (Range1' n) -> return' $ And' (LessEqual' IntTy Lit0 x) (LessThan' IntTy x n) -- others + Len' t (Build' _ _ base n) -> return' $ Plus' (Len' t base) n _ -> return Nothing rule :: MonadAlpha m => RewriteRule m @@ -196,7 +200,7 @@ runProgram = applyRewriteRuleProgram' rule -- | `run` does short cut fusion. -- -- * This function is mainly for polymorphic reductions. This dosn't do much about concrete things, e.g., arithmetical operations. --- * This doesn't do nothing about `Scanl` or `SetAt`. +-- * This does nothing about `Build`, `Scanl` or `SetAt` except combinations with `Len` or `At`. -- -- == Example -- diff --git a/src/Jikka/Core/Evaluate.hs b/src/Jikka/Core/Evaluate.hs index 2143806a..bdab8eec 100644 --- a/src/Jikka/Core/Evaluate.hs +++ b/src/Jikka/Core/Evaluate.hs @@ -275,7 +275,9 @@ evaluateExpr env = \case Var x -> case lookup x env of Nothing -> throwInternalError $ "undefined variable: " ++ unVarName x Just val -> return val - Lit lit -> literalToValue lit + Lit lit -> case lit of + LitBuiltin ConvexHullTrickInit -> callBuiltin ConvexHullTrickInit [] + _ -> literalToValue lit If' _ p e1 e2 -> do p <- valueToBool =<< evaluateExpr env p if p diff --git a/src/Jikka/Core/Format.hs b/src/Jikka/Core/Format.hs index edfed1a0..0147971a 100644 --- a/src/Jikka/Core/Format.hs +++ b/src/Jikka/Core/Format.hs @@ -1,4 +1,5 @@ {-# LANGUAGE LambdaCase #-} +{-# LANGUAGE PatternSynonyms #-} -- | -- Module : Jikka.Core.Format @@ -24,25 +25,102 @@ import Data.Char (toLower) import Data.List (intercalate) import Data.Text (Text, pack) import Jikka.Common.Format.AutoIndent +import Jikka.Core.Language.BuiltinPatterns (pattern Range1') import Jikka.Core.Language.Expr +import Jikka.Core.Language.FreeVars (isUnusedVar) +import Jikka.Core.Language.LambdaPatterns import Jikka.Core.Language.Util +-- | See also Table 2 of . +newtype Prec = Prec Int + deriving (Eq, Ord, Show, Read) + +instance Enum Prec where + toEnum n = Prec n + fromEnum (Prec n) = n + +identPrec = Prec 12 + +funCallPrec = Prec 11 + +unaryPrec = Prec 10 + +powerPrec = Prec 8 + +multPrec = Prec 7 + +addPrec = Prec 6 + +appendPrec = Prec 5 + +comparePrec = Prec 4 + +andPrec = Prec 3 + +orPrec = Prec 2 + +impliesPrec = Prec 1 + +commaPrec = Prec 0 + +lambdaPrec = Prec (-1) + +parenPrec = Prec (-2) + +data Assoc + = NoAssoc + | LeftToRight + | RightToLeft + deriving (Eq, Ord, Enum, Show, Read) + paren :: String -> String paren s = "(" ++ s ++ ")" -formatType :: Type -> String -formatType = \case - VarTy (TypeName a) -> a - IntTy -> "int" - BoolTy -> "bool" - ListTy t -> formatType t ++ " list" +-- | `resolvePrec` inserts parens to the given string if required. +-- +-- >>> resolvePrec multPrec ("1 + 2", addPrec) ++ " * 3" +-- "(1 + 2) * 3" +-- +-- >>> resolvePrec addPrec ("1 * 2", multPrec) ++ " + 3" +-- "1 * 2 + 3" +resolvePrec :: Prec -> (String, Prec) -> String +resolvePrec cur (s, prv) + | cur > prv = paren s + | otherwise = s + +-- | `resolvePrecLeft` inserts parens to the given string if required. +-- +-- >>> resolvePrecLeft addPrec LeftToRight ("1 - 2", addPrec) ++ " - 3" +-- "1 - 2 - 3" +resolvePrecLeft :: Prec -> Assoc -> (String, Prec) -> String +resolvePrecLeft cur assoc (s, prv) + | cur > prv || (cur == prv && assoc /= LeftToRight) = paren s + | otherwise = s + +-- | `resolvePrecRight` inserts parens to the given string if required. +-- +-- >>> "1 - " ++ resolvePrecRight addPrec LeftToRight ("2 - 3", addPrec) +-- "1 - (2 - 3)" +resolvePrecRight :: Prec -> Assoc -> (String, Prec) -> String +resolvePrecRight cur assoc (s, prv) + | cur > prv || (cur == prv && assoc /= RightToLeft) = paren s + | otherwise = s + +formatType' :: Type -> (String, Prec) +formatType' = \case + VarTy (TypeName a) -> (a, identPrec) + IntTy -> ("int", identPrec) + BoolTy -> ("bool", identPrec) + ListTy t -> (resolvePrec funCallPrec (formatType' t) ++ " list", funCallPrec) TupleTy ts -> case ts of - [t] -> paren $ formatType t ++ "," - _ -> paren $ intercalate " * " (map formatType ts) - t@(FunTy _ _) -> - let (ts, ret) = uncurryFunTy t - in paren $ intercalate " -> " (map formatType (ts ++ [ret])) - DataStructureTy ds -> formatDataStructure ds + [t] -> (resolvePrec (pred multPrec) (formatType' t) ++ ",", multPrec) + _ -> (intercalate " * " (map (resolvePrec (pred multPrec) . formatType') ts), multPrec) + FunTy t1 t2 -> + (resolvePrecLeft impliesPrec RightToLeft (formatType' t1) ++ " -> " ++ resolvePrecRight impliesPrec RightToLeft (formatType' t2), impliesPrec) + DataStructureTy ds -> (formatDataStructure ds, identPrec) + +formatType :: Type -> String +formatType = resolvePrec parenPrec . formatType' formatDataStructure :: DataStructure -> String formatDataStructure = \case @@ -57,49 +135,52 @@ formatSemigroup = \case data Builtin' = Fun [Type] String - | PrefixOp String - | InfixOp [Type] String + | PrefixOp [Type] String + | InfixOp [Type] String Prec Assoc | At' Type + | SetAt' Type + | Tuple' [Type] + | Proj' [Type] Integer | If' Type deriving (Eq, Ord, Show, Read) fun :: String -> Builtin' fun = Fun [] -infixOp :: String -> Builtin' +infixOp :: String -> Prec -> Assoc -> Builtin' infixOp = InfixOp [] analyzeBuiltin :: Builtin -> Builtin' analyzeBuiltin = \case -- arithmetical functions - Negate -> PrefixOp "negate" - Plus -> infixOp "+" - Minus -> infixOp "-" - Mult -> infixOp "*" - FloorDiv -> infixOp "/" - FloorMod -> infixOp "%" - CeilDiv -> fun "ceildiv" - CeilMod -> fun "ceilmod" - Pow -> infixOp "**" + Negate -> PrefixOp [] "-" + Plus -> infixOp "+" addPrec LeftToRight + Minus -> infixOp "-" addPrec LeftToRight + Mult -> infixOp "*" multPrec LeftToRight + FloorDiv -> infixOp "/" multPrec LeftToRight + FloorMod -> infixOp "%" multPrec LeftToRight + CeilDiv -> infixOp "/^" multPrec LeftToRight + CeilMod -> infixOp "%^" multPrec LeftToRight + Pow -> infixOp "**" powerPrec RightToLeft -- advanced arithmetical functions Abs -> fun "abs" Gcd -> fun "gcd" Lcm -> fun "lcm" - Min2 t -> Fun [t] "min" - Max2 t -> Fun [t] "max" + Min2 t -> InfixOp [t] " InfixOp [t] ">?" appendPrec LeftToRight -- logical functions - Not -> PrefixOp "not" - And -> infixOp "and" - Or -> infixOp "or" - Implies -> infixOp "implies" + Not -> PrefixOp [] "not" + And -> infixOp "and" andPrec RightToLeft + Or -> infixOp "or" orPrec RightToLeft + Implies -> infixOp "implies" impliesPrec RightToLeft If t -> If' t -- bitwise functions - BitNot -> PrefixOp "~" - BitAnd -> infixOp "&" - BitOr -> infixOp "|" - BitXor -> infixOp "^" - BitLeftShift -> infixOp "<<" - BitRightShift -> infixOp ">>" + BitNot -> PrefixOp [] "~" + BitAnd -> infixOp "&" multPrec LeftToRight + BitOr -> infixOp "|" appendPrec LeftToRight + BitXor -> infixOp "^" addPrec LeftToRight + BitLeftShift -> infixOp "<<" powerPrec LeftToRight + BitRightShift -> infixOp ">>" powerPrec LeftToRight -- matrix functions MatAp _ _ -> fun "matap" MatZero _ -> fun "matzero" @@ -131,33 +212,33 @@ analyzeBuiltin = \case Map t1 t2 -> Fun [t1, t2] "map" Filter t -> Fun [t] "filter" At t -> At' t - SetAt t -> Fun [t] "setAt" + SetAt t -> SetAt' t Elem t -> Fun [t] "elem" Sum -> fun "sum" Product -> fun "product" ModSum -> fun "modsum" ModProduct -> fun "modproduct" - Min1 t -> Fun [t] "min1" - Max1 t -> Fun [t] "max1" + Min1 t -> Fun [t] "min" + Max1 t -> Fun [t] "max" ArgMin t -> Fun [t] "argmin" ArgMax t -> Fun [t] "argmax" All -> fun "all" Any -> fun "any" Sorted t -> Fun [t] "sort" Reversed t -> Fun [t] "reverse" - Range1 -> fun "range1" + Range1 -> fun "range" Range2 -> fun "range2" Range3 -> fun "range3" -- tuple functions - Tuple ts -> Fun ts "tuple" - Proj ts n -> Fun ts ("proj" ++ show n) + Tuple ts -> Tuple' ts + Proj ts n -> Proj' ts (toInteger n) -- comparison - LessThan t -> InfixOp [t] "<" - LessEqual t -> InfixOp [t] "<=" - GreaterThan t -> InfixOp [t] ">" - GreaterEqual t -> InfixOp [t] ">=" - Equal t -> InfixOp [t] "==" - NotEqual t -> InfixOp [t] "!=" + LessThan t -> InfixOp [t] "<" comparePrec NoAssoc + LessEqual t -> InfixOp [t] "<=" comparePrec NoAssoc + GreaterThan t -> InfixOp [t] ">" comparePrec NoAssoc + GreaterEqual t -> InfixOp [t] ">=" comparePrec NoAssoc + Equal t -> InfixOp [t] "==" comparePrec NoAssoc + NotEqual t -> InfixOp [t] "!=" comparePrec NoAssoc -- combinational functions Fact -> fun "fact" Choose -> fun "choose" @@ -176,33 +257,41 @@ formatTemplate = \case [] -> "" ts -> "<" ++ intercalate ", " (map formatType ts) ++ ">" -formatFunCall :: String -> [Expr] -> String +formatFunCall :: (String, Prec) -> [Expr] -> (String, Prec) formatFunCall f = \case [] -> f - args -> f ++ "(" ++ intercalate ", " (map formatExpr' args) ++ ")" + args -> (resolvePrec funCallPrec f ++ "(" ++ intercalate ", " (map (resolvePrec commaPrec . formatExpr') args) ++ ")", funCallPrec) formatBuiltinIsolated' :: Builtin' -> String formatBuiltinIsolated' = \case Fun ts name -> name ++ formatTemplate ts - PrefixOp op -> paren op - InfixOp ts op -> paren $ op ++ formatTemplate ts + PrefixOp ts op -> paren $ op ++ formatTemplate ts + InfixOp ts op _ _ -> paren $ op ++ formatTemplate ts At' t -> paren $ "at" ++ formatTemplate [t] + SetAt' t -> paren $ "set-at" ++ formatTemplate [t] + Tuple' ts -> paren $ "tuple" ++ formatTemplate ts + Proj' ts n -> paren $ "proj-" ++ show n ++ formatTemplate ts If' t -> paren $ "if-then-else" ++ formatTemplate [t] formatBuiltinIsolated :: Builtin -> String formatBuiltinIsolated = formatBuiltinIsolated' . analyzeBuiltin -formatBuiltin' :: Builtin' -> [Expr] -> String +formatBuiltin' :: Builtin' -> [Expr] -> (String, Prec) formatBuiltin' builtin args = case (builtin, args) of - (Fun _ name, _) -> formatFunCall name args - (PrefixOp op, e1 : args) -> formatFunCall (paren $ op ++ " " ++ formatExpr' e1) args - (InfixOp _ op, e1 : e2 : args) -> formatFunCall (paren $ formatExpr' e1 ++ " " ++ op ++ " " ++ formatExpr' e2) args - (At' _, e1 : e2 : args) -> formatFunCall (paren $ formatExpr' e1 ++ ")[" ++ formatExpr' e2 ++ "]") args - (If' _, e1 : e2 : e3 : args) -> formatFunCall (paren $ "if" ++ " " ++ formatExpr' e1 ++ " then " ++ formatExpr' e2 ++ " else " ++ formatExpr' e3) args - _ -> formatFunCall (formatBuiltinIsolated' builtin) args + (Fun _ "map", [Lam x IntTy e, Range1' n]) | x `isUnusedVar` e -> formatFunCall ("replicate", identPrec) [n, e] + (Fun _ name, _) -> formatFunCall (name, identPrec) args + (PrefixOp _ op, e1 : args) -> formatFunCall (op ++ " " ++ resolvePrec unaryPrec (formatExpr' e1), unaryPrec) args + (InfixOp _ op prec assoc, e1 : e2 : args) -> formatFunCall (resolvePrecLeft prec assoc (formatExpr' e1) ++ " " ++ op ++ " " ++ resolvePrecRight prec assoc (formatExpr' e2), prec) args + (At' _, e1 : e2 : args) -> formatFunCall (resolvePrec identPrec (formatExpr' e1) ++ "[" ++ resolvePrec parenPrec (formatExpr' e2) ++ "]", identPrec) args + (SetAt' _, e1 : e2 : e3 : args) -> formatFunCall (resolvePrec identPrec (formatExpr' e1) ++ "[" ++ resolvePrec parenPrec (formatExpr' e2) ++ " := " ++ resolvePrec parenPrec (formatExpr' e3) ++ "]", identPrec) args + (Tuple' [_], e : args) -> formatFunCall (paren (resolvePrec commaPrec (formatExpr' e) ++ ","), identPrec) args + (Tuple' ts, args) | length args >= length ts -> formatFunCall (paren (intercalate ", " (map (resolvePrec commaPrec . formatExpr') (take (length ts) args))), identPrec) (drop (length ts) args) + (Proj' _ n, e : args) -> formatFunCall (resolvePrec identPrec (formatExpr' e) ++ "." ++ show n, identPrec) args + (If' _, e1 : e2 : e3 : args) -> formatFunCall ("if" ++ " " ++ resolvePrec parenPrec (formatExpr' e1) ++ " then " ++ resolvePrec parenPrec (formatExpr' e2) ++ " else " ++ resolvePrec lambdaPrec (formatExpr' e3), lambdaPrec) args + _ -> formatFunCall (formatBuiltinIsolated' builtin, identPrec) args formatBuiltin :: Builtin -> [Expr] -> String -formatBuiltin = formatBuiltin' . analyzeBuiltin +formatBuiltin f args = resolvePrec parenPrec (formatBuiltin' (analyzeBuiltin f) args) formatLiteral :: Literal -> String formatLiteral = \case @@ -215,33 +304,35 @@ formatLiteral = \case formatFormalArgs :: [(VarName, Type)] -> String formatFormalArgs args = unwords $ map (\(x, t) -> paren (unVarName x ++ ": " ++ formatType t)) args -formatExpr' :: Expr -> String +formatExpr' :: Expr -> (String, Prec) formatExpr' = \case - Var x -> unVarName x - Lit lit -> formatLiteral lit + Var x -> (unVarName x, identPrec) + Lit lit -> (formatLiteral lit, identPrec) e@(App _ _) -> let (f, args) = curryApp e in case f of - Var x -> formatFunCall (unVarName x) args - Lit (LitBuiltin builtin) -> formatBuiltin builtin args + Var x -> formatFunCall (unVarName x, identPrec) args + Lit (LitBuiltin builtin) -> (formatBuiltin builtin args, identPrec) _ -> formatFunCall (formatExpr' f) args + LamId _ -> ("id", identPrec) + LamConst _ e -> formatFunCall ("const", identPrec) [e] e@(Lam _ _ _) -> let (args, body) = uncurryLam e - in paren $ "fun " ++ formatFormalArgs args ++ " ->\n" ++ indent ++ "\n" ++ formatExpr' body ++ "\n" ++ dedent ++ "\n" - Let x t e1 e2 -> "let " ++ unVarName x ++ ": " ++ formatType t ++ " =\n" ++ indent ++ "\n" ++ formatExpr' e1 ++ "\n" ++ dedent ++ "\nin " ++ formatExpr' e2 + in ("fun " ++ formatFormalArgs args ++ " ->\n" ++ indent ++ "\n" ++ resolvePrec parenPrec (formatExpr' body) ++ "\n" ++ dedent ++ "\n", lambdaPrec) + Let x t e1 e2 -> ("let " ++ unVarName x ++ ": " ++ formatType t ++ " =\n" ++ indent ++ "\n" ++ resolvePrec parenPrec (formatExpr' e1) ++ "\n" ++ dedent ++ "\nin " ++ resolvePrec lambdaPrec (formatExpr' e2), lambdaPrec) formatExpr :: Expr -> String -formatExpr = unwords . makeIndentFromMarkers 4 . lines . formatExpr' +formatExpr = unlines . makeIndentFromMarkers 4 . lines . resolvePrec parenPrec . formatExpr' formatToplevelExpr :: ToplevelExpr -> [String] formatToplevelExpr = \case - ResultExpr e -> lines (formatExpr' e) + ResultExpr e -> lines (resolvePrec lambdaPrec (formatExpr' e)) ToplevelLet x t e cont -> let' (unVarName x) t e cont ToplevelLetRec f args ret e cont -> let' ("rec " ++ unVarName f ++ " " ++ formatFormalArgs args) ret e cont where let' s t e cont = ["let " ++ s ++ ": " ++ formatType t ++ " =", indent] - ++ lines (formatExpr' e) + ++ lines (resolvePrec parenPrec (formatExpr' e)) ++ [dedent, "in"] ++ formatToplevelExpr cont diff --git a/src/Jikka/Core/Language/Expr.hs b/src/Jikka/Core/Language/Expr.hs index fbb0cc42..e1110c88 100644 --- a/src/Jikka/Core/Language/Expr.hs +++ b/src/Jikka/Core/Language/Expr.hs @@ -373,11 +373,6 @@ pattern Lam2 x1 t1 x2 t2 e = Lam x1 t1 (Lam x2 t2 e) pattern Lam3 x1 t1 x2 t2 x3 t3 e = Lam x1 t1 (Lam x2 t2 (Lam x3 t3 e)) -pattern LamId x t <- - (\case Lam x t (Var y) | x == y -> Just (x, t); _ -> Nothing -> Just (x, t)) - where - LamId x t = Lam x t (Var x) - -- | `ToplevelExpr` is the toplevel exprs. In our core, "let rec" is allowed only on the toplevel. -- -- \[ diff --git a/src/Jikka/Core/Language/LambdaPatterns.hs b/src/Jikka/Core/Language/LambdaPatterns.hs new file mode 100644 index 00000000..2e642bd8 --- /dev/null +++ b/src/Jikka/Core/Language/LambdaPatterns.hs @@ -0,0 +1,19 @@ +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE ViewPatterns #-} + +module Jikka.Core.Language.LambdaPatterns where + +import Jikka.Core.Language.Expr +import Jikka.Core.Language.FreeVars + +pattern LamId t <- + (\case Lam x t (Var y) | x == y -> Just t; _ -> Nothing -> Just t) + where + LamId t = Lam "x" t (Var "x") + +pattern LamConst t e <- + (\case Lam x t e | x `isUnusedVar` e -> Just (t, e); _ -> Nothing -> Just (t, e)) + where + LamConst t e = Lam (findUnusedVarName' e) t e diff --git a/src/Jikka/Core/Language/Util.hs b/src/Jikka/Core/Language/Util.hs index ef4726e8..841b8ece 100644 --- a/src/Jikka/Core/Language/Util.hs +++ b/src/Jikka/Core/Language/Util.hs @@ -3,10 +3,13 @@ module Jikka.Core.Language.Util where +import Control.Arrow import Control.Monad.Identity +import Control.Monad.Trans.Maybe import Control.Monad.Writer (execWriter, tell) -import Data.Maybe (isJust) +import Data.Maybe import Data.Monoid (Dual (..)) +import qualified Data.Vector as V import Jikka.Common.Alpha import Jikka.Common.Error import Jikka.Core.Language.BuiltinPatterns @@ -352,3 +355,23 @@ replaceLenF f i k = go Lam x t body -> Lam x t <$> (if x == f then return body else go body) Let y _ _ _ | y == i -> throwInternalError "Jikka.Core.Language.Util.replaceLenF: name conflict" Let y t e1 e2 -> Let y t <$> go e1 <*> (if y == f then return e2 else go e2) + +-- | `getRecurrenceFormulaBase` makes a pair @((a_0, ..., a_{k - 1}), a)@ from @setat (... (setat a 0 a_0) ...) (k - 1) a_{k - 1})@. +getRecurrenceFormulaBase :: Expr -> ([Expr], Expr) +getRecurrenceFormulaBase = go (V.replicate recurrenceLimit Nothing) + where + recurrenceLimit :: Num a => a + recurrenceLimit = 20 + go :: V.Vector (Maybe (Expr, Type)) -> Expr -> ([Expr], Expr) + go base = \case + SetAt' t e (LitInt' i) e' + | 0 <= i && i < recurrenceLimit -> go (base V.// [(fromInteger i, Just (e', t))]) e + | otherwise -> second (\e -> SetAt' t e (LitInt' i) e') $ go base e + e -> + let (base', base'') = span isJust (V.toList base) + base''' = map (fst . fromJust) base' + e'' = foldr (\(i, e') e -> maybe id (\(e', t) e -> SetAt' t e (LitInt' i) e') e' e) e (zip [toInteger (length base') ..] base'') + in (base''', e'') + +hoistMaybe :: Applicative m => Maybe a -> MaybeT m a +hoistMaybe = MaybeT . pure diff --git a/test/Jikka/Core/FormatSpec.hs b/test/Jikka/Core/FormatSpec.hs index 614cc792..2b1ddef8 100644 --- a/test/Jikka/Core/FormatSpec.hs +++ b/test/Jikka/Core/FormatSpec.hs @@ -39,8 +39,8 @@ spec = describe "formatExpr" $ do [ "let rec solve$0 (n$1: int): int =", " let xs$2: int list =", " map((fun (i$3: int) ->", - " (i$3 * i$3)", - " ), range1(n$1))", + " i$3 * i$3", + " ), range(n$1))", " in sum(xs$2)", "in", "solve$0" diff --git a/test/Jikka/RestrictedPython/Convert/ToCoreSpec.hs b/test/Jikka/RestrictedPython/Convert/ToCoreSpec.hs index b24a7a58..20a9f851 100644 --- a/test/Jikka/RestrictedPython/Convert/ToCoreSpec.hs +++ b/test/Jikka/RestrictedPython/Convert/ToCoreSpec.hs @@ -72,26 +72,26 @@ spec = describe "run" $ do " 0", " in let b: $1 =", " 1", - " in let $4: ($5 * $6) =", - " foldl((fun ($4: ($5 * $6)) ($3: $2) ->", + " in let $4: $5 * $6 =", + " foldl((fun ($4: $5 * $6) ($3: $2) ->", " let b: $5 =", - " proj0($4)", + " $4.0", " in let a: $6 =", - " proj1($4)", + " $4.1", " in let i: $7 =", " $3", " in let c: $8 =", - " (a + b)", + " a + b", " in let a: $9 =", " b", " in let b: $10 =", " c", - " in tuple(b, a)", - " ), tuple(b, a), range1(n))", + " in (b, a)", + " ), (b, a), range(n))", " in let b: $5 =", - " proj0($4)", + " $4.0", " in let a: $6 =", - " proj1($4)", + " $4.1", " in a", "in", "solve" @@ -115,14 +115,14 @@ spec = describe "run" $ do let expected = unlines [ "let rec solve : int =", - " let $2: ($1,) =", - " (if true then let x: $3 =", + " let $2: $1, =", + " if true then let x: $3 =", " 1", - " in tuple(x) else let x: $5 =", + " in (x,) else let x: $5 =", " 0", - " in tuple(x))", + " in (x,)", " in let x: $1 =", - " proj0($2)", + " $2.0", " in x", "in", "solve"