Skip to content

Commit

Permalink
Extracted XGBoostTestBatch abstract base class
Browse files Browse the repository at this point in the history
  • Loading branch information
vruusmann committed Jun 28, 2020
1 parent 650f702 commit 4f38978
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 23 deletions.
26 changes: 3 additions & 23 deletions src/test/java/org/jpmml/xgboost/XGBoostTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@

import java.io.IOException;
import java.io.InputStream;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Predicate;
Expand All @@ -31,7 +30,6 @@
import org.jpmml.evaluator.ResultField;
import org.jpmml.evaluator.testing.ArchiveBatch;
import org.jpmml.evaluator.testing.IntegrationTest;
import org.jpmml.evaluator.testing.IntegrationTestBatch;

public class XGBoostTest extends IntegrationTest {

Expand All @@ -41,10 +39,10 @@ public XGBoostTest(Equivalence<Object> equivalence){

@Override
protected ArchiveBatch createBatch(String name, String dataset, Predicate<ResultField> predicate, Equivalence<Object> equivalence){
ArchiveBatch result = new IntegrationTestBatch(name, dataset, predicate, equivalence){
ArchiveBatch result = new XGBoostTestBatch(name, dataset, predicate, equivalence){

@Override
public IntegrationTest getIntegrationTest(){
public XGBoostTest getIntegrationTest(){
return XGBoostTest.this;
}

Expand All @@ -64,14 +62,7 @@ public PMML getPMML() throws Exception {
featureMap = XGBoostUtil.loadFeatureMap(is);
}

Integer ntreeLimit = null;
if(dataset.length > 1){
ntreeLimit = new Integer(dataset[1]);
}

Map<String, Object> options = new LinkedHashMap<>();
options.put(HasXGBoostOptions.OPTION_COMPACT, ntreeLimit != null);
options.put(HasXGBoostOptions.OPTION_NTREE_LIMIT, ntreeLimit);
Map<String, ?> options = getOptions();

PMML pmml = learner.encodePMML(null, null, featureMap, options);

Expand All @@ -91,17 +82,6 @@ public List<Map<FieldName, String>> getInput() throws IOException {
public List<Map<FieldName, String>> getOutput() throws IOException {
return loadRecords("/csv/" + (getName() + getDataset()) + ".csv");
}

private String[] parseDataset(){
String dataset = getDataset();

int index = dataset.indexOf('@');
if(index > -1){
return new String[]{dataset.substring(0, index), dataset.substring(index + 1)};
}

return new String[]{dataset};
}
};

return result;
Expand Down
65 changes: 65 additions & 0 deletions src/test/java/org/jpmml/xgboost/XGBoostTestBatch.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
/*
* Copyright (c) 2020 Villu Ruusmann
*
* This file is part of JPMML-XGBoost
*
* JPMML-XGBoost is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* JPMML-XGBoost is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with JPMML-XGBoost. If not, see <http://www.gnu.org/licenses/>.
*/
package org.jpmml.xgboost;

import java.util.LinkedHashMap;
import java.util.Map;
import java.util.function.Predicate;

import com.google.common.base.Equivalence;
import org.jpmml.evaluator.ResultField;
import org.jpmml.evaluator.testing.IntegrationTestBatch;

abstract
public class XGBoostTestBatch extends IntegrationTestBatch {

public XGBoostTestBatch(String name, String dataset, Predicate<ResultField> predicate, Equivalence<Object> equivalence){
super(name, dataset, predicate, equivalence);
}

@Override
abstract
public XGBoostTest getIntegrationTest();

public Map<String, Object> getOptions(){
String[] dataset = parseDataset();

Integer ntreeLimit = null;
if(dataset.length > 1){
ntreeLimit = new Integer(dataset[1]);
}

Map<String, Object> options = new LinkedHashMap<>();
options.put(HasXGBoostOptions.OPTION_COMPACT, ntreeLimit != null);
options.put(HasXGBoostOptions.OPTION_NTREE_LIMIT, ntreeLimit);

return options;
}

protected String[] parseDataset(){
String dataset = getDataset();

int index = dataset.indexOf('@');
if(index > -1){
return new String[]{dataset.substring(0, index), dataset.substring(index + 1)};
}

return new String[]{dataset};
}
}

0 comments on commit 4f38978

Please sign in to comment.