-
Notifications
You must be signed in to change notification settings - Fork 82
/
NBC.py
169 lines (147 loc) · 6.74 KB
/
NBC.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
#coding: utf-8
import os
import time
import random
import jieba
import nltk
import sklearn
from sklearn.naive_bayes import MultinomialNB
import numpy as np
import pylab as pl
import matplotlib.pyplot as plt
def MakeWordsSet(words_file):
words_set = set()
with open(words_file, 'r') as fp:
for line in fp.readlines():
word = line.strip().decode("utf-8")
if len(word)>0 and word not in words_set: # 去重
words_set.add(word)
return words_set
def TextProcessing(folder_path, test_size=0.2):
folder_list = os.listdir(folder_path)
data_list = []
class_list = []
# 类间循环
for folder in folder_list:
new_folder_path = os.path.join(folder_path, folder)
files = os.listdir(new_folder_path)
# 类内循环
j = 1
for file in files:
if j > 100: # 每类text样本数最多100
break
with open(os.path.join(new_folder_path, file), 'r') as fp:
raw = fp.read()
# print raw
## --------------------------------------------------------------------------------
## jieba分词
# jieba.enable_parallel(4) # 开启并行分词模式,参数为并行进程数,不支持windows
word_cut = jieba.cut(raw, cut_all=False) # 精确模式,返回的结构是一个可迭代的genertor
word_list = list(word_cut) # genertor转化为list,每个词unicode格式
# jieba.disable_parallel() # 关闭并行分词模式
# print word_list
## --------------------------------------------------------------------------------
data_list.append(word_list)
class_list.append(folder.decode('utf-8'))
j += 1
## 划分训练集和测试集
# train_data_list, test_data_list, train_class_list, test_class_list = sklearn.cross_validation.train_test_split(data_list, class_list, test_size=test_size)
data_class_list = zip(data_list, class_list)
random.shuffle(data_class_list)
index = int(len(data_class_list)*test_size)+1
train_list = data_class_list[index:]
test_list = data_class_list[:index]
train_data_list, train_class_list = zip(*train_list)
test_data_list, test_class_list = zip(*test_list)
# 统计词频放入all_words_dict
all_words_dict = {}
for word_list in train_data_list:
for word in word_list:
if all_words_dict.has_key(word):
all_words_dict[word] += 1
else:
all_words_dict[word] = 1
# key函数利用词频进行降序排序
all_words_tuple_list = sorted(all_words_dict.items(), key=lambda f:f[1], reverse=True) # 内建函数sorted参数需为list
all_words_list = list(zip(*all_words_tuple_list)[0])
return all_words_list, train_data_list, test_data_list, train_class_list, test_class_list
def words_dict(all_words_list, deleteN, stopwords_set=set()):
# 选取特征词
feature_words = []
n = 1
for t in range(deleteN, len(all_words_list), 1):
if n > 1000: # feature_words的维度1000
break
# print all_words_list[t]
if not all_words_list[t].isdigit() and all_words_list[t] not in stopwords_set and 1<len(all_words_list[t])<5:
feature_words.append(all_words_list[t])
n += 1
return feature_words
def TextFeatures(train_data_list, test_data_list, feature_words, flag='nltk'):
def text_features(text, feature_words):
text_words = set(text)
## -----------------------------------------------------------------------------------
if flag == 'nltk':
## nltk特征 dict
features = {word:1 if word in text_words else 0 for word in feature_words}
elif flag == 'sklearn':
## sklearn特征 list
features = [1 if word in text_words else 0 for word in feature_words]
else:
features = []
## -----------------------------------------------------------------------------------
return features
train_feature_list = [text_features(text, feature_words) for text in train_data_list]
test_feature_list = [text_features(text, feature_words) for text in test_data_list]
return train_feature_list, test_feature_list
def TextClassifier(train_feature_list, test_feature_list, train_class_list, test_class_list, flag='nltk'):
## -----------------------------------------------------------------------------------
if flag == 'nltk':
## nltk分类器
train_flist = zip(train_feature_list, train_class_list)
test_flist = zip(test_feature_list, test_class_list)
classifier = nltk.classify.NaiveBayesClassifier.train(train_flist)
# print classifier.classify_many(test_feature_list)
# for test_feature in test_feature_list:
# print classifier.classify(test_feature),
# print ''
test_accuracy = nltk.classify.accuracy(classifier, test_flist)
elif flag == 'sklearn':
## sklearn分类器
classifier = MultinomialNB().fit(train_feature_list, train_class_list)
# print classifier.predict(test_feature_list)
# for test_feature in test_feature_list:
# print classifier.predict(test_feature)[0],
# print ''
test_accuracy = classifier.score(test_feature_list, test_class_list)
else:
test_accuracy = []
return test_accuracy
if __name__ == '__main__':
print "start"
## 文本预处理
folder_path = './Database/SogouC/Sample'
all_words_list, train_data_list, test_data_list, train_class_list, test_class_list = TextProcessing(folder_path, test_size=0.2)
# 生成stopwords_set
stopwords_file = './stopwords_cn.txt'
stopwords_set = MakeWordsSet(stopwords_file)
## 文本特征提取和分类
# flag = 'nltk'
flag = 'sklearn'
deleteNs = range(0, 1000, 20)
test_accuracy_list = []
for deleteN in deleteNs:
# feature_words = words_dict(all_words_list, deleteN)
feature_words = words_dict(all_words_list, deleteN, stopwords_set)
train_feature_list, test_feature_list = TextFeatures(train_data_list, test_data_list, feature_words, flag)
test_accuracy = TextClassifier(train_feature_list, test_feature_list, train_class_list, test_class_list, flag)
test_accuracy_list.append(test_accuracy)
print test_accuracy_list
# 结果评价
plt.figure()
plt.plot(deleteNs, test_accuracy_list)
plt.title('Relationship of deleteNs and test_accuracy')
plt.xlabel('deleteNs')
plt.ylabel('test_accuracy')
plt.savefig('result.png')
print "finished"