diff --git a/searcher.py b/searcher.py index 6c31190..381bd44 100644 --- a/searcher.py +++ b/searcher.py @@ -18,6 +18,9 @@ class Searcher: SHOW_ANSWER = 10 + MAGIC_NUMBER = 20 + IDF_BOUND = 1 + COS_BOUND = 0.1 def __init__(self, db_name): self.con = sqlite.connect(db_name) @@ -33,19 +36,61 @@ def separate_words(text): splitter = re.compile('\\W*') return [s.lower() for s in splitter.split(text) if s != ''] - @staticmethod - def get_top_words(words, n): - '''Return top n words in text - Return list of pairs (word, num of repetition)''' + + def get_top_words(self, words, n): + '''Return top n tf * idf words in text + Return list of words ''' words_top = {word: 0 for word in words} for word in words: words_top[word] += 1 + for word in words_top: + word_idf = self.con.execute( + "select idf from word_list where word = '%s'" % word).fetchone() + if word_idf is None: + words_top[word] = 0 + else: + word_idf = word_idf[0] + if word_idf > self.IDF_BOUND: + words_top[word] = words_top[word] * word_idf + else: + words_top[word] = 0 + + words_top = {word: words_top[word] for word in words_top if words_top[word] > 0} + sorted_top = sorted(words_top.iteritems(), key=operator.itemgetter(1), reverse=True) - if len(sorted_top) <= n: - return sorted_top + clear_list = [pair[0] for pair in sorted_top] + if len(clear_list) <= n: + return clear_list else: - return sorted_top[0: n] + return clear_list[0: n] + + + def find_rows(self, words): + '''Find documents which contain one of words''' + if len(words) == 0: + return [] + + word_id_list = [] + table_num = 0 + clause_list = '' + + for word in words: + word_row = self.con.execute( + "select rowid from word_list where word = '%s'" % word).fetchone() + if word_row is not None: + word_id = word_row[0] + #print "word_id: %d" % word_id + word_id_list.append(word_id) + if table_num > 0: + clause_list += ' or ' + clause_list += 'word_id = %d' % word_id + table_num += 1 + + query = 'select distinct url_id from word_location where %s ' % clause_list + result = self.con.execute(query) + rows = [row for row in result] + return rows def tf(self, words): @@ -133,15 +178,21 @@ def cos_search(self, text, n): text_words_tf_idf = self.get_top_tf_idf(text_words, n) text_length = Searcher.count_length(text_words_tf_idf) - url_ids = self.con.execute("select rowid from url_list").fetchall() - url_ids = [url_id[0] for url_id in url_ids] + top_text_words = self.get_top_words(text_words, self.MAGIC_NUMBER) + url_ids = self.find_rows(top_text_words) + url_ids = [url_id[0] for url_id in url_ids] url_count = len(url_ids) - print "Number of documents is %d" % url_count + + url_full_count = self.con.execute("select count(rowid) from url_list").fetchone()[0] + print "Number of documents: %d " % url_full_count + + print "Number of documents after cutting: %d " % url_count print "Searching..." heap = [] for url_id in url_ids: + #print url_id url_words = self.con.execute("select word from word_list join word_location on " " word_list.rowid = word_location.word_id where " " word_location.url_id = %s" % url_id).fetchall() @@ -151,6 +202,9 @@ def cos_search(self, text, n): url_length = self.con.execute("select length from url_list where rowid = %d" % url_id).fetchone()[0] url_cos = Searcher.cos_distance(url_words_tf_idf, url_length, text_words_tf_idf, text_length) + if url_cos < self.COS_BOUND: + continue + if len(heap) < self.SHOW_ANSWER: heapq.heappush(heap, (url_cos, url_id)) else: