From 2d1e7820ff1d44bb0d762c2aebdeed01056195e9 Mon Sep 17 00:00:00 2001 From: Dirk Groeneveld Date: Mon, 6 Apr 2020 16:30:39 -0700 Subject: [PATCH] Get a vocab object from the reader --- allennlp/commands/train.py | 33 ++++++++++++++----- .../data/dataset_readers/dataset_reader.py | 6 ++++ allennlp/data/vocabulary.py | 16 +++++++++ 3 files changed, 47 insertions(+), 8 deletions(-) diff --git a/allennlp/commands/train.py b/allennlp/commands/train.py index 6514bbd689b..b72f9fcf34a 100644 --- a/allennlp/commands/train.py +++ b/allennlp/commands/train.py @@ -552,12 +552,12 @@ def from_partial_objects( model: Lazy[Model], data_loader: Lazy[DataLoader], trainer: Lazy[Trainer], - vocabulary: Lazy[Vocabulary] = None, - datasets_for_vocab_creation: List[str] = None, - validation_dataset_reader: DatasetReader = None, - validation_data_path: str = None, - validation_data_loader: Lazy[DataLoader] = None, - test_data_path: str = None, + vocabulary: Optional[Lazy[Vocabulary]] = None, + datasets_for_vocab_creation: Optional[List[str]] = None, + validation_dataset_reader: Optional[DatasetReader] = None, + validation_data_path: Optional[str] = None, + validation_data_loader: Optional[Lazy[DataLoader]] = None, + test_data_path: Optional[str] = None, evaluate_on_test: bool = False, ) -> "TrainModel": """ @@ -633,21 +633,38 @@ def from_partial_objects( test_data_path=test_data_path, ) - if datasets_for_vocab_creation: + if datasets_for_vocab_creation is None: + datasets_for_vocab_creation = datasets.keys() + else: for key in datasets_for_vocab_creation: if key not in datasets: raise ConfigurationError(f"invalid 'dataset_for_vocab_creation' {key}") + key_to_dataset_reader = { + "train": dataset_reader, + "test": validation_dataset_reader or dataset_reader, + "validation": validation_dataset_reader or dataset_reader, + } + vocabulary_from_readers = None + for key in datasets_for_vocab_creation: + reader_vocab = key_to_dataset_reader[key].get_vocabulary() + if vocabulary_from_readers is None: + vocabulary_from_readers = reader_vocab + else: + vocabulary_from_readers.extend_from_vocab(reader_vocab) + instance_generator = ( instance for key, dataset in datasets.items() - if not datasets_for_vocab_creation or key in datasets_for_vocab_creation + if key in datasets_for_vocab_creation for instance in dataset ) vocabulary_ = vocabulary.construct(instances=instance_generator) if not vocabulary_: vocabulary_ = Vocabulary.from_instances(instance_generator) + if vocabulary_from_readers is not None: + vocabulary_.extend_from_vocab(vocabulary_from_readers) model_ = model.construct(vocab=vocabulary_) # Initializing the model can have side effect of expanding the vocabulary. diff --git a/allennlp/data/dataset_readers/dataset_reader.py b/allennlp/data/dataset_readers/dataset_reader.py index 22f2265ea98..0cb122709b6 100644 --- a/allennlp/data/dataset_readers/dataset_reader.py +++ b/allennlp/data/dataset_readers/dataset_reader.py @@ -229,6 +229,12 @@ def _read(self, file_path: str) -> Iterable[Instance]: """ raise NotImplementedError + def get_vocabulary(self) -> Optional[Vocabulary]: + """Returns the vocabulary used in the created instances. By default, this + returns `None`, which causes the vocabulary to be automatically discovered + before training.""" + return None + def _instances_from_cache_file(self, cache_filename: str) -> Iterable[Instance]: with open(cache_filename, "r") as cache_file: for line in cache_file: diff --git a/allennlp/data/vocabulary.py b/allennlp/data/vocabulary.py index 470e700f404..1c18a787120 100644 --- a/allennlp/data/vocabulary.py +++ b/allennlp/data/vocabulary.py @@ -366,6 +366,14 @@ def from_files_and_instances( ) return vocab + @classmethod + def from_transformers(cls, model_name: str, namespace: str = "tokens"): + vocab = cls.empty() + import transformers + tokenizer = transformers.AutoTokenizer.from_pretrained(model_name) + vocab.extend_from_dictionary(tokenizer.get_vocab(), namespace) + return vocab + @classmethod def empty(cls) -> "Vocabulary": """ @@ -455,6 +463,14 @@ def extend_from_vocab(self, vocab: "Vocabulary") -> None: for token in vocab.get_token_to_index_vocabulary(namespace): self.add_token_to_namespace(token, namespace) + def extend_from_dictionary(self, encoding_dictionary: Dict[str, int], namespace: str = "from_transformers") -> None: + """ + Populates given namespace with precomputed encoding, for example from pretrained transformers. + """ + for word, idx in encoding_dictionary.items(): + self._token_to_index[namespace][word] = idx + self._index_to_token[namespace][idx] = word + def _extend( self, counter: Dict[str, Dict[str, int]] = None,