Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bump catboost-prediction version to 1.1 #12

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 137 additions & 0 deletions apply_model/java/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
# Created by https://www.toptal.com/developers/gitignore/api/java,intellij+all,maven
# Edit at https://www.toptal.com/developers/gitignore?templates=java,intellij+all,maven

### Intellij+all ###
# Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider
# Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839

# User-specific stuff
.idea/**/workspace.xml
.idea/**/tasks.xml
.idea/**/usage.statistics.xml
.idea/**/dictionaries
.idea/**/shelf

# AWS User-specific
.idea/**/aws.xml

# Generated files
.idea/**/contentModel.xml

# Sensitive or high-churn files
.idea/**/dataSources/
.idea/**/dataSources.ids
.idea/**/dataSources.local.xml
.idea/**/sqlDataSources.xml
.idea/**/dynamic.xml
.idea/**/uiDesigner.xml
.idea/**/dbnavigator.xml

# Gradle
.idea/**/gradle.xml
.idea/**/libraries

# Gradle and Maven with auto-import
# When using Gradle or Maven with auto-import, you should exclude module files,
# since they will be recreated, and may cause churn. Uncomment if using
# auto-import.
# .idea/artifacts
# .idea/compiler.xml
# .idea/jarRepositories.xml
# .idea/modules.xml
# .idea/*.iml
# .idea/modules
# *.iml
# *.ipr

# CMake
cmake-build-*/

# Mongo Explorer plugin
.idea/**/mongoSettings.xml

# File-based project format
*.iws

# IntelliJ
out/

# mpeltonen/sbt-idea plugin
.idea_modules/

# JIRA plugin
atlassian-ide-plugin.xml

# Cursive Clojure plugin
.idea/replstate.xml

# SonarLint plugin
.idea/sonarlint/

# Crashlytics plugin (for Android Studio and IntelliJ)
com_crashlytics_export_strings.xml
crashlytics.properties
crashlytics-build.properties
fabric.properties

# Editor-based Rest Client
.idea/httpRequests

# Android studio 3.1+ serialized cache file
.idea/caches/build_file_checksums.ser

### Intellij+all Patch ###
# Ignore everything but code style settings and run configurations
# that are supposed to be shared within teams.

