Skip to content

Commit

Permalink
Merge pull request #339 from RasaHQ/7078/intent_from_fallback_helper
Browse files Browse the repository at this point in the history
7078/intent_from_fallback_helper
  • Loading branch information
joejuzl authored Nov 16, 2020
2 parents c655e0a + b1092ac commit a4eac21
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 2 deletions.
2 changes: 2 additions & 0 deletions changelog/7078.improvement.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Adds the method `get_intent_of_latest_message` to the `Tracker` allowing easier
access to the user's latest intent in case of an `nlu_fallback`.
18 changes: 18 additions & 0 deletions docs/docs/sdk-tracker.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -119,3 +119,21 @@ Return a list of events after the most recent restart.
* **Return type**

`Optional[Any]`

### Tracker.get_intent_of_latest_message

Retrieves the user's latest intent.

* **Parameters**

* `skip_fallback_intent` (default: `True`) – Optionally skip the `nlu_fallback` intent and return the next highest ranked.


* **Returns**

The intent of the latest message if available.


* **Return type**

`Optional[Text]`
33 changes: 33 additions & 0 deletions rasa_sdk/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
logger = logging.getLogger(__name__)

ACTION_LISTEN_NAME = "action_listen"
NLU_FALLBACK_INTENT_NAME = "nlu_fallback"


class Tracker:
Expand Down Expand Up @@ -271,6 +272,38 @@ def add_slots(self, slots: List[EventType]) -> None:
self.slots[event["name"]] = event["value"]
self.events.append(event)

def get_intent_of_latest_message(
self, skip_fallback_intent: bool = True
) -> Optional[Text]:
"""Retrieves the intent the last user message.
Args:
skip_fallback_intent: Optionally skip the nlu_fallback intent
and return the next.
Returns:
Intent of latest message if available.
"""
latest_message = self.latest_message
if not latest_message:
return None

intent_ranking = latest_message.get("intent_ranking")
if not intent_ranking:
return None

highest_ranking_intent = intent_ranking[0]
if (
highest_ranking_intent["name"] == NLU_FALLBACK_INTENT_NAME
and skip_fallback_intent
):
if len(intent_ranking) >= 2:
return intent_ranking[1]["name"]
else:
return None
else:
return highest_ranking_intent["name"]


class Action:
"""Next action to be taken in response to a dialogue state."""
Expand Down
45 changes: 43 additions & 2 deletions tests/test_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,10 @@
from rasa_sdk.events import (
ActionExecuted,
UserUttered,
ActionReverted,
SlotSet,
Restarted,
)
from rasa_sdk.interfaces import ACTION_LISTEN_NAME
from rasa_sdk.interfaces import ACTION_LISTEN_NAME, NLU_FALLBACK_INTENT_NAME
from typing import List, Dict, Text, Any


Expand Down Expand Up @@ -133,3 +132,45 @@ def test_get_extracted_slots_with_no_active_loop():
"some_other": "some_value2",
"my_slot": "some_value",
}


def test_get_intent_of_latest_message_with_missing_data():
tracker = get_tracker([])

tracker.latest_message = None
assert not tracker.get_intent_of_latest_message()

tracker.latest_message = {
"intent": {"name": NLU_FALLBACK_INTENT_NAME, "confidence": 0.9},
}
assert not tracker.get_intent_of_latest_message()


def test_get_intent_of_latest_message_with_only_fallback():
tracker = get_tracker([])
tracker.latest_message = {
"intent": {"name": NLU_FALLBACK_INTENT_NAME, "confidence": 0.9},
"intent_ranking": [{"name": NLU_FALLBACK_INTENT_NAME, "confidence": 0.9}],
}
assert not tracker.get_intent_of_latest_message()
assert (
tracker.get_intent_of_latest_message(skip_fallback_intent=False)
== NLU_FALLBACK_INTENT_NAME
)


def test_get_intent_of_latest_message_with_user_intent():
tracker = get_tracker([])
tracker.latest_message = {
"intent": {"name": NLU_FALLBACK_INTENT_NAME, "confidence": 0.9},
"intent_ranking": [
{"name": NLU_FALLBACK_INTENT_NAME, "confidence": 0.9},
{"name": "hello", "confidence": 0.8},
{"name": "goodbye", "confidence": 0.7},
],
}
assert tracker.get_intent_of_latest_message() == "hello"
assert (
tracker.get_intent_of_latest_message(skip_fallback_intent=False)
== NLU_FALLBACK_INTENT_NAME
)

0 comments on commit a4eac21

Please sign in to comment.