Skip to content

Commit

Permalink
Merge pull request #7098 from RasaHQ/support-retrieval-intents-e2e
Browse files Browse the repository at this point in the history
make rasa test support retrieval intents for test stories
  • Loading branch information
rasabot authored Nov 17, 2020
2 parents d5710d9 + 4bfc4ae commit 2e7f49f
Show file tree
Hide file tree
Showing 14 changed files with 246 additions and 26 deletions.
2 changes: 2 additions & 0 deletions changelog/7002.bugfix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Test stories can now contain both: normal intents and retrieval intents. The `failed_test_stories.yml`, generated by `rasa test`, also specifies the full retrieval intent now.
Previously `rasa test` would fail on test stories that specified retrieval intents.
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
stories:
- story: chitchat name
steps:
- intent: chitchat/ask_name
- action: utter_chitchat
7 changes: 7 additions & 0 deletions data/test_yaml_stories/test_base_retrieval_intent_story.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
stories:
- story: chitchat name
steps:
- user: |
What's the weather like today?
intent: chitchat
- action: utter_chitchat
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
stories:
- story: chitchat name
steps:
- user: |
What is your name?
intent: affirm
- action: utter_chitchat
7 changes: 7 additions & 0 deletions data/test_yaml_stories/test_full_retrieval_intent_story.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
stories:
- story: chitchat name
steps:
- user: |
What is your name?
intent: chitchat/ask_name
- action: utter_chitchat
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
stories:
- story: chitchat name
steps:
- user: |
What is your name?
intent: chitchat/ask_weather
- action: utter_chitchat
3 changes: 2 additions & 1 deletion rasa/core/featurizers/tracker_featurizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import logging

from rasa.shared.exceptions import RasaException
from rasa.shared.nlu.constants import TEXT
from rasa.shared.nlu.constants import TEXT, INTENT
from tqdm import tqdm
from typing import Tuple, List, Optional, Dict, Text, Union
import numpy as np
Expand All @@ -17,6 +17,7 @@
from rasa.shared.core.constants import USER
import rasa.shared.utils.io
from rasa.shared.nlu.training_data.features import Features
from rasa.shared.constants import INTENT_MESSAGE_PREFIX

FEATURIZER_FILE = "featurizer.json"

