diff --git a/pmml-xgboost-example/src/main/java/org/jpmml/xgboost/example/Main.java b/pmml-xgboost-example/src/main/java/org/jpmml/xgboost/example/Main.java index b965eb4..9bd03ee 100644 --- a/pmml-xgboost-example/src/main/java/org/jpmml/xgboost/example/Main.java +++ b/pmml-xgboost-example/src/main/java/org/jpmml/xgboost/example/Main.java @@ -88,12 +88,6 @@ public class Main { ) private String jsonPath = "$"; - @Parameter ( - names = {"--missing-value"}, - description = "String representation of feature value(s) that should be regarded as missing" - ) - private String missingValue = null; - @Parameter ( names = {"--target-name"}, description = "Target name. Defaults to \"_target\"" @@ -106,6 +100,12 @@ public class Main { ) private List targetCategories = null; + @Parameter ( + names = {"--X-" + HasXGBoostOptions.OPTION_MISSING}, + description = "Missing value. Defaults to Not-a-Number (NaN) value" + ) + private Float missing = Float.NaN; + @Parameter ( names = {"--X-" + HasXGBoostOptions.OPTION_COMPACT}, description = "Transform XGBoost-style trees to PMML-style trees", @@ -127,12 +127,6 @@ public class Main { ) private boolean prune = true; - @Parameter ( - names = {"--X-" + HasXGBoostOptions.OPTION_NAN_AS_MISSING}, - description = "Treat Not-a-Number (NaN) values as missing values" - ) - private boolean nanAsMissing = true; - @Parameter ( names = {"--X-" + HasXGBoostOptions.OPTION_NTREE_LIMIT}, description = "Limit the number of trees. Defaults to all trees" @@ -217,17 +211,13 @@ private void run() throws Exception { logger.info("Parsing embedded feature map"); featureMap = learner.encodeFeatureMap(); - } // End if - - if(this.missingValue != null){ - featureMap.addMissingValue(this.missingValue); } Map options = new LinkedHashMap<>(); + options.put(HasXGBoostOptions.OPTION_MISSING, this.missing); options.put(HasXGBoostOptions.OPTION_COMPACT, this.compact); options.put(HasXGBoostOptions.OPTION_NUMERIC, this.numeric); options.put(HasXGBoostOptions.OPTION_PRUNE, this.prune); - options.put(HasXGBoostOptions.OPTION_NAN_AS_MISSING, this.nanAsMissing); options.put(HasXGBoostOptions.OPTION_NTREE_LIMIT, this.ntreeLimit); PMML pmml; diff --git a/pmml-xgboost/src/main/java/org/jpmml/xgboost/HasXGBoostOptions.java b/pmml-xgboost/src/main/java/org/jpmml/xgboost/HasXGBoostOptions.java index a973fa2..d033046 100644 --- a/pmml-xgboost/src/main/java/org/jpmml/xgboost/HasXGBoostOptions.java +++ b/pmml-xgboost/src/main/java/org/jpmml/xgboost/HasXGBoostOptions.java @@ -32,7 +32,7 @@ public interface HasXGBoostOptions extends HasOptions, HasNativeConfiguration { String OPTION_COMPACT = "compact"; - String OPTION_NAN_AS_MISSING = "nan_as_missing"; + String OPTION_MISSING = "missing"; String OPTION_NTREE_LIMIT = "ntree_limit"; diff --git a/pmml-xgboost/src/main/java/org/jpmml/xgboost/Learner.java b/pmml-xgboost/src/main/java/org/jpmml/xgboost/Learner.java index ac63fd2..870fa27 100644 --- a/pmml-xgboost/src/main/java/org/jpmml/xgboost/Learner.java +++ b/pmml-xgboost/src/main/java/org/jpmml/xgboost/Learner.java @@ -62,7 +62,7 @@ import org.jpmml.converter.PMMLUtil; import org.jpmml.converter.Schema; import org.jpmml.converter.ThresholdFeature; -import org.jpmml.converter.visitors.NaNAsMissingDecorator; +import org.jpmml.converter.ValueUtil; import org.jpmml.converter.visitors.TreeModelPruner; import org.jpmml.xgboost.visitors.TreeModelCompactor; @@ -441,36 +441,65 @@ public int getSplitIndex(Feature feature){ @Override public Feature transformNumerical(Feature feature){ - ContinuousFeature continuousFeature = feature.toContinuousFeature(); - Field field = continuousFeature.getField(); + if(feature instanceof BinaryFeature){ + BinaryFeature binaryFeature = (BinaryFeature)feature; - if(field instanceof DataField){ - DataField dataField = (DataField)field; + return binaryFeature; + } else - PMMLUtil.addValues(dataField, Value.Property.MISSING, Collections.singletonList(missing)); + if(feature instanceof MissingValueFeature){ + MissingValueFeature missingValueFeature = (MissingValueFeature)feature; - return continuousFeature; - } + return missingValueFeature; + } else + + { + ContinuousFeature continuousFeature = feature.toContinuousFeature(); + + Field field = continuousFeature.getField(); + + if(field instanceof DataField){ + DataField dataField = (DataField)field; + + PMMLUtil.addValues(dataField, Value.Property.MISSING, Collections.singletonList(missing)); + + return continuousFeature; + } // End if + + // XXX + if(ValueUtil.isNaN(missing)){ + return continuousFeature; + } - PMMLEncoder encoder = continuousFeature.getEncoder(); + PMMLEncoder encoder = continuousFeature.getEncoder(); - Expression expression = PMMLUtil.createApply(PMMLFunctions.IF, - PMMLUtil.createApply(PMMLFunctions.AND, - PMMLUtil.createApply(PMMLFunctions.ISNOTMISSING, continuousFeature.ref()), - PMMLUtil.createApply(PMMLFunctions.NOTEQUAL, continuousFeature.ref(), PMMLUtil.createConstant(missing)) - ), - continuousFeature.ref() - ); + Expression expression = PMMLUtil.createApply(PMMLFunctions.IF, + PMMLUtil.createApply(PMMLFunctions.AND, + PMMLUtil.createApply(PMMLFunctions.ISNOTMISSING, continuousFeature.ref()), + PMMLUtil.createApply(PMMLFunctions.NOTEQUAL, continuousFeature.ref(), PMMLUtil.createConstant(missing)) + ), + continuousFeature.ref() + ); - DerivedField derivedField = encoder.createDerivedField(FieldNameUtil.create("filter", continuousFeature, missing), OpType.CONTINUOUS, continuousFeature.getDataType(), expression); + DerivedField derivedField = encoder.createDerivedField(FieldNameUtil.create("filter", continuousFeature, missing), OpType.CONTINUOUS, continuousFeature.getDataType(), expression); - return new ContinuousFeature(encoder, derivedField); + return new ContinuousFeature(encoder, derivedField); + } } @Override public Feature transformCategorical(Feature feature){ - throw new IllegalArgumentException(); + + if(feature instanceof CategoricalFeature){ + CategoricalFeature categoricalFeature = (CategoricalFeature)feature; + + return categoricalFeature; + } else + + { + throw new IllegalArgumentException(); + } } }; @@ -480,24 +509,17 @@ public Feature transformCategorical(Feature feature){ public PMML encodePMML(Map options, String targetName, List targetCategories, FeatureMap featureMap){ XGBoostEncoder encoder = new XGBoostEncoder(); - Boolean nanAsMissing = (Boolean)options.get(HasXGBoostOptions.OPTION_NAN_AS_MISSING); - Schema schema = encodeSchema(targetName, targetCategories, featureMap, encoder); MiningModel miningModel = encodeMiningModel(options, schema); PMML pmml = encoder.encodePMML(miningModel); - if((Boolean.TRUE).equals(nanAsMissing)){ - Visitor visitor = new NaNAsMissingDecorator(); - - visitor.applyTo(pmml); - } - return pmml; } public MiningModel encodeMiningModel(Map options, Schema schema){ + Number missing = (Number)options.get(HasXGBoostOptions.OPTION_MISSING); Boolean compact = (Boolean)options.get(HasXGBoostOptions.OPTION_COMPACT); Boolean numeric = (Boolean)options.get(HasXGBoostOptions.OPTION_NUMERIC); Boolean prune = (Boolean)options.get(HasXGBoostOptions.OPTION_PRUNE); @@ -505,6 +527,10 @@ public MiningModel encodeMiningModel(Map options, Schema schema){ if(numeric == null){ numeric = Boolean.TRUE; + } // End if + + if(missing != null){ + schema = toValueFilteredSchema(missing, schema); } MiningModel miningModel = this.gbtree.encodeMiningModel(this.obj, this.base_score, ntreeLimit, numeric, schema) diff --git a/pmml-xgboost/src/main/java/org/jpmml/xgboost/testing/XGBoostEncoderBatch.java b/pmml-xgboost/src/main/java/org/jpmml/xgboost/testing/XGBoostEncoderBatch.java index 0467cdc..3b9d40d 100644 --- a/pmml-xgboost/src/main/java/org/jpmml/xgboost/testing/XGBoostEncoderBatch.java +++ b/pmml-xgboost/src/main/java/org/jpmml/xgboost/testing/XGBoostEncoderBatch.java @@ -63,9 +63,9 @@ public List> getOptionsMatrix(){ } Map options = new LinkedHashMap<>(); + options.put(HasXGBoostOptions.OPTION_MISSING, Float.NaN); options.put(HasXGBoostOptions.OPTION_COMPACT, new Boolean[]{false, true}); options.put(HasXGBoostOptions.OPTION_PRUNE, true); - options.put(HasXGBoostOptions.OPTION_NAN_AS_MISSING, true); options.put(HasXGBoostOptions.OPTION_NTREE_LIMIT, ntreeLimit); return OptionsUtil.generateOptionsMatrix(options);