diff --git a/fusion-test/Canary.hs b/fusion-test/Canary.hs new file mode 100644 index 00000000..33417ee1 --- /dev/null +++ b/fusion-test/Canary.hs @@ -0,0 +1,120 @@ +{- | + +This module provide the function 'fuseHere', which can be inserted into a +pipeline of vector-processing functions. It also contains copies of all the +fusion-related rewrite rules from "Data.Vector.Generic", with 'fuseHere' +inserted. This way, if fusion happens at this point, the 'fuseHere' function +disappears. + +Having to maintain a complete copy of all the rewrite rules is a big downsid of +this approach, and a better way would be appreciated. + +-} +module Canary (fuseHere) where + +import qualified Data.Vector.Generic as V +import qualified Data.Vector.Generic.New as New +import Data.Vector.Fusion.Stream.Monadic ( Stream ) +import qualified Data.Vector.Fusion.Bundle as Bundle +import Data.Vector.Fusion.Bundle ( Bundle, MBundle, lift, inplace ) +import qualified Data.Vector.Fusion.Bundle.Monadic as MBundle + +-- | Put this function into vector pipelines where you want them to fuse +fuseHere :: a -> a +fuseHere = id +{-# NOINLINE fuseHere #-} + +-- | We need to copy all fusion rules here, with fuseHere inserted in the right +-- spot. + +{-# RULES + +"(!)/fuseHere/unstream [Vector]" forall i s. + fuseHere (V.new (New.unstream s)) V.! i = s Bundle.!! i + +"(!?)/fuseHere/unstream [Vector]" forall i s. + fuseHere (V.new (New.unstream s)) V.!? i = s Bundle.!? i + +"head/fuseHere/unstream [Vector]" forall s. + V.head (fuseHere (V.new (New.unstream s))) = Bundle.head s + +"last/fuseHere/unstream [Vector]" forall s. + V.last (fuseHere (V.new (New.unstream s))) = Bundle.last s + +"unsafeIndex/fuseHere/unstream [Vector]" forall i s. + V.unsafeIndex (fuseHere (V.new (New.unstream s))) i = s Bundle.!! i + +"unsafeHead/fuseHere/unstream [Vector]" forall s. + V.unsafeHead (fuseHere (V.new (New.unstream s))) = Bundle.head s + +"unsafeLast/fuseHere/unstream [Vector]" forall s. + V.unsafeLast (fuseHere (V.new (New.unstream s))) = Bundle.last s #-} + +{-# RULES + +"indexM/fuseHere/unstream [Vector]" forall s i. + V.indexM (fuseHere (V.new (New.unstream s))) i = lift s MBundle.!! i + +"headM/fuseHere/unstream [Vector]" forall s. + V.headM (fuseHere (V.new (New.unstream s))) = MBundle.head (lift s) + +"lastM/fuseHere/unstream [Vector]" forall s. + V.lastM (fuseHere (V.new (New.unstream s))) = MBundle.last (lift s) + +"unsafeIndexM/fuseHere/unstream [Vector]" forall s i. + V.unsafeIndexM (fuseHere (V.new (New.unstream s))) i = lift s MBundle.!! i + +"unsafeHeadM/fuseHere/unstream [Vector]" forall s. + V.unsafeHeadM (fuseHere (V.new (New.unstream s))) = MBundle.head (lift s) + +"unsafeLastM/fuseHere/unstream [Vector]" forall s. + V.unsafeLastM (fuseHere (V.new (New.unstream s))) = MBundle.last (lift s) #-} + +{-# RULES + +"slice/fuseHere/new [Vector]" forall i n p. + V.slice i n (fuseHere (V.new p)) = V.new (New.slice i n p) + +"init/fuseHere/new [Vector]" forall p. + V.init (fuseHere (V.new p)) = V.new (New.init p) + +"tail/fuseHere/new [Vector]" forall p. + V.tail (fuseHere (V.new p)) = V.new (New.tail p) + +"take/fuseHere/new [Vector]" forall n p. + V.take n (fuseHere (V.new p)) = V.new (New.take n p) + +"drop/fuseHere/new [Vector]" forall n p. + V.drop n (fuseHere (V.new p)) = V.new (New.drop n p) + +"unsafeSlice/fuseHere/new [Vector]" forall i n p. + V.unsafeSlice i n (fuseHere (V.new p)) = V.new (New.unsafeSlice i n p) + +"unsafeInit/fuseHere/new [Vector]" forall p. + V.unsafeInit (fuseHere (V.new p)) = V.new (New.unsafeInit p) + +"unsafeTail/fuseHere/new [Vector]" forall p. + V.unsafeTail (fuseHere (V.new p)) = V.new (New.unsafeTail p) #-} + + +{-# RULES + +"stream/fuseHere/unstream [Vector]" forall s. + V.stream (fuseHere (V.new (New.unstream s))) = s + +"New.unstream/fuseHere/stream [Vector]" forall v. + New.unstream (fuseHere (V.stream v)) = V.clone v + +"clone/fuseHere/new [Vector]" forall p. + V.clone (fuseHere (V.new p)) = p + +"inplace [Vector]" + forall (f :: forall m. Monad m => Stream m a -> Stream m a) g m. + New.unstream (inplace f g (V.stream (V.new m))) = New.transform f g m + +"uninplace [Vector]" + forall (f :: forall m. Monad m => Stream m a -> Stream m a) g m. + V.stream (V.new (New.transform f g m)) = inplace f g (V.stream (V.new m)) +#-} + + diff --git a/fusion-test/Main.hs b/fusion-test/Main.hs new file mode 100644 index 00000000..84224bba --- /dev/null +++ b/fusion-test/Main.hs @@ -0,0 +1,258 @@ +{-# LANGUAGE TemplateHaskell #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# OPTIONS_GHC -fplugin=Test.Inspection.Plugin #-} + +{- + +Fusion tests for vector +======================= + +This file tests whether fusion happens the way we want it. + +It does so by creating a typical pipeline with a function. For example, to test +`map`, we write + + test_map f = V.toList . V.map f . V.fromList + +Actually, we insert the `fuseHere` function, from the Canary module, in the +spots where we expect fusion to happen: + + test_map f = V.toList . fuseHere . V.map f . fuseHere . V.fromList + +The `fuseHere` function has rules included that make it disappear if fusion +happens at this spot. See Canary.hs for more details. +So if `fuseHere` remains in the code, fusion did not happen as expected. We +check this using the `inspection-testing` library. + +Even if `fuseHere` disappeared, which means that for example the +`unstream/stream` rule fired, we want to check that all of the constructors +of the `Step` data type (`Skip`, `Done`, `Yield`) have been eliminiated. We +test this again using `inspection-testing`. + +-} + +module Main where + +import qualified Data.Vector as V +import Canary +import Test.Inspection +import Data.Vector.Fusion.Stream.Monadic (Step(..)) +import Control.Monad +import qualified Language.Haskell.TH as TH + +main :: IO () +main = return () + +-- Testing pipelines for producers, transformers, consumers +p f = V.toList . fuseHere . f +t f = V.toList . fuseHere . f . fuseHere . V.fromList +c f = f . fuseHere . V.fromList +{-# INLINE t #-} +{-# INLINE c #-} +{-# INLINE p #-} + +-- To help the type checker, avoid ambiguous Monad ctraints +inIO :: IO a -> IO a +inIO = id + +-- Length information + +test_length = c V.length +test_null = c V.null + +-- Indexing + +test_bang = c (V.! 42) +test_safe_bang = c (V.!? 42) +test_head = c V.head +test_last = c V.last +test_unsafeIndex = c (`V.unsafeIndex` 42) +test_unsafeHead = c V.unsafeHead +test_unsafeLast = c V.unsafeLast + +-- Monadic Indexing + +test_indexM = inIO . c (`V.indexM` 42) +test_headM = inIO . c V.headM +test_lastM = inIO . c V.lastM +test_unsafeIndexM = inIO . c (`V.unsafeIndexM` 42) +test_unsafeHeadM = inIO . c V.unsafeHeadM +test_unsafeLastM = inIO . c V.unsafeLastM + +-- Extracting subvectors (slicing) +test_slice = t (V.slice 23 42) +test_init = t V.init +test_tail = t V.tail +test_take = t (V.take 42) +test_drop = t (V.drop 42) +-- splitAt: hard to test +test_unsafeSlice = t (V.unsafeSlice 23 42) +test_unsafeInit = t V.unsafeInit +test_unsafeTail = t V.unsafeTail +test_unsafeTake = t (V.unsafeTake 42) +-- Does not actually fuse +-- test_unsafeDrop = t (V.unsafeDrop 42) + +-- Initialisation + +-- Does not fuse, as the ctant expression floats out +--test_empty = p (\(_::()) -> V.empty) +test_singleton = p V.singleton +test_replicate = p (V.replicate 5) +test_generate = p (V.generate 5) +test_iterateN = p (V.iterateN 5 (+1)) + +-- Monadic Initialisation +-- +-- These don't fuse (no rules for unstreamM) + +-- Unfolding + +test_unfoldr x = p (V.unfoldr x) +test_unfoldrN x = p (V.unfoldrN 42 x) +-- ctructN and ctructrN cannot fuse + +-- Enumeration + +test_enumFromN (x::Integer) = p (V.enumFromN x) +test_enumFromStepN (x::Integer) y = p (V.enumFromStepN x y) +test_enumFromTo (x::Integer) = p (V.enumFromTo x) +test_enumFromThenTo (x::Integer) y = p (V.enumFromThenTo x y) + +-- Concatenation + +test_cons x = t (V.cons x) +test_snoc x = t (`V.snoc` x) +test_append_r x = t (x V.++) +test_append_l x = t (V.++ x) +test_concat = p V.concat + +-- Bulk updates + +-- bulk updates fuse as a consumers, but not as a producer +test_upd x = c (V.// x) +test_update_l x = c (`V.update` x) +test_update_r x = c (x `V.update`) +test_update__1 y z = c (\x -> V.update_ x y z) +test_update__2 y z = c (\x -> V.update_ y x z) +test_update__3 y z = c (\x -> V.update_ y z x) +test_unsafeUpd x = c (`V.unsafeUpd` x) +test_unsafeUpdate_l x = c (`V.unsafeUpdate` x) +test_unsafeUpdate_r x = c (x `V.unsafeUpdate`) +test_unsafeUpdate__1 y z = c (\x -> V.unsafeUpdate_ x y z) +test_unsafeUpdate__2 y z = c (\x -> V.unsafeUpdate_ y x z) +test_unsafeUpdate__3 y z = c (\x -> V.unsafeUpdate_ y z x) + +-- Accumulations + +test_accum f y = c (\x -> V.accum f x y) +test_accumulate_l f y = c (\x -> V.accumulate f x y) +test_accumulate_r f y = c (\x -> V.accumulate f y x) +test_accumulate__1 f y z = c (\x -> V.accumulate_ f x y z) +test_accumulate__2 f y z = c (\x -> V.accumulate_ f y x z) +test_accumulate__3 f y z = c (\x -> V.accumulate_ f y z x) +test_unsafeAccum f y = c (\x -> V.unsafeAccum f x y) +test_unsafeAccumulate_l f y = c (\x -> V.unsafeAccumulate f x y) +test_unsafeAccumulate_r f y = c (\x -> V.unsafeAccumulate f y x) +test_unsafeAccumulate__1 f y z = c (\x -> V.unsafeAccumulate_ f x y z) +test_unsafeAccumulate__2 f y z = c (\x -> V.unsafeAccumulate_ f y x z) +test_unsafeAccumulate__3 f y z = c (\x -> V.unsafeAccumulate_ f y z x) + +-- Permutations + +-- reverse is only a good producer, not a good consumer +test_reverse = p V.reverse +-- backpermute is only a good consumer in the second argument, +-- but not the first and not a good producer +test_backpermute y = c (V.backpermute y) +test_unsafeBackpermute y = c (V.unsafeBackpermute y) + +-- Elementwise operations + +test_indexed = t V.indexed +test_map f = t (V.map f) +test_imap f = t (V.imap f) +test_concatMap f = t (V.concatMap f) +-- No deep fusion? +-- test_concatMap_deep f = t (V.concatMap (\ x -> fuseHere (f x))) + +-- ... much more to come ... + +fmap (concat . reverse)$ forM + [ 'test_head + , 'test_null + , 'test_bang + , 'test_safe_bang + , 'test_head + , 'test_last + , 'test_unsafeIndex + , 'test_unsafeHead + , 'test_unsafeLast + , 'test_indexM + , 'test_headM + , 'test_lastM + , 'test_unsafeIndexM + , 'test_unsafeHeadM + , 'test_unsafeLastM + , 'test_slice + , 'test_init + , 'test_tail + , 'test_take + , 'test_drop + , 'test_unsafeSlice + , 'test_unsafeInit + , 'test_unsafeTail + , 'test_unsafeTake +-- , 'test_unsafeDrop +-- , 'test_empty + , 'test_singleton + , 'test_replicate + , 'test_generate + , 'test_iterateN + , 'test_unfoldr + , 'test_unfoldrN + , 'test_enumFromN + , 'test_enumFromStepN + , 'test_enumFromTo + , 'test_enumFromThenTo + , 'test_cons + , 'test_cons + , 'test_snoc + , 'test_append_r + , 'test_append_l + , 'test_concat + , 'test_upd + , 'test_update_l + , 'test_update_r + , 'test_update__1 + , 'test_update__2 + , 'test_update__3 + , 'test_unsafeUpd + , 'test_unsafeUpdate_l + , 'test_unsafeUpdate_r + , 'test_unsafeUpdate__1 + , 'test_unsafeUpdate__2 + , 'test_unsafeUpdate__3 + , 'test_accum + , 'test_accumulate_l + , 'test_accumulate_r + , 'test_accumulate__1 + , 'test_accumulate__2 + , 'test_accumulate__3 + , 'test_unsafeAccum + , 'test_unsafeAccumulate_l + , 'test_unsafeAccumulate_r + , 'test_unsafeAccumulate__1 + , 'test_unsafeAccumulate__2 + , 'test_unsafeAccumulate__3 + , 'test_reverse + , 'test_backpermute + , 'test_unsafeBackpermute + , 'test_indexed + , 'test_map + , 'test_imap + , 'test_concatMap + ] $ \thn -> inspect + (mkObligation thn (NoUseOf ['fuseHere, 'Yield, 'Skip, 'Done])) + { testName = Just (TH.nameBase thn) } + diff --git a/vector.cabal b/vector.cabal index 573c3ad9..a10db141 100644 --- a/vector.cabal +++ b/vector.cabal @@ -268,3 +268,14 @@ test-suite vector-tests-O2 if impl(ghc >= 8.0) && impl(ghc < 8.1) Ghc-Options: -Wno-redundant-constraints +test-suite fusion-tests + Default-Language: Haskell2010 + type: exitcode-stdio-1.0 + Main-Is: Main.hs + + other-modules: Canary + hs-source-dirs: fusion-test + Build-Depends: base >= 4.5 && < 5, vector, template-haskell, + inspection-testing >= 0.4.1 + + Ghc-Options: -dsuppress-all