diff --git a/example/hurricane2.blog b/example/hurricane2.blog
new file mode 100644
index 00000000..6b5cdd9b
--- /dev/null
+++ b/example/hurricane2.blog
@@ -0,0 +1,32 @@
+/**
+ * Hurricane
+ * Figure 4.2 in Milch's thesis
+ */
+
+type City;
+type PrepLevel;
+type DamageLevel;
+
+random City First ~ UniformChoice({c for City c});
+
+random City NotFirst ~ UniformChoice({c for City c: c != First});
+
+random PrepLevel Prep(City c) ~
+ if (First == c) then Categorical({High -> 0.5, Low -> 0.5})
+ else case Damage(First) in
+ {Severe -> Categorical({High -> 0.9, Low -> 0.1}),
+ Mild -> Categorical({High -> 0.1, Low -> 0.9})}
+ ;
+
+random DamageLevel Damage(City c) ~
+ case Prep(c) in {High -> Categorical({Severe -> 0.2, Mild -> 0.8}),
+ Low -> Categorical({Severe -> 0.8, Mild -> 0.2})}
+ ;
+
+distinct City A, B;
+distinct PrepLevel Low, High;
+distinct DamageLevel Severe, Mild;
+
+obs Damage(First) = Severe;
+
+query First;
diff --git a/src/main/java/blog/sample/GenericProposer.java b/src/main/java/blog/sample/GenericProposer.java
index 450bb2a5..1f002759 100644
--- a/src/main/java/blog/sample/GenericProposer.java
+++ b/src/main/java/blog/sample/GenericProposer.java
@@ -37,11 +37,12 @@
import java.util.HashSet;
import java.util.Iterator;
-import java.util.LinkedList;
import java.util.Properties;
import java.util.Set;
+import blog.bn.BasicVar;
import blog.bn.BayesNetVar;
+import blog.bn.DerivedVar;
import blog.bn.VarWithDistrib;
import blog.common.Util;
import blog.distrib.CondProbDistrib;
@@ -147,46 +148,91 @@ public double proposeNextState(PartialWorldDiff world) {
System.out.println(world);
}
- // Remove barren variables
- LinkedList newlyBarren = new LinkedList(world.getNewlyBarrenVars());
- while (!newlyBarren.isEmpty()) {
- BayesNetVar var = (BayesNetVar) newlyBarren.removeFirst();
- if (!evidenceVars.contains(var) && !queryVars.contains(var)) {
-
- // Remember its parents.
- Set parentSet = world.getCBN().getParents(var);
-
- if (var instanceof VarWithDistrib) {
- // Multiply in the probability of sampling this
- // variable again. Since the parent value may have
- // changed, must use the old world.
- logProbBackward += world.getSaved().getLogProbOfValue(var);
-
- // Uninstantiate
- world.setValue((VarWithDistrib) var, null);
+ // Remove unnecessary variables
+
+ Set evidenceAndQueries = new HashSet();
+ for (Iterator iter = evidenceVars.iterator(); iter.hasNext();) {
+ BayesNetVar var = (BayesNetVar) iter.next();
+ if (var instanceof BasicVar) {
+ evidenceAndQueries.add(var);
+ } else if (var instanceof DerivedVar) {
+ TraceParentRecEvalContext context = new TraceParentRecEvalContext(
+ new PartialWorldDiff(world));
+ var.ensureDetAndSupported(context);
+ if (context.getDependentVar() != null) {
+ evidenceAndQueries.addAll(context.getDependentVar());
}
-
- // Check to see if its parents are now barren.
- for (Iterator parentIter = parentSet.iterator(); parentIter.hasNext();) {
-
- // If parent is barren, add to the end of this
- // linked list. Note that if a parent has two
- // barren children, it will only be added to the
- // end of the list once, when the last child is
- // considered.
- BayesNetVar parent = (BayesNetVar) parentIter.next();
- if (world.getCBN().getChildren(parent).isEmpty())
- newlyBarren.addLast(parent);
+ }
+ }
+ for (Iterator iter = queryVars.iterator(); iter.hasNext();) {
+ BayesNetVar var = (BayesNetVar) iter.next();
+ if (var instanceof BasicVar) {
+ evidenceAndQueries.add(var);
+ } else if (var instanceof DerivedVar) {
+ TraceParentRecEvalContext context = new TraceParentRecEvalContext(
+ new PartialWorldDiff(world));
+ var.ensureDetAndSupported(context);
+ if (context.getDependentVar() != null) {
+ evidenceAndQueries.addAll(context.getDependentVar());
}
}
}
-
- // Uniform sampling from new world.
+ boolean OK;
+ do {
+ OK = true;
+ for (Iterator iter = world.getInstantiatedVars().iterator(); iter
+ .hasNext();) {
+ BasicVar var = (BasicVar) iter.next();
+ Object value = world.getValue(var);
+ world.setValue(var, null);
+ if (!evidenceAndQueriesSupported(evidenceAndQueries, world)) {
+ world.setValue(var, value);
+ } else {
+ OK = false;
+ logProbBackward += world.getSaved().getLogProbOfValue(var);
+ }
+ }
+ } while (!OK);
logProbBackward += (-Math.log(world.getInstantiatedVars().size()
- numBasicEvidenceVars));
return (logProbBackward - logProbForward);
}
+ protected boolean evidenceAndQueriesSupported(Set evidenceAndQueries,
+ PartialWorld world) {
+ PartialWorldDiff tmpWorld = new PartialWorldDiff(world);
+ for (Iterator iter = evidenceAndQueries.iterator(); iter.hasNext();) {
+ BasicVar var = (BasicVar) iter.next();
+ if (!tmpWorld.isInstantiated(var)) {
+ return false;
+ }
+ TraceParentRecEvalContext context = new TraceParentRecEvalContext(
+ tmpWorld);
+ if (var instanceof VarWithDistrib) {
+ ((VarWithDistrib) var).getDistrib(context);
+ if (context.getNumCalculateNewVars() > 0) {
+ return false;
+ }
+ }
+ }
+ for (Iterator iter = tmpWorld.getInstantiatedVars().iterator(); iter
+ .hasNext();) {
+ BasicVar var = (BasicVar) iter.next();
+ if (!tmpWorld.isInstantiated(var)) {
+ return false;
+ }
+ TraceParentRecEvalContext context = new TraceParentRecEvalContext(
+ tmpWorld);
+ if (var instanceof VarWithDistrib) {
+ ((VarWithDistrib) var).getDistrib(context);
+ if (context.getNumCalculateNewVars() > 0) {
+ return false;
+ }
+ }
+ }
+ return true;
+ }
+
// Samples a new value for the given variable (which must be
// supported in world
) and sets this new value as the
// value of the variable in world
. Then ensures that
@@ -195,7 +241,10 @@ public double proposeNextState(PartialWorldDiff world) {
// updates the logProbForward and logProbBackward variables.
protected void sampleValue(VarWithDistrib varToSample, PartialWorld world) {
// Save child set before graph becomes out of date
- Set children = world.getCBN().getChildren(varToSample);
+ Set children = new HashSet();
+ children.addAll(world.getCBN().getChildren(varToSample));
+ children.addAll(evidenceVars);
+ children.addAll(queryVars);
DependencyModel.Distrib distrib = varToSample
.getDistrib(new DefaultEvalContext(world, true));
@@ -217,9 +266,10 @@ protected void sampleValue(VarWithDistrib varToSample, PartialWorld world) {
for (Iterator childrenIter = children.iterator(); childrenIter.hasNext();) {
BayesNetVar child = (BayesNetVar) childrenIter.next();
- if (!world.isInstantiated(child)) // NOT SURE YET THIS IS THE RIGHT THING
- // TO DO! CHECKING WITH BRIAN.
- continue;
+ // if (!world.isInstantiated(child)) // NOT SURE YET THIS IS THE RIGHT
+ // THING
+ // TO DO! CHECKING WITH BRIAN.
+ // continue;
child.ensureDetAndSupported(instantiator);
}
@@ -251,6 +301,5 @@ public double latestLogProbBackward() {
public double latestLogProbForward() {
return logProbForward;
}
-
// End of debugger-only members.
}
diff --git a/src/main/java/blog/sample/ParentRecEvalContext.java b/src/main/java/blog/sample/ParentRecEvalContext.java
index fc2fdd21..f4285a1e 100644
--- a/src/main/java/blog/sample/ParentRecEvalContext.java
+++ b/src/main/java/blog/sample/ParentRecEvalContext.java
@@ -35,7 +35,9 @@
package blog.sample;
-import java.util.*;
+import java.util.Collections;
+import java.util.LinkedHashSet;
+import java.util.Set;
import blog.ObjectIdentifier;
import blog.bn.BasicVar;
@@ -53,80 +55,80 @@
* {@link #getOrComputeValue(BasicVar)} instead.
*/
public class ParentRecEvalContext extends DefaultEvalContext {
- /**
- * Creates a new ParentRecEvalContext using the given world.
- */
- public ParentRecEvalContext(PartialWorld world) {
- super(world);
- }
+ /**
+ * Creates a new ParentRecEvalContext using the given world.
+ */
+ public ParentRecEvalContext(PartialWorld world) {
+ super(world);
+ }
- /**
- * Creates a new ParentRecEvalContext using the given world. If the
- * errorIfUndet
flag is true, the access methods on this instance
- * will print error messages and exit the program if the world is not complete
- * enough to determine the correct return value. Otherwise they will just
- * return null in such cases.
- */
- public ParentRecEvalContext(PartialWorld world, boolean errorIfUndet) {
- super(world, errorIfUndet);
- }
+ /**
+ * Creates a new ParentRecEvalContext using the given world. If the
+ * errorIfUndet
flag is true, the access methods on this instance
+ * will print error messages and exit the program if the world is not complete
+ * enough to determine the correct return value. Otherwise they will just
+ * return null in such cases.
+ */
+ public ParentRecEvalContext(PartialWorld world, boolean errorIfUndet) {
+ super(world, errorIfUndet);
+ }
- final public Object getValue(BasicVar var) {
- Object value = getOrComputeValue(var);
- if (value == null) {
- latestUninstParent = var;
- var.ensureStable();
- handleMissingVar(var);
- } else {
- if (parents.add(var)) {
- var.ensureStable();
- }
- }
- return value;
- }
+ public Object getValue(BasicVar var) {
+ Object value = getOrComputeValue(var);
+ if (value == null) {
+ latestUninstParent = var;
+ var.ensureStable();
+ handleMissingVar(var);
+ } else {
+ if (parents.add(var)) {
+ var.ensureStable();
+ }
+ }
+ return value;
+ }
- protected Object getOrComputeValue(BasicVar var) {
- return world.getValue(var);
- }
+ protected Object getOrComputeValue(BasicVar var) {
+ return world.getValue(var);
+ }
- // Note that we don't have to override getSatisfiers, because the
- // DefaultEvalContext implementation of getSatisfiers calls getValue
- // on the number variable
+ // Note that we don't have to override getSatisfiers, because the
+ // DefaultEvalContext implementation of getSatisfiers calls getValue
+ // on the number variable
- public NumberVar getPOPAppSatisfied(Object obj) {
- if (obj instanceof NonGuaranteedObject) {
- return ((NonGuaranteedObject) obj).getNumberVar();
- }
+ public NumberVar getPOPAppSatisfied(Object obj) {
+ if (obj instanceof NonGuaranteedObject) {
+ return ((NonGuaranteedObject) obj).getNumberVar();
+ }
- if (obj instanceof ObjectIdentifier) {
- parents.add(new OriginVar((ObjectIdentifier) obj));
- return world.getPOPAppSatisfied(obj);
- }
+ if (obj instanceof ObjectIdentifier) {
+ parents.add(new OriginVar((ObjectIdentifier) obj));
+ return world.getPOPAppSatisfied(obj);
+ }
- // Must be guaranteed object, so not generated by any number var
- return null;
- }
+ // Must be guaranteed object, so not generated by any number var
+ return null;
+ }
- /**
- * Returns the set of basic random variables that are instantiated and whose
- * values have been used in calls to the access methods. This set is backed by
- * the ParentRecEvalContext and will change as more random variables are used.
- *
- * @return unmodifiable Set of BasicVar
- */
- public Set getParents() {
- return Collections.unmodifiableSet(parents);
- }
+ /**
+ * Returns the set of basic random variables that are instantiated and whose
+ * values have been used in calls to the access methods. This set is backed by
+ * the ParentRecEvalContext and will change as more random variables are used.
+ *
+ * @return unmodifiable Set of BasicVar
+ */
+ public Set getParents() {
+ return Collections.unmodifiableSet(parents);
+ }
- /**
- * Returns the variable whose value was most recently needed by an access
- * method, but which is not instantiated. This method returns null if no such
- * variable exists.
- */
- public BasicVar getLatestUninstParent() {
- return latestUninstParent;
- }
+ /**
+ * Returns the variable whose value was most recently needed by an access
+ * method, but which is not instantiated. This method returns null if no such
+ * variable exists.
+ */
+ public BasicVar getLatestUninstParent() {
+ return latestUninstParent;
+ }
- protected Set parents = new LinkedHashSet(); // of BasicVar
- protected BasicVar latestUninstParent = null;
+ protected Set parents = new LinkedHashSet(); // of BasicVar
+ protected BasicVar latestUninstParent = null;
}
diff --git a/src/main/java/blog/sample/TraceParentRecEvalContext.java b/src/main/java/blog/sample/TraceParentRecEvalContext.java
new file mode 100644
index 00000000..f8344b71
--- /dev/null
+++ b/src/main/java/blog/sample/TraceParentRecEvalContext.java
@@ -0,0 +1,189 @@
+/**
+ *
+ */
+package blog.sample;
+
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.LinkedHashMap;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Set;
+
+import blog.ObjectIdentifier;
+import blog.bn.BasicVar;
+import blog.bn.NumberVar;
+import blog.bn.OriginVar;
+import blog.bn.VarWithDistrib;
+import blog.common.HashMapWithPreimages;
+import blog.distrib.CondProbDistrib;
+import blog.model.DependencyModel;
+import blog.model.NonGuaranteedObject;
+import blog.world.PartialWorld;
+
+/**
+ * A Evaluation Context class that could store the trace of visited BasicVars
+ * (when you try to get the distributions or values for some Basic or Derived
+ * variables).
+ * This class could also store the case levels in order to get the CBN decision
+ * tree structure (useful for the Open Universe Gibbs Sampling).
+ *
+ * @author Da Tang
+ * @since August 21, 2014
+ */
+public class TraceParentRecEvalContext extends ClassicInstantiatingEvalContext {
+
+ public TraceParentRecEvalContext(PartialWorld world) {
+ super(world);
+ }
+
+ protected TraceParentRecEvalContext(
+ PartialWorld world,
+ LinkedHashMap respVarsAndContexts) {
+ super(world, respVarsAndContexts);
+ }
+
+ protected Object getOrComputeValue(BasicVar var) {
+ Object value = world.getValue(var);
+ if (value == null) {
+ numCalculateNewVars++;
+ if (var instanceof VarWithDistrib) {
+ value = instantiate((VarWithDistrib) var);
+ } else {
+ throw new IllegalArgumentException("Don't know how to instantiate: "
+ + var);
+ }
+ }
+ return value;
+ }
+
+ public Object getValue(BasicVar var) {
+ Object value = getOrComputeValue(var);
+ if (value == null) {
+ latestUninstParent = var;
+ var.ensureStable();
+ handleMissingVar(var);
+ } else {
+ if (parents.add(var)) {
+ var.ensureStable();
+ parentTrace.addLast(var);
+ caseLevelTrace.addLast(caseLevel);
+ }
+ }
+ dependentVar.add(var);
+ return value;
+ }
+
+ public NumberVar getPOPAppSatisfied(Object obj) {
+ if (obj instanceof NonGuaranteedObject) {
+ return ((NonGuaranteedObject) obj).getNumberVar();
+ }
+
+ if (obj instanceof ObjectIdentifier) {
+ parents.add(new OriginVar((ObjectIdentifier) obj));
+ parentTrace.addLast(new OriginVar((ObjectIdentifier) obj));
+ caseLevelTrace.addLast(caseLevelTrace);
+ return world.getPOPAppSatisfied(obj);
+ }
+
+ // Must be guaranteed object, so not generated by any number var
+ return null;
+ }
+
+ protected Object instantiate(VarWithDistrib var) {
+ var.ensureStable();
+
+ /*
+ * if (Util.verbose()) { System.out.println("Need to instantiate: " + var);
+ * }
+ */
+
+ if (respVarsAndContexts.containsKey(var)) {
+ cycleError(var);
+ }
+
+ // Create a new "child" context and get the distribution for
+ // var in that context.
+ respVarsAndContexts.put(var, this);
+ TraceParentRecEvalContext spawn = new TraceParentRecEvalContext(world,
+ respVarsAndContexts);
+ spawn.afterSamplingListener = afterSamplingListener;
+ DependencyModel.Distrib distrib = var.getDistrib(spawn);
+ logProb += spawn.getLogProbability();
+ respVarsAndContexts.remove(var);
+ List parentTrace = spawn.getParentTrace(), caseLevelTrace = spawn
+ .getCaseLevelTrace();
+ for (int i = 0; i < parentTrace.size(); i++) {
+ this.parentTrace.addLast(parentTrace.get(i));
+ this.caseLevelTrace.addLast((Integer) caseLevelTrace.get(i)
+ + this.caseLevel);
+ }
+ this.caseLevel += spawn.getCaseLevel();
+
+ // Sample new value for var
+ CondProbDistrib cpd = distrib.getCPD();
+ cpd.setParams(distrib.getArgValues());
+ Object newValue = cpd.sampleVal();
+ double logProbForThisValue = cpd.getLogProb(newValue);
+ logProb += logProbForThisValue;
+
+ // Assert any identifiers that are used by var
+ Object[] args = var.args();
+ for (int i = 0; i < args.length; ++i) {
+ if (args[i] instanceof ObjectIdentifier) {
+ world.assertIdentifier((ObjectIdentifier) args[i]);
+ }
+ }
+ if (newValue instanceof ObjectIdentifier) {
+ world.assertIdentifier((ObjectIdentifier) newValue);
+ }
+
+ // Actually set value
+ world.setValue(var, newValue);
+
+ if (afterSamplingListener != null) {
+ afterSamplingListener.evaluate(var, newValue, logProbForThisValue);
+ }
+
+ if (staticAfterSamplingListener != null) {
+ staticAfterSamplingListener.evaluate(var, newValue, logProbForThisValue);
+ }
+
+ /*
+ * if (Util.verbose()) { System.out.println("Instantiated: " + var); }
+ */
+
+ return newValue;
+ }
+
+ public List getParentTrace() {
+ return Collections.unmodifiableList(parentTrace);
+ }
+
+ public void increaseCaseLevel() {
+ caseLevel++;
+ }
+
+ public List getCaseLevelTrace() {
+ return Collections.unmodifiableList(caseLevelTrace);
+ }
+
+ public int getCaseLevel() {
+ return caseLevel;
+ }
+
+ public int getNumCalculateNewVars() {
+ return numCalculateNewVars;
+ }
+
+ public Set getDependentVar() {
+ return Collections.unmodifiableSet(dependentVar);
+ }
+
+ private LinkedList parentTrace = new LinkedList();
+ private LinkedList caseLevelTrace = new LinkedList();
+ private int caseLevel = 0;
+ private HashMapWithPreimages assignment;
+ private int numCalculateNewVars = 0;
+ private Set dependentVar = new HashSet();
+}
diff --git a/tools/testing/sample-all b/tools/testing/sample-all
new file mode 100755
index 00000000..126cc021
--- /dev/null
+++ b/tools/testing/sample-all
@@ -0,0 +1,18 @@
+#!/bin/bash
+
+samples=$1
+trials=$2
+output=$3
+
+if [ "$samples" = "-h" ]; then
+ echo "Usage of sample-all"
+ echo "sample_all.sh "
+ exit 0
+fi
+
+rm -f $output
+
+for f in $(find example -name '*.blog'); do
+ echo "Running $f"
+ bash tools/testing/sample.sh $samples $f $trials $output 0
+done
diff --git a/tools/testing/sample. b/tools/testing/sample.
new file mode 100644
index 00000000..e69de29b
diff --git a/tools/testing/sample.sh b/tools/testing/sample.sh
new file mode 100755
index 00000000..91302d5b
--- /dev/null
+++ b/tools/testing/sample.sh
@@ -0,0 +1,39 @@
+#!/bin/bash
+
+samples=$1
+blogfile_fullpath=$2
+trials=$3
+# The output file for the results
+output=$4
+id=$5
+
+if [ "$1" = "-h" ]; then
+ echo "Usage of sample.sh"
+ echo "sample.sh "
+ exit 0
+fi
+
+echo "File: $blogfile_fullpath"
+
+mkdir -p tools/testing/output
+LW=tools/testing/output/tmpLW$id
+MH=tools/testing/output/tmpMH$id
+rm -f $LW
+rm -f $MH
+
+for i in `seq 1 $trials`; do
+ echo "Trial $i..."
+ ./blog -r -n $samples -q $samples --interval $samples -s blog.sample.LWSampler $blogfile_fullpath >> $LW
+ ./blog -r -n $samples -q $samples --interval $samples -s blog.sample.MHSampler $blogfile_fullpath >> $MH
+done
+
+echo "File: $blogfile_fullpath" >> $output
+echo "Trials: $trials" >> $output
+echo "Samples: $samples" >> $output
+echo "" >> $output
+python tools/testing/sampling.py $id >> $output
+echo "---------------------------------" >> $output
+
+# Remove the temporary files
+rm -f $LW
+rm -f $MH
diff --git a/tools/testing/sample.txt b/tools/testing/sample.txt
new file mode 100644
index 00000000..eeee1c7f
--- /dev/null
+++ b/tools/testing/sample.txt
@@ -0,0 +1,31 @@
+--- Instructions for sample.sh on Ubuntu Machine ---
+
+sample.sh calculates the symmetric KL divergence for every queried set of distributions of a particular BLOG example.
+It estimates the true distribution of a queried variable by taking the expectation of the BLOG example ran with
+the LW Sampler over a specified number of trials.
+It then computes the average symmetric KL divergence between the LW Sampler and the "true estimate", and
+the average symmetric KL divergence between the MH Sampler and the "true estimate"
+
+1. Install GNU parallel
+>>> sudo apt-get install parallel
+
+2. checkout latest commit of fix-bug-MHSampler
+>>> git checkout fix-bug-MHSampler
+
+3. compile
+>>> sbt/sbt compile
+
+4. run script from top-level directory of blog
+>>> pwd
+~/blog
+>>> find example -name "*.blog" | parallel -k -j 4 --eta "bash ./tools/testing/sample.sh 10000 {} 10 tools/testing/output/KL.txt {#}"
+
+- uses 10,000 samples
+- runs 10 trials
+- outputs the result to tools/testing/output/KL.txt
+- uses 4 processes (you can adjust with "-j" option)
+- prints out information (--eta)
+- this script is buggy.
+ - doesn't work for vectors (throws exception on regex matching, and even if I did get it working, the state space would be too large). - doesn't work on continuous random variables.
+ - the KL divergence is misleading (sometimes Infinity) for discrete distributions where there values in the support with near zero probability (e.g. Poisson).
+ - the script crashes when the MH Sampler doesn't work. see "Blog Examples Status" spreadsheet for list of BLOG examples that are known to fail
diff --git a/tools/testing/sampling.py b/tools/testing/sampling.py
new file mode 100644
index 00000000..3ca2103b
--- /dev/null
+++ b/tools/testing/sampling.py
@@ -0,0 +1,140 @@
+import re
+import math
+import sys
+
+class Distribution(object):
+ def __init__(self):
+ self.probs = {}
+
+ def add(self, key, value):
+ self.probs[key] = value
+
+ def get(self, key):
+ return self.probs.get(key, 0.0)
+
+ def __str__(self):
+ s = ""
+ for key in self.probs:
+ s += "\t" + key + " : " + str(self.probs[key]) + "\n"
+ return s
+
+"""
+ Returns the list of distribution
+ given a list of all the lines in a file
+"""
+def getListDistributions(lines):
+ distributions = {}
+ for index, line in enumerate(lines):
+ if len(line) >= 28 and line[:12] == "Distribution":
+ name = line[27:-1]
+ dsn = getDistribution(lines, index)
+
+ dsn_list = distributions.get(name, 0)
+ if dsn_list == 0:
+ distributions[name] = [dsn]
+ else:
+ dsn_list.append(dsn)
+
+ return distributions
+
+"""
+ startIndex = Line Number (0-indexed) where the Distribution starts
+
+ Returns the distribution (assuming it is discrete),
+ which is represented as a dictionary where
+ keys = element of the support
+ values = corresponding probabilities of the distribution taking
+ on that element of the support
+"""
+def getDistribution(lines, startIndex):
+ distrib = Distribution()
+ count = startIndex + 1
+ while (lines[count] != "======== Done ========\n"
+ and lines[count][:12] != "Distribution"):
+ line = lines[count]
+ p = re.compile(r'\s+([a-zA-Z0-9_.]+)\s+([0-9.]+)')
+ probs = p.match(line)
+ if probs:
+ distrib.add(probs.group(1), float(probs.group(2)))
+ else:
+ raise Exception("expecting a regex match for line: '" + str(line[:-1]) + "'")
+ count += 1
+ return distrib
+
+"""
+ Given a list of distributions that all have the same support and
+ all are sampled from the same origin distribution,
+ returns a distribution that is the average over all these distributions.
+"""
+def getEmpiricalDistribution(distributions):
+ dictionary = {}
+ for distribution in distributions:
+ for key in distribution.probs:
+ dictionary[key] = dictionary.get(key, 0.0) + distribution.probs[key]
+ N = len(distributions)
+ for key in dictionary:
+ dictionary[key] = dictionary[key] / N
+ dist = Distribution()
+ dist.probs = dictionary
+ return dist
+
+"""
+ Returns the symmetric KL Divergence between dsnA and dsnB
+"""
+def getSymmetricKL(dsnA, dsnB):
+ return (getKL(dsnA, dsnB) + getKL(dsnB, dsnA)) / 2
+
+def getKL(dsnA, dsnB):
+ val = 0.0
+ for key in dsnA.probs:
+ p = dsnA.get(key)
+ q = dsnB.get(key)
+ if q == 0.0:
+ return float("inf")
+ val += (math.log(p/q) * p)
+ return val
+
+def getAverageKL(distributions, actual_distribution):
+ KL = 0.0
+ for distribution in distributions:
+ KL += getSymmetricKL(distribution, actual_distribution)
+ return KL / len(distributions)
+
+# Read in argument
+# The first argument is used for parallel processes
+# to distinguish the different temporary files open at one time
+if len(sys.argv) != 2:
+ print "must provide exactly 2 arguments"
+
+TMP_LW="tools/testing/output/tmpLW" + sys.argv[1]
+TMP_MH="tools/testing/output/tmpMH" + sys.argv[1]
+
+# Read in the LW Sampling
+f = open(TMP_LW, "r")
+dsnsLW = getListDistributions(f.readlines())
+distribution_estimates = {}
+for name in dsnsLW:
+ distribution_list = dsnsLW[name]
+ distribution_estimates[name] = getEmpiricalDistribution(distribution_list)
+
+# Read in the MH Sampling
+f = open(TMP_MH, "r")
+dsnsMH = getListDistributions(f.readlines())
+
+# Print out the KLs of each distribution for MH vs. LW
+print("%-25s %-10s %-10s %-10s %-10s" % ("Random Variable", "LW", "MH", "Log-LW", "Log-MH"))
+for name in distribution_estimates:
+ KL_LW = getAverageKL(dsnsLW[name], distribution_estimates[name])
+ KL_MH = getAverageKL(dsnsMH[name], distribution_estimates[name])
+ #KL_LW_DER = -1.0 / math.log(KL_LW)
+ #KL_MH_DER = -1.0 / math.log(KL_MH)
+ if KL_MH == 0.0:
+ KL_MH_DER = float("inf")
+ else:
+ KL_MH_DER = -1 * math.log(KL_MH)
+
+ if KL_LW == 0.0:
+ KL_LW_DER = float("inf")
+ else:
+ KL_LW_DER = -1 * math.log(KL_LW)
+ print("%-25s %-10.6f %-10.6f %-10.3f %-10.3f" % (name, KL_LW, KL_MH, KL_LW_DER, KL_MH_DER))