diff --git a/Makefile b/Makefile index 53e5a13d1..d76cfc462 100644 --- a/Makefile +++ b/Makefile @@ -2,7 +2,7 @@ MALLET_DIR = $(shell pwd) JAVAC = javac JAVA_FLAGS = \ --classpath "$(MALLET_DIR)/class:$(MALLET_DIR)/lib/mallet-deps.jar:$(MALLET_DIR)/lib/jdom-1.0.jar:$(MALLET_DIR)/lib/grmm-deps.jar:$(MALLET_DIR)/lib/weka.jar " \ +-classpath "$(MALLET_DIR)/lib/wordnet.jar:$(MALLET_DIR)/class:$(MALLET_DIR)/lib/mallet-deps.jar:$(MALLET_DIR)/lib/jdom-1.0.jar:$(MALLET_DIR)/lib/grmm-deps.jar:$(MALLET_DIR)/lib/weka.jar " \ -sourcepath "$(MALLET_DIR)/src" \ -g:lines,vars,source \ -d $(MALLET_DIR)/class \ diff --git a/lib/.DS_Store b/lib/.DS_Store new file mode 100644 index 000000000..5008ddfcf Binary files /dev/null and b/lib/.DS_Store differ diff --git a/lib/wordnet.jar b/lib/wordnet.jar new file mode 100755 index 000000000..51f6c9d30 Binary files /dev/null and b/lib/wordnet.jar differ diff --git a/src/cc/mallet/topics/tree/CorpusWriter.java b/src/cc/mallet/topics/tree/CorpusWriter.java new file mode 100644 index 000000000..47888fe8a --- /dev/null +++ b/src/cc/mallet/topics/tree/CorpusWriter.java @@ -0,0 +1,157 @@ +package cc.mallet.topics.tree; + +import java.io.BufferedReader; +import java.io.DataInputStream; +import java.io.File; +import java.io.FileInputStream; +import java.io.FileNotFoundException; +import java.io.IOException; +import java.io.InputStreamReader; +import java.io.PrintStream; +import java.util.ArrayList; + +import gnu.trove.TIntArrayList; +import gnu.trove.TIntIntHashMap; +import cc.mallet.types.FeatureSequence; +import cc.mallet.types.Instance; +import cc.mallet.types.InstanceList; + +public class CorpusWriter { + + public static void writeCorpus(InstanceList training, String outfilename, String vocabname) throws FileNotFoundException { + + ArrayList vocab = loadVocab(vocabname); + + PrintStream out = new PrintStream (new File(outfilename)); + + int count = -1; + for (Instance instance : training) { + count++; + if (count % 1000 == 0) { + System.out.println("Processed " + count + " number of documents!"); + } + FeatureSequence original_tokens = (FeatureSequence) instance.getData(); + String name = instance.getName().toString(); + + TIntArrayList tokens = new TIntArrayList(original_tokens.getLength()); + TIntIntHashMap topicCounts = new TIntIntHashMap (); + TIntArrayList topics = new TIntArrayList(original_tokens.getLength()); + TIntArrayList paths = new TIntArrayList(original_tokens.getLength()); + + String doc = ""; + for (int jj = 0; jj < original_tokens.getLength(); jj++) { + String word = (String) original_tokens.getObjectAtPosition(jj); + int token = vocab.indexOf(word); + doc += word + " "; + //if(token != -1) { + // doc += word + " "; + //} + } + System.out.println(name); + System.out.println(doc); + + if (!doc.equals("")) { + out.println(doc); + } + } + + out.close(); + } + + public static void writeCorpusMatrix(InstanceList training, String outfilename, String vocabname) throws FileNotFoundException { + + // each document is represented in a vector (vocab size), and each entry is the frequency of a word. + + ArrayList vocab = loadVocab(vocabname); + + PrintStream out = new PrintStream (new File(outfilename)); + + int count = -1; + for (Instance instance : training) { + count++; + if (count % 1000 == 0) { + System.out.println("Processed " + count + " number of documents!"); + } + FeatureSequence original_tokens = (FeatureSequence) instance.getData(); + String name = instance.getName().toString(); + + int[] tokens = new int[vocab.size()]; + for (int jj = 0; jj < tokens.length; jj++) { + tokens[jj] = 0; + } + + for (int jj = 0; jj < original_tokens.getLength(); jj++) { + String word = (String) original_tokens.getObjectAtPosition(jj); + int index = vocab.indexOf(word); + tokens[index] += 1; + } + + String doc = ""; + for (int jj = 0; jj < tokens.length; jj++) { + doc += tokens[jj] + "\t"; + } + + System.out.println(name); + System.out.println(doc); + + if (!doc.equals("")) { + out.println(doc); + } + } + + out.close(); + } + + public static ArrayList loadVocab(String vocabFile) { + + ArrayList vocab = new ArrayList(); + + try { + FileInputStream infstream = new FileInputStream(vocabFile); + DataInputStream in = new DataInputStream(infstream); + BufferedReader br = new BufferedReader(new InputStreamReader(in)); + + String strLine; + //Read File Line By Line + while ((strLine = br.readLine()) != null) { + strLine = strLine.trim(); + String[] str = strLine.split("\t"); + if (str.length > 1) { + vocab.add(str[1]); + } else { + System.out.println("Error! " + strLine); + } + } + in.close(); + + } catch (IOException e) { + System.out.println("No vocab file Found!"); + } + return vocab; + } + + + public static void main(String[] args) { + //String input = "input/nyt/nyt-topic-input.mallet"; + //String corpus = "../../pylda/variational/data/20_news/doc.dat"; + //String vocab = "../../pylda/variational/data/20_news/voc.dat"; + + String input = "input/synthetic/synthetic-topic-input.mallet"; + //String corpus = "../../spectral/input/synthetic-ordered.dat"; + //String vocab = "../../spectral/input/synthetic-ordered.voc"; + String corpus = "../../spectral/input/synthetic.dat"; + String vocab = "../../spectral/input/synthetic.voc"; + + //String input = "../../itm-evaluation/results/govtrack-109/input/govtrack-109-topic-input.mallet"; + //String corpus = "../../pylda/variational/data/20_news/doc.dat"; + //String vocab = "../../itm-evaluation/results/govtrack-109/input/govtrack-109.voc"; + + try{ + InstanceList data = InstanceList.load (new File(input)); + writeCorpusMatrix(data, corpus, vocab); + } catch (Exception e) { + e.printStackTrace(); + } + } + +} diff --git a/src/cc/mallet/topics/tree/HIntIntDoubleHashMap.java b/src/cc/mallet/topics/tree/HIntIntDoubleHashMap.java new file mode 100755 index 000000000..4b4bf47bf --- /dev/null +++ b/src/cc/mallet/topics/tree/HIntIntDoubleHashMap.java @@ -0,0 +1,82 @@ +package cc.mallet.topics.tree; + +import java.io.Serializable; + +import gnu.trove.TIntDoubleHashMap; +import gnu.trove.TIntIntHashMap; +import gnu.trove.TIntObjectHashMap; + +/** + * This class defines a two level hashmap, so a value will be indexed by two keys. + * The value is double, and two keys are both int. + * + * @author Yuening Hu + */ + +public class HIntIntDoubleHashMap implements Serializable{ + TIntObjectHashMap data; + + public HIntIntDoubleHashMap() { + this.data = new TIntObjectHashMap (); + } + + /** + * If keys do not exist, insert value. + * Else update with the new value. + */ + public void put(int key1, int key2, double value) { + if(! this.data.contains(key1)) { + this.data.put(key1, new TIntDoubleHashMap()); + } + TIntDoubleHashMap tmp = this.data.get(key1); + tmp.put(key2, value); + } + + /** + * Return the HashMap indexed by the first key. + */ + public TIntDoubleHashMap get(int key1) { + return this.data.get(key1); + } + + /** + * Return the value indexed by key1 and key2. + */ + public double get(int key1, int key2) { + if (this.data.contains(key1)) { + TIntDoubleHashMap tmp1 = this.data.get(key1); + if (tmp1.contains(key2)) { + return tmp1.get(key2); + } + } + System.out.println("HIntIntDoubleHashMap: key does not exist!"); + return -1; + } + + /** + * Return the first key set. + */ + public int[] getKey1Set() { + return this.data.keys(); + } + + /** + * Check whether key1 is contained in the first key set or not. + */ + public boolean contains(int key1) { + return this.data.contains(key1); + } + + /** + * Check whether the key pair (key1, key2) is contained or not. + */ + public boolean contains(int key1, int key2) { + if (this.data.contains(key1)) { + return this.data.get(key1).contains(key2); + } else { + return false; + } + } + +} + diff --git a/src/cc/mallet/topics/tree/HIntIntIntHashMap.java b/src/cc/mallet/topics/tree/HIntIntIntHashMap.java new file mode 100755 index 000000000..42ecded4c --- /dev/null +++ b/src/cc/mallet/topics/tree/HIntIntIntHashMap.java @@ -0,0 +1,121 @@ +package cc.mallet.topics.tree; + +import java.io.Serializable; + +import gnu.trove.TIntDoubleHashMap; +import gnu.trove.TIntIntHashMap; +import gnu.trove.TIntObjectHashMap; + +/** + * This class defines a two level hashmap, so a value will be indexed by two keys. + * The value is int, and two keys are both int. + * + * @author Yuening Hu + */ + +public class HIntIntIntHashMap implements Serializable{ + + TIntObjectHashMap data; + + public HIntIntIntHashMap() { + this.data = new TIntObjectHashMap (); + } + + /** + * If keys do not exist, insert value. + * Else update with the new value. + */ + public void put(int key1, int key2, int value) { + if(! this.data.contains(key1)) { + this.data.put(key1, new TIntIntHashMap()); + } + TIntIntHashMap tmp = this.data.get(key1); + tmp.put(key2, value); + } + + /** + * Return the HashMap indexed by the first key. + */ + public TIntIntHashMap get(int key1) { + if(this.contains(key1)) { + return this.data.get(key1); + } + return null; + } + + /** + * Return the value indexed by key1 and key2. + */ + public int get(int key1, int key2) { + if (this.contains(key1, key2)) { + return this.data.get(key1).get(key2); + } else { + System.out.println("HIntIntIntHashMap: key does not exist!"); + return 0; + } + } + + /** + * Return the first key set. + */ + public int[] getKey1Set() { + return this.data.keys(); + } + + /** + * Check whether key1 is contained in the first key set or not. + */ + public boolean contains(int key1) { + return this.data.contains(key1); + } + + /** + * Check whether the key pair (key1, key2) is contained or not. + */ + public boolean contains(int key1, int key2) { + if (this.data.contains(key1)) { + return this.data.get(key1).contains(key2); + } else { + return false; + } + } + + /** + * Adjust the value indexed by the key pair (key1, key2) by the specified amount. + */ + public void adjustValue(int key1, int key2, int increment) { + int old = this.get(key1, key2); + this.put(key1, key2, old+increment); + } + + + /** + * If the key pair (key1, key2) exists, adjust the value by the specified amount, + * Or insert the new value. + */ + public void adjustOrPutValue(int key1, int key2, int increment, int newvalue) { + if (this.contains(key1, key2)) { + int old = this.get(key1, key2); + this.put(key1, key2, old+increment); + } else { + this.put(key1, key2, newvalue); + } + } + + /** + * Remove the first key + */ + public void removeKey1(int key1) { + this.data.remove(key1); + } + + /** + * Remove the second key + */ + public void removeKey2(int key1, int key2) { + if (this.data.contains(key1)) { + this.data.get(key1).remove(key2); + } + } + +} diff --git a/src/cc/mallet/topics/tree/HIntIntObjectHashMap.java b/src/cc/mallet/topics/tree/HIntIntObjectHashMap.java new file mode 100755 index 000000000..57c3ce80e --- /dev/null +++ b/src/cc/mallet/topics/tree/HIntIntObjectHashMap.java @@ -0,0 +1,79 @@ +package cc.mallet.topics.tree; + +import java.io.Serializable; + +import gnu.trove.TIntArrayList; +import gnu.trove.TIntIntHashMap; +import gnu.trove.TIntObjectHashMap; + +/** + * This class defines a two level hashmap, so a value will be indexed by two keys. + * The value is int, and two keys are both int. + * + * @author Yuening Hu + */ + +public class HIntIntObjectHashMap implements Serializable{ + TIntObjectHashMap> data; + + public HIntIntObjectHashMap () { + this.data = new TIntObjectHashMap>(); + } + + /** + * If keys do not exist, insert value. + * Else update with the new value. + */ + public void put(int key1, int key2, V value) { + if(! this.data.contains(key1)) { + this.data.put(key1, new TIntObjectHashMap()); + } + TIntObjectHashMap tmp = this.data.get(key1); + tmp.put(key2, value); + } + + /** + * Return the HashMap indexed by the first key. + */ + public TIntObjectHashMap get(int key1) { + return this.data.get(key1); + } + + /** + * Return the value indexed by key1 and key2. + */ + public V get(int key1, int key2) { + if (this.contains(key1, key2)) { + return this.data.get(key1).get(key2); + } else { + System.out.println("HIntIntObjectHashMap: key does not exist! " + key1 + " " + key2); + return null; + } + } + + /** + * Return the first key set. + */ + public int[] getKey1Set() { + return this.data.keys(); + } + + /** + * Check whether key1 is contained in the first key set or not. + */ + public boolean contains(int key1) { + return this.data.contains(key1); + } + + /** + * Check whether the key pair (key1, key2) is contained or not. + */ + public boolean contains(int key1, int key2) { + if (this.data.contains(key1)) { + return this.data.get(key1).contains(key2); + } else { + return false; + } + } + +} diff --git a/src/cc/mallet/topics/tree/Node.java b/src/cc/mallet/topics/tree/Node.java new file mode 100755 index 000000000..9249f1fd8 --- /dev/null +++ b/src/cc/mallet/topics/tree/Node.java @@ -0,0 +1,194 @@ +package cc.mallet.topics.tree; + +import gnu.trove.TDoubleArrayList; +import gnu.trove.TIntArrayList; + +/** + * This class defines a node, which might have children, + * and a distribution scaled by the node prior over the children. + * A node is a synset, which might have children nodes and words + * at the same time. + * + * @author Yuening Hu + */ + +public class Node { + int offset; + double rawCount; + double hypoCount; + String hyperparamName; + + TIntArrayList words; + TDoubleArrayList wordsCount; + TIntArrayList childOffsets; + + int numChildren; + int numPaths; + int numWords; + + double transitionScalor; + TDoubleArrayList transitionPrior; + + public Node() { + this.words = new TIntArrayList (); + this.wordsCount = new TDoubleArrayList (); + this.childOffsets = new TIntArrayList (); + this.transitionPrior = new TDoubleArrayList (); + this.numChildren = 0; + this.numWords = 0; + this.numPaths = 0; + } + + /** + * Initialize the prior distribution. + */ + public void initializePrior(int size) { + for (int ii = 0; ii < size; ii++ ) { + this.transitionPrior.add(0.0); + } + } + + /** + * Initialize the prior distribution. + */ + public void setOffset(int val) { + this.offset = val; + } + + /** + * set the raw count. + */ + public void setRawCount(double count) { + this.rawCount = count; + } + + /** + * set the hypo count. + */ + public void setHypoCount(double count) { + this.hypoCount = count; + } + + /** + * set the hyperparameter name of this node. + */ + public void setHyperparamName(String name) { + this.hyperparamName = name; + } + + /** + * set the prior scaler. + */ + public void setTransitionScalor(double val) { + this.transitionScalor = val; + } + + /** + * set the prior for the given child index. + */ + public void setPrior(int index, double value) { + this.transitionPrior.set(index, value); + } + + /** + * Add a child, which is defined by the offset. + */ + public void addChildrenOffset(int childOffset) { + this.childOffsets.add(childOffset); + this.numChildren += 1; + } + + /** + * Add a word. + */ + public void addWord(int wordIndex, double wordCount) { + this.words.add(wordIndex); + this.wordsCount.add(wordCount); + this.numWords += 1; + } + + /** + * Increase the number of paths. + */ + public void addPaths(int inc) { + this.numPaths += inc; + } + + /** + * return the offset of current node. + */ + public int getOffset() { + return this.offset; + } + + /** + * return the number of children. + */ + public int getNumChildren() { + return this.numChildren; + } + + /** + * return the number of words. + */ + public int getNumWords() { + return this.numWords; + } + + /** + * return the child offset given the child index. + */ + public int getChild(int child_index) { + return this.childOffsets.get(child_index); + } + + /** + * return the word given the word index. + */ + public int getWord(int word_index) { + return this.words.get(word_index); + } + + /** + * return the word count given the word index. + */ + public double getWordCount(int word_index) { + return this.wordsCount.get(word_index); + } + + /** + * return the hypocount of the node. + */ + public double getHypoCount() { + return this.hypoCount; + } + + /** + * return the transition scalor. + */ + public double getTransitionScalor() { + return this.transitionScalor; + } + + /** + * return the scaled transition prior distribution. + */ + public TDoubleArrayList getTransitionPrior() { + return this.transitionPrior; + } + + /** + * normalize the prior to be a distribution and then scale it. + */ + public void normalizePrior() { + double norm = 0; + for (int ii = 0; ii < this.transitionPrior.size(); ii++) { + norm += this.transitionPrior.get(ii); + } + for (int ii = 0; ii < this.transitionPrior.size(); ii++) { + double tmp = this.transitionPrior.get(ii) / norm; + tmp *= this.transitionScalor; + this.transitionPrior.set(ii, tmp); + } + } +} diff --git a/src/cc/mallet/topics/tree/NonZeroPath.java b/src/cc/mallet/topics/tree/NonZeroPath.java new file mode 100755 index 000000000..b58dc5b61 --- /dev/null +++ b/src/cc/mallet/topics/tree/NonZeroPath.java @@ -0,0 +1,31 @@ +package cc.mallet.topics.tree; + +import gnu.trove.TIntDoubleHashMap; +import gnu.trove.TIntIntHashMap; +import gnu.trove.TIntObjectHashMap; + + +/** + * This class defines a structure for recording nonzeropath. + * key1: type; key2: path id; value: count. + * + * @author Yuening Hu + */ + +public class NonZeroPath { + + HIntIntIntHashMap data; + + public NonZeroPath () { + this.data = new HIntIntIntHashMap(); + } + + public void put(int key1, int key2, int value) { + this.data.put(key1, key2, value); + } + + public void get(int key1, int key2) { + this.data.get(key1, key2); + } + +} diff --git a/src/cc/mallet/topics/tree/OntologyWriter.java b/src/cc/mallet/topics/tree/OntologyWriter.java new file mode 100755 index 000000000..360d355c6 --- /dev/null +++ b/src/cc/mallet/topics/tree/OntologyWriter.java @@ -0,0 +1,905 @@ +package cc.mallet.topics.tree; + +import gnu.trove.TIntArrayList; +import gnu.trove.TIntHashSet; +import gnu.trove.TIntIntHashMap; +import gnu.trove.TIntObjectHashMap; + +import java.io.FileOutputStream; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.Stack; +import java.util.TreeMap; +import java.util.TreeSet; + +import topicmod_projects_ldawn.WordnetFile.WordNetFile; +import topicmod_projects_ldawn.WordnetFile.WordNetFile.Synset; +import topicmod_projects_ldawn.WordnetFile.WordNetFile.Synset.Word; + + +/** + * Converts a set of user-selected constraints into Protocol Buffer form. + * This is an adaptation of Yuening's Python code that does the same thing. + * Following the style of Brianna's original code of OntologyWriter.java. + * + * @author Yuening Hu + */ + +public class OntologyWriter { + private Map> parents; + + private int numFiles; + private String filename; + + private int root; + private Map> vocab; + private boolean propagateCounts; + + private int maxLeaves; + private Map leafSynsets; + private Map internalSynsets; + private WordNetFile.Builder leafWn; + private WordNetFile.Builder internalWn; + private boolean finalized; + + static class WordTuple { + public int id; + public int language; + public String word; + public double count; + } + + static class VocabEntry { + public int index; + public int language; + public int flag; + } + + static class Constraint { + public ArrayList cl; + public ArrayList ml; + } + + static class Node { + public int index; + public boolean rootChild; + public String linkType; + public ArrayList children; + public int[] words; + } + + final static int ENGLISH_ID = 0; + + private OntologyWriter(String filename, boolean propagateCounts) { + this.filename = filename; + this.propagateCounts = propagateCounts; + + vocab = new TreeMap>(); + parents = new TreeMap>(); + + root = -1; + + maxLeaves = 10000; + leafSynsets = new TreeMap(); + internalSynsets = new TreeMap(); + leafWn = WordNetFile.newBuilder(); + leafWn.setRoot(-1); + internalWn = WordNetFile.newBuilder(); + internalWn.setRoot(-1); + finalized = false; + } + + private void addParent(int childId, int parentId) { + if (!parents.containsKey(childId)) { + parents.put(childId, new TreeSet()); + } + parents.get(childId).add(parentId); + } + + private List getParents(int id) { + List parentList = new ArrayList(); + if (!parents.containsKey(id) || parents.get(id).size() == 0) { + if (this.root < 0) + this.root = id; + return new ArrayList(); + } else { + parentList.addAll(parents.get(id)); + for (int parentId : parents.get(id)) { + parentList.addAll(getParents(parentId)); + } + } + return parentList; + } + + private int getTermId(int language, String term) { + if (!vocab.containsKey(language)) { + vocab.put(language, new TreeMap()); + } + if (!vocab.get(language).containsKey(term)) { + int length = vocab.get(language).size(); + vocab.get(language).put(term, length); + } + return vocab.get(language).get(term); + } + + private void findRoot(Map synsets) { + for (int synsetId : synsets.keySet()) { + if (synsetId % 1000 == 0) { + System.out.println("Finalizing " + synsetId); + } + for (int parentId : getParents(synsetId)) { + if (propagateCounts) { + double hypCount = this.internalSynsets.get(parentId).getHyponymCount(); + double rawCount = synsets.get(synsetId).getRawCount(); + this.internalSynsets.get(parentId).setHyponymCount(hypCount + rawCount); + } + } + } + } + + // Named this so it doesn't conflict with Object.finalize + private void finalizeMe() throws Exception { + findRoot(this.leafSynsets); + for(int id : this.leafSynsets.keySet()) { + this.leafWn.addSynsets(this.leafSynsets.get(id)); + } + write(this.leafWn); + + findRoot(this.internalSynsets); + if(this.root < 0) { + System.out.println("No root has been found!"); + throw new Exception(); + } + this.internalWn.setRoot(this.root); + for(int id : this.internalSynsets.keySet()) { + this.internalWn.addSynsets(this.internalSynsets.get(id)); + } + write(this.internalWn); + } + + private void write(WordNetFile.Builder wnFile) { + try { + String newFilename = filename + "." + numFiles; + WordNetFile builtFile = wnFile.build(); + builtFile.writeTo(new FileOutputStream(newFilename)); + System.out.println("Serialized version written to: " + newFilename); + this.numFiles ++; + } catch (Exception e) { + e.printStackTrace(); + } + } + + private void addSynset(int numericId, String senseKey, List children, + List words) { + Synset.Builder synset = Synset.newBuilder(); + + double rawCount = 0.0; + synset.setOffset(numericId); + synset.setKey(senseKey); + + if(senseKey.startsWith("ML_")) { + synset.setHyperparameter("ML_"); + } else if(senseKey.startsWith("CL_")) { + synset.setHyperparameter("CL_"); + } else if(senseKey.startsWith("NL_")) { + synset.setHyperparameter("NL_"); + } else if(senseKey.startsWith("ROOT")) { + synset.setHyperparameter("NL_"); + } else if(senseKey.startsWith("LEAF_")) { + synset.setHyperparameter("NL_"); + } else { + synset.setHyperparameter("DEFAULT_"); + } + + if(children != null) { + for (int child : children){ + addParent(child, numericId); + synset.addChildrenOffsets(child); + } + } + + if(words != null) { + for (WordTuple tuple : words) { + Word.Builder word = Word.newBuilder(); + word.setLangId(tuple.language); + //word.setTermId(getTermId(tuple.language, tuple.word)); + word.setTermId(tuple.id); + word.setTermStr(tuple.word); + word.setCount(tuple.count); + rawCount += tuple.count; + synset.addWords(word); + synset.setRawCount(rawCount); + } + } + + synset.setHyponymCount(rawCount + 0.0); + + if(children != null && children.size() > 0) { + //this.internalWn.addSynsets(synset.clone()); + this.internalSynsets.put(numericId, synset); + } else { + //this.leafWn.addSynsets(synset.clone()); + this.leafSynsets.put(numericId, synset); + } + } + + private static ArrayList getVocab(String filename) { + ArrayList vocab = new ArrayList(); + int index = 0; + try { + List lines = Utils.readAll(filename); + for (String line : lines) + { + String[] words = line.trim().split("\t"); + if (words.length > 1) { + vocab.add(words[1]); + } else { + System.out.println("Error! " + index); + } + index++; + } + } catch (Exception e) { + e.printStackTrace(); + } + return vocab; + } + + private static void readConstraints(String consfilename, ArrayList vocab, + ArrayList ml, ArrayList cl) { + //List constraints = new ArrayList(); + + try { + List lines = Utils.readAll(consfilename); + for (String line : lines) { + String[] words = line.trim().split("\t"); + int[] indexWords = new int[words.length - 1]; + for(int ii = 1; ii < words.length; ii++) { + int index = vocab.indexOf(words[ii]); + if (index == -1) { + System.out.println("Found words that not contained in the vocab: " + words[ii]); + throw new Exception(); + } + indexWords[ii-1] = index; + } + + //for(int ii = 0; ii < indexWords.length; ii++) { + // System.out.print(indexWords[ii] + " "); + //} + + if (words[0].equals("SPLIT_")) { + cl.add(indexWords); + } else if (words[0].equals("MERGE_")) { + ml.add(indexWords); + } + } + } catch (Exception e) { + e.printStackTrace(); + } + + } + + private static void generateGraph(ArrayList cons, int value, HIntIntIntHashMap graph) { + for (int[] con : cons) { + for (int w1 : con) { + for (int w2 : con) { + if ( w1 != w2) { + graph.put(w1, w2, value); + graph.put(w2, w1, value); + } + } + } + } + } + + private static ArrayList BFS(HIntIntIntHashMap graph, int[] consWords, int choice) { + ArrayList connected = new ArrayList (); + TIntIntHashMap visited = new TIntIntHashMap (); + for (int word : consWords) { + visited.put(word, 0); + } + + for (int word : consWords) { + if (visited.get(word) == 0) { + Stack queue = new Stack (); + queue.push(word); + TIntHashSet component = new TIntHashSet (); + while (queue.size() > 0) { + int node = queue.pop(); + component.add(node); + for(int neighbor : graph.get(node).keys()) { + if (choice == -1) { + if (graph.get(node, neighbor) > 0 && visited.get(neighbor) == 0) { + visited.adjustValue(neighbor, 1); + queue.push(neighbor); + } + } else { + if (graph.get(node, neighbor) == choice && visited.get(neighbor) == 0) { + visited.adjustValue(neighbor, 1); + queue.push(neighbor); + } + } + } + } + connected.add(component.toArray()); + } + } + + return connected; + + } + + private static ArrayList mergeML(ArrayList ml) { + HIntIntIntHashMap graph = new HIntIntIntHashMap (); + generateGraph(ml, 1, graph); + int[] consWords = graph.getKey1Set(); + + ArrayList ml_merged = BFS(graph, consWords, -1); + return ml_merged; + } + + private static void mergeCL(ArrayList cl_ml_merged, + ArrayList ml_remained, HIntIntIntHashMap graph) throws Exception { + int[] consWords = graph.getKey1Set(); + // get connected components + ArrayList connectedComp = BFS(graph, consWords, -1); + + // merge ml to cl + for(int[] comp : connectedComp) { + ArrayList cl_tmp = BFS(graph, comp, 1); + ArrayList ml_tmp = BFS(graph, comp, 2); + + ArrayList cl_new = new ArrayList (); + for(int[] cons : cl_tmp) { + if (cons.length > 1) { + cl_new.add(cons); + } + } + + ArrayList ml_new = new ArrayList (); + for(int[] cons : ml_tmp) { + if (cons.length > 1) { + ml_new.add(cons); + } + } + + if(cl_new.size() > 0) { + Constraint cons = new Constraint(); + cons.cl = cl_new; + cons.ml = ml_new; + cl_ml_merged.add(cons); + } else { + if (ml_new.size() != 1) { + System.out.println("ml_new.size != 1 && cl_new.size == 0"); + throw new Exception(); + } + Constraint cons = new Constraint(); + cons.cl = null; + cons.ml = ml_new; + ml_remained.add(cons); + } + } + } + + private static HIntIntIntHashMap flipGraph(ArrayList cl_merged, ArrayList ml_merged, + HIntIntIntHashMap graph) { + + HIntIntIntHashMap flipped = new HIntIntIntHashMap (); + TIntHashSet consWordsSet = getConsWords(cl_merged); + TIntHashSet set2 = getConsWords(ml_merged); + consWordsSet.addAll(set2.toArray()); + int[] consWords = consWordsSet.toArray(); + + for(int word : consWords) { + TIntHashSet cl_neighbor = new TIntHashSet (); + for(int neighbor : graph.get(word).keys()) { + if(graph.get(word, neighbor) == 1) { + cl_neighbor.add(neighbor); + } + } + consWordsSet.removeAll(cl_neighbor.toArray()); + for(int nonConnected : consWordsSet.toArray()) { + flipped.put(word, nonConnected, 1); + } + for(int neighbor : graph.get(word).keys()) { + flipped.put(word, neighbor, 0); + } + consWordsSet.addAll(cl_neighbor.toArray()); + } + + //printGraph(flipped, "flipped half"); + for(int[] ml : ml_merged) { + for(int w1 : ml) { + for(int w2 : flipped.get(w1).keys()) { + if(flipped.get(w1, w2) > 0) { + for(int w3 : ml) { + if(w1 != w3 && flipped.get(w3, w2) == 0) { + flipped.put(w1, w2, 0); + flipped.put(w2, w1, 0); + } + } + } + } + + for(int w2 : ml) { + if (w1 != w2) { + flipped.put(w1, w2, 2); + flipped.put(w2, w1, 2); + } + } + } + } + return flipped; + } + + private static TIntHashSet getUnion(TIntHashSet set1, TIntHashSet set2) { + TIntHashSet union = new TIntHashSet (); + union.addAll(set1.toArray()); + union.addAll(set2.toArray()); + return union; + } + + private static TIntHashSet getDifference(TIntHashSet set1, TIntHashSet set2) { + TIntHashSet diff = new TIntHashSet (); + diff.addAll(set1.toArray()); + diff.removeAll(set2.toArray()); + return diff; + } + + private static TIntHashSet getIntersection(TIntHashSet set1, TIntHashSet set2) { + TIntHashSet inter = new TIntHashSet (); + for(int ww : set1.toArray()) { + if(set2.contains(ww)) { + inter.add(ww); + } + } + return inter; + } + + private static void BronKerBosch_v2(TIntHashSet R, TIntHashSet P, TIntHashSet X, + HIntIntIntHashMap G, ArrayList C) { + if(P.size() == 0 && X.size() == 0) { + if(R.size() > 0) { + C.add(R.toArray()); + } + return; + } + + int d = 0; + int pivot = -1; + + TIntHashSet unionPX = getUnion(P, X); + for(int v : unionPX.toArray()) { + TIntHashSet neighbors = new TIntHashSet (); + for(int node : G.get(v).keys()) { + if(G.get(v, node) > 0 && v != node) { + neighbors.add(node); + } + } + if(neighbors.size() > d) { + d = neighbors.size(); + pivot = v; + } + } + + TIntHashSet neighbors = new TIntHashSet (); + if(pivot != -1) { + for(int node : G.get(pivot).keys()) { + if (G.get(pivot, node) > 0 && pivot != node) { + neighbors.add(node); + } + } + } + + TIntHashSet diffPN = getDifference(P, neighbors); + for(int v : diffPN.toArray()) { + TIntHashSet newNeighbors = new TIntHashSet(); + for(int node : G.get(v).keys()) { + if(G.get(v, node) > 0 && v != node) { + newNeighbors.add(node); + } + } + + TIntHashSet unionRV = new TIntHashSet(); + unionRV.add(v); + unionRV.addAll(R.toArray()); + BronKerBosch_v2(unionRV, getIntersection(P, newNeighbors), getIntersection(X, newNeighbors), G, C); + + P.remove(v); + X.add(v); + } + } + + private static ArrayList generateCliques(HIntIntIntHashMap graph) { + + TIntHashSet R = new TIntHashSet (); + TIntHashSet P = new TIntHashSet (); + TIntHashSet X = new TIntHashSet (); + ArrayList cliques = new ArrayList (); + P.addAll(graph.getKey1Set()); + + BronKerBosch_v2(R, P, X, graph, cliques); + + return cliques; + } + + private static int generateCLTree(ArrayList cl_merged, + HIntIntIntHashMap graph, TIntObjectHashMap subtree) { + // the index of root is 0 + int index = 0; + for(Constraint con : cl_merged) { + ArrayList cl = con.cl; + ArrayList ml = con.ml; + HIntIntIntHashMap flipped = flipGraph(cl, ml, graph); + //printGraph(flipped, "flipped graph"); + ArrayList cliques = generateCliques(flipped); + //printArrayList(cliques, "cliques found from flipped graph"); + + Node cl_node = new Node(); + cl_node.index = ++index; + cl_node.rootChild = true; + cl_node.linkType = "CL_"; + cl_node.children = new ArrayList(); + for(int[] clique : cliques) { + TIntHashSet clique_remained = new TIntHashSet(clique); + //printHashSet(clique_remained, "clique_remained"); + ArrayList ml_tmp = BFS(graph, clique, 2); + ArrayList ml_new = new ArrayList (); + for(int[] ml_con : ml_tmp) { + if (ml_con.length > 1) { + ml_new.add(ml_con); + for(int word : ml_con) { + clique_remained.remove(word); + } + } + } + //printHashSet(clique_remained, "clique_remained"); + + Node node = new Node(); + node.index = ++index; + node.rootChild = false; + cl_node.children.add(node.index); + if(ml_new.size() == 0) { + node.linkType = "NL_"; + node.children = null; + node.words = clique_remained.toArray(); + } else if(clique_remained.size() == 0 && ml_new.size() == 1) { + node.linkType = "ML_"; + node.children = null; + node.words = ml_new.get(0); + } else { + node.linkType = "NL_IN_"; + node.children = new ArrayList(); + node.words = null; + for(int[] ml_clique : ml_new) { + Node child_node = new Node(); + child_node.index = ++index; + node.rootChild = false; + child_node.linkType = "ML_"; + child_node.children = null; + child_node.words = ml_clique; + node.children.add(index); + subtree.put(child_node.index, child_node); + } + if(clique_remained.size() > 0) { + Node child_node = new Node(); + child_node.index = ++index; + node.rootChild = false; + child_node.linkType = "NL_"; + child_node.children = null; + child_node.words = clique_remained.toArray(); + node.children.add(index); + subtree.put(child_node.index, child_node); + } + } + subtree.put(node.index, node); + } + subtree.put(cl_node.index, cl_node); + } + + return index; + } + + private static int generateMLTree(ArrayList ml_remained, + TIntObjectHashMap subtree, int index) { + //printConstraintList(ml_remained, "remained"); + int ml_index = index; + for(Constraint con : ml_remained) { + for(int[] ml : con.ml) { + Node node = new Node(); + node.index = ++ml_index; + node.rootChild = true; + node.linkType = "ML_"; + node.children = null; + node.words = ml; + subtree.put(node.index, node); + } + } + return ml_index; + } + + private static TIntHashSet getConsWords(ArrayList cons) { + TIntHashSet consWords = new TIntHashSet(); + for(int[] con : cons) { + consWords.addAll(con); + } + return consWords; + } + + private static TIntObjectHashMap mergeAllConstraints(ArrayList ml, ArrayList cl) { + //printArrayList(ml, "read in ml"); + //printArrayList(cl, "read in cl"); + + // merge ml constraints + ArrayList ml_merged = mergeML(ml); + + // generate graph + HIntIntIntHashMap graph = new HIntIntIntHashMap (); + generateGraph(cl, 1, graph); + generateGraph(ml_merged, 2, graph); + //printGraph(graph, "original graph"); + + // merge cl: notice some ml can be merged into cl, the remained ml will be kept + ArrayList cl_ml_merged = new ArrayList (); + ArrayList ml_remained = new ArrayList (); + try { + mergeCL(cl_ml_merged, ml_remained, graph); + } catch (Exception e) { + e.printStackTrace(); + } + + //printConstraintList(cl_ml_merged, "cl ml merged"); + //printConstraintList(ml_remained, "ml_remained"); + + TIntObjectHashMap subtree = new TIntObjectHashMap(); + int index = generateCLTree(cl_ml_merged, graph, subtree); + int new_index = generateMLTree(ml_remained, subtree, index); + + return subtree; + } + + private static TIntObjectHashMap noMergeConstraints(ArrayList ml, ArrayList cl) { + TIntObjectHashMap subtree = new TIntObjectHashMap(); + int index = 0; + for(int[] cons : ml) { + Node node = new Node(); + node.index = ++index; + node.rootChild = true; + node.linkType = "ML_"; + node.children = null; + node.words = cons; + subtree.put(node.index, node); + } + + for(int[] cons : cl) { + Node node = new Node(); + node.index = ++index; + node.rootChild = true; + node.linkType = "CL_"; + node.children = null; + node.words = cons; + subtree.put(node.index, node); + } + + return subtree; + } + + private static void printHashSet(TIntHashSet set, String title){ + String tmp = title + ": "; + for(int word : set.toArray()) { + tmp += word + " "; + } + } + + private static void printConstraintList(ArrayList constraints, String title) { + System.out.println(title + ": "); + for(Constraint cons : constraints) { + String tmp = ""; + if(cons.ml != null) { + tmp = "ml: "; + for(int[] ml : cons.ml) { + tmp += "( "; + for(int ww : ml) { + tmp += ww + " "; + } + tmp += ") "; + } + } + if(cons.cl != null) { + tmp += "cl: "; + for(int[] cl : cons.cl) { + tmp += "( "; + for(int ww : cl) { + tmp += ww + " "; + } + tmp += ") "; + } + } + System.out.println(tmp); + } + } + + private static void printGraph(HIntIntIntHashMap graph, String title) { + System.out.println(title + ": "); + for(int w1 : graph.getKey1Set()) { + String tmp = ""; + for(int w2 : graph.get(w1).keys()) { + tmp += "( " + w1 + " " + w2 + " : " + graph.get(w1, w2) + " ) "; + } + System.out.println(tmp); + } + } + + private static void printArrayList(ArrayList result, String title) { + System.out.println(title + ": "); + for(int[] sent : result) { + String line = ""; + for (int word : sent) { + line += word + " "; + } + System.out.println(line); + } + } + + private static void printSubTree(TIntObjectHashMap subtree) { + for(int index : subtree.keys()) { + Node node = subtree.get(index); + String tmp = "Node " + index + " : "; + tmp += node.linkType + " "; + if(node.children != null) { + tmp += "chilren ["; + for(int child : node.children) { + tmp += child + " "; + } + tmp += "]"; + } + if (node.words != null) { + tmp += " words [ "; + for(int word : node.words) { + tmp += word + " "; + } + tmp += "]"; + } + System.out.println(tmp); + } + } + + /** + * This is the top-level method that creates the ontology from a set of + * Constraint objects. + * @param vocabFilename the .voc file corresponding to the data set + * being used + * @throws Exception + */ + public static void createOntology(String consFilename, String vocabFilename, + String outputDir, boolean mergeConstraints) throws Exception { + + // load vocab + int LANG_ID = 0; + ArrayList vocab = getVocab(vocabFilename); + System.out.println("Load vocab size: " + vocab.size()); + // load constraints and make sure all constraints words are contained in vocab + ArrayList ml = new ArrayList (); + ArrayList cl = new ArrayList (); + if(consFilename != null) { + readConstraints(consFilename, vocab, ml, cl); + } + + // merge constraints + TIntObjectHashMap subtree; + if (mergeConstraints) { + subtree = mergeAllConstraints(ml, cl); + } else { + subtree = noMergeConstraints(ml, cl); + } + printSubTree(subtree); + + // get constraint count (If count == 0, it is unconstraint words) + int[] vocabFlag = new int[vocab.size()]; + for(int ii = 0; ii < vocabFlag.length; ii++) { + vocabFlag[ii] = 0; + } + for(int index : subtree.keys()) { + Node node = subtree.get(index); + if(node.words != null) { + for(int wordIndex : node.words) { + vocabFlag[wordIndex]++; + } + } + } + + ///////////////// + + OntologyWriter writer = new OntologyWriter(outputDir, true); + List rootChildren = new ArrayList(); + + int leafIndex = subtree.size(); + for(int index : subtree.keys()) { + Node node = subtree.get(index); + List nodeChildren = null; + ArrayList nodeWords = null; + if(node.rootChild) { + rootChildren.add(node.index); + } + if(node.children != null && node.words != null) { + System.out.println("A node has both children and words! Wrong!"); + throw new Exception(); + } else if(node.children != null) { + nodeChildren = node.children; + } else if(node.words != null) { + if(node.words.length == 1) { + nodeWords = new ArrayList (); + WordTuple wt = new WordTuple(); + wt.id = node.words[0]; + wt.language = LANG_ID; + wt.word = vocab.get(wt.id); + wt.count = 1.0 / vocabFlag[wt.id]; + nodeWords.add(wt); + } else { + nodeChildren = new ArrayList(); + for(int wordIndex : node.words) { + leafIndex++; + nodeChildren.add(leafIndex); + List leafChildren = null; + ArrayList leafWords = new ArrayList (); + WordTuple wt = new WordTuple(); + wt.id = wordIndex; + wt.language = LANG_ID; + wt.word = vocab.get(wordIndex); + wt.count = 1.0 / vocabFlag[wordIndex]; + leafWords.add(wt); + String name = "LEAF_" + leafIndex + "_" + wt.word; + writer.addSynset(leafIndex, name, leafChildren, leafWords); + } + } + } + + if(node.words != null && node.words.length == 1) { + node.linkType = "LEAF_"; + String name = node.linkType + node.index + "_" + vocab.get(node.words[0]); + writer.addSynset(node.index, name, nodeChildren, nodeWords); + } else { + writer.addSynset(node.index, node.linkType + node.index, nodeChildren, nodeWords); + } + + } + + // Unused words + for(int wordIndex = 0; wordIndex < vocabFlag.length; wordIndex++) { + if (vocabFlag[wordIndex] == 0) { + rootChildren.add(++leafIndex); + List leafChildren = null; + ArrayList leafWords = new ArrayList (); + WordTuple wt = new WordTuple(); + wt.id = wordIndex; + wt.language = LANG_ID; + wt.word = vocab.get(wordIndex); + wt.count = 1.0; + leafWords.add(wt); + String name = "LEAF_" + leafIndex + "_" + wt.word; + writer.addSynset(leafIndex, name, leafChildren, leafWords); + } + } + + writer.addSynset(0, "ROOT", rootChildren, null); + writer.finalizeMe(); + } + + + public static void main(String[] args) { + String vocabFn = "input/toy/toy.voc"; + String consFile = "input/toy/toy.cons"; + String outputFn = "input/toy/toy_test.wn"; + boolean mergeConstraints = true; + try { + createOntology(consFile, vocabFn, outputFn, mergeConstraints); + } catch (Exception e){ + e.printStackTrace(); + } + } +} diff --git a/src/cc/mallet/topics/tree/Path.java b/src/cc/mallet/topics/tree/Path.java new file mode 100755 index 000000000..abcbf7e7a --- /dev/null +++ b/src/cc/mallet/topics/tree/Path.java @@ -0,0 +1,53 @@ +package cc.mallet.topics.tree; + +import gnu.trove.TIntArrayList; + +/** + * This class defines a path. + * A path is a list of nodes, and the last node emits a word. + * + * @author Yuening Hu + */ + +public class Path { + + TIntArrayList nodes; + //TIntArrayList children; + int finalWord; + + public Path () { + this.nodes = new TIntArrayList(); + this.finalWord = -1; + } + + /** + * Add nodes to this path. + */ + public void addNodes (TIntArrayList innodes) { + for (int ii = 0; ii < innodes.size(); ii++) { + int node_index = innodes.get(ii); + this.nodes.add(node_index); + } + } + + /** + * Add the final word of this path. + */ + public void addFinalWord(int word) { + this.finalWord = word; + } + + /** + * return the node list. + */ + public TIntArrayList getNodes() { + return this.nodes; + } + + /** + * return the final word. + */ + public int getFinalWord() { + return this.finalWord; + } +} diff --git a/src/cc/mallet/topics/tree/PriorTree.java b/src/cc/mallet/topics/tree/PriorTree.java new file mode 100755 index 000000000..93a190974 --- /dev/null +++ b/src/cc/mallet/topics/tree/PriorTree.java @@ -0,0 +1,363 @@ +package cc.mallet.topics.tree; + +import gnu.trove.TIntArrayList; +import gnu.trove.TIntObjectHashMap; +import gnu.trove.TIntObjectIterator; +import gnu.trove.TObjectDoubleHashMap; + +import java.io.BufferedReader; +import java.io.DataInputStream; +import java.io.File; +import java.io.FileInputStream; +import java.io.FileNotFoundException; +import java.io.FilenameFilter; +import java.io.IOException; +import java.io.InputStreamReader; +import java.util.ArrayList; + +import cc.mallet.types.Alphabet; +import cc.mallet.types.InstanceList; + +import topicmod_projects_ldawn.WordnetFile.WordNetFile; +import topicmod_projects_ldawn.WordnetFile.WordNetFile.Synset; +import topicmod_projects_ldawn.WordnetFile.WordNetFile.Synset.Word; + +/** + * This class loads the prior tree structure from the proto buffer files of tree structure. + * Main entrance: initialize() + * + * @author Yuening Hu + */ + +public class PriorTree { + + int root; + int maxDepth; + + TObjectDoubleHashMap hyperparams; + TIntObjectHashMap nodes; + TIntObjectHashMap> wordPaths; + + public PriorTree () { + this.hyperparams = new TObjectDoubleHashMap (); + this.nodes = new TIntObjectHashMap (); + this.wordPaths = new TIntObjectHashMap> (); + } + + /** + * Get the input tree file lists from the given tree file names + */ + private ArrayList getFileList(String tree_files) { + + int split_index = tree_files.lastIndexOf('/'); + String dirname = tree_files.substring(0, split_index); + String fileprefix = tree_files.substring(split_index+1); + fileprefix = fileprefix.replace("*", ""); + + //System.out.println(dirname); + //System.out.println(fileprefix); + + File dir = new File(dirname); + String[] children = dir.list(); + ArrayList filelist = new ArrayList(); + + for (int i = 0; i < children.length; i++) { + if (children[i].startsWith(fileprefix)) { + System.out.println("Found one: " + dirname + "/" + children[i]); + String filename = dirname + "/" + children[i]; + filelist.add(filename); + } + } + return filelist; + } + + /** + * Load hyper parameters from the given file + */ + private void loadHyperparams(String hyperFile) { + try { + FileInputStream infstream = new FileInputStream(hyperFile); + DataInputStream in = new DataInputStream(infstream); + BufferedReader br = new BufferedReader(new InputStreamReader(in)); + + String strLine; + //Read File Line By Line + while ((strLine = br.readLine()) != null) { + strLine = strLine.trim(); + String[] str = strLine.split(" "); + if (str.length != 2) { + System.out.println("Hyperparameter file is not in the correct format!"); + System.exit(0); + } + double tmp = Double.parseDouble(str[1]); + hyperparams.put(str[0], tmp); + } + in.close(); + +// Iterator> it = hyperparams.entrySet().iterator(); +// while (it.hasNext()) { +// Map.Entry entry = it.next(); +// System.out.println(entry.getKey()); +// System.out.println(entry.getValue()); +// } + + } catch (IOException e) { + System.out.println("No hyperparameter file Found!"); + } + } + + /** + * Load tree nodes one by one: load the children, words of each node + */ + private void loadTree(String tree_files, ArrayList vocab) { + + ArrayList filelist = getFileList(tree_files); + + for (int ii = 0; ii < filelist.size(); ii++) { + String filename = filelist.get(ii); + WordNetFile tree = null; + try { + tree = WordNetFile.parseFrom(new FileInputStream(filename)); + } catch (IOException e) { + System.out.println("Cannot find tree file: " + filename); + } + + int new_root = tree.getRoot(); + assert( (new_root == -1) || (this.root == -1) || (new_root == this.root)); + if (new_root >= 0) { + this.root = new_root; + } + + for (int jj = 0; jj < tree.getSynsetsCount(); jj++) { + Synset synset = tree.getSynsets(jj); + Node n = new Node(); + n.setOffset(synset.getOffset()); + n.setRawCount(synset.getRawCount()); + n.setHypoCount(synset.getHyponymCount()); + + double transition = hyperparams.get(synset.getHyperparameter()); + n.setTransitionScalor(transition); + for (int cc = 0; cc < synset.getChildrenOffsetsCount(); cc++) { + n.addChildrenOffset(synset.getChildrenOffsets(cc)); + } + + for (int ww = 0; ww < synset.getWordsCount(); ww++) { + Word word = synset.getWords(ww); + int term_id = vocab.indexOf(word.getTermStr()); + //int term_id = vocab.lookupIndex(word.getTermStr()); + double word_count = word.getCount(); + n.addWord(term_id, word_count); + } + + nodes.put(n.getOffset(), n); + } + } + + assert(this.root >= 0) : "Cannot find a root node in the tree file. Have you provided " + + "all tree files instead of a single tree file? (e.g., use 'tree.wn' instead of 'tree.wn.0')"; + + } + + /** + * Get all the paths in the tree, keep the (word, path) pairs + * Note the word in the pair is actually the word of the leaf node + */ + private int searchDepthFirst(int depth, + int node_index, + TIntArrayList traversed, + TIntArrayList next_pointers) { + int max_depth = depth; + traversed.add(node_index); + Node current_node = this.nodes.get(node_index); + current_node.addPaths(1); + + // go over the words that current node emits + for (int ii = 0; ii < current_node.getNumWords(); ii++) { + int word = current_node.getWord(ii); + Path p = new Path(); + p.addNodes(traversed); + // p.addChildren(next_pointers); + p.addFinalWord(word); + if (! this.wordPaths.contains(word)) { + this.wordPaths.put(word, new ArrayList ()); + } + ArrayList tmp = this.wordPaths.get(word); + tmp.add(p); + } + + // go over the child nodes of the current node + for (int ii = 0; ii < current_node.getNumChildren(); ii++) { + int child = current_node.getChild(ii); + next_pointers.add(child); + int child_depth = this.searchDepthFirst(depth+1, child, traversed, next_pointers); + next_pointers.remove(next_pointers.size()-1); + max_depth = max_depth >= child_depth ? max_depth : child_depth; + } + + traversed.remove(traversed.size()-1); + return max_depth; + } + + /** + * Set the scaled prior distribution of each node + * According to the hypoCount of the nodes' children, generate a Multinomial + * distribution, then scaled by transitionScalor + */ + private void setPrior() { + for (TIntObjectIterator it = this.nodes.iterator(); it.hasNext(); ) { + it.advance(); + Node n = it.value(); + int numChildren = n.getNumChildren(); + int numWords = n.getNumWords(); + + // firstly set the hypoCount for each child + if (numChildren > 0) { + assert numWords == 0; + n.initializePrior(numChildren); + for (int ii = 0; ii < numChildren; ii++) { + int child = n.getChild(ii); + n.setPrior(ii, this.nodes.get(child).getHypoCount()); + } + } + + // this step is for tree structures whose leaf nodes contain more than one words + // if the leaf node contains multiple words, we will treat each word + // as a "leaf node" and set a multinomial over all the words + // if the leaf node contains only one word, so this step will be jumped over. + if (numWords > 1) { + assert numChildren == 0; + n.initializePrior(numWords); + for (int ii = 0; ii < numWords; ii++) { + n.setPrior(ii, n.getWordCount(ii)); + } + } + + // then normalize and scale + n.normalizePrior(); + } + } + + /** + * the entrance of this class + */ + public void initialize(String treeFiles, String hyperFile, ArrayList vocab) { + this.loadHyperparams(hyperFile); + this.loadTree(treeFiles, vocab); + + TIntArrayList traversed = new TIntArrayList (); + TIntArrayList next_pointers = new TIntArrayList (); + //this.maxDepth = this.searchDepthFirst(0, 0, traversed, next_pointers); + this.maxDepth = this.searchDepthFirst(0, this.root, traversed, next_pointers); + this.setPrior(); + + //System.out.println("**************************"); + // check the word paths + System.out.println("Number of words: " + this.wordPaths.size()); + //System.out.println("Initialized paths"); + + /* + for (TIntObjectIterator> it = this.wordPaths.iterator(); it.hasNext(); ) { + it.advance(); + ArrayList paths = it.value(); + System.out.print(it.key() + ", " + vocab.get(it.key())); + //System.out.print(it.key() + ", " + vocab.lookupObject(it.key())); + for (int ii = 0; ii < paths.size(); ii++) { + Path p = paths.get(ii); + System.out.print(", Path " + ii); + System.out.print(", Path nodes list: " + p.getNodes()); + System.out.println(", Path final word: " + p.getFinalWord()); + } + } + System.out.println("**************************"); + + // check the prior + System.out.println("Check the prior"); + for (TIntObjectIterator it = this.nodes.iterator(); it.hasNext(); ) { + it.advance(); + if (it.value().getTransitionPrior().size() > 0) { + System.out.print("Node " + it.key()); + System.out.println(", Transition prior " + it.value().getTransitionPrior()); + } + } + System.out.println("**************************"); + */ + + } + + public int getMaxDepth() { + return this.maxDepth; + } + + public int getRoot() { + return this.root; + } + + public TIntObjectHashMap getNodes() { + return this.nodes; + } + + public TIntObjectHashMap> getWordPaths() { + return this.wordPaths; + } + + /** + * Load vocab + */ + public ArrayList readVocab(String vocabFile) { + + ArrayList vocab = new ArrayList (); + + try { + FileInputStream infstream = new FileInputStream(vocabFile); + DataInputStream in = new DataInputStream(infstream); + BufferedReader br = new BufferedReader(new InputStreamReader(in)); + + String strLine; + //Read File Line By Line + while ((strLine = br.readLine()) != null) { + strLine = strLine.trim(); + String[] str = strLine.split("\t"); + if (str.length > 1) { + vocab.add(str[1]); + } else { + System.out.println("Error! " + strLine); + return null; + } + } + in.close(); + + } catch (IOException e) { + System.out.println("No vocab file Found!"); + } + + return vocab; + } + + /** + * test main + */ + public static void main(String[] args) throws Exception{ + + //String treeFiles = "../toy/toy_set1.wn.*"; + //String hyperFile = "../toy/tree_hyperparams"; + //String inputFile = "../input/toy-topic-input.mallet"; + //String vocabFile = "../toy/toy.voc"; + + //String treeFiles = "../synthetic/synthetic_set1.wn.*"; + //String hyperFile = "../synthetic/tree_hyperparams"; + //String inputFile = "../input/synthetic-topic-input.mallet"; + //String vocabFile = "../synthetic/synthetic.voc"; + + String treeFiles = "input/denews.all.wn"; + String hyperFile = "input/denews.hyper"; + String inputFile = "input/denews-topic-input.mallet"; + String vocabFile = "input/denews.filter.voc"; + + PriorTree tree = new PriorTree(); + ArrayList vocab = tree.readVocab(vocabFile); + + InstanceList ilist = InstanceList.load (new File(inputFile)); + tree.initialize(treeFiles, hyperFile, vocab); + } + +} diff --git a/src/cc/mallet/topics/tree/TopicSampler.java b/src/cc/mallet/topics/tree/TopicSampler.java new file mode 100755 index 000000000..e7ffe00fe --- /dev/null +++ b/src/cc/mallet/topics/tree/TopicSampler.java @@ -0,0 +1,278 @@ +package cc.mallet.topics.tree; + +import gnu.trove.TDoubleArrayList; +import gnu.trove.TIntArrayList; +import gnu.trove.TIntHashSet; +import gnu.trove.TIntIntHashMap; +import gnu.trove.TIntIntIterator; + +import java.io.BufferedReader; +import java.io.DataInputStream; +import java.io.File; +import java.io.FileInputStream; +import java.io.IOException; +import java.io.InputStreamReader; +import java.io.PrintStream; +import java.io.Serializable; +import java.util.ArrayList; +import java.util.Arrays; + +import cc.mallet.topics.tree.TreeTopicSamplerHashD.DocData; +import cc.mallet.types.Dirichlet; +import cc.mallet.types.FeatureSequence; +import cc.mallet.types.Instance; +import cc.mallet.types.InstanceList; +import cc.mallet.util.Randoms; + +/** + * Abstract class for TopicSampler. + * Defines the basic functions for input, output, resume. + * Also defines the abstract functions for child class. + * + * @author Yuening Hu + */ + +public abstract class TopicSampler{ + + int numTopics; + int numIterations; + int startIter; + Randoms random; + double[] alpha; + double alphaSum; + TDoubleArrayList lhood; + TDoubleArrayList iterTime; + ArrayList vocab; + + TreeTopicModel topics; + TIntHashSet cons; + + public TopicSampler (int numberOfTopics, double alphaSum, int seed) { + this.numTopics = numberOfTopics; + this.random = new Randoms(seed); + + this.alphaSum = alphaSum; + this.alpha = new double[numTopics]; + Arrays.fill(alpha, alphaSum / numTopics); + + this.vocab = new ArrayList (); + this.cons = new TIntHashSet(); + + this.lhood = new TDoubleArrayList(); + this.iterTime = new TDoubleArrayList(); + this.startIter = 0; + + // notice: this.topics and this.data are not initialized in this abstract class, + // in each sub class, the topics variable is initialized differently. + } + + + + public void setNumIterations(int iters) { + this.numIterations = iters; + } + + public int getNumIterations() { + return this.numIterations; + } + + + + /** + * This function returns the likelihood. + */ + public double lhood() { + return this.docLHood() + this.topics.topicLHood(); + } + + /** + * Resume lhood and iterTime from the saved lhood file. + */ + public void resumeLHood(String lhoodFile) throws IOException{ + FileInputStream lhoodfstream = new FileInputStream(lhoodFile); + DataInputStream lhooddstream = new DataInputStream(lhoodfstream); + BufferedReader brLHood = new BufferedReader(new InputStreamReader(lhooddstream)); + // the first line is the title + String strLine = brLHood.readLine(); + while ((strLine = brLHood.readLine()) != null) { + strLine = strLine.trim(); + String[] str = strLine.split("\t"); + // iteration, likelihood, iter_time + myAssert(str.length == 3, "lhood file problem!"); + this.lhood.add(Double.parseDouble(str[1])); + this.iterTime.add(Double.parseDouble(str[2])); + } + this.startIter = this.lhood.size(); + if (this.startIter > this.numIterations) { + System.out.println("Have already sampled " + this.numIterations + " iterations!"); + System.exit(0); + } + System.out.println("Start sampling for iteration " + this.startIter); + brLHood.close(); + } + + /** + * Resumes from the saved files. + */ + public void resume(InstanceList training, String resumeDir) { + try { + String statesFile = resumeDir + ".states"; + resumeStates(training, statesFile); + + String lhoodFile = resumeDir + ".lhood"; + resumeLHood(lhoodFile); + } catch (IOException e) { + System.out.println(e.getMessage()); + } + } + + /** + * This function prints the topic words of each topic. + */ + public void printTopWords(File file, int numWords) throws IOException { + PrintStream out = new PrintStream (file); + out.print(displayTopWords(numWords)); + out.close(); + } + + /** + * By implementing the comparable interface, this function ranks the words + * in each topic, and returns the top words for each topic. + */ + public String displayTopWords (int numWords) { + + class WordProb implements Comparable { + int wi; + double p; + public WordProb (int wi, double p) { this.wi = wi; this.p = p; } + public final int compareTo (Object o2) { + if (p > ((WordProb)o2).p) + return -1; + else if (p == ((WordProb)o2).p) + return 0; + else return 1; + } + } + + StringBuilder out = new StringBuilder(); + int numPaths = this.topics.getPathNum(); + //System.out.println(numPaths); + + for (int tt = 0; tt < this.numTopics; tt++){ + String tmp = "\n--------------\nTopic " + tt + "\n------------------------\n"; + //System.out.print(tmp); + out.append(tmp); + WordProb[] wp = new WordProb[numPaths]; + for (int pp = 0; pp < numPaths; pp++){ + int ww = this.topics.getWordFromPath(pp); + double val = this.topics.computeTopicPathProb(tt, ww, pp); + wp[pp] = new WordProb(pp, val); + } + Arrays.sort(wp); + for (int ii = 0; ii < wp.length; ii++){ + int pp = wp[ii].wi; + int ww = this.topics.getWordFromPath(pp); + //tmp = wp[ii].p + "\t" + this.vocab.lookupObject(ww) + "\n"; + tmp = wp[ii].p + "\t" + this.vocab.get(ww) + "\n"; + //System.out.print(tmp); + out.append(tmp); + if(ii > numWords) { + break; + } + } + } + return out.toString(); + } + + /** + * Prints likelihood and iter time. + */ + public void printStats (File file) throws IOException { + PrintStream out = new PrintStream (file); + String tmp = "Iteration\t\tlikelihood\titer_time\n"; + out.print(tmp); + for (int iter = 0; iter < this.lhood.size(); iter++) { + tmp = iter + "\t" + this.lhood.get(iter) + "\t" + this.iterTime.get(iter); + out.println(tmp); + } + out.close(); + } + + public void loadVocab(String vocabFile) { + + try { + FileInputStream infstream = new FileInputStream(vocabFile); + DataInputStream in = new DataInputStream(infstream); + BufferedReader br = new BufferedReader(new InputStreamReader(in)); + + String strLine; + //Read File Line By Line + while ((strLine = br.readLine()) != null) { + strLine = strLine.trim(); + String[] str = strLine.split("\t"); + if (str.length > 1) { + this.vocab.add(str[1]); + } else { + System.out.println("Error! " + strLine); + } + } + in.close(); + + } catch (IOException e) { + System.out.println("No vocab file Found!"); + } + + } + + /** + * Load constraints + */ + public void loadConstraints(String consFile) { + try { + FileInputStream infstream = new FileInputStream(consFile); + DataInputStream in = new DataInputStream(infstream); + BufferedReader br = new BufferedReader(new InputStreamReader(in)); + + String strLine; + //Read File Line By Line + while ((strLine = br.readLine()) != null) { + strLine = strLine.trim(); + String[] str = strLine.split("\t"); + if (str.length > 1) { + // str[0] is either "MERGE_" or "SPLIT_", not a word + for(int ii = 1; ii < str.length; ii++) { + int word = this.vocab.indexOf(str[ii]); + myAssert(word >= 0, "Constraint words not found in vocab: " + str[ii]); + cons.add(word); + } + this.vocab.add(str[1]); + } else { + System.out.println("Error! " + strLine); + } + } + in.close(); + + } catch (IOException e) { + System.out.println("No vocab file Found!"); + } + + } + + /** + * For testing~~ + */ + public static void myAssert(boolean flag, String info) { + if(!flag) { + System.out.println(info); + System.exit(0); + } + } + + abstract void addInstances(InstanceList training); + abstract void resumeStates(InstanceList training, String statesFile) throws IOException; + abstract void clearTopicAssignments(String option, String consFile); + abstract void changeTopic(int doc, int index, int word, int new_topic, int new_path); + abstract double docLHood(); + abstract void printDocumentTopics (File file) throws IOException; + abstract void sampleDoc(int doc); +} diff --git a/src/cc/mallet/topics/tree/TopicTreeWalk.java b/src/cc/mallet/topics/tree/TopicTreeWalk.java new file mode 100755 index 000000000..ac2fd5fd5 --- /dev/null +++ b/src/cc/mallet/topics/tree/TopicTreeWalk.java @@ -0,0 +1,85 @@ +package cc.mallet.topics.tree; + +import java.io.Serializable; + +import gnu.trove.TIntArrayList; +import gnu.trove.TIntHashSet; +import gnu.trove.TIntIntHashMap; + +/** + * This class counts each node and each edge for a topic with tree structure. + * + * @author Yuening Hu + */ + +public class TopicTreeWalk implements Serializable { + + // *** To be sorted + HIntIntIntHashMap counts; + TIntIntHashMap nodeCounts; + + public TopicTreeWalk() { + this.counts = new HIntIntIntHashMap(); + this.nodeCounts = new TIntIntHashMap(); + } + + /** + * Given a path (a list of nodes), increase the nodes and edges counts by + * the specified amount. When a node count is changed from zero or changed + * to zero, return this node. (When this happens, the non-zero path of this + * node might need to be changed, that's why we need this list.) + */ + public int[] changeCount(TIntArrayList path_nodes, int increment) { + for (int nn = 0; nn < path_nodes.size()-1; nn++) { + int parent = path_nodes.get(nn); + int child = path_nodes.get(nn+1); + this.counts.adjustOrPutValue(parent, child, increment, increment); + } + + // keep the nodes whose counts is changed from zero or changed to zero + TIntHashSet affected_nodes = new TIntHashSet(); + + for (int nn = 0; nn < path_nodes.size(); nn++) { + int node = path_nodes.get(nn); + if (! this.nodeCounts.contains(node)) { + this.nodeCounts.put(node, 0); + } + + int old_count = this.nodeCounts.get(node); + this.nodeCounts.adjustValue(node, increment); + int new_count = this.nodeCounts.get(node); + + // keep the nodes whose counts is changed from zero or changed to zero + if (nn != 0 && (old_count == 0 || new_count == 0)) { + affected_nodes.add(node); + } + } + + if (affected_nodes.size() > 0) { + return affected_nodes.toArray(); + } else { + return null; + } + } + + /** + * Return an edge count. + */ + public int getCount(int key1, int key2) { + if (this.counts.contains(key1, key2)) { + return this.counts.get(key1, key2); + } + return 0; + } + + /** + * Return a node count. + */ + public int getNodeCount(int key) { + if (this.nodeCounts.contains(key)) { + return this.nodeCounts.get(key); + } + return 0; + } + +} diff --git a/src/cc/mallet/topics/tree/TreeMarginalProbEstimator.java b/src/cc/mallet/topics/tree/TreeMarginalProbEstimator.java new file mode 100644 index 000000000..f8ee3b059 --- /dev/null +++ b/src/cc/mallet/topics/tree/TreeMarginalProbEstimator.java @@ -0,0 +1,380 @@ +package cc.mallet.topics.tree; + +import gnu.trove.TIntIntHashMap; + +import java.io.File; +import java.io.FileInputStream; +import java.io.IOException; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; +import java.io.PrintStream; +import java.io.Serializable; +import java.util.ArrayList; +import java.util.HashSet; + +import cc.mallet.types.FeatureSequence; +import cc.mallet.types.Instance; +import cc.mallet.types.InstanceList; +import cc.mallet.util.Randoms; + + +/** + * An implementation of left-to-right algorithm for tree-based topic model marginal probability estimators + * presented in Wallach et al., "Evaluation Methods for Topic Models", ICML (2009) + * Followed the example in "cc.mallet.topics.MarginalProbEstimator" by David Mimno + * + * @author Yuening Hu + */ + +public class TreeMarginalProbEstimator implements Serializable { + int TOPIC_BITS = TreeTopicModelFastSortW.TOPIC_BITS; + + int numTopics; + double[] alpha; + double alphasum; + ArrayList vocab; + HashSet removed; + TreeTopicModel topics; + String modelType; + + Randoms random; + boolean sorted; + + public TreeMarginalProbEstimator(TreeTopicModel topics, ArrayList vocab, HashSet removed, double[] alpha) { + this.numTopics = topics.numTopics; + this.vocab = vocab; + this.removed = removed; + this.alpha = alpha; + this.topics = topics; + this.random = new Randoms(); + + this.alphasum = 0.0; + for(int tt = 0; tt < numTopics; tt++) { + this.alphasum += this.alpha[tt]; + } + + if (this.topics.nonZeroPathsBubbleSorted.size() > 0) { + this.sorted = true; + } else if (this.topics.nonZeroPaths.size() > 0) { + this.sorted = false; + } + //System.out.println(this.sorted); + } + + public void setRandomSeed(int seed) { + this.random = new Randoms(seed); + } + + public void setModelType(String modeltype) { + this.modelType = modeltype; + } + + + public double evaluateLeftToRight (InstanceList testing, int numParticles, boolean usingResampling, + PrintStream docProbabilityStream) { + + if(this.modelType.indexOf("fast-est") < 0) { + System.out.println("%%%%%%%%%%%%%%%%%%%"); + System.out.println("Your current tree-model-type"); + System.out.println("\t " + this.modelType); + System.out.println("is not supported by inferencer. "); + System.out.println("Inferencer only supports the following tree-model-type: "); + System.out.println("\t fast-est \n\t fast-est-sortW \n\t fast-est-sortD \n\t fast-est-sortD-sortW"); + System.out.println("%%%%%%%%%%%%%%%%%%%"); + return -1; + } + + double logNumParticles = Math.log(numParticles); + double totalLogLikelihood = 0; + for (Instance instance : testing) { + + FeatureSequence tokenSequence = (FeatureSequence) instance.getData(); + + // read in type index in vocab (different from the alphabet) + // remove tokens not in vocab + ArrayList tokens = new ArrayList (); + for (int position = 0; position < tokenSequence.size(); position++) { + String word = (String) tokenSequence.getObjectAtPosition(position); + if(this.vocab.indexOf(word) >= 0 && !this.removed.contains(word)) { + int type = this.vocab.indexOf(word); + tokens.add(type); + } + } + + double docLogLikelihood = 0; + + double[][] particleProbabilities = new double[ numParticles ][]; + for (int particle = 0; particle < numParticles; particle++) { + particleProbabilities[particle] = + leftToRight(tokens, usingResampling); + } + + for (int position = 0; position < particleProbabilities[0].length; position++) { + double sum = 0; + for (int particle = 0; particle < numParticles; particle++) { + sum += particleProbabilities[particle][position]; + } + + if (sum > 0.0) { + docLogLikelihood += Math.log(sum) - logNumParticles; + } + } + + if (docProbabilityStream != null) { + docProbabilityStream.println(docLogLikelihood); + } + totalLogLikelihood += docLogLikelihood; + } + + return totalLogLikelihood; + } + + protected double[] leftToRight (ArrayList tokens, boolean usingResampling) { + + int docLength = tokens.size(); + double[] wordProbabilities = new double[docLength]; + + int[] localtopics = new int[docLength]; + int[] localpaths = new int[docLength]; + TIntIntHashMap localTopicCounts = new TIntIntHashMap(); + + int tokensSoFar = 0; + int type; + for (int limit = 0; limit < docLength; limit++) { + if (usingResampling) { + + // Iterate up to the current limit + for (int position = 0; position < limit; position++) { + type = tokens.get(position); + + // change topic counts + int old_topic = localtopics[position]; + localtopics[position] = -1; + localpaths[position] = -1; + localTopicCounts.adjustValue(old_topic, -1); + + double smoothing_mass_est = this.topics.smoothingEst.get(type); + + double topic_beta_mass = this.topics.computeTermTopicBeta(localTopicCounts, type); + + ArrayList topic_term_score = new ArrayList(); + double topic_term_mass = this.topics.computeTopicTerm(this.alpha, localTopicCounts, type, topic_term_score); + + double norm_est = smoothing_mass_est + topic_beta_mass + topic_term_mass; + double sample = this.random.nextDouble(); + //double sample = 0.5; + sample *= norm_est; + + int new_topic = -1; + int new_path = -1; + + int[] paths = this.topics.getWordPathIndexSet(type); + + // sample the smoothing bin + if (sample < smoothing_mass_est) { + double smoothing_mass = this.topics.computeTermSmoothing(this.alpha, type); + double norm = smoothing_mass + topic_beta_mass + topic_term_mass; + sample /= norm_est; + sample *= norm; + if (sample < smoothing_mass) { + for (int tt = 0; tt < this.numTopics; tt++) { + for (int pp : paths) { + double val = alpha[tt] * this.topics.getPathPrior(type, pp); + val /= this.topics.getNormalizer(tt, pp); + sample -= val; + if (sample <= 0.0) { + new_topic = tt; + new_path = pp; + break; + } + } + if (new_topic >= 0) { + break; + } + } + } else { + sample -= smoothing_mass; + } + } else { + sample -= smoothing_mass_est; + } + + // sample topic beta bin + if (new_topic < 0 && sample < topic_beta_mass) { + for(int tt : localTopicCounts.keys()) { + for (int pp : paths) { + double val = localTopicCounts.get(tt) * this.topics.getPathPrior(type, pp); + val /= this.topics.getNormalizer(tt, pp); + sample -= val; + if (sample <= 0.0) { + new_topic = tt; + new_path = pp; + break; + } + } + if (new_topic >= 0) { + break; + } + } + } else { + sample -= topic_beta_mass; + } + + // sample topic term bin + if (new_topic < 0) { + for(int jj = 0; jj < topic_term_score.size(); jj++) { + double[] tmp = topic_term_score.get(jj); + int tt = (int) tmp[0]; + int pp = (int) tmp[1]; + double val = tmp[2]; + sample -= val; + if (sample <= 0.0) { + new_topic = tt; + new_path = pp; + break; + } + } + } + + // change topic counts + localtopics[position] = new_topic; + localpaths[position] = new_path; + localTopicCounts.adjustOrPutValue(new_topic, 1, 1); + } + } + + // sample current token at the current limit + type = tokens.get(limit); + + //double smoothing_mass_est = this.topics.smoothingEst.get(type); + double smoothing_mass = this.topics.computeTermSmoothing(this.alpha, type); + + double topic_beta_mass = this.topics.computeTermTopicBeta(localTopicCounts, type); + + ArrayList topic_term_score = new ArrayList(); + double topic_term_mass = this.topics.computeTopicTerm(this.alpha, localTopicCounts, type, topic_term_score); + + //double norm_est = smoothing_mass_est + topic_beta_mass + topic_term_mass; + double norm = smoothing_mass + topic_beta_mass + topic_term_mass; + double sample = this.random.nextDouble(); + sample *= norm; + + wordProbabilities[limit] += (smoothing_mass + topic_beta_mass + topic_term_mass) / + (this.alphasum + tokensSoFar); + + tokensSoFar++; + + int new_topic = -1; + int new_path = -1; + int[] paths = this.topics.getWordPathIndexSet(type); + + // sample the smoothing bin + if (sample < smoothing_mass) { + for (int tt = 0; tt < this.numTopics; tt++) { + for (int pp : paths) { + double val = alpha[tt] * this.topics.getPathPrior(type, pp); + val /= this.topics.getNormalizer(tt, pp); + sample -= val; + if (sample <= 0.0) { + new_topic = tt; + new_path = pp; + break; + } + } + if (new_topic >= 0) { + break; + } + } + } else { + sample -= smoothing_mass; + } + + // sample topic beta bin + if (new_topic < 0 && sample < topic_beta_mass) { + for(int tt : localTopicCounts.keys()) { + for (int pp : paths) { + double val = localTopicCounts.get(tt) * this.topics.getPathPrior(type, pp); + val /= this.topics.getNormalizer(tt, pp); + sample -= val; + if (sample <= 0.0) { + new_topic = tt; + new_path = pp; + break; + } + } + if (new_topic >= 0) { + break; + } + } + } else { + sample -= topic_beta_mass; + } + + // sample topic term bin + if (new_topic < 0) { + for(int jj = 0; jj < topic_term_score.size(); jj++) { + double[] tmp = topic_term_score.get(jj); + int tt = (int) tmp[0]; + int pp = (int) tmp[1]; + double val = tmp[2]; + sample -= val; + if (sample <= 0.0) { + new_topic = tt; + new_path = pp; + break; + } + } + } + + // change topic counts + localtopics[limit] = new_topic; + localpaths[limit] = new_path; + localTopicCounts.adjustOrPutValue(new_topic, 1, 1); + } + + return wordProbabilities; + } + + + // for serialize + private static final long serialVersionUID = 1L; + private static final int CURRENT_SERIAL_VERSION = 0; + private static final int NULL_INTEGER = -1; + + private void writeObject (ObjectOutputStream out) throws IOException { + out.writeInt (CURRENT_SERIAL_VERSION); + out.writeInt(this.numTopics); + out.writeInt(this.TOPIC_BITS); + out.writeBoolean(this.sorted); + out.writeObject(this.modelType); + out.writeObject(this.alpha); + out.writeDouble(this.alphasum); + out.writeObject(this.vocab); + out.writeObject(this.removed); + out.writeObject(this.topics); + } + + private void readObject (ObjectInputStream in) throws IOException, ClassNotFoundException { + int version = in.readInt(); + this.numTopics = in.readInt(); + this.TOPIC_BITS = in.readInt(); + this.sorted = in.readBoolean(); + this.modelType = (String) in.readObject(); + this.alpha = (double[]) in.readObject(); + this.alphasum = in.readDouble(); + this.vocab = (ArrayList) in.readObject(); + this.removed = (HashSet) in.readObject(); + this.topics = (TreeTopicModel) in.readObject(); + } + + public static TreeMarginalProbEstimator read (File f) throws Exception { + + TreeMarginalProbEstimator estimator = null; + + ObjectInputStream ois = new ObjectInputStream (new FileInputStream(f)); + estimator = (TreeMarginalProbEstimator) ois.readObject(); + ois.close(); + return estimator; + } + +} diff --git a/src/cc/mallet/topics/tree/TreeTopicInferencer.java b/src/cc/mallet/topics/tree/TreeTopicInferencer.java new file mode 100755 index 000000000..5b6cde552 --- /dev/null +++ b/src/cc/mallet/topics/tree/TreeTopicInferencer.java @@ -0,0 +1,384 @@ +package cc.mallet.topics.tree; + +import gnu.trove.TIntArrayList; +import gnu.trove.TIntIntHashMap; + +import java.io.File; +import java.io.FileInputStream; +import java.io.IOException; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; +import java.io.PrintWriter; +import java.io.Serializable; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashSet; + +import cc.mallet.topics.TopicInferencer; +import cc.mallet.types.Alphabet; +import cc.mallet.types.FeatureSequence; +import cc.mallet.types.IDSorter; +import cc.mallet.types.Instance; +import cc.mallet.types.InstanceList; +import cc.mallet.util.Randoms; + + +/** + * An implementation of inferencer for tree-based topic model + * Followed the example in "cc.mallet.topics.TopicInferencer" + * + * @author Yuening Hu + */ + +public class TreeTopicInferencer implements Serializable { + + int TOPIC_BITS = TreeTopicModelFastSortW.TOPIC_BITS; + + int numTopics; + double[] alpha; + ArrayList vocab; + HashSet removed; + TreeTopicModel topics; + String modelType; + + Randoms random; + boolean sorted; + + public TreeTopicInferencer(TreeTopicModel topics, ArrayList vocab, HashSet removed, double[] alpha) { + this.numTopics = topics.numTopics; + this.vocab = vocab; + this.removed = removed; + this.alpha = alpha; + this.topics = topics; + this.random = new Randoms(); + + if (this.topics.nonZeroPathsBubbleSorted.size() > 0) { + this.sorted = true; + } else if (this.topics.nonZeroPaths.size() > 0) { + this.sorted = false; + } + //System.out.println(this.sorted); + } + + public void setRandomSeed(int seed) { + this.random = new Randoms(seed); + } + + public void setModelType(String modeltype) { + this.modelType = modeltype; + } + + public String toString() { + String tmp = ""; + tmp += "numTopics: " + numTopics + "\n"; + tmp += "sorted: " + this.sorted + "\n"; + tmp += "removed: " + this.removed.size() + "\n"; + tmp += "nonzero: " + this.topics.nonZeroPaths.size() + "\n"; + tmp += "nonzerobubblesorted: " + this.topics.nonZeroPathsBubbleSorted.size() + "\n"; + return tmp; + } + + /** + * Use Gibbs sampling to infer a topic distribution. + * Topics are initialized to the (or a) most probable topic + * for each token. + */ + public double[] getSampledDistribution(Instance instance, int numIterations, int interval) { + + FeatureSequence alltokens = (FeatureSequence) instance.getData(); + ArrayList tokens = new ArrayList (); + for (int position = 0; position < alltokens.size(); position++) { + String word = (String) alltokens.getObjectAtPosition(position); + if(this.vocab.indexOf(word) >= 0 && !this.removed.contains(word)) { + int type = this.vocab.indexOf(word); + tokens.add(type); + } + } + + int docLength = tokens.size(); + int[] localtopics = new int[docLength]; + int[] localpaths = new int[docLength]; + TIntIntHashMap localTopicCounts = new TIntIntHashMap(); + + // Initialize all positions to the most common topic for that type. + for (int position = 0; position < docLength; position++) { + int type = tokens.get(position); + + int tt = -1; + int pp = -1; + + if (this.sorted) { + ArrayList pairs = this.topics.nonZeroPathsBubbleSorted.get(type); + int[] pair = pairs.get(0); + int key = pair[0]; + tt = key >> TOPIC_BITS; + pp = key - (tt << TOPIC_BITS); + } else { + HIntIntIntHashMap pairs1 = this.topics.nonZeroPaths.get(type); + int maxcount = 0; + for(int topic : pairs1.getKey1Set()) { + int[] paths = pairs1.get(topic).keys(); + for (int jj = 0; jj < paths.length; jj++) { + int path = paths[jj]; + int count = pairs1.get(topic, path); + if (count > maxcount) { + maxcount = count; + tt = topic; + pp = path; + } + } + } + } + + localtopics[position] = tt; + localpaths[position] = pp; + localTopicCounts.adjustOrPutValue(tt, 1, 1); + } + +// String tmpout = ""; +// for(int tt : localTopicCounts.keys()) { +// tmpout += tt + " " + localTopicCounts.get(tt) + "; "; +// } +// System.out.println(tmpout); + + double[] result = new double[numTopics]; + double sum = 0.0; + + for (int iteration = 1; iteration <= numIterations; iteration++) { + for (int position = 0; position < docLength; position++) { + int type = tokens.get(position); + + // change topic counts + int old_topic = localtopics[position]; + localtopics[position] = -1; + localpaths[position] = -1; + localTopicCounts.adjustValue(old_topic, -1); + + double smoothing_mass_est = this.topics.smoothingEst.get(type); + + double topic_beta_mass = this.topics.computeTermTopicBeta(localTopicCounts, type); + + ArrayList topic_term_score = new ArrayList(); + double topic_term_mass = this.topics.computeTopicTerm(this.alpha, localTopicCounts, type, topic_term_score); + + double norm_est = smoothing_mass_est + topic_beta_mass + topic_term_mass; + double sample = this.random.nextDouble(); + //double sample = 0.5; + sample *= norm_est; + + int new_topic = -1; + int new_path = -1; + + int[] paths = this.topics.getWordPathIndexSet(type); + + // sample the smoothing bin + if (sample < smoothing_mass_est) { + double smoothing_mass = this.topics.computeTermSmoothing(this.alpha, type); + double norm = smoothing_mass + topic_beta_mass + topic_term_mass; + sample /= norm_est; + sample *= norm; + if (sample < smoothing_mass) { + for (int tt = 0; tt < this.numTopics; tt++) { + for (int pp : paths) { + double val = alpha[tt] * this.topics.getPathPrior(type, pp); + val /= this.topics.getNormalizer(tt, pp); + sample -= val; + if (sample <= 0.0) { + new_topic = tt; + new_path = pp; + break; + } + } + if (new_topic >= 0) { + break; + } + } + } else { + sample -= smoothing_mass; + } + } else { + sample -= smoothing_mass_est; + } + + // sample topic beta bin + if (new_topic < 0 && sample < topic_beta_mass) { + for(int tt : localTopicCounts.keys()) { + for (int pp : paths) { + double val = localTopicCounts.get(tt) * this.topics.getPathPrior(type, pp); + val /= this.topics.getNormalizer(tt, pp); + sample -= val; + if (sample <= 0.0) { + new_topic = tt; + new_path = pp; + break; + } + } + if (new_topic >= 0) { + break; + } + } + } else { + sample -= topic_beta_mass; + } + + // sample topic term bin + if (new_topic < 0) { + for(int jj = 0; jj < topic_term_score.size(); jj++) { + double[] tmp = topic_term_score.get(jj); + int tt = (int) tmp[0]; + int pp = (int) tmp[1]; + double val = tmp[2]; + sample -= val; + if (sample <= 0.0) { + new_topic = tt; + new_path = pp; + break; + } + } + } + + // change topic counts + localtopics[position] = new_topic; + localpaths[position] = new_path; + localTopicCounts.adjustOrPutValue(new_topic, 1, 1); + +// if (iteration % interval == 0) { +// // Save a sample +// for (int topic=0; topic < numTopics; topic++) { +// if (localTopicCounts.containsKey(topic)) { +// result[topic] = alpha[topic] + localTopicCounts.get(topic); +// } else { +// result[topic] = alpha[topic]; +// } +// sum += result[topic]; +// } +// } + } + } + + + // save at least once + if (sum == 0.0) { + for (int topic=0; topic < numTopics; topic++) { + if (localTopicCounts.containsKey(topic)) { + result[topic] = alpha[topic] + localTopicCounts.get(topic); + } else { + result[topic] = alpha[topic]; + } + sum += result[topic]; + } + } + + // Normalize + for (int topic=0; topic < numTopics; topic++) { + result[topic] /= sum; + } + + return result; + } + + /** + * Infer topics for the provided instances and + * write distributions to the provided file. + * + * @param instances + * @param distributionsFile + * @param numIterations The total number of iterations of sampling per document + * @param interval The number of iterations between saved samples + */ + public void writeInferredDistributions(InstanceList instances, + File distributionsFile, + int numIterations, int interval) throws IOException { + + if(this.modelType.indexOf("fast-est") < 0) { + System.out.println("%%%%%%%%%%%%%%%%%%%"); + System.out.println("Your current tree-model-type"); + System.out.println("\t " + this.modelType); + System.out.println("is not supported by inferencer. "); + System.out.println("Inferencer only supports the following tree-model-type: "); + System.out.println("\t fast-est \n\t fast-est-sortW \n\t fast-est-sortD \n\t fast-est-sortD-sortW"); + System.out.println("%%%%%%%%%%%%%%%%%%%"); + return; + } + + PrintWriter out = new PrintWriter(distributionsFile); + + out.print ("#doc source topic proportion ...\n"); + + IDSorter[] sortedTopics = new IDSorter[ numTopics ]; + for (int topic = 0; topic < numTopics; topic++) { + // Initialize the sorters with dummy values + sortedTopics[topic] = new IDSorter(topic, topic); + } + + int doc = 0; + + for (Instance instance: instances) { + + double[] topicDistribution = + getSampledDistribution(instance, numIterations, interval); + out.print (doc); out.print (' '); + + // Print the Source field of the instance + if (instance.getSource() != null) { + out.print (instance.getSource()); + } else { + out.print ("null-source"); + } + out.print (' '); + + for (int topic = 0; topic < numTopics; topic++) { + sortedTopics[topic].set(topic, topicDistribution[topic]); + } + Arrays.sort(sortedTopics); + + for (int i = 0; i < numTopics; i++) { + out.print (sortedTopics[i].getID() + " " + + sortedTopics[i].getWeight() + " "); + } + out.print (" \n"); + doc++; + } + out.close(); + } + + + // for serialize + private static final long serialVersionUID = 1L; + private static final int CURRENT_SERIAL_VERSION = 0; + private static final int NULL_INTEGER = -1; + + private void writeObject (ObjectOutputStream out) throws IOException { + out.writeInt (CURRENT_SERIAL_VERSION); + out.writeInt(this.numTopics); + out.writeInt(this.TOPIC_BITS); + out.writeBoolean(this.sorted); + out.writeObject(this.modelType); + out.writeObject(this.alpha); + out.writeObject(this.vocab); + out.writeObject(this.removed); + out.writeObject(this.topics); + } + + private void readObject (ObjectInputStream in) throws IOException, ClassNotFoundException { + int version = in.readInt(); + this.numTopics = in.readInt(); + this.TOPIC_BITS = in.readInt(); + this.sorted = in.readBoolean(); + this.modelType = (String) in.readObject(); + this.alpha = (double[]) in.readObject(); + this.vocab = (ArrayList) in.readObject(); + this.removed = (HashSet) in.readObject(); + this.topics = (TreeTopicModel) in.readObject(); + } + + public static TreeTopicInferencer read (File f) throws Exception { + + TreeTopicInferencer inferencer = null; + + ObjectInputStream ois = new ObjectInputStream (new FileInputStream(f)); + inferencer = (TreeTopicInferencer) ois.readObject(); + ois.close(); + return inferencer; + } +} diff --git a/src/cc/mallet/topics/tree/TreeTopicModel.java b/src/cc/mallet/topics/tree/TreeTopicModel.java new file mode 100755 index 000000000..cae9ef3a2 --- /dev/null +++ b/src/cc/mallet/topics/tree/TreeTopicModel.java @@ -0,0 +1,382 @@ +package cc.mallet.topics.tree; + +import gnu.trove.TDoubleArrayList; +import gnu.trove.TIntArrayList; +import gnu.trove.TIntDoubleHashMap; +import gnu.trove.TIntDoubleIterator; +import gnu.trove.TIntHashSet; +import gnu.trove.TIntIntHashMap; +import gnu.trove.TIntObjectHashMap; +import gnu.trove.TIntObjectIterator; + +import java.io.Serializable; +import java.util.ArrayList; +import java.util.Random; +import java.util.TreeMap; + +import cc.mallet.types.Dirichlet; + +/** + * This class defines the tree topic model. + * It implements most of the functions and leave four abstract methods, + * which might be various for different models. + * + * @author Yuening Hu + */ + +public abstract class TreeTopicModel implements Serializable { + + int numTopics; + Random random; + int maxDepth; + int root; + HIntIntObjectHashMap wordPaths; + TIntArrayList pathToWord; + TIntArrayList pathToWordPath; + TIntObjectHashMap nodeToPath; + + TIntDoubleHashMap betaSum; + HIntIntDoubleHashMap beta; // 2 levels hash map + TIntDoubleHashMap priorSum; + HIntIntDoubleHashMap priorPath; + + TIntObjectHashMap nonZeroPaths; + TIntObjectHashMap> nonZeroPathsBubbleSorted; + TIntObjectHashMap traversals; + + HIntIntDoubleHashMap normalizer; + TIntDoubleHashMap rootNormalizer; + TIntDoubleHashMap smoothingEst; + + /*********************************************/ + double smoothingNocons; // for unconstrained words + double topicbetaNocons; // for unconstrained words + double betaNocons; // for unconstrained words + /*********************************************/ + + public TreeTopicModel(int numTopics, Random random) { + this.numTopics = numTopics; + this.random = random; + + this.betaSum = new TIntDoubleHashMap (); + this.beta = new HIntIntDoubleHashMap (); + this.priorSum = new TIntDoubleHashMap (); + this.priorPath = new HIntIntDoubleHashMap (); + + this.wordPaths = new HIntIntObjectHashMap (); + this.pathToWord = new TIntArrayList (); + this.pathToWordPath = new TIntArrayList(); + this.nodeToPath = new TIntObjectHashMap (); + + this.nonZeroPaths = new TIntObjectHashMap (); + this.nonZeroPathsBubbleSorted = new TIntObjectHashMap> (); + this.traversals = new TIntObjectHashMap (); + + this.smoothingNocons = 0.0; + this.topicbetaNocons = 0.0; + } + + /** + * Initialize the parameters, including: + * (1) loading the tree + * (2) initialize betaSum and beta + * (3) initialize priorSum, priorPath + * (4) initialize wordPaths, pathToWord, NodetoPath + * (5) initialize traversals + * (6) initialize nonZeroPaths + */ + protected void initializeParams(String treeFiles, String hyperFile, ArrayList vocab) { + + PriorTree tree = new PriorTree(); + tree.initialize(treeFiles, hyperFile, vocab); + + // get tree depth + this.maxDepth = tree.getMaxDepth(); + // get root index + this.root = tree.getRoot(); + // get tree nodes + TIntObjectHashMap nodes = tree.getNodes(); + // get tree paths + TIntObjectHashMap> word_paths = tree.getWordPaths(); + + // if one node contains multiple words, we need to change each word to a leaf node + // (assigning a leaf index for each word). + int leaf_index = nodes.size(); + HIntIntIntHashMap tmp_wordleaf = new HIntIntIntHashMap(); + + // initialize betaSum and beta + for (TIntObjectIterator it = nodes.iterator(); it.hasNext(); ) { + it.advance(); + int index = it.key(); + Node node = it.value(); + TDoubleArrayList transition_prior = node.getTransitionPrior(); + + // when node has children + if (node.getNumChildren() > 0) { + //assert node.getNumWords() == 0; + this.betaSum.put(index, node.getTransitionScalor()); + for (int ii = 0; ii < node.getNumChildren(); ii++) { + int child = node.getChild(ii); + this.beta.put(index, child, transition_prior.get(ii)); + } + } + + // when node contains multiple words. + // we change a node containing multiple words to a node containing multiple + // leaf node and each leaf node containing one word + if (node.getNumWords() > 1) { + //assert node.getNumChildren() == 0; + this.betaSum.put(index, node.getTransitionScalor()); + for (int ii = 0; ii < node.getNumWords(); ii++) { + int word = node.getWord(ii); + leaf_index++; + this.beta.put(index, leaf_index, transition_prior.get(ii)); + + // one word might have multiple paths, + // so we keep the (word_index, word_parent) + // as the index for this leaf index, which is needed later + tmp_wordleaf.put(word, index, leaf_index); + } + } + } + + /*********************************************/ + // find beta for unconstrained words + Node rootnode = nodes.get(this.root); + for (int ii = 0; ii < rootnode.getNumChildren(); ii++) { + int child = rootnode.getChild(ii); + Node childnode = nodes.get(child); + double tmpbeta = this.beta.get(this.root, child); + //System.out.println("beta for root to " + child + ": " + tmpbeta); + if (childnode.getHypoCount() == 1.0) { + this.betaNocons = this.beta.get(this.root, child); + System.out.println("beta for unconstrained words from root to " + child + ": " + tmpbeta); + break; + } + } + /*********************************************/ + + // initialize priorSum, priorPath + // initialize wordPaths, pathToWord, NodetoPath + int path_index = -1; + TIntObjectHashMap tmp_nodeToPath = new TIntObjectHashMap(); + for (TIntObjectIterator> it = word_paths.iterator(); it.hasNext(); ) { + it.advance(); + + int word = it.key(); + ArrayList paths = it.value(); + this.priorSum.put(word, 0.0); + + int word_path_index = -1; + for (int ii = 0; ii < paths.size(); ii++) { + path_index++; + word_path_index++; + this.pathToWord.add(word); + this.pathToWordPath.add(word_path_index); + + double prob = 1.0; + Path p = paths.get(ii); + TIntArrayList path_nodes = p.getNodes(); + + // for a node that contains multiple words + // if yes, retrieve the leaf index for each word + // and that to nodes of path + int parent = path_nodes.get(path_nodes.size()-1); + if (tmp_wordleaf.contains(word, parent)) { + leaf_index = tmp_wordleaf.get(word, parent); + path_nodes.add(leaf_index); + } + + for (int nn = 0; nn < path_nodes.size() - 1; nn++) { + parent = path_nodes.get(nn); + int child = path_nodes.get(nn+1); + prob *= this.beta.get(parent, child); + } + + for (int nn = 0; nn < path_nodes.size(); nn++) { + int node = path_nodes.get(nn); + if (! tmp_nodeToPath.contains(node)) { + tmp_nodeToPath.put(node, new TIntHashSet()); + } + tmp_nodeToPath.get(node).add(path_index); + //tmp_nodeToPath.get(node).add(word_path_index); + } + + this.priorPath.put(word, path_index, prob); + this.priorSum.adjustValue(word, prob); + this.wordPaths.put(word, path_index, path_nodes); + } + } + + // change tmp_nodeToPath to this.nodeToPath + // this is because arraylist is much more efficient than hashset, when we + // need to go over the whole set multiple times + for(TIntObjectIterator it = tmp_nodeToPath.iterator(); it.hasNext(); ) { + it.advance(); + int node = it.key(); + TIntHashSet paths = (TIntHashSet)it.value(); + TIntArrayList tmp = new TIntArrayList(paths.toArray()); + +// System.out.println("Node" + node); +// for(int ii = 0; ii < tmp.size(); ii++) { +// System.out.print(tmp.get(ii) + " "); +// } +// System.out.println(""); + + this.nodeToPath.put(node, tmp); + } + + // initialize traversals + for (int tt = 0; tt < this.numTopics; tt++) { + TopicTreeWalk tw = new TopicTreeWalk(); + this.traversals.put(tt, tw); + } + + // initialize nonZeroPaths + int[] words = this.wordPaths.getKey1Set(); + for (int ww = 0; ww < words.length; ww++) { + int word = words[ww]; + this.nonZeroPaths.put(word, new HIntIntIntHashMap()); + } + } + + /** + * This function samples a path based on the prior + * and change the node and edge count for a topic. + */ + protected int initialize (int word, int topic) { + double sample = this.random.nextDouble(); + int path_index = this.samplePathFromPrior(word, sample); + this.changeCountOnly(topic, word, path_index, 1); + return path_index; + } + + /** + * This function changes the node and edge count for a topic. + */ + protected void changeCountOnly(int topic, int word, int path, int delta) { + TIntArrayList path_nodes = this.wordPaths.get(word, path); + TopicTreeWalk tw = this.traversals.get(topic); + tw.changeCount(path_nodes, delta); + } + + /** + * This function samples a path from the prior. + */ + protected int samplePathFromPrior(int term, double sample) { + int sampled_path = -1; + sample *= this.priorSum.get(term); + TIntDoubleHashMap paths = this.priorPath.get(term); + for(TIntDoubleIterator it = paths.iterator(); it.hasNext(); ) { + it.advance(); + sample -= it.value(); + if (sample <= 0.0) { + sampled_path = it.key(); + break; + } + } + + return sampled_path; + } + + /** + * This function computes a path probability in a topic. + */ + public double computeTopicPathProb(int topic, int word, int path_index) { + TIntArrayList path_nodes = this.wordPaths.get(word, path_index); + TopicTreeWalk tw = this.traversals.get(topic); + double val = 1.0; + for(int ii = 0; ii < path_nodes.size()-1; ii++) { + int parent = path_nodes.get(ii); + int child = path_nodes.get(ii+1); + val *= this.beta.get(parent, child) + tw.getCount(parent, child); + val /= this.betaSum.get(parent) + tw.getNodeCount(parent); + } + return val; + } + + /** + * This function computes the topic likelihood (by node). + */ + public double topicLHood() { + double val = 0.0; + for (int tt = 0; tt < this.numTopics; tt++) { + for (int nn : this.betaSum.keys()) { + double beta_sum = this.betaSum.get(nn); + //val += Dirichlet.logGamma(beta_sum) * this.beta.get(nn).size(); + val += Dirichlet.logGamma(beta_sum); + + double tmp = 0.0; + for (int cc : this.beta.get(nn).keys()) { + tmp += Dirichlet.logGamma(this.beta.get(nn, cc)); + } + //val -= tmp * this.beta.get(nn).size(); + val -= tmp; + + for (int cc : this.beta.get(nn).keys()) { + int count = this.traversals.get(tt).getCount(nn, cc); + val += Dirichlet.logGamma(this.beta.get(nn, cc) + count); + } + + int count = this.traversals.get(tt).getNodeCount(nn); + val -= Dirichlet.logGamma(beta_sum + count); + } + //System.out.println("likelihood " + val); + } + return val; + } + + public TIntObjectHashMap getPaths(int word) { + return this.wordPaths.get(word); + } + + public int[] getWordPathIndexSet(int word) { + return this.wordPaths.get(word).keys(); + } + + public int getPathNum() { + return this.pathToWord.size(); + } + + public int getWordFromPath(int pp) { + return this.pathToWord.get(pp); + } + + public double getPathPrior(int word, int path) { + return this.priorPath.get(word, path); + } + + // for TreeTopicSamplerFast + public double computeTermSmoothing(double[] alpha, int word) { + return 0; + } + + public double computeTermTopicBeta(TIntIntHashMap topic_counts, int word) { + return 0; + } + + public double computeTopicTermTest(double[] alpha, TIntIntHashMap local_topic_counts, int word, ArrayList dict){ + return 0; + } + + public double computeTermTopicBetaSortD(ArrayList topicCounts, int word) { + return 0; + } + + public double computeTopicTermSortD(double[] alpha, ArrayList local_topic_counts, int word, ArrayList dict){ + return 0; + } + + /*********************************************/ + public void computeSmoothingNocons(double[] alpha) {} + public void computeDocTopicBetaNocons(TIntIntHashMap topic_counts) {} + public void updateStatisticsNocons(double alpha, int topic, int topicCount, int delta){} + /*********************************************/ + + // shared methods + abstract double getNormalizer(int topic, int path); + abstract void updateParams(); + abstract void changeCount(int topic, int word, int path_index, int delta); + abstract double computeTopicTerm(double[] alpha, TIntIntHashMap local_topic_counts, int word, ArrayList dict); + +} diff --git a/src/cc/mallet/topics/tree/TreeTopicModelFast.java b/src/cc/mallet/topics/tree/TreeTopicModelFast.java new file mode 100755 index 000000000..26a6fbf9f --- /dev/null +++ b/src/cc/mallet/topics/tree/TreeTopicModelFast.java @@ -0,0 +1,388 @@ +package cc.mallet.topics.tree; + +import gnu.trove.TIntArrayList; +import gnu.trove.TIntDoubleHashMap; +import gnu.trove.TIntHashSet; +import gnu.trove.TIntIntHashMap; +import gnu.trove.TIntIterator; +import gnu.trove.TIntObjectIterator; + +import java.util.ArrayList; +import java.util.Random; + +/** + * This class extends the tree topic model + * It implemented the four abstract methods in a faster way: + * (1) normalizer is stored and updated accordingly + * (2) normalizer is split to two parts: root normalizer and normalizer to save computation + * (3) non-zero-paths are stored so when we compute the topic term score, we only compute + * the non-zero paths + * + * @author Yuening Hu + */ + +public class TreeTopicModelFast extends TreeTopicModel { + + int INTBITS = 31; + // use at most 10 bits to denote the mask + int MASKMAXBITS = 10; + + /** + * The normalizer is split to two parts: root normalizer and normalizer + * root normalizer is stored per topic, and normalizer is stored per path per topic + * both normalizers are updated when the count is changing. + */ + public TreeTopicModelFast(int numTopics, Random random) { + super(numTopics, random); + this.normalizer = new HIntIntDoubleHashMap (); + //this.normalizer = new HIntIntObjectHashMap (); + this.rootNormalizer = new TIntDoubleHashMap (); + } + + /** + * This function updates the real count with the path masked count. + * The format is: using the first Tree_depth number of bits of an integer + * to denote whether a node in path has count larger than zero, + * and plus the real count. + * If a node path is shorter than Tree_depth, use "1" to fill the remained part. + */ + protected void updatePathMaskedCount(int path, int topic) { + TopicTreeWalk tw = this.traversals.get(topic); + int ww = this.getWordFromPath(path); + TIntArrayList path_nodes = this.wordPaths.get(ww, path); + int leaf_node = path_nodes.get(path_nodes.size() - 1); + int original_count = tw.getNodeCount(leaf_node); + + int shift_count = this.INTBITS; + int count = this.maxDepth - 1; + if (count > this.MASKMAXBITS) count = this.MASKMAXBITS; + int val = 0; + boolean flag = false; + + // note root is not included here + // if count of a node in the path is larger than 0, denote as "1" + // else use "0" + // if path_nodes.size() > MASKMAXBITS, denote the first MASKMAXBITS-1 edges as usual + // then use the last bit to denote the sum of the remaining edges + int remain_sum = 0; + for(int nn = 1; nn < path_nodes.size(); nn++) { + int node = path_nodes.get(nn); + if (nn < (this.MASKMAXBITS - 1)) { + count--; + shift_count--; + if (tw.getNodeCount(node) > 0) { + flag = true; + val += 1 << shift_count; + } + } else { + if (tw.getNodeCount(node) > 0) + remain_sum += 1; + } + } + + // use the last bit to denote the sum of the remaining edges + if (remain_sum > 0) { + count--; + shift_count--; + flag = true; + val += 1 << shift_count; + } + + // if a path is shorter than tree depth, fill in "1" + // should we fit in "0" ??? + while (flag && count > 0) { + shift_count--; + val += 1 << shift_count; + count--; + } + + int maskedpath = val; + // plus the original count + val += original_count; + if (val > 0) { + this.nonZeroPaths.get(ww).put(topic, path, val); + } else if (val == 0) { + if (this.nonZeroPaths.get(ww).get(topic) != null) { + this.nonZeroPaths.get(ww).removeKey2(topic, path); + if (this.nonZeroPaths.get(ww).get(topic).size() == 0) { + this.nonZeroPaths.get(ww).removeKey1(topic); + } + } + } + +// int shift = this.INTBITS - this.maxDepth - 1; +// int testmaskedpath = val >> shift; +// maskedpath = maskedpath >> shift; +// int testcount = val - (testmaskedpath << shift); +// System.out.println(maskedpath + " " + testmaskedpath + " " + original_count + " " + testcount); + + //System.out.println(original_count + " " + this.nonZeroPaths.get(ww).get(topic, path)); + } + + /** + * Compute the root normalizer and the normalizer per topic per path + */ + protected void computeNormalizer(int topic) { + TopicTreeWalk tw = this.traversals.get(topic); + double val = this.betaSum.get(root) + tw.getNodeCount(root); + this.rootNormalizer.put(topic, val); + //System.out.println("Topic " + topic + " root normalizer " + this.rootNormalizer.get(topic)); + + for(int pp = 0; pp < this.getPathNum(); pp++) { + int ww = this.getWordFromPath(pp); + val = this.computeNormalizerPath(topic, ww, pp); + this.normalizer.put(topic, pp, val); + //System.out.println("Topic " + topic + " Path " + pp + " normalizer " + this.normalizer.get(topic, pp)); + } + } + + /** + * Compute the the normalizer given a path and a topic. + */ + private double computeNormalizerPath(int topic, int word, int path) { + TopicTreeWalk tw = this.traversals.get(topic); + TIntArrayList path_nodes = this.wordPaths.get(word, path); + + double norm = 1.0; + // do not consider the root + for (int nn = 1; nn < path_nodes.size() - 1; nn++) { + int node = path_nodes.get(nn); + norm *= this.betaSum.get(node) + tw.getNodeCount(node); + } + return norm; + } + + /** + * Compute the root normalizer and the normalizer per topic per path. + */ + protected int[] findAffectedPaths(int[] nodes) { + TIntHashSet affected = new TIntHashSet(); + for(int ii = 0; ii < nodes.length; ii++) { + int node = nodes[ii]; + TIntArrayList paths = this.nodeToPath.get(node); + for (int jj = 0; jj < paths.size(); jj++) { + int pp = paths.get(jj); + affected.add(pp); + } + } + return affected.toArray(); + } + + /** + * Updates a list of paths with the given amount. + */ + protected void updateNormalizer(int topic, TIntArrayList paths, double delta) { + for (int ii = 0; ii < paths.size(); ii++) { + int pp = paths.get(ii); + double val = this.normalizer.get(topic, pp); + val *= delta; + this.normalizer.put(topic, pp, val); + } + } + + /** + * Computes the observation part. + */ + protected double getObservation(int topic, int word, int path_index) { + TIntArrayList path_nodes = this.wordPaths.get(word, path_index); + TopicTreeWalk tw = this.traversals.get(topic); + double val = 1.0; + for(int ii = 0; ii < path_nodes.size()-1; ii++) { + int parent = path_nodes.get(ii); + int child = path_nodes.get(ii+1); + val *= this.beta.get(parent, child) + tw.getCount(parent, child); + } + val -= this.priorPath.get(word, path_index); + return val; + } + + /** + * After adding instances, update the parameters. + */ + public void updateParams() { + for(int tt = 0; tt < this.numTopics; tt++) { + for(int pp = 0; pp < this.getPathNum(); pp++) { + this.updatePathMaskedCount(pp, tt); + } + this.computeNormalizer(tt); + } + } + + /** + * This function updates the count given the topic and path of a word. + */ + public void changeCount(int topic, int word, int path_index, int delta) { + + TopicTreeWalk tw = this.traversals.get(topic); + TIntArrayList path_nodes = this.wordPaths.get(word, path_index); + + // for affected paths, firstly remove the old values + // do not consider the root + for(int nn = 1; nn < path_nodes.size() - 1; nn++) { + int node = path_nodes.get(nn); + double tmp = this.betaSum.get(node) + tw.getNodeCount(node); + tmp = 1 / tmp; + TIntArrayList paths = this.nodeToPath.get(node); + updateNormalizer(topic, paths, tmp); + } + + // change the count for each edge per topic + // return the node index whose count is changed from 0 or to 0 + int[] affected_nodes = tw.changeCount(path_nodes, delta); + // change path count + if (delta > 0) { + this.nonZeroPaths.get(word).adjustOrPutValue(topic, path_index, delta, delta); + } else { + this.nonZeroPaths.get(word).adjustValue(topic, path_index, delta); + } + + // if necessary, change the path mask of the affected nodes + if (affected_nodes != null && affected_nodes.length > 0) { + int[] affected_paths = this.findAffectedPaths(affected_nodes); + for(int ii = 0; ii < affected_paths.length; ii++) { + this.updatePathMaskedCount(affected_paths[ii], topic); + } + } + + // for affected paths, update the normalizer + for(int nn = 1; nn < path_nodes.size() - 1; nn++) { + int node = path_nodes.get(nn); + double tmp = this.betaSum.get(node) + tw.getNodeCount(node); + TIntArrayList paths = this.nodeToPath.get(node); + updateNormalizer(topic, paths, tmp); + } + + // update the root normalizer + double val = this.betaSum.get(root) + tw.getNodeCount(root); + this.rootNormalizer.put(topic, val); + } + + /** + * This function returns the real normalizer. + */ + public double getNormalizer(int topic, int path) { + return this.normalizer.get(topic, path) * this.rootNormalizer.get(topic); + } + + /** + * This function computes the smoothing bucket for a word. + */ + public double computeTermSmoothing(double[] alpha, int word) { + double smoothing = 0.0; + int[] paths = this.getWordPathIndexSet(word); + + for(int tt = 0; tt < this.numTopics; tt++) { + for(int pp : paths) { + double val = alpha[tt] * this.getPathPrior(word, pp); + val /= this.getNormalizer(tt, pp); + smoothing += val; + } + } + //myAssert(smoothing > 0, "something wrong with smoothing!"); + return smoothing; + } + + /** + * This function computes the topic beta bucket. + */ + public double computeTermTopicBeta(TIntIntHashMap topic_counts, int word) { + double topic_beta = 0.0; + int[] paths = this.getWordPathIndexSet(word); + for(int tt : topic_counts.keys()) { + if (topic_counts.get(tt) > 0 ) { + for (int pp : paths) { + double val = topic_counts.get(tt) * this.getPathPrior(word, pp); + val /= this.getNormalizer(tt, pp); + topic_beta += val; + } + } + } + //myAssert(topic_beta > 0, "something wrong with topic_beta!"); + return topic_beta; + } + + /** + * This function computes the topic beta bucket. + */ + public double computeTermTopicBetaSortD(ArrayList topic_counts, int word) { + double topic_beta = 0.0; + int[] paths = this.getWordPathIndexSet(word); + for(int ii = 0; ii < topic_counts.size(); ii++) { + int[] current = topic_counts.get(ii); + int tt = current[0]; + int count = current[1]; + if (count > 0 ) { + for (int pp : paths) { + double val = count * this.getPathPrior(word, pp); + val /= this.getNormalizer(tt, pp); + topic_beta += val; + } + } + } + //myAssert(topic_beta > 0, "something wrong with topic_beta!"); + return topic_beta; + } + + /** + * This function computes the topic term bucket. + */ + public double computeTopicTerm(double[] alpha, TIntIntHashMap local_topic_counts, int word, ArrayList dict) { + double norm = 0.0; + HIntIntIntHashMap nonzeros = this.nonZeroPaths.get(word); + + // Notice only the nonzero paths are considered + //for(int tt = 0; tt < this.numTopics; tt++) { + for(int tt : nonzeros.getKey1Set()) { + double topic_alpha = alpha[tt]; + int topic_count = local_topic_counts.get(tt); + int[] paths = nonzeros.get(tt).keys(); + for (int pp = 0; pp < paths.length; pp++) { + int path = paths[pp]; + double val = this.getObservation(tt, word, path); + val *= (topic_alpha + topic_count); + val /= this.getNormalizer(tt, path); + double[] tmp = {tt, path, val}; + dict.add(tmp); + norm += val; + } + } + return norm; + } + + public double computeTopicTermSortD(double[] alpha, ArrayList local_topic_counts, int word, ArrayList dict) { + double norm = 0.0; + HIntIntIntHashMap nonzeros = this.nonZeroPaths.get(word); + + int[] tmpTopics = new int[this.numTopics]; + for(int jj = 0; jj < this.numTopics; jj++) { + tmpTopics[jj] = 0; + } + for(int jj = 0; jj < local_topic_counts.size(); jj++) { + int[] current = local_topic_counts.get(jj); + int tt = current[0]; + tmpTopics[tt] = current[1]; + } + + // Notice only the nonzero paths are considered + //for(int tt = 0; tt < this.numTopics; tt++) { + for(int tt : nonzeros.getKey1Set()) { + double topic_alpha = alpha[tt]; + int topic_count = tmpTopics[tt]; + //local_topic_counts.get(ii); + int[] paths = nonzeros.get(tt).keys(); + for (int pp = 0; pp < paths.length; pp++) { + int path = paths[pp]; + double val = this.getObservation(tt, word, path); + val *= (topic_alpha + topic_count); + val /= this.getNormalizer(tt, path); + double[] tmp = {tt, path, val}; + dict.add(tmp); + norm += val; + } + } + return norm; + } + ////////////////////////////////////////////////////////// + + +} diff --git a/src/cc/mallet/topics/tree/TreeTopicModelFastEst.java b/src/cc/mallet/topics/tree/TreeTopicModelFastEst.java new file mode 100755 index 000000000..b90bfe971 --- /dev/null +++ b/src/cc/mallet/topics/tree/TreeTopicModelFastEst.java @@ -0,0 +1,44 @@ +package cc.mallet.topics.tree; + +import gnu.trove.TIntArrayList; +import gnu.trove.TIntDoubleHashMap; + +import java.util.Random; + +/** + * This class extends the tree topic model fast class + * Only add one more function, it computes the smoothing for each word + * only based on the prior (treat the real count as zero), so it + * serves as the upper bound of smoothing. + * + * @author Yuening Hu + */ + +public class TreeTopicModelFastEst extends TreeTopicModelFast { + public TreeTopicModelFastEst(int numTopics, Random random) { + super(numTopics, random); + this.smoothingEst = new TIntDoubleHashMap(); + } + + /** + * This function computes the upper bound of smoothing bucket. + */ + public void computeSmoothingEst(double[] alpha) { + for(int ww : this.wordPaths.getKey1Set()) { + this.smoothingEst.put(ww, 0.0); + for(int tt = 0; tt < this.numTopics; tt++) { + for(int pp : this.wordPaths.get(ww).keys()) { + TIntArrayList path_nodes = this.wordPaths.get(ww, pp); + double prob = 1.0; + for(int nn = 0; nn < path_nodes.size() - 1; nn++) { + int parent = path_nodes.get(nn); + int child = path_nodes.get(nn+1); + prob *= this.beta.get(parent, child) / this.betaSum.get(parent); + } + prob *= alpha[tt]; + this.smoothingEst.adjustValue(ww, prob); + } + } + } + } +} diff --git a/src/cc/mallet/topics/tree/TreeTopicModelFastEstSortW.java b/src/cc/mallet/topics/tree/TreeTopicModelFastEstSortW.java new file mode 100755 index 000000000..444e07348 --- /dev/null +++ b/src/cc/mallet/topics/tree/TreeTopicModelFastEstSortW.java @@ -0,0 +1,44 @@ +package cc.mallet.topics.tree; + +import gnu.trove.TIntArrayList; +import gnu.trove.TIntDoubleHashMap; + +import java.util.Random; + +/** + * This class extends the tree topic model fast class + * Only add one more function, it computes the smoothing for each word + * only based on the prior (treat the real count as zero), so it + * serves as the upper bound of smoothing. + * + * @author Yuening Hu + */ + +public class TreeTopicModelFastEstSortW extends TreeTopicModelFastSortW { + public TreeTopicModelFastEstSortW(int numTopics, Random random) { + super(numTopics, random); + this.smoothingEst = new TIntDoubleHashMap(); + } + + /** + * This function computes the upper bound of smoothing bucket. + */ + public void computeSmoothingEst(double[] alpha) { + for(int ww : this.wordPaths.getKey1Set()) { + this.smoothingEst.put(ww, 0.0); + for(int tt = 0; tt < this.numTopics; tt++) { + for(int pp : this.wordPaths.get(ww).keys()) { + TIntArrayList path_nodes = this.wordPaths.get(ww, pp); + double prob = 1.0; + for(int nn = 0; nn < path_nodes.size() - 1; nn++) { + int parent = path_nodes.get(nn); + int child = path_nodes.get(nn+1); + prob *= this.beta.get(parent, child) / this.betaSum.get(parent); + } + prob *= alpha[tt]; + this.smoothingEst.adjustValue(ww, prob); + } + } + } + } +} diff --git a/src/cc/mallet/topics/tree/TreeTopicModelFastSortW.java b/src/cc/mallet/topics/tree/TreeTopicModelFastSortW.java new file mode 100755 index 000000000..c1658a317 --- /dev/null +++ b/src/cc/mallet/topics/tree/TreeTopicModelFastSortW.java @@ -0,0 +1,297 @@ +package cc.mallet.topics.tree; + +import gnu.trove.TIntArrayList; +import gnu.trove.TIntIntHashMap; + +import java.io.Serializable; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.Map; +import java.util.Random; + +/** + * nonZeroPathsBubbleSorted: Arraylist sorted + * sorted[0] = (topic << TOPIC_BITS) + path + * sorted[1] = (masked_path) + real_count + * + * @author Yuening Hu + */ + +public class TreeTopicModelFastSortW extends TreeTopicModelFast { + + static int TOPIC_BITS = 16; + + public TreeTopicModelFastSortW(int numTopics, Random random) { + super(numTopics, random); + } + + /** + * After adding instances, update the parameters. + */ + public void updateParams() { + + for(int ww : this.nonZeroPaths.keys()) { + if (!this.nonZeroPathsBubbleSorted.containsKey(ww)) { + ArrayList sorted = new ArrayList (); + this.nonZeroPathsBubbleSorted.put(ww, sorted); + } + } + for(int tt = 0; tt < this.numTopics; tt++) { + for(int pp = 0; pp < this.getPathNum(); pp++) { + this.updatePathMaskedCount(pp, tt); + } + this.computeNormalizer(tt); + } + +// for(int ww : this.nonZeroPaths.keys()) { +// System.out.println("Word " + ww); +// ArrayList sorted = this.nonZeroPathsBubbleSorted.get(ww); +// for(int ii = 0; ii < sorted.size(); ii++) { +// int[] tmp = sorted.get(ii); +// System.out.println(tmp[0] + " " + tmp[1] + " " + tmp[2] + " " + tmp[3]); +// } +// } + } + + protected void updatePathMaskedCount(int path, int topic) { + TopicTreeWalk tw = this.traversals.get(topic); + int ww = this.getWordFromPath(path); + TIntArrayList path_nodes = this.wordPaths.get(ww, path); + int leaf_node = path_nodes.get(path_nodes.size() - 1); + int original_count = tw.getNodeCount(leaf_node); + + int shift_count = this.INTBITS; + int count = this.maxDepth - 1; + int val = 0; + boolean flag = false; + + // note root is not included here + // if count of a node in the path is larger than 0, denote as "1" + // else use "0" + for(int nn = 1; nn < path_nodes.size(); nn++) { + int node = path_nodes.get(nn); + shift_count--; + count--; + if (tw.getNodeCount(node) > 0) { + flag = true; + val += 1 << shift_count; + } + } + + // if a path is shorter than tree depth, fill in "1" + // should we fit in "0" ??? + while (flag && count > 0) { + shift_count--; + val += 1 << shift_count; + count--; + } + + val += original_count; + this.addOrUpdateValue(topic, path, ww, val, false); + + } + + private void addOrUpdateValueold(int topic, int path, int word, int newvalue, boolean flag) { + ArrayList sorted = this.nonZeroPathsBubbleSorted.get(word); + int key = (topic << TOPIC_BITS) + path; + //remove the old value + int oldindex = sorted.size(); + int oldvalue = 0; + for(int ii = 0; ii < sorted.size(); ii++) { + int[] tmp = sorted.get(ii); + if(tmp[0] == key) { + oldvalue = tmp[1]; + sorted.remove(ii); + break; + } + } + if(oldindex > sorted.size()) { + oldindex--; + } + + // flag is true, increase value, else just update value + int value = 0; + if(flag) { + value = oldvalue + newvalue; + } else { + value = newvalue; + } + + //add the new value + if (value > 0) { + int index; + if (value > oldvalue) { + index = 0; + for(int ii = oldindex - 1; ii >= 0; ii--) { + //System.out.println(ii + " " + oldindex + " " + sorted.size()); + int[] tmp = sorted.get(ii); + if(value <= tmp[1]) { + index = ii; + break; + } + } + } else { + index = sorted.size(); + for(int ii = oldindex; ii < sorted.size(); ii++) { + int[] tmp = sorted.get(ii); + if(value >= tmp[1]) { + index = ii; + break; + } + } + } + + int[] newpair = {key, value}; + sorted.add(index, newpair); + } + } + + private void addOrUpdateValue(int topic, int path, int word, int newvalue, boolean flag) { + ArrayList sorted = this.nonZeroPathsBubbleSorted.get(word); + int key = (topic << TOPIC_BITS) + path; + //remove the old value + int value = 0; + for(int ii = 0; ii < sorted.size(); ii++) { + int[] tmp = sorted.get(ii); + if(tmp[0] == key) { + value = tmp[1]; + sorted.remove(ii); + break; + } + } + + // flag is true, increase value, else just update value + if(flag) { + value += newvalue; + } else { + value = newvalue; + } + + //add the new value + if (value > 0) { + int index = sorted.size(); + for(int ii = 0; ii < sorted.size(); ii++) { + int[] tmp = sorted.get(ii); + if(value >= tmp[1]) { + index = ii; + break; + } + } + int[] newpair = {key, value}; + sorted.add(index, newpair); + } + } + + public void changeCount(int topic, int word, int path_index, int delta) { + TopicTreeWalk tw = this.traversals.get(topic); + TIntArrayList path_nodes = this.wordPaths.get(word, path_index); + + // for affected paths, firstly remove the old values + // do not consider the root + for(int nn = 1; nn < path_nodes.size() - 1; nn++) { + int node = path_nodes.get(nn); + double tmp = this.betaSum.get(node) + tw.getNodeCount(node); + tmp = 1 / tmp; + TIntArrayList paths = this.nodeToPath.get(node); + updateNormalizer(topic, paths, tmp); + } + + // change the count for each edge per topic + // return the node index whose count is changed from 0 or to 0 + int[] affected_nodes = tw.changeCount(path_nodes, delta); + + // change path count + this.addOrUpdateValue(topic, path_index, word, delta, true); + + // if necessary, change the path mask of the affected nodes + if (affected_nodes != null && affected_nodes.length > 0) { + int[] affected_paths = this.findAffectedPaths(affected_nodes); + for(int ii = 0; ii < affected_paths.length; ii++) { + this.updatePathMaskedCount(affected_paths[ii], topic); + } + } + + // for affected paths, update the normalizer + for(int nn = 1; nn < path_nodes.size() - 1; nn++) { + int node = path_nodes.get(nn); + double tmp = this.betaSum.get(node) + tw.getNodeCount(node); + TIntArrayList paths = this.nodeToPath.get(node); + updateNormalizer(topic, paths, tmp); + } + + // update the root normalizer + double val = this.betaSum.get(root) + tw.getNodeCount(root); + this.rootNormalizer.put(topic, val); + } + + /** + * This function computes the topic term bucket. + */ + public double computeTopicTerm(double[] alpha, TIntIntHashMap local_topic_counts, int word, ArrayList dict) { + double norm = 0.0; + ArrayList nonzeros = this.nonZeroPathsBubbleSorted.get(word); + + // Notice only the nonzero paths are considered + for(int ii = 0; ii < nonzeros.size(); ii++) { + int[] tmp = nonzeros.get(ii); + int key = tmp[0]; + int tt = key >> TOPIC_BITS; + int pp = key - (tt << TOPIC_BITS); + + double topic_alpha = alpha[tt]; + int topic_count = local_topic_counts.get(tt); + + double val = this.getObservation(tt, word, pp); + val *= (topic_alpha + topic_count); + val /= this.getNormalizer(tt, pp); + + //System.out.println(tt + " " + pp + " " + tmp[2] + " " + val); + + double[] result = {tt, pp, val}; + dict.add(result); + + norm += val; + } + + return norm; + } + + public double computeTopicTermSortD(double[] alpha, ArrayList local_topic_counts, int word, ArrayList dict) { + double norm = 0.0; + ArrayList nonzeros = this.nonZeroPathsBubbleSorted.get(word); + + + int[] tmpTopics = new int[this.numTopics]; + for(int jj = 0; jj < this.numTopics; jj++) { + tmpTopics[jj] = 0; + } + for(int jj = 0; jj < local_topic_counts.size(); jj++) { + int[] current = local_topic_counts.get(jj); + int tt = current[0]; + tmpTopics[tt] = current[1]; + } + + // Notice only the nonzero paths are considered + for(int ii = 0; ii < nonzeros.size(); ii++) { + int[] tmp = nonzeros.get(ii); + int key = tmp[0]; + int tt = key >> TOPIC_BITS; + int pp = key - (tt << TOPIC_BITS); + + double topic_alpha = alpha[tt]; + int topic_count = tmpTopics[tt]; + + double val = this.getObservation(tt, word, pp); + val *= (topic_alpha + topic_count); + val /= this.getNormalizer(tt, pp); + + //System.out.println(tt + " " + pp + " " + tmp[2] + " " + val); + + double[] result = {tt, pp, val}; + dict.add(result); + + norm += val; + } + return norm; + } +} diff --git a/src/cc/mallet/topics/tree/TreeTopicModelNaive.java b/src/cc/mallet/topics/tree/TreeTopicModelNaive.java new file mode 100755 index 000000000..c2df2380d --- /dev/null +++ b/src/cc/mallet/topics/tree/TreeTopicModelNaive.java @@ -0,0 +1,76 @@ +package cc.mallet.topics.tree; + +import java.util.ArrayList; +import java.util.Random; + +import cc.mallet.types.Dirichlet; + +import gnu.trove.TDoubleArrayList; +import gnu.trove.TIntArrayList; +import gnu.trove.TIntDoubleHashMap; +import gnu.trove.TIntDoubleIterator; +import gnu.trove.TIntHashSet; +import gnu.trove.TIntIntHashMap; +import gnu.trove.TIntObjectHashMap; +import gnu.trove.TIntObjectIterator; + +/** + * This class extends the tree topic model + * It implemented the four abstract methods in a naive way: given a word, + * (1) compute the probability for each topic every time directly + * + * @author Yuening Hu + */ + +public class TreeTopicModelNaive extends TreeTopicModel{ + + public TreeTopicModelNaive(int numTopics, Random random) { + super(numTopics, random); + } + + /** + * Just calls changeCountOnly(), nothing else. + */ + public void changeCount(int topic, int word, int path, int delta) { +// TIntArrayList path_nodes = this.wordPaths.get(word, path_index); +// TopicTreeWalk tw = this.traversals.get(topic); +// tw.changeCount(path_nodes, delta); + this.changeCountOnly(topic, word, path, delta); + } + + /** + * Given a word and the topic counts in the current document, + * this function computes the probability per path per topic directly + * according to the sampleing equation. + */ + public double computeTopicTerm(double[] alpha, TIntIntHashMap local_topic_counts, int word, ArrayList dict) { + double norm = 0.0; + int[] paths = this.getWordPathIndexSet(word); + for(int tt = 0; tt < this.numTopics; tt++) { + double topic_alpha = alpha[tt]; + int topic_count = local_topic_counts.get(tt); + for (int pp = 0; pp < paths.length; pp++) { + int path_index = paths[pp]; + double val = this.computeTopicPathProb(tt, word, path_index); + val *= (topic_alpha + topic_count); + double[] tmp = {tt, path_index, val}; + dict.add(tmp); + norm += val; + } + } + return norm; + } + + /** + * No parameter needs to be updated. + */ + public void updateParams() { + } + + /** + * Not actually used. + */ + public double getNormalizer(int topic, int path) { + return 0; + } +} diff --git a/src/cc/mallet/topics/tree/TreeTopicSampler.java b/src/cc/mallet/topics/tree/TreeTopicSampler.java new file mode 100755 index 000000000..638502d8d --- /dev/null +++ b/src/cc/mallet/topics/tree/TreeTopicSampler.java @@ -0,0 +1,300 @@ +package cc.mallet.topics.tree; + +import gnu.trove.TDoubleArrayList; +import gnu.trove.TIntHashSet; + +import java.io.BufferedReader; +import java.io.DataInputStream; +import java.io.File; +import java.io.FileInputStream; +import java.io.IOException; +import java.io.InputStreamReader; +import java.io.PrintStream; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; + +import cc.mallet.topics.TopicInferencer; +import cc.mallet.topics.tree.TreeTopicSamplerSortD.DocData; +import cc.mallet.types.InstanceList; +import cc.mallet.util.Randoms; + + +/** + * This class defines the tree topic sampler. + * Defines the basic functions for input, output, resume. + * Also defines the abstract functions for child class. + * + * @author Yuening Hu + */ + +public abstract class TreeTopicSampler { + + int numTopics; + int numIterations; + int startIter; + Randoms random; + double[] alpha; + double alphaSum; + TDoubleArrayList lhood; + TDoubleArrayList iterTime; + ArrayList vocab; + ArrayList removedWords; + ArrayList removedWordsNew; + TIntHashSet cons; + HashMap topickeep; + + public TreeTopicSampler (int numberOfTopics, double alphaSum, int seed) { + this.numTopics = numberOfTopics; + this.random = new Randoms(seed); + + this.alphaSum = alphaSum; + this.alpha = new double[numTopics]; + Arrays.fill(alpha, alphaSum / numTopics); + + this.vocab = new ArrayList (); + this.removedWords = new ArrayList (); + this.removedWordsNew = new ArrayList (); + this.cons = new TIntHashSet(); + this.topickeep = new HashMap(); + + this.lhood = new TDoubleArrayList(); + this.iterTime = new TDoubleArrayList(); + this.startIter = 0; + } + + ///////////////////////////////////////////////////////////// + + public void setNumIterations(int iters) { + this.numIterations = iters; + } + + /** + * Resumes from the saved files. + */ + public void resume(InstanceList[] training, String resumeDir) { + try { + String statesFile = resumeDir + ".states"; + resumeStates(training, statesFile); + + String lhoodFile = resumeDir + ".lhood"; + resumeLHood(lhoodFile); + } catch (IOException e) { + System.out.println(e.getMessage()); + } + } + + ////////////////////////////////////////////////////////// + + public int getNumIterations() { + return this.numIterations; + } + + /** + * Resume lhood and iterTime from the saved lhood file. + */ + public void resumeLHood(String lhoodFile) throws IOException{ + FileInputStream lhoodfstream = new FileInputStream(lhoodFile); + DataInputStream lhooddstream = new DataInputStream(lhoodfstream); + BufferedReader brLHood = new BufferedReader(new InputStreamReader(lhooddstream)); + // the first line is the title + String strLine = brLHood.readLine(); + while ((strLine = brLHood.readLine()) != null) { + strLine = strLine.trim(); + String[] str = strLine.split("\t"); + // iteration, likelihood, iter_time + myAssert(str.length == 3, "lhood file problem!"); + this.lhood.add(Double.parseDouble(str[1])); + this.iterTime.add(Double.parseDouble(str[2])); + } + this.startIter = this.lhood.size(); + +// if (this.startIter > this.numIterations) { +// System.out.println("Have already sampled " + this.numIterations + " iterations!"); +// System.exit(0); +// } +// System.out.println("Start sampling for iteration " + this.startIter); + + brLHood.close(); + } + + /** + * This function prints the topic words of each topic. + */ + public void printTopWords(File file, int numWords) throws IOException { + PrintStream out = new PrintStream (file); + out.print(displayTopWords(numWords)); + out.close(); + } + + /** + * Prints likelihood and iter time. + */ + public void printStats (File file) throws IOException { + PrintStream out = new PrintStream (file); + String tmp = "Iteration\t\tlikelihood\titer_time\n"; + out.print(tmp); + + for (int iter = 0; iter < this.lhood.size(); iter++) { + tmp = iter + "\t" + this.lhood.get(iter) + "\t" + this.iterTime.get(iter); + out.println(tmp); + } + out.close(); + } + + /** + * This function reports the detected topics, the documents topics, + * and saves states file and lhood file. + */ + public void report(String outputDir, int topWords) throws IOException { + + String topicKeysFile = outputDir + ".topics"; + this.printTopWords(new File(topicKeysFile), topWords); + + String docTopicsFile = outputDir + ".docs"; + this.printDocumentTopics(new File(docTopicsFile)); + + String stateFile = outputDir + ".states"; + this.printState (new File(stateFile)); + + String statsFile = outputDir + ".lhood"; + this.printStats (new File(statsFile)); + + String topicWordsFile = outputDir + ".topic-words"; + this.printTopicWords(new File(topicWordsFile)); + } + + public void loadVocab(String vocabFile) { + + try { + FileInputStream infstream = new FileInputStream(vocabFile); + DataInputStream in = new DataInputStream(infstream); + BufferedReader br = new BufferedReader(new InputStreamReader(in)); + + String strLine; + //Read File Line By Line + while ((strLine = br.readLine()) != null) { + strLine = strLine.trim(); + String[] str = strLine.split("\t"); + if (str.length > 1) { + this.vocab.add(str[1]); + } else { + System.out.println("Vocab file error at line: " + strLine); + } + } + in.close(); + + } catch (IOException e) { + System.out.println("No vocab file Found!"); + } + + } + + /** + * Load StopWords + */ + public void loadRemovedWords(String removedwordFile, ArrayList removed) { + + try { + + FileInputStream infstream = new FileInputStream(removedwordFile); + DataInputStream in = new DataInputStream(infstream); + BufferedReader br = new BufferedReader(new InputStreamReader(in)); + + String strLine; + //Read File Line By Line + while ((strLine = br.readLine()) != null) { + strLine = strLine.trim(); + removed.add(strLine); + } + in.close(); + + } catch (IOException e) { + System.out.println("No stop word file Found!"); + } + } + + /** + * Load constraints + */ + public void loadConstraints(String consFile) { + try { + FileInputStream infstream = new FileInputStream(consFile); + DataInputStream in = new DataInputStream(infstream); + BufferedReader br = new BufferedReader(new InputStreamReader(in)); + + String strLine; + //Read File Line By Line + while ((strLine = br.readLine()) != null) { + strLine = strLine.trim(); + String[] str = strLine.split("\t"); + if (str.length > 1) { + // str[0] is either "MERGE_" or "SPLIT_", not a word + for(int ii = 1; ii < str.length; ii++) { + int word = this.vocab.indexOf(str[ii]); + myAssert(word >= 0, "Constraint words not found in vocab: " + str[ii]); + cons.add(word); + } + this.vocab.add(str[1]); + } else { + System.out.println("Error! " + strLine); + } + } + in.close(); + + } catch (IOException e) { + System.out.println("No constraint file Found!"); + } + + } + + /** + * For words on this list, topic assignments will not be cleared. + */ + public void loadKeepList(String keepFile) { + try { + FileInputStream infstream = new FileInputStream(keepFile); + DataInputStream in = new DataInputStream(infstream); + BufferedReader br = new BufferedReader(new InputStreamReader(in)); + + String strLine; + //Read File Line By Line + while ((strLine = br.readLine()) != null) { + strLine = strLine.trim(); + String[] words = strLine.split(" "); + int word = this.vocab.indexOf(words[0]); + int topic = Integer.parseInt(words[1]); + if (!this.topickeep.containsKey(word)) { + this.topickeep.put(word, new TIntHashSet()); + } + TIntHashSet tmp = this.topickeep.get(word); + tmp.add(topic); + } + in.close(); + + } catch (IOException e) { + System.out.println("No keep file Found!"); + } + + } + + /** + * For testing~~ + */ + public static void myAssert(boolean flag, String info) { + if(!flag) { + System.out.println(info); + System.exit(0); + } + } + + abstract public String displayTopWords (int numWords); + abstract public void printState (File file) throws IOException; + abstract public void printTopicWords (File file) throws IOException; + abstract public void sampleDoc(int doc); + abstract public double docLHood(); + abstract public void printDocumentTopics (File file) throws IOException; + abstract public void resumeStates(InstanceList[] training, String statesFile) throws IOException; + abstract public TreeTopicInferencer getInferencer(); + abstract public TreeMarginalProbEstimator getProbEstimator(); +} diff --git a/src/cc/mallet/topics/tree/TreeTopicSamplerFast.java b/src/cc/mallet/topics/tree/TreeTopicSamplerFast.java new file mode 100755 index 000000000..6fb209da8 --- /dev/null +++ b/src/cc/mallet/topics/tree/TreeTopicSamplerFast.java @@ -0,0 +1,286 @@ +package cc.mallet.topics.tree; + +import gnu.trove.TDoubleArrayList; +import gnu.trove.TIntArrayList; +import gnu.trove.TIntIntHashMap; + +import java.io.Serializable; +import java.util.ArrayList; +import java.util.Arrays; + +import cc.mallet.topics.tree.TreeTopicSamplerHashD.DocData; +import cc.mallet.util.Randoms; + +/** + * This class defines a fast tree topic sampler, which calls the fast tree topic model. + * (1) It divides the sampling into three bins: smoothing, topic beta, topic term. + * as Yao and Mimno's paper, KDD, 2009. + * (2) Each time the smoothing, topic beta, and topic term are recomputed. + * It is faster, because, + * (1) For topic term, only compute the one with non-zero paths (see TreeTopicModelFast). + * (2) The normalizer is saved. + * (3) Topic counts for each documents are ranked. + * + * @author Yuening Hu + */ + +public class TreeTopicSamplerFast extends TreeTopicSamplerHashD { + + public TreeTopicSamplerFast (int numberOfTopics, double alphaSum, int seed, boolean sort) { + super(numberOfTopics, alphaSum, seed); + + if (sort) { + this.topics = new TreeTopicModelFastSortW(this.numTopics, this.random); + //} else if (bubble == 1) { + // this.topics = new TreeTopicModelFastSortT1(this.numTopics, this.random); + //} else if (bubble == 2) { + // this.topics = new TreeTopicModelFastSortT2(this.numTopics, this.random); + } else { + this.topics = new TreeTopicModelFast(this.numTopics, this.random); + } + } + + /** + * For each word in a document, firstly covers its topic and path, then sample a + * topic and path, and update. + */ + public void sampleDoc(int doc_id){ + DocData doc = this.data.get(doc_id); + //System.out.println("doc " + doc_id); + + for(int ii = 0; ii < doc.tokens.size(); ii++) { + //int word = doc.tokens.getIndexAtPosition(ii); + int word = doc.tokens.get(ii); + + this.changeTopic(doc_id, ii, word, -1, -1); + + double smoothing_mass = this.topics.computeTermSmoothing(this.alpha, word); + double topic_beta_mass = this.topics.computeTermTopicBeta(doc.topicCounts, word); + + ArrayList topic_term_score = new ArrayList (); + double topic_term_mass = this.topics.computeTopicTerm(this.alpha, doc.topicCounts, word, topic_term_score); + + double norm = smoothing_mass + topic_beta_mass + topic_term_mass; + double sample = this.random.nextDouble(); + //double sample = 0.5; + sample *= norm; + + int new_topic = -1; + int new_path = -1; + + int[] paths = this.topics.getWordPathIndexSet(word); + + // sample the smoothing bin + if (sample < smoothing_mass) { + for (int tt = 0; tt < this.numTopics; tt++) { + for (int pp : paths) { + double val = alpha[tt] * this.topics.getPathPrior(word, pp); + val /= this.topics.getNormalizer(tt, pp); + sample -= val; + if (sample <= 0.0) { + new_topic = tt; + new_path = pp; + break; + } + } + if (new_topic >= 0) { + break; + } + } + myAssert((new_topic >= 0 && new_topic < numTopics), "something wrong in sampling smoothing!"); + } else { + sample -= smoothing_mass; + } + + // sample the topic beta bin + if (new_topic < 0 && sample < topic_beta_mass) { + for(int tt : doc.topicCounts.keys()) { + for (int pp : paths) { + double val = doc.topicCounts.get(tt) * this.topics.getPathPrior(word, pp); + val /= this.topics.getNormalizer(tt, pp); + sample -= val; + if (sample <= 0.0) { + new_topic = tt; + new_path = pp; + break; + } + } + if (new_topic >= 0) { + break; + } + } + myAssert((new_topic >= 0 && new_topic < numTopics), "something wrong in sampling topic beta!"); + } else { + sample -= topic_beta_mass; + } + + // sample the topic term bin + if (new_topic < 0) { + for(int jj = 0; jj < topic_term_score.size(); jj++) { + double[] tmp = topic_term_score.get(jj); + int tt = (int) tmp[0]; + int pp = (int) tmp[1]; + double val = tmp[2]; + sample -= val; + if (sample <= 0.0) { + new_topic = tt; + new_path = pp; + break; + } + } + myAssert((new_topic >= 0 && new_topic < numTopics), "something wrong in sampling topic term!"); + } + + this.changeTopic(doc_id, ii, word, new_topic, new_path); + } + } + + ///////////////////////////// + // The following methods are for testing only. + + public double callComputeTermTopicBeta(TIntIntHashMap topic_counts, int word) { + return this.topics.computeTermTopicBeta(topic_counts, word); + } + + public double callComputeTermSmoothing(int word) { + return this.topics.computeTermSmoothing(this.alpha, word); + } + + public double computeTopicSmoothTest(int word) { + double smooth = 0.0; + int[] paths = this.topics.getWordPathIndexSet(word); + for(int tt = 0; tt < this.numTopics; tt++) { + double topic_alpha = alpha[tt]; + for (int pp = 0; pp < paths.length; pp++) { + int path_index = paths[pp]; + + TIntArrayList path_nodes = this.topics.wordPaths.get(word, path_index); + TopicTreeWalk tw = this.topics.traversals.get(tt); + + double tmp = 1.0; + for(int ii = 0; ii < path_nodes.size()-1; ii++) { + int parent = path_nodes.get(ii); + int child = path_nodes.get(ii+1); + tmp *= this.topics.beta.get(parent, child); + tmp /= this.topics.betaSum.get(parent) + tw.getNodeCount(parent); + } + tmp *= topic_alpha; + smooth += tmp; + } + } + return smooth; + } + + public double computeTopicTermBetaTest(TIntIntHashMap local_topic_counts, int word) { + double topictermbeta = 0.0; + int[] paths = this.topics.getWordPathIndexSet(word); + for(int tt = 0; tt < this.numTopics; tt++) { + int topic_count = local_topic_counts.get(tt); + for (int pp = 0; pp < paths.length; pp++) { + int path_index = paths[pp]; + + TIntArrayList path_nodes = this.topics.wordPaths.get(word, path_index); + TopicTreeWalk tw = this.topics.traversals.get(tt); + + double tmp = 1.0; + for(int ii = 0; ii < path_nodes.size()-1; ii++) { + int parent = path_nodes.get(ii); + int child = path_nodes.get(ii+1); + tmp *= this.topics.beta.get(parent, child); + tmp /= this.topics.betaSum.get(parent) + tw.getNodeCount(parent); + } + tmp *= topic_count; + + topictermbeta += tmp; + } + } + return topictermbeta; + } + + public double computeTopicTermScoreTest(double[] alpha, TIntIntHashMap local_topic_counts, int word, HIntIntDoubleHashMap dict) { + double termscore = 0.0; + int[] paths = this.topics.getWordPathIndexSet(word); + for(int tt = 0; tt < this.numTopics; tt++) { + double topic_alpha = alpha[tt]; + int topic_count = local_topic_counts.get(tt); + for (int pp = 0; pp < paths.length; pp++) { + int path_index = paths[pp]; + + TIntArrayList path_nodes = this.topics.wordPaths.get(word, path_index); + TopicTreeWalk tw = this.topics.traversals.get(tt); + + double val = 1.0; + double tmp = 1.0; + double normalizer = 1.0; + for(int ii = 0; ii < path_nodes.size()-1; ii++) { + int parent = path_nodes.get(ii); + int child = path_nodes.get(ii+1); + val *= this.topics.beta.get(parent, child) + tw.getCount(parent, child); + tmp *= this.topics.beta.get(parent, child); + normalizer *= this.topics.betaSum.get(parent) + tw.getNodeCount(parent); + } + val -= tmp; + val *= (topic_alpha + topic_count); + val /= normalizer; + + dict.put(tt, path_index, val); + termscore += val; + } + } + return termscore; + } + + public double computeTopicTermTest(double[] alpha, TIntIntHashMap local_topic_counts, int word, ArrayList dict) { + double norm = 0.0; + int[] paths = this.topics.getWordPathIndexSet(word); + for(int tt = 0; tt < this.numTopics; tt++) { + double topic_alpha = alpha[tt]; + int topic_count = local_topic_counts.get(tt); + for (int pp = 0; pp < paths.length; pp++) { + int path_index = paths[pp]; + + TIntArrayList path_nodes = this.topics.wordPaths.get(word, path_index); + TopicTreeWalk tw = this.topics.traversals.get(tt); + + double smooth = 1.0; + for(int ii = 0; ii < path_nodes.size()-1; ii++) { + int parent = path_nodes.get(ii); + int child = path_nodes.get(ii+1); + smooth *= this.topics.beta.get(parent, child); + smooth /= this.topics.betaSum.get(parent) + tw.getNodeCount(parent); + } + smooth *= topic_alpha; + + double topicterm = 1.0; + for(int ii = 0; ii < path_nodes.size()-1; ii++) { + int parent = path_nodes.get(ii); + int child = path_nodes.get(ii+1); + topicterm *= this.topics.beta.get(parent, child); + topicterm /= this.topics.betaSum.get(parent) + tw.getNodeCount(parent); + } + topicterm *= topic_count; + + double termscore = 1.0; + double tmp = 1.0; + double normalizer = 1.0; + for(int ii = 0; ii < path_nodes.size()-1; ii++) { + int parent = path_nodes.get(ii); + int child = path_nodes.get(ii+1); + termscore *= this.topics.beta.get(parent, child) + tw.getCount(parent, child); + tmp *= this.topics.beta.get(parent, child); + normalizer *= this.topics.betaSum.get(parent) + tw.getNodeCount(parent); + } + termscore -= tmp; + termscore *= (topic_alpha + topic_count); + termscore /= normalizer; + + double val = smooth + topicterm + termscore; + double[] tmptmp = {tt, path_index, val}; + dict.add(tmptmp); + norm += val; + System.out.println("Fast Topic " + tt + " " + smooth + " " + topicterm + " " + termscore + " " + tmp + " " + topic_alpha + " " + topic_count + " " + termscore); + } + } + return norm; + } +} diff --git a/src/cc/mallet/topics/tree/TreeTopicSamplerFastEst.java b/src/cc/mallet/topics/tree/TreeTopicSamplerFastEst.java new file mode 100755 index 000000000..d0ee50c51 --- /dev/null +++ b/src/cc/mallet/topics/tree/TreeTopicSamplerFastEst.java @@ -0,0 +1,157 @@ +package cc.mallet.topics.tree; + +import java.util.ArrayList; + +import cc.mallet.topics.tree.TreeTopicSamplerHashD.DocData; +import gnu.trove.TIntArrayList; +import gnu.trove.TIntDoubleHashMap; + +/** + * This class improves the fast sampler based on estimation of smoothing. + * Most of the time, the smoothing is very small and not worth to recompute since + * it will hardly be hit. So we use an upper bound for smoothing. + * Only if the smoothing bin is hit, the actual smoothing is computed and resampled. + * + * @author Yuening Hu + */ + +public class TreeTopicSamplerFastEst extends TreeTopicSamplerHashD{ + + public TreeTopicSamplerFastEst (int numberOfTopics, double alphaSum, int seed, boolean sort) { + super(numberOfTopics, alphaSum, seed); + + if (sort) { + this.topics = new TreeTopicModelFastEstSortW(this.numTopics, this.random); + } else { + this.topics = new TreeTopicModelFastEst(this.numTopics, this.random); + } + } + + /** + * Use an upper bound for smoothing. Only if the smoothing + * bin is hit, the actual smoothing is computed and resampled. + */ + public void sampleDoc(int doc_id) { + DocData doc = this.data.get(doc_id); + //System.out.println("doc " + doc_id); + //int[] tmpstats = this.stats.get(this.stats.size()-1); + + for(int ii = 0; ii < doc.tokens.size(); ii++) { + //int word = doc.tokens.getIndexAtPosition(ii); + int word = doc.tokens.get(ii); + + this.changeTopic(doc_id, ii, word, -1, -1); + + //double smoothing_mass = this.topics.computeTermSmoothing(this.alpha, word); + double smoothing_mass_est = this.topics.smoothingEst.get(word); + + double topic_beta_mass = this.topics.computeTermTopicBeta(doc.topicCounts, word); + + ArrayList topic_term_score = new ArrayList(); + double topic_term_mass = this.topics.computeTopicTerm(this.alpha, doc.topicCounts, word, topic_term_score); + + double norm_est = smoothing_mass_est + topic_beta_mass + topic_term_mass; + double sample = this.random.nextDouble(); + //double sample = 0.5; + sample *= norm_est; + + int new_topic = -1; + int new_path = -1; + + int[] paths = this.topics.getWordPathIndexSet(word); + + // sample the smoothing bin + if (sample < smoothing_mass_est) { + //tmpstats[0] += 1; + double smoothing_mass = this.topics.computeTermSmoothing(this.alpha, word); + double norm = smoothing_mass + topic_beta_mass + topic_term_mass; + sample /= norm_est; + sample *= norm; + if (sample < smoothing_mass) { + //tmpstats[1] += 1; + for (int tt = 0; tt < this.numTopics; tt++) { + for (int pp : paths) { + double val = alpha[tt] * this.topics.getPathPrior(word, pp); + val /= this.topics.getNormalizer(tt, pp); + sample -= val; + if (sample <= 0.0) { + new_topic = tt; + new_path = pp; + break; + } + } + if (new_topic >= 0) { + break; + } + } + myAssert((new_topic >= 0 && new_topic < numTopics), "something wrong in sampling smoothing!"); + } else { + sample -= smoothing_mass; + } + } else { + sample -= smoothing_mass_est; + } + + // sample topic beta bin + if (new_topic < 0 && sample < topic_beta_mass) { + //tmpstats[2] += 1; + for(int tt : doc.topicCounts.keys()) { + for (int pp : paths) { + double val = doc.topicCounts.get(tt) * this.topics.getPathPrior(word, pp); + val /= this.topics.getNormalizer(tt, pp); + sample -= val; + if (sample <= 0.0) { + new_topic = tt; + new_path = pp; + break; + } + } + if (new_topic >= 0) { + break; + } + } + myAssert((new_topic >= 0 && new_topic < numTopics), "something wrong in sampling topic beta!"); + } else { + sample -= topic_beta_mass; + } + + + // sample topic term bin + if (new_topic < 0) { + //tmpstats[3] += 1; + for(int jj = 0; jj < topic_term_score.size(); jj++) { + double[] tmp = topic_term_score.get(jj); + int tt = (int) tmp[0]; + int pp = (int) tmp[1]; + double val = tmp[2]; + sample -= val; + if (sample <= 0.0) { + new_topic = tt; + new_path = pp; + break; + } + } + myAssert((new_topic >= 0 && new_topic < numTopics), "something wrong in sampling topic term!"); + } + + this.changeTopic(doc_id, ii, word, new_topic, new_path); + } + + } + + /** + * Before sampling start, compute smoothing upper bound for each word. + */ + public void estimate(int numIterations, String outputFolder, int outputInterval, int topWords) { + if(this.topics instanceof TreeTopicModelFastEst) { + TreeTopicModelFastEst tmp = (TreeTopicModelFastEst) this.topics; + tmp.computeSmoothingEst(this.alpha); + } else if (this.topics instanceof TreeTopicModelFastEstSortW) { + TreeTopicModelFastEstSortW tmp = (TreeTopicModelFastEstSortW) this.topics; + tmp.computeSmoothingEst(this.alpha); + } + + super.estimate(numIterations, outputFolder, outputInterval, topWords); + } + +} diff --git a/src/cc/mallet/topics/tree/TreeTopicSamplerFastEstSortD.java b/src/cc/mallet/topics/tree/TreeTopicSamplerFastEstSortD.java new file mode 100755 index 000000000..3c41ed0cd --- /dev/null +++ b/src/cc/mallet/topics/tree/TreeTopicSamplerFastEstSortD.java @@ -0,0 +1,150 @@ +package cc.mallet.topics.tree; + +import java.util.ArrayList; + +import cc.mallet.topics.tree.TreeTopicSamplerHashD.DocData; + +/** + * This class improves the fast sampler based on estimation of smoothing. + * Most of the time, the smoothing is very small and not worth to recompute since + * it will hardly be hit. So we use an upper bound for smoothing. + * Only if the smoothing bin is hit, the actual smoothing is computed and resampled. + * + * @author Yuening Hu + */ + +public class TreeTopicSamplerFastEstSortD extends TreeTopicSamplerSortD{ + + public TreeTopicSamplerFastEstSortD (int numberOfTopics, double alphaSum, int seed, boolean sort) { + super(numberOfTopics, alphaSum, seed); + + if (sort) { + this.topics = new TreeTopicModelFastEstSortW(this.numTopics, this.random); + } else { + this.topics = new TreeTopicModelFastEst(this.numTopics, this.random); + } + } + + /** + * Use an upper bound for smoothing. Only if the smoothing + * bin is hit, the actual smoothing is computed and resampled. + */ + public void sampleDoc(int doc_id) { + DocData doc = this.data.get(doc_id); + //System.out.println("doc " + doc_id); + + for(int ii = 0; ii < doc.tokens.size(); ii++) { + //int word = doc.tokens.getIndexAtPosition(ii); + int word = doc.tokens.get(ii); + + this.changeTopic(doc_id, ii, word, -1, -1); + + //double smoothing_mass = this.topics.computeTermSmoothing(this.alpha, word); + double smoothing_mass_est = this.topics.smoothingEst.get(word); + double topic_beta_mass = this.topics.computeTermTopicBetaSortD(doc.topicCounts, word); + + ArrayList topic_term_score = new ArrayList(); + double topic_term_mass = this.topics.computeTopicTermSortD(this.alpha, doc.topicCounts, word, topic_term_score); + + double norm_est = smoothing_mass_est + topic_beta_mass + topic_term_mass; + double sample = this.random.nextDouble(); + //double sample = 0.5; + sample *= norm_est; + + int new_topic = -1; + int new_path = -1; + + int[] paths = this.topics.getWordPathIndexSet(word); + + // sample the smoothing bin + if (sample < smoothing_mass_est) { + double smoothing_mass = this.topics.computeTermSmoothing(this.alpha, word); + double norm = smoothing_mass + topic_beta_mass + topic_term_mass; + sample /= norm_est; + sample *= norm; + if (sample < smoothing_mass) { + for (int tt = 0; tt < this.numTopics; tt++) { + for (int pp : paths) { + double val = alpha[tt] * this.topics.getPathPrior(word, pp); + val /= this.topics.getNormalizer(tt, pp); + sample -= val; + if (sample <= 0.0) { + new_topic = tt; + new_path = pp; + break; + } + } + if (new_topic >= 0) { + break; + } + } + myAssert((new_topic >= 0 && new_topic < numTopics), "something wrong in sampling smoothing!"); + } else { + sample -= smoothing_mass; + } + } else { + sample -= smoothing_mass_est; + } + + // sample topic beta bin + if (new_topic < 0 && sample < topic_beta_mass) { + for(int jj = 0; jj < doc.topicCounts.size(); jj++) { + int[] current = doc.topicCounts.get(jj); + int tt = current[0]; + int count = current[1]; + for(int pp : paths) { + double val = count * this.topics.getPathPrior(word, pp); + val /= this.topics.getNormalizer(tt, pp); + sample -= val; + if (sample <= 0.0) { + new_topic = tt; + new_path = pp; + break; + } + } + if (new_topic >= 0) { + break; + } + } + myAssert((new_topic >= 0 && new_topic < numTopics), "something wrong in sampling topic beta!"); + } else { + sample -= topic_beta_mass; + } + + // sample topic term bin + if (new_topic < 0) { + for(int jj = 0; jj < topic_term_score.size(); jj++) { + double[] tmp = topic_term_score.get(jj); + int tt = (int) tmp[0]; + int pp = (int) tmp[1]; + double val = tmp[2]; + sample -= val; + if (sample <= 0.0) { + new_topic = tt; + new_path = pp; + break; + } + } + myAssert((new_topic >= 0 && new_topic < numTopics), "something wrong in sampling topic term!"); + } + + this.changeTopic(doc_id, ii, word, new_topic, new_path); + } + + } + + /** + * Before sampling start, compute smoothing upper bound for each word. + */ + public void estimate(int numIterations, String outputFolder, int outputInterval, int topWords) { + if(this.topics instanceof TreeTopicModelFastEst) { + TreeTopicModelFastEst tmp = (TreeTopicModelFastEst) this.topics; + tmp.computeSmoothingEst(this.alpha); + } else if (this.topics instanceof TreeTopicModelFastEstSortW) { + TreeTopicModelFastEstSortW tmp = (TreeTopicModelFastEstSortW) this.topics; + tmp.computeSmoothingEst(this.alpha); + } + + super.estimate(numIterations, outputFolder, outputInterval, topWords); + } +} \ No newline at end of file diff --git a/src/cc/mallet/topics/tree/TreeTopicSamplerFastSortD.java b/src/cc/mallet/topics/tree/TreeTopicSamplerFastSortD.java new file mode 100755 index 000000000..243e8d5fd --- /dev/null +++ b/src/cc/mallet/topics/tree/TreeTopicSamplerFastSortD.java @@ -0,0 +1,138 @@ +package cc.mallet.topics.tree; + +import gnu.trove.TDoubleArrayList; +import gnu.trove.TIntArrayList; +import gnu.trove.TIntIntHashMap; + +import java.io.Serializable; +import java.util.ArrayList; +import java.util.Arrays; + +import cc.mallet.topics.tree.TreeTopicSamplerHashD.DocData; +import cc.mallet.util.Randoms; + +/** + * This class defines a fast tree topic sampler, which calls the fast tree topic model. + * (1) It divides the sampling into three bins: smoothing, topic beta, topic term. + * as Yao and Mimno's paper, KDD, 2009. + * (2) Each time the smoothing, topic beta, and topic term are recomputed. + * It is faster, because, + * (1) For topic term, only compute the one with non-zero paths (see TreeTopicModelFast). + * (2) The normalizer is saved. + * (3) Topic counts for each documents are ranked. + * + * @author Yuening Hu + */ +public class TreeTopicSamplerFastSortD extends TreeTopicSamplerSortD { + + public TreeTopicSamplerFastSortD (int numberOfTopics, double alphaSum, int seed, boolean sort) { + super(numberOfTopics, alphaSum, seed); + this.topics = new TreeTopicModelFast(this.numTopics, this.random); + + if (sort) { + this.topics = new TreeTopicModelFastSortW(this.numTopics, this.random); + } else { + this.topics = new TreeTopicModelFast(this.numTopics, this.random); + } + } + + /** + * For each word in a document, firstly covers its topic and path, then sample a + * topic and path, and update. + */ + public void sampleDoc(int doc_id){ + DocData doc = this.data.get(doc_id); + //System.out.println("doc " + doc_id); + + for(int ii = 0; ii < doc.tokens.size(); ii++) { + //int word = doc.tokens.getIndexAtPosition(ii); + int word = doc.tokens.get(ii); + + this.changeTopic(doc_id, ii, word, -1, -1); + + double smoothing_mass = this.topics.computeTermSmoothing(this.alpha, word); + double topic_beta_mass = this.topics.computeTermTopicBetaSortD(doc.topicCounts, word); + + ArrayList topic_term_score = new ArrayList (); + double topic_term_mass = this.topics.computeTopicTermSortD(this.alpha, doc.topicCounts, word, topic_term_score); + + double norm = smoothing_mass + topic_beta_mass + topic_term_mass; + double sample = this.random.nextDouble(); + //double sample = 0.5; + sample *= norm; + + int new_topic = -1; + int new_path = -1; + + int[] paths = this.topics.getWordPathIndexSet(word); + + // sample the smoothing bin + if (sample < smoothing_mass) { + for (int tt = 0; tt < this.numTopics; tt++) { + for (int pp : paths) { + double val = alpha[tt] * this.topics.getPathPrior(word, pp); + val /= this.topics.getNormalizer(tt, pp); + sample -= val; + if (sample <= 0.0) { + new_topic = tt; + new_path = pp; + break; + } + } + if (new_topic >= 0) { + break; + } + } + myAssert((new_topic >= 0 && new_topic < numTopics), "something wrong in sampling smoothing!"); + } else { + sample -= smoothing_mass; + } + + // sample the topic beta bin + if (new_topic < 0 && sample < topic_beta_mass) { + + for(int jj = 0; jj < doc.topicCounts.size(); jj++) { + int[] current = doc.topicCounts.get(jj); + int tt = current[0]; + int count = current[1]; + for(int pp : paths) { + double val = count * this.topics.getPathPrior(word, pp); + val /= this.topics.getNormalizer(tt, pp); + sample -= val; + if (sample <= 0.0) { + new_topic = tt; + new_path = pp; + break; + } + } + if (new_topic >= 0) { + break; + } + } + myAssert((new_topic >= 0 && new_topic < numTopics), "something wrong in sampling topic beta!"); + } else { + sample -= topic_beta_mass; + } + + // sample the topic term bin + if (new_topic < 0) { + for(int jj = 0; jj < topic_term_score.size(); jj++) { + double[] tmp = topic_term_score.get(jj); + int tt = (int) tmp[0]; + int pp = (int) tmp[1]; + double val = tmp[2]; + sample -= val; + if (sample <= 0.0) { + new_topic = tt; + new_path = pp; + break; + } + } + myAssert((new_topic >= 0 && new_topic < numTopics), "something wrong in sampling topic term!"); + } + + this.changeTopic(doc_id, ii, word, new_topic, new_path); + } + } + +} diff --git a/src/cc/mallet/topics/tree/TreeTopicSamplerHashD.java b/src/cc/mallet/topics/tree/TreeTopicSamplerHashD.java new file mode 100755 index 000000000..d5afce768 --- /dev/null +++ b/src/cc/mallet/topics/tree/TreeTopicSamplerHashD.java @@ -0,0 +1,648 @@ +package cc.mallet.topics.tree; + +import gnu.trove.TDoubleArrayList; +import gnu.trove.TIntArrayList; +import gnu.trove.TIntHashSet; +import gnu.trove.TIntIntHashMap; +import gnu.trove.TIntIntIterator; +import gnu.trove.TIntObjectHashMap; + +import java.io.BufferedOutputStream; +import java.io.BufferedReader; +import java.io.DataInputStream; +import java.io.File; +import java.io.FileInputStream; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.InputStreamReader; +import java.io.PrintStream; +import java.io.Serializable; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.HashSet; +import java.util.zip.GZIPOutputStream; + +import cc.mallet.types.Dirichlet; +import cc.mallet.types.FeatureSequence; +import cc.mallet.types.IDSorter; +import cc.mallet.types.Instance; +import cc.mallet.types.InstanceList; +import cc.mallet.util.Randoms; + +/** + * This class defines the tree topic sampler, which loads the instances, + * reports the topics, and leaves the sampler method as an abstract method, + * which might be various for different methods. + * + * @author Yuening Hu + */ + +public abstract class TreeTopicSamplerHashD extends TreeTopicSampler implements TreeTopicSamplerInterface{ + + /** + * This class defines the format of a document. + */ + public class DocData { + TIntArrayList tokens; + TIntArrayList topics; + TIntArrayList paths; + // sort + TIntIntHashMap topicCounts; + String docName; + + public DocData (String name, TIntArrayList tokens, TIntArrayList topics, + TIntArrayList paths, TIntIntHashMap topicCounts) { + this.docName = name; + this.tokens = tokens; + this.topics = topics; + this.paths = paths; + this.topicCounts = topicCounts; + } + + public String toString() { + String result = "***************\n"; + result += docName + "\n"; + + result += "tokens: "; + for (int jj = 0; jj < tokens.size(); jj++) { + int index = tokens.get(jj); + String word = vocab.get(index); + result += word + " " + index + ", "; + } + + result += "\ntopics: "; + result += topics.toString(); + + result += "\npaths: "; + result += paths.toString(); + + result += "\ntopicCounts: "; + + for(TIntIntIterator it = this.topicCounts.iterator(); it.hasNext(); ) { + it.advance(); + result += "Topic " + it.key() + ": " + it.value() + ", "; + } + result += "\n*****************\n"; + return result; + } + } + + public class WordProb implements Comparable { + int wi; + double p; + public WordProb (int wi, double p) { this.wi = wi; this.p = p; } + public final int compareTo (Object o2) { + if (p > ((WordProb)o2).p) + return -1; + else if (p == ((WordProb)o2).p) + return 0; + else return 1; + } + } + + ArrayList data; + TreeTopicModel topics; + + public TreeTopicSamplerHashD (int numberOfTopics, double alphaSum, int seed) { + super(numberOfTopics, alphaSum, seed); + this.data = new ArrayList (); + + // notice: this.topics is not initialized in this abstract class, + // in each sub class, the topics variable is initialized differently. + } + + /** + * This function adds instances given the training data in mallet input data format. + * For each token in a document, sample a topic and then sample a path based on prior. + */ + public void addInstances(InstanceList[] training) { + boolean debug = false; + int count = 0; + for(int ll = 0; ll < training.length; ll++) { + int totalcount = 0; + for (Instance instance : training[ll]) { + count++; + FeatureSequence original_tokens = (FeatureSequence) instance.getData(); + String name = instance.getName().toString(); + //String name = "null-source"; + //if (instance.getSource() != null) { + // name = instance.getSource().toString(); + //} + + // *** remained problem: keep topicCounts sorted + TIntArrayList tokens = new TIntArrayList(original_tokens.getLength()); + TIntIntHashMap topicCounts = new TIntIntHashMap (); + TIntArrayList topics = new TIntArrayList(original_tokens.getLength()); + TIntArrayList paths = new TIntArrayList(original_tokens.getLength()); + + for (int jj = 0; jj < original_tokens.getLength(); jj++) { + String word = (String) original_tokens.getObjectAtPosition(jj); + int token = this.vocab.indexOf(word); + int removed = this.removedWordsNew.indexOf(word); + int removednew = this.removedWordsNew.indexOf(word); + if(token != -1 && removed == -1 && removednew == -1) { + int topic = random.nextInt(numTopics); + if(debug) { topic = count % numTopics; } + tokens.add(token); + topics.add(topic); + topicCounts.adjustOrPutValue(topic, 1, 1); + // sample a path for this topic + int path_index = this.topics.initialize(token, topic); + paths.add(path_index); + } + } + DocData doc = new DocData(name, tokens, topics, paths, topicCounts); + this.data.add(doc); + + totalcount += tokens.size(); + //if (totalcount > 200000) { + // System.out.println("total number of tokens: " + totalcount + " docs: " + count); + // break; + //} + } + System.out.println("total number of tokens: " + totalcount); + //System.out.println(doc); + } + } + + /** + * Resume instance states from the saved states file. + */ + public void resumeStates(InstanceList[] training, String statesFile) throws IOException{ + FileInputStream statesfstream = new FileInputStream(statesFile); + DataInputStream statesdstream = new DataInputStream(statesfstream); + BufferedReader states = new BufferedReader(new InputStreamReader(statesdstream)); + + // reading topics, paths + for(int ll = 0; ll < training.length; ll++) { + for (Instance instance : training[ll]) { + FeatureSequence original_tokens = (FeatureSequence) instance.getData(); + String name = instance.getName().toString(); + + // *** remained problem: keep topicCounts sorted + TIntArrayList tokens = new TIntArrayList(original_tokens.getLength()); + TIntIntHashMap topicCounts = new TIntIntHashMap (); + TIntArrayList topics = new TIntArrayList(original_tokens.getLength()); + TIntArrayList paths = new TIntArrayList(original_tokens.getLength()); + + // + String statesLine = states.readLine(); + myAssert(statesLine != null, "statesFile doesn't match with the training data"); + statesLine = statesLine.trim(); + String[] str = statesLine.split("\t"); + + int count = -1; + for (int jj = 0; jj < original_tokens.getLength(); jj++) { + String word = (String) original_tokens.getObjectAtPosition(jj); + int token = this.vocab.indexOf(word); + int removed = this.removedWords.indexOf(word); + int removednew = this.removedWordsNew.indexOf(word); + if(token != -1 && removed == -1) { + count++; + if (removednew == -1) { + String[] tp = str[count].split(":"); + myAssert(tp.length == 2, "statesFile problem!"); + int topic = Integer.parseInt(tp[0]); + //int path = Integer.parseInt(tp[1]); + + int wordpath = Integer.parseInt(tp[1]); + int path = -1; + int backoffpath = -1; + // find the path for this wordpath + TIntObjectHashMap allpaths = this.topics.wordPaths.get(token); + for(int pp : allpaths.keys()) { + if(backoffpath == -1 && this.topics.pathToWordPath.get(pp) == 0){ + backoffpath = pp; + } + if(this.topics.pathToWordPath.get(pp) == wordpath){ + path = pp; + break; + } + } + + if(path == -1) { + // this path must be in a correlation, it will be cleared later + path = backoffpath; + myAssert(path != -1, "path problem"); + + //String tmp = ""; + //tmp += "file " + name + "\n"; + //tmp += "word " + word + "\n"; + //tmp += "token " + token + "\n"; + //tmp += "index " + count + "\n"; + //tmp += "topic " + topic + "\n"; + //tmp += "wordpath " + wordpath + "\n"; + //tmp += "allpaths"; + //for(int pp : allpaths.keys()) { + // tmp += " " + pp; + //} + //System.out.println(tmp); + } + + tokens.add(token); + topics.add(topic); + paths.add(path); + topicCounts.adjustOrPutValue(topic, 1, 1); + this.topics.changeCountOnly(topic, token, path, 1); + } + } + } + if(count != -1) { + count++; + myAssert(str.length == count, "resume problem!"); + } + + DocData doc = new DocData(name, tokens, topics, paths, topicCounts); + this.data.add(doc); + } + } + states.close(); + } + + /** + * This function clears the topic and path assignments for some words: + * (1) term option: only clears the topic and path for constraint words; + * (2) doc option: clears the topic and path for documents which contain + * at least one of the constraint words. + */ + public void clearTopicAssignments(String option, String consFile, String keepFile) { + if (consFile != null) { + this.loadConstraints(consFile); + } + if (this.cons == null || this.cons.size() <= 0) { + return; + } + + if (keepFile != null) { + this.loadKeepList(keepFile); + } else { + this.topickeep = new HashMap(); + } + + for(int dd = 0; dd < this.data.size(); dd++) { + DocData doc = this.data.get(dd); + + for(int ii = 0; ii < doc.tokens.size(); ii++) { + int word = doc.tokens.get(ii); + int topic = doc.topics.get(ii); + int path = doc.paths.get(ii); + + boolean keepTopicFlag = false; + if(this.topickeep.containsKey(word)) { + TIntHashSet keeptopics = this.topickeep.get(word); + if(keeptopics.contains(topic)) { + keepTopicFlag = true; + } + } + + if (option.equals("term")) { + if(this.cons.contains(word) && (!keepTopicFlag)) { + // change the count for count and node_count in TopicTreeWalk + this.topics.changeCountOnly(topic, word, path, -1); + doc.topics.set(ii, -1); + doc.paths.set(ii, -1); + //myAssert(doc.topicCounts.get(topic) >= 1, "clear topic assignments problem"); + doc.topicCounts.adjustValue(topic, -1); + } + } else { // option.equals("doc") + if(!keepTopicFlag) { + this.topics.changeCountOnly(topic, word, path, -1); + doc.topics.set(ii, -1); + doc.paths.set(ii, -1); + doc.topicCounts.adjustValue(topic, -1); + } + } + } + } + +// for(int dd = 0; dd < this.data.size(); dd++) { +// DocData doc = this.data.get(dd); +// Boolean flag = false; +// for(int ii = 0; ii < doc.tokens.size(); ii++) { +// int word = doc.tokens.get(ii); +// int topic = doc.topics.get(ii); +// +// boolean keepTopicFlag = false; +// if(this.topickeep.containsKey(word)) { +// TIntHashSet keeptopics = this.topickeep.get(word); +// if(keeptopics.contains(topic)) { +// keepTopicFlag = true; +// } +// } +// +// if(this.cons.contains(word) && (!keepTopicFlag)) { +// if (option.equals("term")) { +// int path = doc.paths.get(ii); +// // change the count for count and node_count in TopicTreeWalk +// this.topics.changeCountOnly(topic, word, path, -1); +// doc.topics.set(ii, -1); +// doc.paths.set(ii, -1); +// myAssert(doc.topicCounts.get(topic) >= 1, "clear topic assignments problem"); +// doc.topicCounts.adjustValue(topic, -1); +// } else if (option.equals("doc")) { +// flag = true; +// break; +// } +// } +// } +// if (flag) { +// for(int ii = 0; ii < doc.tokens.size(); ii++) { +// int word = doc.tokens.get(ii); +// int topic = doc.topics.get(ii); +// int path = doc.paths.get(ii); +// this.topics.changeCountOnly(topic, word, path, -1); +// doc.topics.set(ii, -1); +// doc.paths.set(ii, -1); +// } +// doc.topicCounts.clear(); +// } +// } + + } + + /** + * This function defines how to change a topic during the sampling process. + * It handles the case where both new_topic and old_topic are "-1" (empty topic). + */ + public void changeTopic(int doc, int index, int word, int new_topic, int new_path) { + DocData current_doc = this.data.get(doc); + int old_topic = current_doc.topics.get(index); + int old_path = current_doc.paths.get(index); + + if (old_topic != -1) { + myAssert((new_topic == -1 && new_path == -1), "old_topic != -1 but new_topic != -1"); + this.topics.changeCount(old_topic, word, old_path, -1); + //myAssert(current_doc.topicCounts.get(old_topic) > 0, "Something wrong in changTopic"); + current_doc.topicCounts.adjustValue(old_topic, -1); + current_doc.topics.set(index, -1); + current_doc.paths.set(index, -1); + } + + if (new_topic != -1) { + myAssert((old_topic == -1 && old_path == -1), "new_topic != -1 but old_topic != -1"); + this.topics.changeCount(new_topic, word, new_path, 1); + current_doc.topicCounts.adjustOrPutValue(new_topic, 1, 1); + current_doc.topics.set(index, new_topic); + current_doc.paths.set(index, new_path); + } + } + + /** + * The function computes the document likelihood. + */ + public double docLHood() { + int docNum = this.data.size(); + + double val = 0.0; + val += Dirichlet.logGamma(this.alphaSum) * docNum; + double tmp = 0.0; + for (int tt = 0; tt < this.numTopics; tt++) { + tmp += Dirichlet.logGamma(this.alpha[tt]); + } + val -= tmp * docNum; + for (int dd = 0; dd < docNum; dd++) { + DocData doc = this.data.get(dd); + for (int tt = 0; tt < this.numTopics; tt++) { + val += Dirichlet.logGamma(this.alpha[tt] + doc.topicCounts.get(tt)); + } + val -= Dirichlet.logGamma(this.alphaSum + doc.topics.size()); + } + return val; + } + + /** + * Print the topic proportion for all documents. + */ + public void printDocumentTopics (File file) throws IOException { + PrintStream out = new PrintStream (file); + out.print ("#doc source topic proportion ...\n"); + + IDSorter[] sortedTopics = new IDSorter[ this.numTopics ]; + for (int topic = 0; topic < this.numTopics; topic++) { + // Initialize the sorters with dummy values + sortedTopics[topic] = new IDSorter(topic, topic); + } + + for (int dd = 0; dd < this.data.size(); dd++) { + DocData doc = this.data.get(dd); + + // compute topic proportion in one document + double sum = 0.0; + double[] prob = new double[this.numTopics]; + for (int topic=0; topic < this.numTopics; topic++) { + if (doc.topicCounts.containsKey(topic)) { + prob[topic] = this.alpha[topic] + doc.topicCounts.get(topic); + } else { + prob[topic] = this.alpha[topic]; + } + sum += prob[topic]; + } + + // normalize and sort + for (int topic=0; topic < this.numTopics; topic++) { + prob[topic] /= sum; + sortedTopics[topic].set(topic, prob[topic]); + } + Arrays.sort(sortedTopics); + + // print one document + out.print (dd); out.print (" "); + + if (doc.docName != null || !doc.docName.equals(" ")) { + out.print (doc.docName); + } else { + out.print ("null-source"); + } + out.print (" "); + for (int i = 0; i < numTopics; i++) { + out.print (sortedTopics[i].getID() + " " + + sortedTopics[i].getWeight() + " "); + } + out.print (" \n"); + } + out.close(); + } + + ////////////////////////////////////////////////////// + + /** + * This function loads vocab, loads tree, and initialize parameters. + */ + public void initialize(String treeFiles, String hyperFile, String vocabFile, String removedwordsFile) { + this.loadVocab(vocabFile); + if (removedwordsFile != null) { + this.loadRemovedWords(removedwordsFile + ".all", this.removedWords); + this.loadRemovedWords(removedwordsFile + ".new", this.removedWordsNew); + } + this.topics.initializeParams(treeFiles, hyperFile, this.vocab); + } + + /** + * This function defines the sampling process, computes the likelihood and running time, + * and specifies when to save the states files. + */ + public void estimate(int numIterations, String outputFolder, int outputInterval, int topWords) { + // update parameters + this.topics.updateParams(); + + if (this.startIter > this.numIterations) { + System.out.println("Have already sampled " + this.numIterations + " iterations!"); + System.exit(0); + } + System.out.println("Start sampling for iteration " + this.startIter); + + for (int ii = this.startIter; ii <= numIterations; ii++) { + //int[] tmpstats = {0, 0, 0, 0}; + //this.stats.add(tmpstats); + long starttime = System.currentTimeMillis(); + //System.out.println("Iter " + ii); + for (int dd = 0; dd < this.data.size(); dd++) { + this.sampleDoc(dd); + if (dd > 0 && dd % 10000 == 0) { + System.out.println("Sampled " + dd + " documents."); + } + } + + double totaltime = (double)(System.currentTimeMillis() - starttime) / 1000; + double lhood = 0; + if ((ii > 0 && ii % outputInterval == 0) || ii == numIterations) { + lhood = this.lhood(); + } + this.lhood.add(lhood); + this.iterTime.add(totaltime); + + if (ii % 10 == 0) { + String tmp = "Iteration " + ii; + tmp += " likelihood " + lhood; + tmp += " totaltime " + totaltime; + System.out.println(tmp); + } + + if ((ii > 0 && ii % outputInterval == 0) || ii == numIterations) { + try { + this.report(outputFolder, topWords); + } catch (IOException e) { + System.out.println(e.getMessage()); + } + } + } + } + + ////////////////////////////////////////////////////// + + /** + * This function returns the likelihood. + */ + public double lhood() { + return this.docLHood() + this.topics.topicLHood(); + } + + /** + * By implementing the comparable interface, this function ranks the words + * in each topic, and returns the top words for each topic. + */ + public String displayTopWords (int numWords) { + + + + StringBuilder out = new StringBuilder(); + int numPaths = this.topics.getPathNum(); + //System.out.println(numPaths); + + for (int tt = 0; tt < this.numTopics; tt++){ + String tmp = "\n--------------\nTopic " + tt + "\n------------------------\n"; + //System.out.print(tmp); + out.append(tmp); + WordProb[] wp = new WordProb[numPaths]; + for (int pp = 0; pp < numPaths; pp++){ + int ww = this.topics.getWordFromPath(pp); + double val = this.topics.computeTopicPathProb(tt, ww, pp); + wp[pp] = new WordProb(pp, val); + } + Arrays.sort(wp); + for (int ii = 0; ii < wp.length; ii++){ + if(ii >= numWords) { + break; + } + int pp = wp[ii].wi; + int ww = this.topics.getWordFromPath(pp); + String word = this.vocab.get(ww); + if (this.removedWords.indexOf(word) == -1 && this.removedWordsNew.indexOf(word) == -1) { + tmp = wp[ii].p + "\t" + word + "\n"; + out.append(tmp); + } + } + } + return out.toString(); + } + + /** + * Prints the topic word distributions. + */ + public void printTopicWords (File file) throws IOException { + + PrintStream out = new PrintStream (file); + int numPaths = this.topics.getPathNum(); + String tmp; + for (int tt = 0; tt < this.numTopics; tt++){ + WordProb[] wp = new WordProb[numPaths]; + for (int pp = 0; pp < numPaths; pp++){ + int ww = this.topics.getWordFromPath(pp); + double val = this.topics.computeTopicPathProb(tt, ww, pp); + wp[pp] = new WordProb(pp, val); + } + Arrays.sort(wp); + for (int ii = 0; ii < wp.length; ii++){ + int pp = wp[ii].wi; + int ww = this.topics.getWordFromPath(pp); + String word = this.vocab.get(ww); + if (this.removedWords.indexOf(word) == -1 && this.removedWordsNew.indexOf(word) == -1) { + tmp = tt + "\t" + word + "\t" + wp[ii].p; + out.println(tmp); + } + } + } + + out.close(); + } + + /** + * Prints the topic and path of each word for all documents. + */ + public void printState (File file) throws IOException { + //PrintStream out = + // new PrintStream(new GZIPOutputStream(new BufferedOutputStream(new FileOutputStream(file)))); + PrintStream out = new PrintStream(file); + + for (int dd = 0; dd < this.data.size(); dd++) { + DocData doc = this.data.get(dd); + String tmp = ""; + for (int ww = 0; ww < doc.topics.size(); ww++) { + int topic = doc.topics.get(ww); + int path = doc.paths.get(ww); + int wordpath = this.topics.pathToWordPath.get(path); + tmp += topic + ":" + wordpath + "\t"; + } + out.println(tmp); + } + out.close(); + } + + public TreeTopicInferencer getInferencer() { + //this.topics.updateParams(); + HashSet removedall = new HashSet (); + removedall.addAll(this.removedWords); + removedall.addAll(this.removedWordsNew); + TreeTopicInferencer inferencer = new TreeTopicInferencer(topics, vocab, removedall, alpha); + return inferencer; + } + + public TreeMarginalProbEstimator getProbEstimator() { + HashSet removedall = new HashSet (); + removedall.addAll(this.removedWords); + removedall.addAll(this.removedWordsNew); + TreeMarginalProbEstimator estimator = new TreeMarginalProbEstimator(topics, vocab, removedall, alpha); + return estimator; + } +} diff --git a/src/cc/mallet/topics/tree/TreeTopicSamplerInterface.java b/src/cc/mallet/topics/tree/TreeTopicSamplerInterface.java new file mode 100755 index 000000000..2ed1373ec --- /dev/null +++ b/src/cc/mallet/topics/tree/TreeTopicSamplerInterface.java @@ -0,0 +1,80 @@ +package cc.mallet.topics.tree; + +import gnu.trove.TDoubleArrayList; +import gnu.trove.TIntArrayList; +import gnu.trove.TIntHashSet; +import gnu.trove.TIntIntHashMap; +import gnu.trove.TIntIntIterator; + +import java.io.BufferedOutputStream; +import java.io.BufferedReader; +import java.io.DataInputStream; +import java.io.File; +import java.io.FileInputStream; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.InputStreamReader; +import java.io.PrintStream; +import java.io.Serializable; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.zip.GZIPOutputStream; + +import cc.mallet.types.Dirichlet; +import cc.mallet.types.FeatureSequence; +import cc.mallet.types.Instance; +import cc.mallet.types.InstanceList; +import cc.mallet.util.Randoms; + +/** + * This class defines the interface of a tree topic sampler. + * + * @author Yuening Hu + */ + +public interface TreeTopicSamplerInterface { + + /* Implemented in TreeTopicSampler.java. + Shared code by TreeTopicSamplerSortD.java and TreeTopicSamplerHashD.java + */ + public void setNumIterations(int iters); + public void resume(InstanceList[] training, String resumeDir); + + // Also implemented in TreeTopicSampler.java, but do not need to be defined in interface. + //public int getNumIterations(); + //public void resumeLHood(String lhoodFile) throws IOException; + //public void report(String outputDir, int topWords) throws IOException; + //public void printTopWords(File file, int numWords) throws IOException; + //public void printStats (File file) throws IOException; + //public void loadVocab(String vocabFile); + //public void loadStopWords(String stopwordFile); + //public void loadConstraints(String consFile); + //abstract public void sampleDoc(int doc); + + + /* Same code for TreeTopicSamplerSortD.java and TreeTopicSamplerHashD.java + But related with this.topics, so not the the shared parent class. + */ + public void initialize(String treeFiles, String hyperFile, String vocabFile, String removedwordsFile); + public void estimate(int numIterations, String outputFolder, int outputInterval, int topWords); + public TreeTopicInferencer getInferencer(); + public TreeMarginalProbEstimator getProbEstimator(); + // do not need to be defined in interface. + //public double lhood(); + //public String displayTopWords (int numWords); + //public void printState (File file) throws IOException; + + + + /* Different code for TreeTopicSamplerSortD.java and TreeTopicSamplerHashD.java + Stay in these two java files separately. + */ + public void addInstances(InstanceList[] training); + public void clearTopicAssignments(String option, String consFile, String keepFile); + // Do not need to be defined in interface. + //public void resumeStates(InstanceList training, String statesFile) throws IOException; + //public void changeTopic(int doc, int index, int word, int new_topic, int new_path); + //public double docLHood(); + //public void printDocumentTopics (File file) throws IOException; + +} \ No newline at end of file diff --git a/src/cc/mallet/topics/tree/TreeTopicSamplerNaive.java b/src/cc/mallet/topics/tree/TreeTopicSamplerNaive.java new file mode 100755 index 000000000..e9fff4e1a --- /dev/null +++ b/src/cc/mallet/topics/tree/TreeTopicSamplerNaive.java @@ -0,0 +1,98 @@ +package cc.mallet.topics.tree; + +import gnu.trove.TDoubleArrayList; +import gnu.trove.TIntArrayList; +import gnu.trove.TIntDoubleHashMap; +import gnu.trove.TIntDoubleIterator; +import gnu.trove.TIntHashSet; +import gnu.trove.TIntIntHashMap; +import gnu.trove.TIntIntIterator; +import gnu.trove.TIntObjectHashMap; +import gnu.trove.TIntObjectIterator; +import gnu.trove.TObjectIntHashMap; + +import java.io.BufferedOutputStream; +import java.io.BufferedReader; +import java.io.DataInputStream; +import java.io.File; +import java.io.FileInputStream; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.InputStreamReader; +import java.io.PrintStream; +import java.io.PrintWriter; +import java.io.Serializable; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.zip.GZIPOutputStream; + +import cc.mallet.types.Alphabet; +import cc.mallet.types.Dirichlet; +import cc.mallet.types.FeatureSequence; +import cc.mallet.types.Instance; +import cc.mallet.types.InstanceList; +import cc.mallet.types.LabelAlphabet; +import cc.mallet.util.Randoms; + +/** + * This class defines a naive tree topic sampler. + * It calls the naive tree topic model. + * + * @author Yuening Hu + */ + +public class TreeTopicSamplerNaive extends TreeTopicSamplerHashD { + + public TreeTopicSamplerNaive (int numberOfTopics, double alphaSum) { + this (numberOfTopics, alphaSum, 0); + } + + public TreeTopicSamplerNaive (int numberOfTopics, double alphaSum, int seed) { + super (numberOfTopics, alphaSum, seed); + this.topics = new TreeTopicModelNaive(this.numTopics, this.random); + } + + /** + * For each word in a document, firstly covers its topic and path, then sample a + * topic and path, and update. + */ + public void sampleDoc(int doc_id){ + DocData doc = this.data.get(doc_id); + //System.out.println("doc " + doc_id); + + for(int ii = 0; ii < doc.tokens.size(); ii++) { + //int word = doc.tokens.getIndexAtPosition(ii); + int word = doc.tokens.get(ii); + + this.changeTopic(doc_id, ii, word, -1, -1); + ArrayList topic_term_score = new ArrayList(); + double norm = this.topics.computeTopicTerm(this.alpha, doc.topicCounts, word, topic_term_score); + //System.out.println(norm); + + int new_topic = -1; + int new_path = -1; + + double sample = this.random.nextDouble(); + //double sample = 0.8; + sample *= norm; + + for(int jj = 0; jj < topic_term_score.size(); jj++) { + double[] tmp = topic_term_score.get(jj); + int tt = (int) tmp[0]; + int pp = (int) tmp[1]; + double val = tmp[2]; + sample -= val; + if (sample <= 0.0) { + new_topic = tt; + new_path = pp; + break; + } + } + + myAssert((new_topic >= 0 && new_topic < numTopics), "something wrong in sampling!"); + + this.changeTopic(doc_id, ii, word, new_topic, new_path); + } + } + +} diff --git a/src/cc/mallet/topics/tree/TreeTopicSamplerSortD.java b/src/cc/mallet/topics/tree/TreeTopicSamplerSortD.java new file mode 100755 index 000000000..bdedd506b --- /dev/null +++ b/src/cc/mallet/topics/tree/TreeTopicSamplerSortD.java @@ -0,0 +1,684 @@ +package cc.mallet.topics.tree; + +import gnu.trove.TDoubleArrayList; +import gnu.trove.TIntArrayList; +import gnu.trove.TIntHashSet; +import gnu.trove.TIntIntHashMap; +import gnu.trove.TIntIntIterator; +import gnu.trove.TIntObjectHashMap; + +import java.io.BufferedReader; +import java.io.DataInputStream; +import java.io.File; +import java.io.FileInputStream; +import java.io.IOException; +import java.io.InputStreamReader; +import java.io.PrintStream; +import java.io.Serializable; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.HashSet; + +import cc.mallet.topics.tree.TreeTopicSamplerHashD.DocData; +import cc.mallet.types.Dirichlet; +import cc.mallet.types.FeatureSequence; +import cc.mallet.types.IDSorter; +import cc.mallet.types.Instance; +import cc.mallet.types.InstanceList; +import cc.mallet.util.Randoms; + +/** + * This class defines the tree topic sampler, which loads the instances, + * reports the topics, and leaves the sampler method as an abstract method, + * which might be various for different methods. + * Rathan than a HashMap for topicCounts in TreeTopicSamplerHashD, + * this class uses a sorted ArrayList for topicCounts. + * + * @author Yuening Hu + */ + +public abstract class TreeTopicSamplerSortD extends TreeTopicSampler implements TreeTopicSamplerInterface { + + /** + * This class defines the format of a document. + */ + public class DocData { + TIntArrayList tokens; + TIntArrayList topics; + TIntArrayList paths; + // sort + ArrayList topicCounts; + String docName; + + public DocData (String name, TIntArrayList tokens, TIntArrayList topics, + TIntArrayList paths, ArrayList topicCounts) { + this.docName = name; + this.tokens = tokens; + this.topics = topics; + this.paths = paths; + this.topicCounts = topicCounts; + } + + public String toString() { + String result = "***************\n"; + result += docName + "\n"; + + result += "tokens: "; + for (int jj = 0; jj < tokens.size(); jj++) { + int index = tokens.get(jj); + String word = vocab.get(index); + result += word + " " + index + ", "; + } + + result += "\ntopics: "; + result += topics.toString(); + + result += "\npaths: "; + result += paths.toString(); + + result += "\ntopicCounts: "; + + for(int ii = 0; ii < this.topicCounts.size(); ii++) { + int[] tmp = this.topicCounts.get(ii); + result += "Topic " + tmp[0] + ": " + tmp[1] + ", "; + } + + result += "\n*****************\n"; + return result; + } + } + + public class WordProb implements Comparable { + int wi; + double p; + public WordProb (int wi, double p) { this.wi = wi; this.p = p; } + public final int compareTo (Object o2) { + if (p > ((WordProb)o2).p) + return -1; + else if (p == ((WordProb)o2).p) + return 0; + else return 1; + } + } + + TreeTopicModel topics; + ArrayList data; + + public TreeTopicSamplerSortD (int numberOfTopics, double alphaSum, int seed) { + super(numberOfTopics, alphaSum, seed); + this.data = new ArrayList (); + + // notice: this.topics is not initialized in this abstract class, + // in each sub class, the topics variable is initialized differently. + } + + /** + * This function adds instances given the training data in mallet input data format. + * For each token in a document, sample a topic and then sample a path based on prior. + */ + public void addInstances(InstanceList[] training) { + boolean debug = false; + int count = 0; + for(int ll = 0; ll < training.length; ll++) { + for (Instance instance : training[ll]) { + count++; + FeatureSequence original_tokens = (FeatureSequence) instance.getData(); + String name = instance.getName().toString(); + //String name = "null-source"; + //if (instance.getSource() != null) { + // name = instance.getSource().toString(); + //} + + // *** remained problem: keep topicCounts sorted + TIntArrayList tokens = new TIntArrayList(original_tokens.getLength()); + ArrayList topicCounts = new ArrayList (); + TIntArrayList topics = new TIntArrayList(original_tokens.getLength()); + TIntArrayList paths = new TIntArrayList(original_tokens.getLength()); + + for (int jj = 0; jj < original_tokens.getLength(); jj++) { + String word = (String) original_tokens.getObjectAtPosition(jj); + int token = this.vocab.indexOf(word); + int removed = this.removedWordsNew.indexOf(word); + int removednew = this.removedWordsNew.indexOf(word); + if(token != -1 && removed == -1 && removednew == -1) { + int topic = random.nextInt(numTopics); + if(debug) { topic = count % numTopics; } + tokens.add(token); + topics.add(topic); + //topicCounts.adjustOrPutValue(topic, 1, 1); + this.updateTopicCounts(topicCounts, topic, 1, 1); + // sample a path for this topic + int path_index = this.topics.initialize(token, topic); + paths.add(path_index); + } + } + + DocData doc = new DocData(name, tokens, topics, paths, topicCounts); + this.data.add(doc); + } + + //System.out.println(doc); + } + + } + + /** + * This function keeps the topicCounts in order by bubble sort. + */ + private void updateTopicCounts(ArrayList topicCounts, int topic, int adjustvalue, int putvalue) { + + // remove old value + int value = -1; + for(int ii = 0; ii < topicCounts.size(); ii++) { + int[] tmp = topicCounts.get(ii); + if(tmp[0] == topic) { + value = tmp[1]; + topicCounts.remove(ii); + break; + } + } + + // adjust the value and update or insert + if (value == -1) { + value = putvalue; + } else { + value += adjustvalue; + } + + if (value > 0) { + int index = topicCounts.size(); + for(int ii = 0; ii < topicCounts.size(); ii++) { + int[] tmp = topicCounts.get(ii); + if(value >= tmp[1]) { + index = ii; + break; + } + } + int[] newpair = {topic, value}; + topicCounts.add(index, newpair); + } + + } + + /** + * Resume instance states from the saved states file. + */ + public void resumeStates(InstanceList[] training, String statesFile) throws IOException{ + FileInputStream statesfstream = new FileInputStream(statesFile); + DataInputStream statesdstream = new DataInputStream(statesfstream); + BufferedReader states = new BufferedReader(new InputStreamReader(statesdstream)); + + // reading topics, paths + for(int ll = 0; ll < training.length; ll++) { + for (Instance instance : training[ll]) { + FeatureSequence original_tokens = (FeatureSequence) instance.getData(); + String name = instance.getName().toString(); + + // *** remained problem: keep topicCounts sorted + TIntArrayList tokens = new TIntArrayList(original_tokens.getLength()); + ArrayList topicCounts = new ArrayList (); + TIntArrayList topics = new TIntArrayList(original_tokens.getLength()); + TIntArrayList paths = new TIntArrayList(original_tokens.getLength()); + + // + String statesLine = states.readLine(); + myAssert(statesLine != null, "statesFile doesn't match with the training data"); + statesLine = statesLine.trim(); + String[] str = statesLine.split("\t"); + + int count = -1; + for (int jj = 0; jj < original_tokens.getLength(); jj++) { + String word = (String) original_tokens.getObjectAtPosition(jj); + int token = this.vocab.indexOf(word); + int removed = this.removedWords.indexOf(word); + int removednew = this.removedWordsNew.indexOf(word); + if(token != -1 && removed == -1) { + count++; + if (removednew == -1) { + String[] tp = str[count].split(":"); + myAssert(tp.length == 2, "statesFile problem!"); + int topic = Integer.parseInt(tp[0]); + int wordpath = Integer.parseInt(tp[1]); + int path = -1; + int backoffpath = -1; + // find the path for this wordpath + TIntObjectHashMap allpaths = this.topics.wordPaths.get(token); + for(int pp : allpaths.keys()) { + if(backoffpath == -1 && this.topics.pathToWordPath.get(pp) == 0){ + backoffpath = pp; + } + if(this.topics.pathToWordPath.get(pp) == wordpath){ + path = pp; + break; + } + } + + if(path == -1) { + // this path must be in a correlation, it will be cleared later + path = backoffpath; + myAssert(path != -1, "path problem"); + } + tokens.add(token); + topics.add(topic); + paths.add(path); + //topicCounts.adjustOrPutValue(topic, 1, 1); + this.updateTopicCounts(topicCounts, topic, 1, 1); + this.topics.changeCountOnly(topic, token, path, 1); + } + } + } + if(count != -1) { + count++; + myAssert(str.length == count, "resume problem!"); + } + + DocData doc = new DocData(name, tokens, topics, paths, topicCounts); + this.data.add(doc); + } + } + states.close(); + } + + /** + * This function clears the topic and path assignments for some words: + * (1) term option: only clears the topic and path for constraint words; + * (2) doc option: clears the topic and path for documents which contain + * at least one of the constraint words. + */ + public void clearTopicAssignments(String option, String consFile, String keepFile) { + this.loadConstraints(consFile); + if (this.cons == null || this.cons.size() <= 0) { + return; + } + + if (keepFile != null) { + this.loadKeepList(keepFile); + } else { + this.topickeep = new HashMap(); + } + + for(int dd = 0; dd < this.data.size(); dd++) { + DocData doc = this.data.get(dd); + + for(int ii = 0; ii < doc.tokens.size(); ii++) { + int word = doc.tokens.get(ii); + int topic = doc.topics.get(ii); + int path = doc.paths.get(ii); + + boolean keepTopicFlag = false; + if(this.topickeep.containsKey(word)) { + TIntHashSet keeptopics = this.topickeep.get(word); + if(keeptopics.contains(topic)) { + keepTopicFlag = true; + } + } + + if (option.equals("term")) { + if(this.cons.contains(word) && (!keepTopicFlag)) { + // change the count for count and node_count in TopicTreeWalk + this.topics.changeCountOnly(topic, word, path, -1); + doc.topics.set(ii, -1); + doc.paths.set(ii, -1); + this.updateTopicCounts(doc.topicCounts, topic, -1, 0); + } + } else { // option.equals("doc") + if(!keepTopicFlag) { + this.topics.changeCountOnly(topic, word, path, -1); + doc.topics.set(ii, -1); + doc.paths.set(ii, -1); + this.updateTopicCounts(doc.topicCounts, topic, -1, 0); + } + } + } + } + +// for(int dd = 0; dd < this.data.size(); dd++) { +// DocData doc = this.data.get(dd); +// Boolean flag = false; +// for(int ii = 0; ii < doc.tokens.size(); ii++) { +// int word = doc.tokens.get(ii); +// int topic = doc.topics.get(ii); +// +// boolean keepTopicFlag = false; +// if(this.topickeep.containsKey(word)) { +// TIntHashSet keeptopics = this.topickeep.get(word); +// if(keeptopics.contains(topic)) { +// keepTopicFlag = true; +// } +// } +// +// if(this.cons.contains(word) && (!keepTopicFlag)) { +// if (option.equals("term")) { +// // change the count for count and node_count in TopicTreeWalk +// int path = doc.paths.get(ii); +// this.topics.changeCountOnly(topic, word, path, -1); +// doc.topics.set(ii, -1); +// doc.paths.set(ii, -1); +// //myAssert(doc.topicCounts.get(topic) >= 1, "clear topic assignments problem"); +// //doc.topicCounts.adjustValue(topic, -1); +// this.updateTopicCounts(doc.topicCounts, topic, -1, 0); +// } else if (option.equals("doc")) { +// flag = true; +// break; +// } +// } +// } +// if (flag) { +// for(int ii = 0; ii < doc.tokens.size(); ii++) { +// int word = doc.tokens.get(ii); +// int topic = doc.topics.get(ii); +// int path = doc.paths.get(ii); +// this.topics.changeCountOnly(topic, word, path, -1); +// doc.topics.set(ii, -1); +// doc.paths.set(ii, -1); +// } +// doc.topicCounts.clear(); +// } +// } + } + + /** + * This function defines how to change a topic during the sampling process. + * It handles the case where both new_topic and old_topic are "-1" (empty topic). + */ + public void changeTopic(int doc, int index, int word, int new_topic, int new_path) { + DocData current_doc = this.data.get(doc); + int old_topic = current_doc.topics.get(index); + int old_path = current_doc.paths.get(index); + + if (old_topic != -1) { + myAssert((new_topic == -1 && new_path == -1), "old_topic != -1 but new_topic != -1"); + this.topics.changeCount(old_topic, word, old_path, -1); + //myAssert(current_doc.topicCounts.get(old_topic) > 0, "Something wrong in changTopic"); + this.updateTopicCounts(current_doc.topicCounts, old_topic, -1, 0); + current_doc.topics.set(index, -1); + current_doc.paths.set(index, -1); + } + + if (new_topic != -1) { + myAssert((old_topic == -1 && old_path == -1), "new_topic != -1 but old_topic != -1"); + this.topics.changeCount(new_topic, word, new_path, 1); + this.updateTopicCounts(current_doc.topicCounts, new_topic, 1, 1); + current_doc.topics.set(index, new_topic); + current_doc.paths.set(index, new_path); + } + } + + /** + * The function computes the document likelihood. + */ + public double docLHood() { + int docNum = this.data.size(); + + double val = 0.0; + val += Dirichlet.logGamma(this.alphaSum) * docNum; + double tmp = 0.0; + for (int tt = 0; tt < this.numTopics; tt++) { + tmp += Dirichlet.logGamma(this.alpha[tt]); + } + val -= tmp * docNum; + for (int dd = 0; dd < docNum; dd++) { + DocData doc = this.data.get(dd); + + int[] tmpTopics = new int[this.numTopics]; + for(int ii = 0; ii < this.numTopics; ii++) { + tmpTopics[ii] = 0; + } + for(int ii = 0; ii < doc.topicCounts.size(); ii++) { + int[] current = doc.topicCounts.get(ii); + int tt = current[0]; + tmpTopics[tt] = current[1]; + } + for(int tt = 0; tt < tmpTopics.length; tt++) { + val += Dirichlet.logGamma(this.alpha[tt] + tmpTopics[tt]); + } + + val -= Dirichlet.logGamma(this.alphaSum + doc.topics.size()); + } + return val; + } + + /** + * Print the topic proportion for all documents. + */ + public void printDocumentTopics (File file) throws IOException { + PrintStream out = new PrintStream (file); + out.print ("#doc source topic proportion ...\n"); + + IDSorter[] sortedTopics = new IDSorter[ this.numTopics ]; + for (int topic = 0; topic < this.numTopics; topic++) { + // Initialize the sorters with dummy values + sortedTopics[topic] = new IDSorter(topic, topic); + } + + for (int dd = 0; dd < this.data.size(); dd++) { + DocData doc = this.data.get(dd); + + // compute topic proportion in one document + double sum = 0.0; + double[] prob = new double[this.numTopics]; + + // initialize + for (int topic=0; topic < this.numTopics; topic++) { + prob[topic] = -1; + } + + // topic counts + for (int ii = 0; ii < doc.topicCounts.size(); ii++) { + int[] current = doc.topicCounts.get(ii); + int topic = current[0]; + prob[topic] = this.alpha[topic] + current[1]; + } + + for (int topic=0; topic < this.numTopics; topic++) { + if (prob[topic] == -1) { + prob[topic] = this.alpha[topic]; + } + sum += prob[topic]; + } + + // normalize and sort + for (int topic=0; topic < this.numTopics; topic++) { + prob[topic] /= sum; + sortedTopics[topic].set(topic, prob[topic]); + } + Arrays.sort(sortedTopics); + + // print one document + out.print (dd); out.print (" "); + + if (doc.docName != null || !doc.docName.equals(" ")) { + out.print (doc.docName); + } else { + out.print ("null-source"); + } + out.print (" "); + for (int i = 0; i < numTopics; i++) { + out.print (sortedTopics[i].getID() + " " + + sortedTopics[i].getWeight() + " "); + } + out.print (" \n"); + } + out.close(); + } + + + + ///////////////////////////////////////////////////////////// + /** + * This function loads vocab, loads tree, and initialize parameters. + */ + public void initialize(String treeFiles, String hyperFile, String vocabFile, String removedwordsFile) { + this.loadVocab(vocabFile); + if (removedwordsFile != null) { + this.loadRemovedWords(removedwordsFile + ".all", this.removedWords); + this.loadRemovedWords(removedwordsFile + ".new", this.removedWordsNew); + } + this.topics.initializeParams(treeFiles, hyperFile, this.vocab); + } + + /** + * This function defines the sampling process, computes the likelihood and running time, + * and specifies when to save the states files. + */ + public void estimate(int numIterations, String outputFolder, int outputInterval, int topWords) { + // update parameters + this.topics.updateParams(); + + if (this.startIter > this.numIterations) { + System.out.println("Have already sampled " + this.numIterations + " iterations!"); + System.exit(0); + } + System.out.println("Start sampling for iteration " + this.startIter); + + for (int ii = this.startIter; ii <= numIterations; ii++) { + long starttime = System.currentTimeMillis(); + //System.out.println("Iter " + ii); + for (int dd = 0; dd < this.data.size(); dd++) { + this.sampleDoc(dd); + if (dd > 0 && dd % 10000 == 0) { + System.out.println("Sampled " + dd + " documents."); + } + } + double totaltime = (double)(System.currentTimeMillis() - starttime) / 1000; + double lhood = 0; + if ((ii > 0 && ii % outputInterval == 0) || ii == numIterations) { + lhood = this.lhood(); + } + this.lhood.add(lhood); + this.iterTime.add(totaltime); + + if (ii % 10 == 0) { + String tmp = "Iteration " + ii; + tmp += " likelihood " + lhood; + tmp += " totaltime " + totaltime; + System.out.println(tmp); + } + + if ((ii > 0 && ii % outputInterval == 0) || ii == numIterations) { + try { + this.report(outputFolder, topWords); + } catch (IOException e) { + System.out.println(e.getMessage()); + } + } + } + } + + ///////////////////////////////////////////////////////////// + + /** + * This function returns the likelihood. + */ + public double lhood() { + return this.docLHood() + this.topics.topicLHood(); + } + + /** + * By implementing the comparable interface, this function ranks the words + * in each topic, and returns the top words for each topic. + */ + public String displayTopWords (int numWords) { + + StringBuilder out = new StringBuilder(); + int numPaths = this.topics.getPathNum(); + //System.out.println(numPaths); + + for (int tt = 0; tt < this.numTopics; tt++){ + String tmp = "\n--------------\nTopic " + tt + "\n------------------------\n"; + //System.out.print(tmp); + out.append(tmp); + WordProb[] wp = new WordProb[numPaths]; + for (int pp = 0; pp < numPaths; pp++){ + int ww = this.topics.getWordFromPath(pp); + double val = this.topics.computeTopicPathProb(tt, ww, pp); + wp[pp] = new WordProb(pp, val); + } + Arrays.sort(wp); + for (int ii = 0; ii < wp.length; ii++){ + if(ii >= numWords) { + break; + } + int pp = wp[ii].wi; + int ww = this.topics.getWordFromPath(pp); + String word = this.vocab.get(ww); + if (this.removedWords.indexOf(word) == -1 && this.removedWordsNew.indexOf(word) == -1) { + tmp = wp[ii].p + "\t" + word + "\n"; + out.append(tmp); + } + } + } + return out.toString(); + } + + /** + * Prints the topic word distributions. + */ + public void printTopicWords (File file) throws IOException { + + PrintStream out = new PrintStream (file); + int numPaths = this.topics.getPathNum(); + String tmp; + + for (int tt = 0; tt < this.numTopics; tt++){ + + WordProb[] wp = new WordProb[numPaths]; + for (int pp = 0; pp < numPaths; pp++){ + int ww = this.topics.getWordFromPath(pp); + double val = this.topics.computeTopicPathProb(tt, ww, pp); + wp[pp] = new WordProb(pp, val); + } + Arrays.sort(wp); + for (int ii = 0; ii < wp.length; ii++){ + int pp = wp[ii].wi; + int ww = this.topics.getWordFromPath(pp); + String word = this.vocab.get(ww); + if (this.removedWords.indexOf(word) == -1 && this.removedWordsNew.indexOf(word) == -1) { + tmp = tt + "\t" + word + "\t" + wp[ii].p; + out.println(tmp); + } + } + } + out.close(); + } + + /** + * Prints the topic and path of each word for all documents. + */ + public void printState (File file) throws IOException { + //PrintStream out = + // new PrintStream(new GZIPOutputStream(new BufferedOutputStream(new FileOutputStream(file)))); + PrintStream out = new PrintStream(file); + + for (int dd = 0; dd < this.data.size(); dd++) { + DocData doc = this.data.get(dd); + String tmp = ""; + for (int ww = 0; ww < doc.topics.size(); ww++) { + int topic = doc.topics.get(ww); + int path = doc.paths.get(ww); + int wordpath = this.topics.pathToWordPath.get(path); + tmp += topic + ":" + wordpath + "\t"; + } + out.println(tmp); + } + out.close(); + } + + public TreeTopicInferencer getInferencer() { + //this.topics.updateParams(); + HashSet removedall = new HashSet (); + removedall.addAll(this.removedWords); + removedall.addAll(this.removedWordsNew); + TreeTopicInferencer inferencer = new TreeTopicInferencer(topics, vocab, removedall, alpha); + return inferencer; + } + + public TreeMarginalProbEstimator getProbEstimator() { + HashSet removedall = new HashSet (); + removedall.addAll(this.removedWords); + removedall.addAll(this.removedWordsNew); + TreeMarginalProbEstimator estimator = new TreeMarginalProbEstimator(topics, vocab, removedall, alpha); + return estimator; + } +} diff --git a/src/cc/mallet/topics/tree/Utils.java b/src/cc/mallet/topics/tree/Utils.java new file mode 100755 index 000000000..5a5c1db68 --- /dev/null +++ b/src/cc/mallet/topics/tree/Utils.java @@ -0,0 +1,82 @@ +package cc.mallet.topics.tree; + +import java.io.BufferedReader; +import java.io.FileReader; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.Comparator; +import java.util.Iterator; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; + + +public class Utils { + /** + * Add an item to a map of counts. + */ + public static void addToMap(Map map, String word) { + int count = 0; + if (map.containsKey(word)) + count = map.get(word); + map.put(word, count+1); + } + + /** + * Sort a map by value and return a list of the sorted keys. + * + * adapted from: + * http://www.programmersheaven.com/download/49349/download.aspx + * + */ + public static List sortByValue(Map map) { + List> list = new LinkedList>( + map.entrySet()); + Collections.sort(list, new Comparator() { + public int compare(Object o1, Object o2) { + return ((Comparable) ((Map.Entry) (o2)).getValue()) + .compareTo(((Map.Entry) (o1)).getValue()); + } + }); + // logger.info(list); + List result = new ArrayList(); + for (Iterator> it = list.iterator(); it.hasNext();) { + Map.Entry entry = (Map.Entry) it.next(); + result.add(entry.getKey()); + } + return result; + } + + /** + * Read all the lines in a file and return them in a list. + */ + public static List readAll(String filename) throws Exception { + List lines = new ArrayList(); + BufferedReader reader = new BufferedReader(new FileReader(filename)); + String line = ""; + while ((line = reader.readLine()) != null) + lines.add(line); + reader.close(); + return lines; + } + + /** + * Converts a list of strings into a single space-separated string. + */ + public static String listToString(List words) { + String str = ""; + for (String word : words) { + str += " " + word; + } + return str.substring(1); + } + + /** + * Converts a space-separated string of words into list form. + */ + public static List stringToList(String str) { + String[] parts = str.toLowerCase().split("\\s+"); + return Arrays.asList(parts); + } +} diff --git a/src/cc/mallet/topics/tree/VocabGenerator.java b/src/cc/mallet/topics/tree/VocabGenerator.java new file mode 100755 index 000000000..8a6ea7360 --- /dev/null +++ b/src/cc/mallet/topics/tree/VocabGenerator.java @@ -0,0 +1,238 @@ +package cc.mallet.topics.tree; + +import java.io.File; +import java.io.IOException; +import java.io.PrintStream; +import java.util.Arrays; +import java.util.HashMap; +import java.util.HashSet; + +import gnu.trove.TIntArrayList; +import gnu.trove.TIntHashSet; +import gnu.trove.TIntIntHashMap; +import gnu.trove.TObjectDoubleHashMap; +import gnu.trove.TObjectIntHashMap; +import cc.mallet.topics.tree.TreeTopicSamplerHashD.DocData; +import cc.mallet.types.Alphabet; +import cc.mallet.types.FeatureSequence; +import cc.mallet.types.Instance; +import cc.mallet.types.InstanceList; +import cc.mallet.util.Maths; + + +/** + * This class generates the vocab file from mallet input. + * This generated vocab can be filtered by either frequency or tfidf. + * Tree-based topic model need this vocab for: + * (1) filter words more flexible + * (2) generate tree structure + * (3) allow removing words + * Main entrance: genVocab() + * + * @author Yuening Hu + */ + +public class VocabGenerator { + public static TObjectDoubleHashMap getIdf(InstanceList data) { + // get idf + TObjectDoubleHashMap idf = new TObjectDoubleHashMap (); + + for (Instance instance : data) { + FeatureSequence original_tokens = (FeatureSequence) instance.getData(); + HashSet words = new HashSet(); + for (int jj = 0; jj < original_tokens.getLength(); jj++) { + String word = (String) original_tokens.getObjectAtPosition(jj); + words.add(word); + } + for(String word : words) { + idf.adjustOrPutValue(word, 1, 1); + } + } + + int D = data.size(); + for(Object ob : idf.keys()){ + String word = (String) ob; + double value = D / (1 + idf.get(word)); + value = Math.log(value) - idf.get(word); + idf.adjustValue(word, value); + } + + System.out.println("Idf size: " + idf.size()); + return idf; + } + + public static TObjectDoubleHashMap computeTfidf(InstanceList data) { + // get idf + TObjectDoubleHashMap idf = getIdf(data); + + // compute tf-idf for each word + HashMap> tfidf = new HashMap> (); + + for (Instance instance : data) { + FeatureSequence original_tokens = (FeatureSequence) instance.getData(); + TObjectIntHashMap tf = new TObjectIntHashMap(); + for (int jj = 0; jj < original_tokens.getLength(); jj++) { + String word = (String) original_tokens.getObjectAtPosition(jj); + tf.adjustOrPutValue(word, 1, 1); + } + for(Object ob : tf.keys()) { + String word = (String) ob; + HashSet values; + if (tfidf.containsKey(word)) { + values = tfidf.get(word); + } else { + values = new HashSet (); + tfidf.put(word, values); + } + double value = tf.get(word) * idf.get(word); + values.add(value); + } + } + + // averaged tfidf + TObjectDoubleHashMap vocabtfidf = new TObjectDoubleHashMap(); + for(String word : tfidf.keySet()) { + double sum = 0; + int count = tfidf.get(word).size(); + for(double value : tfidf.get(word)) { + sum += value; + } + sum = sum / count; + vocabtfidf.put(word, sum); + } + + System.out.println("vocab tfidf size: " + vocabtfidf.size()); + return vocabtfidf; + } + + public static TObjectDoubleHashMap getFrequency (InstanceList data) { + TObjectDoubleHashMap freq = new TObjectDoubleHashMap (); + Alphabet alphabet = data.getAlphabet(); + for(int ii = 0; ii < alphabet.size(); ii++) { + String word = alphabet.lookupObject(ii).toString(); + freq.put(word, 0); + } + + for (Instance instance : data) { + FeatureSequence original_tokens = (FeatureSequence) instance.getData(); + for (int jj = 0; jj < original_tokens.getLength(); jj++) { + String word = (String) original_tokens.getObjectAtPosition(jj); + freq.adjustValue(word, 1); + } + } + + System.out.println("Alphabet size: " + alphabet.size()); + System.out.println("Frequency size: " + freq.size()); + return freq; + } + + public static void genVocab_all(InstanceList data, String vocab, Boolean tfidfRank, double tfidfthresh, double freqthresh, double wordlength) { + //public static void genVocab(InstanceList data, String vocab) { + try{ + File file = new File(vocab); + PrintStream out = new PrintStream (file); + + int language_id = 0; + Alphabet alphabet = data.getAlphabet(); + for(int ii = 0; ii < alphabet.size(); ii++) { + String word = alphabet.lookupObject(ii).toString(); + System.out.println(word); + out.println(language_id + "\t" + word); + } + out.close(); + } catch (IOException e) { + e.getMessage(); + } + + } + + /** + * After the preprocessing of mallet, a vocab is needed to generate + * the prior tree. So this function simply read in the alphabet + * of the training data, filter the words either by frequency or tfidf, + * then output the vocab. + * Currently, the language_id is fixed. + */ + public static void genVocab(InstanceList[] data, String vocab, Boolean tfidfRank, double tfidfthresh, double freqthresh, double wordlength) { + + class WordCount implements Comparable { + String word; + double value; + public WordCount (String word, double value) { this.word = word; this.value = value; } + public final int compareTo (Object o2) { + if (value > ((WordCount)o2).value) + return -1; + else if (value == ((WordCount)o2).value) + return 0; + else return 1; + } + } + + + try{ + File file = new File(vocab); + PrintStream out = new PrintStream (file, "UTF8"); + + HashSet allwords = new HashSet (); + for (int ll = 0; ll < data.length; ll++) { + System.out.println("Language " + ll); + TObjectDoubleHashMap freq = getFrequency(data[ll]); + TObjectDoubleHashMap tfidf = computeTfidf(data[ll]); + TObjectDoubleHashMap selected; + if (tfidfRank) { + selected = tfidf; + } else { + selected = freq; + } + + WordCount[] array = new WordCount[selected.keys().length]; + int index = -1; + for(Object o : selected.keys()) { + String word = (String)o; + double count = selected.get(word); + index++; + array[index] = new WordCount(word, count); + } + System.out.println("Array size: " + array.length); + Arrays.sort(array); + System.out.println("After sort array size: " + array.length); + + int language_id = ll; + int count = 0; + for(int ii = 0; ii < array.length; ii++) { + String word = array[ii].word; + if (word.length() >= wordlength && tfidf.get(word) > tfidfthresh && freq.get(word) > freqthresh) { + if (allwords.contains(word)) { + continue; + } + allwords.add(word); + out.println(language_id + "\t" + array[ii].word + "\t" + tfidf.get(word) + "\t" + (int)freq.get(word)); + count++; + } + } + System.out.println("Filtered vocab size: " + count); + System.out.println("*******************"); + } + out.close(); + + } catch (IOException e) { + e.getMessage(); + } + } + + public static void main(String[] args) { + //String input = "input/synthetic-topic-input.mallet"; + //String vocab = "input/synthetic.voc"; + + String input = "../../itm-evaluation/results/fbis-itm/input/fbis-itm-topic-input.mallet"; + String vocab = "../../itm-evaluation/results/fbis-itm/input/fbis-itm.voc"; + + InstanceList[] instances = new InstanceList[ 2 ]; + InstanceList data = InstanceList.load (new File(input)); + instances[0] = data; + InstanceList data1 = InstanceList.load (new File(input)); + instances[1] = data1; + genVocab(instances, vocab, true, 1, 10, 3); + System.out.println("Done!"); + } +} diff --git a/src/cc/mallet/topics/tree/testFast.java b/src/cc/mallet/topics/tree/testFast.java new file mode 100755 index 000000000..072c96563 --- /dev/null +++ b/src/cc/mallet/topics/tree/testFast.java @@ -0,0 +1,249 @@ +package cc.mallet.topics.tree; + +import gnu.trove.TIntArrayList; +import gnu.trove.TIntDoubleHashMap; +import gnu.trove.TIntIntHashMap; + +import java.io.File; +import java.util.ArrayList; + +import cc.mallet.topics.tree.TreeTopicSamplerHashD.DocData; +import cc.mallet.types.InstanceList; +import junit.framework.TestCase; + +/** + * This class tests the fast sampler. + * @author Yuening Hu + */ + +public class testFast extends TestCase{ + + public TreeTopicSamplerFast Initialize() { + + String inputFile = "input/toy/toy-topic-input.mallet"; + String treeFiles = "input/toy/toy.wn.*"; + String hyperFile = "input/toy/tree_hyperparams"; + String vocabFile = "input/toy/toy.voc"; + String removedFile = "input/toy/removed"; + int numTopics = 3; + double alpha_sum = 0.3; + int randomSeed = 0; + int numIterations = 10; + +// String inputFile = "../input/synthetic-topic-input.mallet"; +// String treeFiles = "../synthetic/synthetic_empty.wn.*"; +// String hyperFile = "../synthetic/tree_hyperparams"; +// String vocabFile = "../synthetic/synthetic.voc"; +// int numTopics = 5; +// double alpha_sum = 0.5; +// int randomSeed = 0; +// int numIterations = 10; + + InstanceList[] instances = new InstanceList[1]; + InstanceList ilist = InstanceList.load (new File(inputFile)); + System.out.println ("Data loaded."); + instances[0] = ilist; + + TreeTopicSamplerFast topicModel = null; + topicModel = new TreeTopicSamplerFast(numTopics, alpha_sum, randomSeed, false); + + topicModel.initialize(treeFiles, hyperFile, vocabFile, removedFile); + topicModel.addInstances(instances); + + topicModel.setNumIterations(numIterations); + + return topicModel; + } + + public void testUpdateParams() { + TreeTopicSamplerFast topicModel = this.Initialize(); + topicModel.topics.updateParams(); + + for(int dd = 0; dd < topicModel.data.size(); dd++) { + System.out.println(topicModel.data.get(dd)); + } + + System.out.println("**************\nNormalizer"); + int numPaths = topicModel.topics.pathToWord.size(); + for(int tt = 0; tt < topicModel.numTopics; tt++) { + for(int pp = 0; pp < numPaths; pp++) { + System.out.println("topic " + tt + " path " + pp + " normalizer " + topicModel.topics.normalizer.get(tt, pp)); + } + } + + System.out.println("**************\nNon zero paths"); + for(int ww : topicModel.topics.nonZeroPaths.keys()) { + for(int tt : topicModel.topics.nonZeroPaths.get(ww).getKey1Set()) { + for(int pp : topicModel.topics.nonZeroPaths.get(ww).get(tt).keys()) { + System.out.println("word " + ww + " topic " + tt + " path " + pp + " " + topicModel.topics.nonZeroPaths.get(ww).get(tt, pp)); + } + } + } + } + + public void testUpdatePathmaskedCount() { + TreeTopicSamplerFast topicModel = this.Initialize(); + topicModel.topics.updateParams(); + int numPaths = topicModel.topics.pathToWord.size(); + + TreeTopicModelFast topics = (TreeTopicModelFast)topicModel.topics; + + for (int ww : topics.nonZeroPaths.keys()) { + for(int tt : topics.nonZeroPaths.get(ww).getKey1Set()) { + for(int pp : topicModel.topics.nonZeroPaths.get(ww).get(tt).keys()) { + TIntArrayList path_nodes = topics.wordPaths.get(ww, pp); + int parent = path_nodes.get(path_nodes.size() - 2); + int child = path_nodes.get(path_nodes.size() - 1); + + int mask = topics.nonZeroPaths.get(ww).get(tt, pp) - topics.traversals.get(tt).getCount(parent, child); + + System.out.println("*************************"); + System.out.println("Topic " + tt + " Word " + ww + " path " + pp); + String tmp = "["; + for (int ii : path_nodes.toNativeArray()) { + tmp += " " + ii; + } + System.out.println("Real path " + tmp + " ]"); + System.out.println("Real count " + topics.traversals.get(tt).getCount(parent, child)); + System.out.println("Masked count " + topics.nonZeroPaths.get(ww).get(tt, pp)); + System.out.println("Masekd count " + Integer.toBinaryString(topics.nonZeroPaths.get(ww).get(tt, pp))); + System.out.println("*************************"); + } + } + } + } + + public void testChangeTopic() { + TreeTopicSamplerFast topicModel = this.Initialize(); + topicModel.topics.updateParams(); + TreeTopicModelFast topics = (TreeTopicModelFast)topicModel.topics; + //for(int dd = 0; dd < topicModel.data.size(); dd++){ + for(int dd = 0; dd < 1; dd++){ + DocData doc = topicModel.data.get(dd); + for(int ii = 0; ii < doc.tokens.size(); ii++) { + int word = doc.tokens.get(ii); + int old_topic = doc.topics.get(ii); + int old_path = doc.paths.get(ii); + TIntArrayList path_nodes = topicModel.topics.wordPaths.get(word, old_path); + int node = path_nodes.get(0); + int leaf = path_nodes.get(path_nodes.size() - 1); + int total = 0; + for(int nn : topics.traversals.get(word).counts.get(node).keys()){ + total += topics.traversals.get(word).getCount(node, nn); + } + + assertTrue(topics.traversals.get(word).getNodeCount(node) == total); + + System.out.println("*************************"); + System.out.println("old topic " + old_topic + " word " + word); + System.out.println("old normalizer " + topics.normalizer.get(old_topic, old_path)); + System.out.println("old root count " + topics.traversals.get(old_topic).getNodeCount(node) + " " + total); + System.out.println("old non zero count " + Integer.toBinaryString(topics.nonZeroPaths.get(word).get(old_topic, old_path))); + System.out.println("old leaf count " + topics.traversals.get(old_topic).getNodeCount(leaf)); + + topicModel.changeTopic(dd, ii, word, -1, -1); + + total = 0; + for(int nn : topics.traversals.get(old_topic).counts.get(node).keys()){ + total += topics.traversals.get(old_topic).getCount(node, nn); + } + assertTrue(topics.traversals.get(old_topic).getNodeCount(node) == total); + System.out.println("*************************"); + System.out.println("updated old topic " + old_topic + " word " + word); + System.out.println("updated old normalizer " + topics.normalizer.get(old_topic, old_path)); + System.out.println("updated old root count " + topics.traversals.get(old_topic).getNodeCount(node) + " " + total); + System.out.println("updated old non zero count " + Integer.toBinaryString(topics.nonZeroPaths.get(word).get(old_topic, old_path))); + System.out.println("updated old leaf count " + topics.traversals.get(old_topic).getNodeCount(leaf)); + + + int new_topic = topicModel.numTopics - old_topic - 1; + int new_path = old_path; + + total = 0; + for(int nn : topics.traversals.get(new_topic).counts.get(node).keys()){ + total += topics.traversals.get(new_topic).getCount(node, nn); + } + assertTrue(topics.traversals.get(new_topic).getNodeCount(node) == total); + + System.out.println("*************************"); + System.out.println("new topic " + new_topic + " word " + word); + System.out.println("new normalizer " + topics.normalizer.get(new_topic, new_path)); + System.out.println("new root count " + topics.traversals.get(new_topic).getNodeCount(node) + " " + total); + System.out.println("new non zero count " + Integer.toBinaryString(topics.nonZeroPaths.get(word).get(new_topic, new_path))); + System.out.println("new leaf count " + topics.traversals.get(new_topic).getNodeCount(leaf)); + + topicModel.changeTopic(dd, ii, word, new_topic, new_path); + + + total = 0; + for(int nn : topics.traversals.get(new_topic).counts.get(node).keys()){ + total += topics.traversals.get(new_topic).getCount(node, nn); + } + assertTrue(topics.traversals.get(new_topic).getNodeCount(node) == total); + System.out.println("*************************"); + System.out.println("updated new topic " + new_topic + " word " + word); + System.out.println("updated new normalizer " + topics.normalizer.get(new_topic, new_path)); + System.out.println("updated new root count " + topics.traversals.get(new_topic).getNodeCount(node) + " " + total); + System.out.println("updated new non zero count " + Integer.toBinaryString(topics.nonZeroPaths.get(word).get(new_topic, new_path))); + System.out.println("updated new leaf count " + topics.traversals.get(new_topic).getNodeCount(leaf)); + + System.out.println("*************************\n"); + } + } + } + + public void testBinValues() { + TreeTopicSamplerFast topicModelFast = this.Initialize(); + topicModelFast.topics.updateParams(); + + TreeTopicSamplerNaive topicModelNaive = testNaive.Initialize(); + topicModelNaive.topics.updateParams(); + + //for(int dd = 0; dd < topicModelFast.data.size(); dd++){ + for(int dd = 0; dd < 1; dd++){ + DocData doc = topicModelFast.data.get(dd); + DocData doc1 = topicModelNaive.data.get(dd); + + //for(int ii = 0; ii < doc.tokens.size(); ii++) { + for(int ii = 4; ii < 5; ii++) { + int word = doc.tokens.get(ii); + int topic = doc.topics.get(ii); + int path = doc.paths.get(ii); + + double smoothing = topicModelFast.callComputeTermSmoothing(word); + double topicbeta = topicModelFast.callComputeTermTopicBeta(doc.topicCounts, word); + ArrayList dict = new ArrayList(); + double topictermscore = topicModelFast.topics.computeTopicTerm(topicModelFast.alpha, + doc.topicCounts, word, dict); + double norm = smoothing + topicbeta + topictermscore; + + double smoothing1 = topicModelFast.computeTopicSmoothTest(word); + double topicbeta1 = topicModelFast.computeTopicTermBetaTest(doc.topicCounts, word); + HIntIntDoubleHashMap dict1 = new HIntIntDoubleHashMap(); + double topictermscore1 = topicModelFast.computeTopicTermScoreTest(topicModelFast.alpha, + doc.topicCounts, word, dict1); + double norm1 = smoothing1 + topicbeta1 + topictermscore1; + + System.out.println("*************"); + System.out.println("Index " + ii); + System.out.println(smoothing + " " + smoothing1); + System.out.println(topicbeta + " " + topicbeta1); + System.out.println(topictermscore + " " + topictermscore1); + + ArrayList dict2 = new ArrayList(); + double norm2 = topicModelFast.computeTopicTermTest(topicModelNaive.alpha, doc.topicCounts, word, dict2); + + ArrayList dict3 = new ArrayList(); + double norm3 = topicModelNaive.topics.computeTopicTerm(topicModelNaive.alpha, doc.topicCounts, word, dict3); + + System.out.println(norm + " " + norm1 + " " + norm2 + " " + norm3); +// if (norm1 != norm2) { +// System.out.println(norm + " " + norm1 + " " + norm2 + " " + norm3 ); +// } + System.out.println("*************"); + assert(norm == norm1); + assert(1 == 0); + } + } + } +} diff --git a/src/cc/mallet/topics/tree/testNaive.java b/src/cc/mallet/topics/tree/testNaive.java new file mode 100755 index 000000000..84143b62c --- /dev/null +++ b/src/cc/mallet/topics/tree/testNaive.java @@ -0,0 +1,157 @@ +package cc.mallet.topics.tree; + +import gnu.trove.TIntArrayList; +import gnu.trove.TIntIntHashMap; + +import java.io.File; +import java.util.ArrayList; + +import cc.mallet.topics.tree.TreeTopicSamplerHashD.DocData; +import cc.mallet.types.Alphabet; +import cc.mallet.types.FeatureSequence; +import cc.mallet.types.Instance; +import cc.mallet.types.InstanceList; +import junit.framework.TestCase; + +/** + * This class tests the naive sampler. + * @author Yuening Hu + */ + +public class testNaive extends TestCase{ + + public static TreeTopicSamplerNaive Initialize() { + + String inputFile = "input/toy/toy-topic-input.mallet"; + String treeFiles = "input/toy/toy.wn.*"; + String hyperFile = "input/toy/tree_hyperparams"; + String vocabFile = "input/toy/toy.voc"; + String removedFile = "input/toy/removed"; + int numTopics = 3; + double alpha_sum = 0.3; + int randomSeed = 0; + int numIterations = 10; + +// String inputFile = "../input/synthetic-topic-input.mallet"; +// String treeFiles = "../synthetic/synthetic.wn.*"; +// String hyperFile = "../synthetic/tree_hyperparams"; +// String vocabFile = "../synthetic/synthetic.voc"; +// int numTopics = 5; +// double alpha_sum = 0.5; +// int randomSeed = 0; +// int numIterations = 10; + + InstanceList[] instances = new InstanceList[1]; + InstanceList ilist = InstanceList.load (new File(inputFile)); + System.out.println ("Data loaded."); + instances[0] = ilist; + + TreeTopicSamplerNaive topicModel = null; + topicModel = new TreeTopicSamplerNaive(numTopics, alpha_sum, randomSeed); + + topicModel.initialize(treeFiles, hyperFile, vocabFile, removedFile); + topicModel.addInstances(instances); + + topicModel.setNumIterations(numIterations); + + return topicModel; + } + + public void testChangeTopic() { + TreeTopicSamplerNaive topicModel = this.Initialize(); + for (int dd = 0; dd < topicModel.data.size(); dd++ ) { + DocData doc = topicModel.data.get(dd); + for (int index = 0; index < doc.tokens.size(); index++) { + int word = doc.tokens.get(index); + int old_topic = doc.topics.get(index); + int old_path = doc.paths.get(index); + int old_count = doc.topicCounts.get(old_topic); + + topicModel.changeTopic(dd, index, word, -1, -1); + assertTrue(doc.topics.get(index) == -1); + assertTrue(doc.paths.get(index) == -1); + assertTrue(doc.topicCounts.get(old_topic) == old_count-1); + + int new_topic = topicModel.numTopics - old_topic - 1; + int new_path = old_path; + int new_count = doc.topicCounts.get(new_topic); + topicModel.changeTopic(dd, index, word, new_topic, new_path); + + assertTrue(doc.topics.get(index) == new_topic); + assertTrue(doc.paths.get(index) == new_path); + assertTrue(doc.topicCounts.get(new_topic) == new_count+1); + } + } + } + + public void testChangCount() { + TreeTopicSamplerNaive topicModel = this.Initialize(); + for (int dd = 0; dd < topicModel.data.size(); dd++ ) { + DocData doc = topicModel.data.get(dd); + + for (int index = 0; index < doc.tokens.size(); index++) { + int word = doc.tokens.get(index); + int old_topic = doc.topics.get(index); + int old_path = doc.paths.get(index); + + TopicTreeWalk tw = topicModel.topics.traversals.get(old_topic); + TIntArrayList path_nodes = topicModel.topics.wordPaths.get(word, old_path); + + int[] old_count = new int[path_nodes.size() - 1]; + for(int nn = 0; nn < path_nodes.size() - 1; nn++) { + int parent = path_nodes.get(nn); + int child = path_nodes.get(nn+1); + old_count[nn] = tw.getCount(parent, child); + } + + int[] old_node_count = new int[path_nodes.size()]; + for(int nn = 0; nn < path_nodes.size(); nn++) { + int node = path_nodes.get(nn); + old_node_count[nn] = tw.getNodeCount(node); + } + + int inc = 1; + tw.changeCount(path_nodes, inc); + + for(int nn = 0; nn < path_nodes.size() - 1; nn++) { + int parent = path_nodes.get(nn); + int child = path_nodes.get(nn+1); + assertTrue(old_count[nn] == tw.getCount(parent, child) - inc); + } + + for(int nn = 0; nn < path_nodes.size(); nn++) { + int node = path_nodes.get(nn); + assertTrue(old_node_count[nn] == tw.getNodeCount(node) - inc); + } + + } + } + + } + + public void testComputeTermScore() { + TreeTopicSamplerNaive topicModel = this.Initialize(); + for (int dd = 0; dd < topicModel.data.size(); dd++ ) { + DocData doc = topicModel.data.get(dd); + System.out.println("------------" + dd + "------------"); + for (int index = 0; index < doc.tokens.size(); index++) { + int word = doc.tokens.get(index); + + //topicModel.changeTopic(dd, index, word, -1, -1); + + ArrayList topic_term_score = new ArrayList(); + double norm = topicModel.topics.computeTopicTerm(topicModel.alpha, doc.topicCounts, word, topic_term_score); + System.out.println(norm); + + for(int jj = 0; jj < topic_term_score.size(); jj++) { + double[] tmp = topic_term_score.get(jj); + int tt = (int) tmp[0]; + int pp = (int) tmp[1]; + double val = tmp[2]; + System.out.println(tt + " " + pp + " " + val); + } + } + } + } + +} diff --git a/src/cc/mallet/topics/tui/EvaluateTreeTopics.java b/src/cc/mallet/topics/tui/EvaluateTreeTopics.java new file mode 100644 index 000000000..ce17fdf5b --- /dev/null +++ b/src/cc/mallet/topics/tui/EvaluateTreeTopics.java @@ -0,0 +1,96 @@ +package cc.mallet.topics.tui; + +import java.io.File; +import java.io.PrintStream; + +import cc.mallet.topics.MarginalProbEstimator; +import cc.mallet.topics.tree.TreeMarginalProbEstimator; +import cc.mallet.topics.tree.TreeTopicInferencer; +import cc.mallet.types.InstanceList; +import cc.mallet.util.CommandOption; + +public class EvaluateTreeTopics { + // common options in mallet + + static CommandOption.String inputFile = new CommandOption.String + (EvaluateTreeTopics.class, "input", "FILENAME", true, null, + "The filename from which to read the list of testing instances. Use - for stdin. " + + "The instances must be FeatureSequence or FeatureSequenceWithBigrams, not FeatureVector", null); + + static CommandOption.Integer randomSeed = new CommandOption.Integer + (EvaluateTreeTopics.class, "random-seed", "INTEGER", true, 0, + "The random seed for the Gibbs sampler. Default is 0, which will use the clock.", null); + + static CommandOption.String evaluatorFilename = new CommandOption.String + (EvaluateTreeTopics.class, "evaluator", "FILENAME", true, null, + "A serialized topic evaluator from a trained topic model.\n" + + "By default this is null, indicating that no file will be read.", null); + + static CommandOption.String docProbabilityFile = new CommandOption.String + (EvaluateTreeTopics.class, "output-doc-probs", "FILENAME", true, null, + "The filename in which to write the inferred log probabilities\n" + + "per document. " + + "By default this is null, indicating that no file will be written.", null); + + static CommandOption.String probabilityFile = new CommandOption.String + (EvaluateTreeTopics.class, "output-prob", "FILENAME", true, "-", + "The filename in which to write the inferred log probability of the testing set\n" + + "Use - for stdout, which is the default.", null); + + static CommandOption.Integer numParticles = new CommandOption.Integer + (EvaluateTreeTopics.class, "num-particles", "INTEGER", true, 10, + "The number of particles to use in left-to-right evaluation.", null); + + static CommandOption.Boolean usingResampling = new CommandOption.Boolean + (EvaluateTreeTopics.class, "use-resampling", "TRUE|FALSE", false, false, + "Whether to resample topics in left-to-right evaluation. Resampling is more accurate, but leads to quadratic scaling in the lenght of documents.", null); + + + public static void main (String[] args) throws java.io.IOException { + // Process the command-line options + CommandOption.setSummary (EvaluateTreeTopics.class, + "Estimate the marginal probability of new documents."); + CommandOption.process (EvaluateTreeTopics.class, args); + + if (evaluatorFilename.value == null) { + System.err.println("You must specify a serialized topic evaluator. Use --help to list options."); + System.exit(0); + } + + if (inputFile.value == null) { + System.err.println("You must specify a serialized instance list. Use --help to list options."); + System.exit(0); + } + + try { + + PrintStream docProbabilityStream = null; + if (docProbabilityFile.value != null) { + docProbabilityStream = new PrintStream(docProbabilityFile.value); + } + + PrintStream outputStream = System.out; + if (probabilityFile.value != null && + ! probabilityFile.value.equals("-")) { + outputStream = new PrintStream(probabilityFile.value); + } + + TreeMarginalProbEstimator evaluator = + TreeMarginalProbEstimator.read(new File(evaluatorFilename.value)); + + InstanceList instances = InstanceList.load (new File(inputFile.value)); + + evaluator.setRandomSeed(randomSeed.value); + + outputStream.println(evaluator.evaluateLeftToRight(instances, numParticles.value, + usingResampling.value, + docProbabilityStream)); + + + } catch (Exception e) { + e.printStackTrace(); + System.err.println(e.getMessage()); + } + + } +} diff --git a/src/cc/mallet/topics/tui/GenerateTree.java b/src/cc/mallet/topics/tui/GenerateTree.java new file mode 100644 index 000000000..d87df11e0 --- /dev/null +++ b/src/cc/mallet/topics/tui/GenerateTree.java @@ -0,0 +1,42 @@ +package cc.mallet.topics.tui; + +import java.io.File; +import cc.mallet.topics.tree.OntologyWriter; +import cc.mallet.util.CommandOption; + +public class GenerateTree { + + static CommandOption.String vocabFile = new CommandOption.String + (GenerateTree.class, "vocab", "FILENAME", true, null, + "The vocabulary file.", null); + + static CommandOption.String treeFiles = new CommandOption.String + (GenerateTree.class, "tree", "FILENAME", true, null, + "The files for tree structure.", null); + + static CommandOption.String consFile = new CommandOption.String + (GenerateTree.class, "constraint", "FILENAME", true, null, + "The constraint file.", null); + + static CommandOption.Boolean mergeCons = new CommandOption.Boolean + (GenerateTree.class, "merge-constraints", "true|false", false, true, + "Merge constraints or not. For example, if you want to merge A and B, " + + "and merge B and C and set merge-constraints as true, the new constraint" + + "will be merge A, B and C.", null); + + public static void main (String[] args) throws java.io.IOException { + // Process the command-line options + CommandOption.setSummary (GenerateTree.class, + "Generate a prior tree structure for LDA, in proto buffer format."); + CommandOption.process (GenerateTree.class, args); + + try { + OntologyWriter.createOntology(consFile.value, vocabFile.value, + treeFiles.value, mergeCons.value); + } catch (Exception e) { + e.printStackTrace(); + } + + } + +} diff --git a/src/cc/mallet/topics/tui/GenerateVocab.java b/src/cc/mallet/topics/tui/GenerateVocab.java new file mode 100644 index 000000000..1ab186349 --- /dev/null +++ b/src/cc/mallet/topics/tui/GenerateVocab.java @@ -0,0 +1,58 @@ +package cc.mallet.topics.tui; + +import java.io.File; + +import cc.mallet.topics.tree.VocabGenerator; +import cc.mallet.types.InstanceList; +import cc.mallet.util.CommandOption; + +public class GenerateVocab { + + // common options in mallet + static CommandOption.SpacedStrings inputFile = new CommandOption.SpacedStrings + (GenerateVocab.class, "input", "FILENAME [FILENAME ...]", true, null, + "The filename from which to read the list of training instances. " + + "Support multiple languages, each language should have its own file. " + + "The instances must be FeatureSequence or FeatureSequenceWithBigrams, not FeatureVector", null); + + static CommandOption.String vocabFile = new CommandOption.String + (GenerateVocab.class, "vocab", "FILENAME", true, null, + "The vocabulary file.", null); + + static CommandOption.Boolean tfidfRank = new CommandOption.Boolean + (GenerateVocab.class, "tfidf-rank", "true|false", false, true, + "Rank vocab by the averaged tfidf of words, or by frequency.", null); + + static CommandOption.Double tfidfThresh = new CommandOption.Double + (GenerateVocab.class, "tfidf-thresh", "DECIMAL", true, 1.0, + "The thresh for tfidf to filter out words.",null); + + static CommandOption.Double freqThresh = new CommandOption.Double + (GenerateVocab.class, "freq-thresh", "DECIMAL", true, 1.0, + "The thresh for frequency to filter out words.",null); + + static CommandOption.Double wordLength = new CommandOption.Double + (GenerateVocab.class, "word-length", "DECIMAL", true, 3.0, + "Keep words with length equal or large than the thresh.",null); + + + public static void main (String[] args) throws java.io.IOException { + // Process the command-line options + CommandOption.setSummary (GenerateVocab.class, + "Filtering words by tfidf, frequency, word-length, and generate the vocab."); + CommandOption.process (GenerateVocab.class, args); + + int numLanguages = inputFile.value.length; + InstanceList[] instances = new InstanceList[ numLanguages ]; + for (int i=0; i < instances.length; i++) { + instances[i] = InstanceList.load(new File(inputFile.value[i])); + System.out.println ("Data " + i + " loaded. Total number of documents: " + instances[i].size()); + } + + + VocabGenerator.genVocab(instances, vocabFile.value, tfidfRank.value, tfidfThresh.value, + freqThresh.value, wordLength.value); + + } + +} diff --git a/src/cc/mallet/topics/tui/InferTreeTopics.java b/src/cc/mallet/topics/tui/InferTreeTopics.java new file mode 100755 index 000000000..ea0efa4bc --- /dev/null +++ b/src/cc/mallet/topics/tui/InferTreeTopics.java @@ -0,0 +1,86 @@ +package cc.mallet.topics.tui; + +import java.io.File; + +import cc.mallet.topics.TopicInferencer; +import cc.mallet.topics.tree.VocabGenerator; +import cc.mallet.topics.tree.OntologyWriter; +import cc.mallet.topics.tree.TreeTopicInferencer; +import cc.mallet.topics.tree.TreeTopicSamplerFast; +import cc.mallet.topics.tree.TreeTopicSamplerFastEst; +import cc.mallet.topics.tree.TreeTopicSamplerFastEstSortD; +import cc.mallet.topics.tree.TreeTopicSamplerFastSortD; +import cc.mallet.topics.tree.TreeTopicSamplerInterface; +import cc.mallet.topics.tree.TreeTopicSamplerNaive; +import cc.mallet.types.InstanceList; +import cc.mallet.util.CommandOption; + +public class InferTreeTopics { + + // common options in mallet + + static CommandOption.String inputFile = new CommandOption.String + (InferTreeTopics.class, "input", "FILENAME", true, null, + "The filename from which to read the list of testing instances. Use - for stdin. " + + "The instances must be FeatureSequence or FeatureSequenceWithBigrams, not FeatureVector", null); + + static CommandOption.Integer numIterations = new CommandOption.Integer + (InferTreeTopics.class, "num-iterations", "INTEGER", true, 1000, + "The number of iterations of Gibbs sampling.", null); + + static CommandOption.String inferencerFilename = new CommandOption.String + (InferTreeTopics.class, "inferencer", "FILENAME", true, null, + "A topic inferencer applies a previously trained topic model to new documents." + + "By default this is null, indicating that no file will be written.", null); + + static CommandOption.String docTopicsFile = new CommandOption.String + (InferTreeTopics.class, "output-doc-topics", "FILENAME", true, null, + "The filename in which to write the inferred topic\n" + + "proportions per document. " + + "By default this is null, indicating that no file will be written.", null); + + static CommandOption.Integer randomSeed = new CommandOption.Integer + (InferTreeTopics.class, "random-seed", "INTEGER", true, 0, + "The random seed for the Gibbs sampler. Default is 0, which will use the clock.", null); + + static CommandOption.Integer outputInteval = new CommandOption.Integer + (InferTreeTopics.class, "output-interval", "INTEGER", true, 20, + "For each interval, the result files are output to the outputFolder.", null); + + + public static void main (String[] args) throws java.io.IOException { + // Process the command-line options + CommandOption.setSummary (InferTreeTopics.class, + "A tool for estimating, saving and printing diagnostics for topic models, such as LDA."); + CommandOption.process (InferTreeTopics.class, args); + + if (inferencerFilename.value == null) { + System.err.println("You must specify a serialized topic inferencer. Use --help to list options."); + System.exit(0); + } + + if (inputFile.value == null) { + System.err.println("You must specify a serialized instance list. Use --help to list options."); + System.exit(0); + } + + try { + InstanceList testlist = InstanceList.load (new File(inputFile.value)); + System.out.println ("Test data loaded."); + + TreeTopicInferencer inferencer = TreeTopicInferencer.read(new File(inferencerFilename.value)); + System.out.println("Inferencer loaded."); + + inferencer.setRandomSeed(randomSeed.value); + + inferencer.writeInferredDistributions(testlist, new File(docTopicsFile.value), + numIterations.value, outputInteval.value); + + } catch (Exception e) { + e.printStackTrace(); + System.err.println(e.getMessage()); + } + + } + +} diff --git a/src/cc/mallet/topics/tui/Vectors2TreeTopics.java b/src/cc/mallet/topics/tui/Vectors2TreeTopics.java new file mode 100755 index 000000000..a3a5d6764 --- /dev/null +++ b/src/cc/mallet/topics/tui/Vectors2TreeTopics.java @@ -0,0 +1,260 @@ +/* Copyright (C) 2005 Univ. of Massachusetts Amherst, Computer Science Dept. + This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit). + http://www.cs.umass.edu/~mccallum/mallet + This software is provided under the terms of the Common Public License, + version 1.0, as published by http://www.opensource.org. For further + information, see the file `LICENSE' included with this distribution. */ + +package cc.mallet.topics.tui; + +import cc.mallet.util.CommandOption; +import cc.mallet.types.InstanceList; +import cc.mallet.topics.tree.TreeMarginalProbEstimator; +import cc.mallet.topics.tree.TreeTopicInferencer; +import cc.mallet.topics.tree.TreeTopicSamplerInterface; +import cc.mallet.topics.tree.TreeTopicSamplerFast; +import cc.mallet.topics.tree.TreeTopicSamplerFastEst; +import cc.mallet.topics.tree.TreeTopicSamplerFastEstSortD; +import cc.mallet.topics.tree.TreeTopicSamplerFastSortD; +import cc.mallet.topics.tree.TreeTopicSamplerNaive; + +import java.io.*; + +/** Perform topic analysis in the style of LDA and its variants. + * @author Andrew McCallum + */ + +public class Vectors2TreeTopics { + + // common options in mallet + static CommandOption.SpacedStrings inputFile = new CommandOption.SpacedStrings + (Vectors2TreeTopics.class, "input", "FILENAME [FILENAME ...]", true, null, + "The filename from which to read the list of training instances. " + + "Support multiple languages, each language should have its own file. " + + "The instances must be FeatureSequence or FeatureSequenceWithBigrams, not FeatureVector", null); + + static CommandOption.Integer numTopics = new CommandOption.Integer + (Vectors2TreeTopics.class, "num-topics", "INTEGER", true, 10, + "The number of topics to fit.", null); + + static CommandOption.Integer numIterations = new CommandOption.Integer + (Vectors2TreeTopics.class, "num-iterations", "INTEGER", true, 1000, + "The number of iterations of Gibbs sampling.", null); + + static CommandOption.Integer randomSeed = new CommandOption.Integer + (Vectors2TreeTopics.class, "random-seed", "INTEGER", true, 0, + "The random seed for the Gibbs sampler. Default is 0, which will use the clock.", null); + + static CommandOption.Integer topWords = new CommandOption.Integer + (Vectors2TreeTopics.class, "num-top-words", "INTEGER", true, 20, + "The number of most probable words to print for each topic after model estimation.", null); + + static CommandOption.Double alpha = new CommandOption.Double + (Vectors2TreeTopics.class, "alpha", "DECIMAL", true, 50.0, + "Alpha parameter: smoothing over topic distribution.",null); + + static CommandOption.String inferencerFilename = new CommandOption.String + (Vectors2TreeTopics.class, "inferencer-filename", "FILENAME", true, null, + "A topic inferencer applies a previously trained topic model to new documents." + + "By default this is null, indicating that no file will be written.", null); + + static CommandOption.String evaluatorFilename = new CommandOption.String + (Vectors2TreeTopics.class, "evaluator-filename", "FILENAME", true, null, + "A held-out likelihood evaluator for new documents. " + + "By default this is null, indicating that no file will be written.", null); + + //////////////////////////////////// + // new options + + static CommandOption.Integer outputInteval = new CommandOption.Integer + (Vectors2TreeTopics.class, "output-interval", "INTEGER", true, 20, + "For each interval, the result files are output to the outputFolder.", null); + + static CommandOption.String outputDir= new CommandOption.String + (Vectors2TreeTopics.class, "output-dir", "FOLDERNAME", true, null, + "The output folder.", null); + + static CommandOption.String vocabFile = new CommandOption.String + (Vectors2TreeTopics.class, "vocab", "FILENAME", true, null, + "The vocabulary file.", null); + + static CommandOption.String treeFiles = new CommandOption.String + (Vectors2TreeTopics.class, "tree", "FILENAME", true, null, + "The files for tree structure.", null); + + static CommandOption.String hyperFile = new CommandOption.String + (Vectors2TreeTopics.class, "tree-hyperparameters", "FILENAME", true, null, + "The hyperparameters for tree structure.", null); + + static CommandOption.Boolean resume = new CommandOption.Boolean + (Vectors2TreeTopics.class, "resume", "true|false", false, false, + "Resume from the previous output states.", null); + + static CommandOption.String resumeDir = new CommandOption.String + (Vectors2TreeTopics.class, "resume-dir", "FOLDERNAME", true, null, + "The resume folder.", null); + + static CommandOption.String consFile = new CommandOption.String + (Vectors2TreeTopics.class, "constraint", "FILENAME", true, null, + "The file constains the constrained words", null); + + static CommandOption.String forgetTopics = new CommandOption.String + (Vectors2TreeTopics.class, "forget-topics", "TYPENAME", true, null, + "Three options: term, doc, null." + + "Forget the previous sampled topic assignments of constrained words only (term), " + + "or the documents containing constrained words (doc)," + + "or not forget at all (keep everything)." + + "This option is for adding interaction.", null); + + static CommandOption.String removedFile = new CommandOption.String + (Vectors2TreeTopics.class, "remove-words", "FILENAME", true, null, + "The file contains the words that you want to be ignored in topic modeling. " + + "You need to have removed.all file, which is the removed words before this round of interaction," + + "and a removed.new file, which is the removed words that users just defined in this round of interaction" + + "This option is for adding interaction.", null); + + static CommandOption.String keepFile = new CommandOption.String + (Vectors2TreeTopics.class, "keep", "FILENAME", true, null, + "The topic assignments of words on this list will be kept instead of cleared," + + "even though it is on the list of constrained words." + + "This option is for adding interaction.", null); + + static CommandOption.String modelType = new CommandOption.String + (Vectors2TreeTopics.class, "tree-model-type", "TYPENAME", true, "fast-est", + "Possible types: naive, fast, fast-est, fast-sortD, fast-sortW, fast-sortD-sortW, " + + "fast-est-sortD, fast-est-sortW, fast-est-sortD-sortW.", null); + + public static void main (String[] args) throws java.io.IOException { + // Process the command-line options + CommandOption.setSummary (Vectors2TreeTopics.class, + "A tool for estimating, saving and printing diagnostics for topic models, such as LDA."); + CommandOption.process (Vectors2TreeTopics.class, args); + + int numLanguages = inputFile.value.length; + InstanceList[] instances = new InstanceList[ numLanguages ]; + for (int i=0; i < instances.length; i++) { + instances[i] = InstanceList.load(new File(inputFile.value[i])); + System.out.println ("Data " + i + " loaded. Total number of documents: " + instances[i].size()); + } + + TreeTopicSamplerInterface topicModel = null; + + // notice there are more inference methods available in this pacakge: + // naive, fast, fast-est, fast-sortD, fast-sortW, + // fast-sortD-sortW, fast-est-sortD, fast-est-sortW, fast-est-sortD-sortW + // by default, we set it as fast-est-sortD-sortW + // but you can change the modelType to any of them by exploring the source code + // also notice the inferencer and evaluator only support fast-est, fast-sortD-sortW, + // fast-est-sortD, fast-est-sortW, fast-est-sortD-sortW + boolean sortW = false; + String modeltype = "fast-est"; + //System.out.println("model type:" + modeltype); + modeltype = modelType.value; + + if (modeltype.equals("naive")) { + topicModel = new TreeTopicSamplerNaive( + numTopics.value, alpha.value, randomSeed.value); + } else if (modeltype.equals("fast")){ + topicModel = new TreeTopicSamplerFast( + numTopics.value, alpha.value, randomSeed.value, sortW); + } else if (modeltype.equals("fast-sortD")){ + topicModel = new TreeTopicSamplerFastSortD( + numTopics.value, alpha.value, randomSeed.value, sortW); + } else if (modeltype.equals("fast-sortW")){ + sortW = true; + topicModel = new TreeTopicSamplerFast( + numTopics.value, alpha.value, randomSeed.value, sortW); + } else if (modeltype.equals("fast-sortD-sortW")){ + sortW = true; + topicModel = new TreeTopicSamplerFastSortD( + numTopics.value, alpha.value, randomSeed.value, sortW); + + } else if (modeltype.equals("fast-est")) { + topicModel = new TreeTopicSamplerFastEst( + numTopics.value, alpha.value, randomSeed.value, sortW); + } else if (modeltype.equals("fast-est-sortD")) { + topicModel = new TreeTopicSamplerFastEstSortD( + numTopics.value, alpha.value, randomSeed.value, sortW); + } else if (modeltype.equals("fast-est-sortW")) { + sortW = true; + topicModel = new TreeTopicSamplerFastEst( + numTopics.value, alpha.value, randomSeed.value, sortW); + } else if (modeltype.equals("fast-est-sortD-sortW")) { + sortW = true; + topicModel = new TreeTopicSamplerFastEstSortD( + numTopics.value, alpha.value, randomSeed.value, sortW); + //} else if (modeltype.equals("fast-est-try")) { + // topicModel = new TreeTopicSamplerFastEstTry( + // numTopics.value, alpha.value, randomSeed.value, sortW); + } else { + System.out.println("model type wrong! please use " + + "'naive', 'fast', 'fast-est', " + + "'fast-sortD', 'fast-sortW', 'fast-sortD-sortW', " + + "'fast-est-sortD', 'fast-est-sortW', 'fast-est-sortD-sortW'!"); + System.exit(0); + } + + // load tree and vocab + topicModel.initialize(treeFiles.value, hyperFile.value, vocabFile.value, removedFile.value); + topicModel.setNumIterations(numIterations.value); + System.out.println("Prior tree loaded!"); + + if (resume.value == true) { + // resume instances from the saved states + topicModel.resume(instances, resumeDir.value); + } else { + // add instances + topicModel.addInstances(instances); + } + System.out.println("Model initialized!"); + + // if clearType is not null, clear the topic assignments of the + // constraint words + if (forgetTopics.value != null) { + if (forgetTopics.value.equals("term") || forgetTopics.value.equals("doc")) { + topicModel.clearTopicAssignments(forgetTopics.value, consFile.value, keepFile.value); + } else { + System.out.println("clear type wrong! please use either 'doc' or 'term'!"); + System.exit(0); + } + } + + // sampling and save states + topicModel.estimate(numIterations.value, outputDir.value, + outputInteval.value, topWords.value); + + // topic report + //System.out.println(topicModel.displayTopWords(topWords.value)); + + if (inferencerFilename.value != null) { + try { + ObjectOutputStream oos = + new ObjectOutputStream(new FileOutputStream(inferencerFilename.value)); + TreeTopicInferencer infer = topicModel.getInferencer(); + infer.setModelType(modeltype); + oos.writeObject(infer); + oos.close(); + } catch (Exception e) { + System.err.println(e.getMessage()); + } + + } + + if (evaluatorFilename.value != null) { + try { + ObjectOutputStream oos = + new ObjectOutputStream(new FileOutputStream(evaluatorFilename.value)); + TreeMarginalProbEstimator estimator = topicModel.getProbEstimator(); + estimator.setModelType(modeltype); + oos.writeObject(estimator); + oos.close(); + + } catch (Exception e) { + System.err.println(e.getMessage()); + } + + } + + } + +}