Skip to content

Commit

Permalink
Merge branch 'master' into support-retrieval-intents-e2e
Browse files Browse the repository at this point in the history
  • Loading branch information
rasabot authored Nov 17, 2020
2 parents fb4f151 + d5710d9 commit 4bfc4ae
Show file tree
Hide file tree
Showing 5 changed files with 71 additions and 7 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/continous-integration.yml
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ jobs:
# Fetch entire history for current branch so that `make lint-docstrings`
# can calculate the proper diff between the branches
git pull --ff-only --unshallow origin ${{ github.head_ref }}
git pull --ff-only --unshallow origin "${{ github.head_ref }}"
- name: Lint Code 🎎
if: needs.changes.outputs.backend == 'true'
Expand Down
1 change: 1 addition & 0 deletions changelog/5974.bugfix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
`ActionRestart` will now trigger `ActionSessionStart` as a followup action.
4 changes: 2 additions & 2 deletions rasa/shared/core/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
LOOP_INTERRUPTED,
ENTITY_LABEL_SEPARATOR,
ACTION_SESSION_START_NAME,
ACTION_LISTEN_NAME,
)
from rasa.shared.nlu.constants import (
ENTITY_ATTRIBUTE_TYPE,
Expand Down Expand Up @@ -692,8 +691,9 @@ def as_story_string(self) -> Text:
return self.type_name

def apply_to(self, tracker: "DialogueStateTracker") -> None:
"""Resets the tracker and triggers a followup `ActionSessionStart`."""
tracker._reset()
tracker.trigger_followup_action(ACTION_LISTEN_NAME)
tracker.trigger_followup_action(ACTION_SESSION_START_NAME)


# noinspection PyProtectedMember
Expand Down
62 changes: 61 additions & 1 deletion tests/core/test_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from typing import Optional, Text, List, Callable, Type, Any, Tuple
from unittest.mock import patch, Mock

from rasa.core.policies.rule_policy import RulePolicy
from rasa.core.actions.action import (
ActionUtterTemplate,
ActionListen,
Expand Down Expand Up @@ -42,7 +43,7 @@
LoopInterrupted,
)
from rasa.core.interpreter import RasaNLUHttpInterpreter
from rasa.shared.nlu.interpreter import NaturalLanguageInterpreter
from rasa.shared.nlu.interpreter import NaturalLanguageInterpreter, RegexInterpreter
from rasa.core.policies import SimplePolicyEnsemble, PolicyEnsemble
from rasa.core.policies.ted_policy import TEDPolicy
from rasa.core.processor import MessageProcessor
Expand All @@ -52,6 +53,7 @@
from rasa.shared.nlu.constants import INTENT_NAME_KEY
from rasa.utils.endpoints import EndpointConfig
from rasa.shared.core.constants import (
ACTION_RESTART_NAME,
DEFAULT_INTENTS,
ACTION_LISTEN_NAME,
ACTION_SESSION_START_NAME,
Expand Down Expand Up @@ -854,6 +856,64 @@ def test_get_next_action_probabilities_pass_policy_predictions_without_interpret
)


async def test_restart_triggers_session_start(
default_channel: CollectingOutputChannel,
default_processor: MessageProcessor,
monkeypatch: MonkeyPatch,
):
# The rule policy is trained and used so as to allow the default action ActionRestart to be predicted
rule_policy = RulePolicy()
rule_policy.train([], default_processor.domain, RegexInterpreter())
monkeypatch.setattr(
default_processor.policy_ensemble,
"policies",
[rule_policy, *default_processor.policy_ensemble.policies],
)

sender_id = uuid.uuid4().hex

entity = "name"
slot_1 = {entity: "name1"}
await default_processor.handle_message(
UserMessage(f"/greet{json.dumps(slot_1)}", default_channel, sender_id)
)

assert default_channel.latest_output() == {
"recipient_id": sender_id,
"text": "hey there name1!",
}

# This restarts the chat
await default_processor.handle_message(
UserMessage("/restart", default_channel, sender_id)
)

tracker = default_processor.tracker_store.get_or_create_tracker(sender_id)

expected = [
ActionExecuted(ACTION_SESSION_START_NAME),
SessionStarted(),
ActionExecuted(ACTION_LISTEN_NAME),
UserUttered(
f"/greet{json.dumps(slot_1)}",
{INTENT_NAME_KEY: "greet", "confidence": 1.0},
[{"entity": entity, "start": 6, "end": 23, "value": "name1"}],
),
SlotSet(entity, slot_1[entity]),
ActionExecuted("utter_greet"),
BotUttered("hey there name1!", metadata={"template_name": "utter_greet"}),
ActionExecuted(ACTION_LISTEN_NAME),
UserUttered("/restart", {INTENT_NAME_KEY: "restart", "confidence": 1.0}),
ActionExecuted(ACTION_RESTART_NAME),
Restarted(),
ActionExecuted(ACTION_SESSION_START_NAME),
SessionStarted(),
# No previous slot is set due to restart.
ActionExecuted(ACTION_LISTEN_NAME),
]
assert list(tracker.events) == expected


async def test_handle_message_if_action_manually_rejects(
default_processor: MessageProcessor, monkeypatch: MonkeyPatch
):
Expand Down
9 changes: 6 additions & 3 deletions tests/shared/core/test_trackers.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,10 @@ def test_restart_event(default_domain: Domain):
tracker.update(Restarted())

assert len(tracker.events) == 5
assert tracker.followup_action is not None
assert tracker.followup_action == ACTION_SESSION_START_NAME

tracker.update(SessionStarted())

assert tracker.followup_action == ACTION_LISTEN_NAME
assert tracker.latest_message.text is None
assert len(list(tracker.generate_all_prior_trackers())) == 1
Expand All @@ -370,7 +373,7 @@ def test_restart_event(default_domain: Domain):
recovered.recreate_from_dialogue(dialogue)

assert recovered.current_state() == tracker.current_state()
assert len(recovered.events) == 5
assert len(recovered.events) == 6
assert recovered.latest_message.text is None
assert len(list(recovered.generate_all_prior_trackers())) == 1

Expand Down Expand Up @@ -1218,7 +1221,7 @@ def test_set_form_validation_deprecation_warning(validate: bool):
),
(
# this conversation does not contain a session
[UserUttered("hi", {"name": "greet"}), ActionExecuted("utter_greet"),],
[UserUttered("hi", {"name": "greet"}), ActionExecuted("utter_greet")],
1,
),
],
Expand Down

0 comments on commit 4bfc4ae

Please sign in to comment.