Skip to content

Commit

Permalink
Merge pull request #46 from kmyk/cumulative-sum
Browse files Browse the repository at this point in the history
Use cumulative sums
  • Loading branch information
kmyk authored Jul 11, 2021
2 parents 84d8c4f + 55bc63a commit 33f5ed1
Show file tree
Hide file tree
Showing 22 changed files with 531 additions and 110 deletions.
4 changes: 2 additions & 2 deletions examples/README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# [examples/](https://github.com/kmyk/Jikka/tree/master/examples)
# examples/

## Scripts

Expand All @@ -10,7 +10,7 @@
- [ ] `dp_a.py`
- [ ] `dp_b.py`
- [ ] `m_solutions2019_e.py`
- [ ] `static_range_sum.py`
- [x] `static_range_sum.py`

### Toy scripts

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@

def main() -> None:
parser = argparse.ArgumentParser()
# parser.add_argument('-n', type=int, default=random.randint(2, 200000))
parser.add_argument('-n', type=int, default=random.randint(2, 1000))
args = parser.parse_args()

n = args.n
a = [random.randint(1, 200000) for _ in range(n)]
print(n)
print(len(n), *a, sep='\n')
print(len(a), *a)

if __name__ == '__main__':
main()
23 changes: 23 additions & 0 deletions examples/data/abc134_c.medium.solver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
#!/usr/bin/env python3
from typing import *

def solve(n: int, a: List[int]) -> List[int]:
l = [-1] * (n + 1)
r = [-1] * (n + 1)
for i in range(n):
l[i + 1] = max(l[i], a[i])
for i in reversed(range(n)):
r[i] = max(r[i + 1], a[i])
ans = [-1] * n
for i in range(n):
ans[i] = max(l[i], r[i + 1])
return ans

def main() -> None:
n = int(input())
_, *a = map(int, input().split())
ans = solve(n, a)
print(len(ans), *ans)

if __name__ == '__main__':
main()
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
17 changes: 17 additions & 0 deletions examples/data/static_range_sum.large.generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#!/usr/bin/env python3
import random

def main() -> None:
n = random.randint(1, 5 * 10 ** 5)
q = random.randint(1, 5 * 10 ** 5)
a = [random.randint(1, 10 ** 9) for _ in range(n)]
l = [random.randint(0, n - 1) for _ in range(q)]
r = [random.randint(l[i] + 1, n) for i in range(q)]
print(n)
print(q)
print(len(a), *a)
print(len(l), *l)
print(len(r), *r)

if __name__ == '__main__':
main()
22 changes: 22 additions & 0 deletions examples/data/static_range_sum.large.solver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from typing import *

def solve(n: int, q: int, a: List[int], l: List[int], r: List[int]) -> List[int]:
b = [0] * (n + 1)
for i in range(n):
b[i + 1] = b[i] + a[i]
ans = [-1] * q
for j in range(q):
ans[j] = b[r[j]] - b[l[j]]
return ans

def main() -> None:
n = int(input())
q = int(input())
_, *a = map(int, input().split())
_, *l = map(int, input().split())
_, *r = map(int, input().split())
ans = solve(n, q, a, l, r)
print(len(ans), *ans, sep='\n')

if __name__ == '__main__':
main()
35 changes: 21 additions & 14 deletions scripts/integration_tests.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#!/usr/bin/env python3
import argparse
import concurrent.futures
import glob
import os
import pathlib
import subprocess
Expand All @@ -11,6 +12,9 @@

logger = getLogger(__name__)

CXX_ONLY = 'large'
NO_RPYTHON = 'medium'


