Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wip/translation api #211

Merged
merged 7 commits into from
Oct 12, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions genienlp/data_utils/almond_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,7 @@ def return_sentences(text, regex_pattern, src_char_spans, is_cjk=False):


def split_text_into_sentences(text, lang, src_char_spans):
# text = '''the . " ${field} " . of . " ${value} " .'''
if lang in ['en']:
sentences = return_sentences(text, '(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=[\.!?])\s', src_char_spans)
elif lang in ['zh', 'ja', 'ko']:
Expand Down
8 changes: 7 additions & 1 deletion genienlp/data_utils/example.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,13 @@ def from_examples(examples: Iterable[Example], numericalizer):
all_context_plus_question_features = []

for ex in examples:
context_plus_question = ex.context + sep_token + ex.question if len(ex.question) else ex.context
if not len(ex.question):
context_plus_question = ex.context
elif not len(ex.context):
context_plus_question = ex.question
else:
context_plus_question = ex.context + sep_token + ex.question

all_context_plus_questions.append(context_plus_question)

# concatenate question and context features with a separator, but no need for a separator if there are no features to begin with
Expand Down
42 changes: 23 additions & 19 deletions genienlp/data_utils/numericalizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ def __init__(
self.max_generative_vocab = max_generative_vocab
self._cache = args.embeddings
self._tokenizer = None
self.config = config

self._preprocess_special_tokens = args.preprocess_special_tokens

Expand All @@ -126,6 +127,8 @@ def __init__(

self._init_tokenizer(save_dir, config, src_lang, tgt_lang)

self.update_language_dependent_properties(src_lang, tgt_lang)

if save_dir is not None:
logger.info(f'Loading the accompanying numericalizer from {save_dir}')
self.load_extras(save_dir)
Expand Down Expand Up @@ -179,25 +182,6 @@ def _init_tokenizer(self, save_dir, config, src_lang, tgt_lang):

self._tokenizer = AutoTokenizer.from_pretrained(**tokenizer_args)

# some tokenizers like Mbart do not set src_lang and tgt_lan when initialized; take care of it here
self._tokenizer.src_lang = src_lang
self._tokenizer.tgt_lang = tgt_lang

# define input prefix to add before every input text
input_prefix = ''
if isinstance(config, MarianConfig) and tgt_lang:
input_prefix = f'>>{tgt_lang}<< '
# only older T5 models need task-specific input prefix
elif self._pretrained_name in T5_PRETRAINED_CONFIG_ARCHIVE_MAP.keys():
assert src_lang == 'en'
if tgt_lang == 'en':
t5_task = 'summarization'
else:
t5_task = f'translation_en_to_{tgt_lang}'
input_prefix = config.task_specific_params[t5_task]['prefix']

self.input_prefix = input_prefix

# We only include the base tokenizers since `isinstance` checks for inheritance
if isinstance(self._tokenizer, (BertTokenizer, BertTokenizerFast)):
self._tokenizer.is_piece_fn = lambda wp: wp.startswith('##')
Expand All @@ -223,6 +207,26 @@ def _init_tokenizer(self, save_dir, config, src_lang, tgt_lang):
# make sure we assigned is_piece_fn
assert self._tokenizer.is_piece_fn

def update_language_dependent_properties(self, src_lang, tgt_lang):
# some tokenizers like Mbart do not set src_lang and tgt_lan when initialized; take care of it here
self._tokenizer.src_lang = src_lang
self._tokenizer.tgt_lang = tgt_lang

# define input prefix to add before every input text
input_prefix = ''
if isinstance(self.config, MarianConfig) and tgt_lang:
input_prefix = f'>>{tgt_lang}<< '
# only older T5 models need task-specific input prefix
elif self._pretrained_name in T5_PRETRAINED_CONFIG_ARCHIVE_MAP.keys():
assert src_lang == 'en'
if tgt_lang == 'en':
t5_task = 'summarization'
else:
t5_task = f'translation_en_to_{tgt_lang}'
input_prefix = self.config.task_specific_params[t5_task]['prefix']

self.input_prefix = input_prefix

def load_extras(self, save_dir):
if self.max_generative_vocab is not None:
with open(os.path.join(save_dir, 'decoder-vocab.txt'), 'r') as fp:
Expand Down
41 changes: 22 additions & 19 deletions genienlp/models/transformer_seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def __init__(self, config=None, *inputs, args, tasks, vocab_sets, save_directory
If `save_directory` is None, will initialize a new model and numericalizer, otherwise, will load them from `save_directory`
"""
config = AutoConfig.from_pretrained(args.pretrained_model, cache_dir=args.embeddings)
self.config = config
super().__init__(config)
self.args = args
args.dimension = config.d_model
Expand All @@ -57,7 +58,7 @@ def __init__(self, config=None, *inputs, args, tasks, vocab_sets, save_directory
# call this function after task is recognized
if tasks:
self.set_generation_output_options(tasks)

# only used for Marian models. adjusted language codes passed to numericalizer will be None for models trained on single langauge pairs
self.orig_src_lang, self.orig_tgt_lang = kwargs.get('src_lang', 'en'), kwargs.get('tgt_lang', 'en')
self.src_lang, self.tgt_lang = adjust_language_code(
Expand All @@ -81,26 +82,9 @@ def __init__(self, config=None, *inputs, args, tasks, vocab_sets, save_directory
tasks=tasks,
)

self.update_language_dependent_configs(self.tgt_lang)
self.model.resize_token_embeddings(self.numericalizer.num_tokens)

# set decoder_start_token_id for mbart
if self.model.config.decoder_start_token_id is None and isinstance(
self.numericalizer._tokenizer, (MBartTokenizer, MBartTokenizerFast)
):
if isinstance(self.numericalizer._tokenizer, MBartTokenizer):
self.model.config.decoder_start_token_id = self.numericalizer._tokenizer.lang_code_to_id[self.tgt_lang]
else:
self.model.config.decoder_start_token_id = self.numericalizer._tokenizer.convert_tokens_to_ids(self.tgt_lang)

# check decoder_start_token_id is set
if self.model.config.decoder_start_token_id is None:
raise ValueError("Make sure that decoder_start_token_id for the model is defined")

# set forced_bos_token_id for certain multilingual models
if isinstance(self.numericalizer._tokenizer, MULTILINGUAL_TOKENIZERS):
forced_bos_token_id = self.numericalizer._tokenizer.lang_code_to_id[self.tgt_lang]
self.model.config.forced_bos_token_id = forced_bos_token_id

if args.dropper_ratio > 0:
# lazy import since dropper is an optional dependency
from loss_dropper import LossDropper
Expand All @@ -115,6 +99,25 @@ def add_new_vocab_from_data(self, tasks, resize_decoder=False):
super().add_new_vocab_from_data(tasks, resize_decoder)
self.model.resize_token_embeddings(self.numericalizer.num_tokens)

def update_language_dependent_configs(self, tgt_lang):
# set decoder_start_token_id for mbart
if self.config.decoder_start_token_id is None and isinstance(
self.numericalizer._tokenizer, (MBartTokenizer, MBartTokenizerFast)
):
if isinstance(self.numericalizer._tokenizer, MBartTokenizer):
self.config.decoder_start_token_id = self.numericalizer._tokenizer.lang_code_to_id[tgt_lang]
else:
self.config.decoder_start_token_id = self.numericalizer._tokenizer.convert_tokens_to_ids(tgt_lang)

# check decoder_start_token_id is set
if self.config.decoder_start_token_id is None:
raise ValueError("Make sure that decoder_start_token_id for the model is defined")

# set forced_bos_token_id for certain multilingual models
if isinstance(self.numericalizer._tokenizer, MULTILINGUAL_TOKENIZERS):
forced_bos_token_id = self.numericalizer._tokenizer.lang_code_to_id[tgt_lang]
self.config.forced_bos_token_id = forced_bos_token_id

def forward(self, *input, **kwargs):
if self.training or kwargs.get('train', False):
batch = input[0]
Expand Down
11 changes: 7 additions & 4 deletions genienlp/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,15 +110,13 @@ def parse_argv(parser):
type=str,
nargs='+',
dest='pred_src_languages',
default=['en'],
help='Specify dataset source languages used during prediction for multilingual tasks'
'multiple languages for each task should be concatenated with +',
)
parser.add_argument(
'--pred_tgt_languages',
type=str,
nargs='+',
default=['en'],
help='Specify dataset target languages used during prediction for multilingual tasks'
'multiple languages for each task should be concatenated with +',
)
Expand Down Expand Up @@ -254,10 +252,15 @@ def set_default_values(args):

def check_args(args):

if not args.pred_src_languages:
setattr(args, 'pred_src_languages', [args.eval_src_languages])
if not args.pred_tgt_languages:
setattr(args, 'pred_tgt_languages', [args.eval_tgt_languages])

if len(args.task_names) != len(args.pred_src_languages):
raise ValueError(
'You have to define prediction languages for each task'
'Use None for single language tasks. Also provide languages in the same order you provided the tasks.'
'You have to define prediction languages for each task.'
' Use None for single language tasks. Also provide languages in the same order you provided the tasks.'
)

if getattr(args, 'do_ned', False) and getattr(args, 'ned_retrieve_method', None) == 'bootleg':
Expand Down
62 changes: 51 additions & 11 deletions genienlp/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@


import asyncio
import copy
import json
import logging
import os
Expand All @@ -44,11 +45,30 @@
from .data_utils.example import Example, NumericalizedExamples
from .ned.ned_utils import init_ned_model
from .tasks.registry import get_tasks
from .util import get_devices, load_config_json, log_model_size, set_seed
from .util import adjust_language_code, get_devices, load_config_json, log_model_size, set_seed
from .validate import generate_with_model

logger = logging.getLogger(__name__)

GENERATION_ARGUMENTS = {
'num_beams',
'num_beam_groups',
'diversity_penalty',
'num_outputs',
'no_repeat_ngram_size',
'top_p',
'top_k',
'repetition_penalty',
'temperature',
'max_output_length',
'src_locale',
'tgt_locale',
'do_alignment',
'align_preserve_input_quotation',
'align_remove_output_quotation',
'translate_example_split',
}


def parse_argv(parser):
parser.add_argument('--path', type=str, required=True)
Expand All @@ -63,8 +83,8 @@ def parse_argv(parser):
parser.add_argument('--port', default=8401, type=int, help='TCP port to listen on')
parser.add_argument('--stdin', action='store_true', help='Interact on stdin/stdout instead of TCP')
parser.add_argument('--database_dir', type=str, help='Database folder containing all relevant files')
parser.add_argument('--src_locale', default='en', help='locale tag of the input language to parse')
parser.add_argument('--tgt_locale', default='en', help='locale tag of the target language to generate')
parser.add_argument('--src_locale', help='locale tag of the input language to parse')
parser.add_argument('--tgt_locale', help='locale tag of the target language to generate')
parser.add_argument('--inference_name', default='nlp', help='name used by kfserving inference service, alphanumeric only')

# These are generation hyperparameters. Each one can be a list of values in which case, we generate `num_outputs` outputs for each set of hyperparameters.
Expand Down Expand Up @@ -120,8 +140,24 @@ def numericalize_examples(self, ex):
return NumericalizedExamples.collate_batches(all_features, self.numericalizer, device=self.device)

def handle_request(self, request):
args = copy.deepcopy(self.args)
generation_options = request.get('options', {})
for k, v in generation_options.items():
if k not in GENERATION_ARGUMENTS:
logger.warning(f'{k} is not a generation option and cannot be overriden')
continue
setattr(args, k, v)

# TODO handle this better by decoupling numericalizer and model
if hasattr(args, 'src_locale') and hasattr(args, 'tgt_locale'):
src_locale, tgt_locale = adjust_language_code(
self.model.config, self.args.pretrained_model, args.src_locale, args.tgt_locale
)
self.numericalizer.update_language_dependent_properties(src_locale, tgt_locale)
self.model.update_language_dependent_configs(tgt_locale)
Mehrad0711 marked this conversation as resolved.
Show resolved Hide resolved

task_name = request['task'] if 'task' in request else 'generic'
task = list(get_tasks([task_name], self.args, self._cached_task_names).values())[0]
task = list(get_tasks([task_name], args, self._cached_task_names).values())[0]
if task_name not in self._cached_task_names:
self._cached_task_names[task_name] = task

Expand Down Expand Up @@ -151,7 +187,7 @@ def handle_request(self, request):
question = task.default_question

ex = Example.from_raw(
str(example_id), context, question, answer, preprocess=task.preprocess_field, lower=self.args.lower
str(example_id), context, question, answer, preprocess=task.preprocess_field, lower=args.lower
)
examples.append(ex)

Expand All @@ -165,18 +201,18 @@ def handle_request(self, request):

try:
with torch.no_grad():
if self.args.calibrator_paths is not None:
if args.calibrator_paths is not None:
output = generate_with_model(
self.model,
[batch],
self.numericalizer,
task,
self.args,
args,
output_predictions_only=True,
confidence_estimators=self.confidence_estimators,
)
response = []
if sum(self.args.num_outputs) > 1:
if sum(args.num_outputs) > 1:
for idx, predictions in enumerate(output.predictions):
candidates = []
for cand in predictions:
Expand All @@ -193,9 +229,9 @@ def handle_request(self, request):
response.append(instance)
else:
output = generate_with_model(
self.model, [batch], self.numericalizer, task, self.args, output_predictions_only=True
self.model, [batch], self.numericalizer, task, args, output_predictions_only=True
)
if sum(self.args.num_outputs) > 1:
if sum(args.num_outputs) > 1:
response = []
for idx, predictions in enumerate(output.predictions):
candidates = []
Expand All @@ -222,7 +258,7 @@ def handle_json_request(self, line: str) -> str:
assert len(response) == 1
response = response[0]
response['id'] = request['id']
return json.dumps(response) + '\n'
return json.dumps(response, ensure_ascii=False) + '\n'

async def handle_client(self, client_reader, client_writer):
try:
Expand Down Expand Up @@ -274,6 +310,10 @@ def run(self):
def init(args):
load_config_json(args)
check_and_update_generation_args(args)
if not args.src_locale:
args.src_locale = args.eval_src_languages
if not args.tgt_locale:
args.tgt_locale = args.eval_tgt_languages
set_seed(args)

devices = get_devices()
Expand Down
Loading