Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feats/bucket lord #23

Merged
merged 13 commits into from
Oct 23, 2023
193 changes: 105 additions & 88 deletions mammoth/inputters/dataloader.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import collections
import itertools
import math
import random

import torch
Expand All @@ -8,16 +9,11 @@
from mammoth.utils.logging import logger


def infinite_iterator(iterable):
return itertools.chain.from_iterable(itertools.repeat(iterable))


def build_dataloader(dataset, batch_size, batch_type, pool_size=None, n_buckets=None, cycle=True, as_iter=True):
"""Convert an mammoth.inputters.ParallelCorpus into an infinite iterator of batches"""
if not cycle:
loader = InferenceBatcher(dataset, batch_size)
else:
examples_stream = infinite_iterator(dataset)
if batch_type == 'sents':
n_buckets = 1

Expand All @@ -30,23 +26,25 @@ def numel_fn(_):
elif batch_type == 'tokens':

def bucket_fn(example_dict):
"""map example dict to bucket index"""
# subtract two for bos/eos
src_len = min(len(example_dict['src']), n_buckets) - 2
if 'tgt' in example_dict:
# subtract four for bos/eos on both sides
true_size = len(example_dict['src']) + len(example_dict['tgt']) - 4
tgt_len = min(len(example_dict['tgt']), n_buckets) - 2
else:
true_size = len(example_dict['src']) + 2
tgt_len = src_len
# maybe dump it in the last bucket if it's just too long
return min(n_buckets - 1, true_size)
return src_len, tgt_len

def numel_fn(example_dict):
"""count tokens in example"""
if 'tgt' in example_dict:
true_size = len(example_dict['src']) + len(example_dict['tgt'])
else:
true_size = len(example_dict['src'])
return true_size

collate_fn = dataset.collate_fn
loader = LookAheadBucketing(examples_stream, pool_size, n_buckets, batch_size, bucket_fn, numel_fn, collate_fn)
loader = LookAheadBucketing(dataset, pool_size, n_buckets, batch_size, bucket_fn, numel_fn)
return iter(loader) if as_iter else loader


Expand All @@ -72,117 +70,136 @@ def __iter__(self):


class LookAheadBucketing():
def __init__(self, examples_stream, look_ahead_size, n_buckets, batch_size, bucket_fn, numel_fn, collate_fn):
self.examples_stream = examples_stream
self._buckets = [[] for _ in range(n_buckets)]
self._lens = [0 for _ in range(n_buckets)]
def __init__(self, dataset, look_ahead_size, n_buckets, batch_size, bucket_fn, numel_fn):
self.dataset = dataset
# actual generator of examples
self.examples_stream = iter([])
# tracks whether the stream needs to be restarted
self._is_exhausted = True
self.n_buckets = n_buckets
self._buckets = [
[
[]
for _ in range(n_buckets)
]
for _ in range(n_buckets)
]
self.look_ahead_size = look_ahead_size
self.batch_size = batch_size
self.bucket_fn = bucket_fn
self.numel_fn = numel_fn
self.collate_fn = collate_fn
self.collate_fn = dataset.collate_fn
self._init()

def _init(self):
logger.info('LookAheadBucketing: initialization start')
for example in itertools.islice(self.examples_stream, self.look_ahead_size):
bucket_idx = self.bucket_fn(example)
self._buckets[bucket_idx].append(example)
self._lens[bucket_idx] += 1
self.examples_stream = iter(self.dataset)
for example in range(self.look_ahead_size):
self.maybe_replenish()
if self._is_exhausted:
break
assert not self.is_empty(), 'Dataset contains no usable example!'
logger.info('LookAheadBucketing: initialization done')

def maybe_replenish(self) -> bool:
"""look up one more example to add to this reservoir."""
def maybe_replenish(self):
"""try to look up one more example to add to this reservoir."""
try:
example = next(self.examples_stream)
bucket_idx = self.bucket_fn(example)
creates_new_bucket = self._lens[bucket_idx] == 0
self._buckets[bucket_idx].append(example)
self._lens[bucket_idx] += 1
return creates_new_bucket
s_idx, t_idx = self.bucket_fn(example)
self._buckets[s_idx][t_idx].append(example)
self._is_exhausted = False
except StopIteration:
return None

def bucket_is_empty(self, bucket_idx) -> bool:
return self._lens[bucket_idx] == 0

