From 526d3ad0bf17bc657d9100cbcb7a0d8682b10643 Mon Sep 17 00:00:00 2001 From: Gary Yuan Date: Mon, 24 Aug 2020 18:17:41 +0800 Subject: [PATCH] Using bool for mask and handling possible spaces in prediction --- pytorch_pretrained_bert/crf.py | 8 ++++---- pytorch_pretrained_bert/crf2.py | 8 ++++---- pytorch_pretrained_bert/modeling_transfo_xl.py | 6 +++--- wmseg_model.py | 2 +- 4 files changed, 12 insertions(+), 12 deletions(-) diff --git a/pytorch_pretrained_bert/crf.py b/pytorch_pretrained_bert/crf.py index 002b62a..f1116cd 100755 --- a/pytorch_pretrained_bert/crf.py +++ b/pytorch_pretrained_bert/crf.py @@ -97,7 +97,7 @@ def _calculate_PZ(self, feats, mask): # partition = utils.switch(partition, cur_partition, mask[idx].view(bat_size, 1).expand(bat_size, self.tagset_size)).view(bat_size, -1) mask_idx = mask[idx, :].view(batch_size, 1).expand(batch_size, tag_size) - mask_idx = mask_idx.byte() + mask_idx = mask_idx.bool() ## effective updated partition part, only keep the partition value of mask value = 1 masked_cur_partition = cur_partition.masked_select(mask_idx) ## let mask_idx broadcastable, to disable warning @@ -143,7 +143,7 @@ def _viterbi_decode(self, feats, mask): partition_history = list() # reverse mask (bug for mask = 1- mask, use this as alternative choice) # mask = 1 + (-1)*mask - mask = (1 - mask.long()).byte() + mask = (1 - mask.long()).bool() _, inivalues = next(seq_iter) # bat_size * from_target_size * to_target_size # only need start from start_tag partition = inivalues[:, START_TAG, :].clone().view(batch_size, tag_size) # bat_size * to_target_size @@ -253,7 +253,7 @@ def _score_sentence(self, scores, mask, tags): ### need convert tags id to search from 400 positions of scores tg_energy = torch.gather(scores.view(seq_len, batch_size, -1), 2, new_tags).view(seq_len, batch_size) # seq_len * bat_size ## mask transpose to (seq_len, batch_size) - tg_energy = tg_energy.masked_select(mask.transpose(1,0).byte()) + tg_energy = tg_energy.masked_select(mask.transpose(1,0).bool()) # ## calculate the score from START_TAG to first label # start_transition = self.transitions[START_TAG,:].view(1, tag_size).expand(batch_size, tag_size) @@ -307,7 +307,7 @@ def _viterbi_decode_nbest(self, feats, mask, nbest): partition_history = list() ## reverse mask (bug for mask = 1- mask, use this as alternative choice) # mask = 1 + (-1)*mask - mask = (1 - mask.long()).byte() + mask = (1 - mask.long()).bool() _, inivalues = next(seq_iter) # bat_size * from_target_size * to_target_size # only need start from start_tag partition = inivalues[:, START_TAG, :].clone() # bat_size * to_target_size diff --git a/pytorch_pretrained_bert/crf2.py b/pytorch_pretrained_bert/crf2.py index 002b62a..f1116cd 100755 --- a/pytorch_pretrained_bert/crf2.py +++ b/pytorch_pretrained_bert/crf2.py @@ -97,7 +97,7 @@ def _calculate_PZ(self, feats, mask): # partition = utils.switch(partition, cur_partition, mask[idx].view(bat_size, 1).expand(bat_size, self.tagset_size)).view(bat_size, -1) mask_idx = mask[idx, :].view(batch_size, 1).expand(batch_size, tag_size) - mask_idx = mask_idx.byte() + mask_idx = mask_idx.bool() ## effective updated partition part, only keep the partition value of mask value = 1 masked_cur_partition = cur_partition.masked_select(mask_idx) ## let mask_idx broadcastable, to disable warning @@ -143,7 +143,7 @@ def _viterbi_decode(self, feats, mask): partition_history = list() # reverse mask (bug for mask = 1- mask, use this as alternative choice) # mask = 1 + (-1)*mask - mask = (1 - mask.long()).byte() + mask = (1 - mask.long()).bool() _, inivalues = next(seq_iter) # bat_size * from_target_size * to_target_size # only need start from start_tag partition = inivalues[:, START_TAG, :].clone().view(batch_size, tag_size) # bat_size * to_target_size @@ -253,7 +253,7 @@ def _score_sentence(self, scores, mask, tags): ### need convert tags id to search from 400 positions of scores tg_energy = torch.gather(scores.view(seq_len, batch_size, -1), 2, new_tags).view(seq_len, batch_size) # seq_len * bat_size ## mask transpose to (seq_len, batch_size) - tg_energy = tg_energy.masked_select(mask.transpose(1,0).byte()) + tg_energy = tg_energy.masked_select(mask.transpose(1,0).bool()) # ## calculate the score from START_TAG to first label # start_transition = self.transitions[START_TAG,:].view(1, tag_size).expand(batch_size, tag_size) @@ -307,7 +307,7 @@ def _viterbi_decode_nbest(self, feats, mask, nbest): partition_history = list() ## reverse mask (bug for mask = 1- mask, use this as alternative choice) # mask = 1 + (-1)*mask - mask = (1 - mask.long()).byte() + mask = (1 - mask.long()).bool() _, inivalues = next(seq_iter) # bat_size * from_target_size * to_target_size # only need start from start_tag partition = inivalues[:, START_TAG, :].clone() # bat_size * to_target_size diff --git a/pytorch_pretrained_bert/modeling_transfo_xl.py b/pytorch_pretrained_bert/modeling_transfo_xl.py index 534a111..7977bf2 100755 --- a/pytorch_pretrained_bert/modeling_transfo_xl.py +++ b/pytorch_pretrained_bert/modeling_transfo_xl.py @@ -484,7 +484,7 @@ def __init__(self, n_head, d_model, d_head, dropout, dropatt=0, self.r_w_bias = r_w_bias def _parallelogram_mask(self, h, w, left=False): - mask = torch.ones((h, w)).byte() + mask = torch.ones((h, w)).bool() m = min(h, w) mask[:m,:m] = torch.triu(mask[:m,:m]) mask[-m:,-m:] = torch.tril(mask[-m:,-m:]) @@ -1184,10 +1184,10 @@ def _forward(self, dec_inp, mems=None): else: mask_shift_len = qlen dec_attn_mask = (torch.triu(all_ones, 1+mlen) - + torch.tril(all_ones, -mask_shift_len)).byte()[:, :, None] # -1 + + torch.tril(all_ones, -mask_shift_len)).bool()[:, :, None] # -1 else: dec_attn_mask = torch.triu( - word_emb.new_ones(qlen, klen), diagonal=1+mlen).byte()[:,:,None] + word_emb.new_ones(qlen, klen), diagonal=1+mlen).bool()[:,:,None] hids = [] if self.attn_type == 0: # default diff --git a/wmseg_model.py b/wmseg_model.py index 0e65d4b..07d1f31 100755 --- a/wmseg_model.py +++ b/wmseg_model.py @@ -259,7 +259,7 @@ def convert_examples_to_features(self, examples): tokenizer = self.bert_tokenizer if self.bert_tokenizer is not None else self.zen_tokenizer for (ex_index, example) in enumerate(examples): - textlist = example.text_a.split(' ') + textlist = example.text_a.replace(' ', ' ').replace(' ', ' ').replace(' ', ' ').split(' ') labellist = example.label tokens = [] labels = []