Skip to content

Commit

Permalink
test: make test/run.py DRYer, make dev the default suite, and optimiz…
Browse files Browse the repository at this point in the history
…e a bit

While we do not need to really optimize test suites, test_web_api.py load time
was annoying me while testing run.py, so I tucked its expensive imports away.
  • Loading branch information
joanise committed Jan 31, 2024
1 parent 9132b26 commit a78ccc5
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 43 deletions.
26 changes: 15 additions & 11 deletions test/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,16 +94,23 @@ def describe_suite(suite: TestSuite):
)


SUITES = ["all", "dev", "e2e", "prod", "api", "other"]


def run_tests(suite: str, describe: bool = False) -> bool:
"""Run the specified test suite.
Args:
suite: one of "all", "dev", etc specifying which suite to run
suite: one of SUITES, "dev" if the empty string
describe: if True, list all the test cases instead of running them.
Returns: True iff success
"""

if not suite:
LOGGER.info("No test suite specified, defaulting to dev.")
suite = "dev"

if suite == "e2e":
test_suite = TestSuite(e2e_tests)
elif suite == "api":
Expand All @@ -116,8 +123,7 @@ def run_tests(suite: str, describe: bool = False) -> bool:
test_suite = TestSuite(other_tests)
else:
LOGGER.error(
"Sorry, you need to select a Test Suite to run, one of: "
"dev, all (or prod), e2e, other"
"Sorry, you need to select a Test Suite to run, one of: " + " ".join(SUITES)
)
return False

Expand All @@ -126,19 +132,17 @@ def run_tests(suite: str, describe: bool = False) -> bool:
return True
else:
runner = TextTestRunner(verbosity=3)
return runner.run(test_suite).wasSuccessful()
success = runner.run(test_suite).wasSuccessful()
if not success:
LOGGER.error("Some tests failed. Please see log above.")
return success


if __name__ == "__main__":
describe = "--describe" in sys.argv
if describe:
sys.argv.remove("--describe")

try:
result = run_tests(sys.argv[1], describe)
if not result:
LOGGER.error("Some tests failed. Please see log above.")
sys.exit(1)
except IndexError:
LOGGER.error("Please specify a test suite to run: i.e. 'dev' or 'all'")
result = run_tests("" if len(sys.argv) <= 1 else sys.argv[1], describe)
if not result:
sys.exit(1)
86 changes: 54 additions & 32 deletions test/test_web_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,28 @@
from unittest import main

from basic_test_case import BasicTestCase
from fastapi.testclient import TestClient
from lxml import etree

from readalongs.log import LOGGER
from readalongs.text.add_ids_to_xml import add_ids
from readalongs.text.convert_xml import convert_xml
from readalongs.text.tokenize_xml import tokenize_xml
from readalongs.util import get_langs
from readalongs.web_api import OutputFormat, create_grammar, web_api_app

API_CLIENT = TestClient(web_api_app)


class TestWebApi(BasicTestCase):
_API_CLIENT = None

@property
def API_CLIENT(self):
from fastapi.testclient import TestClient

from readalongs.web_api import web_api_app

if TestWebApi._API_CLIENT is None:
TestWebApi._API_CLIENT = TestClient(web_api_app)
return TestWebApi._API_CLIENT

def slurp_data_file(self, filename: str) -> str:
"""Convenience function to slurp a whole file in self.data_dir"""
with open(os.path.join(self.data_dir, filename), encoding="utf8") as f:
Expand All @@ -32,17 +40,17 @@ def test_assemble_from_plain_text(self):
"type": "text/plain",
"text_languages": ["fra"],
}
response = API_CLIENT.post("/api/v1/assemble", json=request)
response = self.API_CLIENT.post("/api/v1/assemble", json=request)
self.assertEqual(response.status_code, 200)

def test_bad_path(self):
# Test a request to a path that doesn't exist
response = API_CLIENT.get("/pathdoesntexist")
response = self.API_CLIENT.get("/pathdoesntexist")
self.assertEqual(response.status_code, 404)

def test_bad_method(self):
# Test a request to a valid path with a bad method
response = API_CLIENT.get("/api/v1/assemble")
response = self.API_CLIENT.get("/api/v1/assemble")
self.assertEqual(response.status_code, 405)

def test_assemble_from_xml(self):
Expand All @@ -53,7 +61,7 @@ def test_assemble_from_xml(self):
"type": "application/readalong+xml",
"text_languages": ["fra"],
}
response = API_CLIENT.post("/api/v1/assemble", json=request)
response = self.API_CLIENT.post("/api/v1/assemble", json=request)
self.assertEqual(response.status_code, 200)

def test_illformed_xml(self):
Expand All @@ -63,7 +71,7 @@ def test_illformed_xml(self):
"type": "application/readalong+xml",
"text_languages": ["fra"],
}
response = API_CLIENT.post("/api/v1/assemble", json=request)
response = self.API_CLIENT.post("/api/v1/assemble", json=request)
self.assertEqual(response.status_code, 422)

