Skip to content

Commit

Permalink
Added elliptical slice sampling
Browse files Browse the repository at this point in the history
HMC accepted step is green rather than blue
shermanlo77 committed Apr 7, 2021

Verified

This commit was signed with the committer’s verified signature.
emma-sg Emma Segal-Grossman
1 parent d4d7ade commit 5239397
Showing 10 changed files with 362 additions and 18 deletions.
28 changes: 17 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# oxwasp_exchange_mcmc

*Java* implementations of the **Metropolis-Hastings** (Metropolis et al., 1953) (Hastings, 1970), **Adaptive Metropolis-Hastings** (Haario et al., 2001) (Roberts and Rosenthal, 2009), **Hamiltonian Monte Carlo** (Neal, 2011) and **No U-Turn Sampler** (Hoffman and Gelman, 2014) including the dual averaging version.
*Java* implementations of the **Metropolis-Hastings** (Metropolis et al., 1953) (Hastings, 1970), **Adaptive Metropolis-Hastings** (Haario et al., 2001) (Roberts and Rosenthal, 2009), **Hamiltonian Monte Carlo** (Neal, 2011), **No U-Turn Sampler** (Hoffman and Gelman, 2014) including the dual averaging version and **Elliptical Slice Sampler** (Murray et al., 2010).

Also included are *Processing* implementations of the algorithms for visualising these algorithms sampling a 2D Normal distribution.

