From 5110d82b22199e048cdb0c9c8193ddc6fe98a063 Mon Sep 17 00:00:00 2001 From: Noah Petherbridge Date: Tue, 5 Sep 2017 11:48:30 -0700 Subject: [PATCH] Fix the deparse method and add unit tests --- rivescript/rivescript.py | 152 +++++++++++------------------------ tests/test_deparse.py | 167 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 215 insertions(+), 104 deletions(-) create mode 100644 tests/test_deparse.py diff --git a/rivescript/rivescript.py b/rivescript/rivescript.py index 2c8398f..d6d07cf 100644 --- a/rivescript/rivescript.py +++ b/rivescript/rivescript.py @@ -7,6 +7,7 @@ from __future__ import unicode_literals from six import text_type +import copy import sys import os import re @@ -109,7 +110,7 @@ def __init__(self, debug=False, strict=True, depth=50, log=None, "sub": {}, "person": {}, } - + # Initialize the session manager. if session_manager is None: session_manager = MemorySessionStorage(warn=self._warn) @@ -298,13 +299,9 @@ def deparse(self): "sub": {}, "person": {}, "array": {}, - "triggers": {}, - "that": {}, + "triggers": [], }, - "topic": {}, - "that": {}, - "inherit": {}, - "include": {}, + "topics": {}, } # Populate the config fields. @@ -322,48 +319,29 @@ def deparse(self): # Topic Triggers. for topic in self._topics: - dest = {} # Where to place the topic info + dest = None # Where to place the topic info if topic == "__begin__": # Begin block. - dest = result["begin"]["triggers"] + dest = result["begin"] else: # Normal topic. - if topic not in result["topic"]: - result["topic"][topic] = {} - dest = result["topic"][topic] + if topic not in result["topics"]: + result["topics"][topic] = { + "triggers": [], + "includes": {}, + "inherits": {}, + } + dest = result["topics"][topic] # Copy the triggers. - for trig, data in self._topics[topic].iteritems(): - dest[trig] = self._copy_trigger(trig, data) - - # %Previous's. - for topic in self._thats: - dest = {} # Where to place the topic info + for trig in self._topics[topic]: + dest["triggers"].append(copy.deepcopy(trig)) - if topic == "__begin__": - # Begin block. - dest = result["begin"]["that"] - else: - # Normal topic. - if topic not in result["that"]: - result["that"][topic] = {} - dest = result["that"][topic] - - # The "that" structure is backwards: bot reply, then trigger, then info. - for previous, pdata in self._thats[topic].iteritems(): - for trig, data in pdata.iteritems(): - dest[trig] = self._copy_trigger(trig, data, previous) - - # Inherits/Includes. - for topic, data in self._lineage.iteritems(): - result["inherit"][topic] = [] - for inherit in data: - result["inherit"][topic].append(inherit) - for topic, data in self._includes.iteritems(): - result["include"][topic] = [] - for include in data: - result["include"][topic].append(include) + # Inherits/Includes. + for label, mapping in {"inherits": self._lineage, "includes": self._includes}.items(): + if topic in mapping and len(mapping[topic]): + dest[label] = mapping[topic].copy() return result @@ -380,11 +358,11 @@ def write(self, fh, deparsed=None): by a user interface for editing RiveScript without writing the code directly). - :param fh: Either a file name ``str`` or a file handle object of a file - opened in write mode. - :param optional dict deparsed: A data structure in the same format as - what ``deparse()`` returns. If not passed, this value will come from - the current in-memory data from ``deparse()``. + Parameters: + fh (str or file): a string or a file-like object. + deparsed (dict): a data structure in the same format as what + ``deparse()`` returns. If not passed, this value will come from + the current in-memory data from ``deparse()``. """ # Passed a string instead of a file handle? @@ -432,101 +410,67 @@ def write(self, fh, deparsed=None): fh.write("\n") # Begin block. - if len(deparsed["begin"]["triggers"].keys()): + if len(deparsed["begin"]["triggers"]): fh.write("> begin\n\n") self._write_triggers(fh, deparsed["begin"]["triggers"], indent="\t") fh.write("< begin\n\n") # The topics. Random first! topics = ["random"] - topics.extend(sorted(deparsed["topic"].keys())) + topics.extend(sorted(deparsed["topics"].keys())) done_random = False for topic in topics: - if topic not in deparsed["topic"]: continue + if topic not in deparsed["topics"]: continue if topic == "random" and done_random: continue if topic == "random": done_random = True tagged = False # Used > topic tag - if topic != "random" or topic in deparsed["include"] or topic in deparsed["inherit"]: + data = deparsed["topics"][topic] + + if topic != "random" or len(data["includes"]) or len(data["inherits"]): tagged = True fh.write("> topic " + topic) - if topic in deparsed["inherit"]: - fh.write(" inherits " + " ".join(deparsed["inherit"][topic])) - if topic in deparsed["include"]: - fh.write(" includes " + " ".join(deparsed["include"][topic])) + if data["inherits"]: + fh.write(" inherits " + " ".join(sorted(data["inherits"].keys()))) + if data["includes"]: + fh.write(" includes " + " ".join(sorted(data["includes"].keys()))) fh.write("\n\n") indent = "\t" if tagged else "" - self._write_triggers(fh, deparsed["topic"][topic], indent=indent) - - # Any %Previous's? - if topic in deparsed["that"]: - self._write_triggers(fh, deparsed["that"][topic], indent=indent) + self._write_triggers(fh, data["triggers"], indent=indent) if tagged: fh.write("< topic\n\n") return True - def _copy_trigger(self, trig, data, previous=None): - """Make copies of all data below a trigger. - - :param str trig: The trigger key. - :param dict data: The data under that trigger. - :param previous: The ``%Previous`` for the trigger. - """ - # Copied data. - dest = {} - - if previous: - dest["previous"] = previous - - if "redirect" in data and data["redirect"]: - # @Redirect - dest["redirect"] = data["redirect"] - - if "condition" in data and len(data["condition"].keys()): - # *Condition - dest["condition"] = [] - for i in sorted(data["condition"].keys()): - dest["condition"].append(data["condition"][i]) - - if "reply" in data and len(data["reply"].keys()): - # -Reply - dest["reply"] = [] - for i in sorted(data["reply"].keys()): - dest["reply"].append(data["reply"][i]) - - return dest - def _write_triggers(self, fh, triggers, indent=""): """Write triggers to a file handle. - :param fh: The file handle. - :param dict triggers: The triggers to write to the file. - :param str indent: The indentation (spaces) to prefix each line with. + Parameters: + fh (file): file object. + triggers (list): list of triggers to write. + indent (str): indentation for each line. """ - for trig in sorted(triggers.keys()): - fh.write(indent + "+ " + self._write_wrapped(trig, indent=indent) + "\n") - d = triggers[trig] + for trig in triggers: + fh.write(indent + "+ " + self._write_wrapped(trig["trigger"], indent=indent) + "\n") + d = trig - if "previous" in d: + if d.get("previous"): fh.write(indent + "% " + self._write_wrapped(d["previous"], indent=indent) + "\n") - if "condition" in d: - for cond in d["condition"]: - fh.write(indent + "* " + self._write_wrapped(cond, indent=indent) + "\n") + for cond in d["condition"]: + fh.write(indent + "* " + self._write_wrapped(cond, indent=indent) + "\n") - if "redirect" in d: + if d.get("redirect"): fh.write(indent + "@ " + self._write_wrapped(d["redirect"], indent=indent) + "\n") - if "reply" in d: - for reply in d["reply"]: - fh.write(indent + "- " + self._write_wrapped(reply, indent=indent) + "\n") + for reply in d["reply"]: + fh.write(indent + "- " + self._write_wrapped(reply, indent=indent) + "\n") fh.write("\n") diff --git a/tests/test_deparse.py b/tests/test_deparse.py new file mode 100644 index 0000000..97a2209 --- /dev/null +++ b/tests/test_deparse.py @@ -0,0 +1,167 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +from __future__ import unicode_literals, absolute_import +from six.moves import cStringIO as StringIO + +from .config import RiveScriptTestCase + +class DeparseTests(RiveScriptTestCase): + """Test deparse and write functions.""" + maxDiff = 8000 + + def test_deparse(self): + # The original source that should match the re-written version. + source = """// Written by rivescript.deparse() + ! version = 2.0 + + ! var age = 5 + ! var name = Aiden + + > begin + + + request + - {ok} + + < begin + + + what is your name + - My name is . + + + my name is * + - >Nice to meet you. + + + you too + % nice to meet you + - :) + + + who am i + * != undefined => Aren't you ? + - I don't know. + - We've never met. + + > topic a includes b + + + a + - A. + + < topic + + > topic b + + + b + - B. + + < topic + + > topic c inherits b includes a + + + c + - C. + + < topic + """ + + # Expected deparsed data structure. + expected = { + "begin": { + "global": {}, + "var": { + "name": "Aiden", + "age": "5", + }, + "sub": {}, + "person": {}, + "array": {}, + "triggers": [{ + "trigger": "request", + "reply": ["{ok}"], + "condition": [], + "redirect": None, + "previous": None, + }], + }, + "topics": { + "random": { + "includes": {}, + "inherits": {}, + "triggers": [ + { + "trigger": "what is your name", + "previous": None, + "redirect": None, + "condition": [], + "reply": ["My name is ."] + }, + { + "trigger": "my name is *", + "previous": None, + "redirect": None, + "condition": [], + "reply": [">Nice to meet you."], + }, + { + "trigger": "you too", + "previous": "nice to meet you", + "redirect": None, + "condition": [], + "reply": [":)"], + }, + { + "trigger": "who am i", + "previous": None, + "redirect": None, + "condition": [ + " != undefined => Aren't you ?", + ], + "reply": ["I don't know.", "We've never met."], + }, + ] + }, + "a": { + "includes": { "b": 1 }, + "inherits": {}, + "triggers": [{ + "trigger": "a", + "previous": None, + "redirect": None, + "condition": [], + "reply": ["A."], + }], + }, + "b": { + "includes": {}, + "inherits": {}, + "triggers": [{ + "trigger": "b", + "previous": None, + "redirect": None, + "condition": [], + "reply": ["B."], + }], + }, + "c": { + "includes": {"a": 1}, + "inherits": {"b": 1}, + "triggers": [{ + "trigger": "c", + "previous": None, + "redirect": None, + "condition": [], + "reply": ["C."], + }], + } + } + } + + # Verify the deparsed tree matches expectations. + self.new(source) + dep = self.rs.deparse() + self.assertEqual(dep, expected) + + # See if the re-written RiveScript source matches the original. + buf = StringIO() + self.rs.write(buf) + written = buf.getvalue().split("\n") + for i, line in enumerate(source.split("\n")): + assert line.strip() == written[i].strip()