Skip to content

Commit

Permalink
test: improved command testing, fixed prompt parsing when passed comm…
Browse files Browse the repository at this point in the history
…and with path
  • Loading branch information
ErikBjare committed Nov 5, 2023
1 parent 50c028d commit ef6b472
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 52 deletions.
80 changes: 47 additions & 33 deletions gptme/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,12 @@
from rich.console import Console

from .commands import CMDFIX, action_descriptions, execute_cmd
from .constants import HISTORY_FILE, LOGSDIR, PROMPT_USER
from .constants import (
HISTORY_FILE,
LOGSDIR,
MULTIPROMPT_SEPARATOR,
PROMPT_USER,
)
from .llm import init_llm, reply
from .logmanager import LogManager, _conversations
from .message import Message
Expand Down Expand Up @@ -190,42 +195,13 @@ def main(
log.print()
print("--- ^^^ past messages ^^^ ---")

def parse_prompt(prompt: str) -> str:
try:
# check if prompt is a path, if so, replace it with the contents of that file
f = Path(prompt).expanduser()
if f.exists() and f.is_file():
return f"```{prompt}\n{Path(prompt).expanduser().read_text()}\n```"
except OSError as oserr:
# some prompts are too long to be a path, so we can't read them
if oserr.errno != errno.ENAMETOOLONG:
pass
except UnicodeDecodeError:
# some files are not text files (images, audio, PDFs, binaries, etc), so we can't read them
# TODO: but can we handle them better than just printing the path? maybe with metadata from `file`?
pass

words = prompt.split()
if len(words) > 1:
# check if substring of prompt is a path, if so, append the contents of that file
paths = []
for word in words:
f = Path(word).expanduser()
if f.exists() and f.is_file():
paths.append(word)
if paths:
prompt += "\n\n"
for path in paths:
prompt += parse_prompt(path)

return prompt

# check if any prompt is a full path, if so, replace it with the contents of that file
# TODO: add support for directories
# TODO: maybe do this for all prompts, not just those passed on cli
prompts = [parse_prompt(p) for p in prompts]
prompts = [_parse_prompt(p) for p in prompts]
# join prompts, grouped by `-` if present, since that's the separator for multiple-round prompts
prompts = [p.strip() for p in "\n\n".join(prompts).split("\n\n-") if p]
sep = "\n\n" + MULTIPROMPT_SEPARATOR
prompts = [p.strip() for p in "\n\n".join(prompts).split(sep) if p]

# main loop
while True:
Expand Down Expand Up @@ -454,5 +430,43 @@ def _read_stdin() -> str:
return all_data


def _parse_prompt(prompt: str) -> str:
# if prompt is a command, exit early (as commands might take paths as arguments)
if any(
prompt.startswith(command)
for command in [f"{CMDFIX}{cmd}" for cmd in action_descriptions.keys()]
):
return prompt

try:
# check if prompt is a path, if so, replace it with the contents of that file
f = Path(prompt).expanduser()
if f.exists() and f.is_file():
return f"```{prompt}\n{Path(prompt).expanduser().read_text()}\n```"
except OSError as oserr:
# some prompts are too long to be a path, so we can't read them
if oserr.errno != errno.ENAMETOOLONG:
pass
except UnicodeDecodeError:
# some files are not text files (images, audio, PDFs, binaries, etc), so we can't read them
# TODO: but can we handle them better than just printing the path? maybe with metadata from `file`?
pass

words = prompt.split()
if len(words) > 1:
# check if substring of prompt is a path, if so, append the contents of that file
paths = []
for word in words:
f = Path(word).expanduser()
if f.exists() and f.is_file():
paths.append(word)
if paths:
prompt += "\n\n"
for path in paths:
prompt += _parse_prompt(path)

return prompt


if __name__ == "__main__":
main()
3 changes: 2 additions & 1 deletion gptme/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ def handle_cmd(
# undo the '/help' command itself
log.undo(1, quiet=True)
log.write()
help()


def edit(log: LogManager) -> Generator[Message, None, None]: # pragma: no cover
Expand All @@ -166,7 +167,7 @@ def edit(log: LogManager) -> Generator[Message, None, None]: # pragma: no cover

def save(log: LogManager, filename: str):
# save the most recent code block to a file
code = log.get_last_code_block()
code = log.get_last_code_block(content=True)
if not code:
print("No code block found")
return
Expand Down
9 changes: 8 additions & 1 deletion gptme/constants.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
from pathlib import Path

CMDFIX = "/" # prefix for commands, e.g. /help
# prefix for commands, e.g. /help
CMDFIX = "/"

# separator for multiple rounds of prompts on the command line
# demarcates the end of the user's prompt, and start of the assistant's response
# e.g. /gptme "generate a poem" "-" "save it to poem.txt"
# where the assistant will generate a poem, and then save it to poem.txt
MULTIPROMPT_SEPARATOR = "-"

# Prompts
ROLE_COLOR = {
Expand Down
54 changes: 37 additions & 17 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@
import gptme.cli
import pytest
from click.testing import CliRunner

CMDFIX = gptme.cli.CMDFIX
from gptme.constants import CMDFIX, MULTIPROMPT_SEPARATOR


@pytest.fixture(scope="session")
Expand Down Expand Up @@ -41,7 +40,7 @@ def test_help(runner: CliRunner):

def test_command_exit(args: list[str], runner: CliRunner):
# tests the /exit command
args.append(f"{CMDFIX}help")
args.append(f"{CMDFIX}exit")
print(f"running: gptme {' '.join(args)}")
result = runner.invoke(gptme.cli.main, args)

Expand All @@ -61,6 +60,7 @@ def test_command_help(args: list[str], runner: CliRunner):
assert result.exit_code == 0


@pytest.mark.slow
def test_command_summarize(args: list[str], runner: CliRunner):
# tests the /summarize command
args.append(f"{CMDFIX}summarize")
Expand All @@ -69,29 +69,49 @@ def test_command_summarize(args: list[str], runner: CliRunner):
assert result.exit_code == 0


def test_shell(name: str, runner: CliRunner):
result = runner.invoke(
gptme.cli.main, ["--name", name, f'{CMDFIX}shell echo "yes"']
)
def test_command_save(args: list[str], runner: CliRunner):
# tests the /save command
args.append(f"{CMDFIX}impersonate ```python\nprint('hello')\n```")
args.append(MULTIPROMPT_SEPARATOR)
args.append(f"{CMDFIX}save output.txt")
print(f"running: gptme {' '.join(args)}")
result = runner.invoke(gptme.cli.main, args)
assert result.exit_code == 0

# read the file
with open("output.txt", "r") as f:
content = f.read()
assert content == "hello"


def test_command_fork(args: list[str], runner: CliRunner, name: str):
# tests the /fork command
name += "-fork"
args.append(f"{CMDFIX}fork {name}")
print(f"running: gptme {' '.join(args)}")
result = runner.invoke(gptme.cli.main, args)
assert result.exit_code == 0


def test_shell(args: list[str], runner: CliRunner):
args.append(f"{CMDFIX}shell echo 'yes'")
result = runner.invoke(gptme.cli.main, args)
output = result.output.split("System")[-1]
# check for two 'yes' in output (both command and stdout)
assert output.count("yes") == 2, result.output
assert result.exit_code == 0


def test_python(name: str, runner: CliRunner):
result = runner.invoke(
gptme.cli.main, ["--name", name, f'{CMDFIX}python print("yes")']
)
def test_python(args: list[str], runner: CliRunner):
args.append(f"{CMDFIX}python print('yes')")
result = runner.invoke(gptme.cli.main, args)
assert "yes\n" in result.output
assert result.exit_code == 0


def test_python_error(name: str, runner: CliRunner):
result = runner.invoke(
gptme.cli.main,
["--name", name, f'{CMDFIX}python raise Exception("yes")'],
)
def test_python_error(args: list[str], runner: CliRunner):
args.append(f"{CMDFIX}python raise Exception('yes')")
result = runner.invoke(gptme.cli.main, args)
assert "Exception: yes" in result.output
assert result.exit_code == 0

Expand Down Expand Up @@ -133,7 +153,7 @@ def test_block(args: list[str], lang: str, runner: CliRunner):

# TODO: these could be fast if we had a cache
@pytest.mark.slow
def test_generate_primes(name: str, runner: CliRunner):
def test_generate_primes(args: list[str], runner: CliRunner):
args.append("print the first 10 prime numbers")
result = runner.invoke(gptme.cli.main, args)
# check that the 9th and 10th prime is present
Expand Down

0 comments on commit ef6b472

Please sign in to comment.