Skip to content

Commit

Permalink
specifying parameter in 'traitGradientOnTree' actually does something
Browse files Browse the repository at this point in the history
  • Loading branch information
gabehassler committed Jan 17, 2024
1 parent 5132abf commit 4e7ed59
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 15 deletions.
56 changes: 45 additions & 11 deletions src/dr/evomodel/treedatalikelihood/continuous/TreeTipGradient.java
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,15 @@ public class TreeTipGradient implements GradientWrtParameterProvider, Reportable
private final int dimTrait;

private final Parameter maskParameter;
private final int gradientOffset;

public TreeTipGradient(String traitName,
Parameter specifiedParameter,
TreeDataLikelihood treeDataLikelihood,
ContinuousDataLikelihoodDelegate likelihoodDelegate,
Parameter maskParameter) {

assert(treeDataLikelihood != null);
assert (treeDataLikelihood != null);

this.treeDataLikelihood = treeDataLikelihood;
this.tree = treeDataLikelihood.getTree();
Expand All @@ -82,24 +84,56 @@ public TreeTipGradient(String traitName,
if (treeDataLikelihood.getTreeTrait(nFcdName) == null) {
likelihoodDelegate.addNewFullConditionalDensityTrait(traitName);
}

treeTraitProvider = treeDataLikelihood.getTreeTrait(name);

assert (treeTraitProvider != null);

nTaxa = treeDataLikelihood.getTree().getExternalNodeCount();
nTraits = treeDataLikelihood.getDataLikelihoodDelegate().getTraitCount();
dimTrait = treeDataLikelihood.getDataLikelihoodDelegate().getTraitDim();
// nTraits = treeDataLikelihood.getDataLikelihoodDelegate().getTraitCount();
// dimTrait = treeDataLikelihood.getDataLikelihoodDelegate().getTraitDim();

// PrecisionType precisionType = likelihoodDelegate.getPrecisionType();
// int dimPartial = precisionType.getMatrixLength(dimTrait);



int offset = 0;
ContinuousTraitPartialsProvider dataModel = likelihoodDelegate.getDataModel();

if (specifiedParameter == null) {
specifiedParameter = dataModel.getParameter();
if (dataModel.getDataDimension() != dataModel.getTraitDimension()) {
throw new RuntimeException("Not currently implemented with unspecified parameter and dimension " +
"reduction.");
}
} else {
if (specifiedParameter != dataModel.getParameter()) {
ContinuousTraitPartialsProvider[] childModels = dataModel.getChildModels();
// TODO: recurse child models
for (int i = 0; i < childModels.length; i++) {
dataModel = childModels[i];
if (dataModel.getParameter() == specifiedParameter) {
break;
}
offset += dataModel.getTraitDimension();
}
}
}

if (specifiedParameter != dataModel.getParameter()) {
throw new RuntimeException("Supplied parameter does not match the parameter in the data model" +
" or any of its submodels.");
}

this.traitParameter = dataModel.getParameter();
this.dimTrait = dataModel.getTraitDimension();
this.nTraits = dataModel.getTraitCount();
this.gradientOffset = offset;

if (nTraits != 1) {
throw new RuntimeException("Not yet implemented for >1 traits");
}

this.traitParameter = likelihoodDelegate.getDataModel().getParameter();

if (maskParameter != null &&
(maskParameter.getDimension() != traitParameter.getDimension())) {
throw new RuntimeException("Trait and mask parameters must be the same size");
Expand All @@ -120,17 +154,17 @@ public Parameter getParameter() {
public int getDimension() {
return getParameter().getDimension();
}

@Override
public double[] getGradientLogDensity() {

double[] gradient = new double[nTaxa * dimTrait * nTraits];
double[] gradient = new double[nTaxa * dimTrait * nTraits];

int offsetOutput = 0;
for (int taxon = 0; taxon < nTaxa; ++taxon) {
double[] taxonGradient = (double[]) treeTraitProvider.getTrait(tree, tree.getExternalNode(taxon));
System.arraycopy(taxonGradient, 0, gradient, offsetOutput, taxonGradient.length);
offsetOutput += taxonGradient.length;
System.arraycopy(taxonGradient, gradientOffset, gradient, offsetOutput, dimTrait);
offsetOutput += dimTrait;
}

if (maskParameter != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ public class FullyConjugateTreeTipsPotentialDerivativeParser extends AbstractXML
private final static String FULLY_CONJUGATE_TREE_TIPS_POTENTIAL_DERIVATIVE2 = "traitGradientOnTree";
public static final String TRAIT_NAME = TreeTraitParserUtilities.TRAIT_NAME;
private static final String MASKING = MaskedParameterParser.MASKING;
private static final String TRAITS = "treeTraitInds";

@Override
public String getParserName() {
Expand All @@ -57,7 +58,7 @@ public String getParserName() {

@Override
public String[] getParserNames() {
return new String[] { FULLY_CONJUGATE_TREE_TIPS_POTENTIAL_DERIVATIVE, FULLY_CONJUGATE_TREE_TIPS_POTENTIAL_DERIVATIVE2 };
return new String[]{FULLY_CONJUGATE_TREE_TIPS_POTENTIAL_DERIVATIVE, FULLY_CONJUGATE_TREE_TIPS_POTENTIAL_DERIVATIVE2};
}

@Override
Expand All @@ -75,11 +76,18 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException {
mask = (Parameter) xo.getElementFirstChild(MASKING);
}

Parameter specifiedParameter = (Parameter) xo.getChild(Parameter.class);

if (fcTreeLikelihood != null) {

if (specifiedParameter != fcTreeLikelihood.getTraitParameter()) {
System.err.println("Warning: specified parameter and assumed parameter for '" + xo.getName() +
"' do not match."); //TODO: better warning
}

return new FullyConjugateTreeTipsPotentialDerivative(fcTreeLikelihood, mask);

} else if (treeDataLikelihood != null){
} else if (treeDataLikelihood != null) {

DataLikelihoodDelegate delegate = treeDataLikelihood.getDataLikelihoodDelegate();

Expand All @@ -89,7 +97,7 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException {

final ContinuousDataLikelihoodDelegate continuousData = (ContinuousDataLikelihoodDelegate) delegate;

return new TreeTipGradient(traitName, treeDataLikelihood, continuousData, mask);
return new TreeTipGradient(traitName, specifiedParameter, treeDataLikelihood, continuousData, mask);
} else {
throw new XMLParseException("Must provide a tree likelihood");
}
Expand All @@ -110,6 +118,7 @@ public XMLSyntaxRule[] getSyntaxRules() {
new ElementRule(Parameter.class)
}, true),
new ElementRule(Parameter.class, true),
AttributeRule.newIntegerArrayRule(TRAITS, true)
};

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1006,7 +1006,7 @@ private void testLikelihood(String message, IntegratedFactorAnalysisLikelihood d

private void testConditionalMoments(TreeDataLikelihood dataLikelihood, ContinuousDataLikelihoodDelegate likelihoodDelegate) {
new TreeTipGradient("" +
"trait", dataLikelihood, likelihoodDelegate, null);
"trait", null, dataLikelihood, likelihoodDelegate, null);
TreeTraitLogger treeTraitLogger = new TreeTraitLogger(treeModel,
new TreeTrait[]{dataLikelihood.getTreeTrait("fcd.trait")},
TreeTraitLogger.NodeRestriction.EXTERNAL, false);
Expand Down

0 comments on commit 4e7ed59

Please sign in to comment.