From 7ec04886774579757fba1a4ca6eb658fccbcb752 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Fri, 12 Apr 2024 18:40:10 +0000 Subject: [PATCH] add --skip and --only to test driver --- controllers/pyctrl/driver.py | 41 +++++++++++++++++++++++++++--- controllers/pyctrl/samples/test.py | 7 ++--- scripts/test-jsctrl.sh | 2 +- scripts/test-pyctrl.sh | 2 +- 4 files changed, 43 insertions(+), 9 deletions(-) diff --git a/controllers/pyctrl/driver.py b/controllers/pyctrl/driver.py index c43ca945..c105ebc7 100644 --- a/controllers/pyctrl/driver.py +++ b/controllers/pyctrl/driver.py @@ -1,6 +1,7 @@ import sys import os import re +import argparse import pyaici.rest import pyaici.util @@ -51,10 +52,38 @@ def main(): js_mode = False cmt = "#" - files = sys.argv[1:] - if not files: - print("need some python files as input") - return + parser = argparse.ArgumentParser( + description="Run pyctrl or jsctrl tests", + prog="ctrldriver", + ) + + parser.add_argument( + "--skip", + "-s", + type=str, + default=[], + action="append", + help="skip tests matching string", + ) + + parser.add_argument( + "--only", + "-k", + type=str, + default=[], + action="append", + help="only run tests matching string", + ) + + parser.add_argument( + "test_file", + nargs="+", + help="files to test", + ) + + args = parser.parse_args() + + files = args.test_file if files[0].endswith(".js"): js_mode = True @@ -77,6 +106,10 @@ def main(): else: tests = re.findall(r"^async def (test_\w+)\(.*", arg, flags=re.MULTILINE) for t in tests: + if any([s in t for s in args.skip]): + continue + if args.only and not any([s in t for s in args.only]): + continue if js_mode: arg_t = f"{arg}\ntest({t});\n" else: diff --git a/controllers/pyctrl/samples/test.py b/controllers/pyctrl/samples/test.py index e2744134..c508a0a3 100644 --- a/controllers/pyctrl/samples/test.py +++ b/controllers/pyctrl/samples/test.py @@ -49,10 +49,11 @@ async def test_hello(): prompt = await aici.GetPrompt() print("prompt", prompt) await aici.gen_tokens(regex=r"[A-Z].*", max_tokens=5) + await aici.FixedTokens("\n2 +") l = aici.Label() - await aici.FixedTokens("\n2 + 2 = ") + await aici.FixedTokens(" 2 = ") await aici.gen_tokens(regex=r"\d+", max_tokens=1) - await aici.FixedTokens("\n3 + 3 = ", following=l) + await aici.FixedTokens(" 3 = ", following=l) await aici.gen_tokens(regex=r"\d+", max_tokens=1) @@ -187,4 +188,4 @@ async def test_joke(): await aici.gen_text(max_tokens=15) -aici.test(test_hello()) +aici.test(test_drugs()) diff --git a/scripts/test-jsctrl.sh b/scripts/test-jsctrl.sh index 8fe2ab79..c0516a85 100755 --- a/scripts/test-jsctrl.sh +++ b/scripts/test-jsctrl.sh @@ -7,4 +7,4 @@ cd $HERE/../controllers/jsctrl tsc --version || npm install -g typescript tsc -p samples PYTHONPATH=$HERE/../py \ -python3 ../pyctrl/driver.py samples/dist/test.js +python3 ../pyctrl/driver.py samples/dist/test.js "$@" diff --git a/scripts/test-pyctrl.sh b/scripts/test-pyctrl.sh index 2ed4843d..dac92c3f 100755 --- a/scripts/test-pyctrl.sh +++ b/scripts/test-pyctrl.sh @@ -5,4 +5,4 @@ cd `dirname $0` HERE=`pwd` cd $HERE/../controllers/pyctrl PYTHONPATH=$HERE/../py \ -python3 driver.py samples/test*.py +python3 driver.py samples/test*.py "$@"