Skip to content

Commit

Permalink
Imported Java code from the JPMML-SkLearn project
Browse files Browse the repository at this point in the history
  • Loading branch information
vruusmann committed Aug 31, 2022
1 parent 6575f56 commit 8d07a33
Showing 1 changed file with 102 additions and 32 deletions.
134 changes: 102 additions & 32 deletions pmml-xgboost/src/main/java/org/jpmml/xgboost/Learner.java
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import java.io.Reader;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Set;
Expand All @@ -39,16 +40,26 @@
import com.google.gson.JsonElement;
import com.google.gson.JsonObject;
import com.google.gson.JsonParser;
import org.dmg.pmml.DataField;
import org.dmg.pmml.DataType;
import org.dmg.pmml.DerivedField;
import org.dmg.pmml.Expression;
import org.dmg.pmml.Field;
import org.dmg.pmml.OpType;
import org.dmg.pmml.PMML;
import org.dmg.pmml.PMMLFunctions;
import org.dmg.pmml.Value;
import org.dmg.pmml.Visitor;
import org.dmg.pmml.mining.MiningModel;
import org.jpmml.converter.BinaryFeature;
import org.jpmml.converter.CategoricalFeature;
import org.jpmml.converter.ContinuousFeature;
import org.jpmml.converter.Feature;
import org.jpmml.converter.FieldNameUtil;
import org.jpmml.converter.Label;
import org.jpmml.converter.MissingValueFeature;
import org.jpmml.converter.PMMLEncoder;
import org.jpmml.converter.PMMLUtil;
import org.jpmml.converter.Schema;
import org.jpmml.converter.ThresholdFeature;
import org.jpmml.converter.visitors.NaNAsMissingDecorator;
Expand Down Expand Up @@ -349,28 +360,18 @@ public Schema encodeSchema(String targetName, List<String> targetCategories, Fea
}

public Schema toXGBoostSchema(boolean numeric, Schema schema){
GBTree gbtree = this.gbtree;

Function<Feature, Feature> function = new Function<Feature, Feature>(){
FeatureTransformer function = new FeatureTransformer(){

private List<? extends Feature> features = schema.getFeatures();


@Override
public Feature apply(Feature feature){
int splitType = getSplitType(feature);

switch(splitType){
case Node.SPLIT_NUMERICAL:
return applyNumerical(feature);
case Node.SPLIT_CATEGORICAL:
return applyCategorical(feature);
default:
throw new IllegalArgumentException();
}
public int getSplitIndex(Feature feature){
return this.features.indexOf(feature);
}

private Feature applyNumerical(Feature feature){
@Override
public Feature transformNumerical(Feature feature){

if(feature instanceof BinaryFeature){
BinaryFeature binaryFeature = (BinaryFeature)feature;
Expand Down Expand Up @@ -409,7 +410,8 @@ private Feature applyNumerical(Feature feature){
}
}

private Feature applyCategorical(Feature feature){
@Override
public Feature transformCategorical(Feature feature){

if(feature instanceof CategoricalFeature){
CategoricalFeature categoricalFeature = (CategoricalFeature)feature;
Expand All @@ -421,30 +423,54 @@ private Feature applyCategorical(Feature feature){
throw new IllegalArgumentException();
}
}
};

private int getSplitType(Feature feature){
int splitIndex = this.features.indexOf(feature);
if(splitIndex < 0){
throw new IllegalArgumentException();
}
return schema.toTransformedSchema(function);
}

public Schema toValueFilteredSchema(Number missing, Schema schema){
FeatureTransformer function = new FeatureTransformer(){

private List<? extends Feature> features = schema.getFeatures();

return getSplitType(splitIndex);

@Override
public int getSplitIndex(Feature feature){
return this.features.indexOf(feature);
}

private int getSplitType(int splitIndex){
Set<Integer> splitTypes = gbtree.getSplitType(splitIndex);
@Override
public Feature transformNumerical(Feature feature){
ContinuousFeature continuousFeature = feature.toContinuousFeature();

if(splitTypes.size() == 0){
return Node.SPLIT_NUMERICAL;
} else
Field<?> field = continuousFeature.getField();

if(splitTypes.size() == 1){
return Iterables.getOnlyElement(splitTypes);
} else
if(field instanceof DataField){
DataField dataField = (DataField)field;

{
throw new IllegalArgumentException();
PMMLUtil.addValues(dataField, Value.Property.MISSING, Collections.singletonList(missing));

return continuousFeature;
}

PMMLEncoder encoder = continuousFeature.getEncoder();

Expression expression = PMMLUtil.createApply(PMMLFunctions.IF,
PMMLUtil.createApply(PMMLFunctions.AND,
PMMLUtil.createApply(PMMLFunctions.ISNOTMISSING, continuousFeature.ref()),
PMMLUtil.createApply(PMMLFunctions.NOTEQUAL, continuousFeature.ref(), PMMLUtil.createConstant(missing))
),
continuousFeature.ref()
);

DerivedField derivedField = encoder.createDerivedField(FieldNameUtil.create("filter", continuousFeature, missing), OpType.CONTINUOUS, continuousFeature.getDataType(), expression);

return new ContinuousFeature(encoder, derivedField);
}

@Override
public Feature transformCategorical(Feature feature){
throw new IllegalArgumentException();
}
};

Expand Down Expand Up @@ -576,4 +602,48 @@ private <DIS extends InputStream & DataInput> boolean consumeHeader(DIS is, Stri

return equals;
}

abstract
private class FeatureTransformer implements Function<Feature, Feature> {

abstract
public int getSplitIndex(Feature feature);

abstract
public Feature transformNumerical(Feature feature);

abstract
public Feature transformCategorical(Feature feature);

@Override
public Feature apply(Feature feature){
int splitIndex = getSplitIndex(feature);

int splitType = getSplitType(splitIndex);
switch(splitType){
case Node.SPLIT_NUMERICAL:
return transformNumerical(feature);
case Node.SPLIT_CATEGORICAL:
return transformCategorical(feature);
default:
throw new IllegalArgumentException();
}
}

private int getSplitType(int splitIndex){
Set<Integer> splitTypes = Learner.this.gbtree.getSplitType(splitIndex);

if(splitTypes.size() == 0){
return Node.SPLIT_NUMERICAL;
} else

if(splitTypes.size() == 1){
return Iterables.getOnlyElement(splitTypes);
} else

{
throw new IllegalArgumentException();
}
}
}
}

0 comments on commit 8d07a33

Please sign in to comment.