From ddd59663f4611d8bef9645a71a8d5808920cf2b7 Mon Sep 17 00:00:00 2001 From: td11 Date: Thu, 21 Aug 2014 11:09:51 -0700 Subject: [PATCH 01/10] fix the bug of MHSampler. --- .../java/blog/sample/GenericProposer.java | 146 ++++++++++---- .../blog/sample/ParentRecEvalContext.java | 136 ++++++------- .../sample/TraceParentRecEvalContext.java | 181 ++++++++++++++++++ 3 files changed, 364 insertions(+), 99 deletions(-) create mode 100644 src/main/java/blog/sample/TraceParentRecEvalContext.java diff --git a/src/main/java/blog/sample/GenericProposer.java b/src/main/java/blog/sample/GenericProposer.java index 450bb2a5..82f62f87 100644 --- a/src/main/java/blog/sample/GenericProposer.java +++ b/src/main/java/blog/sample/GenericProposer.java @@ -37,11 +37,12 @@ import java.util.HashSet; import java.util.Iterator; -import java.util.LinkedList; import java.util.Properties; import java.util.Set; +import blog.bn.BasicVar; import blog.bn.BayesNetVar; +import blog.bn.DerivedVar; import blog.bn.VarWithDistrib; import blog.common.Util; import blog.distrib.CondProbDistrib; @@ -147,46 +148,128 @@ public double proposeNextState(PartialWorldDiff world) { System.out.println(world); } - // Remove barren variables - LinkedList newlyBarren = new LinkedList(world.getNewlyBarrenVars()); - while (!newlyBarren.isEmpty()) { - BayesNetVar var = (BayesNetVar) newlyBarren.removeFirst(); - if (!evidenceVars.contains(var) && !queryVars.contains(var)) { - - // Remember its parents. - Set parentSet = world.getCBN().getParents(var); - - if (var instanceof VarWithDistrib) { - // Multiply in the probability of sampling this - // variable again. Since the parent value may have - // changed, must use the old world. - logProbBackward += world.getSaved().getLogProbOfValue(var); - - // Uninstantiate - world.setValue((VarWithDistrib) var, null); + // Remove unnecessary variables + + Set evidenceAndQueries = new HashSet(); + for (Iterator iter = evidenceVars.iterator(); iter.hasNext();) { + BayesNetVar var = (BayesNetVar) iter.next(); + if (var instanceof BasicVar) { + evidenceAndQueries.add(var); + } else if (var instanceof DerivedVar) { + TraceParentRecEvalContext context = new TraceParentRecEvalContext( + new PartialWorldDiff(world)); + var.ensureDetAndSupported(context); + if (context.getCorrespondingVar() != null) { + evidenceAndQueries.add(context.getCorrespondingVar()); } - - // Check to see if its parents are now barren. - for (Iterator parentIter = parentSet.iterator(); parentIter.hasNext();) { - - // If parent is barren, add to the end of this - // linked list. Note that if a parent has two - // barren children, it will only be added to the - // end of the list once, when the last child is - // considered. - BayesNetVar parent = (BayesNetVar) parentIter.next(); - if (world.getCBN().getChildren(parent).isEmpty()) - newlyBarren.addLast(parent); + } + } + for (Iterator iter = queryVars.iterator(); iter.hasNext();) { + BayesNetVar var = (BayesNetVar) iter.next(); + if (var instanceof BasicVar) { + evidenceAndQueries.add(var); + } else if (var instanceof DerivedVar) { + TraceParentRecEvalContext context = new TraceParentRecEvalContext( + new PartialWorldDiff(world)); + var.ensureDetAndSupported(context); + if (context.getCorrespondingVar() != null) { + evidenceAndQueries.add(context.getCorrespondingVar()); } } } - + boolean OK; + do { + OK = true; + for (Iterator iter = world.getInstantiatedVars().iterator(); iter + .hasNext();) { + BasicVar var = (BasicVar) iter.next(); + Object value = world.getValue(var); + world.setValue(var, null); + if (!evidenceAndQueriesSupported(evidenceAndQueries, world)) { + world.setValue(var, value); + } else { + OK = false; + logProbBackward += world.getSaved().getLogProbOfValue(var); + } + } + } while (!OK); + + /* + * // Remove barren variables + * LinkedList newlyBarren = new LinkedList(world.getNewlyBarrenVars()); + * while (!newlyBarren.isEmpty()) { + * BayesNetVar var = (BayesNetVar) newlyBarren.removeFirst(); + * if (!evidenceVars.contains(var) && !queryVars.contains(var)) { + * + * // Remember its parents. + * Set parentSet = world.getCBN().getParents(var); + * + * if (var instanceof VarWithDistrib) { + * // Multiply in the probability of sampling this + * // variable again. Since the parent value may have + * // changed, must use the old world. + * logProbBackward += world.getSaved().getLogProbOfValue(var); + * + * // Uninstantiate + * world.setValue((VarWithDistrib) var, null); + * } + * + * // Check to see if its parents are now barren. + * for (Iterator parentIter = parentSet.iterator(); parentIter.hasNext();) { + * + * // If parent is barren, add to the end of this + * // linked list. Note that if a parent has two + * // barren children, it will only be added to the + * // end of the list once, when the last child is + * // considered. + * BayesNetVar parent = (BayesNetVar) parentIter.next(); + * if (world.getCBN().getChildren(parent).isEmpty()) + * newlyBarren.addLast(parent); + * } + * } + * } + */ // Uniform sampling from new world. logProbBackward += (-Math.log(world.getInstantiatedVars().size() - numBasicEvidenceVars)); return (logProbBackward - logProbForward); } + protected boolean evidenceAndQueriesSupported(Set evidenceAndQueries, + PartialWorld world) { + PartialWorldDiff tmpWorld = new PartialWorldDiff(world); + for (Iterator iter = evidenceAndQueries.iterator(); iter.hasNext();) { + BasicVar var = (BasicVar) iter.next(); + if (!tmpWorld.isInstantiated(var)) { + return false; + } + TraceParentRecEvalContext context = new TraceParentRecEvalContext( + tmpWorld); + if (var instanceof VarWithDistrib) { + ((VarWithDistrib) var).getDistrib(context); + if (context.getNumCalculateNewVars() > 0) { + return false; + } + } + } + for (Iterator iter = tmpWorld.getInstantiatedVars().iterator(); iter + .hasNext();) { + BasicVar var = (BasicVar) iter.next(); + if (!tmpWorld.isInstantiated(var)) { + return false; + } + TraceParentRecEvalContext context = new TraceParentRecEvalContext( + tmpWorld); + if (var instanceof VarWithDistrib) { + ((VarWithDistrib) var).getDistrib(context); + if (context.getNumCalculateNewVars() > 0) { + return false; + } + } + } + return true; + } + // Samples a new value for the given variable (which must be // supported in world) and sets this new value as the // value of the variable in world. Then ensures that @@ -251,6 +334,5 @@ public double latestLogProbBackward() { public double latestLogProbForward() { return logProbForward; } - // End of debugger-only members. } diff --git a/src/main/java/blog/sample/ParentRecEvalContext.java b/src/main/java/blog/sample/ParentRecEvalContext.java index fc2fdd21..f4285a1e 100644 --- a/src/main/java/blog/sample/ParentRecEvalContext.java +++ b/src/main/java/blog/sample/ParentRecEvalContext.java @@ -35,7 +35,9 @@ package blog.sample; -import java.util.*; +import java.util.Collections; +import java.util.LinkedHashSet; +import java.util.Set; import blog.ObjectIdentifier; import blog.bn.BasicVar; @@ -53,80 +55,80 @@ * {@link #getOrComputeValue(BasicVar)} instead. */ public class ParentRecEvalContext extends DefaultEvalContext { - /** - * Creates a new ParentRecEvalContext using the given world. - */ - public ParentRecEvalContext(PartialWorld world) { - super(world); - } + /** + * Creates a new ParentRecEvalContext using the given world. + */ + public ParentRecEvalContext(PartialWorld world) { + super(world); + } - /** - * Creates a new ParentRecEvalContext using the given world. If the - * errorIfUndet flag is true, the access methods on this instance - * will print error messages and exit the program if the world is not complete - * enough to determine the correct return value. Otherwise they will just - * return null in such cases. - */ - public ParentRecEvalContext(PartialWorld world, boolean errorIfUndet) { - super(world, errorIfUndet); - } + /** + * Creates a new ParentRecEvalContext using the given world. If the + * errorIfUndet flag is true, the access methods on this instance + * will print error messages and exit the program if the world is not complete + * enough to determine the correct return value. Otherwise they will just + * return null in such cases. + */ + public ParentRecEvalContext(PartialWorld world, boolean errorIfUndet) { + super(world, errorIfUndet); + } - final public Object getValue(BasicVar var) { - Object value = getOrComputeValue(var); - if (value == null) { - latestUninstParent = var; - var.ensureStable(); - handleMissingVar(var); - } else { - if (parents.add(var)) { - var.ensureStable(); - } - } - return value; - } + public Object getValue(BasicVar var) { + Object value = getOrComputeValue(var); + if (value == null) { + latestUninstParent = var; + var.ensureStable(); + handleMissingVar(var); + } else { + if (parents.add(var)) { + var.ensureStable(); + } + } + return value; + } - protected Object getOrComputeValue(BasicVar var) { - return world.getValue(var); - } + protected Object getOrComputeValue(BasicVar var) { + return world.getValue(var); + } - // Note that we don't have to override getSatisfiers, because the - // DefaultEvalContext implementation of getSatisfiers calls getValue - // on the number variable + // Note that we don't have to override getSatisfiers, because the + // DefaultEvalContext implementation of getSatisfiers calls getValue + // on the number variable - public NumberVar getPOPAppSatisfied(Object obj) { - if (obj instanceof NonGuaranteedObject) { - return ((NonGuaranteedObject) obj).getNumberVar(); - } + public NumberVar getPOPAppSatisfied(Object obj) { + if (obj instanceof NonGuaranteedObject) { + return ((NonGuaranteedObject) obj).getNumberVar(); + } - if (obj instanceof ObjectIdentifier) { - parents.add(new OriginVar((ObjectIdentifier) obj)); - return world.getPOPAppSatisfied(obj); - } + if (obj instanceof ObjectIdentifier) { + parents.add(new OriginVar((ObjectIdentifier) obj)); + return world.getPOPAppSatisfied(obj); + } - // Must be guaranteed object, so not generated by any number var - return null; - } + // Must be guaranteed object, so not generated by any number var + return null; + } - /** - * Returns the set of basic random variables that are instantiated and whose - * values have been used in calls to the access methods. This set is backed by - * the ParentRecEvalContext and will change as more random variables are used. - * - * @return unmodifiable Set of BasicVar - */ - public Set getParents() { - return Collections.unmodifiableSet(parents); - } + /** + * Returns the set of basic random variables that are instantiated and whose + * values have been used in calls to the access methods. This set is backed by + * the ParentRecEvalContext and will change as more random variables are used. + * + * @return unmodifiable Set of BasicVar + */ + public Set getParents() { + return Collections.unmodifiableSet(parents); + } - /** - * Returns the variable whose value was most recently needed by an access - * method, but which is not instantiated. This method returns null if no such - * variable exists. - */ - public BasicVar getLatestUninstParent() { - return latestUninstParent; - } + /** + * Returns the variable whose value was most recently needed by an access + * method, but which is not instantiated. This method returns null if no such + * variable exists. + */ + public BasicVar getLatestUninstParent() { + return latestUninstParent; + } - protected Set parents = new LinkedHashSet(); // of BasicVar - protected BasicVar latestUninstParent = null; + protected Set parents = new LinkedHashSet(); // of BasicVar + protected BasicVar latestUninstParent = null; } diff --git a/src/main/java/blog/sample/TraceParentRecEvalContext.java b/src/main/java/blog/sample/TraceParentRecEvalContext.java new file mode 100644 index 00000000..e9611a3d --- /dev/null +++ b/src/main/java/blog/sample/TraceParentRecEvalContext.java @@ -0,0 +1,181 @@ +/** + * + */ +package blog.sample; + +import java.util.Collections; +import java.util.LinkedHashMap; +import java.util.LinkedList; +import java.util.List; + +import blog.ObjectIdentifier; +import blog.bn.BasicVar; +import blog.bn.NumberVar; +import blog.bn.OriginVar; +import blog.bn.VarWithDistrib; +import blog.common.HashMapWithPreimages; +import blog.distrib.CondProbDistrib; +import blog.model.DependencyModel; +import blog.model.NonGuaranteedObject; +import blog.world.PartialWorld; + +/** + * @author Da Tang + * + */ +public class TraceParentRecEvalContext extends ClassicInstantiatingEvalContext { + + public TraceParentRecEvalContext(PartialWorld world) { + super(world); + } + + protected TraceParentRecEvalContext( + PartialWorld world, + LinkedHashMap respVarsAndContexts) { + super(world, respVarsAndContexts); + } + + protected Object getOrComputeValue(BasicVar var) { + Object value = world.getValue(var); + if (value == null) { + numCalculateNewVars++; + if (var instanceof VarWithDistrib) { + value = instantiate((VarWithDistrib) var); + } else { + throw new IllegalArgumentException("Don't know how to instantiate: " + + var); + } + } + return value; + } + + public Object getValue(BasicVar var) { + Object value = getOrComputeValue(var); + if (value == null) { + latestUninstParent = var; + var.ensureStable(); + handleMissingVar(var); + } else { + if (parents.add(var)) { + var.ensureStable(); + parentTrace.addLast(var); + caseLevelTrace.addLast(caseLevel); + } + } + correspondingVar = var; + return value; + } + + public NumberVar getPOPAppSatisfied(Object obj) { + if (obj instanceof NonGuaranteedObject) { + return ((NonGuaranteedObject) obj).getNumberVar(); + } + + if (obj instanceof ObjectIdentifier) { + parents.add(new OriginVar((ObjectIdentifier) obj)); + parentTrace.addLast(new OriginVar((ObjectIdentifier) obj)); + caseLevelTrace.addLast(caseLevelTrace); + return world.getPOPAppSatisfied(obj); + } + + // Must be guaranteed object, so not generated by any number var + return null; + } + + protected Object instantiate(VarWithDistrib var) { + var.ensureStable(); + + /* + * if (Util.verbose()) { System.out.println("Need to instantiate: " + var); + * } + */ + + if (respVarsAndContexts.containsKey(var)) { + cycleError(var); + } + + // Create a new "child" context and get the distribution for + // var in that context. + respVarsAndContexts.put(var, this); + TraceParentRecEvalContext spawn = new TraceParentRecEvalContext(world, + respVarsAndContexts); + spawn.afterSamplingListener = afterSamplingListener; + DependencyModel.Distrib distrib = var.getDistrib(spawn); + logProb += spawn.getLogProbability(); + respVarsAndContexts.remove(var); + List parentTrace = spawn.getParentTrace(), caseLevelTrace = spawn + .getCaseLevelTrace(); + for (int i = 0; i < parentTrace.size(); i++) { + this.parentTrace.addLast(parentTrace.get(i)); + this.caseLevelTrace.addLast((Integer) caseLevelTrace.get(i) + + this.caseLevel); + } + this.caseLevel += spawn.getCaseLevel(); + + // Sample new value for var + CondProbDistrib cpd = distrib.getCPD(); + cpd.setParams(distrib.getArgValues()); + Object newValue = cpd.sampleVal(); + double logProbForThisValue = cpd.getLogProb(newValue); + logProb += logProbForThisValue; + + // Assert any identifiers that are used by var + Object[] args = var.args(); + for (int i = 0; i < args.length; ++i) { + if (args[i] instanceof ObjectIdentifier) { + world.assertIdentifier((ObjectIdentifier) args[i]); + } + } + if (newValue instanceof ObjectIdentifier) { + world.assertIdentifier((ObjectIdentifier) newValue); + } + + // Actually set value + world.setValue(var, newValue); + + if (afterSamplingListener != null) { + afterSamplingListener.evaluate(var, newValue, logProbForThisValue); + } + + if (staticAfterSamplingListener != null) { + staticAfterSamplingListener.evaluate(var, newValue, logProbForThisValue); + } + + /* + * if (Util.verbose()) { System.out.println("Instantiated: " + var); } + */ + + return newValue; + } + + public List getParentTrace() { + return Collections.unmodifiableList(parentTrace); + } + + public void increaseCaseLevel() { + caseLevel++; + } + + public List getCaseLevelTrace() { + return Collections.unmodifiableList(caseLevelTrace); + } + + public int getCaseLevel() { + return caseLevel; + } + + public int getNumCalculateNewVars() { + return numCalculateNewVars; + } + + public BasicVar getCorrespondingVar() { + return correspondingVar; + } + + private LinkedList parentTrace = new LinkedList(); + private LinkedList caseLevelTrace = new LinkedList(); + private int caseLevel = 0; + private HashMapWithPreimages assignment; + private int numCalculateNewVars = 0; + private BasicVar correspondingVar = null; +} From 77927674ab203791f81cbe317d257084e45f65d9 Mon Sep 17 00:00:00 2001 From: td11 Date: Thu, 21 Aug 2014 17:05:07 -0700 Subject: [PATCH 02/10] Change to variable sets for derived variables. --- .../java/blog/sample/GenericProposer.java | 41 +------------------ .../sample/TraceParentRecEvalContext.java | 10 +++-- 2 files changed, 8 insertions(+), 43 deletions(-) diff --git a/src/main/java/blog/sample/GenericProposer.java b/src/main/java/blog/sample/GenericProposer.java index 82f62f87..75160419 100644 --- a/src/main/java/blog/sample/GenericProposer.java +++ b/src/main/java/blog/sample/GenericProposer.java @@ -160,7 +160,7 @@ public double proposeNextState(PartialWorldDiff world) { new PartialWorldDiff(world)); var.ensureDetAndSupported(context); if (context.getCorrespondingVar() != null) { - evidenceAndQueries.add(context.getCorrespondingVar()); + evidenceAndQueries.addAll(context.getCorrespondingVar()); } } } @@ -173,7 +173,7 @@ public double proposeNextState(PartialWorldDiff world) { new PartialWorldDiff(world)); var.ensureDetAndSupported(context); if (context.getCorrespondingVar() != null) { - evidenceAndQueries.add(context.getCorrespondingVar()); + evidenceAndQueries.addAll(context.getCorrespondingVar()); } } } @@ -193,43 +193,6 @@ public double proposeNextState(PartialWorldDiff world) { } } } while (!OK); - - /* - * // Remove barren variables - * LinkedList newlyBarren = new LinkedList(world.getNewlyBarrenVars()); - * while (!newlyBarren.isEmpty()) { - * BayesNetVar var = (BayesNetVar) newlyBarren.removeFirst(); - * if (!evidenceVars.contains(var) && !queryVars.contains(var)) { - * - * // Remember its parents. - * Set parentSet = world.getCBN().getParents(var); - * - * if (var instanceof VarWithDistrib) { - * // Multiply in the probability of sampling this - * // variable again. Since the parent value may have - * // changed, must use the old world. - * logProbBackward += world.getSaved().getLogProbOfValue(var); - * - * // Uninstantiate - * world.setValue((VarWithDistrib) var, null); - * } - * - * // Check to see if its parents are now barren. - * for (Iterator parentIter = parentSet.iterator(); parentIter.hasNext();) { - * - * // If parent is barren, add to the end of this - * // linked list. Note that if a parent has two - * // barren children, it will only be added to the - * // end of the list once, when the last child is - * // considered. - * BayesNetVar parent = (BayesNetVar) parentIter.next(); - * if (world.getCBN().getChildren(parent).isEmpty()) - * newlyBarren.addLast(parent); - * } - * } - * } - */ - // Uniform sampling from new world. logProbBackward += (-Math.log(world.getInstantiatedVars().size() - numBasicEvidenceVars)); return (logProbBackward - logProbForward); diff --git a/src/main/java/blog/sample/TraceParentRecEvalContext.java b/src/main/java/blog/sample/TraceParentRecEvalContext.java index e9611a3d..5c499fa0 100644 --- a/src/main/java/blog/sample/TraceParentRecEvalContext.java +++ b/src/main/java/blog/sample/TraceParentRecEvalContext.java @@ -4,9 +4,11 @@ package blog.sample; import java.util.Collections; +import java.util.HashSet; import java.util.LinkedHashMap; import java.util.LinkedList; import java.util.List; +import java.util.Set; import blog.ObjectIdentifier; import blog.bn.BasicVar; @@ -62,7 +64,7 @@ public Object getValue(BasicVar var) { caseLevelTrace.addLast(caseLevel); } } - correspondingVar = var; + correspondingVar.add(var); return value; } @@ -168,8 +170,8 @@ public int getNumCalculateNewVars() { return numCalculateNewVars; } - public BasicVar getCorrespondingVar() { - return correspondingVar; + public Set getCorrespondingVar() { + return Collections.unmodifiableSet(correspondingVar); } private LinkedList parentTrace = new LinkedList(); @@ -177,5 +179,5 @@ public BasicVar getCorrespondingVar() { private int caseLevel = 0; private HashMapWithPreimages assignment; private int numCalculateNewVars = 0; - private BasicVar correspondingVar = null; + private Set correspondingVar = new HashSet(); } From 2c48a8ad24a90e95d066628bf978e2c77ca9926f Mon Sep 17 00:00:00 2001 From: td11 Date: Thu, 21 Aug 2014 21:40:21 -0700 Subject: [PATCH 03/10] Change some comments and variable names. --- src/main/java/blog/sample/GenericProposer.java | 8 ++++---- .../blog/sample/TraceParentRecEvalContext.java | 14 +++++++++----- 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/src/main/java/blog/sample/GenericProposer.java b/src/main/java/blog/sample/GenericProposer.java index 75160419..a1e9d038 100644 --- a/src/main/java/blog/sample/GenericProposer.java +++ b/src/main/java/blog/sample/GenericProposer.java @@ -159,8 +159,8 @@ public double proposeNextState(PartialWorldDiff world) { TraceParentRecEvalContext context = new TraceParentRecEvalContext( new PartialWorldDiff(world)); var.ensureDetAndSupported(context); - if (context.getCorrespondingVar() != null) { - evidenceAndQueries.addAll(context.getCorrespondingVar()); + if (context.getDependentVar() != null) { + evidenceAndQueries.addAll(context.getDependentVar()); } } } @@ -172,8 +172,8 @@ public double proposeNextState(PartialWorldDiff world) { TraceParentRecEvalContext context = new TraceParentRecEvalContext( new PartialWorldDiff(world)); var.ensureDetAndSupported(context); - if (context.getCorrespondingVar() != null) { - evidenceAndQueries.addAll(context.getCorrespondingVar()); + if (context.getDependentVar() != null) { + evidenceAndQueries.addAll(context.getDependentVar()); } } } diff --git a/src/main/java/blog/sample/TraceParentRecEvalContext.java b/src/main/java/blog/sample/TraceParentRecEvalContext.java index 5c499fa0..2b822df3 100644 --- a/src/main/java/blog/sample/TraceParentRecEvalContext.java +++ b/src/main/java/blog/sample/TraceParentRecEvalContext.java @@ -22,8 +22,12 @@ import blog.world.PartialWorld; /** - * @author Da Tang + * A Evaluation Context class that could store the trace of visiting BasicVars. + * This class could also store the case levels in order to get the CBN decision + * tree structure (useful for the Open Universe Gibbs Sampling). * + * @author Da Tang + * @since August 21, 2014 */ public class TraceParentRecEvalContext extends ClassicInstantiatingEvalContext { @@ -64,7 +68,7 @@ public Object getValue(BasicVar var) { caseLevelTrace.addLast(caseLevel); } } - correspondingVar.add(var); + dependentVar.add(var); return value; } @@ -170,8 +174,8 @@ public int getNumCalculateNewVars() { return numCalculateNewVars; } - public Set getCorrespondingVar() { - return Collections.unmodifiableSet(correspondingVar); + public Set getDependentVar() { + return Collections.unmodifiableSet(dependentVar); } private LinkedList parentTrace = new LinkedList(); @@ -179,5 +183,5 @@ public Set getCorrespondingVar() { private int caseLevel = 0; private HashMapWithPreimages assignment; private int numCalculateNewVars = 0; - private Set correspondingVar = new HashSet(); + private Set dependentVar = new HashSet(); } From 09d381fc9ef26681495f7a02f260ab69c17b457b Mon Sep 17 00:00:00 2001 From: td11 Date: Fri, 22 Aug 2014 10:19:31 -0700 Subject: [PATCH 04/10] Change some comments. --- src/main/java/blog/sample/TraceParentRecEvalContext.java | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/main/java/blog/sample/TraceParentRecEvalContext.java b/src/main/java/blog/sample/TraceParentRecEvalContext.java index 2b822df3..f8344b71 100644 --- a/src/main/java/blog/sample/TraceParentRecEvalContext.java +++ b/src/main/java/blog/sample/TraceParentRecEvalContext.java @@ -22,7 +22,9 @@ import blog.world.PartialWorld; /** - * A Evaluation Context class that could store the trace of visiting BasicVars. + * A Evaluation Context class that could store the trace of visited BasicVars + * (when you try to get the distributions or values for some Basic or Derived + * variables). * This class could also store the case levels in order to get the CBN decision * tree structure (useful for the Open Universe Gibbs Sampling). * From e37e1a9ae0a9ad76bd9dbf8b6b26eb6b792d9053 Mon Sep 17 00:00:00 2001 From: Christopher Gioia Date: Fri, 22 Aug 2014 14:07:04 -0700 Subject: [PATCH 05/10] original MH Sampler cannot run hurricane2.blog --- example/hurricane2.blog | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) create mode 100644 example/hurricane2.blog diff --git a/example/hurricane2.blog b/example/hurricane2.blog new file mode 100644 index 00000000..6b5cdd9b --- /dev/null +++ b/example/hurricane2.blog @@ -0,0 +1,32 @@ +/** + * Hurricane + * Figure 4.2 in Milch's thesis + */ + +type City; +type PrepLevel; +type DamageLevel; + +random City First ~ UniformChoice({c for City c}); + +random City NotFirst ~ UniformChoice({c for City c: c != First}); + +random PrepLevel Prep(City c) ~ + if (First == c) then Categorical({High -> 0.5, Low -> 0.5}) + else case Damage(First) in + {Severe -> Categorical({High -> 0.9, Low -> 0.1}), + Mild -> Categorical({High -> 0.1, Low -> 0.9})} + ; + +random DamageLevel Damage(City c) ~ + case Prep(c) in {High -> Categorical({Severe -> 0.2, Mild -> 0.8}), + Low -> Categorical({Severe -> 0.8, Mild -> 0.2})} + ; + +distinct City A, B; +distinct PrepLevel Low, High; +distinct DamageLevel Severe, Mild; + +obs Damage(First) = Severe; + +query First; From 55cf84d50b2ac8728403491ddccf1716451fa039 Mon Sep 17 00:00:00 2001 From: Christopher Gioia Date: Fri, 22 Aug 2014 16:42:48 -0700 Subject: [PATCH 06/10] script for KL divergence and readme --- tools/testing/sample.sh | 32 ++++++++++ tools/testing/sample.txt | 31 +++++++++ tools/testing/sampling.py | 131 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 194 insertions(+) create mode 100755 tools/testing/sample.sh create mode 100644 tools/testing/sample.txt create mode 100644 tools/testing/sampling.py diff --git a/tools/testing/sample.sh b/tools/testing/sample.sh new file mode 100755 index 00000000..e9c345ec --- /dev/null +++ b/tools/testing/sample.sh @@ -0,0 +1,32 @@ +#!/bin/bash + +samples=$1 +blogfile_fullpath=$2 +trials=$3 +# The output file for the results +output=$4 +id=$5 + +echo "File: $blogfile_fullpath" + +LW=tools/testing/output/tmpLW$id +MH=tools/testing/output/tmpMH$id +rm -f $LW +rm -f $MH + +for i in `seq 1 $trials`; do + echo "Trial $i..." + ./blog -r -n $samples -q $samples --interval $samples -s blog.sample.LWSampler $blogfile_fullpath >> $LW + ./blog -r -n $samples -q $samples --interval $samples -s blog.sample.MHSampler $blogfile_fullpath >> $MH +done + +echo "File: $blogfile_fullpath" >> $output +echo "Trials: $trials" >> $output +echo "Samples: $samples" >> $output +echo "" >> $output +python tools/testing/sampling.py $id >> $output +echo "---------------------------------" >> $output + +# Remove the temporary files +rm -f $LW +rm -f $MH diff --git a/tools/testing/sample.txt b/tools/testing/sample.txt new file mode 100644 index 00000000..eeee1c7f --- /dev/null +++ b/tools/testing/sample.txt @@ -0,0 +1,31 @@ +--- Instructions for sample.sh on Ubuntu Machine --- + +sample.sh calculates the symmetric KL divergence for every queried set of distributions of a particular BLOG example. +It estimates the true distribution of a queried variable by taking the expectation of the BLOG example ran with +the LW Sampler over a specified number of trials. +It then computes the average symmetric KL divergence between the LW Sampler and the "true estimate", and +the average symmetric KL divergence between the MH Sampler and the "true estimate" + +1. Install GNU parallel +>>> sudo apt-get install parallel + +2. checkout latest commit of fix-bug-MHSampler +>>> git checkout fix-bug-MHSampler + +3. compile +>>> sbt/sbt compile + +4. run script from top-level directory of blog +>>> pwd +~/blog +>>> find example -name "*.blog" | parallel -k -j 4 --eta "bash ./tools/testing/sample.sh 10000 {} 10 tools/testing/output/KL.txt {#}" + +- uses 10,000 samples +- runs 10 trials +- outputs the result to tools/testing/output/KL.txt +- uses 4 processes (you can adjust with "-j" option) +- prints out information (--eta) +- this script is buggy. + - doesn't work for vectors (throws exception on regex matching, and even if I did get it working, the state space would be too large). - doesn't work on continuous random variables. + - the KL divergence is misleading (sometimes Infinity) for discrete distributions where there values in the support with near zero probability (e.g. Poisson). + - the script crashes when the MH Sampler doesn't work. see "Blog Examples Status" spreadsheet for list of BLOG examples that are known to fail diff --git a/tools/testing/sampling.py b/tools/testing/sampling.py new file mode 100644 index 00000000..ca060521 --- /dev/null +++ b/tools/testing/sampling.py @@ -0,0 +1,131 @@ +import re +import math +import sys + +class Distribution(object): + def __init__(self): + self.probs = {} + + def add(self, key, value): + self.probs[key] = value + + def get(self, key): + return self.probs.get(key, 0.0) + + def __str__(self): + s = "" + for key in self.probs: + s += "\t" + key + " : " + str(self.probs[key]) + "\n" + return s + +""" + Returns the list of distribution + given a list of all the lines in a file +""" +def getListDistributions(lines): + distributions = {} + for index, line in enumerate(lines): + if len(line) >= 28 and line[:12] == "Distribution": + name = line[27:-1] + dsn = getDistribution(lines, index) + + dsn_list = distributions.get(name, 0) + if dsn_list == 0: + distributions[name] = [dsn] + else: + dsn_list.append(dsn) + + return distributions + +""" + startIndex = Line Number (0-indexed) where the Distribution starts + + Returns the distribution (assuming it is discrete), + which is represented as a dictionary where + keys = element of the support + values = corresponding probabilities of the distribution taking + on that element of the support +""" +def getDistribution(lines, startIndex): + distrib = Distribution() + count = startIndex + 1 + while (lines[count] != "======== Done ========\n" + and lines[count][:12] != "Distribution"): + line = lines[count] + p = re.compile(r'\s+(\w+)\s+([0-9.]+)') + probs = p.match(line) + if probs: + distrib.add(probs.group(1), float(probs.group(2))) + else: + raise Exception("expecting a regex match for line: '" + str(line[:-1]) + "'") + count += 1 + return distrib + +""" + Given a list of distributions that all have the same support and + all are sampled from the same origin distribution, + returns a distribution that is the average over all these distributions. +""" +def getEmpiricalDistribution(distributions): + dictionary = {} + for distribution in distributions: + for key in distribution.probs: + dictionary[key] = dictionary.get(key, 0.0) + distribution.probs[key] + N = len(distributions) + for key in dictionary: + dictionary[key] = dictionary[key] / N + dist = Distribution() + dist.probs = dictionary + return dist + +""" + Returns the symmetric KL Divergence between dsnA and dsnB +""" +def getSymmetricKL(dsnA, dsnB): + return (getKL(dsnA, dsnB) + getKL(dsnB, dsnA)) / 2 + +def getKL(dsnA, dsnB): + val = 0.0 + for key in dsnA.probs: + p = dsnA.get(key) + q = dsnB.get(key) + if q == 0.0: + return float("inf") + val += (math.log(p/q) * p) + return val + +def getAverageKL(distributions, actual_distribution): + KL = 0.0 + for distribution in distributions: + KL += getSymmetricKL(distribution, actual_distribution) + return KL / len(distributions) + +# Read in argument +# The first argument is used for parallel processes +# to distinguish the different temporary files open at one time +if len(sys.argv) != 2: + print "must provide exactly 2 arguments" + +TMP_LW="tools/testing/output/tmpLW" + sys.argv[1] +TMP_MH="tools/testing/output/tmpMH" + sys.argv[1] + +# Read in the LW Sampling +f = open(TMP_LW, "r") +dsnsLW = getListDistributions(f.readlines()) +distribution_estimates = {} +for name in dsnsLW: + distribution_list = dsnsLW[name] + distribution_estimates[name] = getEmpiricalDistribution(distribution_list) + +# Read in the MH Sampling +f = open(TMP_MH, "r") +dsnsMH = getListDistributions(f.readlines()) + +# Print out the KLs of each distribution for MH vs. LW +print("%-25s %-10s %-10s %-10s %-10s" % ("Random Variable", "LW", "MH", "Log-LW", "Log-MH")) +for name in distribution_estimates: + KL_LW = getAverageKL(dsnsLW[name], distribution_estimates[name]) + KL_MH = getAverageKL(dsnsMH[name], distribution_estimates[name]) + KL_LW_DER = -1.0 / math.log(KL_LW) + KL_MH_DER = -1.0 / math.log(KL_MH) + print("%-25s %-10.6f %-10.6f %-10.3f %-10.3f" % (name, KL_LW, KL_MH, KL_LW_DER, KL_MH_DER)) From c030d4963b684cf6fcdf7ced440c604407e5a0a5 Mon Sep 17 00:00:00 2001 From: Christopher Gioia Date: Fri, 22 Aug 2014 17:35:11 -0700 Subject: [PATCH 07/10] add script to run KL divergence serially --- tools/testing/sample-all | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) create mode 100755 tools/testing/sample-all diff --git a/tools/testing/sample-all b/tools/testing/sample-all new file mode 100755 index 00000000..126cc021 --- /dev/null +++ b/tools/testing/sample-all @@ -0,0 +1,18 @@ +#!/bin/bash + +samples=$1 +trials=$2 +output=$3 + +if [ "$samples" = "-h" ]; then + echo "Usage of sample-all" + echo "sample_all.sh " + exit 0 +fi + +rm -f $output + +for f in $(find example -name '*.blog'); do + echo "Running $f" + bash tools/testing/sample.sh $samples $f $trials $output 0 +done From 1bc58589b32e58fcf5f24f2c727c8a153b808e15 Mon Sep 17 00:00:00 2001 From: Christopher Gioia Date: Fri, 22 Aug 2014 17:43:21 -0700 Subject: [PATCH 08/10] minor --- tools/testing/sample. | 0 tools/testing/sample.sh | 1 + 2 files changed, 1 insertion(+) create mode 100644 tools/testing/sample. diff --git a/tools/testing/sample. b/tools/testing/sample. new file mode 100644 index 00000000..e69de29b diff --git a/tools/testing/sample.sh b/tools/testing/sample.sh index e9c345ec..fb5a9250 100755 --- a/tools/testing/sample.sh +++ b/tools/testing/sample.sh @@ -9,6 +9,7 @@ id=$5 echo "File: $blogfile_fullpath" +mkdir -p tools/testing/output LW=tools/testing/output/tmpLW$id MH=tools/testing/output/tmpMH$id rm -f $LW From 8834e7b3248f5992360f7cf38c88f80225260032 Mon Sep 17 00:00:00 2001 From: Christopher Gioia Date: Tue, 26 Aug 2014 16:41:51 -0700 Subject: [PATCH 09/10] bug in sampling.py where element of random variable couldn't be real number --- tools/testing/sample.sh | 6 ++++++ tools/testing/sampling.py | 15 ++++++++++++--- 2 files changed, 18 insertions(+), 3 deletions(-) diff --git a/tools/testing/sample.sh b/tools/testing/sample.sh index fb5a9250..91302d5b 100755 --- a/tools/testing/sample.sh +++ b/tools/testing/sample.sh @@ -7,6 +7,12 @@ trials=$3 output=$4 id=$5 +if [ "$1" = "-h" ]; then + echo "Usage of sample.sh" + echo "sample.sh " + exit 0 +fi + echo "File: $blogfile_fullpath" mkdir -p tools/testing/output diff --git a/tools/testing/sampling.py b/tools/testing/sampling.py index ca060521..3ca2103b 100644 --- a/tools/testing/sampling.py +++ b/tools/testing/sampling.py @@ -52,7 +52,7 @@ def getDistribution(lines, startIndex): while (lines[count] != "======== Done ========\n" and lines[count][:12] != "Distribution"): line = lines[count] - p = re.compile(r'\s+(\w+)\s+([0-9.]+)') + p = re.compile(r'\s+([a-zA-Z0-9_.]+)\s+([0-9.]+)') probs = p.match(line) if probs: distrib.add(probs.group(1), float(probs.group(2))) @@ -126,6 +126,15 @@ def getAverageKL(distributions, actual_distribution): for name in distribution_estimates: KL_LW = getAverageKL(dsnsLW[name], distribution_estimates[name]) KL_MH = getAverageKL(dsnsMH[name], distribution_estimates[name]) - KL_LW_DER = -1.0 / math.log(KL_LW) - KL_MH_DER = -1.0 / math.log(KL_MH) + #KL_LW_DER = -1.0 / math.log(KL_LW) + #KL_MH_DER = -1.0 / math.log(KL_MH) + if KL_MH == 0.0: + KL_MH_DER = float("inf") + else: + KL_MH_DER = -1 * math.log(KL_MH) + + if KL_LW == 0.0: + KL_LW_DER = float("inf") + else: + KL_LW_DER = -1 * math.log(KL_LW) print("%-25s %-10.6f %-10.6f %-10.3f %-10.3f" % (name, KL_LW, KL_MH, KL_LW_DER, KL_MH_DER)) From 1e8efaff5bb9e218404123aaf44fd151a8de1074 Mon Sep 17 00:00:00 2001 From: td11 Date: Mon, 1 Sep 2014 10:17:53 -0700 Subject: [PATCH 10/10] Fix bug in Issue #303. --- src/main/java/blog/sample/GenericProposer.java | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/main/java/blog/sample/GenericProposer.java b/src/main/java/blog/sample/GenericProposer.java index a1e9d038..1f002759 100644 --- a/src/main/java/blog/sample/GenericProposer.java +++ b/src/main/java/blog/sample/GenericProposer.java @@ -241,7 +241,10 @@ protected boolean evidenceAndQueriesSupported(Set evidenceAndQueries, // updates the logProbForward and logProbBackward variables. protected void sampleValue(VarWithDistrib varToSample, PartialWorld world) { // Save child set before graph becomes out of date - Set children = world.getCBN().getChildren(varToSample); + Set children = new HashSet(); + children.addAll(world.getCBN().getChildren(varToSample)); + children.addAll(evidenceVars); + children.addAll(queryVars); DependencyModel.Distrib distrib = varToSample .getDistrib(new DefaultEvalContext(world, true)); @@ -263,9 +266,10 @@ protected void sampleValue(VarWithDistrib varToSample, PartialWorld world) { for (Iterator childrenIter = children.iterator(); childrenIter.hasNext();) { BayesNetVar child = (BayesNetVar) childrenIter.next(); - if (!world.isInstantiated(child)) // NOT SURE YET THIS IS THE RIGHT THING - // TO DO! CHECKING WITH BRIAN. - continue; + // if (!world.isInstantiated(child)) // NOT SURE YET THIS IS THE RIGHT + // THING + // TO DO! CHECKING WITH BRIAN. + // continue; child.ensureDetAndSupported(instantiator); }