From 50c028dbb8124f6be1853dd4819dcfc741493e4e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Erik=20Bj=C3=A4reholt?= Date: Sun, 5 Nov 2023 10:51:18 +0100 Subject: [PATCH] test: refactored commands and improved testing --- gptme/commands.py | 115 +++++++++++++++++++++++++--------------------- tests/test_cli.py | 61 +++++++++++++++++------- 2 files changed, 108 insertions(+), 68 deletions(-) diff --git a/gptme/commands.py b/gptme/commands.py index 43bdb8d3..43bd889b 100644 --- a/gptme/commands.py +++ b/gptme/commands.py @@ -56,7 +56,7 @@ COMMANDS = list(action_descriptions.keys()) -def execute_cmd(msg, log): +def execute_cmd(msg: Message, log: LogManager) -> bool: """Executes any user-command, returns True if command was executed.""" assert msg.role == "user" @@ -88,20 +88,9 @@ def handle_cmd( log.undo(1, quiet=True) log.write() # rename the conversation - print("Renaming conversation (enter 'auto' to generate a name)") + print("Renaming conversation (enter empty name to auto-generate)") new_name = args[0] if args else input("New name: ") - if new_name == "auto": - new_name = llm.generate_name(log.prepare_messages()) - assert " " not in new_name - print(f"Generated name: {new_name}") - confirm = input("Confirm? [y/N] ") - if confirm.lower() not in ["y", "yes"]: - print("Aborting") - return - log.rename(new_name, keep_date=True) - else: - log.rename(new_name, keep_date=False) - print(f"Renamed conversation to {log.logfile.parent}") + rename(log, new_name) case "fork": # fork the conversation new_name = args[0] if args else input("New name: ") @@ -115,26 +104,7 @@ def handle_cmd( # edit previous messages # first undo the '/edit' command itself log.undo(1, quiet=True) - - # generate editable toml of all messages - t = msgs_to_toml(reversed(log.log)) # type: ignore - res = None - while not res: - t = edit_text_with_editor(t, "toml") - try: - res = toml_to_msgs(t) - except Exception as e: - print(f"\nFailed to parse TOML: {e}") - try: - sleep(1) - except KeyboardInterrupt: - yield Message("system", "Interrupted") - return - log.log = list(reversed(res)) - log.write() - # now we need to redraw the log so the user isn't seeing stale messages in their buffer - # log.print() - print("Applied edited messages, write /log to see the result") + edit(log) case "context": # print context msg yield gen_context_msg() @@ -147,20 +117,8 @@ def handle_cmd( case "save": # undo log.undo(1, quiet=True) - - # save the most recent code block to a file - code = log.get_last_code_block() - if not code: - print("No code block found") - return filename = args[0] if args else input("Filename: ") - if Path(filename).exists(): - ans = input("File already exists, overwrite? [y/N] ") - if ans.lower() != "y": - return - with open(filename, "w") as f: - f.write(code) - print(f"Saved code block to {filename}") + save(log, filename) case "exit": sys.exit(0) case "replay": @@ -183,8 +141,61 @@ def handle_cmd( log.undo(1, quiet=True) log.write() - longest_cmd = max(len(cmd) for cmd in COMMANDS) - print("Available commands:") - for cmd, desc in action_descriptions.items(): - cmd = cmd.ljust(longest_cmd) - print(f" /{cmd} {desc}") + +def edit(log: LogManager) -> Generator[Message, None, None]: # pragma: no cover + # generate editable toml of all messages + t = msgs_to_toml(reversed(log.log)) # type: ignore + res = None + while not res: + t = edit_text_with_editor(t, "toml") + try: + res = toml_to_msgs(t) + except Exception as e: + print(f"\nFailed to parse TOML: {e}") + try: + sleep(1) + except KeyboardInterrupt: + yield Message("system", "Interrupted") + return + log.log = list(reversed(res)) + log.write() + # now we need to redraw the log so the user isn't seeing stale messages in their buffer + # log.print() + print("Applied edited messages, write /log to see the result") + + +def save(log: LogManager, filename: str): + # save the most recent code block to a file + code = log.get_last_code_block() + if not code: + print("No code block found") + return + if Path(filename).exists(): + ans = input("File already exists, overwrite? [y/N] ") + if ans.lower() != "y": + return + with open(filename, "w") as f: + f.write(code) + print(f"Saved code block to {filename}") + + +def rename(log: LogManager, new_name: str): + if new_name in ["", "auto"]: + new_name = llm.generate_name(log.prepare_messages()) + assert " " not in new_name + print(f"Generated name: {new_name}") + confirm = input("Confirm? [y/N] ") + if confirm.lower() not in ["y", "yes"]: + print("Aborting") + return + log.rename(new_name, keep_date=True) + else: + log.rename(new_name, keep_date=False) + print(f"Renamed conversation to {log.logfile.parent}") + + +def help(): + longest_cmd = max(len(cmd) for cmd in COMMANDS) + print("Available commands:") + for cmd, desc in action_descriptions.items(): + print(f" /{cmd.ljust(longest_cmd)} {desc}") diff --git a/tests/test_cli.py b/tests/test_cli.py index 6b2205f3..f86ddd76 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -17,6 +17,16 @@ def name(runid, request): return f"test-{runid}-{request.node.name}" +@pytest.fixture +def args(name: str) -> list[str]: + return [ + "--name", + name, + "--model", + "gpt-3.5-turbo", + ] + + @pytest.fixture def runner(): runner = CliRunner() @@ -29,6 +39,36 @@ def test_help(runner: CliRunner): assert result.exit_code == 0 +def test_command_exit(args: list[str], runner: CliRunner): + # tests the /exit command + args.append(f"{CMDFIX}help") + print(f"running: gptme {' '.join(args)}") + result = runner.invoke(gptme.cli.main, args) + + # check that the /exit command is present + assert "/exit" in result.output + assert result.exit_code == 0 + + +def test_command_help(args: list[str], runner: CliRunner): + # tests the /exit command + args.append(f"{CMDFIX}help") + print(f"running: gptme {' '.join(args)}") + result = runner.invoke(gptme.cli.main, args) + + # check that the /exit command is present + assert "/help" in result.output + assert result.exit_code == 0 + + +def test_command_summarize(args: list[str], runner: CliRunner): + # tests the /summarize command + args.append(f"{CMDFIX}summarize") + print(f"running: gptme {' '.join(args)}") + result = runner.invoke(gptme.cli.main, args) + assert result.exit_code == 0 + + def test_shell(name: str, runner: CliRunner): result = runner.invoke( gptme.cli.main, ["--name", name, f'{CMDFIX}shell echo "yes"'] @@ -70,18 +110,15 @@ def test_python_error(name: str, runner: CliRunner): @pytest.mark.parametrize("lang", ["python", "sh"]) -def test_block(name: str, lang: str, runner: CliRunner): +def test_block(args: list[str], lang: str, runner: CliRunner): # tests that shell codeblocks are formatted correctly such that whitespace and newlines are preserved code = blocks[lang] code = f"""```{lang} {code.strip()} ```""" assert "'" not in code - args = [ - "--name", - name, - f"{CMDFIX}impersonate {code}", - ] + + args.append(f"{CMDFIX}impersonate {code}") print(f"running: gptme {' '.join(args)}") result = runner.invoke(gptme.cli.main, args) output = result.output @@ -97,16 +134,8 @@ def test_block(name: str, lang: str, runner: CliRunner): # TODO: these could be fast if we had a cache @pytest.mark.slow def test_generate_primes(name: str, runner: CliRunner): - result = runner.invoke( - gptme.cli.main, - [ - "--name", - name, - "print the first 10 prime numbers", - "--model", - "gpt-3.5-turbo", - ], - ) + args.append("print the first 10 prime numbers") + result = runner.invoke(gptme.cli.main, args) # check that the 9th and 10th prime is present assert "23" in result.output assert "29" in result.output