diff --git a/toolz/itertoolz.py b/toolz/itertoolz.py index cecab3cf..08410ba2 100644 --- a/toolz/itertoolz.py +++ b/toolz/itertoolz.py @@ -13,8 +13,8 @@ 'unique', 'isiterable', 'isdistinct', 'take', 'drop', 'take_nth', 'first', 'second', 'nth', 'last', 'get', 'concat', 'concatv', 'mapcat', 'cons', 'interpose', 'frequencies', 'reduceby', 'iterate', - 'sliding_window', 'partition', 'partition_all', 'count', 'pluck', - 'join', 'tail', 'diff', 'topk', 'peek', 'random_sample') + 'pad', 'sliding_window', 'partition', 'partition_all', 'count', + 'pluck', 'join', 'tail', 'diff', 'topk', 'peek', 'random_sample') def remove(predicate, seq): @@ -653,6 +653,53 @@ def iterate(func, x): x = func(x) +def pad(seq, before=0, after=0, fill=None): + """ Pads a sequence by a fill value before and/or after. + + Pads the sequence before and after using the fill value provided + by ``fill`` up to the lengths specified by ``before`` and + ``after``. If either ``before`` or ``after`` is ``None``, pad + the fill value infinitely on the respective end. + + Note: + If ``before``is ``None``, the sequence will only be the fill + value. + + Args: + + seq(iterable): Sequence to pad. + before(integral): Amount to pad before. + after(integral): Amount to pad after. + fill(any): Some value to pad with. + + Returns: + + iterable: A sequence that has been padded. + + Examples: + + >>> list(pad(range(2, 4), before=1, after=2, fill=0)) + [0, 2, 3, 0, 0] + + """ + + all_seqs = [] + + if before is None: + return itertools.repeat(fill) + elif before > 0: + all_seqs.append(itertools.repeat(fill, before)) + + all_seqs.append(seq) + + if after is None: + all_seqs.append(itertools.repeat(fill)) + elif after > 0: + all_seqs.append(itertools.repeat(fill, after)) + + return concat(all_seqs) + + def sliding_window(n, seq): """ A sequence of overlapping subsequences diff --git a/toolz/tests/test_itertoolz.py b/toolz/tests/test_itertoolz.py index 93aa856d..ba180ef8 100644 --- a/toolz/tests/test_itertoolz.py +++ b/toolz/tests/test_itertoolz.py @@ -10,7 +10,7 @@ mapcat, isdistinct, first, second, nth, take, tail, drop, interpose, get, rest, last, cons, frequencies, - reduceby, iterate, accumulate, + reduceby, iterate, pad, accumulate, sliding_window, count, partition, partition_all, take_nth, pluck, join, diff, topk, peek, random_sample) @@ -296,6 +296,28 @@ def test_accumulate_works_on_consumable_iterables(): assert list(accumulate(add, iter((1, 2, 3)))) == [1, 3, 6] +def test_pad(): + assert list(pad([1,2,3])) == [1, 2, 3] + assert list(pad([1,2,3], before=1)) == [None, 1, 2, 3] + assert list(pad([1,2,3], after=2)) == [1, 2, 3, None, None] + assert list(pad([1,2,3], before=1, after=2, fill=0)) == [0, 1, 2, 3, + 0, 0] + assert list(zip(range(3), pad([1,2,3], before=None))) == [(0, None), + (1, None), + (2, None)] + assert list(zip(range(6), pad([1,2,3], after=None))) == [(0, 1), + (1, 2), + (2, 3), + (3, None), + (4, None), + (5, None)] + + padded = pad([1,2,3], before=None, after=None) + assert list(zip(range(3), padded)) == [(0, None), + (1, None), + (2, None)] + + def test_sliding_window(): assert list(sliding_window(2, [1, 2, 3, 4])) == [(1, 2), (2, 3), (3, 4)] assert list(sliding_window(3, [1, 2, 3, 4])) == [(1, 2, 3), (2, 3, 4)]