forked from liu-nlper/NER-LSTM-CRF
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
73 lines (62 loc) · 2.84 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
__author__ = '[email protected]'
"""
训练NER模型
"""
import yaml
import pickle
from load_data import load_vocs, init_data
from model import SequenceLabelingModel
def main():
# 加载配置文件
with open('./config.yml') as file_config:
config = yaml.load(file_config)
feature_names = config['model_params']['feature_names']
# 初始化embedding shape, dropouts, 预训练的embedding也在这里初始化)
feature_weight_shape_dict, feature_weight_dropout_dict, \
feature_init_weight_dict = dict(), dict(), dict()
for feature_name in feature_names:
feature_weight_shape_dict[feature_name] = \
config['model_params']['embed_params'][feature_name]['shape']
feature_weight_dropout_dict[feature_name] = \
config['model_params']['embed_params'][feature_name]['dropout_rate']
path_pre_train = config['model_params']['embed_params'][feature_name]['path']
if path_pre_train:
with open(path_pre_train, 'rb') as file_r:
feature_init_weight_dict[feature_name] = pickle.load(file_r)
# 加载数据
# 加载vocs
path_vocs = []
for feature_name in feature_names:
path_vocs.append(config['data_params']['voc_params'][feature_name]['path'])
path_vocs.append(config['data_params']['voc_params']['label']['path'])
vocs = load_vocs(path_vocs)
# 加载训练数据
sep_str = config['data_params']['sep']
assert sep_str in ['table', 'space']
sep = '\t' if sep_str == 'table' else ' '
data_dict = init_data(
path=config['data_params']['path_train'], feature_names=feature_names, sep=sep,
vocs=vocs, max_len=config['model_params']['sequence_length'], model='train')
# 训练模型
model = SequenceLabelingModel(
sequence_length=config['model_params']['sequence_length'],
nb_classes=config['model_params']['nb_classes'],
nb_hidden=config['model_params']['bilstm_params']['num_units'],
feature_weight_shape_dict=feature_weight_shape_dict,
feature_init_weight_dict=feature_init_weight_dict,
feature_weight_dropout_dict=feature_weight_dropout_dict,
dropout_rate=config['model_params']['dropout_rate'],
nb_epoch=config['model_params']['nb_epoch'], feature_names=feature_names,
batch_size=config['model_params']['batch_size'],
train_max_patience=config['model_params']['max_patience'],
use_crf=config['model_params']['use_crf'],
l2_rate=config['model_params']['l2_rate'],
rnn_unit=config['model_params']['rnn_unit'],
learning_rate=config['model_params']['learning_rate'],
path_model=config['model_params']['path_model'])
model.fit(
data_dict=data_dict, dev_size=config['model_params']['dev_size'])
if __name__ == '__main__':
main()