Skip to content

Commit

Permalink
Improved the numerical stability of summation
Browse files Browse the repository at this point in the history
  • Loading branch information
vruusmann committed Sep 29, 2024
1 parent 49ce1fc commit 14b816c
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import org.dmg.pmml.Model;
import org.dmg.pmml.Output;
import org.dmg.pmml.PMML;
import org.dmg.pmml.Target;
import org.dmg.pmml.Targets;
import org.dmg.pmml.mining.MiningModel;
import org.dmg.pmml.mining.Segment;
Expand Down Expand Up @@ -66,6 +67,21 @@ public ModelTranslator<?> newModelTranslator(Model model){
return modelTranslatorFactory.newModelTranslator(pmml, model);
}

static
public Number extractIntercept(Target target){
Number rescaleFactor = target.getRescaleFactor();
Number rescaleConstant = target.getRescaleConstant();

if(rescaleFactor.doubleValue() == 1d && rescaleConstant.doubleValue() != 0d){
// XXX
target.setRescaleConstant(null);

return rescaleConstant;
}

return null;
}

static
public void checkMiningSchema(Model model){
MiningSchema miningSchema = model.requireMiningSchema();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.PMML;
import org.dmg.pmml.PMMLObject;
import org.dmg.pmml.Target;
import org.dmg.pmml.True;
import org.dmg.pmml.mining.MiningModel;
import org.dmg.pmml.mining.Segment;
Expand All @@ -49,6 +50,7 @@
import org.jpmml.evaluator.Classification;
import org.jpmml.evaluator.ProbabilityAggregator;
import org.jpmml.evaluator.ProbabilityDistribution;
import org.jpmml.evaluator.TargetField;
import org.jpmml.evaluator.Value;
import org.jpmml.evaluator.ValueAggregator;
import org.jpmml.evaluator.ValueFactory;
Expand All @@ -59,14 +61,12 @@
import org.jpmml.translator.IdentifierUtil;
import org.jpmml.translator.JBinaryFileInitializer;
import org.jpmml.translator.JDirectInitializer;
import org.jpmml.translator.JVarBuilder;
import org.jpmml.translator.MethodScope;
import org.jpmml.translator.ModelTranslator;
import org.jpmml.translator.Modifiers;
import org.jpmml.translator.PMMLObjectUtil;
import org.jpmml.translator.Scope;
import org.jpmml.translator.TranslationContext;
import org.jpmml.translator.ValueBuilder;
import org.jpmml.translator.ValueFactoryRef;
import org.jpmml.translator.tree.NodeScoreDistributionManager;
import org.jpmml.translator.tree.NodeScoreManager;
Expand Down Expand Up @@ -295,6 +295,25 @@ private void translateValueAggregatorSegmentation(Segmentation segmentation, Tra

JFieldVar methodsVar = codeInitializer.initLambdas(IdentifierUtil.create("methods", segmentation), modelFuncInterface, methods);

switch(multipleModelMethod){
case SUM:
{
TargetField targetField = getTargetField();

Target target = targetField.getTarget();
if(target != null){
Number intercept = extractIntercept(target);

if(intercept != null){
aggregatorBuilder.update("add", intercept);
}
}
}
break;
default:
break;
}

JBlock block = context.block();

try {
Expand Down Expand Up @@ -361,10 +380,7 @@ private void translateValueAggregatorSegmentation(Segmentation segmentation, Tra
throw new UnsupportedAttributeException(segmentation, multipleModelMethod);
}

JVarBuilder resultBuilder = new ValueBuilder(context)
.declare(context.getValueType(), "result", valueInit);

context._return(resultBuilder.getVariable());
context._return(valueInit);
}

private void translateProbabilityAggregatorSegmentation(Segmentation segmentation, TranslationContext context){
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,11 @@ public JMethod translateRegressor(TranslationContext context){

Target target = Iterables.getOnlyElement(targets);

Number intercept = extractIntercept(target);
if(intercept == null){
intercept = 0;
}

ModelTranslator<?> modelTranslator = new TreeModelTranslator(pmml, treeModel);

Node root = treeModel.getNode();
Expand All @@ -173,10 +178,10 @@ public JMethod translateRegressor(TranslationContext context){

switch(mathContext){
case FLOAT:
resultVar = context.declare(float.class, "result", JExpr.lit(0f));
resultVar = context.declare(float.class, "result", JExpr.lit(intercept.floatValue()));
break;
case DOUBLE:
resultVar = context.declare(double.class, "result", JExpr.lit(0d));
resultVar = context.declare(double.class, "result", JExpr.lit(intercept.doubleValue()));
break;
default:
throw new UnsupportedAttributeException(miningModel, mathContext);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,12 @@ public void evaluateSelectFirstAudit() throws Exception {

@Test
public void evaluateXGBoostAudit() throws Exception {
evaluate(XGBOOST, AUDIT, excludeFields(AUDIT_PROBABILITY_FALSE), new FloatEquivalence(32 + 48));
evaluate(XGBOOST, AUDIT, excludeFields(AUDIT_PROBABILITY_FALSE), new FloatEquivalence(8 + 4));
}

@Test
public void evaluateXGBoostAuditNA() throws Exception {
evaluate(XGBOOST, AUDIT_NA, excludeFields(AUDIT_PROBABILITY_FALSE), new FloatEquivalence(32 + 48));
evaluate(XGBOOST, AUDIT_NA, excludeFields(AUDIT_PROBABILITY_FALSE), new FloatEquivalence(8 + 4));
}

@Test
Expand All @@ -95,7 +95,7 @@ public void evaluateRandomForestSentiment() throws Exception {

@Test
public void evaluateXGBoostSentiment() throws Exception {
evaluate(XGBOOST, SENTIMENT, excludeFields(SENTIMENT_PROBABILITY_FALSE), new FloatEquivalence(24));
evaluate(XGBOOST, SENTIMENT, excludeFields(SENTIMENT_PROBABILITY_FALSE), new FloatEquivalence(8 + 4));
}

@Test
Expand Down Expand Up @@ -125,6 +125,6 @@ public void evaluateRandomForestIris() throws Exception {

@Test
public void evaluateXGBoostIris() throws Exception {
evaluate(XGBOOST, IRIS, new FloatEquivalence(10));
evaluate(XGBOOST, IRIS, new FloatEquivalence(8 + 2));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -80,11 +80,11 @@ public void evaluateVotingEnsembleAuto() throws Exception {

@Test
public void evaluateXGBoostAuto() throws Exception {
evaluate(XGBOOST, AUTO, new FloatEquivalence(8));
evaluate(XGBOOST, AUTO, new FloatEquivalence(2));
}

@Test
public void evaluateXGBoostAutoNA() throws Exception {
evaluate(XGBOOST, AUTO_NA, new FloatEquivalence(8 + 4));
evaluate(XGBOOST, AUTO_NA, new FloatEquivalence(2));
}
}

0 comments on commit 14b816c

Please sign in to comment.