Skip to content

Commit

Permalink
Optimized the encoding of Node elements
Browse files Browse the repository at this point in the history
  • Loading branch information
vruusmann committed Jan 18, 2019
1 parent bc4c993 commit 5e0cdea
Showing 1 changed file with 20 additions and 18 deletions.
38 changes: 20 additions & 18 deletions src/main/java/org/jpmml/xgboost/RegTree.java
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
import org.dmg.pmml.Predicate;
import org.dmg.pmml.SimplePredicate;
import org.dmg.pmml.True;
import org.dmg.pmml.tree.BranchNode;
import org.dmg.pmml.tree.LeafNode;
import org.dmg.pmml.tree.TreeModel;
import org.jpmml.converter.BinaryFeature;
import org.jpmml.converter.ContinuousFeature;
Expand Down Expand Up @@ -89,10 +91,7 @@ public void load(XGBoostDataInput input) throws IOException {
}

public TreeModel encodeTreeModel(PredicateManager predicateManager, Schema schema){
org.dmg.pmml.tree.Node root = new org.dmg.pmml.tree.ComplexNode()
.setPredicate(new True());

encodeNode(root, predicateManager, 0, schema);
org.dmg.pmml.tree.Node root = encodeNode(new True(), predicateManager, 0, schema);

TreeModel treeModel = new TreeModel(MiningFunction.REGRESSION, ModelUtil.createMiningSchema(schema.getLabel()), root)
.setSplitCharacteristic(TreeModel.SplitCharacteristic.BINARY_SPLIT)
Expand All @@ -102,8 +101,8 @@ public TreeModel encodeTreeModel(PredicateManager predicateManager, Schema schem
return treeModel;
}

private void encodeNode(org.dmg.pmml.tree.Node parent, PredicateManager predicateManager, int index, Schema schema){
parent.setId(String.valueOf(index + 1));
private org.dmg.pmml.tree.Node encodeNode(Predicate predicate, PredicateManager predicateManager, int index, Schema schema){
String id = String.valueOf(index + 1);

Node node = this.nodes.get(index);

Expand Down Expand Up @@ -152,25 +151,28 @@ private void encodeNode(org.dmg.pmml.tree.Node parent, PredicateManager predicat
defaultLeft = node.default_left();
}

org.dmg.pmml.tree.Node leftChild = new org.dmg.pmml.tree.ComplexNode()
.setPredicate(leftPredicate);

encodeNode(leftChild, predicateManager, node.cleft(), schema);

org.dmg.pmml.tree.Node rightChild = new org.dmg.pmml.tree.ComplexNode()
.setPredicate(rightPredicate);
org.dmg.pmml.tree.Node leftChild = encodeNode(leftPredicate, predicateManager, node.cleft(), schema);
org.dmg.pmml.tree.Node rightChild = encodeNode(rightPredicate, predicateManager, node.cright(), schema);

encodeNode(rightChild, predicateManager, node.cright(), schema);
org.dmg.pmml.tree.Node result = new BranchNode()
.setId(id)
.setScore(null) // XXX
.setDefaultChild(defaultLeft ? leftChild.getId() : rightChild.getId())
.setPredicate(predicate)
.addNodes(leftChild, rightChild);

parent.addNodes(leftChild, rightChild);

parent.setDefaultChild(defaultLeft ? leftChild.getId() : rightChild.getId());
return result;
} else

{
float value = node.leaf_value();

parent.setScore(ValueUtil.formatValue(value));
org.dmg.pmml.tree.Node result = new LeafNode()
.setId(id)
.setScore(value)
.setPredicate(predicate);

return result;
}
}

Expand Down

0 comments on commit 5e0cdea

Please sign in to comment.