Skip to content

Commit

Permalink
Merge pull request #1646 from OlegPonomaryov:batch_pad_fix
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 405671232
  • Loading branch information
copybara-github committed Oct 26, 2021
2 parents de95e6e + 00da919 commit 96cb76c
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 1 deletion.
4 changes: 3 additions & 1 deletion trax/data/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -703,7 +703,9 @@ def batch(generator, batch_size):
# buf is a list of tuples, e.g., [(in1, tgt1), (in2, tgt2), (in3, tgt3)]
# batch is a tuple of arrays: ([in1, in2, in3], [tgt1, tgt2, tgt3])
try:
batched_example = tuple(np.stack(x) for x in zip(*buf))
batched_example = tuple(
pad_to_max_dims([np.asarray(tensor) for tensor in x])
for x in zip(*buf))
except ValueError as e:
for j in range(len(buf)):
logging.error('Batch[%d][%d] input shape: %r output shape: %r',
Expand Down
7 changes: 7 additions & 0 deletions trax/data/inputs_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,13 @@ def test_batch_data(self):
self.assertLen(batch, 2)
self.assertEqual(batch[0].shape, (10,))

def test_batch_data_padding(self):
dataset = (([1] * (10 - i), i+1) for i in range(10))
batches = data.batch(dataset, 10)
batch = next(batches)
self.assertEqual(batch[0].shape, (10, 10))
self.assertTrue(np.array_equal(batch[0][-1], np.asarray([1] + 9 * [0])))

def test_batch_exception_size(self):
dataset = ((i, i + 1) for i in range(10))
with self.assertRaises(ValueError):
Expand Down

0 comments on commit 96cb76c

Please sign in to comment.