Skip to content

Commit

Permalink
fix coalescent gradient by using mapping function in BigFastTreeInter…
Browse files Browse the repository at this point in the history
…vals
  • Loading branch information
xji3 committed Nov 6, 2023
1 parent 491f3ad commit dfb2cd3
Showing 1 changed file with 41 additions and 13 deletions.
54 changes: 41 additions & 13 deletions src/dr/evolution/coalescent/CoalescentGradient.java
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@
package dr.evolution.coalescent;


import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
import dr.evomodel.bigfasttree.BigFastTreeIntervals;
import dr.evomodel.coalescent.CoalescentLikelihood;
import dr.evomodel.tree.TreeModel;
import dr.evomodel.treedatalikelihood.discrete.NodeHeightProxyParameter;
Expand Down Expand Up @@ -97,29 +97,57 @@ public double[] getGradientLogDensity() {
getIntervalIndexForInternalNodes(intervalIndices, nodeIndices, sortedValues);

IntervalList intervals = likelihood.getIntervalList();
BigFastTreeIntervals bigFastTreeIntervals = (BigFastTreeIntervals) intervals;

DemographicFunction demographicFunction = likelihood.getDemoModel().getDemographicFunction();

for (int i = 0; i < tree.getInternalNodeCount(); i++) {
NodeRef node = tree.getNode(tree.getExternalNodeCount() + nodeIndices[i]);
final double time = tree.getNodeHeight(node);
final double intensityGradient = demographicFunction.getIntensityGradient(time);
final double kChoose2 = Binomial.choose2(intervals.getLineageCount(intervalIndices[nodeIndices[i]]));
gradient[i] -= demographicFunction.getLogDemographicGradient(time);

if (intervals.getInterval(intervalIndices[nodeIndices[i]]) != 0) {
gradient[i] -= kChoose2 * intensityGradient;
int numSameHeightNodes = 1;
double thisGradient = 0;
for (int i = 0; i < bigFastTreeIntervals.getIntervalCount(); i++) {
if (bigFastTreeIntervals.getIntervalType(i) == IntervalType.COALESCENT) {
final double time = bigFastTreeIntervals.getIntervalTime(i + 1);
final int lineageCount = bigFastTreeIntervals.getLineageCount(i);
final double kChoose2 = Binomial.choose2(lineageCount);
final double intensityGradient = demographicFunction.getIntensityGradient(time);
thisGradient += demographicFunction.getLogDemographicGradient(time);

if (bigFastTreeIntervals.getInterval(i) != 0) {
thisGradient -= kChoose2 * intensityGradient;
} else {
numSameHeightNodes++;
}

if ( i < bigFastTreeIntervals.getIntervalCount() - 1
&& bigFastTreeIntervals.getInterval(i + 1) != 0) {

final int nextLineageCount = bigFastTreeIntervals.getLineageCount(i + 1);
thisGradient += Binomial.choose2(nextLineageCount) * intensityGradient;

for (int j = 0; j < numSameHeightNodes; j++) {
final int nodeIndex = bigFastTreeIntervals.getNodeNumbersForInterval(i - j)[1];
gradient[nodeIndex - tree.getExternalNodeCount()] = thisGradient / (double) numSameHeightNodes;
}

thisGradient = 0;
numSameHeightNodes = 1;
}
}
}

if (!tree.isRoot(node) && intervals.getInterval(intervalIndices[nodeIndices[i]] + 1) != 0.0) {
final int nextLineageCount = intervals.getLineageCount(intervalIndices[nodeIndices[i]] + 1);
gradient[i] += Binomial.choose2(nextLineageCount) * intensityGradient;
int j = numSameHeightNodes;
int v = bigFastTreeIntervals.getIntervalCount() - 1;
while(j > 0) {
if (bigFastTreeIntervals.getIntervalType(v) == IntervalType.COALESCENT) {
gradient[bigFastTreeIntervals.getNodeNumbersForInterval(v)[1] - tree.getExternalNodeCount()] = thisGradient / (double) numSameHeightNodes;
j--;
}
v--;
}

return gradient;
}

@Deprecated
private void getIntervalIndexForInternalNodes(int[] intervalIndices, int[] nodeIndices, double[] sortedValues) {
double[] nodeHeights = new double[tree.getInternalNodeCount()];
ArrayList<ComparableDouble> sortedInternalNodes = new ArrayList<ComparableDouble>();
Expand Down

0 comments on commit dfb2cd3

Please sign in to comment.