-
Notifications
You must be signed in to change notification settings - Fork 2.1k
/
Copy pathmodel.py
123 lines (97 loc) · 4.25 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import os
from collections import OrderedDict
import numpy as np
from tqdm import tqdm
from paddlehub.common.logger import logger
from slda_weibo.vocab import Vocab, WordCount
class TopicModel(object):
"""Storage Structure of Topic model, including vocabulary and word topic count.
"""
def __init__(self, model_dir, config):
"""
Args:
model_dir: the path of model directory
config: ModelConfig class.
"""
self.__word_topic = None # Model parameter of word topic.
self.__vocab = Vocab() # Vocab data structure of model.
self.__num_topics = config.num_topics # Number of topics.
self.__alpha = config.alpha
self.__alpha_sum = self.__alpha * self.__num_topics
self.__beta = config.beta
self.__beta_sum = None
self.__type = config.type # Model type.
self.__topic_sum = np.zeros(self.__num_topics, dtype="int64") # Accum sum of each topic in word topic.
self.__topic_words = [[] for _ in range(self.__num_topics)]
word_topic_path = os.path.join(model_dir, config.word_topic_file)
vocab_path = os.path.join(model_dir, config.vocab_file)
self.load_model(word_topic_path, vocab_path)
def term_id(self, term):
return self.__vocab.get_id(term)
def load_model(self, word_topic_path, vocab_path):
# Loading vocabulary
self.__vocab.load(vocab_path)
self.__beta_sum = self.__beta * self.__vocab.size()
self.__word_topic = [{} for _ in range(self.__vocab.size())] # 字典列表
self.__load_word_dict(word_topic_path)
logger.info("Model Info: #num_topics=%d #vocab_size=%d alpha=%f beta=%f" %
(self.num_topics(), self.vocab_size(), self.alpha(), self.beta()))
def word_topic_value(self, word_id, topic_id):
"""Return value of specific word under specific topic in the model.
"""
word_dict = self.__word_topic[word_id]
if topic_id not in word_dict:
return 0
return word_dict[topic_id]
def word_topic(self, term_id):
"""Return the topic distribution of a word.
"""
return self.__word_topic[term_id]
def topic_sum_value(self, topic_id):
return self.__topic_sum[topic_id]
def topic_sum(self):
return self.__topic_sum
def num_topics(self):
return self.__num_topics
def vocab_size(self):
return self.__vocab.size()
def alpha(self):
return self.__alpha
def alpha_sum(self):
return self.__alpha_sum
def beta(self):
return self.__beta
def beta_sum(self):
return self.__beta_sum
def type(self):
return self.__type
def __load_word_dict(self, word_dict_path):
"""Load the word topic parameters.
"""
logger.info("Loading word topic.")
with open(word_dict_path, 'r', encoding='utf-8') as f:
for line in tqdm(f.readlines()):
fields = line.strip().split(" ")
assert len(fields) > 0, "Model file format error!"
term_id = int(fields[0])
assert term_id < self.vocab_size(), "Term id out of range!"
assert term_id >= 0, "Term id out of range!"
for i in range(1, len(fields)):
topic_count = fields[i].split(":")
assert len(topic_count) == 2, "Topic count format error!"
topic_id = int(topic_count[0])
assert topic_id >= 0, "Topic out of range!"
assert topic_id < self.__num_topics, "Topic out of range!"
count = int(topic_count[1])
assert count >= 0, "Topic count error!"
self.__word_topic[term_id][topic_id] = count
self.__topic_sum[topic_id] += count
self.__topic_words[topic_id].append(WordCount(term_id, count))
new_dict = OrderedDict()
for key in sorted(self.__word_topic[term_id]):
new_dict[key] = self.__word_topic[term_id][key]
self.__word_topic[term_id] = new_dict
def get_vocab(self):
return self.__vocab.vocabulary()
def topic_words(self):
return self.__topic_words