From 305b0951e1d85c4f6ffe11883e446f391c4991b1 Mon Sep 17 00:00:00 2001
From: mdhall272 <mdhall272@gmail.com>
Date: Wed, 9 Oct 2024 10:20:53 +0100
Subject: [PATCH] Stopped BeagleTreeLikelihoodParser from requiring a
 GammaSiteRateModel (only actually relevant for codon models)

---
 src/dr/evomodel/siteratemodel/SiteRateModel.java      |  6 ++++++
 .../siteratemodel/FreeRateSiteRateModelParser.java    | 11 +++++------
 .../treelikelihood/BeagleTreeLikelihoodParser.java    |  7 ++++---
 3 files changed, 15 insertions(+), 9 deletions(-)

diff --git a/src/dr/evomodel/siteratemodel/SiteRateModel.java b/src/dr/evomodel/siteratemodel/SiteRateModel.java
index 56e1d9bf7e..afb331025f 100644
--- a/src/dr/evomodel/siteratemodel/SiteRateModel.java
+++ b/src/dr/evomodel/siteratemodel/SiteRateModel.java
@@ -27,6 +27,7 @@
 
 package dr.evomodel.siteratemodel;
 
+import dr.evomodel.substmodel.SubstitutionModel;
 import dr.inference.model.Model;
 
 /**
@@ -73,4 +74,9 @@ public interface SiteRateModel extends Model {
      */
     double getProportionForCategory(int category);
 
+    // Added this because some classes still had gamma as a substitution model hard-coded. There are probably better
+    // ways
+
+    SubstitutionModel getSubstitutionModel();
+
 }
diff --git a/src/dr/evomodelxml/siteratemodel/FreeRateSiteRateModelParser.java b/src/dr/evomodelxml/siteratemodel/FreeRateSiteRateModelParser.java
index ed8c0895f8..395c8c9b86 100644
--- a/src/dr/evomodelxml/siteratemodel/FreeRateSiteRateModelParser.java
+++ b/src/dr/evomodelxml/siteratemodel/FreeRateSiteRateModelParser.java
@@ -88,10 +88,6 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException {
             substitutionModel = (SubstitutionModel)xo.getElementFirstChild(SUBSTITUTION_MODEL);
         }
 
-        if(xo.hasChildNamed(BRANCH_SUBSTITUTION_MODEL)){
-            substitutionModel = (SubstitutionModel)xo.getElementFirstChild(BRANCH_SUBSTITUTION_MODEL);
-        }
-
         int catCount = 4;
         catCount = xo.getIntegerAttribute(CATEGORIES);
 
@@ -122,8 +118,11 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException {
 
         DiscretizedSiteRateModel siteRateModel =  new DiscretizedSiteRateModel(SiteModel.SITE_MODEL, muParam, muWeight, delegate);
 
-        siteRateModel.setSubstitutionModel(substitutionModel);
-        siteRateModel.addModel(substitutionModel);
+        if(substitutionModel!=null){
+            siteRateModel.setSubstitutionModel(substitutionModel);
+            siteRateModel.addModel(substitutionModel);
+        }
+
 
         return siteRateModel;
     }
diff --git a/src/dr/evomodelxml/treelikelihood/BeagleTreeLikelihoodParser.java b/src/dr/evomodelxml/treelikelihood/BeagleTreeLikelihoodParser.java
index d4974f77a2..81e000b118 100644
--- a/src/dr/evomodelxml/treelikelihood/BeagleTreeLikelihoodParser.java
+++ b/src/dr/evomodelxml/treelikelihood/BeagleTreeLikelihoodParser.java
@@ -33,6 +33,7 @@
 import dr.evomodel.branchmodel.EpochBranchModel;
 import dr.evomodel.branchmodel.HomogeneousBranchModel;
 import dr.evomodel.siteratemodel.GammaSiteRateModel;
+import dr.evomodel.siteratemodel.SiteRateModel;
 import dr.evomodel.substmodel.FrequencyModel;
 import dr.evomodel.substmodel.SubstitutionModel;
 import dr.evomodel.treelikelihood.AbstractTreeLikelihood;
@@ -76,7 +77,7 @@ public String getParserName() {
 
     protected BeagleTreeLikelihood createTreeLikelihood(PatternList patternList, MutableTreeModel treeModel,
                                                         BranchModel branchModel,
-                                                        GammaSiteRateModel siteRateModel,
+                                                        SiteRateModel siteRateModel,
                                                         BranchRateModel branchRateModel,
                                                         TipStatesModel tipStatesModel,
                                                         boolean useAmbiguities, PartialsRescalingScheme scalingScheme,
@@ -103,7 +104,7 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException {
 
         PatternList patternList = (PatternList) xo.getChild(PatternList.class);
         MutableTreeModel treeModel = (MutableTreeModel) xo.getChild(MutableTreeModel.class);
-        GammaSiteRateModel siteRateModel = (GammaSiteRateModel) xo.getChild(GammaSiteRateModel.class);
+        SiteRateModel siteRateModel = (SiteRateModel) xo.getChild(SiteRateModel.class);
 
         FrequencyModel rootFreqModel = (FrequencyModel) xo.getChild(FrequencyModel.class);
 
@@ -204,7 +205,7 @@ public Class getReturnType() {
             AttributeRule.newBooleanRule(USE_AMBIGUITIES, true),
             new ElementRule(PatternList.class),
             new ElementRule(TreeModel.class),
-            new ElementRule(GammaSiteRateModel.class),
+            new ElementRule(SiteRateModel.class),
             new ElementRule(BranchModel.class, true),
             new ElementRule(SubstitutionModel.class, true),
             new ElementRule(BranchRateModel.class, true),