Skip to content

Commit

Permalink
Improved the handling of Not-a-Number (NaN) values. Fixes #28
Browse files Browse the repository at this point in the history
See commit b791c22
  • Loading branch information
vruusmann committed Aug 31, 2022
1 parent 8d07a33 commit 57192fb
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 46 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -88,12 +88,6 @@ public class Main {
)
private String jsonPath = "$";

@Parameter (
names = {"--missing-value"},
description = "String representation of feature value(s) that should be regarded as missing"
)
private String missingValue = null;

@Parameter (
names = {"--target-name"},
description = "Target name. Defaults to \"_target\""
Expand All @@ -106,6 +100,12 @@ public class Main {
)
private List<String> targetCategories = null;

@Parameter (
names = {"--X-" + HasXGBoostOptions.OPTION_MISSING},
description = "Missing value. Defaults to Not-a-Number (NaN) value"
)
private Float missing = Float.NaN;

@Parameter (
names = {"--X-" + HasXGBoostOptions.OPTION_COMPACT},
description = "Transform XGBoost-style trees to PMML-style trees",
Expand All @@ -127,12 +127,6 @@ public class Main {
)
private boolean prune = true;

@Parameter (
names = {"--X-" + HasXGBoostOptions.OPTION_NAN_AS_MISSING},
description = "Treat Not-a-Number (NaN) values as missing values"
)
private boolean nanAsMissing = true;

@Parameter (
names = {"--X-" + HasXGBoostOptions.OPTION_NTREE_LIMIT},
description = "Limit the number of trees. Defaults to all trees"
Expand Down Expand Up @@ -217,17 +211,13 @@ private void run() throws Exception {
logger.info("Parsing embedded feature map");

featureMap = learner.encodeFeatureMap();
} // End if

if(this.missingValue != null){
featureMap.addMissingValue(this.missingValue);
}

Map<String, Object> options = new LinkedHashMap<>();
options.put(HasXGBoostOptions.OPTION_MISSING, this.missing);
options.put(HasXGBoostOptions.OPTION_COMPACT, this.compact);
options.put(HasXGBoostOptions.OPTION_NUMERIC, this.numeric);
options.put(HasXGBoostOptions.OPTION_PRUNE, this.prune);
options.put(HasXGBoostOptions.OPTION_NAN_AS_MISSING, this.nanAsMissing);
options.put(HasXGBoostOptions.OPTION_NTREE_LIMIT, this.ntreeLimit);

PMML pmml;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ public interface HasXGBoostOptions extends HasOptions, HasNativeConfiguration {

String OPTION_COMPACT = "compact";

String OPTION_NAN_AS_MISSING = "nan_as_missing";
String OPTION_MISSING = "missing";

String OPTION_NTREE_LIMIT = "ntree_limit";

Expand Down
80 changes: 53 additions & 27 deletions pmml-xgboost/src/main/java/org/jpmml/xgboost/Learner.java
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
import org.jpmml.converter.PMMLUtil;
import org.jpmml.converter.Schema;
import org.jpmml.converter.ThresholdFeature;
import org.jpmml.converter.visitors.NaNAsMissingDecorator;
import org.jpmml.converter.ValueUtil;
import org.jpmml.converter.visitors.TreeModelPruner;
import org.jpmml.xgboost.visitors.TreeModelCompactor;

Expand Down Expand Up @@ -441,36 +441,65 @@ public int getSplitIndex(Feature feature){

@Override
public Feature transformNumerical(Feature feature){
ContinuousFeature continuousFeature = feature.toContinuousFeature();

Field<?> field = continuousFeature.getField();
if(feature instanceof BinaryFeature){
BinaryFeature binaryFeature = (BinaryFeature)feature;

if(field instanceof DataField){
DataField dataField = (DataField)field;
return binaryFeature;
} else

PMMLUtil.addValues(dataField, Value.Property.MISSING, Collections.singletonList(missing));
if(feature instanceof MissingValueFeature){
MissingValueFeature missingValueFeature = (MissingValueFeature)feature;

return continuousFeature;
}
return missingValueFeature;
} else

{
ContinuousFeature continuousFeature = feature.toContinuousFeature();

Field<?> field = continuousFeature.getField();

if(field instanceof DataField){
DataField dataField = (DataField)field;

PMMLUtil.addValues(dataField, Value.Property.MISSING, Collections.singletonList(missing));

return continuousFeature;
} // End if

// XXX
if(ValueUtil.isNaN(missing)){
return continuousFeature;
}

PMMLEncoder encoder = continuousFeature.getEncoder();
PMMLEncoder encoder = continuousFeature.getEncoder();

Expression expression = PMMLUtil.createApply(PMMLFunctions.IF,
PMMLUtil.createApply(PMMLFunctions.AND,
PMMLUtil.createApply(PMMLFunctions.ISNOTMISSING, continuousFeature.ref()),
PMMLUtil.createApply(PMMLFunctions.NOTEQUAL, continuousFeature.ref(), PMMLUtil.createConstant(missing))
),
continuousFeature.ref()
);
Expression expression = PMMLUtil.createApply(PMMLFunctions.IF,
PMMLUtil.createApply(PMMLFunctions.AND,
PMMLUtil.createApply(PMMLFunctions.ISNOTMISSING, continuousFeature.ref()),
PMMLUtil.createApply(PMMLFunctions.NOTEQUAL, continuousFeature.ref(), PMMLUtil.createConstant(missing))
),
continuousFeature.ref()
);

DerivedField derivedField = encoder.createDerivedField(FieldNameUtil.create("filter", continuousFeature, missing), OpType.CONTINUOUS, continuousFeature.getDataType(), expression);
DerivedField derivedField = encoder.createDerivedField(FieldNameUtil.create("filter", continuousFeature, missing), OpType.CONTINUOUS, continuousFeature.getDataType(), expression);

return new ContinuousFeature(encoder, derivedField);
return new ContinuousFeature(encoder, derivedField);
}
}

@Override
public Feature transformCategorical(Feature feature){
throw new IllegalArgumentException();

if(feature instanceof CategoricalFeature){
CategoricalFeature categoricalFeature = (CategoricalFeature)feature;

return categoricalFeature;
} else

{
throw new IllegalArgumentException();
}
}
};

Expand All @@ -480,31 +509,28 @@ public Feature transformCategorical(Feature feature){
public PMML encodePMML(Map<String, ?> options, String targetName, List<String> targetCategories, FeatureMap featureMap){
XGBoostEncoder encoder = new XGBoostEncoder();

Boolean nanAsMissing = (Boolean)options.get(HasXGBoostOptions.OPTION_NAN_AS_MISSING);

Schema schema = encodeSchema(targetName, targetCategories, featureMap, encoder);

MiningModel miningModel = encodeMiningModel(options, schema);

PMML pmml = encoder.encodePMML(miningModel);

if((Boolean.TRUE).equals(nanAsMissing)){
Visitor visitor = new NaNAsMissingDecorator();

visitor.applyTo(pmml);
}

return pmml;
}

public MiningModel encodeMiningModel(Map<String, ?> options, Schema schema){
Number missing = (Number)options.get(HasXGBoostOptions.OPTION_MISSING);
Boolean compact = (Boolean)options.get(HasXGBoostOptions.OPTION_COMPACT);
Boolean numeric = (Boolean)options.get(HasXGBoostOptions.OPTION_NUMERIC);
Boolean prune = (Boolean)options.get(HasXGBoostOptions.OPTION_PRUNE);
Integer ntreeLimit = (Integer)options.get(HasXGBoostOptions.OPTION_NTREE_LIMIT);

if(numeric == null){
numeric = Boolean.TRUE;
} // End if

if(missing != null){
schema = toValueFilteredSchema(missing, schema);
}

MiningModel miningModel = this.gbtree.encodeMiningModel(this.obj, this.base_score, ntreeLimit, numeric, schema)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,9 @@ public List<Map<String, Object>> getOptionsMatrix(){
}

Map<String, Object> options = new LinkedHashMap<>();
options.put(HasXGBoostOptions.OPTION_MISSING, Float.NaN);
options.put(HasXGBoostOptions.OPTION_COMPACT, new Boolean[]{false, true});
options.put(HasXGBoostOptions.OPTION_PRUNE, true);
options.put(HasXGBoostOptions.OPTION_NAN_AS_MISSING, true);
options.put(HasXGBoostOptions.OPTION_NTREE_LIMIT, ntreeLimit);

return OptionsUtil.generateOptionsMatrix(options);
Expand Down

0 comments on commit 57192fb

Please sign in to comment.