Skip to content

Commit

Permalink
calculate Hessian wrt transformed space
Browse files Browse the repository at this point in the history
  • Loading branch information
xji3 committed Nov 21, 2024
1 parent 2ab57d5 commit 800293a
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,9 @@ public ApproximateTreeDataLikelihood(MaximizerWrtParameter maximizer) {
this.parameter = gradient.getParameter();
this.marginalLikelihoodConst = (parameter.getDimension() - 1) * Math.log(2 * Math.PI);
// todo: get Numerical Hessian.
if (isGradientProvidingHessian(gradient)) {
if (maximizer.getTransform() != null) {
this.hessianWrtParameterProvider = constructHessian();
} else if (isGradientProvidingHessian(gradient)) {
this.hessianWrtParameterProvider = (HessianWrtParameterProvider) gradient;
} else {
this.hessianWrtParameterProvider = new NumericalHessianFromGradient(gradient);
Expand Down Expand Up @@ -100,63 +102,33 @@ private boolean isGradientProvidingHessian(GradientWrtParameterProvider gradient
}

private HessianWrtParameterProvider constructHessian() {
GradientWrtParameterProvider gradientWrtParameterProvider = new GradientWrtParameterProvider() {

final MultivariateFunction function = new MultivariateFunction() {
@Override
public double evaluate(double[] argument) {

setParameter(new WrappedVector.Raw(argument), parameter);
return getLogLikelihood();
}

@Override
public int getNumArguments() {
return parameter.getDimension();
}

@Override
public double getLowerBound(int n) {
return Double.NEGATIVE_INFINITY;
}

@Override
public double getUpperBound(int n) {
return Double.POSITIVE_INFINITY;
}
};

return new HessianWrtParameterProvider() {
private TransformedMultivariateParameter transformedParameter = new TransformedMultivariateParameter(parameter, (Transform.MultivariableTransform) maximizer.getTransform());

@Override
public Likelihood getLikelihood() {
return likelihood;
throw new RuntimeException("should not be called");
}

@Override
public Parameter getParameter() {
return parameter;
return transformedParameter;
}

@Override
public int getDimension() {
return parameter.getDimension();
return transformedParameter.getDimension();
}

@Override
public double[] getGradientLogDensity() {
return getGradientLogDensity();
}

@Override
public double[] getDiagonalHessianLogDensity() {
return NumericalDerivative.diagonalHessian(function, parameter.getParameterValues());
}

@Override
public double[][] getHessianLogDensity() {
return NumericalDerivative.getNumericalHessian(function, parameter.getParameterValues());
double[] untransformedGradient = maximizer.getGradient().getGradientLogDensity();
return maximizer.getTransform().updateGradientLogDensity(untransformedGradient, parameter.getParameterValues(), 0, parameter.getDimension());
}
};

return new NumericalHessianFromGradient(gradientWrtParameterProvider);
}

private void updateMarginalLikelihood() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@

package dr.evomodel.treedatalikelihood.discrete;

import dr.evolution.tree.NodeRef;
import dr.evomodel.tree.DefaultTreeModel;
import dr.evomodel.tree.TreeChangedEvent;
import dr.evomodel.tree.TreeModel;
import dr.evomodel.tree.TreeParameterModel;
Expand All @@ -53,6 +55,11 @@ public NodeHeightProxyParameter(String name,
includeRoot);
}

@Override
public Bounds<Double> getBounds() {
return null;
}

public TreeModel getTree() {
return tree;
}
Expand Down

0 comments on commit 800293a

Please sign in to comment.