diff --git a/example/hurricane2.blog b/example/hurricane2.blog new file mode 100644 index 00000000..6b5cdd9b --- /dev/null +++ b/example/hurricane2.blog @@ -0,0 +1,32 @@ +/** + * Hurricane + * Figure 4.2 in Milch's thesis + */ + +type City; +type PrepLevel; +type DamageLevel; + +random City First ~ UniformChoice({c for City c}); + +random City NotFirst ~ UniformChoice({c for City c: c != First}); + +random PrepLevel Prep(City c) ~ + if (First == c) then Categorical({High -> 0.5, Low -> 0.5}) + else case Damage(First) in + {Severe -> Categorical({High -> 0.9, Low -> 0.1}), + Mild -> Categorical({High -> 0.1, Low -> 0.9})} + ; + +random DamageLevel Damage(City c) ~ + case Prep(c) in {High -> Categorical({Severe -> 0.2, Mild -> 0.8}), + Low -> Categorical({Severe -> 0.8, Mild -> 0.2})} + ; + +distinct City A, B; +distinct PrepLevel Low, High; +distinct DamageLevel Severe, Mild; + +obs Damage(First) = Severe; + +query First; diff --git a/src/main/java/blog/sample/GenericProposer.java b/src/main/java/blog/sample/GenericProposer.java index 450bb2a5..1f002759 100644 --- a/src/main/java/blog/sample/GenericProposer.java +++ b/src/main/java/blog/sample/GenericProposer.java @@ -37,11 +37,12 @@ import java.util.HashSet; import java.util.Iterator; -import java.util.LinkedList; import java.util.Properties; import java.util.Set; +import blog.bn.BasicVar; import blog.bn.BayesNetVar; +import blog.bn.DerivedVar; import blog.bn.VarWithDistrib; import blog.common.Util; import blog.distrib.CondProbDistrib; @@ -147,46 +148,91 @@ public double proposeNextState(PartialWorldDiff world) { System.out.println(world); } - // Remove barren variables - LinkedList newlyBarren = new LinkedList(world.getNewlyBarrenVars()); - while (!newlyBarren.isEmpty()) { - BayesNetVar var = (BayesNetVar) newlyBarren.removeFirst(); - if (!evidenceVars.contains(var) && !queryVars.contains(var)) { - - // Remember its parents. - Set parentSet = world.getCBN().getParents(var); - - if (var instanceof VarWithDistrib) { - // Multiply in the probability of sampling this - // variable again. Since the parent value may have - // changed, must use the old world. - logProbBackward += world.getSaved().getLogProbOfValue(var); - - // Uninstantiate - world.setValue((VarWithDistrib) var, null); + // Remove unnecessary variables + + Set evidenceAndQueries = new HashSet(); + for (Iterator iter = evidenceVars.iterator(); iter.hasNext();) { + BayesNetVar var = (BayesNetVar) iter.next(); + if (var instanceof BasicVar) { + evidenceAndQueries.add(var); + } else if (var instanceof DerivedVar) { + TraceParentRecEvalContext context = new TraceParentRecEvalContext( + new PartialWorldDiff(world)); + var.ensureDetAndSupported(context); + if (context.getDependentVar() != null) { + evidenceAndQueries.addAll(context.getDependentVar()); } - - // Check to see if its parents are now barren. - for (Iterator parentIter = parentSet.iterator(); parentIter.hasNext();) { - - // If parent is barren, add to the end of this - // linked list. Note that if a parent has two - // barren children, it will only be added to the - // end of the list once, when the last child is - // considered. - BayesNetVar parent = (BayesNetVar) parentIter.next(); - if (world.getCBN().getChildren(parent).isEmpty()) - newlyBarren.addLast(parent); + } + } + for (Iterator iter = queryVars.iterator(); iter.hasNext();) { + BayesNetVar var = (BayesNetVar) iter.next(); + if (var instanceof BasicVar) { + evidenceAndQueries.add(var); + } else if (var instanceof DerivedVar) { + TraceParentRecEvalContext context = new TraceParentRecEvalContext( + new PartialWorldDiff(world)); + var.ensureDetAndSupported(context); + if (context.getDependentVar() != null) { + evidenceAndQueries.addAll(context.getDependentVar()); } } } - - // Uniform sampling from new world. + boolean OK; + do { + OK = true; + for (Iterator iter = world.getInstantiatedVars().iterator(); iter + .hasNext();) { + BasicVar var = (BasicVar) iter.next(); + Object value = world.getValue(var); + world.setValue(var, null); + if (!evidenceAndQueriesSupported(evidenceAndQueries, world)) { + world.setValue(var, value); + } else { + OK = false; + logProbBackward += world.getSaved().getLogProbOfValue(var); + } + } + } while (!OK); logProbBackward += (-Math.log(world.getInstantiatedVars().size() - numBasicEvidenceVars)); return (logProbBackward - logProbForward); } + protected boolean evidenceAndQueriesSupported(Set evidenceAndQueries, + PartialWorld world) { + PartialWorldDiff tmpWorld = new PartialWorldDiff(world); + for (Iterator iter = evidenceAndQueries.iterator(); iter.hasNext();) { + BasicVar var = (BasicVar) iter.next(); + if (!tmpWorld.isInstantiated(var)) { + return false; + } + TraceParentRecEvalContext context = new TraceParentRecEvalContext( + tmpWorld); + if (var instanceof VarWithDistrib) { + ((VarWithDistrib) var).getDistrib(context); + if (context.getNumCalculateNewVars() > 0) { + return false; + } + } + } + for (Iterator iter = tmpWorld.getInstantiatedVars().iterator(); iter + .hasNext();) { + BasicVar var = (BasicVar) iter.next(); + if (!tmpWorld.isInstantiated(var)) { + return false; + } + TraceParentRecEvalContext context = new TraceParentRecEvalContext( + tmpWorld); + if (var instanceof VarWithDistrib) { + ((VarWithDistrib) var).getDistrib(context); + if (context.getNumCalculateNewVars() > 0) { + return false; + } + } + } + return true; + } + // Samples a new value for the given variable (which must be // supported in world) and sets this new value as the // value of the variable in world. Then ensures that @@ -195,7 +241,10 @@ public double proposeNextState(PartialWorldDiff world) { // updates the logProbForward and logProbBackward variables. protected void sampleValue(VarWithDistrib varToSample, PartialWorld world) { // Save child set before graph becomes out of date - Set children = world.getCBN().getChildren(varToSample); + Set children = new HashSet(); + children.addAll(world.getCBN().getChildren(varToSample)); + children.addAll(evidenceVars); + children.addAll(queryVars); DependencyModel.Distrib distrib = varToSample .getDistrib(new DefaultEvalContext(world, true)); @@ -217,9 +266,10 @@ protected void sampleValue(VarWithDistrib varToSample, PartialWorld world) { for (Iterator childrenIter = children.iterator(); childrenIter.hasNext();) { BayesNetVar child = (BayesNetVar) childrenIter.next(); - if (!world.isInstantiated(child)) // NOT SURE YET THIS IS THE RIGHT THING - // TO DO! CHECKING WITH BRIAN. - continue; + // if (!world.isInstantiated(child)) // NOT SURE YET THIS IS THE RIGHT + // THING + // TO DO! CHECKING WITH BRIAN. + // continue; child.ensureDetAndSupported(instantiator); } @@ -251,6 +301,5 @@ public double latestLogProbBackward() { public double latestLogProbForward() { return logProbForward; } - // End of debugger-only members. } diff --git a/src/main/java/blog/sample/ParentRecEvalContext.java b/src/main/java/blog/sample/ParentRecEvalContext.java index fc2fdd21..f4285a1e 100644 --- a/src/main/java/blog/sample/ParentRecEvalContext.java +++ b/src/main/java/blog/sample/ParentRecEvalContext.java @@ -35,7 +35,9 @@ package blog.sample; -import java.util.*; +import java.util.Collections; +import java.util.LinkedHashSet; +import java.util.Set; import blog.ObjectIdentifier; import blog.bn.BasicVar; @@ -53,80 +55,80 @@ * {@link #getOrComputeValue(BasicVar)} instead. */ public class ParentRecEvalContext extends DefaultEvalContext { - /** - * Creates a new ParentRecEvalContext using the given world. - */ - public ParentRecEvalContext(PartialWorld world) { - super(world); - } + /** + * Creates a new ParentRecEvalContext using the given world. + */ + public ParentRecEvalContext(PartialWorld world) { + super(world); + } - /** - * Creates a new ParentRecEvalContext using the given world. If the - * errorIfUndet flag is true, the access methods on this instance - * will print error messages and exit the program if the world is not complete - * enough to determine the correct return value. Otherwise they will just - * return null in such cases. - */ - public ParentRecEvalContext(PartialWorld world, boolean errorIfUndet) { - super(world, errorIfUndet); - } + /** + * Creates a new ParentRecEvalContext using the given world. If the + * errorIfUndet flag is true, the access methods on this instance + * will print error messages and exit the program if the world is not complete + * enough to determine the correct return value. Otherwise they will just + * return null in such cases. + */ + public ParentRecEvalContext(PartialWorld world, boolean errorIfUndet) { + super(world, errorIfUndet); + } - final public Object getValue(BasicVar var) { - Object value = getOrComputeValue(var); - if (value == null) { - latestUninstParent = var; - var.ensureStable(); - handleMissingVar(var); - } else { - if (parents.add(var)) { - var.ensureStable(); - } - } - return value; - } + public Object getValue(BasicVar var) { + Object value = getOrComputeValue(var); + if (value == null) { + latestUninstParent = var; + var.ensureStable(); + handleMissingVar(var); + } else { + if (parents.add(var)) { + var.ensureStable(); + } + } + return value; + } - protected Object getOrComputeValue(BasicVar var) { - return world.getValue(var); - } + protected Object getOrComputeValue(BasicVar var) { + return world.getValue(var); + } - // Note that we don't have to override getSatisfiers, because the - // DefaultEvalContext implementation of getSatisfiers calls getValue - // on the number variable + // Note that we don't have to override getSatisfiers, because the + // DefaultEvalContext implementation of getSatisfiers calls getValue + // on the number variable - public NumberVar getPOPAppSatisfied(Object obj) { - if (obj instanceof NonGuaranteedObject) { - return ((NonGuaranteedObject) obj).getNumberVar(); - } + public NumberVar getPOPAppSatisfied(Object obj) { + if (obj instanceof NonGuaranteedObject) { + return ((NonGuaranteedObject) obj).getNumberVar(); + } - if (obj instanceof ObjectIdentifier) { - parents.add(new OriginVar((ObjectIdentifier) obj)); - return world.getPOPAppSatisfied(obj); - } + if (obj instanceof ObjectIdentifier) { + parents.add(new OriginVar((ObjectIdentifier) obj)); + return world.getPOPAppSatisfied(obj); + } - // Must be guaranteed object, so not generated by any number var - return null; - } + // Must be guaranteed object, so not generated by any number var + return null; + } - /** - * Returns the set of basic random variables that are instantiated and whose - * values have been used in calls to the access methods. This set is backed by - * the ParentRecEvalContext and will change as more random variables are used. - * - * @return unmodifiable Set of BasicVar - */ - public Set getParents() { - return Collections.unmodifiableSet(parents); - } + /** + * Returns the set of basic random variables that are instantiated and whose + * values have been used in calls to the access methods. This set is backed by + * the ParentRecEvalContext and will change as more random variables are used. + * + * @return unmodifiable Set of BasicVar + */ + public Set getParents() { + return Collections.unmodifiableSet(parents); + } - /** - * Returns the variable whose value was most recently needed by an access - * method, but which is not instantiated. This method returns null if no such - * variable exists. - */ - public BasicVar getLatestUninstParent() { - return latestUninstParent; - } + /** + * Returns the variable whose value was most recently needed by an access + * method, but which is not instantiated. This method returns null if no such + * variable exists. + */ + public BasicVar getLatestUninstParent() { + return latestUninstParent; + } - protected Set parents = new LinkedHashSet(); // of BasicVar - protected BasicVar latestUninstParent = null; + protected Set parents = new LinkedHashSet(); // of BasicVar + protected BasicVar latestUninstParent = null; } diff --git a/src/main/java/blog/sample/TraceParentRecEvalContext.java b/src/main/java/blog/sample/TraceParentRecEvalContext.java new file mode 100644 index 00000000..f8344b71 --- /dev/null +++ b/src/main/java/blog/sample/TraceParentRecEvalContext.java @@ -0,0 +1,189 @@ +/** + * + */ +package blog.sample; + +import java.util.Collections; +import java.util.HashSet; +import java.util.LinkedHashMap; +import java.util.LinkedList; +import java.util.List; +import java.util.Set; + +import blog.ObjectIdentifier; +import blog.bn.BasicVar; +import blog.bn.NumberVar; +import blog.bn.OriginVar; +import blog.bn.VarWithDistrib; +import blog.common.HashMapWithPreimages; +import blog.distrib.CondProbDistrib; +import blog.model.DependencyModel; +import blog.model.NonGuaranteedObject; +import blog.world.PartialWorld; + +/** + * A Evaluation Context class that could store the trace of visited BasicVars + * (when you try to get the distributions or values for some Basic or Derived + * variables). + * This class could also store the case levels in order to get the CBN decision + * tree structure (useful for the Open Universe Gibbs Sampling). + * + * @author Da Tang + * @since August 21, 2014 + */ +public class TraceParentRecEvalContext extends ClassicInstantiatingEvalContext { + + public TraceParentRecEvalContext(PartialWorld world) { + super(world); + } + + protected TraceParentRecEvalContext( + PartialWorld world, + LinkedHashMap respVarsAndContexts) { + super(world, respVarsAndContexts); + } + + protected Object getOrComputeValue(BasicVar var) { + Object value = world.getValue(var); + if (value == null) { + numCalculateNewVars++; + if (var instanceof VarWithDistrib) { + value = instantiate((VarWithDistrib) var); + } else { + throw new IllegalArgumentException("Don't know how to instantiate: " + + var); + } + } + return value; + } + + public Object getValue(BasicVar var) { + Object value = getOrComputeValue(var); + if (value == null) { + latestUninstParent = var; + var.ensureStable(); + handleMissingVar(var); + } else { + if (parents.add(var)) { + var.ensureStable(); + parentTrace.addLast(var); + caseLevelTrace.addLast(caseLevel); + } + } + dependentVar.add(var); + return value; + } + + public NumberVar getPOPAppSatisfied(Object obj) { + if (obj instanceof NonGuaranteedObject) { + return ((NonGuaranteedObject) obj).getNumberVar(); + } + + if (obj instanceof ObjectIdentifier) { + parents.add(new OriginVar((ObjectIdentifier) obj)); + parentTrace.addLast(new OriginVar((ObjectIdentifier) obj)); + caseLevelTrace.addLast(caseLevelTrace); + return world.getPOPAppSatisfied(obj); + } + + // Must be guaranteed object, so not generated by any number var + return null; + } + + protected Object instantiate(VarWithDistrib var) { + var.ensureStable(); + + /* + * if (Util.verbose()) { System.out.println("Need to instantiate: " + var); + * } + */ + + if (respVarsAndContexts.containsKey(var)) { + cycleError(var); + } + + // Create a new "child" context and get the distribution for + // var in that context. + respVarsAndContexts.put(var, this); + TraceParentRecEvalContext spawn = new TraceParentRecEvalContext(world, + respVarsAndContexts); + spawn.afterSamplingListener = afterSamplingListener; + DependencyModel.Distrib distrib = var.getDistrib(spawn); + logProb += spawn.getLogProbability(); + respVarsAndContexts.remove(var); + List parentTrace = spawn.getParentTrace(), caseLevelTrace = spawn + .getCaseLevelTrace(); + for (int i = 0; i < parentTrace.size(); i++) { + this.parentTrace.addLast(parentTrace.get(i)); + this.caseLevelTrace.addLast((Integer) caseLevelTrace.get(i) + + this.caseLevel); + } + this.caseLevel += spawn.getCaseLevel(); + + // Sample new value for var + CondProbDistrib cpd = distrib.getCPD(); + cpd.setParams(distrib.getArgValues()); + Object newValue = cpd.sampleVal(); + double logProbForThisValue = cpd.getLogProb(newValue); + logProb += logProbForThisValue; + + // Assert any identifiers that are used by var + Object[] args = var.args(); + for (int i = 0; i < args.length; ++i) { + if (args[i] instanceof ObjectIdentifier) { + world.assertIdentifier((ObjectIdentifier) args[i]); + } + } + if (newValue instanceof ObjectIdentifier) { + world.assertIdentifier((ObjectIdentifier) newValue); + } + + // Actually set value + world.setValue(var, newValue); + + if (afterSamplingListener != null) { + afterSamplingListener.evaluate(var, newValue, logProbForThisValue); + } + + if (staticAfterSamplingListener != null) { + staticAfterSamplingListener.evaluate(var, newValue, logProbForThisValue); + } + + /* + * if (Util.verbose()) { System.out.println("Instantiated: " + var); } + */ + + return newValue; + } + + public List getParentTrace() { + return Collections.unmodifiableList(parentTrace); + } + + public void increaseCaseLevel() { + caseLevel++; + } + + public List getCaseLevelTrace() { + return Collections.unmodifiableList(caseLevelTrace); + } + + public int getCaseLevel() { + return caseLevel; + } + + public int getNumCalculateNewVars() { + return numCalculateNewVars; + } + + public Set getDependentVar() { + return Collections.unmodifiableSet(dependentVar); + } + + private LinkedList parentTrace = new LinkedList(); + private LinkedList caseLevelTrace = new LinkedList(); + private int caseLevel = 0; + private HashMapWithPreimages assignment; + private int numCalculateNewVars = 0; + private Set dependentVar = new HashSet(); +} diff --git a/tools/testing/sample-all b/tools/testing/sample-all new file mode 100755 index 00000000..126cc021 --- /dev/null +++ b/tools/testing/sample-all @@ -0,0 +1,18 @@ +#!/bin/bash + +samples=$1 +trials=$2 +output=$3 + +if [ "$samples" = "-h" ]; then + echo "Usage of sample-all" + echo "sample_all.sh " + exit 0 +fi + +rm -f $output + +for f in $(find example -name '*.blog'); do + echo "Running $f" + bash tools/testing/sample.sh $samples $f $trials $output 0 +done diff --git a/tools/testing/sample. b/tools/testing/sample. new file mode 100644 index 00000000..e69de29b diff --git a/tools/testing/sample.sh b/tools/testing/sample.sh new file mode 100755 index 00000000..91302d5b --- /dev/null +++ b/tools/testing/sample.sh @@ -0,0 +1,39 @@ +#!/bin/bash + +samples=$1 +blogfile_fullpath=$2 +trials=$3 +# The output file for the results +output=$4 +id=$5 + +if [ "$1" = "-h" ]; then + echo "Usage of sample.sh" + echo "sample.sh " + exit 0 +fi + +echo "File: $blogfile_fullpath" + +mkdir -p tools/testing/output +LW=tools/testing/output/tmpLW$id +MH=tools/testing/output/tmpMH$id +rm -f $LW +rm -f $MH + +for i in `seq 1 $trials`; do + echo "Trial $i..." + ./blog -r -n $samples -q $samples --interval $samples -s blog.sample.LWSampler $blogfile_fullpath >> $LW + ./blog -r -n $samples -q $samples --interval $samples -s blog.sample.MHSampler $blogfile_fullpath >> $MH +done + +echo "File: $blogfile_fullpath" >> $output +echo "Trials: $trials" >> $output +echo "Samples: $samples" >> $output +echo "" >> $output +python tools/testing/sampling.py $id >> $output +echo "---------------------------------" >> $output + +# Remove the temporary files +rm -f $LW +rm -f $MH diff --git a/tools/testing/sample.txt b/tools/testing/sample.txt new file mode 100644 index 00000000..eeee1c7f --- /dev/null +++ b/tools/testing/sample.txt @@ -0,0 +1,31 @@ +--- Instructions for sample.sh on Ubuntu Machine --- + +sample.sh calculates the symmetric KL divergence for every queried set of distributions of a particular BLOG example. +It estimates the true distribution of a queried variable by taking the expectation of the BLOG example ran with +the LW Sampler over a specified number of trials. +It then computes the average symmetric KL divergence between the LW Sampler and the "true estimate", and +the average symmetric KL divergence between the MH Sampler and the "true estimate" + +1. Install GNU parallel +>>> sudo apt-get install parallel + +2. checkout latest commit of fix-bug-MHSampler +>>> git checkout fix-bug-MHSampler + +3. compile +>>> sbt/sbt compile + +4. run script from top-level directory of blog +>>> pwd +~/blog +>>> find example -name "*.blog" | parallel -k -j 4 --eta "bash ./tools/testing/sample.sh 10000 {} 10 tools/testing/output/KL.txt {#}" + +- uses 10,000 samples +- runs 10 trials +- outputs the result to tools/testing/output/KL.txt +- uses 4 processes (you can adjust with "-j" option) +- prints out information (--eta) +- this script is buggy. + - doesn't work for vectors (throws exception on regex matching, and even if I did get it working, the state space would be too large). - doesn't work on continuous random variables. + - the KL divergence is misleading (sometimes Infinity) for discrete distributions where there values in the support with near zero probability (e.g. Poisson). + - the script crashes when the MH Sampler doesn't work. see "Blog Examples Status" spreadsheet for list of BLOG examples that are known to fail diff --git a/tools/testing/sampling.py b/tools/testing/sampling.py new file mode 100644 index 00000000..3ca2103b --- /dev/null +++ b/tools/testing/sampling.py @@ -0,0 +1,140 @@ +import re +import math +import sys + +class Distribution(object): + def __init__(self): + self.probs = {} + + def add(self, key, value): + self.probs[key] = value + + def get(self, key): + return self.probs.get(key, 0.0) + + def __str__(self): + s = "" + for key in self.probs: + s += "\t" + key + " : " + str(self.probs[key]) + "\n" + return s + +""" + Returns the list of distribution + given a list of all the lines in a file +""" +def getListDistributions(lines): + distributions = {} + for index, line in enumerate(lines): + if len(line) >= 28 and line[:12] == "Distribution": + name = line[27:-1] + dsn = getDistribution(lines, index) + + dsn_list = distributions.get(name, 0) + if dsn_list == 0: + distributions[name] = [dsn] + else: + dsn_list.append(dsn) + + return distributions + +""" + startIndex = Line Number (0-indexed) where the Distribution starts + + Returns the distribution (assuming it is discrete), + which is represented as a dictionary where + keys = element of the support + values = corresponding probabilities of the distribution taking + on that element of the support +""" +def getDistribution(lines, startIndex): + distrib = Distribution() + count = startIndex + 1 + while (lines[count] != "======== Done ========\n" + and lines[count][:12] != "Distribution"): + line = lines[count] + p = re.compile(r'\s+([a-zA-Z0-9_.]+)\s+([0-9.]+)') + probs = p.match(line) + if probs: + distrib.add(probs.group(1), float(probs.group(2))) + else: + raise Exception("expecting a regex match for line: '" + str(line[:-1]) + "'") + count += 1 + return distrib + +""" + Given a list of distributions that all have the same support and + all are sampled from the same origin distribution, + returns a distribution that is the average over all these distributions. +""" +def getEmpiricalDistribution(distributions): + dictionary = {} + for distribution in distributions: + for key in distribution.probs: + dictionary[key] = dictionary.get(key, 0.0) + distribution.probs[key] + N = len(distributions) + for key in dictionary: + dictionary[key] = dictionary[key] / N + dist = Distribution() + dist.probs = dictionary + return dist + +""" + Returns the symmetric KL Divergence between dsnA and dsnB +""" +def getSymmetricKL(dsnA, dsnB): + return (getKL(dsnA, dsnB) + getKL(dsnB, dsnA)) / 2 + +def getKL(dsnA, dsnB): + val = 0.0 + for key in dsnA.probs: + p = dsnA.get(key) + q = dsnB.get(key) + if q == 0.0: + return float("inf") + val += (math.log(p/q) * p) + return val + +def getAverageKL(distributions, actual_distribution): + KL = 0.0 + for distribution in distributions: + KL += getSymmetricKL(distribution, actual_distribution) + return KL / len(distributions) + +# Read in argument +# The first argument is used for parallel processes +# to distinguish the different temporary files open at one time +if len(sys.argv) != 2: + print "must provide exactly 2 arguments" + +TMP_LW="tools/testing/output/tmpLW" + sys.argv[1] +TMP_MH="tools/testing/output/tmpMH" + sys.argv[1] + +# Read in the LW Sampling +f = open(TMP_LW, "r") +dsnsLW = getListDistributions(f.readlines()) +distribution_estimates = {} +for name in dsnsLW: + distribution_list = dsnsLW[name] + distribution_estimates[name] = getEmpiricalDistribution(distribution_list) + +# Read in the MH Sampling +f = open(TMP_MH, "r") +dsnsMH = getListDistributions(f.readlines()) + +# Print out the KLs of each distribution for MH vs. LW +print("%-25s %-10s %-10s %-10s %-10s" % ("Random Variable", "LW", "MH", "Log-LW", "Log-MH")) +for name in distribution_estimates: + KL_LW = getAverageKL(dsnsLW[name], distribution_estimates[name]) + KL_MH = getAverageKL(dsnsMH[name], distribution_estimates[name]) + #KL_LW_DER = -1.0 / math.log(KL_LW) + #KL_MH_DER = -1.0 / math.log(KL_MH) + if KL_MH == 0.0: + KL_MH_DER = float("inf") + else: + KL_MH_DER = -1 * math.log(KL_MH) + + if KL_LW == 0.0: + KL_LW_DER = float("inf") + else: + KL_LW_DER = -1 * math.log(KL_LW) + print("%-25s %-10.6f %-10.6f %-10.3f %-10.3f" % (name, KL_LW, KL_MH, KL_LW_DER, KL_MH_DER))