Skip to content

Commit

Permalink
test: refactored commands and improved testing
Browse files Browse the repository at this point in the history
  • Loading branch information
ErikBjare committed Nov 5, 2023
1 parent d0a2245 commit 50c028d
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 68 deletions.
115 changes: 63 additions & 52 deletions gptme/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
COMMANDS = list(action_descriptions.keys())


def execute_cmd(msg, log):
def execute_cmd(msg: Message, log: LogManager) -> bool:
"""Executes any user-command, returns True if command was executed."""
assert msg.role == "user"

Expand Down Expand Up @@ -88,20 +88,9 @@ def handle_cmd(
log.undo(1, quiet=True)
log.write()
# rename the conversation
print("Renaming conversation (enter 'auto' to generate a name)")
print("Renaming conversation (enter empty name to auto-generate)")
new_name = args[0] if args else input("New name: ")
if new_name == "auto":
new_name = llm.generate_name(log.prepare_messages())
assert " " not in new_name
print(f"Generated name: {new_name}")
confirm = input("Confirm? [y/N] ")
if confirm.lower() not in ["y", "yes"]:
print("Aborting")
return
log.rename(new_name, keep_date=True)
else:
log.rename(new_name, keep_date=False)
print(f"Renamed conversation to {log.logfile.parent}")
rename(log, new_name)
case "fork":
# fork the conversation
new_name = args[0] if args else input("New name: ")
Expand All @@ -115,26 +104,7 @@ def handle_cmd(
# edit previous messages
# first undo the '/edit' command itself
log.undo(1, quiet=True)

# generate editable toml of all messages
t = msgs_to_toml(reversed(log.log)) # type: ignore
res = None
while not res:
t = edit_text_with_editor(t, "toml")
try:
res = toml_to_msgs(t)
except Exception as e:
print(f"\nFailed to parse TOML: {e}")
try:
sleep(1)
except KeyboardInterrupt:
yield Message("system", "Interrupted")
return
log.log = list(reversed(res))
log.write()
# now we need to redraw the log so the user isn't seeing stale messages in their buffer
# log.print()
print("Applied edited messages, write /log to see the result")
edit(log)
case "context":
# print context msg
yield gen_context_msg()
Expand All @@ -147,20 +117,8 @@ def handle_cmd(
case "save":
# undo
log.undo(1, quiet=True)

# save the most recent code block to a file
code = log.get_last_code_block()
if not code:
print("No code block found")
return
filename = args[0] if args else input("Filename: ")
if Path(filename).exists():
ans = input("File already exists, overwrite? [y/N] ")
if ans.lower() != "y":
return
with open(filename, "w") as f:
f.write(code)
print(f"Saved code block to {filename}")
save(log, filename)
case "exit":
sys.exit(0)
case "replay":
Expand All @@ -183,8 +141,61 @@ def handle_cmd(
log.undo(1, quiet=True)
log.write()

longest_cmd = max(len(cmd) for cmd in COMMANDS)
print("Available commands:")
for cmd, desc in action_descriptions.items():
cmd = cmd.ljust(longest_cmd)
print(f" /{cmd} {desc}")

def edit(log: LogManager) -> Generator[Message, None, None]: # pragma: no cover
# generate editable toml of all messages
t = msgs_to_toml(reversed(log.log)) # type: ignore
res = None
while not res:
t = edit_text_with_editor(t, "toml")
try:
res = toml_to_msgs(t)
except Exception as e:
print(f"\nFailed to parse TOML: {e}")
try:
sleep(1)
except KeyboardInterrupt:
yield Message("system", "Interrupted")
return
log.log = list(reversed(res))
log.write()
# now we need to redraw the log so the user isn't seeing stale messages in their buffer
# log.print()
print("Applied edited messages, write /log to see the result")


def save(log: LogManager, filename: str):
# save the most recent code block to a file
code = log.get_last_code_block()
if not code:
print("No code block found")
return
if Path(filename).exists():
ans = input("File already exists, overwrite? [y/N] ")
if ans.lower() != "y":
return
with open(filename, "w") as f:
f.write(code)
print(f"Saved code block to {filename}")


def rename(log: LogManager, new_name: str):
if new_name in ["", "auto"]:
new_name = llm.generate_name(log.prepare_messages())
assert " " not in new_name
print(f"Generated name: {new_name}")
confirm = input("Confirm? [y/N] ")
if confirm.lower() not in ["y", "yes"]:
print("Aborting")
return
log.rename(new_name, keep_date=True)
else:
log.rename(new_name, keep_date=False)
print(f"Renamed conversation to {log.logfile.parent}")


def help():
longest_cmd = max(len(cmd) for cmd in COMMANDS)
print("Available commands:")
for cmd, desc in action_descriptions.items():
print(f" /{cmd.ljust(longest_cmd)} {desc}")
61 changes: 45 additions & 16 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,16 @@ def name(runid, request):
return f"test-{runid}-{request.node.name}"


@pytest.fixture
def args(name: str) -> list[str]:
return [
"--name",
name,
"--model",
"gpt-3.5-turbo",
]


@pytest.fixture
def runner():
runner = CliRunner()
Expand All @@ -29,6 +39,36 @@ def test_help(runner: CliRunner):
assert result.exit_code == 0


def test_command_exit(args: list[str], runner: CliRunner):
# tests the /exit command
args.append(f"{CMDFIX}help")
print(f"running: gptme {' '.join(args)}")
result = runner.invoke(gptme.cli.main, args)

# check that the /exit command is present
assert "/exit" in result.output
assert result.exit_code == 0


def test_command_help(args: list[str], runner: CliRunner):
# tests the /exit command
args.append(f"{CMDFIX}help")
print(f"running: gptme {' '.join(args)}")
result = runner.invoke(gptme.cli.main, args)

# check that the /exit command is present
assert "/help" in result.output
assert result.exit_code == 0


def test_command_summarize(args: list[str], runner: CliRunner):
# tests the /summarize command
args.append(f"{CMDFIX}summarize")
print(f"running: gptme {' '.join(args)}")
result = runner.invoke(gptme.cli.main, args)
assert result.exit_code == 0


def test_shell(name: str, runner: CliRunner):
result = runner.invoke(
gptme.cli.main, ["--name", name, f'{CMDFIX}shell echo "yes"']
Expand Down Expand Up @@ -70,18 +110,15 @@ def test_python_error(name: str, runner: CliRunner):


@pytest.mark.parametrize("lang", ["python", "sh"])
def test_block(name: str, lang: str, runner: CliRunner):
def test_block(args: list[str], lang: str, runner: CliRunner):
# tests that shell codeblocks are formatted correctly such that whitespace and newlines are preserved
code = blocks[lang]
code = f"""```{lang}
{code.strip()}
```"""
assert "'" not in code
args = [
"--name",
name,
f"{CMDFIX}impersonate {code}",
]

args.append(f"{CMDFIX}impersonate {code}")
print(f"running: gptme {' '.join(args)}")
result = runner.invoke(gptme.cli.main, args)
output = result.output
Expand All @@ -97,16 +134,8 @@ def test_block(name: str, lang: str, runner: CliRunner):
# TODO: these could be fast if we had a cache
@pytest.mark.slow
def test_generate_primes(name: str, runner: CliRunner):
result = runner.invoke(
gptme.cli.main,
[
"--name",
name,
"print the first 10 prime numbers",
"--model",
"gpt-3.5-turbo",
],
)
args.append("print the first 10 prime numbers")
result = runner.invoke(gptme.cli.main, args)
# check that the 9th and 10th prime is present
assert "23" in result.output
assert "29" in result.output
Expand Down

0 comments on commit 50c028d

Please sign in to comment.