Skip to content

Commit

Permalink
Fixes a bug introduced with the revised data iterator: max_observed_{…
Browse files Browse the repository at this point in the history
…source,target}_len should only be computed on sentence pairs added to the buckets (#184)
  • Loading branch information
fhieber authored Nov 6, 2017
1 parent 4c091e1 commit 20587b9
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 9 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@ Note that Sockeye has checks in place to not translate with an old model that wa

For each item we will potentially have subsections for: _Added_, _Changed_, _Removed_, _Deprecated_, and _Fixed_.

## [1.10.3]
### Changed
- Fixed a bug with max_observed_{source,target}_len being computed on the complete data set, not only on the
sentences actually added to the buckets based on `--max_seq_len`.

## [1.10.2]
### Added
- `--max-num-epochs` flag to train for a maximum number of passes through the training data.
Expand Down
2 changes: 1 addition & 1 deletion sockeye/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

__version__ = '1.10.1'
__version__ = '1.10.3'
16 changes: 8 additions & 8 deletions sockeye/data_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,6 +523,11 @@ def _assign_to_buckets(self, source_sentences, target_sentences):
for source, target in zip(source_sentences, target_sentences):
source_len = len(source)
target_len = len(target)
buck_idx, buck = get_parallel_bucket(self.buckets, source_len, target_len)
if buck is None:
ndiscard += 1
continue # skip this sentence pair

tokens_source += source_len
tokens_target += target_len
if source_len > self.max_observed_source_len:
Expand All @@ -533,18 +538,13 @@ def _assign_to_buckets(self, source_sentences, target_sentences):
num_of_unks_source += source.count(self.unk_id)
num_of_unks_target += target.count(self.unk_id)

buck_idx, buck = get_parallel_bucket(self.buckets, len(source), len(target))
if buck is None:
ndiscard += 1
continue

buff_source = np.full((buck[0],), self.pad_id, dtype=self.dtype)
buff_target = np.full((buck[1],), self.pad_id, dtype=self.dtype)
buff_source[:len(source)] = source
buff_target[:len(target)] = target
buff_source[:source_len] = source
buff_target[:target_len] = target
self.data_source[buck_idx].append(buff_source)
self.data_target[buck_idx].append(buff_target)
self.data_target_average_len[buck_idx] += len(target)
self.data_target_average_len[buck_idx] += target_len

# Average number of non-padding elements in target sequence per bucket
for buck_idx, buck in enumerate(self.buckets):
Expand Down

0 comments on commit 20587b9

Please sign in to comment.