def _choose_and_prepare_bucket(self, bucket_idx=None):
"""pick a bucket (at random unless specified) and prepare examples for iteration"""
if bucket_idx is None:
bucket_idx = random.choices(range(len(self._buckets)), weights=self._lens, k=1)[0]
# if bucket_idx >= len(self._buckets):
# import pdb; pdb.set_trace()
# if len(self._prefetched[self._buckets[bucket_idx]]) == 0:
# import pdb; pdb.set_trace()
random.shuffle(self._buckets[bucket_idx])
self._is_exhausted = True

def bucket_is_empty(self, s_idx: int, t_idx: int) -> bool:
"""check if this bucket is empty"""
return len(self._buckets[s_idx][t_idx]) == 0

def _choose_bucket(self):
"""pick a bucket at random"""
buckets = [(s, t) for s in range(self.n_buckets) for t in range(self.n_buckets)]
weights = [len(self._buckets[s][t]) for s in range(self.n_buckets) for t in range(self.n_buckets)]
bucket_idx = random.choices(buckets, weights=weights, k=1)[0]
return bucket_idx

def is_empty(self):
return all(size == 0 for size in self._lens)
def _select_from_bucket(self, s_idx: int, t_idx: int) -> object:
"""randomly select an item from a bucket"""
bucket = self._buckets[s_idx][t_idx]
obj_idx = random.randrange(len(bucket))
# swap to last to get O(1) deletion
bucket[obj_idx], bucket[-1] = bucket[-1], bucket[obj_idx]
return bucket.pop()

def is_empty(self) -> bool:
"""check if all buckets are empty"""
return all(len(bucket) == 0 for bucket in itertools.chain.from_iterable(self._buckets))

def _spiralling(self, s_idx: int, t_idx: int):
def _seq():
# from https://math.stackexchange.com/questions/163080/on-a-two-dimensional-grid-is-there-a-formula-i-can-use-to-spiral-coordinates-in#answer-3448361 # noqa: E501
for n in itertools.count(1):
k = math.ceil((math.sqrt(n) - 1) / 2.0)
t = 2 * k + 1
m = t ** 2
t = t - 1
if n >= m - t:
yield k - (m - n), k
else:
m = m - t
if n >= m - t:
yield -k, k - (m - n)
else:
m = m - t
if n >= m - t:
yield -k + (m - n), -k
else:
yield k, -k + (m - n - t)

offsets = ((s_idx + x, t_idx + y) for x, y in _seq())
# offsets = itertools.takewhile(
# # this far out is obviously too far out
# lambda tup: (tup[0] < self.n_buckets * 2 + 1) and (tup[1] < self.n_buckets * 2 + 1),
# offsets,
# )
offsets = filter(
lambda tup: (0 <= tup[0] < self.n_buckets) and (0 <= tup[1] < self.n_buckets),
offsets,
)
# maybe more brittle than the takewhile a few lines above
offsets = itertools.islice(offsets, self.n_buckets ** 2)
yield from offsets

def __iter__(self):
while True:
# 1. maybe we've exhausted the stream and the buckets
if self.is_empty():
break
# 1. maybe we've exhausted both the stream and the buckets:
# if so, we restart the example stream
if self.is_empty() and self._is_exhausted:
self._init()
accum, cur_batch_size = [], 0
# 2. pick a length at random
smallest_bucket_idx = self._choose_and_prepare_bucket()
smallest_bucket_idx = self._choose_bucket()
current_bucket_idx = smallest_bucket_idx
# 3. build batch
batch_is_complete = False
while not batch_is_complete:
# stop either when batch is built or when it can't be built
while not (batch_is_complete or self.is_empty()):
# maybe switch buckets
if self.bucket_is_empty(current_bucket_idx):
if self.is_empty():
logger.info('Reached end of stream') # should not happen
if accum:
yield self.collate_fn(accum)
break
try:
if not any(self._lens[current_bucket_idx:]):
# this was the largest bucket, so we'll need to pick the next smallest instead
smallest_bucket_idx = next(
bucket_idx
for bucket_idx in range(smallest_bucket_idx, -1, -1)
if self._lens[bucket_idx] != 0
)
current_bucket_idx = smallest_bucket_idx
else:
# there was a larger bucket, shift the index by one
current_bucket_idx = next(
bucket_idx
for bucket_idx in range(current_bucket_idx, len(self._buckets) + 1)
if self._lens[bucket_idx] != 0
)
except StopIteration:
logger.warning(
'StopIteration when trying to pick a bucket in a smart way. '
'Doing something stupid instead. Please check me.'
)
current_bucket_idx = next(
bucket_idx
for bucket_idx in range(len(self._lens))
if self._lens[bucket_idx] != 0
)
_ = self._choose_and_prepare_bucket(bucket_idx=current_bucket_idx)
current_bucket_idx = smallest_bucket_idx
next_indices = self._spiralling(*current_bucket_idx)
while self.bucket_is_empty(*current_bucket_idx):
current_bucket_idx = next(next_indices)
# retrieve and process the example
example = self._buckets[current_bucket_idx].pop()
self._lens[current_bucket_idx] -= 1
example = self._select_from_bucket(*current_bucket_idx)
accum.append(example)
numel = self.numel_fn(example)
cur_batch_size += numel
batch_is_complete = cur_batch_size >= self.batch_size

