Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/master' into beauti-tb
Browse files Browse the repository at this point in the history
  • Loading branch information
jtmccr1 committed Jan 14, 2025
2 parents 41c40e7 + f37ecb6 commit 3d5c850
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 19 deletions.
6 changes: 3 additions & 3 deletions src/dr/app/beauti/generator/ClockModelGenerator.java
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,9 @@ public void writeBranchRatesModel(PartitionClockModel clockModel, XMLWriter writ
writer.writeIDref(DefaultTreeModel.TREE_MODEL, treePrefix + DefaultTreeModel.TREE_MODEL);
writer.writeCloseTag(CTMCScalePriorParser.MODEL_NAME);

}

if (generateScaleGradient){
//location gradient
writer.writeOpenTag(LocationScaleGradientParser.NAME, new Attribute[]{
new Attribute.Default<>(XMLParser.ID, prefix + LocationGradient.LOCATION_GRADIENT),
Expand All @@ -352,9 +355,6 @@ public void writeBranchRatesModel(PartitionClockModel clockModel, XMLWriter writ
writer.writeCloseTag(LocationScaleGradientParser.LOCATION);
writer.writeCloseTag(LocationScaleGradientParser.NAME);

}

if (generateScaleGradient){
//scale gradient
writer.writeOpenTag(LocationScaleGradientParser.NAME, new Attribute[]{
new Attribute.Default<>(XMLParser.ID, prefix + ScaleGradient.SCALE_GRADIENT),
Expand Down
15 changes: 7 additions & 8 deletions src/dr/app/beauti/generator/OperatorsGenerator.java
Original file line number Diff line number Diff line change
Expand Up @@ -233,11 +233,11 @@ private void writeOperator(Operator operator, XMLWriter writer) {
case SCALE_WITH_INDICATORS:
writeScaleWithIndicatorsOperator(operator, writer);
break;
case GMRF_GIBBS_OPERATOR:
writeGMRFGibbsOperator(operator, prefix, writer);
case GMRF_BLOCKUPDATE_OPERATOR:
writeGMRFBlockUpdateOperator(operator, prefix, writer);
break;
case SKY_GRID_GIBBS_OPERATOR:
writeSkyGridGibbsOperator(operator, prefix, writer);
case SKY_GRID_BLOCKUPDATE_OPERATOR:
writeSkyGridBlockUpdateOperator(operator, prefix, writer);
break;
case SKY_GRID_HMC_OPERATOR:
writeSkyGridHMCOperator(operator, prefix, writer);
Expand Down Expand Up @@ -529,12 +529,11 @@ private void writeSampleNonActiveOperator(Operator operator, XMLWriter writer) {
writer.writeCloseTag(SampleNonActiveGibbsOperatorParser.SAMPLE_NONACTIVE_GIBBS_OPERATOR);
}

private void writeSkyGridGibbsOperator(Operator operator, String treePriorPrefix, XMLWriter writer) {
private void writeSkyGridBlockUpdateOperator(Operator operator, String treePriorPrefix, XMLWriter writer) {
writer.writeOpenTag(
GMRFSkyrideBlockUpdateOperatorParser.GRID_BLOCK_UPDATE_OPERATOR,
new Attribute[] {
// This is a Gibbs operator so shouldn't have a tuning parameter?
// new Attribute.Default<Double>(GMRFSkyrideBlockUpdateOperatorParser.SCALE_FACTOR, operator.getTuning()),
new Attribute.Default<Double>(GMRFSkyrideBlockUpdateOperatorParser.SCALE_FACTOR, operator.getTuning()),
getWeightAttribute(operator.getWeight())
}
);
Expand Down Expand Up @@ -699,7 +698,7 @@ private void writeShrinkageClockHMCOperator(Operator operator, String prefix, XM
writer.writeCloseTag(HamiltonianMonteCarloOperatorParser.HMC_OPERATOR);
}

private void writeGMRFGibbsOperator(Operator operator, String treePriorPrefix, XMLWriter writer) {
private void writeGMRFBlockUpdateOperator(Operator operator, String treePriorPrefix, XMLWriter writer) {
writer.writeOpenTag(
GMRFSkyrideBlockUpdateOperatorParser.BLOCK_UPDATE_OPERATOR,
new Attribute[]{
Expand Down
7 changes: 3 additions & 4 deletions src/dr/app/beauti/options/PartitionTreePrior.java
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
package dr.app.beauti.options;

import dr.app.beauti.types.*;
import dr.evomodel.coalescent.VariableDemographicModel;
import dr.evomodel.speciation.CalibrationPoints;
import dr.evomodelxml.coalescent.GMRFSkyrideLikelihoodParser;
import dr.evomodelxml.speciation.BirthDeathEpidemiologyModelParser;
Expand Down Expand Up @@ -271,9 +270,9 @@ public void initModelParametersAndOpererators() {
// "demographic.indicators", OperatorType.SCALE_WITH_INDICATORS, 0.5, 2 * demoWeights);

createOperatorUsing2Parameters("gmrfGibbsOperator", "gmrfGibbsOperator", "Gibbs sampler for GMRF Skyride", "skyride.logPopSize",
"skyride.precision", OperatorType.GMRF_GIBBS_OPERATOR, -1, 2);
"skyride.precision", OperatorType.GMRF_BLOCKUPDATE_OPERATOR, 1, 2);
createOperatorUsing2Parameters("gmrfSkyGridGibbsOperator", "skygrid.logPopSize", "Gibbs sampler for Bayesian SkyGrid", "skygrid.logPopSize",
GMRFSkyrideLikelihoodParser.SKYGRID_PRECISION, OperatorType.SKY_GRID_GIBBS_OPERATOR, -1, 2);
GMRFSkyrideLikelihoodParser.SKYGRID_PRECISION, OperatorType.SKY_GRID_BLOCKUPDATE_OPERATOR, 1, 2);
createScaleOperator(GMRFSkyrideLikelihoodParser.SKYGRID_PRECISION, "skygrid precision", 0.75, 1.0);
createOperatorUsing2Parameters("gmrfSkyGridHMCOperator", "Multiple", "HMC transition kernel for Bayesian SkyGrid", "skygrid.logPopSize",
GMRFSkyrideLikelihoodParser.SKYGRID_PRECISION, OperatorType.SKY_GRID_HMC_OPERATOR, -1, 2);
Expand All @@ -292,7 +291,7 @@ public void initModelParametersAndOpererators() {
createOperator(BirthDeathSerialSamplingModelParser.BDSS + "."
+ BirthDeathSerialSamplingModelParser.RELATIVE_MU, OperatorType.RANDOM_WALK_LOGIT, demoTuning, 1);
createScaleOperator(BirthDeathSerialSamplingModelParser.BDSS + "."
+ BirthDeathSerialSamplingModelParser.PSI, demoTuning, 1); // todo random worl op ?
+ BirthDeathSerialSamplingModelParser.PSI, demoTuning, 1); // todo random walk op ?
createScaleOperator(BirthDeathSerialSamplingModelParser.BDSS + "."
+ BirthDeathSerialSamplingModelParser.ORIGIN, demoTuning, 1);
// createScaleOperator(BirthDeathSerialSamplingModelParser.BDSS + "."
Expand Down
5 changes: 2 additions & 3 deletions src/dr/app/beauti/types/OperatorType.java
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
package dr.app.beauti.types;

import dr.evomodel.operators.BitFlipInSubstitutionModelOperator;
import dr.evomodelxml.operators.TreeNodeSlideParser;
import dr.inference.operators.RateBitExchangeOperator;
import dr.inferencexml.operators.ScaleOperatorParser;

Expand Down Expand Up @@ -67,8 +66,8 @@ public enum OperatorType {
NARROW_EXCHANGE("narrowExchange"),
WIDE_EXCHANGE("wideExchange"),
EMPIRICAL_TREE_SWAP("empiricalSwap"),
GMRF_GIBBS_OPERATOR("gmrfGibbsOperator"),
SKY_GRID_GIBBS_OPERATOR("gmrfGibbsOperator"),
GMRF_BLOCKUPDATE_OPERATOR("gmrfBlockUpdateOperator"),
SKY_GRID_BLOCKUPDATE_OPERATOR("gmrfBlockUpdateOperator"),
SKY_GRID_HMC_OPERATOR("gmrfHMCOperator"),
// PRECISION_GMRF_OPERATOR("precisionGMRFOperator"),
WILSON_BALDING("wilsonBalding"),
Expand Down
31 changes: 30 additions & 1 deletion src/dr/evomodel/continuous/ContinuousDiffusionStatistic.java
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ public class ContinuousDiffusionStatistic extends Statistic.Abstract {
public static final String SPEARMAN = "spearman";
public static final String CORRELATION_COEFFICIENT = "correlationCoefficient";
public static final String DISTANCE_TIME_CORRELATION = "distanceTimeCorrelation";
public static final String SQUAREDDISTANCE_TIME4_CORRELATION = "squaredDistanceTimeFourCorrelation";
public static final String R_SQUARED = "Rsquared";
public static final String STATISTIC = "statistic";
public static final String TRAIT = "trait";
Expand Down Expand Up @@ -458,6 +459,17 @@ public double getStatisticValue(int dim) {
Regression r = new Regression(convertDoubles(times),convertDoubles(distances));
return r.getCorrelationCoefficient();
}
} else if (summaryStat == summaryStatistic.SQUAREDDISTANCE_TIME4_CORRELATION) {
List<Double> squareddistances = squareElements(distances);
if (summaryMode == Mode.SPEARMAN) {
return getSpearmanRho(convertDoubles(times),convertDoubles(squareddistances));
} else if (summaryMode == Mode.R_SQUARED) {
Regression r = new Regression(convertDoubles(times), convertDoubles(squareddistances));
return r.getRSquared();
} else {
Regression r = new Regression(convertDoubles(times),convertDoubles(squareddistances));
return r.getCorrelationCoefficient();
}
} else {
return treeLength;
}
Expand Down Expand Up @@ -490,6 +502,14 @@ private double[] toArray(List<Double> list) {
return returnArray;
}

private static List<Double> squareElements(List<Double> inputList) {
List<Double> squaredList = new ArrayList<>();
for (Double number : inputList) {
squaredList.add(number * number);
}
return squaredList;
}

private double[] imputeValue(double[] nodeValue, double[] parentValue, double time, double nodeHeight, double parentHeight, double[] precisionArray, double rate, boolean trueNoise) {

final double scaledTimeChild = (time - nodeHeight) * rate;
Expand Down Expand Up @@ -932,7 +952,8 @@ enum summaryStatistic {
WAVEFRONT_DISTANCE,
WAVEFRONT_DISTANCE_PHYLO,
WAVEFRONT_RATE,
DISTANCE_TIME_CORRELATION
DISTANCE_TIME_CORRELATION,
SQUAREDDISTANCE_TIME4_CORRELATION
}

enum BranchSet {
Expand Down Expand Up @@ -1023,6 +1044,12 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException {
System.err.println(name+": mode = "+mode+" ignored for "+DISTANCE_TIME_CORRELATION+", reverting to correlation coefficient mode");
statMode = Mode.CORRELATION_COEFFICIENT;
}
} else if (statistic.equals(SQUAREDDISTANCE_TIME4_CORRELATION)) {
summaryStat = summaryStatistic.SQUAREDDISTANCE_TIME4_CORRELATION;
if (mode.equals(AVERAGE) || mode.equals(WEIGHTED_AVERAGE) || mode.equals(COEFFICIENT_OF_VARIATION) || mode.equals(MEDIAN)){
System.err.println(name+": mode = "+mode+" ignored for "+SQUAREDDISTANCE_TIME4_CORRELATION+", reverting to correlation coefficient mode");
statMode = Mode.CORRELATION_COEFFICIENT;
}
} else if (statistic.equals(WAVEFRONT_DISTANCE)) {
summaryStat = summaryStatistic.WAVEFRONT_DISTANCE;
if (!mode.equals(WEIGHTED_AVERAGE)) {
Expand Down Expand Up @@ -1065,6 +1092,8 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException {
summaryStat = summaryStatistic.DIFFUSION_COEFFICIENT;
} else if (statistic.equals(DISTANCE_TIME_CORRELATION)) {
summaryStat = summaryStatistic.DISTANCE_TIME_CORRELATION;
} else if (statistic.equals(SQUAREDDISTANCE_TIME4_CORRELATION)) {
summaryStat = summaryStatistic.SQUAREDDISTANCE_TIME4_CORRELATION;
} else {
System.err.println(name+": unknown statistic: "+statistic+". Reverting to diffusion rate.");
summaryStat = summaryStatistic.DIFFUSION_RATE;
Expand Down

0 comments on commit 3d5c850

Please sign in to comment.