Skip to content

Commit

Permalink
prelim impl of generic gradient after chain-rule
Browse files Browse the repository at this point in the history
  • Loading branch information
msuchard committed Nov 4, 2023
1 parent bd6afdc commit 31a43e3
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 20 deletions.
33 changes: 25 additions & 8 deletions src/dr/inference/hmc/TransformedGradientWrtParameter.java
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

import dr.inference.model.Likelihood;
import dr.inference.model.Parameter;
import dr.inference.model.TransformedParameter;
import dr.util.Transform;
import dr.xml.Reportable;

Expand All @@ -38,35 +39,51 @@
public class TransformedGradientWrtParameter implements GradientWrtParameterProvider, Reportable {

private final GradientWrtParameterProvider gradient;
private final Transform transform;
private final TransformedParameter parameter;

public TransformedGradientWrtParameter(GradientWrtParameterProvider gradient,
Transform transform) {
TransformedParameter parameter) {
this.gradient = gradient;
this.transform = transform;
this.parameter = parameter;
}
@Override
public Likelihood getLikelihood() {
return null;
return gradient.getLikelihood();
}

@Override
public Parameter getParameter() {
return null;
return parameter.getUntransformedParameter();
}

@Override
public int getDimension() {
return 0;
return parameter.getUntransformedParameter().getDimension();
}

@Override
public double[] getGradientLogDensity() {
return new double[0];

double[] transformedGradient = gradient.getGradientLogDensity();
double[] untransformedValues = parameter.getParameterUntransformedValues();

Transform transform = parameter.getTransform();

double[] untransformedGradient;
if (transform instanceof Transform.MultivariableTransform) {
Transform.MultivariableTransform multivariableTransform = (Transform.MultivariableTransform) transform;
untransformedGradient = multivariableTransform.updateGradientLogDensity(transformedGradient, untransformedValues,
0, untransformedValues.length);
} else {
throw new RuntimeException("Not yet implemented");
}

return untransformedGradient;
}

@Override
public String getReport() {
return null;
return GradientWrtParameterProvider.getReportAndCheckForError(this, 0.0,
Double.POSITIVE_INFINITY, null);
}
}
23 changes: 19 additions & 4 deletions src/dr/inferencexml/hmc/TransformedGradientWrtParameterParser.java
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@

import dr.inference.hmc.GradientWrtParameterProvider;
import dr.inference.hmc.TransformedGradientWrtParameter;
import dr.inference.model.Parameter;
import dr.inference.model.TransformedParameter;
import dr.util.Transform;
import dr.xml.*;

Expand All @@ -37,24 +39,37 @@

public class TransformedGradientWrtParameterParser extends AbstractXMLObjectParser {

public static final String PARSER_NAME = "transformedGradient";
private static final String PARSER_NAME = "transformedGradient";
private static final String WRT = "wrt";

@Override
public Object parseXMLObject(XMLObject xo) throws XMLParseException {

GradientWrtParameterProvider gradient = (GradientWrtParameterProvider)
xo.getChild(GradientWrtParameterProvider.class);

Transform transform = (Transform) xo.getChild(Transform.class);
TransformedParameter parameter = (TransformedParameter) xo.getChild(TransformedParameter.class);

return new TransformedGradientWrtParameter(gradient, transform);
if (xo.hasChildNamed(WRT)) {

Parameter wrt = (Parameter) xo.getElementFirstChild(WRT);

if (wrt != parameter.getUntransformedParameter()) {
throw new XMLParseException("Mismatch between transformed and untransformed parameters");
}
}

return new TransformedGradientWrtParameter(gradient, parameter);
}

@Override
public XMLSyntaxRule[] getSyntaxRules() {
return new XMLSyntaxRule[] {
new ElementRule(GradientWrtParameterProvider.class),
new ElementRule(Transform.class),
new ElementRule(TransformedParameter.class),
new ElementRule(WRT, new XMLSyntaxRule[]{
new ElementRule(Parameter.class),
}, true),
};
}

Expand Down
35 changes: 27 additions & 8 deletions src/dr/util/TransformedVectorSumTransform.java
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,11 @@
public class TransformedVectorSumTransform extends Transform.MultivariateTransform {

private static final String NAME = "transformedVectorSumTransform";
private static final String PARSER_NAME2 = "vectorScanTransformedParameter";
private final Transform incrementTransform;
private final int dim;

public TransformedVectorSumTransform(int dim, Transform incrementTransform) {
super(dim);
this.dim = dim;
this.incrementTransform = incrementTransform;
}

Expand All @@ -39,6 +38,24 @@ public String getTransformName() {
return NAME;
}

@Override
protected double[] updateGradientLogDensity(double[] transformedGradient, double[] untransformedValues) {

final int dim = untransformedValues.length;

double[] transformedValues = transform(untransformedValues); // TODO This seem unnecessary; maybe change interface to pass these values?
double[] untransformedGradient = new double[dim];

untransformedGradient[dim - 1] = transformedGradient[dim - 1] *
incrementTransform.gradient(transformedValues[dim - 1]);
for (int i = dim - 2; i >= 0; --i) {
untransformedGradient[i] = transformedGradient[i] *
incrementTransform.gradient(transformedValues[i]) + untransformedGradient[i + 1];
}

return untransformedGradient;
}

@Override
protected double[] transform(double[] values) {
double[] fx = new double[values.length];
Expand Down Expand Up @@ -85,7 +102,6 @@ protected boolean isInInteriorDomain(double[] values) {

@Override
public TransformedMultivariateParameter parseXMLObject(XMLObject xo) throws XMLParseException {
final String name = xo.hasId() ? xo.getId() : null;

Parameter param = (Parameter) xo.getChild(Parameter.class);

Expand All @@ -98,13 +114,13 @@ public TransformedMultivariateParameter parseXMLObject(XMLObject xo) throws XMLP
lower = xo.getDoubleAttribute("lower");
}

Transform incrementTransform = null;
String ttype = (String) xo.getAttribute(INCREMENT_TRANSFORM);
if (ttype.equalsIgnoreCase("log")) {
Transform incrementTransform;
String transformType = (String) xo.getAttribute(INCREMENT_TRANSFORM);
if (transformType.equalsIgnoreCase("log")) {
incrementTransform = Transform.LOG;
} else if (ttype.equalsIgnoreCase("logit")) {
} else if (transformType.equalsIgnoreCase("logit")) {
incrementTransform = new Transform.ScaledLogitTransform(upper, lower);
} else if (ttype.equalsIgnoreCase("none")) {
} else if (transformType.equalsIgnoreCase("none")) {
incrementTransform = new Transform.NoTransform();
} else {
throw new RuntimeException("Invalid transform type");
Expand Down Expand Up @@ -134,5 +150,8 @@ public Class getReturnType() {
public String getParserName() {
return NAME;
}

@Override
public String[] getParserNames() { return new String[]{NAME, PARSER_NAME2}; }
};
}

0 comments on commit 31a43e3

Please sign in to comment.