From 800293ac4214ae9699f1647d2bed13a10ff04f61 Mon Sep 17 00:00:00 2001 From: xji3 Date: Thu, 21 Nov 2024 15:43:34 -0600 Subject: [PATCH] calculate Hessian wrt transformed space --- .../ApproximateTreeDataLikelihood.java | 52 +++++-------------- .../discrete/NodeHeightProxyParameter.java | 7 +++ 2 files changed, 19 insertions(+), 40 deletions(-) diff --git a/src/dr/evomodel/treedatalikelihood/ApproximateTreeDataLikelihood.java b/src/dr/evomodel/treedatalikelihood/ApproximateTreeDataLikelihood.java index cf2ab072f0..fd2b020e50 100644 --- a/src/dr/evomodel/treedatalikelihood/ApproximateTreeDataLikelihood.java +++ b/src/dr/evomodel/treedatalikelihood/ApproximateTreeDataLikelihood.java @@ -70,7 +70,9 @@ public ApproximateTreeDataLikelihood(MaximizerWrtParameter maximizer) { this.parameter = gradient.getParameter(); this.marginalLikelihoodConst = (parameter.getDimension() - 1) * Math.log(2 * Math.PI); // todo: get Numerical Hessian. - if (isGradientProvidingHessian(gradient)) { + if (maximizer.getTransform() != null) { + this.hessianWrtParameterProvider = constructHessian(); + } else if (isGradientProvidingHessian(gradient)) { this.hessianWrtParameterProvider = (HessianWrtParameterProvider) gradient; } else { this.hessianWrtParameterProvider = new NumericalHessianFromGradient(gradient); @@ -100,63 +102,33 @@ private boolean isGradientProvidingHessian(GradientWrtParameterProvider gradient } private HessianWrtParameterProvider constructHessian() { + GradientWrtParameterProvider gradientWrtParameterProvider = new GradientWrtParameterProvider() { - final MultivariateFunction function = new MultivariateFunction() { - @Override - public double evaluate(double[] argument) { - - setParameter(new WrappedVector.Raw(argument), parameter); - return getLogLikelihood(); - } - - @Override - public int getNumArguments() { - return parameter.getDimension(); - } - - @Override - public double getLowerBound(int n) { - return Double.NEGATIVE_INFINITY; - } - - @Override - public double getUpperBound(int n) { - return Double.POSITIVE_INFINITY; - } - }; - - return new HessianWrtParameterProvider() { + private TransformedMultivariateParameter transformedParameter = new TransformedMultivariateParameter(parameter, (Transform.MultivariableTransform) maximizer.getTransform()); @Override public Likelihood getLikelihood() { - return likelihood; + throw new RuntimeException("should not be called"); } @Override public Parameter getParameter() { - return parameter; + return transformedParameter; } @Override public int getDimension() { - return parameter.getDimension(); + return transformedParameter.getDimension(); } @Override public double[] getGradientLogDensity() { - return getGradientLogDensity(); - } - - @Override - public double[] getDiagonalHessianLogDensity() { - return NumericalDerivative.diagonalHessian(function, parameter.getParameterValues()); - } - - @Override - public double[][] getHessianLogDensity() { - return NumericalDerivative.getNumericalHessian(function, parameter.getParameterValues()); + double[] untransformedGradient = maximizer.getGradient().getGradientLogDensity(); + return maximizer.getTransform().updateGradientLogDensity(untransformedGradient, parameter.getParameterValues(), 0, parameter.getDimension()); } }; + + return new NumericalHessianFromGradient(gradientWrtParameterProvider); } private void updateMarginalLikelihood() { diff --git a/src/dr/evomodel/treedatalikelihood/discrete/NodeHeightProxyParameter.java b/src/dr/evomodel/treedatalikelihood/discrete/NodeHeightProxyParameter.java index 83a509fb3a..10982b67f9 100644 --- a/src/dr/evomodel/treedatalikelihood/discrete/NodeHeightProxyParameter.java +++ b/src/dr/evomodel/treedatalikelihood/discrete/NodeHeightProxyParameter.java @@ -27,6 +27,8 @@ package dr.evomodel.treedatalikelihood.discrete; +import dr.evolution.tree.NodeRef; +import dr.evomodel.tree.DefaultTreeModel; import dr.evomodel.tree.TreeChangedEvent; import dr.evomodel.tree.TreeModel; import dr.evomodel.tree.TreeParameterModel; @@ -53,6 +55,11 @@ public NodeHeightProxyParameter(String name, includeRoot); } + @Override + public Bounds getBounds() { + return null; + } + public TreeModel getTree() { return tree; }