diff --git a/src/dr/app/beauti/components/marginalLikelihoodEstimation/MarginalLikelihoodEstimationGenerator.java b/src/dr/app/beauti/components/marginalLikelihoodEstimation/MarginalLikelihoodEstimationGenerator.java index 21f59d727a..8b19fdcfde 100644 --- a/src/dr/app/beauti/components/marginalLikelihoodEstimation/MarginalLikelihoodEstimationGenerator.java +++ b/src/dr/app/beauti/components/marginalLikelihoodEstimation/MarginalLikelihoodEstimationGenerator.java @@ -37,6 +37,8 @@ import dr.evolution.util.Taxa; import dr.evolution.util.Units; import dr.evomodel.branchratemodel.BranchRateModel; +import dr.evomodel.branchratemodel.BranchSpecificFixedEffects; +import dr.inference.distribution.DistributionLikelihood; import dr.evomodel.tree.DefaultTreeModel; import dr.evomodelxml.TreeWorkingPriorParsers; import dr.evomodelxml.branchratemodel.*; @@ -1026,6 +1028,20 @@ public void writeMLE(XMLWriter writer, MarginalLikelihoodEstimationOptions optio break; case MIXED_EFFECTS_CLOCK: + + writer.writeIDref(DistributionLikelihood.DISTRIBUTION_LIKELIHOOD, model.getPrefix() + BranchSpecificFixedEffects.RATES_PRIOR); + writer.writeIDref(DistributionLikelihood.DISTRIBUTION_LIKELIHOOD, model.getPrefix() + BranchSpecificFixedEffects.SCALE_PRIOR); + writer.writeIDref(DistributionLikelihood.DISTRIBUTION_LIKELIHOOD, model.getPrefix() + BranchSpecificFixedEffects.INTERCEPT_PRIOR); + + String coeff = BranchSpecificFixedEffectsParser.COEFFICIENT; + int number = 1; + String concat = coeff + number; + while (model.hasParameter(concat)) { + writer.writeIDref(DistributionLikelihood.DISTRIBUTION_LIKELIHOOD, model.getPrefix() + BranchSpecificFixedEffectsParser.FIXED_EFFECTS_LIKELIHOOD + number); + number++; + concat = coeff + number; + } + break; default: diff --git a/src/dr/app/beauti/generator/ClockModelGenerator.java b/src/dr/app/beauti/generator/ClockModelGenerator.java index 093411b4e5..c19446180c 100644 --- a/src/dr/app/beauti/generator/ClockModelGenerator.java +++ b/src/dr/app/beauti/generator/ClockModelGenerator.java @@ -30,7 +30,6 @@ import dr.app.beauti.components.ComponentFactory; import dr.app.beauti.options.*; import dr.app.beauti.types.ClockType; -import dr.app.beauti.types.OperatorType; import dr.app.beauti.util.XMLWriter; import dr.evolution.util.Taxa; import dr.evomodel.branchratemodel.ArbitraryBranchRates; @@ -66,8 +65,6 @@ import dr.util.Attribute; import dr.xml.XMLParser; -import java.util.Map; - import static dr.inference.model.ParameterParser.PARAMETER; import static dr.inferencexml.distribution.PriorParsers.*; import static dr.inferencexml.distribution.shrinkage.BayesianBridgeLikelihoodParser.*; @@ -301,16 +298,19 @@ public void writeBranchRatesModel(PartitionClockModel clockModel, XMLWriter writ writeCovarianceStatistic(writer, tag, prefix, treePrefix); - //TODO add more String constants for this type of code + boolean generateRatesGradient = false; boolean generateScaleGradient = false; for (Operator operator : options.selectOperators()) { - if (operator.getName().equals("HMC relaxed clock location and scale") && operator.isUsed()) { + if (operator.getName().equals(ClockType.HMC_CLOCK_RATES_DESCRIPTION) && operator.isUsed()) { + generateRatesGradient = true; + } + if (operator.getName().equals(ClockType.HMC_CLOCK_LOCATION_SCALE_DESCRIPTION) && operator.isUsed()) { generateScaleGradient = true; } } - if (generateScaleGradient) { + if (generateRatesGradient) { //scale prior writer.writeOpenTag(DistributionLikelihood.DISTRIBUTION_LIKELIHOOD, @@ -352,6 +352,9 @@ public void writeBranchRatesModel(PartitionClockModel clockModel, XMLWriter writ writer.writeCloseTag(LocationScaleGradientParser.LOCATION); writer.writeCloseTag(LocationScaleGradientParser.NAME); + } + + if (generateScaleGradient){ //scale gradient writer.writeOpenTag(LocationScaleGradientParser.NAME, new Attribute[]{ new Attribute.Default<>(XMLParser.ID, prefix + ScaleGradient.SCALE_GRADIENT), @@ -958,18 +961,18 @@ public static void writeBranchRatesModelRef(PartitionClockModel model, XMLWriter case MIXED_EFFECTS_CLOCK: //always write distribution likelihoods for rate, scale and intercept - writer.writeIDref(DistributionLikelihood.DISTRIBUTION_LIKELIHOOD, BranchSpecificFixedEffects.RATES_PRIOR); - writer.writeIDref(DistributionLikelihood.DISTRIBUTION_LIKELIHOOD, BranchSpecificFixedEffects.SCALE_PRIOR); - writer.writeIDref(DistributionLikelihood.DISTRIBUTION_LIKELIHOOD, BranchSpecificFixedEffects.INTERCEPT_PRIOR); + //writer.writeIDref(DistributionLikelihood.DISTRIBUTION_LIKELIHOOD, BranchSpecificFixedEffects.RATES_PRIOR); + //writer.writeIDref(DistributionLikelihood.DISTRIBUTION_LIKELIHOOD, BranchSpecificFixedEffects.SCALE_PRIOR); + //writer.writeIDref(DistributionLikelihood.DISTRIBUTION_LIKELIHOOD, BranchSpecificFixedEffects.INTERCEPT_PRIOR); //check for coefficients - String coeff = BranchSpecificFixedEffectsParser.COEFFICIENT; + /*String coeff = BranchSpecificFixedEffectsParser.COEFFICIENT; int number = 1; String concat = coeff + number; while (model.hasParameter(concat)) { writer.writeIDref(DistributionLikelihood.DISTRIBUTION_LIKELIHOOD, BranchSpecificFixedEffectsParser.FIXED_EFFECTS_LIKELIHOOD + number); number++; concat = coeff + number; - } + }*/ tag = ArbitraryBranchRatesParser.ARBITRARY_BRANCH_RATES; id = model.getPrefix() + ArbitraryBranchRates.BRANCH_RATES; break; diff --git a/src/dr/app/beauti/generator/ParameterPriorGenerator.java b/src/dr/app/beauti/generator/ParameterPriorGenerator.java index 339932202e..aa8c0613bb 100644 --- a/src/dr/app/beauti/generator/ParameterPriorGenerator.java +++ b/src/dr/app/beauti/generator/ParameterPriorGenerator.java @@ -35,13 +35,13 @@ import dr.evolution.util.Taxa; import dr.evomodel.branchratemodel.BranchSpecificFixedEffects; import dr.evomodel.tree.DefaultTreeModel; +import dr.evomodelxml.branchratemodel.BranchSpecificFixedEffectsParser; import dr.evomodelxml.coalescent.GMRFSkyrideLikelihoodParser; import dr.evomodelxml.tree.CTMCScalePriorParser; import dr.evomodelxml.tree.MonophylyStatisticParser; import dr.inference.distribution.DistributionLikelihood; import dr.inference.model.ParameterParser; import dr.inferencexml.distribution.CachedDistributionLikelihoodParser; -import dr.inferencexml.distribution.DistributionLikelihoodParser; import dr.inferencexml.distribution.PriorParsers; import dr.inferencexml.model.BooleanLikelihoodParser; import dr.inferencexml.model.OneOnXPriorParser; @@ -58,19 +58,48 @@ */ public class ParameterPriorGenerator extends Generator { - //map parameters to prior IDs, for use with HMC - private HashMap mapParameterToPrior; + //map parameters to prior IDs, for use with HMC or other approaches that define their prior befor the XML block + private final HashMap mapParameterToPrior; public ParameterPriorGenerator(BeautiOptions options, ComponentFactory[] components) { super(options, components); //TODO don't like this being here, but will see how things pan out as more HMC approaches are added mapParameterToPrior = new HashMap(); + } + + /** + * Add all possibly previously defined priors to a HashMap + * Cannot be done in constructor as the models have not been defined by the user at that point + */ + public void addParametersToPrior() { + int totalModels = options.getPartitionClockModels().size(); + List partitionClockModels = options.getPartitionClockModels(); //HMC skygrid mapParameterToPrior.put(GMRFSkyrideLikelihoodParser.SKYGRID_PRECISION, GMRFSkyrideLikelihoodParser.SKYGRID_PRECISION_PRIOR); //HMC relaxed clock - mapParameterToPrior.put(ClockType.HMC_CLOCK_LOCATION, BranchSpecificFixedEffects.LOCATION_PRIOR); - mapParameterToPrior.put(ClockType.HMC_CLOCK_BRANCH_RATES, BranchSpecificFixedEffects.RATES_PRIOR); - mapParameterToPrior.put(ClockType.HMCLN_SCALE, BranchSpecificFixedEffects.SCALE_PRIOR); + for (int i = 0; i < totalModels; i++) { + String prefix = partitionClockModels.get(i).getPrefix(); + mapParameterToPrior.put(ClockType.HMC_CLOCK_LOCATION, prefix + BranchSpecificFixedEffects.LOCATION_PRIOR); + mapParameterToPrior.put(ClockType.HMC_CLOCK_BRANCH_RATES, prefix + BranchSpecificFixedEffects.RATES_PRIOR); + mapParameterToPrior.put(ClockType.HMCLN_SCALE, prefix + BranchSpecificFixedEffects.SCALE_PRIOR); + } + //mixed effects clock + //always write distribution likelihoods for rate, scale and intercept + for (int i = 0; i < totalModels; i++) { + String prefix = partitionClockModels.get(i).getPrefix(); + mapParameterToPrior.put(ClockType.ME_CLOCK_LOCATION, prefix + BranchSpecificFixedEffects.RATES_PRIOR); + mapParameterToPrior.put(ClockType.ME_CLOCK_SCALE, prefix + BranchSpecificFixedEffects.SCALE_PRIOR); + mapParameterToPrior.put(BranchSpecificFixedEffectsParser.INTERCEPT, prefix + BranchSpecificFixedEffects.INTERCEPT_PRIOR); + //check for coefficients + String coeff = BranchSpecificFixedEffectsParser.COEFFICIENT; + int number = 1; + String concat = coeff + number; + while (partitionClockModels.get(i).hasParameter(concat)) { + mapParameterToPrior.put(concat, prefix + BranchSpecificFixedEffectsParser.FIXED_EFFECTS_LIKELIHOOD + number); + number++; + concat = coeff + number; + } + } } /** @@ -79,6 +108,10 @@ public ParameterPriorGenerator(BeautiOptions options, ComponentFactory[] compone * @param writer the writer */ public void writeParameterPriors(XMLWriter writer) { + + //first make sure that all possibly previously defined priors are part of the HashMap + addParametersToPrior(); + boolean first = true; for (Map.Entry taxaBooleanEntry : options.taxonSetsMono.entrySet()) { diff --git a/src/dr/app/beauti/generator/SubstitutionModelGenerator.java b/src/dr/app/beauti/generator/SubstitutionModelGenerator.java index 368448d91f..b972cec104 100644 --- a/src/dr/app/beauti/generator/SubstitutionModelGenerator.java +++ b/src/dr/app/beauti/generator/SubstitutionModelGenerator.java @@ -753,7 +753,7 @@ private void writeTwoStateSiteModel(XMLWriter writer, PartitionSubstitutionModel if (options.useNuRelativeRates()) { Parameter parameter = model.getParameter("nu"); String prefix1 = options.getPrefix(); - if (!parameter.getSubParameters().isEmpty()) { + if (parameter.getParent() != null && !parameter.getSubParameters().isEmpty()) { writeNuRelativeRateBlock(writer, prefix1, parameter); } } else { @@ -802,7 +802,9 @@ private void writeAASiteModel(XMLWriter writer, PartitionSubstitutionModel model if (options.useNuRelativeRates()) { Parameter parameter = model.getParameter("nu"); - writeNuRelativeRateBlock(writer, prefix, parameter); + if (parameter.getParent() != null && !parameter.getSubParameters().isEmpty()) { + writeNuRelativeRateBlock(writer, prefix, parameter); + } } else { writeParameter(SiteModelParser.RELATIVE_RATE, "mu", model, writer); } diff --git a/src/dr/app/beauti/options/PartitionClockModel.java b/src/dr/app/beauti/options/PartitionClockModel.java index 7079014a2b..8ade66566c 100644 --- a/src/dr/app/beauti/options/PartitionClockModel.java +++ b/src/dr/app/beauti/options/PartitionClockModel.java @@ -130,7 +130,7 @@ public void initModelParametersAndOpererators() { .initial(1.0).mean(1.0).offset(0.0).partitionOptions(this).isPriorFixed(true) .isAdaptiveMultivariateCompatible(false).build(parameters); - new Parameter.Builder(ClockType.HMC_CLOCK_BRANCH_RATES, "HMC relaxed clock branch rates") + new Parameter.Builder(ClockType.HMC_CLOCK_BRANCH_RATES, ClockType.HMC_CLOCK_RATES_DESCRIPTION) .prior(PriorType.LOGNORMAL_HPM_PRIOR).initial(0.001).isNonNegative(true) .partitionOptions(this).isPriorFixed(true) .isAdaptiveMultivariateCompatible(false).build(parameters); @@ -216,11 +216,11 @@ public void initModelParametersAndOpererators() { createScaleOperator(ClockType.UCGD_SHAPE, demoTuning, rateWeights); //HMC relaxed clock - createOperator("HMCRCR", "HMC relaxed clock branch rates", + createOperator("HMCRCR", ClockType.HMC_CLOCK_RATES_DESCRIPTION, "Hamiltonian Monte Carlo relaxed clock branch rates operator", null, OperatorType.RELAXED_CLOCK_HMC_RATE_OPERATOR,-1 , 1.0); - createOperator("HMCRCS", "HMC relaxed clock location and scale", + createOperator("HMCRCS", ClockType.HMC_CLOCK_LOCATION_SCALE_DESCRIPTION, "Hamiltonian Monte Carlo relaxed clock scale operator", null, OperatorType.RELAXED_CLOCK_HMC_SCALE_OPERATOR,-1 , 0.5); - //for the time being turn off the HMC relaxed clock scale kernel + //turn off the HMC relaxed clock scale kernel by default getOperator("HMCRCS").setUsed(false); createScaleOperator(ClockType.HMC_CLOCK_LOCATION, demoTuning, rateWeights); createScaleOperator(ClockType.HMCLN_SCALE, demoTuning, rateWeights); diff --git a/src/dr/app/beauti/types/ClockType.java b/src/dr/app/beauti/types/ClockType.java index 09029542d9..e88d521d85 100644 --- a/src/dr/app/beauti/types/ClockType.java +++ b/src/dr/app/beauti/types/ClockType.java @@ -68,4 +68,7 @@ public String toString() { final public static String ACLD_MEAN = "acld.mean"; final public static String ACLD_STDEV = "acld.stdev"; + + final public static String HMC_CLOCK_RATES_DESCRIPTION = "HMC relaxed clock branch rates"; + final public static String HMC_CLOCK_LOCATION_SCALE_DESCRIPTION = "HMC relaxed clock location and scale"; } \ No newline at end of file