Skip to content

Commit

Permalink
Add a file
Browse files Browse the repository at this point in the history
  • Loading branch information
singletongue committed Sep 2, 2019
1 parent 4a9cbe3 commit 1e116ee
Showing 1 changed file with 292 additions and 0 deletions.
292 changes: 292 additions & 0 deletions masked_lm_example.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,292 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from pytorch_transformers import BertForMaskedLM, BertConfig\n",
"from tokenization import MecabBertTokenizer"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"tokenizer = MecabBertTokenizer(vocab_file='/Users/m-suzuki/work/japanese-bert/jawiki-20190701/mecab-ipadic-bpe-32k/vocab.txt')"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"text = '今年の夏は友達と北海道に行きました。'"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"tokens = ['[CLS]'] + tokenizer.tokenize(text)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['[CLS]', '今年', 'の', '夏', 'は', '友達', 'と', '北海道', 'に', '行き', 'まし', 'た', '。']"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tokens"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"tokens[7] = '[MASK]'"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['[CLS]', '今年', 'の', '夏', 'は', '友達', 'と', '[MASK]', 'に', '行き', 'まし', 'た', '。']"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tokens"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"token_ids = tokenizer.convert_tokens_to_ids(tokens)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[2, 21659, 5, 1431, 9, 13164, 13, 4, 7, 2630, 4110, 10, 8]"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"token_ids"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"token_ids = torch.tensor([token_ids])"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[ 2, 21659, 5, 1431, 9, 13164, 13, 4, 7, 2630,\n",
" 4110, 10, 8]])"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"token_ids"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"config = BertConfig.from_json_file('/Users/m-suzuki/work/japanese-bert/jawiki-20190701/mecab-ipadic-bpe-32k/bert_config.json')"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"model = BertForMaskedLM(config)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"_IncompatibleKeys(missing_keys=[], unexpected_keys=['cls.seq_relationship.weight', 'cls.seq_relationship.bias'])"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.load_state_dict(torch.load('/Users/m-suzuki/work/japanese-bert/jawiki-20190701/mecab-ipadic-bpe-32k/pytorch_model.bin'), strict=False)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"predictions, = model(token_ids)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"_, top10_pred_ids = torch.topk(predictions, k=10, dim=2)"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[[ 9, 6, 8, 5, 7, 14, 40, 13, 12, 73],\n",
" [21659, 2442, 19791, 5249, 1431, 8936, 7030, 1485, 1353, 738],\n",
" [ 5, 1390, 28443, 71, 9, 7, 6, 75, 28, 1337],\n",
" [ 1431, 2442, 1337, 16045, 1390, 120, 13164, 16832, 288, 51],\n",
" [ 9, 6, 7, 5, 40, 13, 11, 12, 13164, 28],\n",
" [13164, 3582, 7791, 10216, 8328, 5738, 1431, 4429, 1703, 3634],\n",
" [ 13, 12, 705, 6, 5, 9, 14, 1004, 1763, 7791],\n",
" [ 8085, 4470, 473, 3754, 2756, 298, 13502, 1839, 5452, 19455],\n",
" [ 7, 12, 119, 16, 40, 14, 13, 15, 1763, 11957],\n",
" [ 2630, 522, 11957, 15, 20190, 1234, 3487, 12660, 12, 4288],\n",
" [ 4110, 13222, 7025, 15, 11355, 2551, 307, 16, 17365, 2982],\n",
" [ 10, 16, 17, 183, 75, 81, 15, 4110, 1520, 203],\n",
" [ 8, 40, 13, 6, 38, 708, 969, 143, 5, 1989]]])"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"top10_pred_ids"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"['[CLS]'] ['は', '、', '。', 'の', 'に', 'が', 'から', 'と', 'で', 'お']\n",
"['今年'] ['今年', '冬', '昨年', '最近', '夏', '今度', '今回', '彼女', '私', '今']\n",
"['の'] ['の', '春', '##の', 'この', 'は', 'に', '、', 'だ', 'も', '秋']\n",
"['夏'] ['夏', '冬', '秋', '夏休み', '春', 'もの', '友達', 'なつ', '間', '中']\n",
"['は'] ['は', '、', 'に', 'の', 'から', 'と', 'を', 'で', '友達', 'も']\n",
"['友達'] ['友達', '友人', 'みんな', '同級生', '先輩', '皆', '夏', '妹', '友', '仲間']\n",
"['と'] ['と', 'で', 'だけ', '、', 'の', 'は', 'が', 'と共に', 'よく', 'みんな']\n",
"['[MASK]'] ['遊び', '一緒', '学校', '旅行', '食べ', '海', '会い', '公園', '過ごし', '買い物']\n",
"['に'] ['に', 'で', 'へ', 'て', 'から', 'が', 'と', 'し', 'よく', '帰り']\n",
"['行き'] ['行き', '行っ', '帰り', 'し', '行け', '来', '行く', '行か', 'で', 'いき']\n",
"['まし'] ['まし', 'でし', 'ませ', 'し', 'たかっ', 'ます', 'だっ', 'て', 'ましょ', 'です']\n",
"['た'] ['た', 'て', 'な', 'つ', 'だ', 'ない', 'し', 'まし', 'たい', 'う']\n",
"['。'] ['。', 'から', 'と', '、', '」', '!', 'ので', 'という', 'の', 'ね']\n"
]
}
],
"source": [
"for correct_id, pred_ids in zip(token_ids[0], top10_pred_ids[0]):\n",
" correct_token = tokenizer.convert_ids_to_tokens([correct_id.item()])\n",
" pred_tokens = tokenizer.convert_ids_to_tokens(pred_ids.tolist())\n",
" print(correct_token, pred_tokens)\n",
" "
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.2"
}
},
"nbformat": 4,
"nbformat_minor": 4
}

0 comments on commit 1e116ee

Please sign in to comment.