@@ -11,18 +11,10 @@ For dependencies, see the `pom.xml` file or:
* `org.apache.commons.math3` [Commons Math: The Apache Commons Mathematics Library](http://commons.apache.org/proper/commons-math/)
* `org.ejml` [Efficient Java Matrix Library](http://ejml.org/wiki/index.php?title=Main_Page)

## References
* Haario, H., Saksman, E., Tamminen, J., et al. (2001). An adaptive Metropolis algorithm. _Bernoulli_, 7(2):223-242.
* Hastings, W. K. (1970). Monte Carlo sampling methods using Markov chains and their applications. _Biometrika_ 57(1):97-109.
* Hoffman, M. D. and Gelman, A. (2014). The No-U-turn sampler: Adaptively setting path lengths in Hamiltonian Monte Carlo. _Journal of Machine Learning Research_, 15(1):1593-1623.
* Metropolis, N., Rosenbluth, A. W., Rosenbluth, M. N., Teller, A. H., and Teller, E. (1953). Equation of state calculations by fast computing machines. _The Journal of Chemical Physics_, 21(6):1087–1092.
* Neal, R. M. (2011). MCMC using Hamiltonian dynamics. In Brooks, S., Gelman, A., Jones, G., and Meng, X.-L., editors, _Handbook of Markov Chain Monte Carlo_, chapter 5, pages 113–162. CRC press.
* Roberts, G. O. and Rosenthal, J. S. (2009). Examples of adaptive MCMC. _Journal of Computational and Graphical Statistics_, 18(2):349–367.

## How to use (Linux recommended)
Call the `.jar` file using
```
java -jar oxwasp_exchange_mcmc-0.0.1-jar-with-dependencies.jar -option
java -jar oxwasp_exchange_mcmc-1.0.0-jar-with-dependencies.jar -option
```
where `-option` can be one of the following options below. Click on `pause` to pause the simulation, once paused, click on `step` to run the simulation one step at a time. Click on `quit` to quit the simulation.

@@ -48,6 +40,11 @@ Click on the screen to start a No U-Turn Sampler. The yellow dots show the leap

![alt text](tex/processing_nuts.png "No U-Turn Sampler")

### `-slice` Elliptical Slice Sampler
Click on the screen to start an elliptical slice sampler. The next step of the chain is searched for on the ellipse. The yellow sections show acceptable regions of the ellipse. The red sections show unacceptable regions of the ellipse. Red dots shows rejected samples.

![alt text](tex/processing_slice.png "Elliptical Slice Sampler")

## How to compile (Linux recommended)
*Maven* required.

@@ -62,4 +59,13 @@ Go to the repository and run
```
mvn package
```
and the `.jar` files are located in `/target/`.
and the `.jar` files are located in `target/`.

## References
* Haario, H., Saksman, E., Tamminen, J., et al. (2001). An adaptive Metropolis algorithm. _Bernoulli_, 7(2):223-242.
* Hastings, W. K. (1970). Monte Carlo sampling methods using Markov chains and their applications. _Biometrika_ 57(1):97-109.
* Hoffman, M. D. and Gelman, A. (2014). The No-U-turn sampler: Adaptively setting path lengths in Hamiltonian Monte Carlo. _Journal of Machine Learning Research_, 15(1):1593-1623.
* Metropolis, N., Rosenbluth, A. W., Rosenbluth, M. N., Teller, A. H., and Teller, E. (1953). Equation of state calculations by fast computing machines. _The Journal of Chemical Physics_, 21(6):1087–1092.
* Murray, I., Adams, R. P., and MacKay, D. J. (2010). Elliptical slice sampling. _In Proceedings of the 13th International Conference on Artificial Intelligence and Statistics_.
* Neal, R. M. (2011). MCMC using Hamiltonian dynamics. In Brooks, S., Gelman, A., Jones, G., and Meng, X.-L., editors, _Handbook of Markov Chain Monte Carlo_, chapter 5, pages 113–162. CRC press.
* Roberts, G. O. and Rosenthal, J. S. (2009). Examples of adaptive MCMC. _Journal of Computational and Graphical Statistics_, 18(2):349–367.
2 changes: 1 addition & 1 deletion pom.xml
Original file line number Diff line number Diff line change
@@ -2,7 +2,7 @@
<modelVersion>4.0.0</modelVersion>
<groupId>uk.ac.warwick.sip</groupId>
<artifactId>oxwasp_exchange_mcmc</artifactId>
<version>0.0.1</version>
<version>1.0.0</version>
<packaging>jar</packaging>
<properties>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
172 changes: 172 additions & 0 deletions src/uk/ac/warwick/sip/mcmc/EllipticalSlice.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
/*
* Copyright 2018-2021 Sherman Lo
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package uk.ac.warwick.sip.mcmc;

import java.util.ArrayList;
import java.util.Iterator;

import org.apache.commons.math3.random.MersenneTwister;
import org.ejml.dense.row.CommonOps_DDRM;
import org.ejml.simple.SimpleMatrix;

/**CLASS: ELLIPTICAL SLICE SAMPLER
* Sampler which uses elliptical slice sampling
* Reference: Elliptical slice sampling (Murray, I., Adams, R. P., and MacKay, D. J. (2010).
* Elliptical slice sampling. In Proceedings of the 13th International Conference on Artificial
* Intelligence and Statistics.)
* At a step, a sample from the prior is taken. A weighted average between the current step and that
* sample is used as the next step in the MCMC according to the slice sampling scheme.
*/
public class EllipticalSlice extends Mcmc{

//the likelihood of the model, elliptical slice sampling works on the likelihood rather than the
//posterior
protected TargetDistribution likelihood;
//the prior distribution
protected NormalDistribution prior;
//position vector of the start of the ellipse (the current step)
protected SimpleMatrix ellipse0;
//position vector of the end of the ellipse (the sample from the prior)
protected SimpleMatrix proposal;
//after a step, an array of points looked at on the ellipse
protected ArrayList<SimpleMatrix> ellipticalPositions;
//slice variable, log uniform number plus log likelihood
protected double sliceVariable;

/**CONSTRUCTOR
* Elliptical slice sampling algorithm. Samples from the prior (called the proposal in this code)
* and then uses slice sampling to find a weighted combination of the proposed and the current
* value, which turns out to make an ellipse.
* @param target Object which has a method to call the pdf
* @param likelihood the likelihood of the model
* @param prior NormalDistribution representing the prior
* @param chainLength Length of the chain to be obtained
* @param rng Random number generator all the random numbers
*/
public EllipticalSlice(TargetDistribution target, TargetDistribution likelihood,
NormalDistribution prior, int chainLength, MersenneTwister rng) {
super(target, chainLength, rng);
this.likelihood = likelihood;
this.prior = prior;
this.ellipticalPositions = new ArrayList<SimpleMatrix>();
}

/**CONSTRUCTOR
* Constructor for extending the length of the chain and resume running it
* Does a shallow copy of the provided chain and extending the member variable chainArray
* @param chain Chain to be extended
* @param nMoreSteps Number of steps to be extended
*/
public EllipticalSlice(EllipticalSlice chain, int nMoreSteps) {
//shallow copies of member variables
super(chain, nMoreSteps);
this.likelihood = chain.likelihood;
this.prior = chain.prior;
this.proposal = chain.proposal;
this.ellipticalPositions = chain.ellipticalPositions;
this.sliceVariable = chain.sliceVariable;
}

@Override
public void step(SimpleMatrix position) {
//elliptical slice sampling is a weighted sum of position and a sample from the prior
this.ellipse0 = new SimpleMatrix(position); //start of ellipse
this.proposal = this.prior.sample(this.rng); //end of ellipse
//array to store all points looked at on the ellipse
this.ellipticalPositions = new ArrayList<SimpleMatrix>();

double angle = this.sampleAngle(0, 2*Math.PI);
double angleMin = angle - 2*Math.PI;
double angleMax = angle;

//sliceVariable is compared with the likelihood at points on the ellipse
this.sliceVariable = Math.log(this.rng.nextDouble()) - this.likelihood.getPotential(position);
boolean gotSample = false;

//look through the ellipse until got valid sample
while (!gotSample) {
position.set(this.getPointOnEllipse(angle));
if (this.isValidPointOnEllipse(position)) {
gotSample = true;
} else {
//else this is not a valid point, look elsewhere on the ellipse
this.ellipticalPositions.add(new SimpleMatrix(position));
if (angle < 0) {
angleMin = angle;
} else {
angleMax = angle;
}
angle = this.sampleAngle(angleMin, angleMax);
}
}
this.updateStatistics(position);
}

/**METHOD: IS VALID POINT ON ELLIPSE
* Check if this position vector would be accepted under the slice sampling scheme. This would
* require the member variable sliceVariable initalised with the correct value.
* @param position position vector (on the ellipse)
* @return boolean if the position vector is valid under slice sampling
*/
public boolean isValidPointOnEllipse(SimpleMatrix position) {
return -this.likelihood.getPotential(position) > this.sliceVariable;
}

/**METHOD: GET POINT ON ELLIPSE
* Return a weighted sum of initial and proposal, the weight using cosine and sine of angle
* @param initial vector of current position of chain
* @param proposal vector of sample from prior
* @param angle of the ellipse
* @return vector, weighted sum of initial and proposal
*/
public SimpleMatrix getPointOnEllipse(double angle) {
SimpleMatrix initialWeighted = new SimpleMatrix(this.ellipse0);
SimpleMatrix proposalWeighted = new SimpleMatrix(this.proposal);

//centre the position vectors at the prior mean
CommonOps_DDRM.subtractEquals(initialWeighted.getDDRM(), this.prior.mean.getDDRM());
CommonOps_DDRM.subtractEquals(proposalWeighted.getDDRM(), this.prior.mean.getDDRM());

CommonOps_DDRM.scale(Math.cos(angle), initialWeighted.getDDRM());
CommonOps_DDRM.scale(Math.sin(angle), proposalWeighted.getDDRM());

CommonOps_DDRM.addEquals(proposalWeighted.getDDRM(), initialWeighted.getDDRM());
CommonOps_DDRM.addEquals(proposalWeighted.getDDRM(), this.prior.mean.getDDRM());

return proposalWeighted;
}

/**METHOD: GET ELLIPTICAL POSITIONS ITERATOR
* Return an iterator which iterates through all the points looked at on the ellipse
*/
public Iterator<SimpleMatrix> getEllipticalPositionsIterator(){
return this.ellipticalPositions.iterator();
}

/**METHOD: SAMPLE ANGLE`
* Sample from the uniform distribution
* @param min minimun value
* @param max maximum value
*/
protected double sampleAngle(double min, double max) {
double angle = this.rng.nextDouble();
angle *= max - min;
angle += min;
return angle;
}

}
2 changes: 2 additions & 0 deletions src/uk/ac/warwick/sip/mcmc/Global.java
Original file line number Diff line number Diff line change
@@ -48,6 +48,8 @@ public static void main(String[] args) {
uk.ac.warwick.sip.mcmcprocessing.NoUTurnSampler.main(args);
} else if (userArg.equals("-rwmh")) {
uk.ac.warwick.sip.mcmcprocessing.RandomWalkMetropolisHastings.main(args);
} else if (userArg.equals("-slice")) {
uk.ac.warwick.sip.mcmcprocessing.EllipticalSlice.main(args);
} else if (userArg.equals("-example")) {
example();
} else if (userArg.equals("-test")) {
13 changes: 13 additions & 0 deletions src/uk/ac/warwick/sip/mcmc/NormalDistribution.java
Original file line number Diff line number Diff line change
@@ -18,6 +18,7 @@

import java.lang.Math;

import org.apache.commons.math3.random.MersenneTwister;
import org.ejml.dense.row.CommonOps_DDRM;
import org.ejml.dense.row.decomposition.TriangularSolver_DDRM;
import org.ejml.simple.SimpleMatrix;
@@ -98,7 +99,19 @@ public SimpleMatrix getDPotential(SimpleMatrix x) {
SimpleMatrix covariance = new SimpleMatrix(this.nDim, this.nDim);
CommonOps_DDRM.multInner(covarianceCholInverse.getDDRM(), covariance.getDDRM());
return covariance.mult(x.minus(this.mean));
}

/**METHOD: SAMPLE
* Sample a Gaussian distribution with a mean and covariance
* @return vector, sample from Gaussian distribution
*/
public SimpleMatrix sample(MersenneTwister rng) {
double [] xArray = new double[this.nDim];
for (int i=0; i<this.nDim; i++){
xArray[i] = rng.nextGaussian();
}
SimpleMatrix x = new SimpleMatrix(this.nDim, 1, true, xArray);
return this.covarianceChol.mult(x).plus(this.mean);
}

}
151 changes: 151 additions & 0 deletions src/uk/ac/warwick/sip/mcmcprocessing/EllipticalSlice.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
package uk.ac.warwick.sip.mcmcprocessing;

import java.util.Iterator;

import org.apache.commons.math3.random.MersenneTwister;
import org.ejml.simple.SimpleMatrix;

import processing.core.PApplet;
import uk.ac.warwick.sip.mcmc.NormalDistribution;
import uk.ac.warwick.sip.mcmc.TargetDistribution;

/**CLASS: ELLIPTICAL SLICE SAMPLING
* Simulation for elliptical slice sampling, click to add a chain
* Draws ellipse search space, draw red dots for rejected samples
*/
public class EllipticalSlice extends McmcApplet {

static public final int N_PLOT = 1000;

//chain to simulate (hides the super class version)
uk.ac.warwick.sip.mcmc.EllipticalSlice chain;
//target distribution
protected TargetDistribution target;
//Normal prior distribution,
protected NormalDistribution prior;
//likelihood function
protected TargetDistribution likelihood;

/**OVERRIDE: SETUP
* Set target distribution, prior and likelihood
*/
@Override
public void setup() {
super.setup();
this.target = this.getNormalDistribution();
this.prior = this.getPrior();
this.likelihood = this.getLikelihood();
}

/**OVERRIDE: DRAW MCMC
* Draw all samples except for the last one
* Draw leap frog steps
* Draw accepted sample from the leap frog steps
*/
@Override
protected void drawMcmc() {
double [] chainArray = this.chain.getChain();
this.drawAllButLastSamples();

//draw all points on the ellipse
Iterator<SimpleMatrix> ellipticalPoints = this.chain.getEllipticalPositionsIterator();
this.stroke(255, 0, 0);
this.fill(255, 0, 0);
float x, y;
while(ellipticalPoints.hasNext()) {
SimpleMatrix point = ellipticalPoints.next();
x = (float) point.get(0);
y = (float) point.get(1);
//draw the first sample
this.ellipse(x, y , CIRCLE_SIZE, CIRCLE_SIZE);
}

//draw the ellipse as a parametric plot, using angle as a parameter
//yellow line for acceptable region of the ellipse
//red line for rejected region of the ellipse, according to slice sampling
double angle = -Math.PI;
double angleDiff = 2*Math.PI/N_PLOT;
float x0 = Float.NaN;
float y0 = Float.NaN;
SimpleMatrix ellipticalPoint;
this.strokeWeight(2);
if (this.chain.getNStep() > 0) {
for (int i=0; i<N_PLOT; i++) {
ellipticalPoint = this.chain.getPointOnEllipse(angle);
x = (float) ellipticalPoint.get(0);
y = (float) ellipticalPoint.get(1);

if (i > 0) {

if (this.chain.isValidPointOnEllipse(ellipticalPoint)) {
this.stroke(255, 255, 0);
} else {
this.stroke(255, 0, 0);
}

this.line(x0, y0, x, y);
}

x0 = x;
y0 = y;
angle += angleDiff;
}
}
this.strokeWeight(1);

//draw the latest sample
this.stroke(0,255,0);
this.fill(0,255,0);
x = (float) chainArray[this.chain.getNStep()*2];
y = (float) chainArray[this.chain.getNStep()*2 + 1];
this.ellipse(x, y , CIRCLE_SIZE, CIRCLE_SIZE);
}

/**OVERRIDE: MOUSE RELEASED
* Instantiate a new chain, if the mouse of not or clicked on a gui
*/
@Override
public void mouseReleased() {
//if the mouse hasn't clicked on a gui or on one
if (!this.isMouseClickOnGui) {
if (!this.isMouseOnGui()) {
//get mouse position
double [] mousePosition = new double [2];
mousePosition[0] = (double) this.mouseX;
mousePosition[1] = (double) this.mouseY;
//instantiate chain
MersenneTwister rng = new MersenneTwister(this.millis());
this.chain = new uk.ac.warwick.sip.mcmc.EllipticalSlice(this.target, this.likelihood,
this.prior, CHAIN_LENGTH, rng);
this.chain.setInitialValue(mousePosition);
super.chain = this.chain; //save copy to the superclass
this.isInit = true;
}
}
super.mouseReleased();
}

/**METHOD: GET LIKELIHOOD
* Return the default likelihood function
*/
private NormalDistribution getLikelihood() {
//instantiate the covariance
SimpleMatrix targetCovariance = new SimpleMatrix(2, 2);
targetCovariance.set(0, 0, 2*NORMAL_TARGET_VARIANCE);
targetCovariance.set(1, 1, 2*NORMAL_TARGET_VARIANCE);
//instantiate the mean
SimpleMatrix mean = new SimpleMatrix(2, 1, true, this.getCentre());
return new NormalDistribution(2, mean, targetCovariance);
}

/**METHOD: GET PRIOR
* Return the default prior distribution
*/
private NormalDistribution getPrior() {
return this.getLikelihood();
}

public static void main(String[] args) {
PApplet.main("uk.ac.warwick.sip.mcmcprocessing.EllipticalSlice");
}
}
Original file line number Diff line number Diff line change
@@ -50,7 +50,7 @@ public class HamiltonianMonteCarlo extends McmcApplet{
@Override
public void setup() {
super.setup();
target = this.getNormalDistribution();
this.target = this.getNormalDistribution();

//instantiate slider bar
this.nLeapFrogSlider = new GSlider(this, 110, 150, 300, 150, 30);
4 changes: 2 additions & 2 deletions src/uk/ac/warwick/sip/mcmcprocessing/McmcApplet.java
Original file line number Diff line number Diff line change
@@ -69,8 +69,8 @@ public void settings() {
*/
@Override
public void setup() {
this.pauseButton = new GButton(this, 10,10,50,50,"pause");
this.stepButton = new GButton(this, 10,80,50,50,"step");
this.pauseButton = new GButton(this, 10, 10, 50,50, "pause");
this.stepButton = new GButton(this, 10, 80, 50, 50, "step");
this.quitButton = new GButton(this, 10, this.height-80, 50, 50, "quit");
}

6 changes: 3 additions & 3 deletions src/uk/ac/warwick/sip/mcmcprocessing/NoUTurnSampler.java
Original file line number Diff line number Diff line change
@@ -80,9 +80,9 @@ protected void drawMcmc() {
x1 = x2;
y1 = y2;
}
//draw in blue the accepted sample
this.stroke(0,0,255);
this.fill(0,0,255);
//draw in green the accepted sample
this.stroke(0,255,0);
this.fill(0,255,0);
x2 = (float) chainArray[this.chain.getNStep()*2];
y2 = (float) chainArray[this.chain.getNStep()*2+1];
this.ellipse(x2, y2 , CIRCLE_SIZE, CIRCLE_SIZE);
Binary file added tex/processing_slice.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit 5239397

Please sign in to comment.