Skip to content

Commit

Permalink
fix add token to vocab
Browse files Browse the repository at this point in the history
  • Loading branch information
pavanchhatpar committed May 10, 2020
1 parent 8b4ae66 commit a3b65c4
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 21 deletions.
3 changes: 2 additions & 1 deletion copynet_tf/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ def __init__(self,
super(Decoder, self).__init__(**kwargs)
self.vocab = vocab
self.logger = logging.getLogger(__name__)
self._copy_index = self.vocab.add_token(copy_token, "target")
self._copy_index = self.vocab.get_token_id(
copy_token, "target")
self._unk_index = self.vocab.get_token_id(
self.vocab._unk_token, "target")
self._start_index = self.vocab.get_token_id(
Expand Down
7 changes: 4 additions & 3 deletions copynet_tf/vocab.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,18 +136,19 @@ def get_token_id(self, token, namespace):
else:
raise ValueError(f"Unknown namespace: {namespace}")

def _add_token(self, token, token2index):
def _add_token(self, token, token2index, index2token):
if token in token2index:
return token2index[token]
i = len(token2index)
token2index[token] = i
index2token[i] = token
return i

def add_token(self, token, namespace):
if namespace == 'source':
return self._add_token(token, self._source)
return self._add_token(token, self._source, self._source_inverse)
elif namespace == 'target':
return self._add_token(token, self._target)
return self._add_token(token, self._target, self._target_inverse)
else:
raise ValueError(f"Unknown namespace: {namespace}")

Expand Down
4 changes: 3 additions & 1 deletion examples/greetings/greetings/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ def __init__(self,
cfg.TSEQ_LEN,
cfg.VOCAB_SAVE
)
copy_token = "@COPY@"
self.vocab.add_token(copy_token, "target")
self.searcher = BeamSearch(
10, self.vocab.get_token_id(self.vocab._end_token, "target"),
cfg.TSEQ_LEN - 1)
Expand All @@ -39,7 +41,7 @@ def __init__(self,
[target_emb_mat, tf.zeros(target_vocab_size)])
self.decoder = CopyNetDecoder(
self.vocab, self.encoder.get_output_dim(),
self.searcher, self.decoder_output_layer)
self.searcher, self.decoder_output_layer, copy_token=copy_token)
emb_mat = tf.convert_to_tensor(
self.vocab.get_embedding_matrix("source"))
self.source_embedder = FixedEmbedding(
Expand Down
Loading

0 comments on commit a3b65c4

Please sign in to comment.