# 4. try to replenish reservoir if possible
# if not, this will also update self._is_exhausted
self.maybe_replenish()
# if (new_bucket is not None) and (new_bucket <= bucket):
# assert self._buckets[bucket_idx] != bucket
# bucket_idx += 1

yield self.collate_fn(accum)
# if self.bucket_is_empty(bucket_idx):
# del self._buckets[bucket_idx]


class DynamicDatasetIter(object):
Expand Down
78 changes: 78 additions & 0 deletions mammoth/tests/test_look_ahead_bucketing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
from itertools import product

import unittest
from mammoth.inputters.dataloader import (
build_dataloader,
LookAheadBucketing,
InferenceBatcher,
)


class hashabledict(dict):
def __hash__(self):
return hash(tuple(sorted(self.items())))


class MockStream():
def __init__(self, items):
self.items = items

def __len__(self):
return len(self.items)

def __getitem__(self, idx):
return self.items[idx]

def __iter__(self):
return iter(self.items)

def collate_fn(self, items):
return items


class TestLookAheadBucketing(unittest.TestCase):

def test_all_read(self):
stream = MockStream([
hashabledict({
'src': tuple([letter for _ in range(i)]),
'tgt': tuple([letter for _ in range(j)]),
})
for letter in 'xyz'
for i, j in product(range(1, 11), range(1, 11))
])
lab = build_dataloader(stream, 2, 'tokens', pool_size=4, n_buckets=4, cycle=True, as_iter=False)
examples_read = []
batches = iter(lab)
while not (lab._is_exhausted and lab.is_empty()):
examples_read.extend(next(batches))
sorted_src_ref = sorted([ex['src'] for ex in stream.items])
sorted_src_obs = sorted([ex['src'] for ex in examples_read])
self.assertTrue(sorted_src_ref == sorted_src_obs)
sorted_tgt_ref = sorted([ex['tgt'] for ex in stream.items])
sorted_tgt_obs = sorted([ex['tgt'] for ex in examples_read])
self.assertTrue(sorted_tgt_ref == sorted_tgt_obs)

def test_reroutes(self):
stream = MockStream([hashabledict({'src': '_', 'tgt': '_'})] * 10)
lab = build_dataloader(stream, 2, 'tokens', 4, 2, cycle=True, as_iter=False)
self.assertTrue(type(lab) is LookAheadBucketing)
not_lab = build_dataloader(stream, 2, 'tokens', 4, 2, cycle=False, as_iter=False)
self.assertTrue(type(not_lab) is InferenceBatcher)

def test_always_continues(self):
stream = MockStream([hashabledict({'src': '_', 'tgt': '_'})] * 10)
was_exhausted = False
stopped_exhaustion = False
lab = build_dataloader(stream, 2, 'tokens', pool_size=4, n_buckets=4, cycle=True, as_iter=False)
batches = iter(lab)
all_items = []
for _ in range(len(stream) * 3 // 2):
all_items.extend(next(batches))
was_exhausted = was_exhausted or lab._is_exhausted
if was_exhausted:
stopped_exhaustion = stopped_exhaustion or not lab._is_exhausted

self.assertTrue(was_exhausted)
self.assertTrue(stopped_exhaustion)
self.assertTrue(len(all_items) > len(stream))