diff --git a/src/main/java/blog/bn/CBN.java b/src/main/java/blog/bn/CBN.java index 269d249a..52b60b66 100644 --- a/src/main/java/blog/bn/CBN.java +++ b/src/main/java/blog/bn/CBN.java @@ -36,6 +36,7 @@ package blog.bn; import blog.common.DGraph; +import blog.world.PartialWorld; /** * A contingent bayes net (CBN) contains a set of random variables V. For each @@ -43,16 +44,23 @@ * tree T_X. The decision tree is a binary tree where each node is a predicate * on some subset of V. Each leaf of T_X is a probability distribution * parametrized by a subset of V. (Summarized from Arora et. al, UAI-10) - * - * TODO: As the requirements for CBNs become clearer, add actual methods here - * - * @author rbharath - * @date Aug 11, 2012 + * + * @author Da Tang + * @since Sep 07, 2014 */ -public interface CBN extends DGraph { - /** - * An empty CBN - */ - static final CBN EMPTY_CBN = new DefaultCBN(); +public interface CBN extends DGraph { + /** + * An empty CBN + */ + static final CBN EMPTY_CBN = new DefaultCBN(); + + /** + * Calculating whether an edge Y -> Z is contingent on variable X or not in + * the + * PartialWorld or not. + * + */ + boolean isContingentOn(PartialWorld world, BayesNetVar X, BayesNetVar Y, + BayesNetVar Z); } diff --git a/src/main/java/blog/bn/DefaultCBN.java b/src/main/java/blog/bn/DefaultCBN.java index b8a59aa8..cc18b3ae 100644 --- a/src/main/java/blog/bn/DefaultCBN.java +++ b/src/main/java/blog/bn/DefaultCBN.java @@ -37,27 +37,76 @@ import java.util.HashMap; import java.util.Iterator; +import java.util.LinkedList; import java.util.Map; import blog.common.DefaultDGraph; -import blog.common.DefaultDGraph.NodeInfo; +import blog.common.Util; +import blog.sample.TraceParentRecEvalContext; +import blog.world.PartialWorld; /** - * This class provides a default implementation of CBNs. Over the next few weeks, - * all inference algorithms will be modified to use CBNs. + * This class provides a default implementation of CBNs. + * Uses a DefaultCBN rather than a DefaultDGraph, so as to avoid + * ClassCastExceptions. + * + * @author Da Tang + * @since Sep 7, 2014 */ public class DefaultCBN extends DefaultDGraph implements CBN { - /** - * Uses a DefaultCBN rather than a DefaultDGraph, so as to avoid ClassCastExceptions - */ - public Object clone() { - DefaultCBN clone = new DefaultCBN(); - clone.nodeInfo = (Map) ((HashMap) nodeInfo).clone(); - for (Iterator iter = clone.nodeInfo.entrySet().iterator(); iter.hasNext();) { - Map.Entry entry = (Map.Entry) iter.next(); - entry.setValue(((NodeInfo) entry.getValue()).clone()); - } - return clone; - } + /** + * clone method for the class Default CBN. + */ + public Object clone() { + DefaultCBN clone = new DefaultCBN(); + clone.nodeInfo = (Map) ((HashMap) nodeInfo).clone(); + for (Iterator iter = clone.nodeInfo.entrySet().iterator(); iter.hasNext();) { + Map.Entry entry = (Map.Entry) iter.next(); + entry.setValue(((NodeInfo) entry.getValue()).clone()); + } + return clone; + } + + @Override + public boolean isContingentOn(PartialWorld world, BayesNetVar X, + BayesNetVar Y, BayesNetVar Z) { + TraceParentRecEvalContext context = new TraceParentRecEvalContext(world); + if (Z instanceof VarWithDistrib) { + ((VarWithDistrib) Z).getDistrib(context); + } else if (Z instanceof DerivedVar) { + ((DerivedVar) Z).getValue(context); + } else { + return true; + } + + LinkedList parentTrace = new LinkedList(); + parentTrace.addAll(context.getParentTrace()); + + int x = parentTrace.indexOf(X), y = parentTrace.indexOf(Y); + if (x < 0 || y < 0) { + return false; + } + if (X instanceof NumberVar) { + if (x < y && world.getCBN().getAncestors(Y).contains(X)) { + if (Util.verbose()) { + System.out.println("\t Contingent relations type 1: " + X.toString() + + " " + Y.toString() + " " + Z.toString()); + } + return true; + } else { + return false; + } + } else { + if (x < y) { + if (Util.verbose()) { + System.out.println("\t Contingent relations type 2: " + X.toString() + + " " + Y.toString() + " " + Z.toString()); + } + return true; + } else { + return false; + } + } + } } diff --git a/src/main/java/blog/bn/PatchCBN.java b/src/main/java/blog/bn/PatchCBN.java index b0ab9445..eda97d70 100644 --- a/src/main/java/blog/bn/PatchCBN.java +++ b/src/main/java/blog/bn/PatchCBN.java @@ -35,23 +35,70 @@ package blog.bn; -import blog.common.ParentUpdateDGraph; +import java.util.LinkedList; + import blog.common.DGraph; +import blog.common.ParentUpdateDGraph; +import blog.common.Util; +import blog.sample.TraceParentRecEvalContext; +import blog.world.PartialWorld; /** * A patch data structure to an underlying CBN that represents changes to - * the set of nodes and to the parent sets of existing nodes. Currently is a + * the set of nodes and to the parent sets of existing nodes. Currently is a * thin wrapper over ParentUpdateDGraph. - * + * * @author rbharath * @date August 12, 2012 */ public class PatchCBN extends ParentUpdateDGraph implements CBN { - /** - * Creates a new ParentUpdateDGraph that represents no changes to the given - * underlying graph. - */ - public PatchCBN(CBN underlying) { - super((DGraph) underlying); - } + /** + * Creates a new ParentUpdateDGraph that represents no changes to the given + * underlying graph. + */ + public PatchCBN(CBN underlying) { + super((DGraph) underlying); + } + + @Override + public boolean isContingentOn(PartialWorld world, BayesNetVar X, + BayesNetVar Y, BayesNetVar Z) { + TraceParentRecEvalContext context = new TraceParentRecEvalContext(world); + if (Z instanceof VarWithDistrib) { + ((VarWithDistrib) Z).getDistrib(context); + } else if (Z instanceof DerivedVar) { + ((DerivedVar) Z).getValue(context); + } else { + return true; + } + + LinkedList parentTrace = new LinkedList(); + parentTrace.addAll(context.getParentTrace()); + + int x = parentTrace.indexOf(X), y = parentTrace.indexOf(Y); + if (x < 0 || y < 0) { + return false; + } + if (X instanceof NumberVar) { + if (x < y && world.getCBN().getAncestors(Y).contains(X)) { + if (Util.verbose()) { + System.out.println("\t Contingent relations type 1: " + X.toString() + + " " + Y.toString() + " " + Z.toString()); + } + return true; + } else { + return false; + } + } else { + if (x < y) { + if (Util.verbose()) { + System.out.println("\t Contingent relations type 2: " + X.toString() + + " " + Y.toString() + " " + Z.toString()); + } + return true; + } else { + return false; + } + } + } } diff --git a/src/main/java/blog/sample/AbstractProposer.java b/src/main/java/blog/sample/AbstractProposer.java index df1c42fc..4031e3ec 100644 --- a/src/main/java/blog/sample/AbstractProposer.java +++ b/src/main/java/blog/sample/AbstractProposer.java @@ -123,7 +123,7 @@ public PartialWorldDiff reduceToCore(PartialWorld curWorld, BayesNetVar var) { } public double proposeNextState(PartialWorldDiff proposedWorld, - BayesNetVar var, int i) { + BayesNetVar var, Object value) { return 0; } diff --git a/src/main/java/blog/sample/GenericProposer.java b/src/main/java/blog/sample/GenericProposer.java index 450bb2a5..3888fdef 100644 --- a/src/main/java/blog/sample/GenericProposer.java +++ b/src/main/java/blog/sample/GenericProposer.java @@ -41,7 +41,9 @@ import java.util.Properties; import java.util.Set; +import blog.bn.BasicVar; import blog.bn.BayesNetVar; +import blog.bn.DerivedVar; import blog.bn.VarWithDistrib; import blog.common.Util; import blog.distrib.CondProbDistrib; @@ -148,38 +150,122 @@ public double proposeNextState(PartialWorldDiff world) { } // Remove barren variables - LinkedList newlyBarren = new LinkedList(world.getNewlyBarrenVars()); - while (!newlyBarren.isEmpty()) { - BayesNetVar var = (BayesNetVar) newlyBarren.removeFirst(); - if (!evidenceVars.contains(var) && !queryVars.contains(var)) { - - // Remember its parents. - Set parentSet = world.getCBN().getParents(var); - - if (var instanceof VarWithDistrib) { - // Multiply in the probability of sampling this - // variable again. Since the parent value may have - // changed, must use the old world. - logProbBackward += world.getSaved().getLogProbOfValue(var); - - // Uninstantiate - world.setValue((VarWithDistrib) var, null); + Set evidenceAndQueryVars = new HashSet(); + for (Iterator iter = evidenceVars.iterator(); iter.hasNext();) { + BayesNetVar curVar = (BayesNetVar) iter.next(); + if (curVar instanceof BasicVar) { + evidenceAndQueryVars.add(curVar); + } else if (curVar instanceof DerivedVar) { + evidenceAndQueryVars.addAll(curVar.getParents(world)); + } + } + for (Iterator iter = queryVars.iterator(); iter.hasNext();) { + BayesNetVar curVar = (BayesNetVar) iter.next(); + if (curVar instanceof BasicVar) { + evidenceAndQueryVars.add(curVar); + } else if (curVar instanceof DerivedVar) { + evidenceAndQueryVars.addAll(curVar.getParents(world)); + } + } + boolean OK; + do { + OK = true; + for (Iterator iter = world.getInstantiatedVars().iterator(); iter + .hasNext();) { + BasicVar curVar = (BasicVar) iter.next(); + PartialWorldDiff tmpWorld = new PartialWorldDiff(world); + tmpWorld.setValue(curVar, null); + if (evidenceAndQueriesAreSupported(evidenceAndQueryVars, tmpWorld)) { + world.setValue(curVar, null); + OK = false; + logProbBackward += world.getSaved().getLogProbOfValue(curVar); } + } + } while (!OK); - // Check to see if its parents are now barren. - for (Iterator parentIter = parentSet.iterator(); parentIter.hasNext();) { - - // If parent is barren, add to the end of this - // linked list. Note that if a parent has two - // barren children, it will only be added to the - // end of the list once, when the last child is - // considered. - BayesNetVar parent = (BayesNetVar) parentIter.next(); - if (world.getCBN().getChildren(parent).isEmpty()) - newlyBarren.addLast(parent); - } + // Uniform sampling from new world. + logProbBackward += (-Math.log(world.getInstantiatedVars().size() + - numBasicEvidenceVars)); + return (logProbBackward - logProbForward); + } + + /** + * Proposes a next state for the Markov chain given the current state. The + * proposedWorld argument is a PartialWorldDiff that the proposer can modify + * to create the proposal; the saved version of this PartialWorldDiff is the + * state before the proposal. Returns the log proposal ratio: log (q(x | x') / + * q(x' | x)) + */ + public double proposeNextState(PartialWorldDiff world, + VarWithDistrib varToSample) { + if (evidence == null) { + throw new IllegalStateException( + "initialize() has not been called on proposer."); + } + + logProbForward = 0; + logProbBackward = 0; + + PickVarToSampleResult result = pickVarToSample(world); + chosenVar = varToSample; + + if (result.varToSample == null) + return 1.0; + + if (Util.verbose()) + System.out.println(" sampling " + varToSample); + + // Multiply in the probability of this uniform sample. + logProbForward += (-Math.log(result.numberOfChoices)); + + if (Util.verbose()) { + System.out.println("GenericProposer: world right before sampling" + + " new value for " + varToSample + ".\n"); + System.out.println(world); + } + + // Sample value for variable and update forward and backward probs + sampleValue(varToSample, world); + + if (Util.verbose()) { + System.out.println("GenericProposer: world right before getting" + + " newly barren vars.\n"); + System.out.println(world); + } + + // Remove barren variables + Set evidenceAndQueryVars = new HashSet(); + for (Iterator iter = evidenceVars.iterator(); iter.hasNext();) { + BayesNetVar curVar = (BayesNetVar) iter.next(); + if (curVar instanceof BasicVar) { + evidenceAndQueryVars.add(curVar); + } else if (curVar instanceof DerivedVar) { + evidenceAndQueryVars.addAll(curVar.getParents(world)); + } + } + for (Iterator iter = queryVars.iterator(); iter.hasNext();) { + BayesNetVar curVar = (BayesNetVar) iter.next(); + if (curVar instanceof BasicVar) { + evidenceAndQueryVars.add(curVar); + } else if (curVar instanceof DerivedVar) { + evidenceAndQueryVars.addAll(curVar.getParents(world)); } } + boolean OK; + do { + OK = true; + for (Iterator iter = world.getInstantiatedVars().iterator(); iter + .hasNext();) { + BasicVar curVar = (BasicVar) iter.next(); + PartialWorldDiff tmpWorld = new PartialWorldDiff(world); + tmpWorld.setValue(curVar, null); + if (evidenceAndQueriesAreSupported(evidenceAndQueryVars, tmpWorld)) { + world.setValue(curVar, null); + OK = false; + logProbBackward += world.getSaved().getLogProbOfValue(curVar); + } + } + } while (!OK); // Uniform sampling from new world. logProbBackward += (-Math.log(world.getInstantiatedVars().size() @@ -187,6 +273,120 @@ public double proposeNextState(PartialWorldDiff world) { return (logProbBackward - logProbForward); } + /** + * Proposes a next state for the Markov chain given the current state and the + * current value. This method is written for Gibbs Sampler and therefore we + * need to pass the objective value for the sampled variable. The method + * returns the log probability of \log\frac{\Pr[var|\mathcal + * \sigma_{T_{var}^{var=value}]}}{|V(var)|}, which is the part of Gibbs weight + * no related to the core variable set. The other part will be computed in the + * Gibbs Sampler itself. + * + */ + public double proposeNextState(PartialWorldDiff world, BayesNetVar var, + Object value) { + if (evidence == null) { + throw new IllegalStateException( + "initialize() has not been called on proposer."); + } + + logProbGibbs = 0; + + if (!(var instanceof VarWithDistrib)) { + throw new IllegalStateException( + "Try to sample a non-distribution variable in Gibbs Sampling."); + } + + // Sample value for variable and update forward and backward probs + sampleValue((VarWithDistrib) var, value, world); + + if (Util.verbose()) { + System.out.println("GenericProposer: world right before getting" + + " newly barren vars.\n"); + System.out.println(world); + } + + Set evidenceAndQueryVars = new HashSet(); + for (Iterator iter = evidenceVars.iterator(); iter.hasNext();) { + BayesNetVar curVar = (BayesNetVar) iter.next(); + if (curVar instanceof BasicVar) { + evidenceAndQueryVars.add(curVar); + } else if (curVar instanceof DerivedVar) { + evidenceAndQueryVars.addAll(curVar.getParents(world)); + } + } + for (Iterator iter = queryVars.iterator(); iter.hasNext();) { + BayesNetVar curVar = (BayesNetVar) iter.next(); + if (curVar instanceof BasicVar) { + evidenceAndQueryVars.add(curVar); + } else if (curVar instanceof DerivedVar) { + evidenceAndQueryVars.addAll(curVar.getParents(world)); + } + } + boolean OK; + do { + OK = true; + for (Iterator iter = world.getInstantiatedVars().iterator(); iter + .hasNext();) { + BasicVar curVar = (BasicVar) iter.next(); + PartialWorldDiff tmpWorld = new PartialWorldDiff(world); + tmpWorld.setValue(curVar, null); + if (evidenceAndQueriesAreSupported(evidenceAndQueryVars, tmpWorld)) { + world.setValue(curVar, null); + OK = false; + logProbBackward += world.getSaved().getLogProbOfValue(curVar); + } + } + } while (!OK); + + // Uniform sampling from new world. + logProbGibbs += (-Math.log(world.getInstantiatedVars().size() + - numBasicEvidenceVars)); + return logProbGibbs; + } + + /** + * Check whether the set of evidence and queries are all supported or not in + * the given partial world. Also check whether the world is self-supported or + * not. Returns true if all the evidence and queries are supported and also + * the world is self-supported. + * + */ + protected boolean evidenceAndQueriesAreSupported(Set evidenceAndQueries, + PartialWorld world) { + PartialWorldDiff tmpWorld = new PartialWorldDiff(world); + for (Iterator iter = evidenceAndQueries.iterator(); iter.hasNext();) { + BasicVar var = (BasicVar) iter.next(); + if (!tmpWorld.isInstantiated(var)) { + return false; + } + TraceParentRecEvalContext context = new TraceParentRecEvalContext( + tmpWorld); + if (var instanceof VarWithDistrib) { + ((VarWithDistrib) var).getDistrib(context); + if (context.getNumCalculateNewVars() > 0) { + return false; + } + } + } + for (Iterator iter = tmpWorld.getInstantiatedVars().iterator(); iter + .hasNext();) { + BasicVar var = (BasicVar) iter.next(); + if (!tmpWorld.isInstantiated(var)) { + return false; + } + TraceParentRecEvalContext context = new TraceParentRecEvalContext( + tmpWorld); + if (var instanceof VarWithDistrib) { + ((VarWithDistrib) var).getDistrib(context); + if (context.getNumCalculateNewVars() > 0) { + return false; + } + } + } + return true; + } + // Samples a new value for the given variable (which must be // supported in world) and sets this new value as the // value of the variable in world. Then ensures that @@ -195,8 +395,10 @@ public double proposeNextState(PartialWorldDiff world) { // updates the logProbForward and logProbBackward variables. protected void sampleValue(VarWithDistrib varToSample, PartialWorld world) { // Save child set before graph becomes out of date - Set children = world.getCBN().getChildren(varToSample); - + Set children = new HashSet(); + children.addAll(world.getCBN().getChildren(varToSample)); + children.addAll(evidenceVars); + children.addAll(queryVars); DependencyModel.Distrib distrib = varToSample .getDistrib(new DefaultEvalContext(world, true)); CondProbDistrib cpd = distrib.getCPD(); @@ -217,8 +419,7 @@ protected void sampleValue(VarWithDistrib varToSample, PartialWorld world) { for (Iterator childrenIter = children.iterator(); childrenIter.hasNext();) { BayesNetVar child = (BayesNetVar) childrenIter.next(); - if (!world.isInstantiated(child)) // NOT SURE YET THIS IS THE RIGHT THING - // TO DO! CHECKING WITH BRIAN. + if (!world.isInstantiated(child) && !(child instanceof DerivedVar)) continue; child.ensureDetAndSupported(instantiator); } @@ -226,6 +427,85 @@ protected void sampleValue(VarWithDistrib varToSample, PartialWorld world) { logProbForward += instantiator.getLogProbability(); } + protected void sampleValue(VarWithDistrib varToSample, Object value, + PartialWorld world) { + // Save child set before graph becomes out of date + Set children = new HashSet(); + children.addAll(world.getCBN().getChildren(varToSample)); + children.addAll(evidenceVars); + children.addAll(queryVars); + world.setValue(varToSample, value); + logProbGibbs += varToSample.getDistrib(new DefaultEvalContext(world, true)) + .getCPD().getLogProb(value); + + // Make the world self-supporting. The only variables whose active + // parent sets could have changed are the children of varToSample. + ClassicInstantiatingEvalContext instantiator = new ClassicInstantiatingEvalContext( + world); + + for (Iterator childrenIter = children.iterator(); childrenIter.hasNext();) { + BayesNetVar child = (BayesNetVar) childrenIter.next(); + + if (!world.isInstantiated(child) && !(child instanceof DerivedVar)) + continue; + + child.ensureDetAndSupported(instantiator); + } + } + + /** + * Reduce the current partial world to the core. That is, only leave the core + * variables and the sampled variable instantiated and uninstantiate the other + * variables. + */ + public PartialWorldDiff reduceToCore(PartialWorld curWorld, BayesNetVar var) { + if (!(var instanceof BasicVar) || !curWorld.isInstantiated(var)) { + throw new IllegalStateException( + "Sampled variable is not an instantiated BasicVar."); + } + LinkedList coreBFS = new LinkedList(); + Set core = new HashSet(); + for (Iterator iter = evidenceVars.iterator(); iter.hasNext();) { + BayesNetVar curVar = (BayesNetVar) iter.next(); + if (!coreBFS.contains(curVar)) + coreBFS.addLast(curVar); + } + for (Iterator iter = queryVars.iterator(); iter.hasNext();) { + BayesNetVar curVar = (BayesNetVar) iter.next(); + if (!coreBFS.contains(curVar)) + coreBFS.addLast(curVar); + } + while (!coreBFS.isEmpty()) { + BayesNetVar curVar = (BayesNetVar) coreBFS.removeFirst(); + if (curVar instanceof BasicVar) { + if (curWorld.isInstantiated(curVar)) { + core.add(curVar); + } + } else if (!(curVar instanceof DerivedVar)) { + continue; + } + Set parentVars = curVar.getParents(curWorld); + for (Iterator iter = parentVars.iterator(); iter.hasNext();) { + BayesNetVar parVar = (BayesNetVar) iter.next(); + if (!curWorld.getCBN().isContingentOn(curWorld, var, parVar, curVar)) { + coreBFS.addLast(parVar); + } + } + } + Set varsToUninstantiate = new HashSet(); + varsToUninstantiate.addAll(curWorld.getInstantiatedVars()); + varsToUninstantiate.removeAll(core); + varsToUninstantiate.remove(var); + PartialWorldDiff newWorld = new PartialWorldDiff(curWorld); + for (Iterator iter = varsToUninstantiate.iterator(); iter.hasNext();) { + BasicVar curVar = (BasicVar) iter.next(); + newWorld.setValue(curVar, null); + } + return newWorld; + } + + private double logProbGibbs; + // The following are for debugger use only! private VarWithDistrib chosenVar = null; diff --git a/src/main/java/blog/sample/GibbsSampler.java b/src/main/java/blog/sample/GibbsSampler.java index 289c33df..5453602c 100644 --- a/src/main/java/blog/sample/GibbsSampler.java +++ b/src/main/java/blog/sample/GibbsSampler.java @@ -35,180 +35,190 @@ package blog.sample; -import java.util.*; +import java.lang.reflect.Constructor; +import java.util.HashSet; +import java.util.Iterator; +import java.util.Properties; +import java.util.Set; import blog.bn.BayesNetVar; import blog.bn.VarWithDistrib; -import blog.bn.NumberVar; -import blog.bn.RandFuncAppVar; -import blog.model.Type; import blog.common.Util; import blog.model.Model; -import blog.model.Evidence; -import blog.world.PartialWorld; import blog.world.PartialWorldDiff; -import blog.model.Query; -import java.util.List; -import java.util.Properties; /** - * An implementation of the open universe Gibbs Sampler described by - * Arora et. al. This sampler differs from a standard Gibbs sampler + * An implementation of the open universe Gibbs Sampler described by + * Arora et. al. This sampler differs from a standard Gibbs sampler * in the fact that it shrinks and expands the CBN during sampling steps * to account for the changes in CBN structure as the values of random * variables change. - * + * * This implementation is built as a modification of the MH sampler since * many of the CBN manipulations are the same for Gibbs as for MH and since this * Gibbs sampler reverts to MH sampling for variables of infinite domain. - * - * TODO: Although the structure of this class is correct, a number of the - * called methods are only stubs. These will be filled in over the next few - * commits. This class will probably have MHSampler as superclass in the next - * iteration since both classes share significant amounts of structure. - * - * @author rbharath - * @date Aug 10, 2012 + * + * + * @author Da Tang + * @date Sep 7, 2014 */ public class GibbsSampler extends MHSampler { - /** - * Creates a new Gibbs Sampler for a given BLOG model. - */ - public GibbsSampler(Model model, Properties properties) { - super(model); - // For now, only the generic proposer is allowed - properties.setProperty("proposerClass", "blog.GenericProposer"); - constructProposer(properties); + /** + * Creates a new Gibbs Sampler for a given BLOG model. + */ + public GibbsSampler(Model model) { + super(model); + } + + /** + * Creates a new Gibbs Sampler for a given BLOG model. + */ + public GibbsSampler(Model model, Properties properties) { + super(model); + // For now, only the generic proposer is allowed + // properties.setProperty("proposerClass", "blog.VariableGibbsProposer"); + constructProposer(properties); + } + + /** Method responsible for initializing the proposer field. */ + protected void constructProposer(Properties properties) { + String proposerClassName = properties.getProperty("proposerClass", + "blog.sample.GenericProposer"); + System.out.println("Constructing Gibbs Sampling proposer of class " + + proposerClassName); + + try { + Class proposerClass = Class.forName(proposerClassName); + Class[] paramTypes = { Model.class, Properties.class }; + Constructor constructor = proposerClass.getConstructor(paramTypes); + + Object[] args = { model, properties }; + proposer = (Proposer) constructor.newInstance(args); + } catch (Exception e) { + Util.fatalError(e); } - - /** - * Generates the next partial world by Gibbs sampling: Randomly selects a - * non-evidence variable X. Reduces the current instantiation - * to its core, and resamples X from this core. Ensures that the resulting - * world is minimal and self supported. - */ - public void nextSample() { - // Find Nonevidence Variables in Current World - Set eligibleVars = new HashSet(curWorld.getInstantiatedVars()); - eligibleVars.removeAll(evidenceVars); - ++totalNumSamples; - ++numSamplesThisTrial; - - // Return if no vars to sample - if (eligibleVars.isEmpty()) - return; - - // Find Variable to Sample - VarWithDistrib varToSample = - (VarWithDistrib) Util.uniformSample(eligibleVars); - - if (Util.verbose()) - System.out.println("Sampling " + varToSample); - - int domSize = 0; - - // Find the domain size of this variable - if (varToSample instanceof NumberVar) { - // Number Variables have infinite domain - // TODO: If evidence restricts this number var, does it still - // have infinite domain? - domSize = -1; + } + + /** + * Generates the next partial world by Gibbs sampling: Randomly selects a + * non-evidence variable X. Reduces the current instantiation + * to its core, and resamples X from this core. Ensures that the resulting + * world is minimal and self supported. + */ + public void nextSample() { + // Find Nonevidence Variables in Current World + Set eligibleVars = new HashSet(curWorld.getInstantiatedVars()); + eligibleVars.removeAll(evidence.getEvidenceVars()); + ++totalNumSamples; + ++numSamplesThisTrial; + + // Return if no vars to sample + if (eligibleVars.isEmpty()) + return; + + // Find Variable to Sample + VarWithDistrib varToSample = (VarWithDistrib) Util + .uniformSample(eligibleVars); + + if (Util.verbose()) + System.out.println("Sampling " + varToSample); + + Object[] finiteSupport = varToSample + .getDistrib(new DefaultEvalContext(curWorld)).getCPD() + .getFiniteSupport(); + + // If domain size is finite + if (finiteSupport != null) { + // Calculate possible transitions and their weights + curWorld.save(); + int supportSize = finiteSupport.length; + double[] weights = new double[supportSize]; + PartialWorldDiff[] diffs = new PartialWorldDiff[supportSize]; + + PartialWorldDiff reducedWorld = proposer.reduceToCore(curWorld, + varToSample); + + for (int i = 0; i < supportSize; i++) { + // Set varToSample to i-th value in domain + // Ensure Minimal Self-Supported Instantiation + PartialWorldDiff proposedWorld; + if (finiteSupport[i].equals(curWorld.getValue(varToSample))) { + proposedWorld = new PartialWorldDiff(curWorld); } else { - RandFuncAppVar randFunc = (RandFuncAppVar) varToSample; - Type retType = randFunc.getType(); - if (!retType.hasFiniteGuaranteed()) { - domSize = -1; - } else { - domSize = retType.range().size(); - } + proposedWorld = new PartialWorldDiff(curWorld, reducedWorld); } - - // If domain size is finite - if (domSize >= 0) { - // Calculate possible transitions and their weights - double[] weights = new double[domSize]; - PartialWorldDiff[] diffs = new PartialWorldDiff[domSize]; - - for (int i = 0; i < domSize; i++) { - PartialWorldDiff reducedWorld = - proposer.reduceToCore(curWorld, varToSample); - // Set varToSample to i-th value in domain - // Ensure Minimal Self-Supported Instantiation - double logProposalRatio = - proposer.proposeNextState(reducedWorld, varToSample, i); - double weight = evidence.getEvidenceProb(curWorld); - weights[i] = weight; - diffs[i] = reducedWorld; - } - - int idx = Util.sampleWithProbs(weights); - PartialWorldDiff selected = diffs[idx]; - - // Save the selected world - selected.save(); - } else { - // Infinite Domain Size so we fall back to MH Sampling - curWorld.save(); // make sure we start with saved world. - double logProposalRatio = - proposer.proposeNextState(curWorld, varToSample); - - if (Util.verbose()) { - System.out.println(); - System.out.println("\tlog proposal ratio: " + logProposalRatio); - } - - double logProbRatio = - computeLogProbRatio(curWorld.getSaved(), curWorld); - if (Util.verbose()) { - System.out.println("\tlog probability ratio: " + logProbRatio); - } - double logAcceptRatio = logProbRatio + logProposalRatio; - if (Util.verbose()) { - System.out.println("\tlog acceptance ratio: " + logAcceptRatio); - } - - // Accept or reject proposal - if ((logAcceptRatio >= 0) || - (Util.random() < Math.exp(logAcceptRatio))) { - curWorld.save(); - if (Util.verbose()) { - System.out.println("\taccepted"); - } - ++totalNumAccepted; - ++numAcceptedThisTrial; - proposer.updateStats(true); - } else { - curWorld.revert(); // clean slate for next proposal - if (Util.verbose()) { - System.out.println("\trejected"); - } - proposer.updateStats(false); - } - + double logProposalRatio = proposer.proposeNextState(proposedWorld, + varToSample, finiteSupport[i]); + // double weight = evidence.getEvidenceProb(curWorld); + for (Iterator iter = reducedWorld.getInstantiatedVars().iterator(); iter + .hasNext();) { + BayesNetVar curVar = (BayesNetVar) iter.next(); + if (curVar.getParents(proposedWorld).contains(varToSample)) { + logProposalRatio += proposedWorld.getLogProbOfValue(curVar); + } } - } - - // Num Samples Drawn Thus Far - protected int totalNumSamples = 0; - protected int totalNumAccepted = 0; - protected int numSamplesThisTrial = 0; - protected int numAcceptedThisTrial = 0; - - // Generic Proposer to Handle State Transitions - protected GenericProposer proposer; - - // Properties - protected Properties properties; - - // Types of Identifiers - protected Set idTypes; - // The set of query variables - protected List queryVars = new ArrayList(); - protected Set evidenceVars = new HashSet(); + if (!proposedWorld.getVarsWithValue(Model.NULL).isEmpty() + || !evidence.isTrue(proposedWorld)) + logProposalRatio = Double.NEGATIVE_INFINITY; + weights[i] = Math.exp(logProposalRatio + + computeLogMultRatio(curWorld, proposedWorld)); + diffs[i] = proposedWorld; + } + + int idx = Util.sampleWithProbs(Util.normalize(weights)); + if (idx < 0) { + int mt = 250; + mt = mt + mt; + } + PartialWorldDiff selected = diffs[idx]; + + // Save the selected world + selected.save(); + } else { + // Infinite Domain Size so we fall back to MH Sampling + curWorld.save(); // make sure we start with saved world. + double logProposalRatio = ((GenericProposer) proposer).proposeNextState( + curWorld, varToSample); + + if (Util.verbose()) { + System.out.println(); + System.out.println("\tlog proposal ratio: " + logProposalRatio); + } + + double logProbRatio = computeLogProbRatio(curWorld.getSaved(), curWorld); + if (Util.verbose()) { + System.out.println("\tlog probability ratio: " + logProbRatio); + } + double logAcceptRatio = logProbRatio + logProposalRatio; + if (Util.verbose()) { + System.out.println("\tlog acceptance ratio: " + logAcceptRatio); + } + + // Accept or reject proposal + if ((logAcceptRatio >= 0) || (Util.random() < Math.exp(logAcceptRatio))) { + curWorld.save(); + if (Util.verbose()) { + System.out.println("\taccepted"); + } + ++totalNumAccepted; + ++numAcceptedThisTrial; + proposer.updateStats(true); + } else { + curWorld.revert(); // clean slate for next proposal + if (Util.verbose()) { + System.out.println("\trejected"); + } + proposer.updateStats(false); + } - // Generic Proposer to Handle the + } + } - protected PartialWorldDiff curWorld = null; + // Num Samples Drawn Thus Far + protected int totalNumSamples = 0; + protected int totalNumAccepted = 0; + protected int numSamplesThisTrial = 0; + protected int numAcceptedThisTrial = 0; } diff --git a/src/main/java/blog/sample/MHSampler.java b/src/main/java/blog/sample/MHSampler.java index bb00781a..93dae6dc 100644 --- a/src/main/java/blog/sample/MHSampler.java +++ b/src/main/java/blog/sample/MHSampler.java @@ -307,7 +307,7 @@ public double computeLogProbRatio(PartialWorld savedWorld, return logProbRatio; } - private double computeLogMultRatio(PartialWorld savedWorld, + public double computeLogMultRatio(PartialWorld savedWorld, PartialWorldDiff proposedWorld) { double logMultRatio = 0; diff --git a/src/main/java/blog/sample/ParentRecEvalContext.java b/src/main/java/blog/sample/ParentRecEvalContext.java index fc2fdd21..f4285a1e 100644 --- a/src/main/java/blog/sample/ParentRecEvalContext.java +++ b/src/main/java/blog/sample/ParentRecEvalContext.java @@ -35,7 +35,9 @@ package blog.sample; -import java.util.*; +import java.util.Collections; +import java.util.LinkedHashSet; +import java.util.Set; import blog.ObjectIdentifier; import blog.bn.BasicVar; @@ -53,80 +55,80 @@ * {@link #getOrComputeValue(BasicVar)} instead. */ public class ParentRecEvalContext extends DefaultEvalContext { - /** - * Creates a new ParentRecEvalContext using the given world. - */ - public ParentRecEvalContext(PartialWorld world) { - super(world); - } + /** + * Creates a new ParentRecEvalContext using the given world. + */ + public ParentRecEvalContext(PartialWorld world) { + super(world); + } - /** - * Creates a new ParentRecEvalContext using the given world. If the - * errorIfUndet flag is true, the access methods on this instance - * will print error messages and exit the program if the world is not complete - * enough to determine the correct return value. Otherwise they will just - * return null in such cases. - */ - public ParentRecEvalContext(PartialWorld world, boolean errorIfUndet) { - super(world, errorIfUndet); - } + /** + * Creates a new ParentRecEvalContext using the given world. If the + * errorIfUndet flag is true, the access methods on this instance + * will print error messages and exit the program if the world is not complete + * enough to determine the correct return value. Otherwise they will just + * return null in such cases. + */ + public ParentRecEvalContext(PartialWorld world, boolean errorIfUndet) { + super(world, errorIfUndet); + } - final public Object getValue(BasicVar var) { - Object value = getOrComputeValue(var); - if (value == null) { - latestUninstParent = var; - var.ensureStable(); - handleMissingVar(var); - } else { - if (parents.add(var)) { - var.ensureStable(); - } - } - return value; - } + public Object getValue(BasicVar var) { + Object value = getOrComputeValue(var); + if (value == null) { + latestUninstParent = var; + var.ensureStable(); + handleMissingVar(var); + } else { + if (parents.add(var)) { + var.ensureStable(); + } + } + return value; + } - protected Object getOrComputeValue(BasicVar var) { - return world.getValue(var); - } + protected Object getOrComputeValue(BasicVar var) { + return world.getValue(var); + } - // Note that we don't have to override getSatisfiers, because the - // DefaultEvalContext implementation of getSatisfiers calls getValue - // on the number variable + // Note that we don't have to override getSatisfiers, because the + // DefaultEvalContext implementation of getSatisfiers calls getValue + // on the number variable - public NumberVar getPOPAppSatisfied(Object obj) { - if (obj instanceof NonGuaranteedObject) { - return ((NonGuaranteedObject) obj).getNumberVar(); - } + public NumberVar getPOPAppSatisfied(Object obj) { + if (obj instanceof NonGuaranteedObject) { + return ((NonGuaranteedObject) obj).getNumberVar(); + } - if (obj instanceof ObjectIdentifier) { - parents.add(new OriginVar((ObjectIdentifier) obj)); - return world.getPOPAppSatisfied(obj); - } + if (obj instanceof ObjectIdentifier) { + parents.add(new OriginVar((ObjectIdentifier) obj)); + return world.getPOPAppSatisfied(obj); + } - // Must be guaranteed object, so not generated by any number var - return null; - } + // Must be guaranteed object, so not generated by any number var + return null; + } - /** - * Returns the set of basic random variables that are instantiated and whose - * values have been used in calls to the access methods. This set is backed by - * the ParentRecEvalContext and will change as more random variables are used. - * - * @return unmodifiable Set of BasicVar - */ - public Set getParents() { - return Collections.unmodifiableSet(parents); - } + /** + * Returns the set of basic random variables that are instantiated and whose + * values have been used in calls to the access methods. This set is backed by + * the ParentRecEvalContext and will change as more random variables are used. + * + * @return unmodifiable Set of BasicVar + */ + public Set getParents() { + return Collections.unmodifiableSet(parents); + } - /** - * Returns the variable whose value was most recently needed by an access - * method, but which is not instantiated. This method returns null if no such - * variable exists. - */ - public BasicVar getLatestUninstParent() { - return latestUninstParent; - } + /** + * Returns the variable whose value was most recently needed by an access + * method, but which is not instantiated. This method returns null if no such + * variable exists. + */ + public BasicVar getLatestUninstParent() { + return latestUninstParent; + } - protected Set parents = new LinkedHashSet(); // of BasicVar - protected BasicVar latestUninstParent = null; + protected Set parents = new LinkedHashSet(); // of BasicVar + protected BasicVar latestUninstParent = null; } diff --git a/src/main/java/blog/sample/Proposer.java b/src/main/java/blog/sample/Proposer.java index 4c2366e8..4efca796 100644 --- a/src/main/java/blog/sample/Proposer.java +++ b/src/main/java/blog/sample/Proposer.java @@ -35,12 +35,12 @@ package blog.sample; -import java.util.*; +import java.util.List; +import blog.bn.BayesNetVar; import blog.model.Evidence; -import blog.world.PartialWorldDiff; import blog.world.PartialWorld; -import blog.bn.BayesNetVar; +import blog.world.PartialWorldDiff; /** * Interface for Metropolis-Hastings and Gibbs proposal distributions. @@ -48,7 +48,7 @@ * chain, and propose a new state x' given any state x. It must also be able to * compute the proposal ratio q(x | x') / q(x' | x), where q is the proposal * distribution. - * + * *

* Implementations of the Proposer interface should have a constructor with two * arguments: a blog.Model object defining the prior distribution, and a @@ -62,7 +62,7 @@ public interface Proposer { * enough to answer the given queries. Furthermore, the proposer stores the * given evidence and queries so that proposeNextState can also * maintain these properties. - * + * * @param queries * List of Query objects */ @@ -74,16 +74,16 @@ public interface Proposer { * modify to create the proposal; the saved version that underlies this * PartialWorldDiff is the state before the proposal. Returns the log proposal * ratio: log (q(x | x') / q(x' | x)) - * + * *

* The proposed world satisfies the evidence and is complete enough to answer * the queries specified in the last call to initialize. - * + * * Note that if this proposal distribution is a mixture or cycle of more * elementary proposal distributions, the proposal probabilities q(x | x') and * q(x' | x) may be specific to the elementary distribution used for this * proposal. - * + * * @throws IllegalStateException * if initialize has not been called */ @@ -115,21 +115,24 @@ public interface Proposer { * X to the proposed value i. The proposed world satisfies the evidence * and is complete enough to answer the queries specified in the last * call to initialize. - * - *

Note that this function is only used within the Gibbs Sampler. No - * need to implement the function if Gibbs Sampler is not used. + * + *

+ * Note that this function is only used within the Gibbs Sampler. No need to + * implement the function if Gibbs Sampler is not used. */ - double proposeNextState(PartialWorldDiff proposedWorld, - BayesNetVar var, int i); + double proposeNextState(PartialWorldDiff proposedWorld, BayesNetVar var, + Object value); /** - * Propose a next state for the Markov Chain which resamples the given variable + * Propose a next state for the Markov Chain which resamples the given + * variable * X The proposed world satisfies the evidence * and is complete enough to answer the queries specified in the last * call to initialize. - * - *

Note that this function is only used within the Gibbs Sampler. No - * need to implement the function if Gibbs Sampler is not used. + * + *

+ * Note that this function is only used within the Gibbs Sampler. No need to + * implement the function if Gibbs Sampler is not used. */ double proposeNextState(PartialWorldDiff proposedWorld, BayesNetVar var); diff --git a/src/main/java/blog/sample/TraceParentRecEvalContext.java b/src/main/java/blog/sample/TraceParentRecEvalContext.java new file mode 100644 index 00000000..f8344b71 --- /dev/null +++ b/src/main/java/blog/sample/TraceParentRecEvalContext.java @@ -0,0 +1,189 @@ +/** + * + */ +package blog.sample; + +import java.util.Collections; +import java.util.HashSet; +import java.util.LinkedHashMap; +import java.util.LinkedList; +import java.util.List; +import java.util.Set; + +import blog.ObjectIdentifier; +import blog.bn.BasicVar; +import blog.bn.NumberVar; +import blog.bn.OriginVar; +import blog.bn.VarWithDistrib; +import blog.common.HashMapWithPreimages; +import blog.distrib.CondProbDistrib; +import blog.model.DependencyModel; +import blog.model.NonGuaranteedObject; +import blog.world.PartialWorld; + +/** + * A Evaluation Context class that could store the trace of visited BasicVars + * (when you try to get the distributions or values for some Basic or Derived + * variables). + * This class could also store the case levels in order to get the CBN decision + * tree structure (useful for the Open Universe Gibbs Sampling). + * + * @author Da Tang + * @since August 21, 2014 + */ +public class TraceParentRecEvalContext extends ClassicInstantiatingEvalContext { + + public TraceParentRecEvalContext(PartialWorld world) { + super(world); + } + + protected TraceParentRecEvalContext( + PartialWorld world, + LinkedHashMap respVarsAndContexts) { + super(world, respVarsAndContexts); + } + + protected Object getOrComputeValue(BasicVar var) { + Object value = world.getValue(var); + if (value == null) { + numCalculateNewVars++; + if (var instanceof VarWithDistrib) { + value = instantiate((VarWithDistrib) var); + } else { + throw new IllegalArgumentException("Don't know how to instantiate: " + + var); + } + } + return value; + } + + public Object getValue(BasicVar var) { + Object value = getOrComputeValue(var); + if (value == null) { + latestUninstParent = var; + var.ensureStable(); + handleMissingVar(var); + } else { + if (parents.add(var)) { + var.ensureStable(); + parentTrace.addLast(var); + caseLevelTrace.addLast(caseLevel); + } + } + dependentVar.add(var); + return value; + } + + public NumberVar getPOPAppSatisfied(Object obj) { + if (obj instanceof NonGuaranteedObject) { + return ((NonGuaranteedObject) obj).getNumberVar(); + } + + if (obj instanceof ObjectIdentifier) { + parents.add(new OriginVar((ObjectIdentifier) obj)); + parentTrace.addLast(new OriginVar((ObjectIdentifier) obj)); + caseLevelTrace.addLast(caseLevelTrace); + return world.getPOPAppSatisfied(obj); + } + + // Must be guaranteed object, so not generated by any number var + return null; + } + + protected Object instantiate(VarWithDistrib var) { + var.ensureStable(); + + /* + * if (Util.verbose()) { System.out.println("Need to instantiate: " + var); + * } + */ + + if (respVarsAndContexts.containsKey(var)) { + cycleError(var); + } + + // Create a new "child" context and get the distribution for + // var in that context. + respVarsAndContexts.put(var, this); + TraceParentRecEvalContext spawn = new TraceParentRecEvalContext(world, + respVarsAndContexts); + spawn.afterSamplingListener = afterSamplingListener; + DependencyModel.Distrib distrib = var.getDistrib(spawn); + logProb += spawn.getLogProbability(); + respVarsAndContexts.remove(var); + List parentTrace = spawn.getParentTrace(), caseLevelTrace = spawn + .getCaseLevelTrace(); + for (int i = 0; i < parentTrace.size(); i++) { + this.parentTrace.addLast(parentTrace.get(i)); + this.caseLevelTrace.addLast((Integer) caseLevelTrace.get(i) + + this.caseLevel); + } + this.caseLevel += spawn.getCaseLevel(); + + // Sample new value for var + CondProbDistrib cpd = distrib.getCPD(); + cpd.setParams(distrib.getArgValues()); + Object newValue = cpd.sampleVal(); + double logProbForThisValue = cpd.getLogProb(newValue); + logProb += logProbForThisValue; + + // Assert any identifiers that are used by var + Object[] args = var.args(); + for (int i = 0; i < args.length; ++i) { + if (args[i] instanceof ObjectIdentifier) { + world.assertIdentifier((ObjectIdentifier) args[i]); + } + } + if (newValue instanceof ObjectIdentifier) { + world.assertIdentifier((ObjectIdentifier) newValue); + } + + // Actually set value + world.setValue(var, newValue); + + if (afterSamplingListener != null) { + afterSamplingListener.evaluate(var, newValue, logProbForThisValue); + } + + if (staticAfterSamplingListener != null) { + staticAfterSamplingListener.evaluate(var, newValue, logProbForThisValue); + } + + /* + * if (Util.verbose()) { System.out.println("Instantiated: " + var); } + */ + + return newValue; + } + + public List getParentTrace() { + return Collections.unmodifiableList(parentTrace); + } + + public void increaseCaseLevel() { + caseLevel++; + } + + public List getCaseLevelTrace() { + return Collections.unmodifiableList(caseLevelTrace); + } + + public int getCaseLevel() { + return caseLevel; + } + + public int getNumCalculateNewVars() { + return numCalculateNewVars; + } + + public Set getDependentVar() { + return Collections.unmodifiableSet(dependentVar); + } + + private LinkedList parentTrace = new LinkedList(); + private LinkedList caseLevelTrace = new LinkedList(); + private int caseLevel = 0; + private HashMapWithPreimages assignment; + private int numCalculateNewVars = 0; + private Set dependentVar = new HashSet(); +} diff --git a/src/main/java/blog/world/PartialWorldDiff.java b/src/main/java/blog/world/PartialWorldDiff.java index 7e4b668d..655232f6 100644 --- a/src/main/java/blog/world/PartialWorldDiff.java +++ b/src/main/java/blog/world/PartialWorldDiff.java @@ -49,11 +49,9 @@ import blog.bn.PatchCBN; import blog.common.HashMapDiff; import blog.common.HashMultiMapDiff; -import blog.common.IndexedMultiMap; import blog.common.IndexedMultiMapDiff; import blog.common.MapDiff; import blog.common.MapWithPreimagesDiff; -import blog.common.MultiMap; import blog.common.MultiMapDiff; import blog.common.Util; @@ -65,332 +63,372 @@ * PartialWorldDiff unless they are removed with removeIdentifier. */ public class PartialWorldDiff extends AbstractPartialWorld { - /** - * Creates a new PartialWorldDiff with the given underlying world. This world - * uses object identifiers for the same types as the underlying world does. - */ - public PartialWorldDiff(PartialWorld underlying) { - super(underlying.getIdTypes()); - basicVarToValue = new HashMapDiff(underlying.basicVarToValueMap()); + /** + * Creates a new PartialWorldDiff with the given underlying world. This world + * uses object identifiers for the same types as the underlying world does. + */ + public PartialWorldDiff(PartialWorld underlying) { + super(underlying.getIdTypes()); + basicVarToValue = new HashMapDiff(underlying.basicVarToValueMap()); nameToBasicVar = new HashMapDiff(underlying.nameToBasicVarMap()); - objToUsesAsValue = new HashMultiMapDiff(underlying.objToUsesAsValueMap()); - objToUsesAsArg = new HashMultiMapDiff(underlying.objToUsesAsArgMap()); - assertedIdToPOPApp = new HashMapDiff(underlying.assertedIdToPOPAppMap()); - popAppToAssertedIds = new IndexedMultiMapDiff( - underlying.popAppToAssertedIdsMap()); - commIdToPOPApp = new HashMapDiff(underlying.assertedIdToPOPAppMap()); - popAppToCommIds = new IndexedMultiMapDiff( - underlying.popAppToAssertedIdsMap()); - cbn = new PatchCBN(underlying.getCBN()); - varToUninstParent = new MapWithPreimagesDiff( - underlying.varToUninstParentMap()); - varToLogProb = new HashMapDiff(underlying.varToLogProbMap()); - derivedVarToValue = new HashMapDiff(underlying.derivedVarToValueMap()); - - savedWorld = underlying; - } - - /** - * Creates a new PartialWorldDiff whose underlying world is - * underlying, and whose current version is set equal to - * toCopy. - */ - public PartialWorldDiff(PartialWorld underlying, PartialWorld toCopy) { - this(underlying); - - for (Iterator iter = toCopy.getAssertedIdentifiers().iterator(); iter - .hasNext();) { - ObjectIdentifier id = (ObjectIdentifier) iter.next(); - assertIdentifier(id, toCopy.getPOPAppSatisfied(id)); - } - - for (Iterator iter = toCopy.getInstantiatedVars().iterator(); iter - .hasNext();) { - BasicVar var = (BasicVar) iter.next(); - setValue(var, toCopy.getValue(var)); - } - - for (Iterator iter = toCopy.getDerivedVars().iterator(); iter.hasNext();) { - addDerivedVar((DerivedVar) iter.next()); - } - } - - /** - * Returns the saved version of this world. The returned PartialWorld object - * is updated as new versions are saved. - */ - public PartialWorld getSaved() { - return savedWorld; - } - - /** - * Changes the saved version of this world to equal the current version. - */ - public void save() { - for (Iterator iter = getIdsWithChangedPOPApps().iterator(); iter.hasNext();) { - ObjectIdentifier id = (ObjectIdentifier) iter.next(); - NumberVar newPOPApp = (NumberVar) assertedIdToPOPApp.get(id); - if (newPOPApp == null) { - savedWorld.removeIdentifier(id); - } else { - savedWorld.assertIdentifier(id, newPOPApp); - } - } - - for (Iterator iter = getChangedVars().iterator(); iter.hasNext();) { - BasicVar var = (BasicVar) iter.next(); - savedWorld.setValue(var, getValue(var)); - } - - Set derivedVars = ((MapDiff) derivedVarToValue).getChangedKeys(); - for (Iterator iter = derivedVars.iterator(); iter.hasNext();) { - DerivedVar var = (DerivedVar) iter.next(); - if (derivedVarToValue.containsKey(var)) { // not removed - savedWorld.addDerivedVar(var); // no effect if already there - } - } - - updateParentsAndProbs(); - savedWorld.updateCBN(cbn, varToUninstParent, varToLogProb, - derivedVarToValue); - - clearChanges(); // since underlying is now updated - - for (Iterator iter = diffListeners.iterator(); iter.hasNext();) { - WorldDiffListener listener = (WorldDiffListener) iter.next(); - listener.notifySaved(); - } - } - - /** - * Changes this world to equal the saved version. Warning: WorldListener - * objects will not be notified of changes to the values of basic variables - * made by this method. - */ - public void revert() { - clearChanges(); - clearCommIdChanges(); - - for (Iterator iter = diffListeners.iterator(); iter.hasNext();) { - WorldDiffListener listener = (WorldDiffListener) iter.next(); - listener.notifyReverted(); - } - } - - /** - * Returns the set of variables that have different values in the current - * world than they do in the saved world. This includes variables that are - * instantiated in this world and not the saved world, or vice versa. - * - * @return unmodifiable Set of BasicVar - */ - public Set getChangedVars() { - return ((MapDiff) basicVarToValue).getChangedKeys(); - } - - /** - * Returns the set of objects that serve as values for a different set of - * basic RVs in this world than they do in the saved world. This may include - * objects that exist in this world but not the saved world, or vice versa. - */ - public Set getObjsWithChangedUsesAsValue() { - return ((MultiMapDiff) objToUsesAsValue).getChangedKeys(); - } - - /** - * Returns the set of object identifiers that are asserted in either this - * world or the saved world, and that satisfy a different POP application in - * this world than in the saved world. This includes object identifiers that - * are asserted in this world and not the saved world, or vice versa. - */ - public Set getIdsWithChangedPOPApps() { - return ((MapDiff) assertedIdToPOPApp).getChangedKeys(); - } - - /** - * Returns the set of POP applications whose set of asserted identifiers is - * different in this world from in the saved world. - */ - public Set getPOPAppsWithChangedIds() { - return ((MultiMapDiff) popAppToAssertedIds).getChangedKeys(); - } - - /** - * Returns the Set of BayesNetVar objects V such that the probability P(V | - * parents(V)) is not the same in this world as in the saved world. This may - * be because the value of V has changed or because the values of some of V's - * parents have changed. The returned set also includes any DerivedVars whose - * value has changed. - * - * @return unmodifiable Set of BayesNetVar - */ - public Set getVarsWithChangedProbs() { - updateParentsAndProbs(); - - HashSet results = new HashSet(); - results.addAll(((MapDiff) varToLogProb).getChangedKeys()); - results.addAll(((MapDiff) derivedVarToValue).getChangedKeys()); - return results; - } - - /** - * Returns the set of variables that are barren in this world but either are - * not in the graph or are not barren in the saved world. A barren variable is - * one with no children. - * - * @return unmodifiable Set of BayesNetVar - */ - public Set getNewlyBarrenVars() { - updateParentsAndProbs(); - return ((PatchCBN) cbn).getNewlyBarrenNodes(); - } - - /** - * Returns the set of identifiers that are floating in this world and not the - * saved world. An identifier is floating if it is used as an argument of some - * basic variable, but is not the value of any basic variable. - * - * @return unmodifiable Set of ObjectIdentifier - */ - public Set getNewlyFloatingIds() { - Set newlyFloating = new HashSet(); - - // Scan changed and newly instantiated vars, looking for new - // arguments and old values that are now floating. - for (Iterator iter = getChangedVars().iterator(); iter.hasNext();) { - BasicVar var = (BasicVar) iter.next(); - if (getValue(var) != null) { - for (int i = 0; i < var.args().length; ++i) { - Object arg = var.args()[i]; - if ((arg instanceof ObjectIdentifier) - && getVarsWithValue(arg).isEmpty() - && (getSaved().getVarsWithArg(arg).isEmpty() || !getSaved() - .getVarsWithValue(arg).isEmpty())) { - newlyFloating.add(arg); - } - } - } - - if (getSaved().getValue(var) != null) { - Object oldValue = getSaved().getValue(var); - if ((oldValue instanceof ObjectIdentifier) - && (assertedIdToPOPApp.get(oldValue) != null) - && getVarsWithValue(oldValue).isEmpty() - && !getVarsWithArg(oldValue).isEmpty()) { - newlyFloating.add(oldValue); - } - } - } - - return Collections.unmodifiableSet(newlyFloating); - } - - /** - * Returns the set of number variables that are overloaded in this world but - * not the saved world. A number variable is overloaded if the number of - * identifiers asserted to satisfy it is greater than its value, or it is not - * instantiated and one or more identifiers are still asserted to satisfy it. - * - * @return unmodifiable Set of NumberVar - */ - public Set getNewlyOverloadedNumberVars() { - Set newlyOverloaded = new HashSet(); - - // Iterate over variables whose values changed, look for number vars - for (Iterator iter = getChangedVars().iterator(); iter.hasNext();) { - BasicVar var = (BasicVar) iter.next(); - if (var instanceof NumberVar) { - NumberVar nv = (NumberVar) var; - if (isOverloaded(nv) && !getSaved().isOverloaded(nv)) { - newlyOverloaded.add(nv); + objToUsesAsValue = new HashMultiMapDiff(underlying.objToUsesAsValueMap()); + objToUsesAsArg = new HashMultiMapDiff(underlying.objToUsesAsArgMap()); + assertedIdToPOPApp = new HashMapDiff(underlying.assertedIdToPOPAppMap()); + popAppToAssertedIds = new IndexedMultiMapDiff( + underlying.popAppToAssertedIdsMap()); + commIdToPOPApp = new HashMapDiff(underlying.assertedIdToPOPAppMap()); + popAppToCommIds = new IndexedMultiMapDiff( + underlying.popAppToAssertedIdsMap()); + cbn = new PatchCBN(underlying.getCBN()); + varToUninstParent = new MapWithPreimagesDiff( + underlying.varToUninstParentMap()); + varToLogProb = new HashMapDiff(underlying.varToLogProbMap()); + derivedVarToValue = new HashMapDiff(underlying.derivedVarToValueMap()); + + savedWorld = underlying; + } + + /** + * Creates a new PartialWorldDiff whose underlying world is + * underlying, and whose current version is set equal to + * toCopy. + */ + public PartialWorldDiff(PartialWorld underlying, PartialWorld toCopy) { + this(underlying); + + for (Iterator iter = toCopy.getAssertedIdentifiers().iterator(); iter + .hasNext();) { + ObjectIdentifier id = (ObjectIdentifier) iter.next(); + assertIdentifier(id, toCopy.getPOPAppSatisfied(id)); + } + + Set toRemoveAssertedIndentifiers = new HashSet(); + toRemoveAssertedIndentifiers.addAll(underlying.getAssertedIdentifiers()); + toRemoveAssertedIndentifiers.removeAll(toCopy.getAssertedIdentifiers()); + for (Iterator iter = toRemoveAssertedIndentifiers.iterator(); iter + .hasNext();) { + ObjectIdentifier id = (ObjectIdentifier) iter.next(); + removeIdentifier(id); + } + + for (Iterator iter = toCopy.getInstantiatedVars().iterator(); iter + .hasNext();) { + BasicVar var = (BasicVar) iter.next(); + setValue(var, toCopy.getValue(var)); + } + + Set toRemoveBasicVars = new HashSet(); + toRemoveBasicVars.addAll(underlying.getInstantiatedVars()); + toRemoveBasicVars.removeAll(toCopy.getInstantiatedVars()); + for (Iterator iter = toRemoveBasicVars.iterator(); iter.hasNext();) { + BasicVar var = (BasicVar) iter.next(); + setValue(var, null); + } + + for (Iterator iter = toCopy.getDerivedVars().iterator(); iter.hasNext();) { + addDerivedVar((DerivedVar) iter.next()); + } + + Set toRemoveDerivedVars = new HashSet(); + toRemoveDerivedVars.addAll(underlying.getDerivedVars()); + toRemoveDerivedVars.removeAll(toCopy.getDerivedVars()); + for (Iterator iter = toRemoveDerivedVars.iterator(); iter.hasNext();) { + removeDerivedVar((DerivedVar) iter.next()); + } + + updateParentsAndProbs(); + updateCBN(cbn, varToUninstParent, varToLogProb, derivedVarToValue); + } + + /** + * Returns the saved version of this world. The returned PartialWorld object + * is updated as new versions are saved. + */ + public PartialWorld getSaved() { + return savedWorld; + } + + /** + * Changes the saved version of this world to equal the current version. + */ + public void save() { + for (Iterator iter = getIdsWithChangedPOPApps().iterator(); iter.hasNext();) { + ObjectIdentifier id = (ObjectIdentifier) iter.next(); + NumberVar newPOPApp = (NumberVar) assertedIdToPOPApp.get(id); + if (newPOPApp == null) { + savedWorld.removeIdentifier(id); + } else { + savedWorld.assertIdentifier(id, newPOPApp); + } + } + + for (Iterator iter = getChangedVars().iterator(); iter.hasNext();) { + BasicVar var = (BasicVar) iter.next(); + savedWorld.setValue(var, getValue(var)); + } + + Set derivedVars = ((MapDiff) derivedVarToValue).getChangedKeys(); + for (Iterator iter = derivedVars.iterator(); iter.hasNext();) { + DerivedVar var = (DerivedVar) iter.next(); + if (derivedVarToValue.containsKey(var)) { // not removed + savedWorld.addDerivedVar(var); // no effect if already there + } + } + + updateParentsAndProbs(); + savedWorld.updateCBN(cbn, varToUninstParent, varToLogProb, + derivedVarToValue); + + changeUnderlying(); + + for (Iterator iter = diffListeners.iterator(); iter.hasNext();) { + WorldDiffListener listener = (WorldDiffListener) iter.next(); + listener.notifySaved(); + } + } + + /** + * Changes this world to equal the saved version. Warning: WorldListener + * objects will not be notified of changes to the values of basic variables + * made by this method. + */ + public void revert() { + clearChanges(); + clearCommIdChanges(); + + for (Iterator iter = diffListeners.iterator(); iter.hasNext();) { + WorldDiffListener listener = (WorldDiffListener) iter.next(); + listener.notifyReverted(); + } + } + + /** + * Returns the set of variables that have different values in the current + * world than they do in the saved world. This includes variables that are + * instantiated in this world and not the saved world, or vice versa. + * + * @return unmodifiable Set of BasicVar + */ + public Set getChangedVars() { + return ((MapDiff) basicVarToValue).getChangedKeys(); + } + + /** + * Returns the set of objects that serve as values for a different set of + * basic RVs in this world than they do in the saved world. This may include + * objects that exist in this world but not the saved world, or vice versa. + */ + public Set getObjsWithChangedUsesAsValue() { + return ((MultiMapDiff) objToUsesAsValue).getChangedKeys(); + } + + /** + * Returns the set of object identifiers that are asserted in either this + * world or the saved world, and that satisfy a different POP application in + * this world than in the saved world. This includes object identifiers that + * are asserted in this world and not the saved world, or vice versa. + */ + public Set getIdsWithChangedPOPApps() { + return ((MapDiff) assertedIdToPOPApp).getChangedKeys(); + } + + /** + * Returns the set of POP applications whose set of asserted identifiers is + * different in this world from in the saved world. + */ + public Set getPOPAppsWithChangedIds() { + return ((MultiMapDiff) popAppToAssertedIds).getChangedKeys(); + } + + /** + * Returns the Set of BayesNetVar objects V such that the probability P(V | + * parents(V)) is not the same in this world as in the saved world. This may + * be because the value of V has changed or because the values of some of V's + * parents have changed. The returned set also includes any DerivedVars whose + * value has changed. + * + * @return unmodifiable Set of BayesNetVar + */ + public Set getVarsWithChangedProbs() { + updateParentsAndProbs(); + + HashSet results = new HashSet(); + results.addAll(((MapDiff) varToLogProb).getChangedKeys()); + results.addAll(((MapDiff) derivedVarToValue).getChangedKeys()); + return results; + } + + /** + * Returns the set of variables that are barren in this world but either are + * not in the graph or are not barren in the saved world. A barren variable is + * one with no children. + * + * @return unmodifiable Set of BayesNetVar + */ + public Set getNewlyBarrenVars() { + updateParentsAndProbs(); + return ((PatchCBN) cbn).getNewlyBarrenNodes(); + } + + /** + * Returns the set of identifiers that are floating in this world and not the + * saved world. An identifier is floating if it is used as an argument of some + * basic variable, but is not the value of any basic variable. + * + * @return unmodifiable Set of ObjectIdentifier + */ + public Set getNewlyFloatingIds() { + Set newlyFloating = new HashSet(); + + // Scan changed and newly instantiated vars, looking for new + // arguments and old values that are now floating. + for (Iterator iter = getChangedVars().iterator(); iter.hasNext();) { + BasicVar var = (BasicVar) iter.next(); + if (getValue(var) != null) { + for (int i = 0; i < var.args().length; ++i) { + Object arg = var.args()[i]; + if ((arg instanceof ObjectIdentifier) + && getVarsWithValue(arg).isEmpty() + && (getSaved().getVarsWithArg(arg).isEmpty() || !getSaved() + .getVarsWithValue(arg).isEmpty())) { + newlyFloating.add(arg); + } + } + } + + if (getSaved().getValue(var) != null) { + Object oldValue = getSaved().getValue(var); + if ((oldValue instanceof ObjectIdentifier) + && (assertedIdToPOPApp.get(oldValue) != null) + && getVarsWithValue(oldValue).isEmpty() + && !getVarsWithArg(oldValue).isEmpty()) { + newlyFloating.add(oldValue); + } + } + } + + return Collections.unmodifiableSet(newlyFloating); + } + + /** + * Returns the set of number variables that are overloaded in this world but + * not the saved world. A number variable is overloaded if the number of + * identifiers asserted to satisfy it is greater than its value, or it is not + * instantiated and one or more identifiers are still asserted to satisfy it. + * + * @return unmodifiable Set of NumberVar + */ + public Set getNewlyOverloadedNumberVars() { + Set newlyOverloaded = new HashSet(); + + // Iterate over variables whose values changed, look for number vars + for (Iterator iter = getChangedVars().iterator(); iter.hasNext();) { + BasicVar var = (BasicVar) iter.next(); + if (var instanceof NumberVar) { + NumberVar nv = (NumberVar) var; + if (isOverloaded(nv) && !getSaved().isOverloaded(nv)) { + newlyOverloaded.add(nv); Util.debug("Number var ", nv, " with value ", getValue(nv), " is overloaded by ", popAppToAssertedIds.get(nv)); - } - } - } - - // Iterate over number variables with a changed set of asserted IDs - for (Iterator iter = getPOPAppsWithChangedIds().iterator(); iter.hasNext();) { - NumberVar nv = (NumberVar) iter.next(); - if (isOverloaded(nv) && !getSaved().isOverloaded(nv)) { - newlyOverloaded.add(nv); - } - } - - return newlyOverloaded; - } - - /** - * Returns the set of number variables that yield different probability - * multipliers in this world than they do in the saved world. These are the - * number variables that have different values or different numbers of - * asserted identifiers in this world and the saved world, and have at least - * one asserted identifier in this world or the saved world. - * - * @return unmodifiable Set of NumberVar - */ - public Set getVarsWithChangedMultipliers() { - Set changedMultipliers = new HashSet(); - - // Iterate over variables whose values changed, look for number vars - for (Iterator iter = getChangedVars().iterator(); iter.hasNext();) { - BasicVar var = (BasicVar) iter.next(); - if (var instanceof NumberVar) { - NumberVar nv = (NumberVar) var; - if ((getAssertedIdsForPOPApp(nv).size() > 0) - || (getSaved().getAssertedIdsForPOPApp(nv).size() > 0)) { - changedMultipliers.add(nv); - } - } - } - - // Iterate over number variables whose set of asserted identifiers - // changed - for (Iterator iter = getPOPAppsWithChangedIds().iterator(); iter.hasNext();) { - changedMultipliers.add(iter.next()); - } - - return changedMultipliers; - } - - /** - * Adds the given object to the list of listeners that will be notified when - * this PartialWorldDiff is saved or reverted. - */ - public void addDiffListener(WorldDiffListener listener) { - if (!diffListeners.contains(listener)) { - diffListeners.add(listener); - } - } - - /** - * Removes the given object from the list of listeners that will be notified - * when this PartialWorldDiff is saved or reverted. - */ - public void removeDiffListener(WorldDiffListener listener) { - diffListeners.remove(listener); - } - - private void clearChanges() { - ((MapDiff) basicVarToValue).clearChanges(); - ((MultiMapDiff) objToUsesAsValue).clearChanges(); - ((MultiMapDiff) objToUsesAsArg).clearChanges(); - ((MapDiff) assertedIdToPOPApp).clearChanges(); - ((MultiMapDiff) popAppToAssertedIds).clearChanges(); - ((PatchCBN) cbn).clearChanges(); - ((MapDiff) varToLogProb).clearChanges(); - ((MapDiff) derivedVarToValue).clearChanges(); - - dirtyVars.clear(); - } - - private void clearCommIdChanges() { - ((MapDiff) commIdToPOPApp).clearChanges(); - ((MultiMapDiff) popAppToCommIds).clearChanges(); - } - - private PartialWorld savedWorld; -// private PatchCBN cbn; - - private List diffListeners = new ArrayList(); // of WorldDiffListener + } + } + } + + // Iterate over number variables with a changed set of asserted IDs + for (Iterator iter = getPOPAppsWithChangedIds().iterator(); iter.hasNext();) { + NumberVar nv = (NumberVar) iter.next(); + if (isOverloaded(nv) && !getSaved().isOverloaded(nv)) { + newlyOverloaded.add(nv); + } + } + + return newlyOverloaded; + } + + /** + * Returns the set of number variables that yield different probability + * multipliers in this world than they do in the saved world. These are the + * number variables that have different values or different numbers of + * asserted identifiers in this world and the saved world, and have at least + * one asserted identifier in this world or the saved world. + * + * @return unmodifiable Set of NumberVar + */ + public Set getVarsWithChangedMultipliers() { + Set changedMultipliers = new HashSet(); + + // Iterate over variables whose values changed, look for number vars + for (Iterator iter = getChangedVars().iterator(); iter.hasNext();) { + BasicVar var = (BasicVar) iter.next(); + if (var instanceof NumberVar) { + NumberVar nv = (NumberVar) var; + if ((getAssertedIdsForPOPApp(nv).size() > 0) + || (getSaved().getAssertedIdsForPOPApp(nv).size() > 0)) { + changedMultipliers.add(nv); + } + } + } + + // Iterate over number variables whose set of asserted identifiers + // changed + for (Iterator iter = getPOPAppsWithChangedIds().iterator(); iter.hasNext();) { + changedMultipliers.add(iter.next()); + } + + return changedMultipliers; + } + + /** + * Adds the given object to the list of listeners that will be notified when + * this PartialWorldDiff is saved or reverted. + */ + public void addDiffListener(WorldDiffListener listener) { + if (!diffListeners.contains(listener)) { + diffListeners.add(listener); + } + } + + /** + * Removes the given object from the list of listeners that will be notified + * when this PartialWorldDiff is saved or reverted. + */ + public void removeDiffListener(WorldDiffListener listener) { + diffListeners.remove(listener); + } + + private void changeUnderlying() { + ((MapDiff) basicVarToValue).changeUnderlying(); + ((MultiMapDiff) objToUsesAsValue).changeUnderlying(); + ((MultiMapDiff) objToUsesAsArg).changeUnderlying(); + ((MapDiff) assertedIdToPOPApp).changeUnderlying(); + ((MultiMapDiff) popAppToAssertedIds).changeUnderlying(); + ((PatchCBN) cbn).changeUnderlying(); + ((MapDiff) varToLogProb).changeUnderlying(); + ((MapDiff) derivedVarToValue).changeUnderlying(); + + dirtyVars.clear(); + } + + private void clearChanges() { + ((MapDiff) basicVarToValue).clearChanges(); + ((MultiMapDiff) objToUsesAsValue).clearChanges(); + ((MultiMapDiff) objToUsesAsArg).clearChanges(); + ((MapDiff) assertedIdToPOPApp).clearChanges(); + ((MultiMapDiff) popAppToAssertedIds).clearChanges(); + ((PatchCBN) cbn).clearChanges(); + ((MapDiff) varToLogProb).clearChanges(); + ((MapDiff) derivedVarToValue).clearChanges(); + + dirtyVars.clear(); + } + + private void clearCommIdChanges() { + ((MapDiff) commIdToPOPApp).clearChanges(); + ((MultiMapDiff) popAppToCommIds).clearChanges(); + } + + private PartialWorld savedWorld; + // private PatchCBN cbn; + + private List diffListeners = new ArrayList(); // of WorldDiffListener }