diff --git a/rasa/shared/core/training_data/story_writer/yaml_story_writer.py b/rasa/shared/core/training_data/story_writer/yaml_story_writer.py index faf26a51479c..1e0253a36a08 100644 --- a/rasa/shared/core/training_data/story_writer/yaml_story_writer.py +++ b/rasa/shared/core/training_data/story_writer/yaml_story_writer.py @@ -89,15 +89,15 @@ def dump( is_test_story: Identifies if the stories should be exported in test stories format. """ - self._is_test_story = is_test_story - - result = self.stories_to_yaml(story_steps) + result = self.stories_to_yaml(story_steps, is_test_story) if is_appendable and KEY_STORIES in result: result = result[KEY_STORIES] rasa.shared.utils.io.write_yaml(result, target, True) - def stories_to_yaml(self, story_steps: List[StoryStep]) -> Dict[Text, Any]: + def stories_to_yaml( + self, story_steps: List[StoryStep], is_test_story: bool = False + ) -> Dict[Text, Any]: """Converts a sequence of story steps into yaml format. Args: @@ -105,6 +105,8 @@ def stories_to_yaml(self, story_steps: List[StoryStep]) -> Dict[Text, Any]: """ from rasa.shared.utils.validation import KEY_TRAINING_DATA_FORMAT_VERSION + self._is_test_story = is_test_story + stories = [] rules = [] for story_step in story_steps: diff --git a/tests/shared/core/training_data/story_writer/test_yaml_story_writer.py b/tests/shared/core/training_data/story_writer/test_yaml_story_writer.py index 5ff537e43bf8..cfcc5ea63a57 100644 --- a/tests/shared/core/training_data/story_writer/test_yaml_story_writer.py +++ b/tests/shared/core/training_data/story_writer/test_yaml_story_writer.py @@ -167,3 +167,18 @@ async def test_action_start_action_listen_are_not_dumped(): assert ACTION_SESSION_START_NAME not in dump assert ACTION_LISTEN_NAME not in dump + + +def test_yaml_writer_stories_to_yaml(default_domain: Domain): + from collections import OrderedDict + + reader = YAMLStoryReader(default_domain, None, False) + writer = YAMLStoryWriter() + steps = reader.read_from_file( + "data/test_yaml_stories/simple_story_with_only_end.yml" + ) + + result = writer.stories_to_yaml(steps) + assert isinstance(result, OrderedDict) + assert "stories" in result + assert len(result["stories"]) == 1