Skip to content

Commit

Permalink
Merge pull request #117 from kmyk/kubaru-morau
Browse files Browse the repository at this point in the history
Covert Kubaru DP to Morau DP
  • Loading branch information
kmyk authored Jul 31, 2021
2 parents 17debdc + c2d7f47 commit 4860fc7
Show file tree
Hide file tree
Showing 31 changed files with 503 additions and 176 deletions.
2 changes: 2 additions & 0 deletions Jikka.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ library
Jikka.Core.Convert.ConvexHullTrick
Jikka.Core.Convert.CumulativeSum
Jikka.Core.Convert.Eta
Jikka.Core.Convert.KubaruToMorau
Jikka.Core.Convert.MakeScanl
Jikka.Core.Convert.MatrixExponentiation
Jikka.Core.Convert.PropagateMod
Expand All @@ -94,6 +95,7 @@ library
Jikka.Core.Language.BuiltinPatterns
Jikka.Core.Language.Expr
Jikka.Core.Language.FreeVars
Jikka.Core.Language.LambdaPatterns
Jikka.Core.Language.Lint
Jikka.Core.Language.RewriteRules
Jikka.Core.Language.Runtime
Expand Down
5 changes: 4 additions & 1 deletion examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,13 @@
- Educational DP Contest [Q - Flowers](https://atcoder.jp/contests/dp/tasks/dp_q)
- A lazy propagation segment tree / 遅延伝播セグメント木
- submission at v5.0.11.0: <https://atcoder.jp/contests/dp/submissions/24561829>
- :heavy_check_mark: AC `dp_z.py`
- :heavy_check_mark: AC `dp_z-morau.py`
- Educational DP Contest [Z - Frog 3](https://atcoder.jp/contests/dp/tasks/dp_z)
- Convex hull trick
- submission at v5.0.11.0: <https://atcoder.jp/contests/dp/submissions/24563891>
- :heavy_check_mark: AC `dp_z-kubaru.py`
- The Kubaru DP version of `dp_z-morau.py`.
- submission at v5.1.0.0: <https://atcoder.jp/contests/dp/submissions/24701829>
- :hourglass: TLE `abc134_c.py`
- AtCoder Beginner Contest 134 [C - Exception Handling](https://atcoder.jp/contests/abc134/tasks/abc134_c)
- Cumulative sums from both sides / 両側からの累積和
Expand Down
1 change: 1 addition & 0 deletions examples/data/aliases.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{ "dp_z-kubaru": "dp_z-morau" }
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
27 changes: 27 additions & 0 deletions examples/dp_z-kubaru.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# https://atcoder.jp/contests/dp/tasks/dp_z
from typing import *

INF = 10 ** 18

def solve(n: int, c: int, h: List[int]) -> int:
assert 2 <= n <= 10 ** 5
assert 1 <= c <= 10 ** 12
assert len(h) == n
assert all(1 <= h_i <= 10 ** 6 for h_i in h)

dp = [INF for _ in range(n)]
dp[0] = 0
for i in range(n):
for j in range(i + 1, n):
dp[j] = min(dp[j], dp[i] + (h[i] - h[j]) ** 2 + c)
return dp[n - 1]

def main() -> None:
n, c = map(int, input().split())
h = list(map(int, input().split()))
assert len(h) == n
ans = solve(n, c, h)
print(ans)

if __name__ == '__main__':
main()
File renamed without changes.
18 changes: 18 additions & 0 deletions scripts/integration_tests.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
#!/usr/bin/env python3
import argparse
import concurrent.futures
import functools
import glob
import json
import os
import pathlib
import platform
Expand All @@ -26,6 +28,12 @@ def compile_cxx(src_path: pathlib.Path, dst_path: pathlib.Path):
subprocess.check_call(command, timeout=5 * TIMEOUT_FACTOR)


@functools.lru_cache(maxsize=None)
def get_alias_mapping() -> Dict[str, str]:
with open(pathlib.Path('examples', 'data', 'aliases.json')) as fh:
return json.load(fh)


def collect_input_cases(script: pathlib.Path, *, tempdir: pathlib.Path) -> List[pathlib.Path]:
inputcases: List[pathlib.Path] = []

Expand Down Expand Up @@ -79,6 +87,14 @@ def collect_input_cases(script: pathlib.Path, *, tempdir: pathlib.Path) -> List[
return []
inputcases.append(inputcase)

# resolve alias
aliases = get_alias_mapping()
if script.stem in aliases:
if inputcases:
logger.error("%s: there must not be test cases when it uses an alias: %s", str(script), list(map(str, inputcases)))
return []
return collect_input_cases(script.parent / (aliases[script.stem] + script.suffix), tempdir=tempdir)

return inputcases


Expand Down Expand Up @@ -170,6 +186,8 @@ def find_unused_test_cases() -> List[pathlib.Path]:
errors = list(pathlib.Path('examples', 'errors').glob('*.py'))
unused = []
for path in pathlib.Path('examples', 'data').glob('*'):
if path.name == 'aliases.json':
continue
name = path.name[:-len(''.join(path.suffixes))]
if name not in [script.stem for script in scripts + errors]:
unused.append(path)
Expand Down
4 changes: 3 additions & 1 deletion src/Jikka/CPlusPlus/Convert/FromCore.hs
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,9 @@ runLiteral env = \case
case stmts of
[] -> return e
_ -> throwInternalError "now builtin values don't use statements"
X.LitInt n -> return $ Y.Lit (Y.LitInt64 n)
X.LitInt n
| - (2 ^ 63) <= n && n < 2 ^ 63 -> return $ Y.Lit (Y.LitInt64 n)
| otherwise -> throwInternalError $ "integer value is too large for int64_t: " ++ show n
X.LitBool p -> return $ Y.Lit (Y.LitBool p)
X.LitNil t -> do
t <- runType t
Expand Down
29 changes: 20 additions & 9 deletions src/Jikka/CPlusPlus/Convert/MoveSemantics.hs
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,21 @@ runAssignExpr = \case
AssignIncr e -> AssignIncr <$> runLeftExpr e
AssignDecr e -> AssignDecr <$> runLeftExpr e

isMovable :: VarName -> [[Statement]] -> Bool
isMovable x cont =
let ReadWriteList rs _ = analyzeStatements (concat cont)
in x `S.notMember` rs
isMovableTo :: VarName -> VarName -> [[Statement]] -> Bool
isMovableTo x y cont
| x `S.notMember` readList' (analyzeStatements (concat cont)) = True
| otherwise =
let go = \case
[] -> False
(Assign (AssignExpr SimpleAssign (LeftVar x') (Var y')) : cont')
| x' == x && y' == y ->
let ReadWriteList _ ws' = analyzeStatements cont'
ReadWriteList rs ws = analyzeStatements (concat (tail cont))
in y `S.notMember` S.unions [ws', rs, ws]
(stmt : cont) ->
let ReadWriteList rs ws = analyzeStatement stmt
in x `S.notMember` S.unions [rs, ws] && go cont
in go (head cont)

runStatement :: MonadState (M.Map VarName VarName) m => Statement -> [[Statement]] -> m [Statement]
runStatement stmt cont = case stmt of
Expand Down Expand Up @@ -88,16 +99,16 @@ runStatement stmt cont = case stmt of
DeclareCopy e -> DeclareCopy <$> runExpr e
DeclareInitialize es -> DeclareInitialize <$> mapM runExpr es
case init of
DeclareCopy (Var x) | x `isMovable` cont -> do
DeclareCopy (Var x) | (x `isMovableTo` y) cont -> do
modify' (M.insert y x)
return []
DeclareCopy (Call ConvexHullTrickCtor []) -> return [Declare t y DeclareDefault]
DeclareCopy (Call ConvexHullTrickCopyAddLine [Var x, a, b])
| x `isMovable` cont -> do
| (x `isMovableTo` y) cont -> do
modify' (M.insert y x)
return [callMethod' (Var x) "add_line" [a, b]]
DeclareCopy (Call (SegmentTreeCopySetPoint _) [Var x, i, a])
| x `isMovable` cont -> do
| (x `isMovableTo` y) cont -> do
modify' (M.insert y x)
return [callMethod' (Var x) "set" [i, a]]
_ -> do
Expand All @@ -111,13 +122,13 @@ runStatement stmt cont = case stmt of
AssignExpr SimpleAssign (LeftVar y) (Var x) | x == y -> return []
AssignExpr SimpleAssign (LeftVar y) (Call ConvexHullTrickCopyAddLine [Var x, a, b])
| x == y -> return [callMethod' (Var x) "add_line" [a, b]]
| x `isMovable` cont -> do
| (x `isMovableTo` y) cont -> do
modify' (M.insert y x)
return [callMethod' (Var x) "add_line" [a, b]]
| otherwise -> return [Assign e]
AssignExpr SimpleAssign (LeftVar y) (Call (SegmentTreeCopySetPoint _) [Var x, i, a])
| x == y -> return [callMethod' (Var x) "set" [i, a]]
| x `isMovable` cont -> do
| (x `isMovableTo` y) cont -> do
modify' (M.insert y x)
return [callMethod' (Var x) "set" [i, a]]
| otherwise -> return [Assign e]
Expand Down
4 changes: 3 additions & 1 deletion src/Jikka/CPlusPlus/Format.hs
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,9 @@ formatType = \case
formatLiteral :: Literal -> Code
formatLiteral = \case
LitInt32 n -> show n
LitInt64 n -> show n
LitInt64 n
| - (2 ^ 31) <= n && n < 2 ^ 31 -> show n
| otherwise -> show n ++ "ll"
LitBool p -> if p then "true" else "false"
LitChar c -> show c
LitString s -> show s
Expand Down
2 changes: 1 addition & 1 deletion src/Jikka/CPlusPlus/Language/VariableAnalysis.hs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import qualified Data.Set as S
import Jikka.CPlusPlus.Language.Expr

data ReadWriteList = ReadWriteList
{ readList :: S.Set VarName,
{ readList' :: S.Set VarName,
writeList :: S.Set VarName
}
deriving (Eq, Ord, Show, Read)
Expand Down
2 changes: 2 additions & 0 deletions src/Jikka/Core/Convert.hs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import qualified Jikka.Core.Convert.ConstantPropagation as ConstantPropagation
import qualified Jikka.Core.Convert.ConvexHullTrick as ConvexHullTrick
import qualified Jikka.Core.Convert.CumulativeSum as CumulativeSum
import qualified Jikka.Core.Convert.Eta as Eta
import qualified Jikka.Core.Convert.KubaruToMorau as KubaruToMorau
import qualified Jikka.Core.Convert.MakeScanl as MakeScanl
import qualified Jikka.Core.Convert.MatrixExponentiation as MatrixExponentiation
import qualified Jikka.Core.Convert.PropagateMod as PropagateMod
Expand Down Expand Up @@ -56,6 +57,7 @@ run'' prog = do
prog <- CloseSum.run prog
prog <- CloseAll.run prog
prog <- CloseMin.run prog
prog <- KubaruToMorau.run prog
prog <- CumulativeSum.run prog
prog <- SegmentTree.run prog
prog <- BubbleLet.run prog
Expand Down
4 changes: 2 additions & 2 deletions src/Jikka/Core/Convert/CloseMin.hs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ reduceMin = simpleRewriteRule $ \case
Lam x _ (Min2' _ e1 e2) -> Just $ Min2' t (Min1' t (Cons' t e0 (Map' t1 t2 (Lam x t e1) xs))) (Min1' t (Cons' t e0 (Map' t1 t2 (Lam x t e2) xs)))
Lam x _ (Negate' e) -> Just $ Negate' (Max1' t (Cons' t (Negate' e0) (Map' t1 t2 (Lam x IntTy e) xs)))
Lam x _ (Plus' e1 e2) | x `isUnusedVar` e1 -> Just $ Plus' e1 (Min1' t (Cons' t (Minus' e0 e1) (Map' t1 t2 (Lam x IntTy e2) xs)))
Lam x _ (Plus' e1 e2) | x `isUnusedVar` e2 -> Just $ Plus' (Min1' t (Cons' t (Minus' e0 e1) (Map' t1 t2 (Lam x IntTy e1) xs))) e2
Lam x _ (Plus' e1 e2) | x `isUnusedVar` e2 -> Just $ Plus' (Min1' t (Cons' t (Minus' e0 e2) (Map' t1 t2 (Lam x IntTy e1) xs))) e2
_ -> Nothing
_ -> Nothing

Expand All @@ -80,7 +80,7 @@ reduceMax = simpleRewriteRule $ \case
Lam x _ (Max2' _ e1 e2) -> Just $ Max2' t (Max1' t (Cons' t e0 (Map' t1 t2 (Lam x t e1) xs))) (Max1' t (Cons' t e0 (Map' t1 t2 (Lam x t e2) xs)))
Lam x _ (Negate' e) -> Just $ Negate' (Min1' t (Cons' t (Negate' e0) (Map' t1 t2 (Lam x IntTy e) xs)))
Lam x _ (Plus' e1 e2) | x `isUnusedVar` e1 -> Just $ Plus' e1 (Max1' t (Cons' t (Minus' e0 e1) (Map' t1 t2 (Lam x IntTy e2) xs)))
Lam x _ (Plus' e1 e2) | x `isUnusedVar` e2 -> Just $ Plus' (Max1' t (Cons' t (Minus' e0 e1) (Map' t1 t2 (Lam x IntTy e1) xs))) e2
Lam x _ (Plus' e1 e2) | x `isUnusedVar` e2 -> Just $ Plus' (Max1' t (Cons' t (Minus' e0 e2) (Map' t1 t2 (Lam x IntTy e1) xs))) e2
_ -> Nothing
_ -> Nothing

Expand Down
124 changes: 76 additions & 48 deletions src/Jikka/Core/Convert/ConvexHullTrick.hs
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,6 @@ import Jikka.Core.Language.Lint
import Jikka.Core.Language.RewriteRules
import Jikka.Core.Language.Util

hoistMaybe :: Applicative m => Maybe a -> MaybeT m a
hoistMaybe = MaybeT . pure

-- | This is something commutative because only one kind of @c@ is allowed.
plusPair :: (ArithmeticalExpr, ArithmeticalExpr) -> (ArithmeticalExpr, ArithmeticalExpr) -> Maybe (ArithmeticalExpr, ArithmeticalExpr)
plusPair (a1, c1) (a2, _) | isZeroArithmeticalExpr a2 = Just (a1, c1)
Expand Down Expand Up @@ -112,47 +109,48 @@ parseLinearFunctionBody' f i j e = result <$> go e
_ -> Nothing
_ -> Nothing

parseLinearFunctionBody :: MonadAlpha m => VarName -> VarName -> Integer -> Expr -> m (Maybe (Expr, Expr, Expr, Expr, Expr))
parseLinearFunctionBody :: MonadAlpha m => VarName -> VarName -> Integer -> Expr -> m (Maybe (Expr, Expr, Expr, Expr, Expr, Maybe Expr))
parseLinearFunctionBody f i k = runMaybeT . go
where
goMin e j step size = case unNPlusKPattern (parseArithmeticalExpr size) of
Just (i', k') | i' == i && k' == k -> do
(a, b, c, d) <- hoistMaybe $ parseLinearFunctionBody' f i j step
-- raname @j@ to @i@
a <- lift $ substitute j (Var i) a
c <- lift $ substitute j (Var i) c
return (LitInt' 1, a, b, c, d, (`Minus'` d) <$> e)
_ -> hoistMaybe Nothing
goMax e j step size = do
(sign, a, b, c, d, e) <- goMin e j step size
return (Negate' sign, a, Negate' b, Negate' c, d, Negate' <$> e)
go = \case
Min1' _ (Map' _ _ (Lam j _ step) (Range1' size)) -> case unNPlusKPattern (parseArithmeticalExpr size) of
Just (i', k') | i' == i && k' == k -> do
(a, b, c, d) <- hoistMaybe $ parseLinearFunctionBody' f i j step
-- raname @j@ to @i@
a <- lift $ substitute j (Var i) a
c <- lift $ substitute j (Var i) c
return (LitInt' 1, a, b, c, d)
_ -> hoistMaybe Nothing
Max1' _ (Map' _ _ (Lam j _ step) (Range1' size)) -> case unNPlusKPattern (parseArithmeticalExpr size) of
Just (i', k') | i' == i && k' == k -> do
(a, b, c, d) <- hoistMaybe $ parseLinearFunctionBody' f i j step
-- raname @j@ to @i@
a <- lift $ substitute j (Var i) a
c <- lift $ substitute j (Var i) c
return (LitInt' (-1), a, Negate' b, Negate' c, d)
_ -> hoistMaybe Nothing
Min1' _ (Map' _ _ (Lam j _ step) (Range1' size)) -> goMin Nothing j step size
Max1' _ (Map' _ _ (Lam j _ step) (Range1' size)) -> goMax Nothing j step size
Min1' _ (Cons' _ e (Map' _ _ (Lam j _ step) (Range1' size))) -> goMin (Just e) j step size
Max1' _ (Cons' _ e (Map' _ _ (Lam j _ step) (Range1' size))) -> goMax (Just e) j step size
Min1' _ (Snoc' _ (Map' _ _ (Lam j _ step) (Range1' size)) e) -> goMin (Just e) j step size
Max1' _ (Snoc' _ (Map' _ _ (Lam j _ step) (Range1' size)) e) -> goMax (Just e) j step size
Negate' e -> do
(sign, a, b, c, d) <- go e
return (Negate' sign, a, b, c, Negate' d)
(sign, a, b, c, d, e) <- go e
return (Negate' sign, a, b, c, Negate' d, e)
Plus' e1 e2 | isConstantTimeExpr e2 -> do
(sign, a, b, c, d) <- go e1
return (sign, a, b, c, Plus' d e2)
(sign, a, b, c, d, e) <- go e1
return (sign, a, b, c, Plus' d e2, e)
Plus' e1 e2 | isConstantTimeExpr e1 -> do
(sign, a, b, c, d) <- go e2
return (sign, a, b, c, Plus' e1 d)
(sign, a, b, c, d, e) <- go e2
return (sign, a, b, c, Plus' e1 d, e)
Minus' e1 e2 | isConstantTimeExpr e2 -> do
(sign, a, b, c, d) <- go e1
return (sign, a, b, c, Minus' d e2)
(sign, a, b, c, d, e) <- go e1
return (sign, a, b, c, Minus' d e2, e)
Minus' e1 e2 | isConstantTimeExpr e1 -> do
(sign, a, b, c, d) <- go e2
return (Negate' sign, a, b, c, Minus' e1 d)
(sign, a, b, c, d, e) <- go e2
return (Negate' sign, a, b, c, Minus' e1 d, e)
Mult' e1 e2 | isConstantTimeExpr e2 -> do
(sign, a, b, c, d) <- go e1
return (Mult' sign e2, a, b, c, Mult' d e2)
(sign, a, b, c, d, e) <- go e1
return (Mult' sign e2, a, b, c, Mult' d e2, e)
Mult' e1 e2 | isConstantTimeExpr e1 -> do
(sign, a, b, c, d) <- go e2
return (Mult' e1 sign, a, b, c, Mult' e1 d)
(sign, a, b, c, d, e) <- go e2
return (Mult' e1 sign, a, b, c, Mult' e1 d, e)
_ -> hoistMaybe Nothing

getLength :: Expr -> Maybe Integer
Expand All @@ -166,27 +164,57 @@ rule :: (MonadAlpha m, MonadError Error m) => RewriteRule m
rule = RewriteRule $ \_ -> \case
-- build (fun f -> step(f)) base n
Build' IntTy (Lam f _ step) base n -> runMaybeT $ do
let ts = [ConvexHullTrickTy, ListTy IntTy]
i <- lift genVarName'
k <- hoistMaybe $ getLength base
step <- replaceLenF f i k step
-- step(f) = sign(f) * min (map (fun j -> a(f, j) c(f) + b(f, j)) (range (i + k))) + d(f)
(sign, a, c, b, d) <- MaybeT $ parseLinearFunctionBody f i k step
x <- lift genVarName'
y <- lift genVarName'
f' <- lift $ genVarName f
let ts = [ConvexHullTrickTy, ListTy IntTy]
-- base' = (empty, base)
let base' = uncurryApp (Tuple' ts) [ConvexHullTrickInit', base]
-- step(f) = sign() * min (cons e(f, i) (map (fun j -> a(f, j) c(f, i) + b(f, j)) (range (i + k)))) + d(f, i)
(sign, a, c, b, d, e) <- MaybeT $ parseLinearFunctionBody f i k step
-- Update base when k = 0. If user's program has no bugs, it uses min(cons(x, xs)) when k = 0.
(base, n, k, c, d, e) <- case (e, k) of
(Just e, 0) -> do
e0 <- lift $ substitute i (LitInt' 0) e
d0 <- lift $ substitute i (LitInt' 0) d
let base' = Let f (ListTy IntTy) base $ Snoc' IntTy base (Plus' (Mult' sign e0) d0)
c <- lift $ substitute i (Plus' (Var i) (LitInt' 1)) c
d <- lift $ substitute i (Plus' (Var i) (LitInt' 1)) d
e <- lift $ substitute i (Plus' (Var i) (LitInt' 1)) e
return (base', Minus' n (LitInt' 1), k + 1, c, d, Just e)
_ -> return (base, n, k, c, d, e)
-- base' = (cht, base)
base' <- do
x <- lift genVarName'
f' <- lift $ genVarName f
i' <- lift $ genVarName i
a <- lift $ substitute f (Var f') a
b <- lift $ substitute f (Var f') b
a <- lift $ substitute i (Var i') a
b <- lift $ substitute i (Var i') b
-- cht for base[0], ..., base[k - 1]
let cht = Foldl' IntTy ConvexHullTrickTy (Lam2 x ConvexHullTrickTy i' IntTy (ConvexHullTrickInsert' (Var x) a b)) ConvexHullTrickInit' (Range1' (LitInt' k))
return $
Let f' (ListTy IntTy) base $
uncurryApp (Tuple' ts) [cht, Var f']
-- step' = fun (cht, f) i ->
-- let f' = setat f index(i) (min cht f[i + k] + c(i))
-- let f' = setat f index(i) value(..)
-- in let cht' = update cht a(i) b(i)
-- in (cht', f')
let step' =
Lam2 x (TupleTy ts) i IntTy $
Let f (ListTy IntTy) (Proj' ts 1 (Var x)) $
step' <- do
x <- lift genVarName'
-- value(..) = (min e (min cht f[i + k] + c(i)))
let value = Plus' (Mult' sign (maybe id (\e -> Min2' IntTy e) e (ConvexHullTrickGetMin' (Proj' ts 0 (Var x)) c))) d
y <- lift genVarName'
f' <- lift $ genVarName f
a <- lift $ substitute f (Var f') a
b <- lift $ substitute f (Var f') b
a <- lift $ substitute i (Plus' (Var i) (LitInt' k)) a
b <- lift $ substitute i (Plus' (Var i) (LitInt' k)) b
return $
Lam2 x (TupleTy ts) i IntTy $
Let f (ListTy IntTy) (Proj' ts 1 (Var x)) $
Let f' (ListTy IntTy) (Snoc' IntTy (Var f) value) $
Let y ConvexHullTrickTy (ConvexHullTrickInsert' (Proj' ts 0 (Var x)) a b) $
Let f' (ListTy IntTy) (Snoc' IntTy (Var f) (Plus' (Mult' sign (ConvexHullTrickGetMin' (Var y) c)) d)) $
uncurryApp (Tuple' ts) [Var y, Var f']
uncurryApp (Tuple' ts) [Var y, Var f']
-- proj 1 (foldl step' base' (range (n - 1)))
return $ Proj' ts 1 (Foldl' IntTy (TupleTy ts) step' base' (Range1' n))
_ -> return Nothing
Expand Down
Loading

0 comments on commit 4860fc7

Please sign in to comment.