diff --git a/src/dr/evomodel/treedatalikelihood/continuous/TreeTipGradient.java b/src/dr/evomodel/treedatalikelihood/continuous/TreeTipGradient.java index 10ea6c4dd2..c53ca99916 100644 --- a/src/dr/evomodel/treedatalikelihood/continuous/TreeTipGradient.java +++ b/src/dr/evomodel/treedatalikelihood/continuous/TreeTipGradient.java @@ -51,13 +51,15 @@ public class TreeTipGradient implements GradientWrtParameterProvider, Reportable private final int dimTrait; private final Parameter maskParameter; + private final int gradientOffset; public TreeTipGradient(String traitName, + Parameter specifiedParameter, TreeDataLikelihood treeDataLikelihood, ContinuousDataLikelihoodDelegate likelihoodDelegate, Parameter maskParameter) { - assert(treeDataLikelihood != null); + assert (treeDataLikelihood != null); this.treeDataLikelihood = treeDataLikelihood; this.tree = treeDataLikelihood.getTree(); @@ -82,24 +84,56 @@ public TreeTipGradient(String traitName, if (treeDataLikelihood.getTreeTrait(nFcdName) == null) { likelihoodDelegate.addNewFullConditionalDensityTrait(traitName); } - + treeTraitProvider = treeDataLikelihood.getTreeTrait(name); assert (treeTraitProvider != null); nTaxa = treeDataLikelihood.getTree().getExternalNodeCount(); - nTraits = treeDataLikelihood.getDataLikelihoodDelegate().getTraitCount(); - dimTrait = treeDataLikelihood.getDataLikelihoodDelegate().getTraitDim(); +// nTraits = treeDataLikelihood.getDataLikelihoodDelegate().getTraitCount(); +// dimTrait = treeDataLikelihood.getDataLikelihoodDelegate().getTraitDim(); // PrecisionType precisionType = likelihoodDelegate.getPrecisionType(); // int dimPartial = precisionType.getMatrixLength(dimTrait); - + + + int offset = 0; + ContinuousTraitPartialsProvider dataModel = likelihoodDelegate.getDataModel(); + + if (specifiedParameter == null) { + specifiedParameter = dataModel.getParameter(); + if (dataModel.getDataDimension() != dataModel.getTraitDimension()) { + throw new RuntimeException("Not currently implemented with unspecified parameter and dimension " + + "reduction."); + } + } else { + if (specifiedParameter != dataModel.getParameter()) { + ContinuousTraitPartialsProvider[] childModels = dataModel.getChildModels(); + // TODO: recurse child models + for (int i = 0; i < childModels.length; i++) { + dataModel = childModels[i]; + if (dataModel.getParameter() == specifiedParameter) { + break; + } + offset += dataModel.getTraitDimension(); + } + } + } + + if (specifiedParameter != dataModel.getParameter()) { + throw new RuntimeException("Supplied parameter does not match the parameter in the data model" + + " or any of its submodels."); + } + + this.traitParameter = dataModel.getParameter(); + this.dimTrait = dataModel.getTraitDimension(); + this.nTraits = dataModel.getTraitCount(); + this.gradientOffset = offset; + if (nTraits != 1) { throw new RuntimeException("Not yet implemented for >1 traits"); } - this.traitParameter = likelihoodDelegate.getDataModel().getParameter(); - if (maskParameter != null && (maskParameter.getDimension() != traitParameter.getDimension())) { throw new RuntimeException("Trait and mask parameters must be the same size"); @@ -120,17 +154,17 @@ public Parameter getParameter() { public int getDimension() { return getParameter().getDimension(); } - + @Override public double[] getGradientLogDensity() { - double[] gradient = new double[nTaxa * dimTrait * nTraits]; + double[] gradient = new double[nTaxa * dimTrait * nTraits]; int offsetOutput = 0; for (int taxon = 0; taxon < nTaxa; ++taxon) { double[] taxonGradient = (double[]) treeTraitProvider.getTrait(tree, tree.getExternalNode(taxon)); - System.arraycopy(taxonGradient, 0, gradient, offsetOutput, taxonGradient.length); - offsetOutput += taxonGradient.length; + System.arraycopy(taxonGradient, gradientOffset, gradient, offsetOutput, dimTrait); + offsetOutput += dimTrait; } if (maskParameter != null) { diff --git a/src/dr/evomodelxml/continuous/hmc/FullyConjugateTreeTipsPotentialDerivativeParser.java b/src/dr/evomodelxml/continuous/hmc/FullyConjugateTreeTipsPotentialDerivativeParser.java index 2dd91d45b7..61840a84af 100644 --- a/src/dr/evomodelxml/continuous/hmc/FullyConjugateTreeTipsPotentialDerivativeParser.java +++ b/src/dr/evomodelxml/continuous/hmc/FullyConjugateTreeTipsPotentialDerivativeParser.java @@ -49,6 +49,7 @@ public class FullyConjugateTreeTipsPotentialDerivativeParser extends AbstractXML private final static String FULLY_CONJUGATE_TREE_TIPS_POTENTIAL_DERIVATIVE2 = "traitGradientOnTree"; public static final String TRAIT_NAME = TreeTraitParserUtilities.TRAIT_NAME; private static final String MASKING = MaskedParameterParser.MASKING; + private static final String TRAITS = "treeTraitInds"; @Override public String getParserName() { @@ -57,7 +58,7 @@ public String getParserName() { @Override public String[] getParserNames() { - return new String[] { FULLY_CONJUGATE_TREE_TIPS_POTENTIAL_DERIVATIVE, FULLY_CONJUGATE_TREE_TIPS_POTENTIAL_DERIVATIVE2 }; + return new String[]{FULLY_CONJUGATE_TREE_TIPS_POTENTIAL_DERIVATIVE, FULLY_CONJUGATE_TREE_TIPS_POTENTIAL_DERIVATIVE2}; } @Override @@ -75,11 +76,18 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException { mask = (Parameter) xo.getElementFirstChild(MASKING); } + Parameter specifiedParameter = (Parameter) xo.getChild(Parameter.class); + if (fcTreeLikelihood != null) { + if (specifiedParameter != fcTreeLikelihood.getTraitParameter()) { + System.err.println("Warning: specified parameter and assumed parameter for '" + xo.getName() + + "' do not match."); //TODO: better warning + } + return new FullyConjugateTreeTipsPotentialDerivative(fcTreeLikelihood, mask); - } else if (treeDataLikelihood != null){ + } else if (treeDataLikelihood != null) { DataLikelihoodDelegate delegate = treeDataLikelihood.getDataLikelihoodDelegate(); @@ -89,7 +97,7 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException { final ContinuousDataLikelihoodDelegate continuousData = (ContinuousDataLikelihoodDelegate) delegate; - return new TreeTipGradient(traitName, treeDataLikelihood, continuousData, mask); + return new TreeTipGradient(traitName, specifiedParameter, treeDataLikelihood, continuousData, mask); } else { throw new XMLParseException("Must provide a tree likelihood"); } @@ -110,6 +118,7 @@ public XMLSyntaxRule[] getSyntaxRules() { new ElementRule(Parameter.class) }, true), new ElementRule(Parameter.class, true), + AttributeRule.newIntegerArrayRule(TRAITS, true) }; @Override diff --git a/src/test/dr/evomodel/treedatalikelihood/continuous/ContinuousDataLikelihoodDelegateTest.java b/src/test/dr/evomodel/treedatalikelihood/continuous/ContinuousDataLikelihoodDelegateTest.java index 582bb9ba46..b257936932 100644 --- a/src/test/dr/evomodel/treedatalikelihood/continuous/ContinuousDataLikelihoodDelegateTest.java +++ b/src/test/dr/evomodel/treedatalikelihood/continuous/ContinuousDataLikelihoodDelegateTest.java @@ -1006,7 +1006,7 @@ private void testLikelihood(String message, IntegratedFactorAnalysisLikelihood d private void testConditionalMoments(TreeDataLikelihood dataLikelihood, ContinuousDataLikelihoodDelegate likelihoodDelegate) { new TreeTipGradient("" + - "trait", dataLikelihood, likelihoodDelegate, null); + "trait", null, dataLikelihood, likelihoodDelegate, null); TreeTraitLogger treeTraitLogger = new TreeTraitLogger(treeModel, new TreeTrait[]{dataLikelihood.getTreeTrait("fcd.trait")}, TreeTraitLogger.NodeRestriction.EXTERNAL, false);