-
Notifications
You must be signed in to change notification settings - Fork 55
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
4a9cbe3
commit 1e116ee
Showing
1 changed file
with
292 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,292 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import torch\n", | ||
"from pytorch_transformers import BertForMaskedLM, BertConfig\n", | ||
"from tokenization import MecabBertTokenizer" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"tokenizer = MecabBertTokenizer(vocab_file='/Users/m-suzuki/work/japanese-bert/jawiki-20190701/mecab-ipadic-bpe-32k/vocab.txt')" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"text = '今年の夏は友達と北海道に行きました。'" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 4, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"tokens = ['[CLS]'] + tokenizer.tokenize(text)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 5, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"['[CLS]', '今年', 'の', '夏', 'は', '友達', 'と', '北海道', 'に', '行き', 'まし', 'た', '。']" | ||
] | ||
}, | ||
"execution_count": 5, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"tokens" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 6, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"tokens[7] = '[MASK]'" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 7, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"['[CLS]', '今年', 'の', '夏', 'は', '友達', 'と', '[MASK]', 'に', '行き', 'まし', 'た', '。']" | ||
] | ||
}, | ||
"execution_count": 7, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"tokens" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 8, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"token_ids = tokenizer.convert_tokens_to_ids(tokens)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 9, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"[2, 21659, 5, 1431, 9, 13164, 13, 4, 7, 2630, 4110, 10, 8]" | ||
] | ||
}, | ||
"execution_count": 9, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"token_ids" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 10, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"token_ids = torch.tensor([token_ids])" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 11, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"tensor([[ 2, 21659, 5, 1431, 9, 13164, 13, 4, 7, 2630,\n", | ||
" 4110, 10, 8]])" | ||
] | ||
}, | ||
"execution_count": 11, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"token_ids" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 12, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"config = BertConfig.from_json_file('/Users/m-suzuki/work/japanese-bert/jawiki-20190701/mecab-ipadic-bpe-32k/bert_config.json')" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 13, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"model = BertForMaskedLM(config)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 14, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"_IncompatibleKeys(missing_keys=[], unexpected_keys=['cls.seq_relationship.weight', 'cls.seq_relationship.bias'])" | ||
] | ||
}, | ||
"execution_count": 14, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"model.load_state_dict(torch.load('/Users/m-suzuki/work/japanese-bert/jawiki-20190701/mecab-ipadic-bpe-32k/pytorch_model.bin'), strict=False)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 15, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"predictions, = model(token_ids)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 16, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"_, top10_pred_ids = torch.topk(predictions, k=10, dim=2)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 17, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"tensor([[[ 9, 6, 8, 5, 7, 14, 40, 13, 12, 73],\n", | ||
" [21659, 2442, 19791, 5249, 1431, 8936, 7030, 1485, 1353, 738],\n", | ||
" [ 5, 1390, 28443, 71, 9, 7, 6, 75, 28, 1337],\n", | ||
" [ 1431, 2442, 1337, 16045, 1390, 120, 13164, 16832, 288, 51],\n", | ||
" [ 9, 6, 7, 5, 40, 13, 11, 12, 13164, 28],\n", | ||
" [13164, 3582, 7791, 10216, 8328, 5738, 1431, 4429, 1703, 3634],\n", | ||
" [ 13, 12, 705, 6, 5, 9, 14, 1004, 1763, 7791],\n", | ||
" [ 8085, 4470, 473, 3754, 2756, 298, 13502, 1839, 5452, 19455],\n", | ||
" [ 7, 12, 119, 16, 40, 14, 13, 15, 1763, 11957],\n", | ||
" [ 2630, 522, 11957, 15, 20190, 1234, 3487, 12660, 12, 4288],\n", | ||
" [ 4110, 13222, 7025, 15, 11355, 2551, 307, 16, 17365, 2982],\n", | ||
" [ 10, 16, 17, 183, 75, 81, 15, 4110, 1520, 203],\n", | ||
" [ 8, 40, 13, 6, 38, 708, 969, 143, 5, 1989]]])" | ||
] | ||
}, | ||
"execution_count": 17, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"top10_pred_ids" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 18, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"['[CLS]'] ['は', '、', '。', 'の', 'に', 'が', 'から', 'と', 'で', 'お']\n", | ||
"['今年'] ['今年', '冬', '昨年', '最近', '夏', '今度', '今回', '彼女', '私', '今']\n", | ||
"['の'] ['の', '春', '##の', 'この', 'は', 'に', '、', 'だ', 'も', '秋']\n", | ||
"['夏'] ['夏', '冬', '秋', '夏休み', '春', 'もの', '友達', 'なつ', '間', '中']\n", | ||
"['は'] ['は', '、', 'に', 'の', 'から', 'と', 'を', 'で', '友達', 'も']\n", | ||
"['友達'] ['友達', '友人', 'みんな', '同級生', '先輩', '皆', '夏', '妹', '友', '仲間']\n", | ||
"['と'] ['と', 'で', 'だけ', '、', 'の', 'は', 'が', 'と共に', 'よく', 'みんな']\n", | ||
"['[MASK]'] ['遊び', '一緒', '学校', '旅行', '食べ', '海', '会い', '公園', '過ごし', '買い物']\n", | ||
"['に'] ['に', 'で', 'へ', 'て', 'から', 'が', 'と', 'し', 'よく', '帰り']\n", | ||
"['行き'] ['行き', '行っ', '帰り', 'し', '行け', '来', '行く', '行か', 'で', 'いき']\n", | ||
"['まし'] ['まし', 'でし', 'ませ', 'し', 'たかっ', 'ます', 'だっ', 'て', 'ましょ', 'です']\n", | ||
"['た'] ['た', 'て', 'な', 'つ', 'だ', 'ない', 'し', 'まし', 'たい', 'う']\n", | ||
"['。'] ['。', 'から', 'と', '、', '」', '!', 'ので', 'という', 'の', 'ね']\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"for correct_id, pred_ids in zip(token_ids[0], top10_pred_ids[0]):\n", | ||
" correct_token = tokenizer.convert_ids_to_tokens([correct_id.item()])\n", | ||
" pred_tokens = tokenizer.convert_ids_to_tokens(pred_ids.tolist())\n", | ||
" print(correct_token, pred_tokens)\n", | ||
" " | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.7.2" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 4 | ||
} |