.idea/*

!.idea/codeStyles
!.idea/runConfigurations

### Java ###
# Compiled class file
*.class

# Log file
*.log

# BlueJ files
*.ctxt

# Mobile Tools for Java (J2ME)
.mtj.tmp/

# Package Files #
*.jar
*.war
*.nar
*.ear
*.zip
*.tar.gz
*.rar

# virtual machine crash logs, see http://www.java.com/en/download/help/error_hotspot.xml
hs_err_pid*
replay_pid*

### Maven ###
target/
pom.xml.tag
pom.xml.releaseBackup
pom.xml.versionsBackup
pom.xml.next
release.properties
dependency-reduced-pom.xml
buildNumber.properties
.mvn/timing.properties
# https://github.com/takari/maven-wrapper#usage-without-binary-jar
.mvn/wrapper/maven-wrapper.jar

# Eclipse m2e generated files
# Eclipse Core
.project
# JDT-specific (Eclipse Java Development Tools)
.classpath

# End of https://www.toptal.com/developers/gitignore/api/java,intellij+all,maven
11 changes: 10 additions & 1 deletion apply_model/java/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,16 @@
<dependency>
<groupId>ai.catboost</groupId>
<artifactId>catboost-prediction</artifactId>
<version>0.2.4</version>
<version>1.1</version>
</dependency>
<dependency>
<groupId>ai.catboost</groupId>
<artifactId>catboost-common</artifactId>
<version>1.1</version>
</dependency>
</dependencies>
<properties>
<maven.compiler.source>11</maven.compiler.source>
<maven.compiler.target>11</maven.compiler.target>
</properties>
</project>
28 changes: 15 additions & 13 deletions apply_model/java/src/main/java/CatBoost4jPredictionTutorial.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@

import java.io.IOException;

import static java.util.Objects.requireNonNull;

public class CatBoost4jPredictionTutorial {
private CatBoostModel adultModel = null;

public CatBoost4jPredictionTutorial() throws CatBoostError, IOException {
// Load "adult.cbm" model that we trained withing Jupyter Notebook
adultModel = CatBoostModel.loadModel(ClassLoader.getSystemResourceAsStream("models/adult.cbm"));
adultModel = CatBoostModel.loadModel(requireNonNull(ClassLoader.getSystemResourceAsStream("models/adult.cbm")));

// You can also try to load your own model just comment out the line above and uncomment two lines below while
// replacing "foo/bar" with path to your model that classifies data from UCI Adult Dataset.
Expand All @@ -35,7 +37,7 @@ public void playWithModelForAdultDataset() throws CatBoostError {

System.out.print("Adult dataset model metainformation\n");

System.out.print(String.format("tree count: %d\n", adultModel.getTreeCount()));
System.out.printf("tree count: %d\n", adultModel.getTreeCount());

// In our case we were solving a binary classification problem (weather person makes over 50K a year), so the
// dimension of the prediction will be 1, it will return probability of the object to belong to the positive
Expand All @@ -46,16 +48,16 @@ public void playWithModelForAdultDataset() throws CatBoostError {
//
// For most of cases prediction dimension will be 1 (for regression and for ranking), it can be N for cases of
// multiclassification, where N is a number of classes.
System.out.print(String.format("prediction dimension: %d\n",adultModel.getPredictionDimension()));
System.out.printf("prediction dimension: %d\n",adultModel.getPredictionDimension());

// Take a note, number of numeric features used by the model may be less than number of numeric features
// that were present in a training dataset. This may happen if, for example, when traing dataset contained
// constant features, they do not carry any information for classifier, so training process will ignore them.
System.out.print(String.format("used numeric feature count: %d\n", adultModel.getUsedNumericFeatureCount()));
System.out.printf("used numeric feature count: %d\n", adultModel.getUsedNumericFeatureCount());

// Number of categoric features used by the classifier may also be less than number of categoric feature present
// in training dataset, for the same reasons as for numeric features.
System.out.print(String.format("used categoric feature count: %d\n", adultModel.getUsedCategoricFeatureCount()));
System.out.printf("used categoric feature count: %d\n", adultModel.getUsedCategoricFeatureCount());

// Ok now lets try to use our model for prediction. We'll look at the test part of Adult dataset. You will need
// to download it [1] from UCI repository. Look for "adult.test", "adult.name" will also be useful because it
Expand Down Expand Up @@ -91,15 +93,15 @@ public void playWithModelForAdultDataset() throws CatBoostError {
// CatBoost doesn't compute "probability", to turn CatBoost prediction into a probability we'll need to apply
// sigmoid function.
final double personAMakesOver50KProbability = sigmoid(personAPrediction.get(0, 0));
System.out.print(String.format("Person A make over 50K a year with probability %f\n", personAMakesOver50KProbability));
System.out.printf("Person A make over 50K a year with probability %f\n", personAMakesOver50KProbability);

// When we were training CatBoost we used a default classification threshold for AUC which is equal to 0.5,
// this means that our formula is optimized for this threashold, though we may change threshold to optimize some
// other metric on a different dataset, but we won't do it in this tutorial.
final double classificationThreshold = 0.5;

final boolean personAMakesOver50K = personAMakesOver50KProbability > classificationThreshold;
System.out.print(String.format("Person A %s\n", answer(personAMakesOver50K)));
System.out.printf("Person A %s\n", answer(personAMakesOver50K));

// Now lets find an example with missing features and income greater than 50K a year. At line 40 of "adult.test"
// we can find following line:
Expand All @@ -124,8 +126,8 @@ public void playWithModelForAdultDataset() throws CatBoostError {
final CatBoostPredictions personBPrediction = adultModel.predict(personBNumericFeatures, personBCategoricFeatures);
final double personBMakeOver50KProbability = sigmoid(personBPrediction.get(0, 0));
final boolean personBMakesOver50K = personBMakeOver50KProbability > classificationThreshold;
System.out.print(String.format("Person B make over 50K a year with probability %f\n", personBMakeOver50KProbability));
System.out.print(String.format("Person B %s\n", answer(personBMakesOver50K)));
System.out.printf("Person B make over 50K a year with probability %f\n", personBMakeOver50KProbability);
System.out.printf("Person B %s\n", answer(personBMakesOver50K));

// There is also a batch interface for model application, e.g. you can apply model to multiple objects at once.
//
Expand All @@ -151,10 +153,10 @@ public void playWithModelForAdultDataset() throws CatBoostError {
System.out.print("Using batch interface\n");

// Predictions should be same as above
System.out.print(String.format("Person A make over 50K a year with probability %f\n", personsABMakeOver50KProbabilities[0]));
System.out.print(String.format("Person A %s\n", answer(personsABMakeOver50K[0])));
System.out.print(String.format("Person B make over 50K a year with probability %f\n", personsABMakeOver50KProbabilities[1]));
System.out.print(String.format("Person B %s\n", answer(personsABMakeOver50K[1])));
System.out.printf("Person A make over 50K a year with probability %f\n", personsABMakeOver50KProbabilities[0]);
System.out.printf("Person A %s\n", answer(personsABMakeOver50K[0]));
System.out.printf("Person B make over 50K a year with probability %f\n", personsABMakeOver50KProbabilities[1]);
System.out.printf("Person B %s\n", answer(personsABMakeOver50K[1]));

// TODO(yazevnul): add examples with hashed categorical features when `CatBoostModel.hashCategoricalFeatures`
// will become public.
Expand Down