Skip to content

Commit

Permalink
syntaxic sugar for RandomField and BaselineIncrement formulation
Browse files Browse the repository at this point in the history
  • Loading branch information
msuchard committed Nov 4, 2023
1 parent 54fe8b5 commit a6af952
Show file tree
Hide file tree
Showing 11 changed files with 129 additions and 69 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -64,26 +64,10 @@ private double[][] getUnconstrainedValues() {
values[0] = new double[1];
values[0][0] = concatenated[0];
values[1] = new double[dim - 1];
for (int i = 0; i < dim - 1; i++) {
values[1][i] = concatenated[i+1];
}
System.arraycopy(concatenated, 1, values[1], 0, dim -1);
return values;
}

// private double[] getFirstElementVariableValues() {
// double[] vals = new double[1];
// vals[0] = variables.getParameterValue(0);
// return vals;
// }
//
// private double[] getBridgeVariableValues() {
// double[] vals = new double[dim - 1];
// for (int i = 0; i < dim - 1; i++) {
// vals[i] = variables.getParameterValue(i + 1);
// }
// return vals;
// }

@Override
public double getLogLikelihood() {
double[][] unconstrained = getUnconstrainedValues();
Expand All @@ -103,9 +87,7 @@ public double[] getGradientLogDensity() {
grad[0] = ((GradientProvider)firstElementDistribution).getGradientLogDensity(transformedVariables[0])[0];

double[] bridgeGrad = bridge.getGradientLogDensity(transformedVariables[1]);
for (int i = 0; i < dim - 1; i++) {
grad[i + 1] = bridgeGrad[i];
}
System.arraycopy(bridgeGrad, 0, grad, 1, dim -1);

return transform.updateGradientLogDensity(grad, transform.inverse(variables.getParameterValues(), 0, dim), 0, dim);

Expand Down
3 changes: 1 addition & 2 deletions src/dr/inference/distribution/RandomField.java
Original file line number Diff line number Diff line change
Expand Up @@ -143,8 +143,7 @@ public BayesianBridge(String name, Parameter field, RandomFieldDistribution dist

@Override
public double getCoefficient(int i) {
double[] mean = getDistribution().getMean();
return (field.getParameterValue(i) - mean[i]) - (field.getParameterValue(i + 1) - mean[i + 1]);
return getDistribution().getIncrement(i, field);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@

import dr.inference.model.Parameter;
import dr.inference.model.PriorPreconditioningProvider;
import dr.math.MathUtils;
import dr.math.distributions.GammaDistribution;
import dr.math.distributions.NormalDistribution;

/**
Expand Down Expand Up @@ -99,10 +97,6 @@ public double[] nextRandom() {
return draws;
}

// public void setGlobalScale(double draw) {
// globalScale.setParameterValue(0, draw);
// }

private final Parameter localScale;
private final Parameter slabWidth;
}
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ public Parameter getSlabWidth() {
for (int i = 0; i < dim; ++i) {
gradient[i] = MarginalizedAlphaStableDistribution.gradLogPdf(x[i], scale, alpha);
}
} else if (alpha == 1.0) {
} else {
for (int i = 0; i < dim; ++i) {
gradient[i] = LaplaceDistribution.gradLogPdf(x[i], 0, scale);
}
Expand All @@ -63,7 +63,7 @@ public double logPdf(double[] v) {
for (double x : v) {
sum += MarginalizedAlphaStableDistribution.logPdf(x, scale, alpha);
}
} else if (alpha == 1.0) {
} else {
for (int i = 0; i < dim; ++i) {
sum += LaplaceDistribution.logPdf(v[i], 0, scale);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import dr.inference.distribution.RandomField;
import dr.inference.model.Parameter;
import dr.math.distributions.BaselineIncrementField;
import dr.math.distributions.Distribution;
import dr.xml.*;

import static dr.inferencexml.distribution.RandomFieldParser.WEIGHTS_RULE;
Expand All @@ -43,9 +44,10 @@ public class BaselineIncrementFieldParser extends AbstractXMLObjectParser {

public Object parseXMLObject(XMLObject xo) throws XMLParseException {

Parameter baseline = (Parameter) xo.getElementFirstChild(BASELINE);
Parameter increments = (Parameter) xo.getElementFirstChild(INCREMENTS);
RandomField.WeightProvider weights = parseWeightProvider(xo, increments.getDimension() + 1);
Distribution baseline = (Distribution) xo.getElementFirstChild(BASELINE);
Distribution increments = (Distribution) xo.getElementFirstChild(INCREMENTS);

RandomField.WeightProvider weights = parseWeightProvider(xo, 0);

String id = xo.hasId() ? xo.getId() : PARSER_NAME;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,6 @@
import dr.inference.distribution.ParametricDistributionModel;
import dr.inference.distribution.shrinkage.*;
import dr.inference.model.Parameter;
import dr.inference.model.ParameterParser;
import dr.inferencexml.distribution.shrinkage.BayesianBridgeDistributionModelParser;
import dr.util.FirstOrderFiniteDifferenceTransform;
import dr.util.InverseFirstOrderFiniteDifferenceTransform;
import dr.util.Transform;
import dr.xml.*;
Expand All @@ -51,22 +48,20 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException {

Parameter variables = (Parameter) xo.getChild(Parameter.class);

// BayesianBridgeDistributionModelParser bridgeParser = new BayesianBridgeDistributionModelParser();
// BayesianBridgeDistributionModel bridge = (BayesianBridgeDistributionModel) bridgeParser.parseXMLObject(xo.getChild(BAYESIAN_BRIDGE_DISTRIBUTION).getChild(0));
BayesianBridgeDistributionModel bridge = (BayesianBridgeDistributionModel) xo.getChild(BAYESIAN_BRIDGE_DISTRIBUTION).getChild(0);

ParametricDistributionModel firstElementDistribution = (ParametricDistributionModel) xo.getChild(FIRST_ELEMENT_DISTRIBUTION).getChild(0);

double upper = xo.getAttribute("upper", 1.0);
double lower = xo.getAttribute("lower", 0.0);

String ttype = (String) xo.getAttribute(INCREMENT_TRANSFORM, "none");
String transformType = xo.getAttribute(INCREMENT_TRANSFORM, "none");
Transform.UnivariableTransform incrementTransform;
if ( ttype.equalsIgnoreCase("none") ) {
if ( transformType.equalsIgnoreCase("none") ) {
incrementTransform = new Transform.NoTransform();
} else if ( ttype.equalsIgnoreCase("log") ) {
} else if ( transformType.equalsIgnoreCase("log") ) {
incrementTransform = new Transform.LogTransform();
} else if ( ttype.equalsIgnoreCase("logit") ) {
} else if ( transformType.equalsIgnoreCase("logit") ) {
incrementTransform = new Transform.ScaledLogitTransform(lower, upper);
} else {
throw new RuntimeException("Invalid option for "+ INCREMENT_TRANSFORM);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
import dr.inference.distribution.shrinkage.JointBayesianBridgeDistributionModel;
import dr.inference.model.Parameter;
import dr.math.distributions.BayesianBridgeMarkovRandomField;
import dr.math.distributions.GaussianMarkovRandomField;
import dr.xml.*;

import static dr.inferencexml.distribution.RandomFieldParser.WEIGHTS_RULE;
Expand Down
128 changes: 103 additions & 25 deletions src/dr/math/distributions/BaselineIncrementField.java
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
package dr.math.distributions;

import dr.inference.distribution.RandomField;
import dr.inference.distribution.shrinkage.BayesianBridgeStatisticsProvider;
import dr.inference.model.GradientProvider;
import dr.inference.model.Model;
import dr.inference.model.Parameter;
Expand All @@ -36,29 +37,45 @@
* @author Yucai Shao
* @author Andy Magee
*/
public class BaselineIncrementField extends RandomFieldDistribution {
public class BaselineIncrementField extends RandomFieldDistribution
implements BayesianBridgeStatisticsProvider {

public static final String TYPE = "BaselineIncrementField";

private final Parameter baseline;
private final Parameter increments;
private final RandomField.WeightProvider weights;
private final Distribution baseline;
private final Distribution increments;

private final GradientProvider baselineGradient;
private final GradientProvider incrementGradient;

private final BayesianBridgeStatisticsProvider bayesianBridge;

public BaselineIncrementField(String name,
Parameter baseline,
Parameter increments,
Distribution baseline,
Distribution increments,
RandomField.WeightProvider weights) {
super(name);

this.baseline = baseline;
this.increments = increments;
this.weights = weights;

addVariable(baseline);
addVariable(increments);
if (baseline instanceof Model) {
addModel((Model) baseline);
}
if (increments instanceof Model) {
addModel((Model) increments);
}

baselineGradient = (baseline instanceof GradientProvider) ?
(GradientProvider) baseline : null;

incrementGradient = (increments instanceof GradientProvider) ?
(GradientProvider) increments : null;

bayesianBridge = (increments instanceof BayesianBridgeStatisticsProvider) ?
(BayesianBridgeStatisticsProvider) baseline : null;

if (weights != null) {
addModel(weights);
throw new IllegalArgumentException("Unsure how weights influence this field");
}
}
Expand All @@ -75,37 +92,91 @@ public double[] nextRandom() {

@Override
protected void handleModelChangedEvent(Model model, Object object, int index) {

// if (model == baseline || model == increments) {
// Do nothing
// TODO do we need a fireModelChangedEvent()?
// }
}

@Override
protected void storeState() {
protected void storeState() { }

}
@Override
protected void restoreState() { }

@Override
protected void restoreState() {
protected void acceptState() { }

@Override
protected void handleVariableChangedEvent(Variable variable, int index, Parameter.ChangeType type) {
throw new IllegalArgumentException("Unknown variable");
}

@Override
protected void acceptState() {
public double getCoefficient(int i) {
throw new RuntimeException("Should not be called");
}

@Override
public Parameter getGlobalScale() {
if (bayesianBridge != null) {
return bayesianBridge.getGlobalScale();
} else {
throw new IllegalArgumentException("Not a Bayesian bridge");
}
}

@Override
protected void handleVariableChangedEvent(Variable variable, int index, Parameter.ChangeType type) {
public Parameter getLocalScale() {
if (bayesianBridge != null) {
return bayesianBridge.getLocalScale();
} else {
throw new IllegalArgumentException("Not a Bayesian bridge");
}
}

@Override
public Parameter getExponent() {
if (bayesianBridge != null) {
return bayesianBridge.getExponent();
} else {
throw new IllegalArgumentException("Not a Bayesian bridge");
}
}

@Override
public Parameter getSlabWidth() {
if (bayesianBridge != null) {
return bayesianBridge.getSlabWidth();
} else {
throw new IllegalArgumentException("Not a Bayesian bridge");
}
}

@Override
public int getDimension() {
throw new RuntimeException("Not yet implemented");
if (bayesianBridge != null) {
return bayesianBridge.getDimension();
} else {
throw new IllegalArgumentException("Not a Bayesian bridge");
}
}

@Override
public double[] getGradientLogDensity(Object x) {
throw new RuntimeException("Not yet implemented");
double[] field = (double[]) x;

double[] sub = new double[field.length - 1];
System.arraycopy(field, 1, sub, 0, field.length - 1);

double[] baselineGrad = baselineGradient.getGradientLogDensity(new double[]{field[0]});
double[] incrementGrad = incrementGradient.getGradientLogDensity(sub);

double[] gradient = new double[field.length];
gradient[0] = baselineGrad[0];
System.arraycopy(incrementGrad, 0, gradient, 1, field.length - 1);

return gradient;
}

@Override
Expand All @@ -120,18 +191,20 @@ public double[][] getHessianLogDensity(Object x) {

@Override
public double logPdf(double[] x) {
throw new RuntimeException("Not yet implemented");

double logPdf = baseline.logPdf(x[0]);
for (int i = 1; i < x.length; ++i) {
logPdf += increments.logPdf(x[i]);
}

return logPdf;
}

@Override
public double[][] getScaleMatrix() {
throw new RuntimeException("Not yet implemented");
}
public double[][] getScaleMatrix() { throw new RuntimeException("Not yet implemented");}

@Override
public double[] getMean() {
throw new RuntimeException("Not yet implemented");
}
public double[] getMean() { throw new RuntimeException("Not yet implemented"); }

@Override
public String getType() { return TYPE; }
Expand All @@ -140,4 +213,9 @@ public double[] getMean() {
public GradientProvider getGradientWrt(Parameter parameter) {
throw new RuntimeException("Not yet implemented");
}

@Override
public double getIncrement(int i, Parameter field) {
return field.getParameterValue(i + 1);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ protected SymmetricTriDiagonalMatrix getQ() {
protected void handleModelChangedEvent(Model model, Object object, int index) {
if (model == bayesianBridge) {
qKnown = false;
// TODO do we need a fireModelChangedEvent()?
} else {
throw new IllegalArgumentException("Unknown model");
}
Expand Down
10 changes: 9 additions & 1 deletion src/dr/math/distributions/GaussianMarkovRandomField.java
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,14 @@ private double[][] getPrecision() {
return precision;
}

@Override
public double getIncrement(int i, Parameter field) {

double[] mean = getMean();
return (field.getParameterValue(i) - mean[i]) - (field.getParameterValue(i + 1) - mean[i + 1]);
}

@Override
public GradientProvider getGradientWrt(Parameter parameter) {

if (parameter == precisionParameter) {
Expand Down Expand Up @@ -251,7 +259,7 @@ public double[][] getScaleMatrix() {

@Override
public Variable<Double> getLocationVariable() {
return null;
return meanParameter;
}

@Override
Expand Down
2 changes: 2 additions & 0 deletions src/dr/math/distributions/RandomFieldDistribution.java
Original file line number Diff line number Diff line change
Expand Up @@ -41,4 +41,6 @@ public RandomFieldDistribution(String name) {
}

public abstract GradientProvider getGradientWrt(Parameter parameter);

public abstract double getIncrement(int i, Parameter field);
}

0 comments on commit a6af952

Please sign in to comment.