Skip to content

Commit

Permalink
Support of new option offset for ignoring token score of special toke…
Browse files Browse the repository at this point in the history
…ns (#1592)

Co-authored-by: asenellart <[email protected]>
  • Loading branch information
ASenart and asenellart authored Dec 26, 2023
1 parent 4f8a4f3 commit bb6b841
Show file tree
Hide file tree
Showing 5 changed files with 18 additions and 6 deletions.
4 changes: 3 additions & 1 deletion include/ctranslate2/scoring.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ namespace ctranslate2 {
struct ScoringOptions {
// Truncate the inputs after this many tokens (set 0 to disable truncation).
size_t max_input_length = 1024;
dim_t offset = 0;
};

struct ScoringResult {
Expand All @@ -38,6 +39,7 @@ namespace ctranslate2 {
layers::DecoderState& state,
const std::vector<std::vector<size_t>>& sequences,
const Vocabulary& vocabulary,
const dim_t preferred_size_multiple = 1);
const dim_t preferred_size_multiple = 1,
const dim_t offset=0);

}
9 changes: 8 additions & 1 deletion python/cpp/translator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -228,10 +228,12 @@ namespace ctranslate2 {
size_t max_batch_size,
const std::string& batch_type_str,
size_t max_input_length,
dim_t offset,
bool asynchronous) {
const auto batch_type = str_to_batch_type(batch_type_str);
ScoringOptions options;
options.max_input_length = max_input_length;
options.offset = offset;

std::shared_lock lock(_mutex);
assert_model_is_ready();
Expand All @@ -252,6 +254,7 @@ namespace ctranslate2 {
size_t read_batch_size,
const std::string& batch_type_str,
size_t max_input_length,
dim_t offset,
bool with_tokens_score,
const TokenizeFn& source_tokenize_fn,
const TokenizeFn& target_tokenize_fn,
Expand All @@ -263,7 +266,7 @@ namespace ctranslate2 {
const auto batch_type = str_to_batch_type(batch_type_str);
ScoringOptions options;
options.max_input_length = max_input_length;

options.offset = offset;
std::shared_lock lock(_mutex);
assert_model_is_ready();

Expand Down Expand Up @@ -592,6 +595,7 @@ namespace ctranslate2 {
py::arg("max_batch_size")=0,
py::arg("batch_type")="examples",
py::arg("max_input_length")=1024,
py::arg("offset") = 0,
py::arg("asynchronous")=false,
py::call_guard<py::gil_scoped_release>(),
R"pbdoc(
Expand All @@ -606,6 +610,7 @@ namespace ctranslate2 {
minimized.
batch_type: Whether :obj:`max_batch_size` is the number of "examples" or "tokens".
max_input_length: Truncate inputs after this many tokens (0 to disable).
offset: Ignore the first n tokens in target in score calculation.
asynchronous: Run the scoring asynchronously.
Returns:
Expand All @@ -621,6 +626,7 @@ namespace ctranslate2 {
py::arg("read_batch_size")=0,
py::arg("batch_type")="examples",
py::arg("max_input_length")=1024,
py::arg("offset")=0,
py::arg("with_tokens_score")=false,
py::arg("source_tokenize_fn")=nullptr,
py::arg("target_tokenize_fn")=nullptr,
Expand Down Expand Up @@ -649,6 +655,7 @@ namespace ctranslate2 {
batch_type: Whether :obj:`max_batch_size` and :obj:`read_batch_size` are the
number of "examples" or "tokens".
max_input_length: Truncate inputs after this many tokens (0 to disable).
offset: Ignore the first n tokens in target in score calculation.
with_tokens_score: Include the token-level scores in the output file.
source_tokenize_fn: Function to tokenize source lines.
target_tokenize_fn: Function to tokenize target lines.
Expand Down
3 changes: 2 additions & 1 deletion src/models/language_model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,8 @@ namespace ctranslate2 {
state,
ids,
vocabulary,
_model->preferred_size_multiple());
_model->preferred_size_multiple(),
options.offset);
}

bool DecoderReplica::skip_scoring(const std::vector<std::string>& tokens,
Expand Down
3 changes: 2 additions & 1 deletion src/models/sequence_to_sequence.cc
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,8 @@ namespace ctranslate2 {
state,
target_ids,
_model->get_target_vocabulary(),
_model->preferred_size_multiple());
_model->preferred_size_multiple(),
options.offset);
}

bool EncoderDecoderReplica::skip_scoring(const std::vector<std::string>& source,
Expand Down
5 changes: 3 additions & 2 deletions src/scoring.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ namespace ctranslate2 {
layers::DecoderState& state,
const std::vector<std::vector<size_t>>& sequences,
const Vocabulary& vocabulary,
const dim_t preferred_size_multiple) {
const dim_t preferred_size_multiple,
const dim_t offset) {
const dim_t batch_size = sequences.size();
const Device device = decoder.device();

Expand Down Expand Up @@ -57,7 +58,7 @@ namespace ctranslate2 {
auto& result = results[b];
result.tokens.reserve(output_length);
result.tokens_score.reserve(output_length);
for (dim_t t = 0; t < output_length; ++t) {
for (dim_t t = offset; t < output_length; ++t) {
result.tokens.emplace_back(vocabulary.to_token(output_sequences[b][t]));
result.tokens_score.emplace_back(scores.at<float>({b, t}));
}
Expand Down

0 comments on commit bb6b841

Please sign in to comment.