diff --git a/src/main/java/org/jpmml/xgboost/RegTree.java b/src/main/java/org/jpmml/xgboost/RegTree.java index 346625c..c658b16 100644 --- a/src/main/java/org/jpmml/xgboost/RegTree.java +++ b/src/main/java/org/jpmml/xgboost/RegTree.java @@ -28,6 +28,8 @@ import org.dmg.pmml.Predicate; import org.dmg.pmml.SimplePredicate; import org.dmg.pmml.True; +import org.dmg.pmml.tree.BranchNode; +import org.dmg.pmml.tree.LeafNode; import org.dmg.pmml.tree.TreeModel; import org.jpmml.converter.BinaryFeature; import org.jpmml.converter.ContinuousFeature; @@ -89,10 +91,7 @@ public void load(XGBoostDataInput input) throws IOException { } public TreeModel encodeTreeModel(PredicateManager predicateManager, Schema schema){ - org.dmg.pmml.tree.Node root = new org.dmg.pmml.tree.ComplexNode() - .setPredicate(new True()); - - encodeNode(root, predicateManager, 0, schema); + org.dmg.pmml.tree.Node root = encodeNode(new True(), predicateManager, 0, schema); TreeModel treeModel = new TreeModel(MiningFunction.REGRESSION, ModelUtil.createMiningSchema(schema.getLabel()), root) .setSplitCharacteristic(TreeModel.SplitCharacteristic.BINARY_SPLIT) @@ -102,8 +101,8 @@ public TreeModel encodeTreeModel(PredicateManager predicateManager, Schema schem return treeModel; } - private void encodeNode(org.dmg.pmml.tree.Node parent, PredicateManager predicateManager, int index, Schema schema){ - parent.setId(String.valueOf(index + 1)); + private org.dmg.pmml.tree.Node encodeNode(Predicate predicate, PredicateManager predicateManager, int index, Schema schema){ + String id = String.valueOf(index + 1); Node node = this.nodes.get(index); @@ -152,25 +151,28 @@ private void encodeNode(org.dmg.pmml.tree.Node parent, PredicateManager predicat defaultLeft = node.default_left(); } - org.dmg.pmml.tree.Node leftChild = new org.dmg.pmml.tree.ComplexNode() - .setPredicate(leftPredicate); - - encodeNode(leftChild, predicateManager, node.cleft(), schema); - - org.dmg.pmml.tree.Node rightChild = new org.dmg.pmml.tree.ComplexNode() - .setPredicate(rightPredicate); + org.dmg.pmml.tree.Node leftChild = encodeNode(leftPredicate, predicateManager, node.cleft(), schema); + org.dmg.pmml.tree.Node rightChild = encodeNode(rightPredicate, predicateManager, node.cright(), schema); - encodeNode(rightChild, predicateManager, node.cright(), schema); + org.dmg.pmml.tree.Node result = new BranchNode() + .setId(id) + .setScore(null) // XXX + .setDefaultChild(defaultLeft ? leftChild.getId() : rightChild.getId()) + .setPredicate(predicate) + .addNodes(leftChild, rightChild); - parent.addNodes(leftChild, rightChild); - - parent.setDefaultChild(defaultLeft ? leftChild.getId() : rightChild.getId()); + return result; } else { float value = node.leaf_value(); - parent.setScore(ValueUtil.formatValue(value)); + org.dmg.pmml.tree.Node result = new LeafNode() + .setId(id) + .setScore(value) + .setPredicate(predicate); + + return result; } }