-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
HMC accepted step is green rather than blue
1 parent
d4d7ade
commit 5239397
Showing
10 changed files
with
362 additions
and
18 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,172 @@ | ||
/* | ||
* Copyright 2018-2021 Sherman Lo | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
*/ | ||
|
||
package uk.ac.warwick.sip.mcmc; | ||
|
||
import java.util.ArrayList; | ||
import java.util.Iterator; | ||
|
||
import org.apache.commons.math3.random.MersenneTwister; | ||
import org.ejml.dense.row.CommonOps_DDRM; | ||
import org.ejml.simple.SimpleMatrix; | ||
|
||
/**CLASS: ELLIPTICAL SLICE SAMPLER | ||
* Sampler which uses elliptical slice sampling | ||
* Reference: Elliptical slice sampling (Murray, I., Adams, R. P., and MacKay, D. J. (2010). | ||
* Elliptical slice sampling. In Proceedings of the 13th International Conference on Artificial | ||
* Intelligence and Statistics.) | ||
* At a step, a sample from the prior is taken. A weighted average between the current step and that | ||
* sample is used as the next step in the MCMC according to the slice sampling scheme. | ||
*/ | ||
public class EllipticalSlice extends Mcmc{ | ||
|
||
//the likelihood of the model, elliptical slice sampling works on the likelihood rather than the | ||
//posterior | ||
protected TargetDistribution likelihood; | ||
//the prior distribution | ||
protected NormalDistribution prior; | ||
//position vector of the start of the ellipse (the current step) | ||
protected SimpleMatrix ellipse0; | ||
//position vector of the end of the ellipse (the sample from the prior) | ||
protected SimpleMatrix proposal; | ||
//after a step, an array of points looked at on the ellipse | ||
protected ArrayList<SimpleMatrix> ellipticalPositions; | ||
//slice variable, log uniform number plus log likelihood | ||
protected double sliceVariable; | ||
|
||
/**CONSTRUCTOR | ||
* Elliptical slice sampling algorithm. Samples from the prior (called the proposal in this code) | ||
* and then uses slice sampling to find a weighted combination of the proposed and the current | ||
* value, which turns out to make an ellipse. | ||
* @param target Object which has a method to call the pdf | ||
* @param likelihood the likelihood of the model | ||
* @param prior NormalDistribution representing the prior | ||
* @param chainLength Length of the chain to be obtained | ||
* @param rng Random number generator all the random numbers | ||
*/ | ||
public EllipticalSlice(TargetDistribution target, TargetDistribution likelihood, | ||
NormalDistribution prior, int chainLength, MersenneTwister rng) { | ||
super(target, chainLength, rng); | ||
this.likelihood = likelihood; | ||
this.prior = prior; | ||
this.ellipticalPositions = new ArrayList<SimpleMatrix>(); | ||
} | ||
|
||
/**CONSTRUCTOR | ||
* Constructor for extending the length of the chain and resume running it | ||
* Does a shallow copy of the provided chain and extending the member variable chainArray | ||
* @param chain Chain to be extended | ||
* @param nMoreSteps Number of steps to be extended | ||
*/ | ||
public EllipticalSlice(EllipticalSlice chain, int nMoreSteps) { | ||
//shallow copies of member variables | ||
super(chain, nMoreSteps); | ||
this.likelihood = chain.likelihood; | ||
this.prior = chain.prior; | ||
this.proposal = chain.proposal; | ||
this.ellipticalPositions = chain.ellipticalPositions; | ||
this.sliceVariable = chain.sliceVariable; | ||
} | ||
|
||
@Override | ||
public void step(SimpleMatrix position) { | ||
//elliptical slice sampling is a weighted sum of position and a sample from the prior | ||
this.ellipse0 = new SimpleMatrix(position); //start of ellipse | ||
this.proposal = this.prior.sample(this.rng); //end of ellipse | ||
//array to store all points looked at on the ellipse | ||
this.ellipticalPositions = new ArrayList<SimpleMatrix>(); | ||
|
||
double angle = this.sampleAngle(0, 2*Math.PI); | ||
double angleMin = angle - 2*Math.PI; | ||
double angleMax = angle; | ||
|
||
//sliceVariable is compared with the likelihood at points on the ellipse | ||
this.sliceVariable = Math.log(this.rng.nextDouble()) - this.likelihood.getPotential(position); | ||
boolean gotSample = false; | ||
|
||
//look through the ellipse until got valid sample | ||
while (!gotSample) { | ||
position.set(this.getPointOnEllipse(angle)); | ||
if (this.isValidPointOnEllipse(position)) { | ||
gotSample = true; | ||
} else { | ||
//else this is not a valid point, look elsewhere on the ellipse | ||
this.ellipticalPositions.add(new SimpleMatrix(position)); | ||
if (angle < 0) { | ||
angleMin = angle; | ||
} else { | ||
angleMax = angle; | ||
} | ||
angle = this.sampleAngle(angleMin, angleMax); | ||
} | ||
} | ||
this.updateStatistics(position); | ||
} | ||
|
||
/**METHOD: IS VALID POINT ON ELLIPSE | ||
* Check if this position vector would be accepted under the slice sampling scheme. This would | ||
* require the member variable sliceVariable initalised with the correct value. | ||
* @param position position vector (on the ellipse) | ||
* @return boolean if the position vector is valid under slice sampling | ||
*/ | ||
public boolean isValidPointOnEllipse(SimpleMatrix position) { | ||
return -this.likelihood.getPotential(position) > this.sliceVariable; | ||
} | ||
|
||
/**METHOD: GET POINT ON ELLIPSE | ||
* Return a weighted sum of initial and proposal, the weight using cosine and sine of angle | ||
* @param initial vector of current position of chain | ||
* @param proposal vector of sample from prior | ||
* @param angle of the ellipse | ||
* @return vector, weighted sum of initial and proposal | ||
*/ | ||
public SimpleMatrix getPointOnEllipse(double angle) { | ||
SimpleMatrix initialWeighted = new SimpleMatrix(this.ellipse0); | ||
SimpleMatrix proposalWeighted = new SimpleMatrix(this.proposal); | ||
|
||
//centre the position vectors at the prior mean | ||
CommonOps_DDRM.subtractEquals(initialWeighted.getDDRM(), this.prior.mean.getDDRM()); | ||
CommonOps_DDRM.subtractEquals(proposalWeighted.getDDRM(), this.prior.mean.getDDRM()); | ||
|
||
CommonOps_DDRM.scale(Math.cos(angle), initialWeighted.getDDRM()); | ||
CommonOps_DDRM.scale(Math.sin(angle), proposalWeighted.getDDRM()); | ||
|
||
CommonOps_DDRM.addEquals(proposalWeighted.getDDRM(), initialWeighted.getDDRM()); | ||
CommonOps_DDRM.addEquals(proposalWeighted.getDDRM(), this.prior.mean.getDDRM()); | ||
|
||
return proposalWeighted; | ||
} | ||
|
||
/**METHOD: GET ELLIPTICAL POSITIONS ITERATOR | ||
* Return an iterator which iterates through all the points looked at on the ellipse | ||
*/ | ||
public Iterator<SimpleMatrix> getEllipticalPositionsIterator(){ | ||
return this.ellipticalPositions.iterator(); | ||
} | ||
|
||
/**METHOD: SAMPLE ANGLE` | ||
* Sample from the uniform distribution | ||
* @param min minimun value | ||
* @param max maximum value | ||
*/ | ||
protected double sampleAngle(double min, double max) { | ||
double angle = this.rng.nextDouble(); | ||
angle *= max - min; | ||
angle += min; | ||
return angle; | ||
} | ||
|
||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
151 changes: 151 additions & 0 deletions
151
src/uk/ac/warwick/sip/mcmcprocessing/EllipticalSlice.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,151 @@ | ||
package uk.ac.warwick.sip.mcmcprocessing; | ||
|
||
import java.util.Iterator; | ||
|
||
import org.apache.commons.math3.random.MersenneTwister; | ||
import org.ejml.simple.SimpleMatrix; | ||
|
||
import processing.core.PApplet; | ||
import uk.ac.warwick.sip.mcmc.NormalDistribution; | ||
import uk.ac.warwick.sip.mcmc.TargetDistribution; | ||
|
||
/**CLASS: ELLIPTICAL SLICE SAMPLING | ||
* Simulation for elliptical slice sampling, click to add a chain | ||
* Draws ellipse search space, draw red dots for rejected samples | ||
*/ | ||
public class EllipticalSlice extends McmcApplet { | ||
|
||
static public final int N_PLOT = 1000; | ||
|
||
//chain to simulate (hides the super class version) | ||
uk.ac.warwick.sip.mcmc.EllipticalSlice chain; | ||
//target distribution | ||
protected TargetDistribution target; | ||
//Normal prior distribution, | ||
protected NormalDistribution prior; | ||
//likelihood function | ||
protected TargetDistribution likelihood; | ||
|
||
/**OVERRIDE: SETUP | ||
* Set target distribution, prior and likelihood | ||
*/ | ||
@Override | ||
public void setup() { | ||
super.setup(); | ||
this.target = this.getNormalDistribution(); | ||
this.prior = this.getPrior(); | ||
this.likelihood = this.getLikelihood(); | ||
} | ||
|
||
/**OVERRIDE: DRAW MCMC | ||
* Draw all samples except for the last one | ||
* Draw leap frog steps | ||
* Draw accepted sample from the leap frog steps | ||
*/ | ||
@Override | ||
protected void drawMcmc() { | ||
double [] chainArray = this.chain.getChain(); | ||
this.drawAllButLastSamples(); | ||
|
||
//draw all points on the ellipse | ||
Iterator<SimpleMatrix> ellipticalPoints = this.chain.getEllipticalPositionsIterator(); | ||
this.stroke(255, 0, 0); | ||
this.fill(255, 0, 0); | ||
float x, y; | ||
while(ellipticalPoints.hasNext()) { | ||
SimpleMatrix point = ellipticalPoints.next(); | ||
x = (float) point.get(0); | ||
y = (float) point.get(1); | ||
//draw the first sample | ||
this.ellipse(x, y , CIRCLE_SIZE, CIRCLE_SIZE); | ||
} | ||
|
||
//draw the ellipse as a parametric plot, using angle as a parameter | ||
//yellow line for acceptable region of the ellipse | ||
//red line for rejected region of the ellipse, according to slice sampling | ||
double angle = -Math.PI; | ||
double angleDiff = 2*Math.PI/N_PLOT; | ||
float x0 = Float.NaN; | ||
float y0 = Float.NaN; | ||
SimpleMatrix ellipticalPoint; | ||
this.strokeWeight(2); | ||
if (this.chain.getNStep() > 0) { | ||
for (int i=0; i<N_PLOT; i++) { | ||
ellipticalPoint = this.chain.getPointOnEllipse(angle); | ||
x = (float) ellipticalPoint.get(0); | ||
y = (float) ellipticalPoint.get(1); | ||
|
||
if (i > 0) { | ||
|
||
if (this.chain.isValidPointOnEllipse(ellipticalPoint)) { | ||
this.stroke(255, 255, 0); | ||
} else { | ||
this.stroke(255, 0, 0); | ||
} | ||
|
||
this.line(x0, y0, x, y); | ||
} | ||
|
||
x0 = x; | ||
y0 = y; | ||
angle += angleDiff; | ||
} | ||
} | ||
this.strokeWeight(1); | ||
|
||
//draw the latest sample | ||
this.stroke(0,255,0); | ||
this.fill(0,255,0); | ||
x = (float) chainArray[this.chain.getNStep()*2]; | ||
y = (float) chainArray[this.chain.getNStep()*2 + 1]; | ||
this.ellipse(x, y , CIRCLE_SIZE, CIRCLE_SIZE); | ||
} | ||
|
||
/**OVERRIDE: MOUSE RELEASED | ||
* Instantiate a new chain, if the mouse of not or clicked on a gui | ||
*/ | ||
@Override | ||
public void mouseReleased() { | ||
//if the mouse hasn't clicked on a gui or on one | ||
if (!this.isMouseClickOnGui) { | ||
if (!this.isMouseOnGui()) { | ||
//get mouse position | ||
double [] mousePosition = new double [2]; | ||
mousePosition[0] = (double) this.mouseX; | ||
mousePosition[1] = (double) this.mouseY; | ||
//instantiate chain | ||
MersenneTwister rng = new MersenneTwister(this.millis()); | ||
this.chain = new uk.ac.warwick.sip.mcmc.EllipticalSlice(this.target, this.likelihood, | ||
this.prior, CHAIN_LENGTH, rng); | ||
this.chain.setInitialValue(mousePosition); | ||
super.chain = this.chain; //save copy to the superclass | ||
this.isInit = true; | ||
} | ||
} | ||
super.mouseReleased(); | ||
} | ||
|
||
/**METHOD: GET LIKELIHOOD | ||
* Return the default likelihood function | ||
*/ | ||
private NormalDistribution getLikelihood() { | ||
//instantiate the covariance | ||
SimpleMatrix targetCovariance = new SimpleMatrix(2, 2); | ||
targetCovariance.set(0, 0, 2*NORMAL_TARGET_VARIANCE); | ||
targetCovariance.set(1, 1, 2*NORMAL_TARGET_VARIANCE); | ||
//instantiate the mean | ||
SimpleMatrix mean = new SimpleMatrix(2, 1, true, this.getCentre()); | ||
return new NormalDistribution(2, mean, targetCovariance); | ||
} | ||
|
||
/**METHOD: GET PRIOR | ||
* Return the default prior distribution | ||
*/ | ||
private NormalDistribution getPrior() { | ||
return this.getLikelihood(); | ||
} | ||
|
||
public static void main(String[] args) { | ||
PApplet.main("uk.ac.warwick.sip.mcmcprocessing.EllipticalSlice"); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.