diff --git a/tests/data/hftest/codebook.json b/tests/data/hftest/codebook.json new file mode 100644 index 00000000..92b91a01 --- /dev/null +++ b/tests/data/hftest/codebook.json @@ -0,0 +1 @@ +{"version": 1, "id_salt": "4688a4853dafc6a3d6934f0dd02205be0700d2ca64b636127a4436494dcaf88e"} \ No newline at end of file diff --git a/tests/data/hftest/input/DocumentReference.ndjson b/tests/data/hftest/input/DocumentReference.ndjson new file mode 100644 index 00000000..6cb777d0 --- /dev/null +++ b/tests/data/hftest/input/DocumentReference.ndjson @@ -0,0 +1,2 @@ +{"id":"43","content":[{"attachment":{"contentType":"text\/plain","data":"VGVzdCBub3RlIDE="}}],"context":{"encounter":[{"reference":"Encounter\/23"}],"period":{"end":"2021-06-24","start":"2021-06-23"}},"status":"current","subject":{"reference":"Patient\/334567"},"type":{"coding":[{"code":"NOTE:149798455","display":"Admission MD","system":"http://cumulus.smarthealthit.org/i2b2"}]},"resourceType":"DocumentReference"} +{"id":"44","content":[{"attachment":{"contentType":"text\/plain","data":"VGVzdCBub3RlIDI="}}],"context":{"encounter":[{"reference":"Encounter\/25"}],"period":{"end":"2021-06-25","start":"2021-06-24"}},"status":"current","subject":{"reference":"Patient\/323456"},"type":{"coding":[{"code":"NOTE:149798455","display":"Admission MD","system":"http://cumulus.smarthealthit.org/i2b2"}]},"resourceType":"DocumentReference"} diff --git a/tests/data/hftest/output/etl__completion/etl__completion.000.ndjson b/tests/data/hftest/output/etl__completion/etl__completion.000.ndjson new file mode 100644 index 00000000..c7dedb12 --- /dev/null +++ b/tests/data/hftest/output/etl__completion/etl__completion.000.ndjson @@ -0,0 +1 @@ +{"table_name": "hftest__summary", "group_name": "test-group", "export_time": "2020-10-13T12:00:20-05:00"} diff --git a/tests/data/hftest/output/hftest__summary/hftest__summary.000.ndjson b/tests/data/hftest/output/hftest__summary/hftest__summary.000.ndjson new file mode 100644 index 00000000..3312de2e --- /dev/null +++ b/tests/data/hftest/output/hftest__summary/hftest__summary.000.ndjson @@ -0,0 +1,2 @@ +{"id": "c31a3dbf188ed241b2c06b2475cd56159017fa1df1ea882d3fc4beab860fc24d", "docref_id": "c31a3dbf188ed241b2c06b2475cd56159017fa1df1ea882d3fc4beab860fc24d", "generated_on": "2021-09-14T21:23:45+00:00", "task_version": 0, "summary": "Patient has a fever."} +{"id": "eb30741bbb9395fc3da72d02fd29b96e2e4c0c2592c3ae997d80bf522c80070e", "docref_id": "eb30741bbb9395fc3da72d02fd29b96e2e4c0c2592c3ae997d80bf522c80070e", "generated_on": "2021-09-14T21:23:45+00:00", "task_version": 0, "summary": "Patient has a fever."} diff --git a/tests/deid/test_deid_mstool.py b/tests/deid/test_deid_mstool.py index 933ba750..217572fb 100644 --- a/tests/deid/test_deid_mstool.py +++ b/tests/deid/test_deid_mstool.py @@ -1,9 +1,11 @@ """Tests for the mstool module""" +import asyncio import filecmp import os import shutil import tempfile +from unittest import mock import pytest @@ -74,3 +76,69 @@ async def test_bad_fhir(self): common.write_json(os.path.join(input_dir, "Condition.ndjson"), {}) with self.assertRaises(SystemExit): await run_mstool(input_dir, output_dir) + + +# Separate class here from the above, because this doesn't need the MS tool installed +class TestMicrosoftToolWrapper(AsyncTestCase): + """Test case for the MS tool wrapper code""" + + def setUp(self): + super().setUp() + + self.process = mock.MagicMock() + self.process.returncode = None # process not yet finished + + mock_exec = self.patch("asyncio.create_subprocess_exec") + mock_exec.return_value = self.process + + async def test_progress(self): + """Confirms that we poll for progress as we go""" + mock_progress = mock.MagicMock() + mock_wrapper = mock.MagicMock() + mock_wrapper.__enter__.return_value = mock_progress + self.patch("cumulus_etl.cli_utils.make_progress_bar", return_value=mock_wrapper) + + # We are going to stage 3 different checkpoints: + # - a couple bytes written + # - first file in place, a couple bytes of second + # - both files in place, finished + self.patch( + "asyncio.wait_for", + side_effect=[ + asyncio.TimeoutError, + asyncio.TimeoutError, + ("Out", "Err"), + ], + ) + + def fake_getsize(path: str) -> int: + match path: + case "first.ndjson": + return 10 + case "second.ndjson": + return 10 + case "tmp1.ndjson": + return 3 + case "tmp2.ndjson": + self.process.returncode = 0 # mark the process as done + return 3 + case "ghost.ndjson": + # Test that we gracefully handle files deleting underneath us + raise FileNotFoundError + + self.patch( + "glob.glob", + side_effect=[ + ["first.ndjson", "second.ndjson"], + ["tmp1.ndjson", "ghost.ndjson"], + ["first.ndjson", "tmp2.ndjson"], + ], + ) + self.patch("os.path.getsize", side_effect=fake_getsize) + + await run_mstool("/in", "/out") + + self.assertEqual(mock_progress.update.call_count, 3) + self.assertEqual(mock_progress.update.call_args_list[0].kwargs, {"completed": 3 / 20}) + self.assertEqual(mock_progress.update.call_args_list[1].kwargs, {"completed": 13 / 20}) + self.assertEqual(mock_progress.update.call_args_list[2].kwargs, {"completed": 1}) diff --git a/tests/hftest/test_hftask.py b/tests/hftest/test_hftask.py new file mode 100644 index 00000000..2240fe18 --- /dev/null +++ b/tests/hftest/test_hftask.py @@ -0,0 +1,152 @@ +"""Tests for etl/studies/hftest/""" + +import os + +import respx + +from cumulus_etl import common, errors +from cumulus_etl.etl.studies import hftest + +from tests import i2b2_mock_data +from tests.etl import BaseEtlSimple, TaskTestCase + + +def mock_prompt(respx_mock: respx.MockRouter, text: str, url: str = "http://localhost:8086/") -> respx.Route: + full_prompt = f"""[INST] <> +You will be given a clinical note, and you should reply with a short summary of that note. +<> + +{text} [/INST]""" + return respx_mock.post( + url, + json={ + "inputs": full_prompt, + "options": { + "wait_for_model": True, + }, + "parameters": { + "max_new_tokens": 1000, + }, + }, + ).respond(json=[{"generated_text": full_prompt + " Patient has a fever."}]) + + +def mock_info( + respx_mock: respx.MockRouter, url: str = "http://localhost:8086/info", override: dict = None +) -> respx.Route: + response = { + "model_id": "meta-llama/Llama-2-13b-chat-hf", + "model_sha": "0ba94ac9b9e1d5a0037780667e8b219adde1908c", + "sha": "09eca6422788b1710c54ee0d05dd6746f16bb681", + } + response.update(override or {}) + return respx_mock.get(url).respond(json=response) + + +class TestHuggingFaceTestTask(TaskTestCase): + """Test case for HuggingFaceTestTask""" + + @respx.mock(assert_all_called=True) + async def test_happy_path(self, respx_mock): + """Verify we summarize a basic note properly""" + docref0 = i2b2_mock_data.documentreference() + self.make_json("DocumentReference", "0", **docref0) + mock_prompt(respx_mock, i2b2_mock_data.DOCREF_TEXT) + + await hftest.HuggingFaceTestTask(self.job_config, self.scrubber).run() + + self.assertEqual(1, self.format.write_records.call_count) + batch = self.format.write_records.call_args[0][0] + self.assertEqual(1, len(batch.rows)) + expected_id = self.codebook.db.resource_hash("0") + self.assertEqual( + { + "id": expected_id, + "docref_id": expected_id, + "summary": "Patient has a fever.", + "generated_on": "2021-09-14T21:23:45+00:00", + "task_version": hftest.HuggingFaceTestTask.task_version, + }, + batch.rows[0], + ) + + @respx.mock(assert_all_called=True) + async def test_env_url_override(self, respx_mock): + """Verify we can override the hugging face default URL.""" + docref0 = i2b2_mock_data.documentreference() + self.make_json("DocumentReference", "0", **docref0) + + self.patch_dict(os.environ, {"CUMULUS_HUGGING_FACE_URL": "https://blarg/"}) + mock_prompt(respx_mock, i2b2_mock_data.DOCREF_TEXT, url="https://blarg/") + + await hftest.HuggingFaceTestTask(self.job_config, self.scrubber).run() + self.assertEqual(1, self.format.write_records.call_count) + + @respx.mock(assert_all_called=True) + async def test_caching(self, respx_mock): + """Verify we cache results""" + docref0 = i2b2_mock_data.documentreference() + self.make_json("DocumentReference", "0", **docref0) + route = mock_prompt(respx_mock, i2b2_mock_data.DOCREF_TEXT) + + self.assertFalse(os.path.exists(f"{self.phi_dir}/ctakes-cache")) + await hftest.HuggingFaceTestTask(self.job_config, self.scrubber).run() + + self.assertEqual(1, route.call_count) + cache_dir = f"{self.phi_dir}/ctakes-cache/hftest__summary_v0/06ee/" + cache_file = f"{cache_dir}/sha256-06ee538c626fbf4bdcec2199b7225c8034f26e2b46a7b5cb7ab385c8e8c00efa.json" + self.assertEqual("Patient has a fever.", common.read_text(cache_file)) + + await hftest.HuggingFaceTestTask(self.job_config, self.scrubber).run() + self.assertEqual(1, route.call_count) + + # Confirm that if we remove the cache file, we call the endpoint again + os.remove(cache_file) + await hftest.HuggingFaceTestTask(self.job_config, self.scrubber).run() + self.assertEqual(2, route.call_count) + + @respx.mock(assert_all_called=True) + async def test_init_check_unreachable(self, respx_mock): + """Verify we bail if the server isn't reachable""" + respx_mock.get("http://localhost:8086/info").respond(status_code=500) + with self.assertRaises(SystemExit) as cm: + await hftest.HuggingFaceTestTask.init_check() + self.assertEqual(errors.SERVICE_MISSING, cm.exception.code) + + @respx.mock(assert_all_called=True) + async def test_init_check_config(self, respx_mock): + """Verify we check the server properties""" + # Happy path + mock_info(respx_mock) + await hftest.HuggingFaceTestTask.init_check() + + # Bad model ID + mock_info(respx_mock, override={"model_id": "bogus/Llama-2-13b-chat-hf"}) + with self.assertRaises(SystemExit) as cm: + await hftest.HuggingFaceTestTask.init_check() + self.assertEqual(errors.SERVICE_MISSING, cm.exception.code) + + # Bad model SHA + mock_info(respx_mock, override={"model_sha": "bogus"}) + with self.assertRaises(SystemExit) as cm: + await hftest.HuggingFaceTestTask.init_check() + self.assertEqual(errors.SERVICE_MISSING, cm.exception.code) + + # Bad SHA + mock_info(respx_mock, override={"sha": "bogus"}) + with self.assertRaises(SystemExit) as cm: + await hftest.HuggingFaceTestTask.init_check() + self.assertEqual(errors.SERVICE_MISSING, cm.exception.code) + + +class TestHuggingFaceETL(BaseEtlSimple): + """Tests the end-to-end ETL of the hftest tasks.""" + + DATA_ROOT = "hftest" + + @respx.mock(assert_all_called=True) + async def test_basic_etl(self, respx_mock): + mock_prompt(respx_mock, text="Test note 1") + mock_prompt(respx_mock, text="Test note 2") + await self.run_etl(tasks=["hftest__summary"]) + self.assert_output_equal() diff --git a/tests/i2b2_mock_data.py b/tests/i2b2_mock_data.py index ff65eb25..0d38c1d4 100644 --- a/tests/i2b2_mock_data.py +++ b/tests/i2b2_mock_data.py @@ -2,6 +2,8 @@ from cumulus_etl.loaders.i2b2 import transform +DOCREF_TEXT = "Chief complaint: fever and chills. Denies cough." + def patient_dim() -> transform.PatientDimension: return transform.PatientDimension( @@ -63,7 +65,7 @@ def documentreference_dim() -> transform.ObservationFact: "ENCOUNTER_NUM": 67890, "CONCEPT_CD": "NOTE:149798455", # emergency room type "START_DATE": "2016-01-01", - "OBSERVATION_BLOB": "Chief complaint: fever and chills. Denies cough.", + "OBSERVATION_BLOB": DOCREF_TEXT, "TVAL_CHAR": "Emergency note", } ) diff --git a/tests/nlp/test_watcher.py b/tests/nlp/test_watcher.py new file mode 100644 index 00000000..2b2da7f2 --- /dev/null +++ b/tests/nlp/test_watcher.py @@ -0,0 +1,66 @@ +"""Tests for nlp/watcher.py""" + +import os +import tempfile +from unittest import mock + +import ddt +import respx + +from cumulus_etl import common, errors, nlp + +from tests.ctakesmock import CtakesMixin +from tests.utils import AsyncTestCase + + +class TestNLPWatcher(AsyncTestCase): + """Generic test case for service watching code""" + + @mock.patch("cumulus_etl.cli_utils.is_url_available", new=lambda x: False) + def test_ctakes_down(self): + """Verify we report cTAKES being down correctly""" + with self.assertRaises(SystemExit) as cm: + nlp.check_ctakes() + self.assertEqual(errors.CTAKES_MISSING, cm.exception.code) + + @mock.patch("cumulus_etl.cli_utils.is_url_available", new=lambda x: False) + def test_negation_cnlpt_down(self): + """Verify we report negation being down correctly""" + with self.assertRaises(SystemExit) as cm: + nlp.check_negation_cnlpt() + self.assertEqual(errors.CNLPT_MISSING, cm.exception.code) + + @mock.patch("cumulus_etl.cli_utils.is_url_available", new=lambda x: False) + def test_term_exists_cnlpt_down(self): + """Verify we report term exists being down correctly""" + with self.assertRaises(SystemExit) as cm: + nlp.check_term_exists_cnlpt() + self.assertEqual(errors.CNLPT_MISSING, cm.exception.code) + + def test_restart_ctakes_no_folder(self): + self.assertFalse(nlp.restart_ctakes_with_bsv("", "")) + + def test_restart_ctakes_nonexistent_folder(self): + with tempfile.TemporaryDirectory() as tmpdir: + self.assertFalse(nlp.restart_ctakes_with_bsv(f"{tmpdir}/nope", "")) + + def test_restart_ctakes_file_not_folder(self): + with tempfile.NamedTemporaryFile() as file: + self.assertFalse(nlp.restart_ctakes_with_bsv(file.name, "")) + + +class TestCTakesWatcher(CtakesMixin, AsyncTestCase): + """Test case for cTAKES watching code that needs a real server""" + + @mock.patch("select.poll") + @mock.patch("time.sleep", new=lambda x: None) # don't sleep during restart + def test_restart_timeout(self, mock_poll): + mock_poller = mock.MagicMock() + mock_poller.poll.return_value = False + mock_poll.return_value = mock_poller + + with tempfile.NamedTemporaryFile() as file: + common.write_text(file.name, "C0028081|T184|night sweats|Sweats") + with self.assertRaises(SystemExit) as cm: + nlp.restart_ctakes_with_bsv(self.ctakes_overrides.name, file.name) + self.assertEqual(errors.CTAKES_RESTART_FAILED, cm.exception.code)