Skip to content
This repository was archived by the owner on May 6, 2022. It is now read-only.

Feature: API for setting NLU confidence threshold #72

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ public final class TensorflowNLU implements NLUService {
private TensorflowModel nluModel = null;
private TFNLUOutput outputParser = null;
private int maxTokens;
private float confidenceThreshold;
private String fallbackIntent;

private volatile boolean ready = false;

Expand All @@ -92,6 +94,8 @@ private TensorflowNLU(Builder builder) {
this.loadThread.start();
this.padTokenId = this.textEncoder.encodeSingle("[PAD]");
this.sepTokenId = this.textEncoder.encodeSingle("[SEP]");
this.confidenceThreshold = builder.confidenceThreshold;
this.fallbackIntent = builder.fallbackIntent;
}

private void initParsers(Map<String, String> parserClasses) {
Expand Down Expand Up @@ -218,6 +222,15 @@ private NLUResult tfClassify(String utterance, NLUContext nluContext) {
// interpret model outputs
Tuple<Metadata.Intent, Float> prediction = outputParser.getIntent(
this.nluModel.outputs(0));
float confidence = prediction.second();

if (confidence < this.confidenceThreshold) {
return new NLUResult.Builder(utterance)
.withIntent(this.fallbackIntent)
.withConfidence(confidence)
.build();
}

Metadata.Intent intent = prediction.first();
nluContext.traceDebug("Intent: %s", intent.getName());

Expand All @@ -230,7 +243,7 @@ private NLUResult tfClassify(String utterance, NLUContext nluContext) {

return new NLUResult.Builder(utterance)
.withIntent(intent.getName())
.withConfidence(prediction.second())
.withConfidence(confidence)
.withSlots(parsedSlots)
.build();
}
Expand Down Expand Up @@ -286,6 +299,8 @@ public static class Builder {
private TensorflowModel.Loader modelLoader;
private ThreadFactory threadFactory;
private TextEncoder textEncoder;
private float confidenceThreshold;
private String fallbackIntent;

/**
* Creates a new builder instance.
Expand Down Expand Up @@ -344,6 +359,24 @@ public Builder setTextEncoder(TextEncoder encoder) {
return this;
}

/**
* Sets a confidence threshold for classification, below which the
* specified fallback intent will be returned.
*
* @param confidence the lowest confidence value that will be accepted
* as a valid classification.
* @param fallback the name of the intent that will be returned if the
* model's confidence is below {@code confidence}.
* @return this
*/
public Builder setConfidenceThreshold(float confidence,
String fallback) {

this.confidenceThreshold = confidence;
this.fallbackIntent = fallback;
return this;
}

/**
* Sets a configuration value.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ public void classify() throws Exception {

StringBuilder tooManyTokens = new StringBuilder();
for (int i = 0; i <= env.nlu.getMaxTokens(); i++) {
tooManyTokens.append("a ");
tooManyTokens.append("a ");
}
utterance = tooManyTokens.toString();
result = env.classify(utterance).get();
Expand All @@ -86,8 +86,9 @@ public void classify() throws Exception {
assertTrue(result.getSlots().isEmpty());

utterance = "this code is for test 1";
float conf = 0.75f;
float[] intentResult =
buildIntentResult(2, env.metadata.getIntents().length);
buildIntentResult(2, env.metadata.getIntents().length, conf);
float[] tagResult =
new float[utterance.split(" ").length
* env.metadata.getTags().length];
Expand All @@ -104,7 +105,7 @@ public void classify() throws Exception {

assertNull(result.getError());
assertEquals("describe_test", result.getIntent());
assertEquals(10.0, result.getConfidence());
assertEquals(conf, result.getConfidence());
for (String slotName : slots.keySet()) {
assertEquals(slots.get(slotName), result.getSlots().get(slotName));
}
Expand All @@ -117,9 +118,10 @@ public void classify() throws Exception {
// (which is incorrect, but we're just testing the slot extraction
// logic here)
utterance = "this bad code is for test 1";
intentResult = buildIntentResult(2, env.metadata.getIntents().length);
intentResult =
buildIntentResult(2, env.metadata.getIntents().length, conf);
tagResult = new float[utterance.split(" ").length
* env.metadata.getTags().length];
* env.metadata.getTags().length];
setTag(tagResult, env.metadata.getTags().length, 0, 1);
setTag(tagResult, env.metadata.getTags().length, 2, 1);
setTag(tagResult, env.metadata.getTags().length, 6, 3);
Expand All @@ -133,7 +135,7 @@ public void classify() throws Exception {

assertNull(result.getError());
assertEquals("describe_test", result.getIntent());
assertEquals(10.0, result.getConfidence());
assertEquals(conf, result.getConfidence());
for (String slotName : slots.keySet()) {
assertEquals(slots.get(slotName), result.getSlots().get(slotName));
}
Expand All @@ -142,9 +144,37 @@ public void classify() throws Exception {
assertTrue(result.getContext().isEmpty());
}

private float[] buildIntentResult(int index, int numIntents) {
@Test
public void testConfidenceThreshold() throws Exception {
TestEnv env = new TestEnv(testConfig());
env.nluBuilder.setConfidenceThreshold(0.5f, "fallback");

String utterance = "how far is it to the moon?";
float conf = 0.3f;
float[] intentResult =
buildIntentResult(1, env.metadata.getIntents().length, conf);

// include some tags in the result to make sure they're ignored
float[] tagResult =
new float[utterance.split(" ").length
* env.metadata.getTags().length];
setTag(tagResult, env.metadata.getTags().length, 0, 1);
setTag(tagResult, env.metadata.getTags().length, 1, 2);
env.testModel.setOutputs(intentResult, tagResult);
NLUResult result = env.classify(utterance).get();

assertNull(result.getError());
assertEquals("fallback", result.getIntent());
assertEquals(conf, result.getConfidence());
assertTrue(result.getSlots().isEmpty());
assertEquals(utterance, result.getUtterance());
assertTrue(result.getContext().isEmpty());
}

private float[] buildIntentResult(int index, int numIntents,
float confidence) {
float[] result = new float[numIntents];
result[index] = 10;
result[index] = confidence;
return result;
}

Expand Down