From a6af952f606714da219292894241b996a91edb1f Mon Sep 17 00:00:00 2001 From: "Marc A. Suchard" Date: Sat, 4 Nov 2023 16:18:11 -0700 Subject: [PATCH] syntaxic sugar for RandomField and BaselineIncrement formulation --- ...sianBridgeMarkovRandomFieldLikelihood.java | 22 +-- .../inference/distribution/RandomField.java | 3 +- .../JointBayesianBridgeDistributionModel.java | 6 - ...rginalBayesianBridgeDistributionModel.java | 4 +- .../BaselineIncrementFieldParser.java | 8 +- ...idgeMarkovRandomFieldLikelihoodParser.java | 13 +- ...BayesianBridgeMarkovRandomFieldParser.java | 1 - .../distributions/BaselineIncrementField.java | 128 ++++++++++++++---- .../BayesianBridgeMarkovRandomField.java | 1 + .../GaussianMarkovRandomField.java | 10 +- .../RandomFieldDistribution.java | 2 + 11 files changed, 129 insertions(+), 69 deletions(-) diff --git a/src/dr/inference/distribution/BayesianBridgeMarkovRandomFieldLikelihood.java b/src/dr/inference/distribution/BayesianBridgeMarkovRandomFieldLikelihood.java index 65ef2e729b..161c85c1c6 100644 --- a/src/dr/inference/distribution/BayesianBridgeMarkovRandomFieldLikelihood.java +++ b/src/dr/inference/distribution/BayesianBridgeMarkovRandomFieldLikelihood.java @@ -64,26 +64,10 @@ private double[][] getUnconstrainedValues() { values[0] = new double[1]; values[0][0] = concatenated[0]; values[1] = new double[dim - 1]; - for (int i = 0; i < dim - 1; i++) { - values[1][i] = concatenated[i+1]; - } + System.arraycopy(concatenated, 1, values[1], 0, dim -1); return values; } -// private double[] getFirstElementVariableValues() { -// double[] vals = new double[1]; -// vals[0] = variables.getParameterValue(0); -// return vals; -// } -// -// private double[] getBridgeVariableValues() { -// double[] vals = new double[dim - 1]; -// for (int i = 0; i < dim - 1; i++) { -// vals[i] = variables.getParameterValue(i + 1); -// } -// return vals; -// } - @Override public double getLogLikelihood() { double[][] unconstrained = getUnconstrainedValues(); @@ -103,9 +87,7 @@ public double[] getGradientLogDensity() { grad[0] = ((GradientProvider)firstElementDistribution).getGradientLogDensity(transformedVariables[0])[0]; double[] bridgeGrad = bridge.getGradientLogDensity(transformedVariables[1]); - for (int i = 0; i < dim - 1; i++) { - grad[i + 1] = bridgeGrad[i]; - } + System.arraycopy(bridgeGrad, 0, grad, 1, dim -1); return transform.updateGradientLogDensity(grad, transform.inverse(variables.getParameterValues(), 0, dim), 0, dim); diff --git a/src/dr/inference/distribution/RandomField.java b/src/dr/inference/distribution/RandomField.java index cd72c29699..942075e68b 100644 --- a/src/dr/inference/distribution/RandomField.java +++ b/src/dr/inference/distribution/RandomField.java @@ -143,8 +143,7 @@ public BayesianBridge(String name, Parameter field, RandomFieldDistribution dist @Override public double getCoefficient(int i) { - double[] mean = getDistribution().getMean(); - return (field.getParameterValue(i) - mean[i]) - (field.getParameterValue(i + 1) - mean[i + 1]); + return getDistribution().getIncrement(i, field); } @Override diff --git a/src/dr/inference/distribution/shrinkage/JointBayesianBridgeDistributionModel.java b/src/dr/inference/distribution/shrinkage/JointBayesianBridgeDistributionModel.java index 6ffeca2c28..98ce55b7e1 100644 --- a/src/dr/inference/distribution/shrinkage/JointBayesianBridgeDistributionModel.java +++ b/src/dr/inference/distribution/shrinkage/JointBayesianBridgeDistributionModel.java @@ -2,8 +2,6 @@ import dr.inference.model.Parameter; import dr.inference.model.PriorPreconditioningProvider; -import dr.math.MathUtils; -import dr.math.distributions.GammaDistribution; import dr.math.distributions.NormalDistribution; /** @@ -99,10 +97,6 @@ public double[] nextRandom() { return draws; } -// public void setGlobalScale(double draw) { -// globalScale.setParameterValue(0, draw); -// } - private final Parameter localScale; private final Parameter slabWidth; } \ No newline at end of file diff --git a/src/dr/inference/distribution/shrinkage/MarginalBayesianBridgeDistributionModel.java b/src/dr/inference/distribution/shrinkage/MarginalBayesianBridgeDistributionModel.java index a150287ebf..b06c9cf758 100644 --- a/src/dr/inference/distribution/shrinkage/MarginalBayesianBridgeDistributionModel.java +++ b/src/dr/inference/distribution/shrinkage/MarginalBayesianBridgeDistributionModel.java @@ -44,7 +44,7 @@ public Parameter getSlabWidth() { for (int i = 0; i < dim; ++i) { gradient[i] = MarginalizedAlphaStableDistribution.gradLogPdf(x[i], scale, alpha); } - } else if (alpha == 1.0) { + } else { for (int i = 0; i < dim; ++i) { gradient[i] = LaplaceDistribution.gradLogPdf(x[i], 0, scale); } @@ -63,7 +63,7 @@ public double logPdf(double[] v) { for (double x : v) { sum += MarginalizedAlphaStableDistribution.logPdf(x, scale, alpha); } - } else if (alpha == 1.0) { + } else { for (int i = 0; i < dim; ++i) { sum += LaplaceDistribution.logPdf(v[i], 0, scale); } diff --git a/src/dr/inferencexml/distribution/BaselineIncrementFieldParser.java b/src/dr/inferencexml/distribution/BaselineIncrementFieldParser.java index 122d71b8aa..32f4ec3c71 100644 --- a/src/dr/inferencexml/distribution/BaselineIncrementFieldParser.java +++ b/src/dr/inferencexml/distribution/BaselineIncrementFieldParser.java @@ -28,6 +28,7 @@ import dr.inference.distribution.RandomField; import dr.inference.model.Parameter; import dr.math.distributions.BaselineIncrementField; +import dr.math.distributions.Distribution; import dr.xml.*; import static dr.inferencexml.distribution.RandomFieldParser.WEIGHTS_RULE; @@ -43,9 +44,10 @@ public class BaselineIncrementFieldParser extends AbstractXMLObjectParser { public Object parseXMLObject(XMLObject xo) throws XMLParseException { - Parameter baseline = (Parameter) xo.getElementFirstChild(BASELINE); - Parameter increments = (Parameter) xo.getElementFirstChild(INCREMENTS); - RandomField.WeightProvider weights = parseWeightProvider(xo, increments.getDimension() + 1); + Distribution baseline = (Distribution) xo.getElementFirstChild(BASELINE); + Distribution increments = (Distribution) xo.getElementFirstChild(INCREMENTS); + + RandomField.WeightProvider weights = parseWeightProvider(xo, 0); String id = xo.hasId() ? xo.getId() : PARSER_NAME; diff --git a/src/dr/inferencexml/distribution/BayesianBridgeMarkovRandomFieldLikelihoodParser.java b/src/dr/inferencexml/distribution/BayesianBridgeMarkovRandomFieldLikelihoodParser.java index c42cae5820..b8f4e58f00 100644 --- a/src/dr/inferencexml/distribution/BayesianBridgeMarkovRandomFieldLikelihoodParser.java +++ b/src/dr/inferencexml/distribution/BayesianBridgeMarkovRandomFieldLikelihoodParser.java @@ -29,9 +29,6 @@ import dr.inference.distribution.ParametricDistributionModel; import dr.inference.distribution.shrinkage.*; import dr.inference.model.Parameter; -import dr.inference.model.ParameterParser; -import dr.inferencexml.distribution.shrinkage.BayesianBridgeDistributionModelParser; -import dr.util.FirstOrderFiniteDifferenceTransform; import dr.util.InverseFirstOrderFiniteDifferenceTransform; import dr.util.Transform; import dr.xml.*; @@ -51,8 +48,6 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException { Parameter variables = (Parameter) xo.getChild(Parameter.class); -// BayesianBridgeDistributionModelParser bridgeParser = new BayesianBridgeDistributionModelParser(); -// BayesianBridgeDistributionModel bridge = (BayesianBridgeDistributionModel) bridgeParser.parseXMLObject(xo.getChild(BAYESIAN_BRIDGE_DISTRIBUTION).getChild(0)); BayesianBridgeDistributionModel bridge = (BayesianBridgeDistributionModel) xo.getChild(BAYESIAN_BRIDGE_DISTRIBUTION).getChild(0); ParametricDistributionModel firstElementDistribution = (ParametricDistributionModel) xo.getChild(FIRST_ELEMENT_DISTRIBUTION).getChild(0); @@ -60,13 +55,13 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException { double upper = xo.getAttribute("upper", 1.0); double lower = xo.getAttribute("lower", 0.0); - String ttype = (String) xo.getAttribute(INCREMENT_TRANSFORM, "none"); + String transformType = xo.getAttribute(INCREMENT_TRANSFORM, "none"); Transform.UnivariableTransform incrementTransform; - if ( ttype.equalsIgnoreCase("none") ) { + if ( transformType.equalsIgnoreCase("none") ) { incrementTransform = new Transform.NoTransform(); - } else if ( ttype.equalsIgnoreCase("log") ) { + } else if ( transformType.equalsIgnoreCase("log") ) { incrementTransform = new Transform.LogTransform(); - } else if ( ttype.equalsIgnoreCase("logit") ) { + } else if ( transformType.equalsIgnoreCase("logit") ) { incrementTransform = new Transform.ScaledLogitTransform(lower, upper); } else { throw new RuntimeException("Invalid option for "+ INCREMENT_TRANSFORM); diff --git a/src/dr/inferencexml/distribution/BayesianBridgeMarkovRandomFieldParser.java b/src/dr/inferencexml/distribution/BayesianBridgeMarkovRandomFieldParser.java index fb55a6262c..466a17d55a 100644 --- a/src/dr/inferencexml/distribution/BayesianBridgeMarkovRandomFieldParser.java +++ b/src/dr/inferencexml/distribution/BayesianBridgeMarkovRandomFieldParser.java @@ -30,7 +30,6 @@ import dr.inference.distribution.shrinkage.JointBayesianBridgeDistributionModel; import dr.inference.model.Parameter; import dr.math.distributions.BayesianBridgeMarkovRandomField; -import dr.math.distributions.GaussianMarkovRandomField; import dr.xml.*; import static dr.inferencexml.distribution.RandomFieldParser.WEIGHTS_RULE; diff --git a/src/dr/math/distributions/BaselineIncrementField.java b/src/dr/math/distributions/BaselineIncrementField.java index 201e75cd4d..cffbdda7dc 100644 --- a/src/dr/math/distributions/BaselineIncrementField.java +++ b/src/dr/math/distributions/BaselineIncrementField.java @@ -26,6 +26,7 @@ package dr.math.distributions; import dr.inference.distribution.RandomField; +import dr.inference.distribution.shrinkage.BayesianBridgeStatisticsProvider; import dr.inference.model.GradientProvider; import dr.inference.model.Model; import dr.inference.model.Parameter; @@ -36,29 +37,45 @@ * @author Yucai Shao * @author Andy Magee */ -public class BaselineIncrementField extends RandomFieldDistribution { +public class BaselineIncrementField extends RandomFieldDistribution + implements BayesianBridgeStatisticsProvider { public static final String TYPE = "BaselineIncrementField"; - private final Parameter baseline; - private final Parameter increments; - private final RandomField.WeightProvider weights; + private final Distribution baseline; + private final Distribution increments; + + private final GradientProvider baselineGradient; + private final GradientProvider incrementGradient; + + private final BayesianBridgeStatisticsProvider bayesianBridge; public BaselineIncrementField(String name, - Parameter baseline, - Parameter increments, + Distribution baseline, + Distribution increments, RandomField.WeightProvider weights) { super(name); this.baseline = baseline; this.increments = increments; - this.weights = weights; - addVariable(baseline); - addVariable(increments); + if (baseline instanceof Model) { + addModel((Model) baseline); + } + if (increments instanceof Model) { + addModel((Model) increments); + } + + baselineGradient = (baseline instanceof GradientProvider) ? + (GradientProvider) baseline : null; + + incrementGradient = (increments instanceof GradientProvider) ? + (GradientProvider) increments : null; + + bayesianBridge = (increments instanceof BayesianBridgeStatisticsProvider) ? + (BayesianBridgeStatisticsProvider) baseline : null; if (weights != null) { - addModel(weights); throw new IllegalArgumentException("Unsure how weights influence this field"); } } @@ -75,37 +92,91 @@ public double[] nextRandom() { @Override protected void handleModelChangedEvent(Model model, Object object, int index) { - +// if (model == baseline || model == increments) { + // Do nothing + // TODO do we need a fireModelChangedEvent()? +// } } @Override - protected void storeState() { + protected void storeState() { } - } + @Override + protected void restoreState() { } @Override - protected void restoreState() { + protected void acceptState() { } + @Override + protected void handleVariableChangedEvent(Variable variable, int index, Parameter.ChangeType type) { + throw new IllegalArgumentException("Unknown variable"); } @Override - protected void acceptState() { + public double getCoefficient(int i) { + throw new RuntimeException("Should not be called"); + } + @Override + public Parameter getGlobalScale() { + if (bayesianBridge != null) { + return bayesianBridge.getGlobalScale(); + } else { + throw new IllegalArgumentException("Not a Bayesian bridge"); + } } @Override - protected void handleVariableChangedEvent(Variable variable, int index, Parameter.ChangeType type) { + public Parameter getLocalScale() { + if (bayesianBridge != null) { + return bayesianBridge.getLocalScale(); + } else { + throw new IllegalArgumentException("Not a Bayesian bridge"); + } + } + + @Override + public Parameter getExponent() { + if (bayesianBridge != null) { + return bayesianBridge.getExponent(); + } else { + throw new IllegalArgumentException("Not a Bayesian bridge"); + } + } + @Override + public Parameter getSlabWidth() { + if (bayesianBridge != null) { + return bayesianBridge.getSlabWidth(); + } else { + throw new IllegalArgumentException("Not a Bayesian bridge"); + } } @Override public int getDimension() { - throw new RuntimeException("Not yet implemented"); + if (bayesianBridge != null) { + return bayesianBridge.getDimension(); + } else { + throw new IllegalArgumentException("Not a Bayesian bridge"); + } } @Override public double[] getGradientLogDensity(Object x) { - throw new RuntimeException("Not yet implemented"); + double[] field = (double[]) x; + + double[] sub = new double[field.length - 1]; + System.arraycopy(field, 1, sub, 0, field.length - 1); + + double[] baselineGrad = baselineGradient.getGradientLogDensity(new double[]{field[0]}); + double[] incrementGrad = incrementGradient.getGradientLogDensity(sub); + + double[] gradient = new double[field.length]; + gradient[0] = baselineGrad[0]; + System.arraycopy(incrementGrad, 0, gradient, 1, field.length - 1); + + return gradient; } @Override @@ -120,18 +191,20 @@ public double[][] getHessianLogDensity(Object x) { @Override public double logPdf(double[] x) { - throw new RuntimeException("Not yet implemented"); + + double logPdf = baseline.logPdf(x[0]); + for (int i = 1; i < x.length; ++i) { + logPdf += increments.logPdf(x[i]); + } + + return logPdf; } @Override - public double[][] getScaleMatrix() { - throw new RuntimeException("Not yet implemented"); - } + public double[][] getScaleMatrix() { throw new RuntimeException("Not yet implemented");} @Override - public double[] getMean() { - throw new RuntimeException("Not yet implemented"); - } + public double[] getMean() { throw new RuntimeException("Not yet implemented"); } @Override public String getType() { return TYPE; } @@ -140,4 +213,9 @@ public double[] getMean() { public GradientProvider getGradientWrt(Parameter parameter) { throw new RuntimeException("Not yet implemented"); } + + @Override + public double getIncrement(int i, Parameter field) { + return field.getParameterValue(i + 1); + } } diff --git a/src/dr/math/distributions/BayesianBridgeMarkovRandomField.java b/src/dr/math/distributions/BayesianBridgeMarkovRandomField.java index ada7e15e32..10067af016 100644 --- a/src/dr/math/distributions/BayesianBridgeMarkovRandomField.java +++ b/src/dr/math/distributions/BayesianBridgeMarkovRandomField.java @@ -84,6 +84,7 @@ protected SymmetricTriDiagonalMatrix getQ() { protected void handleModelChangedEvent(Model model, Object object, int index) { if (model == bayesianBridge) { qKnown = false; + // TODO do we need a fireModelChangedEvent()? } else { throw new IllegalArgumentException("Unknown model"); } diff --git a/src/dr/math/distributions/GaussianMarkovRandomField.java b/src/dr/math/distributions/GaussianMarkovRandomField.java index b049af5e55..a7c8646ca0 100644 --- a/src/dr/math/distributions/GaussianMarkovRandomField.java +++ b/src/dr/math/distributions/GaussianMarkovRandomField.java @@ -156,6 +156,14 @@ private double[][] getPrecision() { return precision; } + @Override + public double getIncrement(int i, Parameter field) { + + double[] mean = getMean(); + return (field.getParameterValue(i) - mean[i]) - (field.getParameterValue(i + 1) - mean[i + 1]); + } + + @Override public GradientProvider getGradientWrt(Parameter parameter) { if (parameter == precisionParameter) { @@ -251,7 +259,7 @@ public double[][] getScaleMatrix() { @Override public Variable getLocationVariable() { - return null; + return meanParameter; } @Override diff --git a/src/dr/math/distributions/RandomFieldDistribution.java b/src/dr/math/distributions/RandomFieldDistribution.java index 3bf018678d..e559ca6e65 100644 --- a/src/dr/math/distributions/RandomFieldDistribution.java +++ b/src/dr/math/distributions/RandomFieldDistribution.java @@ -41,4 +41,6 @@ public RandomFieldDistribution(String name) { } public abstract GradientProvider getGradientWrt(Parameter parameter); + + public abstract double getIncrement(int i, Parameter field); }