Skip to content

Commit

Permalink
Merge pull request #6969 from RasaHQ/fix-yaml-story-writer
Browse files Browse the repository at this point in the history
Fix AttributeError when using YAMLStoryWriter.stories_to_yaml
  • Loading branch information
federicotdn authored Oct 8, 2020
2 parents c0b953e + d497757 commit 4c09656
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 4 deletions.
10 changes: 6 additions & 4 deletions rasa/shared/core/training_data/story_writer/yaml_story_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,22 +89,24 @@ def dump(
is_test_story: Identifies if the stories should be exported in test stories
format.
"""
self._is_test_story = is_test_story

result = self.stories_to_yaml(story_steps)
result = self.stories_to_yaml(story_steps, is_test_story)
if is_appendable and KEY_STORIES in result:
result = result[KEY_STORIES]

rasa.shared.utils.io.write_yaml(result, target, True)

def stories_to_yaml(self, story_steps: List[StoryStep]) -> Dict[Text, Any]:
def stories_to_yaml(
self, story_steps: List[StoryStep], is_test_story: bool = False
) -> Dict[Text, Any]:
"""Converts a sequence of story steps into yaml format.
Args:
story_steps: Original story steps to be converted to the YAML.
"""
from rasa.shared.utils.validation import KEY_TRAINING_DATA_FORMAT_VERSION

self._is_test_story = is_test_story

stories = []
rules = []
for story_step in story_steps:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,3 +167,18 @@ async def test_action_start_action_listen_are_not_dumped():

assert ACTION_SESSION_START_NAME not in dump
assert ACTION_LISTEN_NAME not in dump


def test_yaml_writer_stories_to_yaml(default_domain: Domain):
from collections import OrderedDict

reader = YAMLStoryReader(default_domain, None, False)
writer = YAMLStoryWriter()
steps = reader.read_from_file(
"data/test_yaml_stories/simple_story_with_only_end.yml"
)

result = writer.stories_to_yaml(steps)
assert isinstance(result, OrderedDict)
assert "stories" in result
assert len(result["stories"]) == 1

0 comments on commit 4c09656

Please sign in to comment.