Expand Down
70 changes: 65 additions & 5 deletions rasa/core/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,16 @@
from rasa.shared.exceptions import RasaException
import rasa.shared.utils.io
from rasa.core.channels import UserMessage
from rasa.shared.nlu.training_data.message import Message
from rasa.shared.core.training_data.story_writer.yaml_story_writer import (
YAMLStoryWriter,
)
from rasa.shared.core.domain import Domain
from rasa.nlu.constants import ENTITY_ATTRIBUTE_TEXT
from rasa.nlu.constants import (
ENTITY_ATTRIBUTE_TEXT,
RESPONSE_SELECTOR_DEFAULT_INTENT,
RESPONSE_SELECTOR_RETRIEVAL_INTENTS,
)
from rasa.shared.nlu.constants import (
INTENT,
ENTITIES,
Expand All @@ -23,6 +28,11 @@
ENTITY_ATTRIBUTE_END,
EXTRACTOR,
ENTITY_ATTRIBUTE_TYPE,
INTENT_RESPONSE_KEY,
INTENT_NAME_KEY,
RESPONSE,
RESPONSE_SELECTOR,
FULL_RETRIEVAL_INTENT_NAME_KEY,
)
from rasa.constants import RESULTS_FILE, PERCENTAGE_KEY
from rasa.shared.core.events import ActionExecuted, UserUttered
Expand Down Expand Up @@ -352,6 +362,44 @@ def _clean_entity_results(
return cleaned_entities


def _get_full_retrieval_intent(parsed: Dict[Text, Any]) -> Text:
"""Return full retrieval intent, if it's present, or normal intent otherwise.
Args:
parsed: Predicted parsed data.
Returns:
The extracted intent.
"""
base_intent = parsed.get(INTENT, {}).get(INTENT_NAME_KEY)
response_selector = parsed.get(RESPONSE_SELECTOR, {})

# return normal intent if it's not a retrieval intent
if base_intent not in response_selector.get(
RESPONSE_SELECTOR_RETRIEVAL_INTENTS, {}
):
return base_intent

# extract full retrieval intent
# if the response selector parameter was not specified in config,
# the response selector contains a "default" key
if RESPONSE_SELECTOR_DEFAULT_INTENT in response_selector:
full_retrieval_intent = (
response_selector.get(RESPONSE_SELECTOR_DEFAULT_INTENT, {})
.get(RESPONSE, {})
.get(INTENT_RESPONSE_KEY)
)
return full_retrieval_intent if full_retrieval_intent else base_intent

# if specified, the response selector contains the base intent as key
full_retrieval_intent = (
response_selector.get(base_intent, {})
.get(RESPONSE, {})
.get(INTENT_RESPONSE_KEY)
)
return full_retrieval_intent if full_retrieval_intent else base_intent


def _collect_user_uttered_predictions(
event: UserUttered,
predicted: Dict[Text, Any],
Expand All @@ -360,11 +408,22 @@ def _collect_user_uttered_predictions(
) -> EvaluationStore:
user_uttered_eval_store = EvaluationStore()

intent_gold = event.intent.get("name")
predicted_intent = predicted.get(INTENT, {}).get("name")
# intent from the test story, may either be base intent or full retrieval intent
base_intent = event.intent.get(INTENT_NAME_KEY)
full_retrieval_intent = event.intent.get(FULL_RETRIEVAL_INTENT_NAME_KEY)
intent_gold = full_retrieval_intent if full_retrieval_intent else base_intent

# predicted intent: note that this is only the base intent at this point
predicted_base_intent = predicted.get(INTENT, {}).get(INTENT_NAME_KEY)

# if the test story only provides the base intent AND the prediction was correct,
# we are not interested in full retrieval intents and skip this section.
# In any other case we are interested in the full retrieval intent (e.g. for report)
if intent_gold != predicted_base_intent:
predicted_base_intent = _get_full_retrieval_intent(predicted)

user_uttered_eval_store.add_to_store(
intent_predictions=[predicted_intent], intent_targets=[intent_gold]
intent_predictions=[predicted_base_intent], intent_targets=[intent_gold]
)

entity_gold = event.entities
Expand Down Expand Up @@ -553,7 +612,8 @@ async def _predict_tracker_actions(
# Indirectly that means that the test story was in YAML format.
if not event.text:
predicted = event.parse_data
# Indirectly that means that the test story was in Markdown format.
# Indirectly that means that the test story was either:
# in YAML format containing a user message, or in Markdown format.
# Leaving that as it is because Markdown is in legacy mode.
else:
predicted = await processor.parse_message(UserMessage(event.text))
Expand Down
45 changes: 39 additions & 6 deletions rasa/shared/core/training_data/story_reader/yaml_story_reader.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,19 @@
import functools
import logging
from pathlib import Path
from typing import Dict, Text, List, Any, Optional, Union
from typing import Dict, Text, List, Any, Optional, Union, Tuple

import rasa.shared.data
from rasa.shared.core.slots import TextSlot, ListSlot
from rasa.shared.exceptions import YamlException
import rasa.shared.utils.io
from rasa.shared.core.constants import LOOP_NAME
from rasa.shared.nlu.constants import ENTITIES, INTENT_NAME_KEY
from rasa.shared.nlu.constants import (
ENTITIES,
INTENT_NAME_KEY,
PREDICTED_CONFIDENCE_KEY,
FULL_RETRIEVAL_INTENT_NAME_KEY,
)
from rasa.shared.nlu.training_data import entities_parser
import rasa.shared.utils.validation

Expand All @@ -24,6 +29,7 @@
from rasa.shared.core.events import UserUttered, SlotSet, ActiveLoop
from rasa.shared.core.training_data.story_reader.story_reader import StoryReader
from rasa.shared.core.training_data.structures import StoryStep
from rasa.shared.nlu.training_data.message import Message

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -307,8 +313,13 @@ def _parse_user_utterance(self, step: Dict[Text, Any]) -> None:
self.current_step_builder.add_user_messages([utterance])

def _validate_that_utterance_is_in_domain(self, utterance: UserUttered) -> None:

intent_name = utterance.intent.get(INTENT_NAME_KEY)

# check if this is a retrieval intent
# in this case check only for the base intent in domain
intent_name = Message.separate_intent_response_key(intent_name)[0]

if not self.domain:
logger.debug(
"Skipped validating if intent is in domain as domain " "is `None`."
Expand Down Expand Up @@ -345,7 +356,9 @@ def _parse_or_statement(self, step: Dict[Text, Any]) -> None:
utterances, self._is_used_for_training
)

def _user_intent_from_step(self, step: Dict[Text, Any]) -> Text:
def _user_intent_from_step(
self, step: Dict[Text, Any]
) -> Tuple[Text, Optional[Text]]:
user_intent = step.get(KEY_USER_INTENT, "").strip()

if not user_intent:
Expand All @@ -366,13 +379,33 @@ def _user_intent_from_step(self, step: Dict[Text, Any]) -> Text:
)
# Remove leading slash
user_intent = user_intent[1:]
return user_intent

# StoryStep should never contain a full retrieval intent, only the base intent.
# However, users can specify full retrieval intents in their test stories file
# for the NLU testing purposes.
base_intent, response_key = Message.separate_intent_response_key(user_intent)
if response_key and not self.is_test_stories_file(self.source_name):
rasa.shared.utils.io.raise_warning(
f"Issue found in '{self.source_name}' while parsing story "
f"{self._get_item_title()}:\n"
f"User intent '{user_intent}' is a full retrieval intent. "
f"Stories shouldn't contain full retrieval intents. "
f"Rasa Open Source will only use base intent '{base_intent}' "
f"for training.",
docs=self._get_docs_link(),
)

return (base_intent, user_intent) if response_key else (base_intent, None)

def _parse_raw_user_utterance(self, step: Dict[Text, Any]) -> Optional[UserUttered]:
from rasa.shared.nlu.interpreter import RegexInterpreter

intent_name = self._user_intent_from_step(step)
intent = {"name": intent_name, "confidence": 1.0}
intent_name, full_retrieval_intent = self._user_intent_from_step(step)
intent = {
INTENT_NAME_KEY: intent_name,
FULL_RETRIEVAL_INTENT_NAME_KEY: full_retrieval_intent,
PREDICTED_CONFIDENCE_KEY: 1.0,
}

if KEY_USER_MESSAGE in step:
user_message = step[KEY_USER_MESSAGE].strip()
Expand Down
2 changes: 2 additions & 0 deletions rasa/shared/nlu/constants.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
TEXT = "text"
INTENT = "intent"
RESPONSE = "response"
RESPONSE_SELECTOR = "response_selector"
INTENT_RESPONSE_KEY = "intent_response_key"
ACTION_TEXT = "action_text"
ACTION_NAME = "action_name"
INTENT_NAME_KEY = "name"
FULL_RETRIEVAL_INTENT_NAME_KEY = "full_retrieval_intent_name"
METADATA = "metadata"
METADATA_INTENT = "intent"
METADATA_EXAMPLE = "example"
Expand Down
17 changes: 16 additions & 1 deletion tests/core/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import uuid
from datetime import datetime

from typing import Text, Generator
from typing import Text, Generator, Callable

import pytest

Expand Down Expand Up @@ -219,3 +219,18 @@ async def form_bot_agent(trained_async) -> Agent:
)

return Agent.load_local_model(zipped_model)


@pytest.fixture(scope="session")
async def response_selector_agent(trained_async: Callable) -> Agent:
zipped_model = await trained_async(
domain="examples/responseselectorbot/domain.yml",
config="examples/responseselectorbot/config.yml",
training_files=[
"examples/responseselectorbot/data/rules.yml",
"examples/responseselectorbot/data/stories.yml",
"examples/responseselectorbot/data/nlu.yml",
],
)

return Agent.load_local_model(zipped_model)
46 changes: 46 additions & 0 deletions tests/core/test_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,3 +242,49 @@ def test_event_has_proper_implementation(
actual_entities = _clean_entity_results(text, [entity])

assert actual_entities[0] == expected_entity


@pytest.mark.parametrize(
"test_file",
[
("data/test_yaml_stories/test_full_retrieval_intent_story.yml"),
("data/test_yaml_stories/test_base_retrieval_intent_story.yml"),
],
)
async def test_retrieval_intent(response_selector_agent: Agent, test_file: Text):
generator = await _create_data_generator(
test_file, response_selector_agent, use_e2e=True,
)
test_stories = generator.generate_story_trackers()

story_evaluation, num_stories = await _collect_story_predictions(
test_stories, response_selector_agent, use_e2e=True
)
# check that test story can either specify base intent or full retrieval intent
assert not story_evaluation.evaluation_store.has_prediction_target_mismatch()


@pytest.mark.parametrize(
"test_file",
[
("data/test_yaml_stories/test_full_retrieval_intent_wrong_prediction.yml"),
("data/test_yaml_stories/test_base_retrieval_intent_wrong_prediction.yml"),
],
)
async def test_retrieval_intent_wrong_prediction(
tmpdir: Path, response_selector_agent: Agent, test_file: Text
):
stories_path = str(tmpdir / FAILED_STORIES_FILE)

await evaluate_stories(
stories=test_file,
agent=response_selector_agent,
out_directory=str(tmpdir),
max_stories=None,
e2e=True,
)

failed_stories = rasa.shared.utils.io.read_file(stories_path)

# check if the predicted entry contains full retrieval intent
assert "# predicted: chitchat/ask_name" in failed_stories
Loading

0 comments on commit 2e7f49f

Please sign in to comment.