Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feats/bucket lord #23

Merged
merged 13 commits into from
Oct 23, 2023
120 changes: 78 additions & 42 deletions mammoth/inputters/dataloader.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import collections
import itertools
import math
import random

import torch
Expand Down Expand Up @@ -30,15 +31,18 @@ def numel_fn(_):
elif batch_type == 'tokens':

def bucket_fn(example_dict):
"""map example dict to bucket index"""
# subtract two for bos/eos
src_len = min(len(example_dict['src']), n_buckets) - 2
if 'tgt' in example_dict:
# subtract four for bos/eos on both sides
true_size = len(example_dict['src']) + len(example_dict['tgt']) - 4
tgt_len = min(len(example_dict['tgt']), n_buckets) - 2
else:
true_size = len(example_dict['src']) + 2
tgt_len = src_len
# maybe dump it in the last bucket if it's just too long
return min(n_buckets - 1, true_size)
return src_len, tgt_len

def numel_fn(example_dict):
"""count tokens in example"""
if 'tgt' in example_dict:
true_size = len(example_dict['src']) + len(example_dict['tgt'])
else:
Expand Down Expand Up @@ -74,8 +78,21 @@ def __iter__(self):
class LookAheadBucketing():
def __init__(self, examples_stream, look_ahead_size, n_buckets, batch_size, bucket_fn, numel_fn, collate_fn):
self.examples_stream = examples_stream
self._buckets = [[] for _ in range(n_buckets)]
self._lens = [0 for _ in range(n_buckets)]
self.n_buckets = n_buckets
self._buckets = [
[
[]
for _ in range(n_buckets)
]
for _ in range(n_buckets)
]
self._lens = [
TimotheeMickus marked this conversation as resolved.
Show resolved Hide resolved
[
0
for _ in range(n_buckets)
]
for _ in range(n_buckets)
]
self.look_ahead_size = look_ahead_size
self.batch_size = batch_size
self.bucket_fn = bucket_fn
Expand All @@ -86,40 +103,75 @@ def __init__(self, examples_stream, look_ahead_size, n_buckets, batch_size, buck
def _init(self):
logger.info('LookAheadBucketing: initialization start')
for example in itertools.islice(self.examples_stream, self.look_ahead_size):
bucket_idx = self.bucket_fn(example)
self._buckets[bucket_idx].append(example)
self._lens[bucket_idx] += 1
s_bucket, t_bucket = self.bucket_fn(example)
self._buckets[s_bucket][t_bucket].append(example)
self._lens[s_bucket][t_bucket] += 1
logger.info('LookAheadBucketing: initialization done')

def maybe_replenish(self) -> bool:
"""look up one more example to add to this reservoir."""
try:
example = next(self.examples_stream)
bucket_idx = self.bucket_fn(example)
creates_new_bucket = self._lens[bucket_idx] == 0
self._buckets[bucket_idx].append(example)
self._lens[bucket_idx] += 1
s_bucket, t_bucket = self.bucket_fn(example)
creates_new_bucket = self._lens[s_bucket][t_bucket] == 0
self._buckets[s_bucket][t_bucket].append(example)
self._lens[s_bucket][t_bucket] += 1
return creates_new_bucket
except StopIteration:
return None

def bucket_is_empty(self, bucket_idx) -> bool:
return self._lens[bucket_idx] == 0
def bucket_is_empty(self, s_idx, t_idx) -> bool:
return self._lens[s_idx][t_idx] == 0

def _choose_and_prepare_bucket(self, bucket_idx=None):
"""pick a bucket (at random unless specified) and prepare examples for iteration"""
if bucket_idx is None:
bucket_idx = random.choices(range(len(self._buckets)), weights=self._lens, k=1)[0]
buckets = [(s, t) for s in range(self.n_buckets) for t in range(self.n_buckets)]
weights = [self._lens[s][t] for s in range(self.n_buckets) for t in range(self.n_buckets)]
s_bucket, t_bucket = random.choices(buckets, weights=weights, k=1)[0]
# if bucket_idx >= len(self._buckets):
# import pdb; pdb.set_trace()
# if len(self._prefetched[self._buckets[bucket_idx]]) == 0:
# import pdb; pdb.set_trace()
random.shuffle(self._buckets[bucket_idx])
return bucket_idx
random.shuffle(self._buckets[s_bucket][t_bucket])
return s_bucket, t_bucket

def is_empty(self):
return all(size == 0 for size in self._lens)
TimotheeMickus marked this conversation as resolved.
Show resolved Hide resolved

def _spiralling(self, s_idx, t_idx):
def _seq():
# from https://math.stackexchange.com/questions/163080/
# on-a-two-dimensional-grid-is-there-a-formula-i-can-use-to-spiral-coordinates-in#answer-3448361
for n in itertools.count(1):
k = math.ceil((math.sqrt(n) - 1) / 2.0)
t = 2 * k + 1
m = t ** 2
t = t - 1
if n >= m - t:
yield s_idx + k - (m - n), t_idx - k
else:
m = m - t
if n >= m - t:
yield s_idx - k, t_idx - k + (m - n)
else:
m = m - t
if n >= m - t:
yield s_idx - k + (m - n), t_idx + k
else:
yield s_idx + k, t_idx + k - (m - n - t)

offsets = map(lambda tup: (tup[0] + s_idx, tup[1] + t_idx), _seq())
offsets = filter(
lambda tup: (0 <= tup[0] < self.n_buckets) and (0 <= tup[1] < self.n_buckets),
offsets,
)
offsets = filter(
lambda tup: self._lens[tup[0]][tup[1]] > 0,
offsets,
)
yield from offsets

def __iter__(self):
while True:
# 1. maybe we've exhausted the stream and the buckets
Expand All @@ -132,32 +184,16 @@ def __iter__(self):
# 3. build batch
batch_is_complete = False
while not batch_is_complete:
assert not self.is_empty(), 'Stream should never end!'
# maybe switch buckets
if self.bucket_is_empty(current_bucket_idx):
if self.is_empty():
logger.info('Reached end of stream') # should not happen
if accum:
yield self.collate_fn(accum)
break
if not any(self._lens[current_bucket_idx:]):
# this was the largest bucket, so we'll need to pick the next smallest instead
smallest_bucket_idx = next(
bucket_idx
for bucket_idx in range(smallest_bucket_idx, -1, -1)
if self._lens[bucket_idx] != 0
)
current_bucket_idx = smallest_bucket_idx
else:
# there was a larger bucket, shift the index by one
current_bucket_idx = next(
bucket_idx
for bucket_idx in range(current_bucket_idx, len(self._buckets) + 1)
if self._lens[bucket_idx] != 0
)
_ = self._choose_and_prepare_bucket(bucket_idx=current_bucket_idx)
current_bucket_idx = smallest_bucket_idx
next_indices = self._spiralling(*current_bucket_idx)
while self.bucket_is_empty(*current_bucket_idx):
current_bucket_idx = next(next_indices)
# retrieve and process the example
example = self._buckets[current_bucket_idx].pop()
self._lens[current_bucket_idx] -= 1
s_bucket, t_bucket = current_bucket_idx
example = self._buckets[s_bucket][t_bucket].pop()
self._lens[s_bucket][t_bucket] -= 1
accum.append(example)
numel = self.numel_fn(example)
cur_batch_size += numel
Expand Down