diff --git a/pom.xml b/pom.xml
index 431ef3e..95aca72 100644
--- a/pom.xml
+++ b/pom.xml
@@ -52,7 +52,7 @@
org.jpmml
jpmml-converter
- 1.3.4
+ 1.3.5
@@ -65,13 +65,13 @@
org.jpmml
pmml-evaluator
- 1.4.5
+ 1.4.6
test
org.jpmml
pmml-evaluator-test
- 1.4.5
+ 1.4.6
test
diff --git a/src/main/java/org/jpmml/xgboost/RegTree.java b/src/main/java/org/jpmml/xgboost/RegTree.java
index 713bc1c..346625c 100644
--- a/src/main/java/org/jpmml/xgboost/RegTree.java
+++ b/src/main/java/org/jpmml/xgboost/RegTree.java
@@ -89,7 +89,7 @@ public void load(XGBoostDataInput input) throws IOException {
}
public TreeModel encodeTreeModel(PredicateManager predicateManager, Schema schema){
- org.dmg.pmml.tree.Node root = new org.dmg.pmml.tree.Node()
+ org.dmg.pmml.tree.Node root = new org.dmg.pmml.tree.ComplexNode()
.setPredicate(new True());
encodeNode(root, predicateManager, 0, schema);
@@ -152,12 +152,12 @@ private void encodeNode(org.dmg.pmml.tree.Node parent, PredicateManager predicat
defaultLeft = node.default_left();
}
- org.dmg.pmml.tree.Node leftChild = new org.dmg.pmml.tree.Node()
+ org.dmg.pmml.tree.Node leftChild = new org.dmg.pmml.tree.ComplexNode()
.setPredicate(leftPredicate);
encodeNode(leftChild, predicateManager, node.cleft(), schema);
- org.dmg.pmml.tree.Node rightChild = new org.dmg.pmml.tree.Node()
+ org.dmg.pmml.tree.Node rightChild = new org.dmg.pmml.tree.ComplexNode()
.setPredicate(rightPredicate);
encodeNode(rightChild, predicateManager, node.cright(), schema);
diff --git a/src/main/java/org/jpmml/xgboost/visitors/TreeModelCompactor.java b/src/main/java/org/jpmml/xgboost/visitors/TreeModelCompactor.java
index 86c8dad..1210981 100644
--- a/src/main/java/org/jpmml/xgboost/visitors/TreeModelCompactor.java
+++ b/src/main/java/org/jpmml/xgboost/visitors/TreeModelCompactor.java
@@ -30,9 +30,9 @@ public class TreeModelCompactor extends AbstractTreeModelTransformer {
@Override
public void enterNode(Node node){
- String defaultChild = node.getDefaultChild();
String id = node.getId();
- String score = node.getScore();
+ Object score = node.getScore();
+ String defaultChild = node.getDefaultChild();
if(id == null){
throw new IllegalArgumentException();
@@ -41,7 +41,7 @@ public void enterNode(Node node){
if(node.hasNodes()){
List children = node.getNodes();
- if(children.size() != 2 || defaultChild == null || score != null){
+ if(children.size() != 2 || score != null || defaultChild == null){
throw new IllegalArgumentException();
}
@@ -69,7 +69,7 @@ public void enterNode(Node node){
} else
{
- if(defaultChild != null || score == null){
+ if(score == null || defaultChild != null){
throw new IllegalArgumentException();
}
}