Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add scanr, scanl, scanr1, scanl1 #92

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 67 additions & 3 deletions vec/src/Data/Vec/DataFamily/SpineStrict.hs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleInstances #-}
Expand Down Expand Up @@ -91,6 +92,13 @@ module Data.Vec.DataFamily.SpineStrict (
ifoldMap1,
foldr,
ifoldr,
-- * Scans
scanr,
scanl,
scanl',
scanr1,
scanl1,
scanl1',
-- * Special folds
length,
null,
Expand Down Expand Up @@ -582,10 +590,10 @@ last :: forall n a. N.SNatI n => Vec ('S n) a -> a
last xs = getLast (N.induction1 start step) xs where
start :: Last 'Z a
start = Last $ \(x:::VNil) -> x

step :: Last m a -> Last ('S m) a
step (Last rec) = Last $ \(_ ::: ys) -> rec ys


newtype Last n a = Last { getLast :: Vec ('S n) a -> a }

Expand All @@ -596,7 +604,7 @@ init :: forall n a. N.SNatI n => Vec ('S n) a -> Vec n a
init xs = getInit (N.induction1 start step) xs where
start :: Init 'Z a
start = Init (const VNil)

step :: Init m a -> Init ('S m) a
step (Init rec) = Init $ \(y ::: ys) -> y ::: rec ys

Expand Down Expand Up @@ -845,6 +853,62 @@ ifoldr = getIFoldr $ N.induction1 start step where

newtype IFoldr a n b = IFoldr { getIFoldr :: (Fin n -> a -> b -> b) -> b -> Vec n a -> b }

-- | Right-to-left scan.
scanr :: forall a b n. N.SNatI n => (a -> b -> b) -> b -> Vec n a -> Vec ('S n) b
scanr f z = getScanr $ N.induction1 start step where
start :: Scanr a 'Z b
start = Scanr $ \_ -> singleton z

step :: Scanr a m b -> Scanr a ('S m) b
step (Scanr go) = Scanr $ \(x ::: xs) -> let ys@(y ::: _) = go xs in f x y ::: ys

newtype Scanr a n b = Scanr { getScanr :: Vec n a -> Vec ('S n) b }

-- | Left-to-right scan.
scanl :: forall a b n. N.SNatI n => (b -> a -> b) -> b -> Vec n a -> Vec ('S n) b
scanl f = getScanl $ N.induction1 start step where
start :: Scanl a 'Z b
start = Scanl $ \z VNil -> singleton z

step :: Scanl a m b -> Scanl a ('S m) b
step (Scanl go) = Scanl $ \acc (x ::: xs) -> acc ::: go (f acc x) xs

-- | Left-to-right scan with strict accumulator.
scanl' :: forall a b n. N.SNatI n => (b -> a -> b) -> b -> Vec n a -> Vec ('S n) b
scanl' f = getScanl $ N.induction1 start step where
start :: Scanl a 'Z b
start = Scanl $ \z VNil -> singleton z

step :: Scanl a m b -> Scanl a ('S m) b
step (Scanl go) = Scanl $ \(!acc) (x ::: xs) -> acc ::: go (f acc x) xs

newtype Scanl a n b = Scanl { getScanl :: b -> Vec n a -> Vec ('S n) b }

-- | Right-to-left scan with no starting value.
scanr1 :: forall a n. N.SNatI n => (a -> a -> a) -> Vec n a -> Vec n a
scanr1 f = getScanr1 $ N.induction1 start step where
start :: Scanr1 'Z a
start = Scanr1 $ \_ -> VNil

step :: forall m. N.SNatI m => Scanr1 m a -> Scanr1 ('S m) a
step (Scanr1 go) = Scanr1 $ \(x ::: xs) -> case N.snat :: N.SNat m of
N.SZ -> x ::: VNil
N.SS -> let ys@(y ::: _) = go xs in f x y ::: ys

newtype Scanr1 n a = Scanr1 { getScanr1 :: Vec n a -> Vec n a }

-- | Left-to-right scan with no starting value.
scanl1 :: forall a n. N.SNatI n => (a -> a -> a) -> Vec n a -> Vec n a
scanl1 f xs = case N.snat :: N.SNat n of
N.SZ -> VNil
N.SS -> let (y ::: ys) = xs in scanl f y ys

-- | Left-to-right scan with no starting value, and with strict accumulator.
scanl1' :: forall a n. N.SNatI n => (a -> a -> a) -> Vec n a -> Vec n a
scanl1' f xs = case N.snat :: N.SNat n of
N.SZ -> VNil
N.SS -> let (y ::: ys) = xs in scanl' f y ys

-- | Yield the length of a 'Vec'. /O(n)/
length :: forall n a. N.SNatI n => Vec n a -> Int
length _ = getLength l where
Expand Down
46 changes: 46 additions & 0 deletions vec/src/Data/Vec/Lazy.hs
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,13 @@ module Data.Vec.Lazy (
foldr,
ifoldr,
foldl',
-- * Scans
scanr,
scanl,
scanl',
scanr1,
scanl1,
scanl1',
-- * Special folds
length,
null,
Expand Down Expand Up @@ -691,6 +698,45 @@ foldl' f z = go z where
go !acc VNil = acc
go !acc (x ::: xs) = go (f acc x) xs

-- | Right-to-left scan.
scanr :: forall a b n. (a -> b -> b) -> b -> Vec n a -> Vec ('S n) b
scanr f z = go where
go :: Vec m a -> Vec ('S m) b
go VNil = singleton z
go (x ::: xs) = case go xs of ys@(y ::: _) -> f x y ::: ys

-- | Left-to-right scan.
scanl :: forall a b n. (b -> a -> b) -> b -> Vec n a -> Vec ('S n) b
scanl f = go where
go :: b -> Vec m a -> Vec ('S m) b
go acc VNil = acc ::: VNil
go acc (x ::: xs) = acc ::: go (f acc x) xs

-- | Left-to-right scan with strict accumulator.
scanl' :: forall a b n. (b -> a -> b) -> b -> Vec n a -> Vec ('S n) b
scanl' f = go where
go :: b -> Vec m a -> Vec ('S m) b
go !acc VNil = acc ::: VNil
go !acc (x ::: xs) = acc ::: go (f acc x) xs

-- | Right-to-left scan with no starting value.
scanr1 :: forall a n. (a -> a -> a) -> Vec n a -> Vec n a
scanr1 f = go where
go :: Vec m a -> Vec m a
go VNil = VNil
go (x ::: VNil) = x ::: VNil
go (x ::: xs@(_ ::: _)) = case go xs of ys@(y ::: _) -> f x y ::: ys

-- | Left-to-right scan with no starting value.
scanl1 :: forall a n. (a -> a -> a) -> Vec n a -> Vec n a
scanl1 _ VNil = VNil
scanl1 f (x ::: xs) = scanl f x xs

-- | Left-to-right scan with no starting value, and with strict accumulator.
scanl1' :: forall a n. (a -> a -> a) -> Vec n a -> Vec n a
scanl1' _ VNil = VNil
scanl1' f (x ::: xs) = scanl' f x xs

-- | Yield the length of a 'Vec'. /O(n)/
length :: Vec n a -> Int
length = go 0 where
Expand Down
71 changes: 68 additions & 3 deletions vec/src/Data/Vec/Lazy/Inline.hs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleInstances #-}
Expand Down Expand Up @@ -51,6 +52,13 @@ module Data.Vec.Lazy.Inline (
ifoldMap1,
foldr,
ifoldr,
-- * Scans
scanr,
scanl,
scanl',
scanr1,
scanl1,
scanl1',
-- * Special folds
length,
null,
Expand Down Expand Up @@ -260,10 +268,10 @@ last :: forall n a. N.SNatI n => Vec ('S n) a -> a
last xs = getLast (N.induction1 start step) xs where
start :: Last 'Z a
start = Last $ \(x:::VNil) -> x

step :: Last m a -> Last ('S m) a
step (Last rec) = Last $ \(_ ::: ys) -> rec ys


newtype Last n a = Last { getLast :: Vec ('S n) a -> a }

Expand All @@ -274,7 +282,7 @@ init :: forall n a. N.SNatI n => Vec ('S n) a -> Vec n a
init xs = getInit (N.induction1 start step) xs where
start :: Init 'Z a
start = Init (const VNil)

step :: Init m a -> Init ('S m) a
step (Init rec) = Init $ \(y ::: ys) -> y ::: rec ys

Expand Down Expand Up @@ -520,6 +528,63 @@ ifoldr = getIFoldr $ N.induction1 start step where

newtype IFoldr a n b = IFoldr { getIFoldr :: (Fin n -> a -> b -> b) -> b -> Vec n a -> b }

-- | Right-to-left scan.
scanr :: forall a b n. N.SNatI n => (a -> b -> b) -> b -> Vec n a -> Vec ('S n) b
scanr f z = getScanr $ N.induction1 start step where
start :: Scanr a 'Z b
start = Scanr $ \_ -> singleton z

step :: Scanr a m b -> Scanr a ('S m) b
step (Scanr go) = Scanr $ \(x ::: xs) -> case go xs of
ys@(y ::: _) -> f x y ::: ys

newtype Scanr a n b = Scanr { getScanr :: Vec n a -> Vec ('S n) b }

-- | Left-to-right scan.
scanl :: forall a b n. N.SNatI n => (b -> a -> b) -> b -> Vec n a -> Vec ('S n) b
scanl f = getScanl $ N.induction1 start step where
start :: Scanl a 'Z b
start = Scanl $ \z VNil -> singleton z

step :: Scanl a m b -> Scanl a ('S m) b
step (Scanl go) = Scanl $ \acc (x ::: xs) -> acc ::: go (f acc x) xs

-- | Left-to-right scan with strict accumulator.
scanl' :: forall a b n. N.SNatI n => (b -> a -> b) -> b -> Vec n a -> Vec ('S n) b
scanl' f = getScanl $ N.induction1 start step where
start :: Scanl a 'Z b
start = Scanl $ \z VNil -> singleton z

step :: Scanl a m b -> Scanl a ('S m) b
step (Scanl go) = Scanl $ \(!acc) (x ::: xs) -> acc ::: go (f acc x) xs
phadej marked this conversation as resolved.
Show resolved Hide resolved

newtype Scanl a n b = Scanl { getScanl :: b -> Vec n a -> Vec ('S n) b }

-- | Right-to-left scan with no starting value.
scanr1 :: forall a n. N.SNatI n => (a -> a -> a) -> Vec n a -> Vec n a
scanr1 f = getScanr1 $ N.induction1 start step where
start :: Scanr1 'Z a
start = Scanr1 $ \_ -> VNil

step :: forall m. N.SNatI m => Scanr1 m a -> Scanr1 ('S m) a
step (Scanr1 go) = Scanr1 $ \(x ::: xs) -> case N.snat :: N.SNat m of
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't feel right. You shouldn't need to check length in the step case. I can take a look myself if you cannot find a way to avoid it.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In scanr1, the last element is special (it is the zero element), and so the 0 -> 1 step and the m -> m + 1 step (where m > 0) are different. But yeah please let me know if there's a better way to write this.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@phadej - let me know if you have any updates on this!

N.SZ -> x ::: VNil
N.SS -> case go xs of ys@(y ::: _) -> f x y ::: ys

newtype Scanr1 n a = Scanr1 { getScanr1 :: Vec n a -> Vec n a }

-- | Left-to-right scan with no starting value.
scanl1 :: forall a n. N.SNatI n => (a -> a -> a) -> Vec n a -> Vec n a
scanl1 f xs = case N.snat :: N.SNat n of
N.SZ -> VNil
N.SS -> case xs of y ::: ys -> scanl f y ys

-- | Left-to-right scan with no starting value, and with strict accumulator.
scanl1' :: forall a n. N.SNatI n => (a -> a -> a) -> Vec n a -> Vec n a
scanl1' f xs = case N.snat :: N.SNat n of
N.SZ -> VNil
N.SS -> case xs of y ::: ys -> scanl' f y ys

-- | Yield the length of a 'Vec'. /O(n)/
length :: forall n a. N.SNatI n => Vec n a -> Int
length _ = getLength l where
Expand Down
2 changes: 1 addition & 1 deletion vec/src/Data/Vec/Pull.hs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
--
-- The module tries to have same API as "Data.Vec.Lazy", missing bits:
-- @withDict@, @toPull@, @fromPull@, @traverse@ (and variants),
-- @(++)@, @concat@ and @split@.
-- @scanr@ (and variants), @(++)@, @concat@ and @split@.
module Data.Vec.Pull (
Vec (..),
-- * Construction
Expand Down
100 changes: 98 additions & 2 deletions vec/test/Inspection.hs
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ lhsLast = I.last $ 'a' ::: 'b' ::: 'c' ::: VNil
lhsLast' :: Char
lhsLast' = L.last $ 'a' ::: 'b' ::: 'c' :::VNil

rhsLast :: Char
rhsLast :: Char
rhsLast = 'c'

inspect $ 'lhsLast === 'rhsLast
Expand Down Expand Up @@ -167,4 +167,100 @@ rhsToNonEmpty :: NonEmpty Char
rhsToNonEmpty = 'a' :| ['b', 'c']

inspect $ 'lhsToNonEmpty === 'rhsToNonEmpty
inspect $ 'lhsToNonEmpty' =/= 'rhsToNonEmpty
inspect $ 'lhsToNonEmpty' =/= 'rhsToNonEmpty

-------------------------------------------------------------------------------
-- scanr
-------------------------------------------------------------------------------

lhsScanr :: Vec N.Nat5 Int
lhsScanr = I.scanr (-) 0 $ 1 ::: 2 ::: 3 ::: 4 ::: VNil

lhsScanr' :: Vec N.Nat5 Int
lhsScanr' = L.scanr (-) 0 $ 1 ::: 2 ::: 3 ::: 4 ::: VNil

rhsScanr :: Vec N.Nat5 Int
rhsScanr = (-2) ::: 3 ::: (-1) ::: 4 ::: 0 ::: VNil

inspect $ 'lhsScanr === 'rhsScanr
inspect $ 'lhsScanr' =/= 'rhsScanr

-------------------------------------------------------------------------------
-- scanl
-------------------------------------------------------------------------------

lhsScanl :: Vec N.Nat5 Int
lhsScanl = I.scanl (-) 0 $ 1 ::: 2 ::: 3 ::: 4 ::: VNil

lhsScanl0 :: Vec N.Nat5 Int
lhsScanl0 = L.scanl (-) 0 $ 1 ::: 2 ::: 3 ::: 4 ::: VNil

rhsScanl :: Vec N.Nat5 Int
rhsScanl = 0 ::: (-1) ::: (-3) ::: (-6) ::: (-10) ::: VNil

inspect $ 'lhsScanl === 'rhsScanl
inspect $ 'lhsScanl0 =/= 'rhsScanl

-------------------------------------------------------------------------------
-- scanl'
-------------------------------------------------------------------------------

lhsScanl' :: Vec N.Nat5 Int
lhsScanl' = I.scanl' (-) 0 $ 1 ::: 2 ::: 3 ::: 4 ::: VNil

lhsScanl'0 :: Vec N.Nat5 Int
lhsScanl'0 = L.scanl' (-) 0 $ 1 ::: 2 ::: 3 ::: 4 ::: VNil

rhsScanl' :: Vec N.Nat5 Int
rhsScanl' = 0 ::: (-1) ::: (-3) ::: (-6) ::: (-10) ::: VNil

inspect $ 'lhsScanl' === 'rhsScanl'
inspect $ 'lhsScanl'0 =/= 'rhsScanl'

-------------------------------------------------------------------------------
-- scanr1
-------------------------------------------------------------------------------

lhsScanr1 :: Vec N.Nat4 Int
lhsScanr1 = I.scanr1 (-) $ 1 ::: 2 ::: 3 ::: 4 ::: VNil

lhsScanr1' :: Vec N.Nat4 Int
lhsScanr1' = L.scanr1 (-) $ 1 ::: 2 ::: 3 ::: 4 ::: VNil

rhsScanr1 :: Vec N.Nat4 Int
rhsScanr1 = (-2) ::: 3 ::: (-1) ::: 4 ::: VNil

inspect $ 'lhsScanr1 === 'rhsScanr1
inspect $ 'lhsScanr1' =/= 'rhsScanr1

-------------------------------------------------------------------------------
-- scanl1
-------------------------------------------------------------------------------

lhsScanl1 :: Vec N.Nat4 Int
lhsScanl1 = I.scanl1 (-) $ 1 ::: 2 ::: 3 ::: 4 ::: VNil

lhsScanl10 :: Vec N.Nat4 Int
lhsScanl10 = L.scanl1 (-) $ 1 ::: 2 ::: 3 ::: 4 ::: VNil

rhsScanl1 :: Vec N.Nat4 Int
rhsScanl1 = 1 ::: (-1) ::: (-4) ::: (-8) ::: VNil

inspect $ 'lhsScanl1 === 'rhsScanl1
inspect $ 'lhsScanl10 =/= 'rhsScanl1

-------------------------------------------------------------------------------
-- scanl1'
-------------------------------------------------------------------------------

lhsScanl1' :: Vec N.Nat4 Int
lhsScanl1' = I.scanl1' (-) $ 1 ::: 2 ::: 3 ::: 4 ::: VNil

lhsScanl1'0 :: Vec N.Nat4 Int
lhsScanl1'0 = L.scanl1' (-) $ 1 ::: 2 ::: 3 ::: 4 ::: VNil

rhsScanl1' :: Vec N.Nat4 Int
rhsScanl1' = 1 ::: (-1) ::: (-4) ::: (-8) ::: VNil

inspect $ 'lhsScanl1' === 'rhsScanl1'
inspect $ 'lhsScanl1'0 =/= 'rhsScanl1'
Loading