From e2eda8f13b872ad8b4b556c930abc9c4a2dd09fa Mon Sep 17 00:00:00 2001 From: rob Date: Mon, 1 Apr 2013 10:37:58 +0100 Subject: [PATCH] init --- db.py | 42 ++++++++++++++++++++++++++++++++++++++++++ gen.py | 33 +++++++++++++++++++++++++++++++++ markov.py | 24 ++++++++++++++++++++++++ parse.py | 36 ++++++++++++++++++++++++++++++++++++ 4 files changed, 135 insertions(+) create mode 100644 db.py create mode 100644 gen.py create mode 100644 markov.py create mode 100644 parse.py diff --git a/db.py b/db.py new file mode 100644 index 0000000..143785c --- /dev/null +++ b/db.py @@ -0,0 +1,42 @@ +import sqlite3 + +class Db: + def __init__(self, name): + self.name = name + self.conn = sqlite3.connect(name + '.db') + c = self.conn.cursor() + c.execute('CREATE TABLE IF NOT EXISTS words (word, next_word, count)') + c.execute('CREATE INDEX IF NOT EXISTS i_word ON words (word,next_word)') + + def get_word_pair_count(self, word, next_word): + c = self.conn.cursor() + c.execute('select count from words where word=? and next_word=?' , (word, next_word)) + r = c.fetchone() + if r: + return r[0] + else: + return 0 + + def add_word(self, word, next_word): + count = self.get_word_pair_count(word, next_word) + c = self.conn.cursor() + if count: + c.execute('UPDATE words SET count=? WHERE word=? AND next_word=?', (count + 1, word, next_word)) + else: + c.execute('INSERT INTO words (word, next_word, count) VALUES (?,?,?)', (word, next_word, 1)) + + def commit(self): + self.conn.commit() + + def get_word_count(self, word): + c = self.conn.cursor() + counts = {} + for row in c.execute('SELECT next_word, count FROM words WHERE word=?', (word,)): + counts[row[0]] = row[1] + + return counts + + def reset(self): + c = self.conn.cursor() + c.execute('delete from words') + self.conn.commit() diff --git a/gen.py b/gen.py new file mode 100644 index 0000000..1249584 --- /dev/null +++ b/gen.py @@ -0,0 +1,33 @@ +import sqlite3 +from parse import Parser +from random import randint + +class Generator: + def __init__(self, name, db): + self.name = name + self.db = db + + def get_next_word(self, word): + candidate_words = self.db.get_word_count(word) + total_next_words = sum(candidate_words.values()) + i = randint(1, total_next_words) + t=0 + for w in candidate_words.keys(): + t += candidate_words[w] + if (i <= t): + return w + assert False + + def make_sentence(self): + word = self.get_next_word(Parser.SENTENCE_START_SYMBOL) + sentence = [] + + while word != Parser.SENTENCE_END_SYMBOL: + sentence.append(word) + word = self.get_next_word(word) + + return ' '.join(sentence) + + def generate(self, count): + for i in range(0, count): + print self.make_sentence() \ No newline at end of file diff --git a/markov.py b/markov.py new file mode 100644 index 0000000..75bf459 --- /dev/null +++ b/markov.py @@ -0,0 +1,24 @@ +from db import Db +from gen import Generator +from parse import Parser +import sys + + +if __name__ == '__main__': + args = sys.argv + usage = 'Usage: %s (parse |gen )' % (args[0], ) + + if (len(args) != 4): + raise ValueError(usage) + + mode = args[1] + name = args[2] + db = Db(name) + if mode == 'parse': + file_name = args[3] + Parser(name, db).parse(file_name) + elif mode == 'gen': + count = int(args[3]) + Generator(name, db).generate(count) + else: + raise ValueError(usage) \ No newline at end of file diff --git a/parse.py b/parse.py new file mode 100644 index 0000000..ed59cf6 --- /dev/null +++ b/parse.py @@ -0,0 +1,36 @@ +import sqlite3 +import codecs +import sys + +class Parser: + SENTENCE_START_SYMBOL = '^' + SENTENCE_END_SYMBOL = '$' + + def __init__(self, name, db): + self.name = name + self.db = db + + def save_word_pair(self, word1, word2): + self.db.add_word(word1, word2) + + def parse(self, file_name): + txt = codecs.open(file_name, 'r', 'utf-8').read() + sentences = txt.split('\n') + i = 0 + + for sentence in sentences: + words = sentence.split() + prev_word = Parser.SENTENCE_START_SYMBOL + + for word in words: + self.save_word_pair(prev_word, word) + prev_word = word + + self.save_word_pair(prev_word, Parser.SENTENCE_END_SYMBOL) + self.db.commit() + i += 1 + if i % 1000 == 0: + print i + sys.stdout.flush() + +