def test_invalid_ras(self):
Expand All @@ -73,7 +81,7 @@ def test_invalid_ras(self):
"type": "application/readalong+xml",
"text_languages": ["fra"],
}
response = API_CLIENT.post("/api/v1/assemble", json=request)
response = self.API_CLIENT.post("/api/v1/assemble", json=request)
self.assertEqual(response.status_code, 422)

def test_create_grammar(self):
Expand All @@ -84,6 +92,8 @@ def test_create_grammar(self):
tokenized = tokenize_xml(parsed)
ids_added = add_ids(tokenized)
g2ped, valid = convert_xml(ids_added)
from readalongs.web_api import create_grammar

word_dict, text = create_grammar(g2ped)
self.assertTrue(valid)
self.assertEqual(len(word_dict), len(text.split()))
Expand All @@ -96,7 +106,7 @@ def test_bad_g2p(self):
"type": "text/plain",
"text_languages": ["test"],
}
response = API_CLIENT.post("/api/v1/assemble", json=request)
response = self.API_CLIENT.post("/api/v1/assemble", json=request)
self.assertIn("No language called", response.json()["detail"])
self.assertEqual(response.status_code, 422)

Expand All @@ -107,7 +117,7 @@ def test_g2p_faiture(self):
"type": "text/plain",
"text_languages": ["fra"],
}
response = API_CLIENT.post("/api/v1/assemble", json=request)
response = self.API_CLIENT.post("/api/v1/assemble", json=request)
self.assertEqual(response.status_code, 422)
content = response.json()
self.assertIn("No valid g2p conversion", content["detail"])
Expand All @@ -119,7 +129,7 @@ def test_no_words(self):
"type": "text/plain",
"text_languages": ["eng"],
}
response = API_CLIENT.post("/api/v1/assemble", json=request)
response = self.API_CLIENT.post("/api/v1/assemble", json=request)
self.assertEqual(response.status_code, 422)
content = response.json()
self.assertIn("Could not find any words", content["detail"])
Expand All @@ -132,15 +142,15 @@ def test_empty_g2p(self):
"type": "text/plain",
"text_languages": ["eng", "und"],
}
response = API_CLIENT.post("/api/v1/assemble", json=request)
response = self.API_CLIENT.post("/api/v1/assemble", json=request)
self.assertEqual(response.status_code, 422)
content_log = response.json()["detail"]
for message_part in ["The output of the g2p process", "24", "23", "is empty"]:
self.assertIn(message_part, content_log)

def test_langs(self):
# Test the langs endpoint
response = API_CLIENT.get("/api/v1/langs")
response = self.API_CLIENT.get("/api/v1/langs")
codes = [x["code"] for x in response.json()]
self.assertEqual(set(codes), set(get_langs()[0]))
self.assertEqual(codes, list(sorted(codes)))
Expand All @@ -156,7 +166,7 @@ def test_logs(self):
"debug": True,
"text_languages": ["fra", "und"],
}
response = API_CLIENT.post("/api/v1/assemble", json=request)
response = self.API_CLIENT.post("/api/v1/assemble", json=request)
content = response.json()
# print("Content", content)
self.assertIn('Could not g2p "ña" as French', content["log"])
Expand All @@ -169,7 +179,7 @@ def test_debug(self):
"debug": True,
"text_languages": ["fra"],
}
response = API_CLIENT.post("/api/v1/assemble", json=request)
response = self.API_CLIENT.post("/api/v1/assemble", json=request)
content = response.json()
self.assertEqual(content["input"], request)
self.assertGreater(len(content["tokenized"]), 10)
Expand All @@ -182,7 +192,7 @@ def test_debug(self):
"type": "text/plain",
"text_languages": ["fra"],
}
response = API_CLIENT.post("/api/v1/assemble", json=request)
response = self.API_CLIENT.post("/api/v1/assemble", json=request)
content = response.json()
self.assertIsNone(content["input"])
self.assertIsNone(content["tokenized"])
Expand Down Expand Up @@ -214,7 +224,9 @@ def test_convert_to_TextGrid(self):
"dur": 83.1,
"ras": self.hej_verden_xml,
}
response = API_CLIENT.post("/api/v1/convert_alignment/textgrid", json=request)
response = self.API_CLIENT.post(
"/api/v1/convert_alignment/textgrid", json=request
)
self.assertEqual(response.status_code, 200)
self.assertIn("aligned.TextGrid", response.headers["content-disposition"])
self.assertEqual(
Expand Down Expand Up @@ -276,7 +288,9 @@ class = "IntervalTier"
request = {
"ras": self.hej_verden_xml,
}
response = API_CLIENT.post("/api/v1/convert_alignment/textgrid", json=request)
response = self.API_CLIENT.post(
"/api/v1/convert_alignment/textgrid", json=request
)
self.assertEqual(response.status_code, 200)
self.assertIn("aligned.TextGrid", response.headers["content-disposition"])
self.assertNotIn("xmax = 83.100000", response.text)
Expand All @@ -286,7 +300,7 @@ def test_convert_to_eaf(self):
"dur": 83.1,
"ras": self.hej_verden_xml,
}
response = API_CLIENT.post("/api/v1/convert_alignment/eaf", json=request)
response = self.API_CLIENT.post("/api/v1/convert_alignment/eaf", json=request)
self.assertEqual(response.status_code, 200)
self.assertIn("<ANNOTATION_DOCUMENT", response.text)
self.assertIn("aligned.eaf", response.headers["content-disposition"])
Expand All @@ -296,7 +310,7 @@ def test_convert_to_srt(self):
"dur": 83.1,
"ras": self.hej_verden_xml,
}
response = API_CLIENT.post("/api/v1/convert_alignment/srt", json=request)
response = self.API_CLIENT.post("/api/v1/convert_alignment/srt", json=request)
self.assertEqual(response.status_code, 200)
self.assertIn("aligned_sentences.srt", response.headers["content-disposition"])
self.assertEqual(
Expand All @@ -311,7 +325,7 @@ def test_convert_to_srt(self):
),
)