def collect_input_cases(script: pathlib.Path, *, tempdir: pathlib.Path) -> List[pathlib.Path]:
inputcases: List[pathlib.Path] = []
Expand All @@ -24,28 +28,29 @@ def collect_input_cases(script: pathlib.Path, *, tempdir: pathlib.Path) -> List[
inputcases.append(path)

# using generators
generator_path = pathlib.Path('examples', 'data', script.stem + '.generator.py')
solver_path = pathlib.Path('examples', 'data', script.stem + '.solver.py')
if generator_path.exists():
for generator_path in pathlib.Path('examples', 'data').glob(glob.escape(script.stem) + '*.generator.py'):
_, testset_name, _, _ = generator_path.name.split('.')

solver_path = pathlib.Path('examples', 'data', script.stem + '.' + testset_name + '.solver.py')
if not solver_path.exists():
logger.error('%s: failed to find the solver', str(script))
return []

for i in range(10):
inputcase = tempdir / "{}.random-large-{}.in".format(script.stem, i)
outputcase = tempdir / "{}.random-large-{}.out".format(script.stem, i)
for i in range(20):
inputcase = tempdir / "{}.{}-{}.in".format(script.stem, testset_name, i)
outputcase = tempdir / "{}.{}-{}.out".format(script.stem, testset_name, i)
with open(inputcase, 'wb') as fh:
try:
subprocess.check_call([sys.executable, str(generator_path)], stdout=fh, timeout=2)
subprocess.check_call([sys.executable, str(generator_path)], stdout=fh, timeout=5)
except subprocess.SubprocessError as e:
logger.error('%s: failed to generate an input of a random case: %s', str(script), e)
logger.error('%s: %s: failed to generate an input of a random case: %s', str(script), str(inputcase), e)
return []
with open(inputcase, 'rb') as fh1:
with open(outputcase, 'wb') as fh2:
try:
subprocess.check_call([sys.executable, str(solver_path)], stdin=fh1, stdout=fh2, timeout=2)
subprocess.check_call([sys.executable, str(solver_path)], stdin=fh1, stdout=fh2, timeout=5)
except subprocess.SubprocessError as e:
logger.error('%s: failed to generate an output of a random case: %s', str(script), e)
logger.error('%s: %s: failed to generate an output of a random case: %s', str(script), str(inputcase), e)
return []
inputcases.append(inputcase)

Expand All @@ -59,12 +64,12 @@ def run_integration_test(script: pathlib.Path, *, executable: pathlib.Path) -> b
logger.info('%s: compiling...', str(script))
with open(tempdir / 'main.cpp', 'wb') as fh:
try:
subprocess.check_call([str(executable), 'convert', str(script)], stdout=fh, timeout=10)
subprocess.check_call([str(executable), 'convert', str(script)], stdout=fh, timeout=20)
except subprocess.SubprocessError as e:
logger.error('%s: failed to compile from Python to C++: %s', str(script), e)
return False
try:
subprocess.check_call(['g++', '-std=c++17', '-Wall', '-O2', '-I', str(pathlib.Path('runtime', 'include')), '-o', str(tempdir / 'a.exe'), str(tempdir / 'main.cpp')], timeout=10)
subprocess.check_call(['g++', '-std=c++17', '-Wall', '-O2', '-I', str(pathlib.Path('runtime', 'include')), '-o', str(tempdir / 'a.exe'), str(tempdir / 'main.cpp')], timeout=20)
except subprocess.SubprocessError as e:
logger.error('%s: failed to compile from C++ to executable: %s', str(script), e)
return False
Expand All @@ -83,15 +88,17 @@ def run_integration_test(script: pathlib.Path, *, executable: pathlib.Path) -> b
matrix.append(('core', [str(executable), 'execute', '--target', 'core', str(script)]))
matrix.append(('C++', [str(tempdir / 'a.exe')]))
for title, command in matrix:
if title == 'restricted Python' and 'large' in inputcase.name:
if title == 'restricted Python' and (NO_RPYTHON in inputcase.name or CXX_ONLY in inputcase.name):
continue
if title == 'core' and CXX_ONLY in inputcase.name:
continue
if 'wip' in inputcase.name:
continue

logger.info('%s: %s: running as %s...', str(script), str(inputcase), title)
with open(inputcase, 'rb') as fh:
try:
actual = subprocess.check_output(command, stdin=fh, timeout=10)
actual = subprocess.check_output(command, stdin=fh, timeout=20)
except subprocess.SubprocessError as e:
logger.error('%s: %s: failed to run as %s: %s', str(script), str(inputcase), title, e)
return False
Expand Down
90 changes: 49 additions & 41 deletions src/Jikka/CPlusPlus/Convert/FromCore.hs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ import qualified Jikka.CPlusPlus.Language.Util as Y
import Jikka.Common.Alpha
import Jikka.Common.Error
import qualified Jikka.Core.Format as X (formatBuiltinIsolated, formatType)
import qualified Jikka.Core.Language.Beta as X
import qualified Jikka.Core.Language.BuiltinPatterns as X
import qualified Jikka.Core.Language.Expr as X
import qualified Jikka.Core.Language.TypeCheck as X
Expand Down Expand Up @@ -262,77 +261,86 @@ runAppBuiltin f args = wrapError' ("converting builtin " ++ X.formatBuiltinIsola
X.Permute -> go2 $ \e1 e2 -> Y.Call (Y.Function "jikka::permute" []) [e1, e2]
X.MultiChoose -> go2 $ \e1 e2 -> Y.Call (Y.Function "jikka::multichoose" []) [e1, e2]

runExpr :: (MonadAlpha m, MonadError Error m) => Env -> X.Expr -> m Y.Expr
runExpr :: (MonadAlpha m, MonadError Error m) => Env -> X.Expr -> m ([Y.Statement], Y.Expr)
runExpr env = \case
X.Var x -> Y.Var <$> lookupVarName env x
X.Lit lit -> runLiteral lit
X.Var x -> do
y <- lookupVarName env x
return ([], Y.Var y)
X.Lit lit -> do
lit <- runLiteral lit
return ([], lit)
X.If' _ e1 e2 e3 -> do
(stmts1, e1) <- runExpr env e1
(stmts2, e2) <- runExpr env e2
(stmts3, e3) <- runExpr env e3
case (stmts2, stmts3) of
([], []) ->
return (stmts1, Y.Cond e1 e2 e3)
_ -> do
phi <- newFreshName LocalNameKind ""
let assign = Y.Assign . Y.AssignExpr Y.SimpleAssign (Y.LeftVar phi)
return (stmts1 ++ [Y.If e1 (stmts2 ++ [assign e2]) (Just (stmts3 ++ [assign e3]))], Y.Var phi)
e@(X.App _ _) -> do
let (f, args) = X.curryApp e
args <- mapM (runExpr env) args
(e, args) <- case f of
case f of
X.Lit (X.LitBuiltin builtin) -> do
let arity = arityOfBuiltin builtin
if length args < arity
then do
e <- runExpr env e
let (ts, ret) = X.uncurryFunTy (X.builtinToType builtin)
ts <- mapM runType ts
ret <- runType ret
xs <- replicateM (arity - length args) (renameVarName LocalArgumentNameKind "_")
xs <- replicateM (arity - length args) (newFreshName LocalArgumentNameKind "")
e <- runAppBuiltin builtin (map snd args ++ map Y.Var xs)
let (_, e') = foldr (\(t, x) (ret, e) -> (Y.TyFunction ret [t], Y.Lam [(t, x)] ret [Y.Return e])) (ret, e) (zip (drop (length args) ts) xs)
return (e', [])
else do
e <- runAppBuiltin builtin (take arity args)
return (e, drop arity args)
e -> do
e <- runExpr env e
return (e, args)
return $ foldl (\e arg -> Y.Call (Y.Callable e) [arg]) e args
return (concatMap fst args, e')
else
if length args == arity
then do
e <- runAppBuiltin builtin (map snd args)
return (concatMap fst args, e)
else do
e <- runAppBuiltin builtin (take arity (map snd args))
return (concatMap fst args, Y.Call (Y.Callable e) (drop arity (map snd args)))
_ -> do
(stmts, f) <- runExpr env f
return (stmts ++ concatMap fst args, Y.Call (Y.Callable f) (map snd args))
e@(X.Lam _ _ _) -> do
let (args, body) = X.uncurryLam e
ys <- mapM (renameVarName LocalArgumentNameKind . fst) args
let env' = reverse (zipWith (\(x, t) y -> (x, t, y)) args ys) ++ env
ret <- runType =<< typecheckExpr env' body
body <- runExprToStatements env' body
(stmts, body) <- runExpr env' body
ts <- mapM (runType . snd) args
let (_, [Y.Return e]) = foldr (\(t, y) (ret, body) -> (Y.TyFunction ret [t], [Y.Return (Y.Lam [(t, y)] ret body)])) (ret, body) (zip ts ys)
return e
X.Let x _ e1 e2 -> runExpr env =<< X.substitute x e1 e2

runExprToStatements :: (MonadAlpha m, MonadError Error m) => Env -> X.Expr -> m [Y.Statement]
runExprToStatements env = \case
let (_, [Y.Return e]) = foldr (\(t, y) (ret, body) -> (Y.TyFunction ret [t], [Y.Return (Y.Lam [(t, y)] ret body)])) (ret, stmts ++ [Y.Return body]) (zip ts ys)
return ([], e)
X.Let x t e1 e2 -> do
y <- renameVarName LocalNameKind x
t' <- runType t
e1 <- runExpr env e1
e2 <- runExprToStatements ((x, t, y) : env) e2
return $ Y.Declare t' y (Just e1) : e2
X.If' _ e1 e2 e3 -> do
e1 <- runExpr env e1
e2 <- runExprToStatements env e2
e3 <- runExprToStatements env e3
return [Y.If e1 e2 (Just e3)]
e -> do
e <- runExpr env e
return [Y.Return e]
(stmts1, e1) <- runExpr env e1
(stmts2, e2) <- runExpr ((x, t, y) : env) e2
return (stmts1 ++ Y.Declare t' y (Just e1) : stmts2, e2)

runToplevelFunDef :: (MonadAlpha m, MonadError Error m) => Env -> Y.VarName -> [(X.VarName, X.Type)] -> X.Type -> X.Expr -> m [Y.ToplevelStatement]
runToplevelFunDef env f args ret body = do
ret <- runType ret
args <- forM args $ \(x, t) -> do
y <- renameVarName ArgumentNameKind x
return (x, t, y)
body <- runExprToStatements (reverse args ++ env) body
(stmts, result) <- runExpr (reverse args ++ env) body
args <- forM args $ \(_, t, y) -> do
t <- runType t
return (t, y)
return [Y.FunDef ret f args body]
return [Y.FunDef ret f args (stmts ++ [Y.Return result])]

runToplevelVarDef :: (MonadAlpha m, MonadError Error m) => Env -> Y.VarName -> X.Type -> X.Expr -> m [Y.ToplevelStatement]
runToplevelVarDef env x t e = do
t <- runType t
e <- runExpr env e
return [Y.VarDef t x e]
(stmts, e) <- runExpr env e
case stmts of
[] -> return [Y.VarDef t x e]
_ -> return [Y.VarDef t x (Y.Call (Y.Callable (Y.Lam [] t (stmts ++ [Y.Return e]))) [])]

runMainRead :: (MonadAlpha m, MonadError Error m) => Y.VarName -> Y.Type -> m [Y.Statement]
runMainRead x t = do
Expand Down Expand Up @@ -431,8 +439,8 @@ runToplevelExpr env = \case
args <- forM args $ \(x, t) -> do
y <- renameVarName ArgumentNameKind x
return (x, t, y)
e <- runExpr (reverse args ++ env) body
let body = [Y.Return e]
(stmts, e) <- runExpr (reverse args ++ env) body
let body = stmts ++ [Y.Return e]
args' <- forM args $ \(_, t, y) -> do
t <- runType t
return (t, y)
Expand All @@ -442,8 +450,8 @@ runToplevelExpr env = \case
t <- runType t
y <- newFreshName ArgumentNameKind ""
return (t, y)
e <- runExpr env e
let body = [Y.Return (Y.Call (Y.Callable e) (map (Y.Var . snd) args))]
(stmts, e) <- runExpr env e
let body = stmts ++ [Y.Return (Y.Call (Y.Callable e) (map (Y.Var . snd) args))]
return (args, body)
ret <- runType ret
let solve = [Y.FunDef ret f args body]
Expand Down
8 changes: 7 additions & 1 deletion src/Jikka/Core/Convert.hs
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,14 @@ import Jikka.Common.Alpha
import Jikka.Common.Error
import qualified Jikka.Core.Convert.Alpha as Alpha
import qualified Jikka.Core.Convert.Beta as Beta
import qualified Jikka.Core.Convert.BubbleLet as BubbleLet
import qualified Jikka.Core.Convert.CloseAll as CloseAll
import qualified Jikka.Core.Convert.CloseMin as CloseMin
import qualified Jikka.Core.Convert.CloseSum as CloseSum
import qualified Jikka.Core.Convert.ConstantFolding as ConstantFolding
import qualified Jikka.Core.Convert.ConstantPropagation as ConstantPropagation
import qualified Jikka.Core.Convert.CumulativeSum as CumulativeSum
import qualified Jikka.Core.Convert.Eta as Eta
import qualified Jikka.Core.Convert.MakeScanl as MakeScanl
import qualified Jikka.Core.Convert.MatrixExponentiation as MatrixExponentiation
import qualified Jikka.Core.Convert.PropagateMod as PropagateMod
Expand Down Expand Up @@ -54,7 +57,10 @@ run' prog = do
prog <- CloseSum.run prog
prog <- CloseAll.run prog
prog <- CloseMin.run prog
StrengthReduction.run prog
prog <- CumulativeSum.run prog
prog <- BubbleLet.run prog
prog <- StrengthReduction.run prog
Eta.run prog

run :: (MonadAlpha m, MonadError Error m) => Program -> m Program
run prog =
Expand Down
Loading

0 comments on commit 33f5ed1

Please sign in to comment.