From 61678e702e1b96b907da542b7ee0c4c7a9471aff Mon Sep 17 00:00:00 2001 From: Daniel Date: Sat, 22 Jan 2022 15:27:41 +0800 Subject: [PATCH] strip the below warning info, it's quite annoying: Warning: masked_scatter_ received a mask with dtype torch.uint8, this behavior is now deprecated --- main.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/main.py b/main.py index f1b8ae3..bb26e30 100644 --- a/main.py +++ b/main.py @@ -205,7 +205,7 @@ def batchify_with_label(input_batch_list, gpu, num_layer, volatile_flag=False): word_seq_tensor = autograd.Variable(torch.zeros((batch_size, max_seq_len))).long() biword_seq_tensor = autograd.Variable(torch.zeros((batch_size, max_seq_len))).long() label_seq_tensor = autograd.Variable(torch.zeros((batch_size, max_seq_len))).long() - mask = autograd.Variable(torch.zeros((batch_size, max_seq_len))).byte() + mask = autograd.Variable(torch.zeros((batch_size, max_seq_len))).bool() ### bert seq tensor bert_seq_tensor = autograd.Variable(torch.zeros((batch_size, max_seq_len+2))).long() bert_mask = autograd.Variable(torch.zeros((batch_size, max_seq_len+2))).long() @@ -217,8 +217,8 @@ def batchify_with_label(input_batch_list, gpu, num_layer, volatile_flag=False): gaz_len = [len(gaz_chars[i][0][0][0]) for i in range(batch_size)] max_gaz_len = max(gaz_len) gaz_chars_tensor = torch.zeros(batch_size, max_seq_len, 4, max_gaz_num, max_gaz_len).long() - gaz_mask_tensor = torch.ones(batch_size, max_seq_len, 4, max_gaz_num).byte() - gazchar_mask_tensor = torch.ones(batch_size, max_seq_len, 4, max_gaz_num, max_gaz_len).byte() + gaz_mask_tensor = torch.ones(batch_size, max_seq_len, 4, max_gaz_num).bool() + gazchar_mask_tensor = torch.ones(batch_size, max_seq_len, 4, max_gaz_num, max_gaz_len).bool() for b, (seq, bert_id, biseq, label, seqlen, layergaz, gazmask, gazcount, gazchar, gazchar_mask, gaznum, gazlen) in enumerate(zip(words, bert_ids, biwords, labels, word_seq_lengths, layer_gazs, gaz_mask, gaz_count, gaz_chars, gazchar_mask, gaz_num, gaz_len)):