response = API_CLIENT.post(
response = self.API_CLIENT.post(
"/api/v1/convert_alignment/srt?tier=word", json=request
)
self.assertEqual(response.status_code, 200)
Expand All @@ -338,7 +352,7 @@ def test_convert_to_vtt(self):
"dur": 83.1,
"ras": self.hej_verden_xml,
}
response = API_CLIENT.post(
response = self.API_CLIENT.post(
"/api/v1/convert_alignment/vtt?tier=sentence", json=request
)
self.assertEqual(response.status_code, 200)
Expand All @@ -355,7 +369,7 @@ def test_convert_to_vtt(self):
),
)

response = API_CLIENT.post(
response = self.API_CLIENT.post(
"/api/v1/convert_alignment/vtt?tier=word", json=request
)
self.assertEqual(response.status_code, 200)
Expand All @@ -380,14 +394,18 @@ def test_convert_to_TextGrid_errors(self):
"dur": 83.1,
"ras": "this is not XML",
}
response = API_CLIENT.post("/api/v1/convert_alignment/textgrid", json=request)
response = self.API_CLIENT.post(
"/api/v1/convert_alignment/textgrid", json=request
)
self.assertEqual(response.status_code, 422, "Invalid XML should fail.")

request = {
"dur": -10.0,
"ras": self.hej_verden_xml,
}
response = API_CLIENT.post("/api/v1/convert_alignment/textgrid", json=request)
response = self.API_CLIENT.post(
"/api/v1/convert_alignment/textgrid", json=request
)
self.assertEqual(response.status_code, 422, "Negative duration should fail.")

def test_cleanup_temp_dir(self):
Expand All @@ -397,7 +415,7 @@ def test_cleanup_temp_dir(self):
"ras": self.hej_verden_xml,
}
with self.assertLogs(LOGGER, "INFO") as log_cm:
response = API_CLIENT.post(
response = self.API_CLIENT.post(
"/api/v1/convert_alignment/textgrid", json=request
)
self.assertEqual(response.status_code, 200)
Expand Down Expand Up @@ -437,9 +455,11 @@ def test_cleanup_even_if_error(self):
"dur": 83.1,
"ras": overlap_xml,
}
from readalongs.web_api import OutputFormat

for format_name in OutputFormat:
with self.assertLogs(LOGGER, "INFO") as log_cm:
response = API_CLIENT.post(
response = self.API_CLIENT.post(
f"/api/v1/convert_alignment/{format_name.value}", json=request
)
self.assertEqual(response.status_code, 422)
Expand All @@ -455,17 +475,19 @@ def test_convert_to_bad_format(self):
"dur": 83.1,
"ras": self.hej_verden_xml,
}
response = API_CLIENT.post("/api/v1/convert_alignment/badformat", json=request)
response = self.API_CLIENT.post(
"/api/v1/convert_alignment/badformat", json=request
)
self.assertEqual(response.status_code, 422)

request = {
"dur": 83.1,
"ras": self.hej_verden_xml,
}
response = API_CLIENT.post("/api/v1/convert_alignment", json=request)
response = self.API_CLIENT.post("/api/v1/convert_alignment", json=request)
self.assertEqual(response.status_code, 404)

response = API_CLIENT.post(
response = self.API_CLIENT.post(
"/api/v1/convert_alignment/vtt?tier=badtier", json=request
)
self.assertEqual(response.status_code, 422)
Expand Down

0 comments on commit a78ccc5

Please sign in to comment.