From d919c0d52c1415ef22a1a9858a77892cdd6c8716 Mon Sep 17 00:00:00 2001 From: Kimiyuki Onaka Date: Sat, 10 Jul 2021 07:02:30 +0900 Subject: [PATCH 1/9] feat(core): Update Jikka.Core.Convert.Close{All,Min,Sum} --- src/Jikka/Core/Convert/CloseAll.hs | 73 ++++++++++++++++---- src/Jikka/Core/Convert/CloseMin.hs | 105 ++++++++++++++++++++++------- src/Jikka/Core/Convert/CloseSum.hs | 1 - 3 files changed, 142 insertions(+), 37 deletions(-) diff --git a/src/Jikka/Core/Convert/CloseAll.hs b/src/Jikka/Core/Convert/CloseAll.hs index cb7ee00f..dbb7c2c1 100644 --- a/src/Jikka/Core/Convert/CloseAll.hs +++ b/src/Jikka/Core/Convert/CloseAll.hs @@ -14,6 +14,8 @@ module Jikka.Core.Convert.CloseAll -- * internal rules rule, + reduceAll, + reduceAny, ) where @@ -23,17 +25,66 @@ import Jikka.Core.Language.BuiltinPatterns import Jikka.Core.Language.Expr import Jikka.Core.Language.Lint import Jikka.Core.Language.RewriteRules +import Jikka.Core.Language.Util -rule :: Monad m => RewriteRule m -rule = simpleRewriteRule $ \case - -- reduce `Reversed` - All' (Reversed' _ xs) -> Just $ All' xs - Any' (Reversed' _ xs) -> Just $ Any' xs - -- reduce `Sorted` - All' (Sorted' _ xs) -> Just $ All' xs - Any' (Sorted' _ xs) -> Just $ Any' xs - -- reduce `Map` - _ -> Nothing +reduceAll :: MonadAlpha m => RewriteRule m +reduceAll = + let return' = return . Just + in RewriteRule $ \_ -> \case + -- list build functions + All' (Nil' _) -> return' LitTrue + All' (Cons' _ x xs) -> return' $ And' x (All' xs) + -- list map functions + All' (Reversed' _ xs) -> return' $ All' xs + All' (Sorted' _ xs) -> return' $ All' xs + All' (Filter' _ f xs) -> do + x <- genVarName' + return' $ All' (Map' BoolTy BoolTy (Lam x BoolTy (Implies' (App f (Var x)) (Var x))) xs) + All' (Map' _ _ f xs) -> case f of + Lam x _ (Not' e) -> do + return' $ Not' (Any' (Map' BoolTy BoolTy (Lam x BoolTy e) xs)) + Lam x _ (And' e1 e2) -> do + x1 <- genVarName x + x2 <- genVarName x + return' $ And' (All' (Map' BoolTy BoolTy (Lam x1 BoolTy e1) xs)) (All' (Map' BoolTy BoolTy (Lam x2 BoolTy e2) xs)) + _ -> return Nothing + -- others + _ -> return Nothing + +reduceAny :: MonadAlpha m => RewriteRule m +reduceAny = + let return' = return . Just + in RewriteRule $ \_ -> \case + -- list build functions + Any' (Nil' _) -> return' LitFalse + Any' (Cons' _ x xs) -> return' $ Or' x (Any' xs) + -- list map functions + Any' (Reversed' _ xs) -> return' $ Any' xs + Any' (Sorted' _ xs) -> return' $ Any' xs + Any' (Filter' _ f xs) -> do + x <- genVarName' + return' $ Any' (Map' BoolTy BoolTy (Lam x BoolTy (And' (App f (Var x)) (Var x))) xs) + Any' (Map' _ _ f xs) -> case f of + Lam x _ (Not' e) -> do + return' $ Not' (All' (Map' BoolTy BoolTy (Lam x BoolTy e) xs)) + Lam x _ (Or' e1 e2) -> do + x1 <- genVarName x + x2 <- genVarName x + return' $ Or' (Any' (Map' BoolTy BoolTy (Lam x1 BoolTy e1) xs)) (Any' (Map' BoolTy BoolTy (Lam x2 BoolTy e2) xs)) + Lam x _ (Implies' e1 e2) -> do + x1 <- genVarName x + x2 <- genVarName x + return' $ Or' (Any' (Map' BoolTy BoolTy (Lam x1 BoolTy (Negate' e1)) xs)) (Any' (Map' BoolTy BoolTy (Lam x2 BoolTy e2) xs)) + _ -> return Nothing + -- others + _ -> return Nothing + +rule :: MonadAlpha m => RewriteRule m +rule = + mconcat + [ reduceAll, + reduceAny + ] runProgram :: MonadAlpha m => Program -> m Program runProgram = applyRewriteRuleProgram' rule @@ -64,11 +115,9 @@ runProgram = applyRewriteRuleProgram' rule -- -- * `Nil` \(: \forall \alpha. \list(\alpha)\) -- * `Cons` \(: \forall \alpha. \alpha \to \list(\alpha) \to \list(\alpha)\) --- * `Range1` \(: \int \to \list(\int)\) -- -- === List Map functions -- --- * `Scanl` \(: \forall \alpha \beta. (\beta \to \alpha \to \beta) \to \beta \to \list(\alpha) \to \list(\beta)\) -- * `Map` \(: \forall \alpha \beta. (\alpha \to \beta) \to \list(\alpha) \to \list(\beta)\) -- * `Filter` \(: \forall \alpha \beta. (\alpha \to \bool) \to \list(\alpha) \to \list(\beta)\) -- * `Reversed` \(: \forall \alpha. \list(\alpha) \to \list(\alpha)\) diff --git a/src/Jikka/Core/Convert/CloseMin.hs b/src/Jikka/Core/Convert/CloseMin.hs index 12ed9ac0..ccd1b05d 100644 --- a/src/Jikka/Core/Convert/CloseMin.hs +++ b/src/Jikka/Core/Convert/CloseMin.hs @@ -14,6 +14,10 @@ module Jikka.Core.Convert.CloseMin -- * internal rules rule, + reduceMin, + reduceMax, + reduceArgMin, + reduceArgMax, ) where @@ -25,35 +29,89 @@ import Jikka.Core.Language.FreeVars import Jikka.Core.Language.Lint import Jikka.Core.Language.RewriteRules -rule :: Monad m => RewriteRule m -rule = simpleRewriteRule $ \case - -- reduce `Reversed` - Max1' t (Reversed' _ xs) -> Just $ Max1' t xs - Min1' t (Reversed' _ xs) -> Just $ Min1' t xs - ArgMax' t (Reversed' _ xs) -> Just $ Minus' (Minus' (Len' t xs) (ArgMax' t xs)) Lit1 +reduceMin :: Monad m => RewriteRule m +reduceMin = simpleRewriteRule $ \case + -- list build functions + Min1' t (Nil' _) -> Just $ Bottom' t "no minimum in empty list" + Min1' _ (Cons' _ e (Nil' _)) -> Just e + Min1' t (Cons' _ e (Cons' _ e' es)) -> Just $ Min2' t e (Min1' t (Cons' t e' es)) + -- list map functions + Min1' t (Reversed' _ es) -> Just $ Min1' t es + Min1' t (Cons' _ e (Reversed' _ es)) -> Just $ Min1' t (Cons' t e es) + Min1' t (Sorted' _ es) -> Just $ Min1' t es + Min1' t (Cons' _ e (Sorted' _ es)) -> Just $ Min1' t (Cons' t e es) + Min1' t (Map' t1 t2 f es) -> case f of + Lam x _ e | x `isUnusedVar` e -> Just e + Lam x _ (Min2' _ e1 e2) -> Just $ Min2' t (Min1' t (Map' t1 t2 (Lam x t e1) es)) (Min1' t (Map' t1 t2 (Lam x t e2) es)) + Lam x _ (Negate' e) -> Just $ Negate' (Max1' t (Map' t1 t2 (Lam x IntTy e) es)) + Lam x _ (Plus' e1 e2) | x `isUnusedVar` e1 -> Just $ Plus' e1 (Min1' t (Map' t1 t2 (Lam x IntTy e2) es)) + Lam x _ (Plus' e1 e2) | x `isUnusedVar` e2 -> Just $ Plus' (Min1' t (Map' t1 t2 (Lam x IntTy e1) es)) e1 + _ -> Nothing + Min1' t (Cons' _ e0 (Map' t1 t2 f xs)) -> case f of + Lam x _ e | x `isUnusedVar` e -> Just $ If' t (Equal' IntTy (Len' t xs) Lit0) e0 (Min2' t e0 e) + Lam x _ (Min2' _ e1 e2) -> Just $ Min2' t (Min1' t (Cons' t e0 (Map' t1 t2 (Lam x t e1) xs))) (Min1' t (Cons' t e0 (Map' t1 t2 (Lam x t e2) xs))) + Lam x _ (Negate' e) -> Just $ Negate' (Max1' t (Cons' t (Negate' e0) (Map' t1 t2 (Lam x IntTy e) xs))) + Lam x _ (Plus' e1 e2) | x `isUnusedVar` e1 -> Just $ Plus' e1 (Min1' t (Cons' t (Minus' e0 e1) (Map' t1 t2 (Lam x IntTy e2) xs))) + Lam x _ (Plus' e1 e2) | x `isUnusedVar` e2 -> Just $ Plus' (Min1' t (Cons' t (Minus' e0 e1) (Map' t1 t2 (Lam x IntTy e1) xs))) e1 + _ -> Nothing + _ -> Nothing + +reduceMax :: Monad m => RewriteRule m +reduceMax = simpleRewriteRule $ \case + -- list build functions + Max1' t (Nil' _) -> Just $ Bottom' t "no maximum in empty list" + Max1' _ (Cons' _ e (Nil' _)) -> Just e + Max1' t (Cons' _ e (Cons' _ e' es)) -> Just $ Max2' t e (Max1' t (Cons' t e' es)) + -- list map functions + Max1' t (Reversed' _ es) -> Just $ Max1' t es + Max1' t (Cons' _ e (Reversed' _ es)) -> Just $ Max1' t (Cons' t e es) + Max1' t (Sorted' _ es) -> Just $ Max1' t es + Max1' t (Cons' _ e (Sorted' _ es)) -> Just $ Max1' t (Cons' t e es) + Max1' t (Map' t1 t2 f es) -> case f of + Lam x _ e | x `isUnusedVar` e -> Just e + Lam x _ (Max2' _ e1 e2) -> Just $ Max2' t (Map' t1 t2 (Lam x t e1) es) (Map' t1 t2 (Lam x t e2) es) + Lam x _ (Negate' e) -> Just $ Negate' (Min1' t2 (Map' t1 t2 (Lam x IntTy e) es)) + Lam x _ (Plus' e1 e2) | x `isUnusedVar` e1 -> Just $ Plus' e1 (Max1' t2 (Map' t1 t2 (Lam x IntTy e2) es)) + Lam x _ (Plus' e1 e2) | x `isUnusedVar` e2 -> Just $ Plus' (Max1' t2 (Map' t1 t2 (Lam x IntTy e1) es)) e1 + _ -> Nothing + Max1' t (Cons' _ e0 (Map' t1 t2 f xs)) -> case f of + Lam x _ e | x `isUnusedVar` e -> Just $ If' t (Equal' IntTy (Len' t xs) Lit0) e0 (Max2' t e0 e) + Lam x _ (Max2' _ e1 e2) -> Just $ Max2' t (Max1' t (Cons' t e0 (Map' t1 t2 (Lam x t e1) xs))) (Max1' t (Cons' t e0 (Map' t1 t2 (Lam x t e2) xs))) + Lam x _ (Negate' e) -> Just $ Negate' (Min1' t (Cons' t (Negate' e0) (Map' t1 t2 (Lam x IntTy e) xs))) + Lam x _ (Plus' e1 e2) | x `isUnusedVar` e1 -> Just $ Plus' e1 (Max1' t (Cons' t (Minus' e0 e1) (Map' t1 t2 (Lam x IntTy e2) xs))) + Lam x _ (Plus' e1 e2) | x `isUnusedVar` e2 -> Just $ Plus' (Max1' t (Cons' t (Minus' e0 e1) (Map' t1 t2 (Lam x IntTy e1) xs))) e1 + _ -> Nothing + _ -> Nothing + +-- | TODO: implement this +reduceArgMin :: Monad m => RewriteRule m +reduceArgMin = simpleRewriteRule $ \case + -- list map functions ArgMin' t (Reversed' _ xs) -> Just $ Minus' (Minus' (Len' t xs) (ArgMin' t xs)) Lit1 - -- reduce `Sorted` - Max1' t (Sorted' _ xs) -> Just $ Max1' t xs - Min1' t (Sorted' _ xs) -> Just $ Min1' t xs - -- reduce `Map` - Max1' _ (Map' _ _ (Lam x _ e) _) | x `isUnusedVar` e -> Just e - Max1' _ (Map' t1 t2 (Lam x t (Max2' t' e1 e2)) xs) -> Just $ Max2' t' (Map' t1 t2 (Lam x t e1) xs) (Map' t1 t2 (Lam x t e2) xs) - Max1' _ (Map' t1 t2 (Lam x t (Negate' e)) xs) -> Just $ Negate' (Min1' t2 (Map' t1 t2 (Lam x t e) xs)) - Max1' _ (Map' t1 t2 (Lam x t (Plus' e1 e2)) xs) | x `isUnusedVar` e1 -> Just $ Plus' e1 (Max1' t2 (Map' t1 t2 (Lam x t e2) xs)) - Max1' _ (Map' t1 t2 (Lam x t (Plus' e1 e2)) xs) | x `isUnusedVar` e2 -> Just $ Plus' (Max1' t2 (Map' t1 t2 (Lam x t e1) xs)) e1 - Min1' _ (Map' _ _ (Lam x _ e) _) | x `isUnusedVar` e -> Just e - Min1' _ (Map' t1 t2 (Lam x t (Min2' t' e1 e2)) xs) -> Just $ Min2' t' (Map' t1 t2 (Lam x t e1) xs) (Map' t1 t2 (Lam x t e2) xs) - Min1' _ (Map' t1 t2 (Lam x t (Negate' e)) xs) -> Just $ Negate' (Max1' t2 (Map' t1 t2 (Lam x t e) xs)) - Min1' _ (Map' t1 t2 (Lam x t (Plus' e1 e2)) xs) | x `isUnusedVar` e1 -> Just $ Plus' e1 (Min1' t2 (Map' t1 t2 (Lam x t e2) xs)) - Min1' _ (Map' t1 t2 (Lam x t (Plus' e1 e2)) xs) | x `isUnusedVar` e2 -> Just $ Plus' (Min1' t2 (Map' t1 t2 (Lam x t e1) xs)) e1 - ArgMax' _ (Map' _ _ (Lam x t e) xs) | x `isUnusedVar` e -> Just $ Minus' (Len' t xs) Lit1 - ArgMax' _ (Map' t1 t2 (Lam x t (Plus' e1 e2)) xs) | x `isUnusedVar` e1 -> Just $ ArgMax' t2 (Map' t1 t2 (Lam x t e2) xs) - ArgMax' _ (Map' t1 t2 (Lam x t (Plus' e1 e2)) xs) | x `isUnusedVar` e2 -> Just $ ArgMax' t2 (Map' t1 t2 (Lam x t e1) xs) ArgMin' _ (Map' _ _ (Lam x _ e) _) | x `isUnusedVar` e -> Just Lit0 ArgMin' _ (Map' t1 t2 (Lam x t (Plus' e1 e2)) xs) | x `isUnusedVar` e1 -> Just $ ArgMin' t2 (Map' t1 t2 (Lam x t e2) xs) ArgMin' _ (Map' t1 t2 (Lam x t (Plus' e1 e2)) xs) | x `isUnusedVar` e2 -> Just $ ArgMin' t2 (Map' t1 t2 (Lam x t e1) xs) _ -> Nothing +-- | TODO: implement this +reduceArgMax :: Monad m => RewriteRule m +reduceArgMax = simpleRewriteRule $ \case + -- list map functions + ArgMax' t (Reversed' _ xs) -> Just $ Minus' (Minus' (Len' t xs) (ArgMax' t xs)) Lit1 + ArgMax' _ (Map' _ _ (Lam x t e) xs) | x `isUnusedVar` e -> Just $ Minus' (Len' t xs) Lit1 + ArgMax' _ (Map' t1 t2 (Lam x t (Plus' e1 e2)) xs) | x `isUnusedVar` e1 -> Just $ ArgMax' t2 (Map' t1 t2 (Lam x t e2) xs) + ArgMax' _ (Map' t1 t2 (Lam x t (Plus' e1 e2)) xs) | x `isUnusedVar` e2 -> Just $ ArgMax' t2 (Map' t1 t2 (Lam x t e1) xs) + _ -> Nothing + +rule :: Monad m => RewriteRule m +rule = + mconcat + [ reduceMin, + reduceMax, + reduceArgMin, + reduceArgMax + ] + runProgram :: MonadAlpha m => Program -> m Program runProgram = applyRewriteRuleProgram' rule @@ -87,7 +145,6 @@ runProgram = applyRewriteRuleProgram' rule -- -- === List Map functions -- --- * `Scanl` \(: \forall \alpha \beta. (\beta \to \alpha \to \beta) \to \beta \to \list(\alpha) \to \list(\beta)\) -- * `Map` \(: \forall \alpha \beta. (\alpha \to \beta) \to \list(\alpha) \to \list(\beta)\) -- * `Filter` \(: \forall \alpha \beta. (\alpha \to \bool) \to \list(\alpha) \to \list(\beta)\) -- * `Reversed` \(: \forall \alpha. \list(\alpha) \to \list(\alpha)\) diff --git a/src/Jikka/Core/Convert/CloseSum.hs b/src/Jikka/Core/Convert/CloseSum.hs index 41419176..027714ae 100644 --- a/src/Jikka/Core/Convert/CloseSum.hs +++ b/src/Jikka/Core/Convert/CloseSum.hs @@ -188,7 +188,6 @@ runProgram = applyRewriteRuleProgram' rule -- -- === List Map functions -- --- * `Scanl` \(: \forall \alpha \beta. (\beta \to \alpha \to \beta) \to \beta \to \list(\alpha) \to \list(\beta)\) -- * `Map` \(: \forall \alpha \beta. (\alpha \to \beta) \to \list(\alpha) \to \list(\beta)\) -- * `Filter` \(: \forall \alpha \beta. (\alpha \to \bool) \to \list(\alpha) \to \list(\beta)\) -- * `Reversed` \(: \forall \alpha. \list(\alpha) \to \list(\alpha)\) From 36f65c0dbb4eaac94282fee78e0409d4a51c24fa Mon Sep 17 00:00:00 2001 From: Kimiyuki Onaka Date: Sun, 11 Jul 2021 17:58:32 +0900 Subject: [PATCH 2/9] fix(core): Fix isOneArithmeticalExpr --- src/Jikka/Core/Language/ArithmeticalExpr.hs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Jikka/Core/Language/ArithmeticalExpr.hs b/src/Jikka/Core/Language/ArithmeticalExpr.hs index 17ad8b28..9fd4e22a 100644 --- a/src/Jikka/Core/Language/ArithmeticalExpr.hs +++ b/src/Jikka/Core/Language/ArithmeticalExpr.hs @@ -187,7 +187,7 @@ isZeroArithmeticalExpr :: ArithmeticalExpr -> Bool isZeroArithmeticalExpr e = normalizeArithmeticalExpr e == sumExprFromInteger 0 isOneArithmeticalExpr :: ArithmeticalExpr -> Bool -isOneArithmeticalExpr e = normalizeArithmeticalExpr e == sumExprFromInteger 0 +isOneArithmeticalExpr e = normalizeArithmeticalExpr e == sumExprFromInteger 1 unNPlusKPattern :: ArithmeticalExpr -> Maybe (VarName, Integer) unNPlusKPattern e = case normalizeArithmeticalExpr e of From d33323cd4e80ea8f879b955d1e939c0aa30c7095 Mon Sep 17 00:00:00 2001 From: Kimiyuki Onaka Date: Sun, 11 Jul 2021 18:34:09 +0900 Subject: [PATCH 3/9] feat(core): Add a module for eta-reduction --- src/Jikka/Core/Convert.hs | 4 +- src/Jikka/Core/Convert/Eta.hs | 72 ++++++++++++++++++++++++++++++ test/Jikka/Core/Convert/EtaSpec.hs | 36 +++++++++++++++ 3 files changed, 111 insertions(+), 1 deletion(-) create mode 100644 src/Jikka/Core/Convert/Eta.hs create mode 100644 test/Jikka/Core/Convert/EtaSpec.hs diff --git a/src/Jikka/Core/Convert.hs b/src/Jikka/Core/Convert.hs index e6082df7..1c8bd8ca 100644 --- a/src/Jikka/Core/Convert.hs +++ b/src/Jikka/Core/Convert.hs @@ -24,6 +24,7 @@ import qualified Jikka.Core.Convert.CloseMin as CloseMin import qualified Jikka.Core.Convert.CloseSum as CloseSum import qualified Jikka.Core.Convert.ConstantFolding as ConstantFolding import qualified Jikka.Core.Convert.ConstantPropagation as ConstantPropagation +import qualified Jikka.Core.Convert.Eta as Eta import qualified Jikka.Core.Convert.MakeScanl as MakeScanl import qualified Jikka.Core.Convert.MatrixExponentiation as MatrixExponentiation import qualified Jikka.Core.Convert.PropagateMod as PropagateMod @@ -54,7 +55,8 @@ run' prog = do prog <- CloseSum.run prog prog <- CloseAll.run prog prog <- CloseMin.run prog - StrengthReduction.run prog + prog <- StrengthReduction.run prog + Eta.run prog run :: (MonadAlpha m, MonadError Error m) => Program -> m Program run prog = diff --git a/src/Jikka/Core/Convert/Eta.hs b/src/Jikka/Core/Convert/Eta.hs new file mode 100644 index 00000000..a91ae3ba --- /dev/null +++ b/src/Jikka/Core/Convert/Eta.hs @@ -0,0 +1,72 @@ +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE LambdaCase #-} + +-- | +-- Module : Jikka.Core.Convert.Eta +-- Description : does eta-reductions and makes exprs pointful. / eta 簡約を行って式を pointful にします。 +-- Copyright : (c) Kimiyuki Onaka, 2021 +-- License : Apache License 2.0 +-- Maintainer : kimiyuki95@gmail.com +-- Stability : experimental +-- Portability : portable +module Jikka.Core.Convert.Eta + ( run, + + -- * internal rules + rule, + ) +where + +import Data.Maybe +import Jikka.Common.Alpha +import Jikka.Common.Error +import Jikka.Core.Language.BuiltinPatterns +import Jikka.Core.Language.Expr +import Jikka.Core.Language.Lint +import Jikka.Core.Language.RewriteRules +import Jikka.Core.Language.Util + +expandExpr :: MonadAlpha m => Type -> Expr -> m (Maybe Expr) +expandExpr t e = case (t, e) of + (FunTy t1 t2, Lam x _ body) -> do + body <- expandExpr t2 body + return $ Lam x t1 <$> body + (FunTy t1 t2, e) -> do + x <- genVarName' + let e' = App e (Var x) + e'' <- expandExpr t2 e' + return . Just $ Lam x t1 (fromMaybe e' e'') + _ -> return Nothing + +rule :: MonadAlpha m => RewriteRule m +rule = + let go :: MonadAlpha m => Expr -> Type -> (Expr -> Expr) -> m (Maybe Expr) + go e t f = (f <$>) <$> expandExpr t e + in RewriteRule $ \_ -> \case + Let x t e1 e2 -> go e1 t (\e1 -> Let x t e1 e2) + Iterate' t k f x -> go f (FunTy t t) (\f -> Iterate' t k f x) + Foldl' t1 t2 f init xs -> go f (FunTy t2 (FunTy t1 t1)) (\f -> Foldl' t1 t2 f init xs) + Scanl' t1 t2 f init xs -> go f (FunTy t2 (FunTy t1 t1)) (\f -> Scanl' t1 t2 f init xs) + Map' t1 t2 f xs -> go f (FunTy t1 t2) (\f -> Map' t1 t2 f xs) + Filter' t f xs -> go f (FunTy t BoolTy) (\f -> Filter' t f xs) + _ -> return Nothing + +runProgram :: MonadAlpha m => Program -> m Program +runProgram = applyRewriteRuleProgram' rule + +-- `run` does eta-reductions in some locations. +-- This aims to: + +-- * simplify other rewrite-rules + +-- * convert to C++ + +-- TODO: expand in toplevel-let too. +run :: (MonadAlpha m, MonadError Error m) => Program -> m Program +run prog = wrapError' "Jikka.Core.Convert.Eta" $ do + precondition $ do + ensureWellTyped prog + prog <- runProgram prog + postcondition $ do + ensureWellTyped prog + return prog diff --git a/test/Jikka/Core/Convert/EtaSpec.hs b/test/Jikka/Core/Convert/EtaSpec.hs new file mode 100644 index 00000000..a2e0b485 --- /dev/null +++ b/test/Jikka/Core/Convert/EtaSpec.hs @@ -0,0 +1,36 @@ +{-# LANGUAGE OverloadedStrings #-} + +module Jikka.Core.Convert.EtaSpec + ( spec, + ) +where + +import Jikka.Common.Alpha +import Jikka.Common.Error +import Jikka.Core.Convert.Eta (run) +import Jikka.Core.Language.Expr +import Test.Hspec + +run' :: Program -> Either Error Program +run' = flip evalAlphaT 0 . run + +spec :: Spec +spec = describe "run" $ do + it "works" $ do + let prog = + ResultExpr + ( Let + "plus" + (FunTy IntTy (FunTy IntTy IntTy)) + (Lit (LitBuiltin Plus)) + (Var "plus") + ) + let expected = + ResultExpr + ( Let + "plus" + (FunTy IntTy (FunTy IntTy IntTy)) + (Lam "$0" IntTy (Lam "$1" IntTy (App2 (Lit (LitBuiltin Plus)) (Var "$0") (Var "$1")))) + (Var "plus") + ) + run' prog `shouldBe` Right expected From 04778a025d0af11dc77c7bf965a260b89be107ca Mon Sep 17 00:00:00 2001 From: Kimiyuki Onaka Date: Sun, 11 Jul 2021 18:35:06 +0900 Subject: [PATCH 4/9] feat(core): Add a module to use cumulative sums --- src/Jikka/Core/Convert.hs | 2 + src/Jikka/Core/Convert/CumulativeSum.hs | 56 +++++++++++++++++++++ src/Jikka/Core/Language/ArithmeticalExpr.hs | 7 +++ 3 files changed, 65 insertions(+) create mode 100644 src/Jikka/Core/Convert/CumulativeSum.hs diff --git a/src/Jikka/Core/Convert.hs b/src/Jikka/Core/Convert.hs index 1c8bd8ca..038372fe 100644 --- a/src/Jikka/Core/Convert.hs +++ b/src/Jikka/Core/Convert.hs @@ -24,6 +24,7 @@ import qualified Jikka.Core.Convert.CloseMin as CloseMin import qualified Jikka.Core.Convert.CloseSum as CloseSum import qualified Jikka.Core.Convert.ConstantFolding as ConstantFolding import qualified Jikka.Core.Convert.ConstantPropagation as ConstantPropagation +import qualified Jikka.Core.Convert.CumulativeSum as CumulativeSum import qualified Jikka.Core.Convert.Eta as Eta import qualified Jikka.Core.Convert.MakeScanl as MakeScanl import qualified Jikka.Core.Convert.MatrixExponentiation as MatrixExponentiation @@ -55,6 +56,7 @@ run' prog = do prog <- CloseSum.run prog prog <- CloseAll.run prog prog <- CloseMin.run prog + prog <- CumulativeSum.run prog prog <- StrengthReduction.run prog Eta.run prog diff --git a/src/Jikka/Core/Convert/CumulativeSum.hs b/src/Jikka/Core/Convert/CumulativeSum.hs new file mode 100644 index 00000000..a0b6a4f4 --- /dev/null +++ b/src/Jikka/Core/Convert/CumulativeSum.hs @@ -0,0 +1,56 @@ +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE LambdaCase #-} + +-- | +-- Module : Jikka.Core.Convert.CumulativeSum +-- Description : processes queries like range sum query using cumulative sums. / 累積和を用いて range sum query のようなクエリを処理します。 +-- Copyright : (c) Kimiyuki Onaka, 2021 +-- License : Apache License 2.0 +-- Maintainer : kimiyuki95@gmail.com +-- Stability : experimental +-- Portability : portable +module Jikka.Core.Convert.CumulativeSum + ( run, + + -- * internal rules + rule, + ) +where + +import Jikka.Common.Alpha +import Jikka.Common.Error +import qualified Jikka.Core.Convert.Alpha as Alpha +import Jikka.Core.Language.ArithmeticalExpr +import Jikka.Core.Language.BuiltinPatterns +import Jikka.Core.Language.Expr +import Jikka.Core.Language.Lint +import Jikka.Core.Language.RewriteRules +import Jikka.Core.Language.Util + +rule :: MonadAlpha m => RewriteRule m +rule = RewriteRule $ \_ -> \case + Sum' (Map' _ _ (Lam x _ (At' _ a x')) (Range1' n)) -> do + case makeAffineFunctionFromArithmeticalExpr x (parseArithmeticalExpr x') of + Just (coeff, shift) | isOneArithmeticalExpr coeff -> do + y <- genVarName' + let e = + if isZeroArithmeticalExpr shift + then At' IntTy (Var y) n + else Minus' (At' IntTy (Var y) (Plus' n (formatArithmeticalExpr shift))) (At' IntTy (Var y) (formatArithmeticalExpr shift)) + return . Just $ + Let y (ListTy IntTy) (Scanl' IntTy IntTy (Lit (LitBuiltin Plus)) Lit0 a) e + _ -> return Nothing + _ -> return Nothing + +runProgram :: MonadAlpha m => Program -> m Program +runProgram = applyRewriteRuleProgram' rule + +run :: (MonadAlpha m, MonadError Error m) => Program -> m Program +run prog = wrapError' "Jikka.Core.Convert.CumulativeSum" $ do + precondition $ do + ensureWellTyped prog + prog <- Alpha.run prog + prog <- runProgram prog + postcondition $ do + ensureWellTyped prog + return prog diff --git a/src/Jikka/Core/Language/ArithmeticalExpr.hs b/src/Jikka/Core/Language/ArithmeticalExpr.hs index 9fd4e22a..dcb6d216 100644 --- a/src/Jikka/Core/Language/ArithmeticalExpr.hs +++ b/src/Jikka/Core/Language/ArithmeticalExpr.hs @@ -4,6 +4,7 @@ module Jikka.Core.Language.ArithmeticalExpr where +import Control.Arrow import Control.Monad import Control.Monad.ST import Control.Monad.Trans @@ -189,6 +190,7 @@ isZeroArithmeticalExpr e = normalizeArithmeticalExpr e == sumExprFromInteger 0 isOneArithmeticalExpr :: ArithmeticalExpr -> Bool isOneArithmeticalExpr e = normalizeArithmeticalExpr e == sumExprFromInteger 1 +-- | `unNPlusKPattern` recognizes a pattern of \(x + k\) for a variable \(x\) and an integer constant \(k \in \mathbb{Z}\). unNPlusKPattern :: ArithmeticalExpr -> Maybe (VarName, Integer) unNPlusKPattern e = case normalizeArithmeticalExpr e of SumExpr @@ -201,3 +203,8 @@ unNPlusKPattern e = case normalizeArithmeticalExpr e of sumExprConst = k } -> Just (x, k) _ -> Nothing + +-- | `makeAffineFunctionFromArithmeticalExpr` is a specialized version of `makeVectorFromArithmeticalExpr`. +-- This function returns \(a, b\) for a given variable \(x\) and a given expr \(e = a x + b\) where \(a, b\) which doesn't use \(x\) free. +makeAffineFunctionFromArithmeticalExpr :: VarName -> ArithmeticalExpr -> Maybe (ArithmeticalExpr, ArithmeticalExpr) +makeAffineFunctionFromArithmeticalExpr x es = first V.head <$> makeVectorFromArithmeticalExpr (V.singleton x) es From 5ae541018129b8cb1fddcb337f66d73afb0cc760 Mon Sep 17 00:00:00 2001 From: Kimiyuki Onaka Date: Sun, 11 Jul 2021 18:49:13 +0900 Subject: [PATCH 5/9] feat(core): Create Jikka.Core.Convert.BubbleLet --- src/Jikka/Core/Convert.hs | 2 ++ src/Jikka/Core/Convert/BubbleLet.hs | 51 +++++++++++++++++++++++++++++ 2 files changed, 53 insertions(+) create mode 100644 src/Jikka/Core/Convert/BubbleLet.hs diff --git a/src/Jikka/Core/Convert.hs b/src/Jikka/Core/Convert.hs index 038372fe..47d9c773 100644 --- a/src/Jikka/Core/Convert.hs +++ b/src/Jikka/Core/Convert.hs @@ -19,6 +19,7 @@ import Jikka.Common.Alpha import Jikka.Common.Error import qualified Jikka.Core.Convert.Alpha as Alpha import qualified Jikka.Core.Convert.Beta as Beta +import qualified Jikka.Core.Convert.BubbleLet as BubbleLet import qualified Jikka.Core.Convert.CloseAll as CloseAll import qualified Jikka.Core.Convert.CloseMin as CloseMin import qualified Jikka.Core.Convert.CloseSum as CloseSum @@ -57,6 +58,7 @@ run' prog = do prog <- CloseAll.run prog prog <- CloseMin.run prog prog <- CumulativeSum.run prog + prog <- BubbleLet.run prog prog <- StrengthReduction.run prog Eta.run prog diff --git a/src/Jikka/Core/Convert/BubbleLet.hs b/src/Jikka/Core/Convert/BubbleLet.hs new file mode 100644 index 00000000..e9678c10 --- /dev/null +++ b/src/Jikka/Core/Convert/BubbleLet.hs @@ -0,0 +1,51 @@ +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE LambdaCase #-} + +-- | +-- Module : Jikka.Core.Convert.BubbleLet +-- Description : bubbles let-exprs in higher-order functions. / 高階関数中の let 式を浮き上がらせます。 +-- Copyright : (c) Kimiyuki Onaka, 2021 +-- License : Apache License 2.0 +-- Maintainer : kimiyuki95@gmail.com +-- Stability : experimental +-- Portability : portable +module Jikka.Core.Convert.BubbleLet + ( run, + + -- * internal rules + rule, + ) +where + +import Jikka.Common.Alpha +import Jikka.Common.Error +import Jikka.Core.Language.BuiltinPatterns +import Jikka.Core.Language.Expr +import Jikka.Core.Language.FreeVars +import Jikka.Core.Language.Lint +import Jikka.Core.Language.RewriteRules + +rule :: MonadAlpha m => RewriteRule m +rule = + let go f cont = case f of + Lam x t (Let y t' e body) | x `isUnusedVar` e -> return . Just $ Let y t' e (cont (Lam x t body)) + _ -> return Nothing + in RewriteRule $ \_ -> \case + Iterate' t k f x -> go f (\f -> Iterate' t k f x) + Foldl' t1 t2 f init xs -> go f (\f -> Foldl' t1 t2 f init xs) + Scanl' t1 t2 f init xs -> go f (\f -> Scanl' t1 t2 f init xs) + Map' t1 t2 f xs -> go f (\f -> Map' t1 t2 f xs) + Filter' t f xs -> go f (\f -> Filter' t f xs) + _ -> return Nothing + +runProgram :: MonadAlpha m => Program -> m Program +runProgram = applyRewriteRuleProgram' rule + +run :: (MonadAlpha m, MonadError Error m) => Program -> m Program +run prog = wrapError' "Jikka.Core.Convert.BubbleLet" $ do + precondition $ do + ensureWellTyped prog + prog <- runProgram prog + postcondition $ do + ensureWellTyped prog + return prog From 064cb5ff609a684c6e7e15ac3c5576bb38b83155 Mon Sep 17 00:00:00 2001 From: Kimiyuki Onaka Date: Sun, 11 Jul 2021 19:51:38 +0900 Subject: [PATCH 6/9] feat(cxx): Fix a performance issue of generated C++ --- src/Jikka/CPlusPlus/Convert/FromCore.hs | 90 +++++++++++--------- src/Jikka/Core/Language/TypeCheck.hs | 12 +++ test/Jikka/CPlusPlus/Convert/FromCoreSpec.hs | 24 +++--- 3 files changed, 72 insertions(+), 54 deletions(-) diff --git a/src/Jikka/CPlusPlus/Convert/FromCore.hs b/src/Jikka/CPlusPlus/Convert/FromCore.hs index 37cda4b3..fd095b2c 100644 --- a/src/Jikka/CPlusPlus/Convert/FromCore.hs +++ b/src/Jikka/CPlusPlus/Convert/FromCore.hs @@ -24,7 +24,6 @@ import qualified Jikka.CPlusPlus.Language.Util as Y import Jikka.Common.Alpha import Jikka.Common.Error import qualified Jikka.Core.Format as X (formatBuiltinIsolated, formatType) -import qualified Jikka.Core.Language.Beta as X import qualified Jikka.Core.Language.BuiltinPatterns as X import qualified Jikka.Core.Language.Expr as X import qualified Jikka.Core.Language.TypeCheck as X @@ -262,59 +261,66 @@ runAppBuiltin f args = wrapError' ("converting builtin " ++ X.formatBuiltinIsola X.Permute -> go2 $ \e1 e2 -> Y.Call (Y.Function "jikka::permute" []) [e1, e2] X.MultiChoose -> go2 $ \e1 e2 -> Y.Call (Y.Function "jikka::multichoose" []) [e1, e2] -runExpr :: (MonadAlpha m, MonadError Error m) => Env -> X.Expr -> m Y.Expr +runExpr :: (MonadAlpha m, MonadError Error m) => Env -> X.Expr -> m ([Y.Statement], Y.Expr) runExpr env = \case - X.Var x -> Y.Var <$> lookupVarName env x - X.Lit lit -> runLiteral lit + X.Var x -> do + y <- lookupVarName env x + return ([], Y.Var y) + X.Lit lit -> do + lit <- runLiteral lit + return ([], lit) + X.If' _ e1 e2 e3 -> do + (stmts1, e1) <- runExpr env e1 + (stmts2, e2) <- runExpr env e2 + (stmts3, e3) <- runExpr env e3 + case (stmts2, stmts3) of + ([], []) -> + return (stmts1, Y.Cond e1 e2 e3) + _ -> do + phi <- newFreshName LocalNameKind "" + let assign = Y.Assign . Y.AssignExpr Y.SimpleAssign (Y.LeftVar phi) + return (stmts1 ++ [Y.If e1 (stmts2 ++ [assign e2]) (Just (stmts3 ++ [assign e3]))], Y.Var phi) e@(X.App _ _) -> do let (f, args) = X.curryApp e args <- mapM (runExpr env) args - (e, args) <- case f of + case f of X.Lit (X.LitBuiltin builtin) -> do let arity = arityOfBuiltin builtin if length args < arity then do - e <- runExpr env e let (ts, ret) = X.uncurryFunTy (X.builtinToType builtin) ts <- mapM runType ts ret <- runType ret - xs <- replicateM (arity - length args) (renameVarName LocalArgumentNameKind "_") + xs <- replicateM (arity - length args) (newFreshName LocalArgumentNameKind "") + e <- runAppBuiltin builtin (map snd args ++ map Y.Var xs) let (_, e') = foldr (\(t, x) (ret, e) -> (Y.TyFunction ret [t], Y.Lam [(t, x)] ret [Y.Return e])) (ret, e) (zip (drop (length args) ts) xs) - return (e', []) - else do - e <- runAppBuiltin builtin (take arity args) - return (e, drop arity args) - e -> do - e <- runExpr env e - return (e, args) - return $ foldl (\e arg -> Y.Call (Y.Callable e) [arg]) e args + return (concatMap fst args, e') + else + if length args == arity + then do + e <- runAppBuiltin builtin (map snd args) + return (concatMap fst args, e) + else do + e <- runAppBuiltin builtin (take arity (map snd args)) + return (concatMap fst args, Y.Call (Y.Callable e) (drop arity (map snd args))) + _ -> do + (stmts, f) <- runExpr env f + return (stmts ++ concatMap fst args, Y.Call (Y.Callable f) (map snd args)) e@(X.Lam _ _ _) -> do let (args, body) = X.uncurryLam e ys <- mapM (renameVarName LocalArgumentNameKind . fst) args let env' = reverse (zipWith (\(x, t) y -> (x, t, y)) args ys) ++ env ret <- runType =<< typecheckExpr env' body - body <- runExprToStatements env' body + (stmts, body) <- runExpr env' body ts <- mapM (runType . snd) args - let (_, [Y.Return e]) = foldr (\(t, y) (ret, body) -> (Y.TyFunction ret [t], [Y.Return (Y.Lam [(t, y)] ret body)])) (ret, body) (zip ts ys) - return e - X.Let x _ e1 e2 -> runExpr env =<< X.substitute x e1 e2 - -runExprToStatements :: (MonadAlpha m, MonadError Error m) => Env -> X.Expr -> m [Y.Statement] -runExprToStatements env = \case + let (_, [Y.Return e]) = foldr (\(t, y) (ret, body) -> (Y.TyFunction ret [t], [Y.Return (Y.Lam [(t, y)] ret body)])) (ret, stmts ++ [Y.Return body]) (zip ts ys) + return ([], e) X.Let x t e1 e2 -> do y <- renameVarName LocalNameKind x t' <- runType t - e1 <- runExpr env e1 - e2 <- runExprToStatements ((x, t, y) : env) e2 - return $ Y.Declare t' y (Just e1) : e2 - X.If' _ e1 e2 e3 -> do - e1 <- runExpr env e1 - e2 <- runExprToStatements env e2 - e3 <- runExprToStatements env e3 - return [Y.If e1 e2 (Just e3)] - e -> do - e <- runExpr env e - return [Y.Return e] + (stmts1, e1) <- runExpr env e1 + (stmts2, e2) <- runExpr ((x, t, y) : env) e2 + return (stmts1 ++ Y.Declare t' y (Just e1) : stmts2, e2) runToplevelFunDef :: (MonadAlpha m, MonadError Error m) => Env -> Y.VarName -> [(X.VarName, X.Type)] -> X.Type -> X.Expr -> m [Y.ToplevelStatement] runToplevelFunDef env f args ret body = do @@ -322,17 +328,19 @@ runToplevelFunDef env f args ret body = do args <- forM args $ \(x, t) -> do y <- renameVarName ArgumentNameKind x return (x, t, y) - body <- runExprToStatements (reverse args ++ env) body + (stmts, result) <- runExpr (reverse args ++ env) body args <- forM args $ \(_, t, y) -> do t <- runType t return (t, y) - return [Y.FunDef ret f args body] + return [Y.FunDef ret f args (stmts ++ [Y.Return result])] runToplevelVarDef :: (MonadAlpha m, MonadError Error m) => Env -> Y.VarName -> X.Type -> X.Expr -> m [Y.ToplevelStatement] runToplevelVarDef env x t e = do t <- runType t - e <- runExpr env e - return [Y.VarDef t x e] + (stmts, e) <- runExpr env e + case stmts of + [] -> return [Y.VarDef t x e] + _ -> return [Y.VarDef t x (Y.Call (Y.Callable (Y.Lam [] t (stmts ++ [Y.Return e]))) [])] runMainRead :: (MonadAlpha m, MonadError Error m) => Y.VarName -> Y.Type -> m [Y.Statement] runMainRead x t = do @@ -431,8 +439,8 @@ runToplevelExpr env = \case args <- forM args $ \(x, t) -> do y <- renameVarName ArgumentNameKind x return (x, t, y) - e <- runExpr (reverse args ++ env) body - let body = [Y.Return e] + (stmts, e) <- runExpr (reverse args ++ env) body + let body = stmts ++ [Y.Return e] args' <- forM args $ \(_, t, y) -> do t <- runType t return (t, y) @@ -442,8 +450,8 @@ runToplevelExpr env = \case t <- runType t y <- newFreshName ArgumentNameKind "" return (t, y) - e <- runExpr env e - let body = [Y.Return (Y.Call (Y.Callable e) (map (Y.Var . snd) args))] + (stmts, e) <- runExpr env e + let body = stmts ++ [Y.Return (Y.Call (Y.Callable e) (map (Y.Var . snd) args))] return (args, body) ret <- runType ret let solve = [Y.FunDef ret f args body] diff --git a/src/Jikka/Core/Language/TypeCheck.hs b/src/Jikka/Core/Language/TypeCheck.hs index 7d3324d4..696dc470 100644 --- a/src/Jikka/Core/Language/TypeCheck.hs +++ b/src/Jikka/Core/Language/TypeCheck.hs @@ -117,6 +117,18 @@ literalToType = \case LitNil t -> ListTy t LitBottom t _ -> t +arityOfBuiltin :: Builtin -> Int +arityOfBuiltin = \case + Min2 _ -> 2 + Max2 _ -> 2 + Foldl _ _ -> 3 + Iterate _ -> 3 + At _ -> 2 + Min1 _ -> 1 + Max1 _ -> 1 + Proj _ _ -> 1 + builtin -> length (fst (uncurryFunTy (builtinToType builtin))) + type TypeEnv = [(VarName, Type)] -- | `typecheckExpr` checks that the given `Expr` has the correct types. diff --git a/test/Jikka/CPlusPlus/Convert/FromCoreSpec.hs b/test/Jikka/CPlusPlus/Convert/FromCoreSpec.hs index d3ebdb4c..d6be8032 100644 --- a/test/Jikka/CPlusPlus/Convert/FromCoreSpec.hs +++ b/test/Jikka/CPlusPlus/Convert/FromCoreSpec.hs @@ -36,20 +36,18 @@ spec = describe "run" $ do Y.TyInt64 "f_0" [(Y.TyInt64, "n_1")] - [ Y.If - (Y.BinOp Y.Equal (Y.Var "n_1") (Y.Lit (Y.LitInt64 0))) - [Y.Return (Y.Lit (Y.LitInt64 1))] - ( Just - [ Y.Return - ( Y.BinOp - Y.Mul - (Y.Var "n_1") - ( Y.Call - (Y.Callable (Y.Var "f_0")) - [Y.BinOp Y.Sub (Y.Var "n_1") (Y.Lit (Y.LitInt64 1))] - ) + [ Y.Return + ( Y.Cond + (Y.BinOp Y.Equal (Y.Var "n_1") (Y.Lit (Y.LitInt64 0))) + (Y.Lit (Y.LitInt64 1)) + ( Y.BinOp + Y.Mul + (Y.Var "n_1") + ( Y.Call + (Y.Callable (Y.Var "f_0")) + [Y.BinOp Y.Sub (Y.Var "n_1") (Y.Lit (Y.LitInt64 1))] ) - ] + ) ) ] let expectedSolve = From 3199c7d1cf2c8062ba833e29af4fdfd19c74682b Mon Sep 17 00:00:00 2001 From: Kimiyuki Onaka Date: Sun, 11 Jul 2021 20:13:29 +0900 Subject: [PATCH 7/9] test: Create examples/data/static_range_sum.{generator,solver}.py --- examples/data/static_range_sum.generator.py | 17 ++++++++++++++++ examples/data/static_range_sum.solver.py | 22 +++++++++++++++++++++ 2 files changed, 39 insertions(+) create mode 100644 examples/data/static_range_sum.generator.py create mode 100644 examples/data/static_range_sum.solver.py diff --git a/examples/data/static_range_sum.generator.py b/examples/data/static_range_sum.generator.py new file mode 100644 index 00000000..0202cd02 --- /dev/null +++ b/examples/data/static_range_sum.generator.py @@ -0,0 +1,17 @@ +#!/usr/bin/env python3 +import random + +def main() -> None: + n = random.randint(1, 5 * 10 ** 3) + q = random.randint(1, 5 * 10 ** 3) + a = [random.randint(1, 10 ** 9) for _ in range(n)] + l = [random.randint(0, n - 1) for _ in range(q)] + r = [random.randint(l[i] + 1, n) for i in range(q)] + print(n) + print(q) + print(len(a), *a) + print(len(l), *l) + print(len(r), *r) + +if __name__ == '__main__': + main() diff --git a/examples/data/static_range_sum.solver.py b/examples/data/static_range_sum.solver.py new file mode 100644 index 00000000..72d81098 --- /dev/null +++ b/examples/data/static_range_sum.solver.py @@ -0,0 +1,22 @@ +from typing import * + +def solve(n: int, q: int, a: List[int], l: List[int], r: List[int]) -> None: + b = [0] * (n + 1) + for i in range(n): + b[i + 1] = b[i] + a[i] + ans = [-1] * q + for j in range(q): + ans[j] = b[r[j]] - b[l[j]] + return ans + +def main() -> None: + n = int(input()) + q = int(input()) + _, *a = map(int, input().split()) + _, *l = map(int, input().split()) + _, *r = map(int, input().split()) + ans = solve(n, q, a, l, r) + print(len(ans), *ans, sep='\n') + +if __name__ == '__main__': + main() From 15b393eba18ab8a3caedfb627844e733c1e82e50 Mon Sep 17 00:00:00 2001 From: Kimiyuki Onaka Date: Sun, 11 Jul 2021 20:39:40 +0900 Subject: [PATCH 8/9] test: Update examples/ --- examples/README.md | 4 +-- ..._c.gen.py => abc134_c.medium.generator.py} | 3 ++- examples/data/abc134_c.medium.solver.py | 23 +++++++++++++++++ ...a.generator.py => dp_a.large.generator.py} | 0 .../{dp_a.solver.py => dp_a.large.solver.py} | 0 ...1000000000.in => fib.medium-1000000000.in} | 0 ...00000000.out => fib.medium-1000000000.out} | 0 ...py => static_range_sum.large.generator.py} | 4 +-- ...er.py => static_range_sum.large.solver.py} | 2 +- scripts/integration_tests.py | 25 ++++++++++++------- 10 files changed, 46 insertions(+), 15 deletions(-) rename examples/data/{abc134_c.gen.py => abc134_c.medium.generator.py} (77%) create mode 100644 examples/data/abc134_c.medium.solver.py rename examples/data/{dp_a.generator.py => dp_a.large.generator.py} (100%) rename examples/data/{dp_a.solver.py => dp_a.large.solver.py} (100%) rename examples/data/{fib.large-1000000000.in => fib.medium-1000000000.in} (100%) rename examples/data/{fib.large-1000000000.out => fib.medium-1000000000.out} (100%) rename examples/data/{static_range_sum.generator.py => static_range_sum.large.generator.py} (82%) rename examples/data/{static_range_sum.solver.py => static_range_sum.large.solver.py} (96%) diff --git a/examples/README.md b/examples/README.md index bce6c765..49f81342 100644 --- a/examples/README.md +++ b/examples/README.md @@ -1,4 +1,4 @@ -# [examples/](https://github.com/kmyk/Jikka/tree/master/examples) +# examples/ ## Scripts @@ -10,7 +10,7 @@ - [ ] `dp_a.py` - [ ] `dp_b.py` - [ ] `m_solutions2019_e.py` -- [ ] `static_range_sum.py` +- [x] `static_range_sum.py` ### Toy scripts diff --git a/examples/data/abc134_c.gen.py b/examples/data/abc134_c.medium.generator.py similarity index 77% rename from examples/data/abc134_c.gen.py rename to examples/data/abc134_c.medium.generator.py index 7a88d4ae..b73fc671 100644 --- a/examples/data/abc134_c.gen.py +++ b/examples/data/abc134_c.medium.generator.py @@ -4,13 +4,14 @@ def main() -> None: parser = argparse.ArgumentParser() + # parser.add_argument('-n', type=int, default=random.randint(2, 200000)) parser.add_argument('-n', type=int, default=random.randint(2, 1000)) args = parser.parse_args() n = args.n a = [random.randint(1, 200000) for _ in range(n)] print(n) - print(len(n), *a, sep='\n') + print(len(a), *a) if __name__ == '__main__': main() diff --git a/examples/data/abc134_c.medium.solver.py b/examples/data/abc134_c.medium.solver.py new file mode 100644 index 00000000..f2685466 --- /dev/null +++ b/examples/data/abc134_c.medium.solver.py @@ -0,0 +1,23 @@ +#!/usr/bin/env python3 +from typing import * + +def solve(n: int, a: List[int]) -> List[int]: + l = [-1] * (n + 1) + r = [-1] * (n + 1) + for i in range(n): + l[i + 1] = max(l[i], a[i]) + for i in reversed(range(n)): + r[i] = max(r[i + 1], a[i]) + ans = [-1] * n + for i in range(n): + ans[i] = max(l[i], r[i + 1]) + return ans + +def main() -> None: + n = int(input()) + _, *a = map(int, input().split()) + ans = solve(n, a) + print(len(ans), *ans) + +if __name__ == '__main__': + main() diff --git a/examples/data/dp_a.generator.py b/examples/data/dp_a.large.generator.py similarity index 100% rename from examples/data/dp_a.generator.py rename to examples/data/dp_a.large.generator.py diff --git a/examples/data/dp_a.solver.py b/examples/data/dp_a.large.solver.py similarity index 100% rename from examples/data/dp_a.solver.py rename to examples/data/dp_a.large.solver.py diff --git a/examples/data/fib.large-1000000000.in b/examples/data/fib.medium-1000000000.in similarity index 100% rename from examples/data/fib.large-1000000000.in rename to examples/data/fib.medium-1000000000.in diff --git a/examples/data/fib.large-1000000000.out b/examples/data/fib.medium-1000000000.out similarity index 100% rename from examples/data/fib.large-1000000000.out rename to examples/data/fib.medium-1000000000.out diff --git a/examples/data/static_range_sum.generator.py b/examples/data/static_range_sum.large.generator.py similarity index 82% rename from examples/data/static_range_sum.generator.py rename to examples/data/static_range_sum.large.generator.py index 0202cd02..07da32dd 100644 --- a/examples/data/static_range_sum.generator.py +++ b/examples/data/static_range_sum.large.generator.py @@ -2,8 +2,8 @@ import random def main() -> None: - n = random.randint(1, 5 * 10 ** 3) - q = random.randint(1, 5 * 10 ** 3) + n = random.randint(1, 5 * 10 ** 5) + q = random.randint(1, 5 * 10 ** 5) a = [random.randint(1, 10 ** 9) for _ in range(n)] l = [random.randint(0, n - 1) for _ in range(q)] r = [random.randint(l[i] + 1, n) for i in range(q)] diff --git a/examples/data/static_range_sum.solver.py b/examples/data/static_range_sum.large.solver.py similarity index 96% rename from examples/data/static_range_sum.solver.py rename to examples/data/static_range_sum.large.solver.py index 72d81098..63a12d98 100644 --- a/examples/data/static_range_sum.solver.py +++ b/examples/data/static_range_sum.large.solver.py @@ -1,6 +1,6 @@ from typing import * -def solve(n: int, q: int, a: List[int], l: List[int], r: List[int]) -> None: +def solve(n: int, q: int, a: List[int], l: List[int], r: List[int]) -> List[int]: b = [0] * (n + 1) for i in range(n): b[i + 1] = b[i] + a[i] diff --git a/scripts/integration_tests.py b/scripts/integration_tests.py index dda97f75..722560d7 100644 --- a/scripts/integration_tests.py +++ b/scripts/integration_tests.py @@ -1,6 +1,7 @@ #!/usr/bin/env python3 import argparse import concurrent.futures +import glob import os import pathlib import subprocess @@ -11,6 +12,9 @@ logger = getLogger(__name__) +CXX_ONLY = 'large' +NO_RPYTHON = 'medium' + def collect_input_cases(script: pathlib.Path, *, tempdir: pathlib.Path) -> List[pathlib.Path]: inputcases: List[pathlib.Path] = [] @@ -24,28 +28,29 @@ def collect_input_cases(script: pathlib.Path, *, tempdir: pathlib.Path) -> List[ inputcases.append(path) # using generators - generator_path = pathlib.Path('examples', 'data', script.stem + '.generator.py') - solver_path = pathlib.Path('examples', 'data', script.stem + '.solver.py') - if generator_path.exists(): + for generator_path in pathlib.Path('examples', 'data').glob(glob.escape(script.stem) + '*.generator.py'): + _, testset_name, _, _ = generator_path.name.split('.') + + solver_path = pathlib.Path('examples', 'data', script.stem + '.' + testset_name + '.solver.py') if not solver_path.exists(): logger.error('%s: failed to find the solver', str(script)) return [] - for i in range(10): - inputcase = tempdir / "{}.random-large-{}.in".format(script.stem, i) - outputcase = tempdir / "{}.random-large-{}.out".format(script.stem, i) + for i in range(20): + inputcase = tempdir / "{}.{}-{}.in".format(script.stem, testset_name, i) + outputcase = tempdir / "{}.{}-{}.out".format(script.stem, testset_name, i) with open(inputcase, 'wb') as fh: try: subprocess.check_call([sys.executable, str(generator_path)], stdout=fh, timeout=2) except subprocess.SubprocessError as e: - logger.error('%s: failed to generate an input of a random case: %s', str(script), e) + logger.error('%s: %s: failed to generate an input of a random case: %s', str(script), str(inputcase), e) return [] with open(inputcase, 'rb') as fh1: with open(outputcase, 'wb') as fh2: try: subprocess.check_call([sys.executable, str(solver_path)], stdin=fh1, stdout=fh2, timeout=2) except subprocess.SubprocessError as e: - logger.error('%s: failed to generate an output of a random case: %s', str(script), e) + logger.error('%s: %s: failed to generate an output of a random case: %s', str(script), str(inputcase), e) return [] inputcases.append(inputcase) @@ -83,7 +88,9 @@ def run_integration_test(script: pathlib.Path, *, executable: pathlib.Path) -> b matrix.append(('core', [str(executable), 'execute', '--target', 'core', str(script)])) matrix.append(('C++', [str(tempdir / 'a.exe')])) for title, command in matrix: - if title == 'restricted Python' and 'large' in inputcase.name: + if title == 'restricted Python' and (NO_RPYTHON in inputcase.name or CXX_ONLY in inputcase.name): + continue + if title == 'core' and CXX_ONLY in inputcase.name: continue if 'wip' in inputcase.name: continue From 55bc63a0cab2ba0f19ab266017273959119fd022 Mon Sep 17 00:00:00 2001 From: Kimiyuki Onaka Date: Sun, 11 Jul 2021 21:04:27 +0900 Subject: [PATCH 9/9] test: Extend timeouts --- scripts/integration_tests.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/scripts/integration_tests.py b/scripts/integration_tests.py index 722560d7..6d4b7916 100644 --- a/scripts/integration_tests.py +++ b/scripts/integration_tests.py @@ -41,14 +41,14 @@ def collect_input_cases(script: pathlib.Path, *, tempdir: pathlib.Path) -> List[ outputcase = tempdir / "{}.{}-{}.out".format(script.stem, testset_name, i) with open(inputcase, 'wb') as fh: try: - subprocess.check_call([sys.executable, str(generator_path)], stdout=fh, timeout=2) + subprocess.check_call([sys.executable, str(generator_path)], stdout=fh, timeout=5) except subprocess.SubprocessError as e: logger.error('%s: %s: failed to generate an input of a random case: %s', str(script), str(inputcase), e) return [] with open(inputcase, 'rb') as fh1: with open(outputcase, 'wb') as fh2: try: - subprocess.check_call([sys.executable, str(solver_path)], stdin=fh1, stdout=fh2, timeout=2) + subprocess.check_call([sys.executable, str(solver_path)], stdin=fh1, stdout=fh2, timeout=5) except subprocess.SubprocessError as e: logger.error('%s: %s: failed to generate an output of a random case: %s', str(script), str(inputcase), e) return [] @@ -64,12 +64,12 @@ def run_integration_test(script: pathlib.Path, *, executable: pathlib.Path) -> b logger.info('%s: compiling...', str(script)) with open(tempdir / 'main.cpp', 'wb') as fh: try: - subprocess.check_call([str(executable), 'convert', str(script)], stdout=fh, timeout=10) + subprocess.check_call([str(executable), 'convert', str(script)], stdout=fh, timeout=20) except subprocess.SubprocessError as e: logger.error('%s: failed to compile from Python to C++: %s', str(script), e) return False try: - subprocess.check_call(['g++', '-std=c++17', '-Wall', '-O2', '-I', str(pathlib.Path('runtime', 'include')), '-o', str(tempdir / 'a.exe'), str(tempdir / 'main.cpp')], timeout=10) + subprocess.check_call(['g++', '-std=c++17', '-Wall', '-O2', '-I', str(pathlib.Path('runtime', 'include')), '-o', str(tempdir / 'a.exe'), str(tempdir / 'main.cpp')], timeout=20) except subprocess.SubprocessError as e: logger.error('%s: failed to compile from C++ to executable: %s', str(script), e) return False @@ -98,7 +98,7 @@ def run_integration_test(script: pathlib.Path, *, executable: pathlib.Path) -> b logger.info('%s: %s: running as %s...', str(script), str(inputcase), title) with open(inputcase, 'rb') as fh: try: - actual = subprocess.check_output(command, stdin=fh, timeout=10) + actual = subprocess.check_output(command, stdin=fh, timeout=20) except subprocess.SubprocessError as e: logger.error('%s: %s: failed to run as %s: %s', str(script), str(inputcase), title, e) return False