Skip to content

Commit

Permalink
Merge pull request #330 from smart-on-fhir/mikix/more-tests
Browse files Browse the repository at this point in the history
tests: add more coverage
  • Loading branch information
mikix authored Jul 11, 2024
2 parents d9ecb3e + a1c96b9 commit d398ca3
Show file tree
Hide file tree
Showing 8 changed files with 295 additions and 1 deletion.
1 change: 1 addition & 0 deletions tests/data/hftest/codebook.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"version": 1, "id_salt": "4688a4853dafc6a3d6934f0dd02205be0700d2ca64b636127a4436494dcaf88e"}
2 changes: 2 additions & 0 deletions tests/data/hftest/input/DocumentReference.ndjson
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
{"id":"43","content":[{"attachment":{"contentType":"text\/plain","data":"VGVzdCBub3RlIDE="}}],"context":{"encounter":[{"reference":"Encounter\/23"}],"period":{"end":"2021-06-24","start":"2021-06-23"}},"status":"current","subject":{"reference":"Patient\/334567"},"type":{"coding":[{"code":"NOTE:149798455","display":"Admission MD","system":"http://cumulus.smarthealthit.org/i2b2"}]},"resourceType":"DocumentReference"}
{"id":"44","content":[{"attachment":{"contentType":"text\/plain","data":"VGVzdCBub3RlIDI="}}],"context":{"encounter":[{"reference":"Encounter\/25"}],"period":{"end":"2021-06-25","start":"2021-06-24"}},"status":"current","subject":{"reference":"Patient\/323456"},"type":{"coding":[{"code":"NOTE:149798455","display":"Admission MD","system":"http://cumulus.smarthealthit.org/i2b2"}]},"resourceType":"DocumentReference"}
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"table_name": "hftest__summary", "group_name": "test-group", "export_time": "2020-10-13T12:00:20-05:00"}
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
{"id": "c31a3dbf188ed241b2c06b2475cd56159017fa1df1ea882d3fc4beab860fc24d", "docref_id": "c31a3dbf188ed241b2c06b2475cd56159017fa1df1ea882d3fc4beab860fc24d", "generated_on": "2021-09-14T21:23:45+00:00", "task_version": 0, "summary": "Patient has a fever."}
{"id": "eb30741bbb9395fc3da72d02fd29b96e2e4c0c2592c3ae997d80bf522c80070e", "docref_id": "eb30741bbb9395fc3da72d02fd29b96e2e4c0c2592c3ae997d80bf522c80070e", "generated_on": "2021-09-14T21:23:45+00:00", "task_version": 0, "summary": "Patient has a fever."}
68 changes: 68 additions & 0 deletions tests/deid/test_deid_mstool.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
"""Tests for the mstool module"""

import asyncio
import filecmp
import os
import shutil
import tempfile
from unittest import mock

import pytest

Expand Down Expand Up @@ -74,3 +76,69 @@ async def test_bad_fhir(self):
common.write_json(os.path.join(input_dir, "Condition.ndjson"), {})
with self.assertRaises(SystemExit):
await run_mstool(input_dir, output_dir)


# Separate class here from the above, because this doesn't need the MS tool installed
class TestMicrosoftToolWrapper(AsyncTestCase):
"""Test case for the MS tool wrapper code"""

def setUp(self):
super().setUp()

self.process = mock.MagicMock()
self.process.returncode = None # process not yet finished

mock_exec = self.patch("asyncio.create_subprocess_exec")
mock_exec.return_value = self.process

async def test_progress(self):
"""Confirms that we poll for progress as we go"""
mock_progress = mock.MagicMock()
mock_wrapper = mock.MagicMock()
mock_wrapper.__enter__.return_value = mock_progress
self.patch("cumulus_etl.cli_utils.make_progress_bar", return_value=mock_wrapper)

# We are going to stage 3 different checkpoints:
# - a couple bytes written
# - first file in place, a couple bytes of second
# - both files in place, finished
self.patch(
"asyncio.wait_for",
side_effect=[
asyncio.TimeoutError,
asyncio.TimeoutError,
("Out", "Err"),
],
)

def fake_getsize(path: str) -> int:
match path:
case "first.ndjson":
return 10
case "second.ndjson":
return 10
case "tmp1.ndjson":
return 3
case "tmp2.ndjson":
self.process.returncode = 0 # mark the process as done
return 3
case "ghost.ndjson":
# Test that we gracefully handle files deleting underneath us
raise FileNotFoundError

self.patch(
"glob.glob",
side_effect=[
["first.ndjson", "second.ndjson"],
["tmp1.ndjson", "ghost.ndjson"],
["first.ndjson", "tmp2.ndjson"],
],
)
self.patch("os.path.getsize", side_effect=fake_getsize)

await run_mstool("/in", "/out")

self.assertEqual(mock_progress.update.call_count, 3)
self.assertEqual(mock_progress.update.call_args_list[0].kwargs, {"completed": 3 / 20})
self.assertEqual(mock_progress.update.call_args_list[1].kwargs, {"completed": 13 / 20})
self.assertEqual(mock_progress.update.call_args_list[2].kwargs, {"completed": 1})
152 changes: 152 additions & 0 deletions tests/hftest/test_hftask.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
"""Tests for etl/studies/hftest/"""

import os

import respx

from cumulus_etl import common, errors
from cumulus_etl.etl.studies import hftest

from tests import i2b2_mock_data
from tests.etl import BaseEtlSimple, TaskTestCase


def mock_prompt(respx_mock: respx.MockRouter, text: str, url: str = "http://localhost:8086/") -> respx.Route:
full_prompt = f"""<s>[INST] <<SYS>>
You will be given a clinical note, and you should reply with a short summary of that note.
<</SYS>>
{text} [/INST]"""
return respx_mock.post(
url,
json={
"inputs": full_prompt,
"options": {
"wait_for_model": True,
},
"parameters": {
"max_new_tokens": 1000,
},
},
).respond(json=[{"generated_text": full_prompt + " Patient has a fever."}])


def mock_info(
respx_mock: respx.MockRouter, url: str = "http://localhost:8086/info", override: dict = None
) -> respx.Route:
response = {
"model_id": "meta-llama/Llama-2-13b-chat-hf",
"model_sha": "0ba94ac9b9e1d5a0037780667e8b219adde1908c",
"sha": "09eca6422788b1710c54ee0d05dd6746f16bb681",
}
response.update(override or {})
return respx_mock.get(url).respond(json=response)


class TestHuggingFaceTestTask(TaskTestCase):
"""Test case for HuggingFaceTestTask"""

@respx.mock(assert_all_called=True)
async def test_happy_path(self, respx_mock):
"""Verify we summarize a basic note properly"""
docref0 = i2b2_mock_data.documentreference()
self.make_json("DocumentReference", "0", **docref0)
mock_prompt(respx_mock, i2b2_mock_data.DOCREF_TEXT)

await hftest.HuggingFaceTestTask(self.job_config, self.scrubber).run()

self.assertEqual(1, self.format.write_records.call_count)
batch = self.format.write_records.call_args[0][0]
self.assertEqual(1, len(batch.rows))
expected_id = self.codebook.db.resource_hash("0")
self.assertEqual(
{
"id": expected_id,
"docref_id": expected_id,
"summary": "Patient has a fever.",
"generated_on": "2021-09-14T21:23:45+00:00",
"task_version": hftest.HuggingFaceTestTask.task_version,
},
batch.rows[0],
)

@respx.mock(assert_all_called=True)
async def test_env_url_override(self, respx_mock):
"""Verify we can override the hugging face default URL."""
docref0 = i2b2_mock_data.documentreference()
self.make_json("DocumentReference", "0", **docref0)

self.patch_dict(os.environ, {"CUMULUS_HUGGING_FACE_URL": "https://blarg/"})
mock_prompt(respx_mock, i2b2_mock_data.DOCREF_TEXT, url="https://blarg/")

await hftest.HuggingFaceTestTask(self.job_config, self.scrubber).run()
self.assertEqual(1, self.format.write_records.call_count)

@respx.mock(assert_all_called=True)
async def test_caching(self, respx_mock):
"""Verify we cache results"""
docref0 = i2b2_mock_data.documentreference()
self.make_json("DocumentReference", "0", **docref0)
route = mock_prompt(respx_mock, i2b2_mock_data.DOCREF_TEXT)

self.assertFalse(os.path.exists(f"{self.phi_dir}/ctakes-cache"))
await hftest.HuggingFaceTestTask(self.job_config, self.scrubber).run()

self.assertEqual(1, route.call_count)
cache_dir = f"{self.phi_dir}/ctakes-cache/hftest__summary_v0/06ee/"
cache_file = f"{cache_dir}/sha256-06ee538c626fbf4bdcec2199b7225c8034f26e2b46a7b5cb7ab385c8e8c00efa.json"
self.assertEqual("Patient has a fever.", common.read_text(cache_file))

await hftest.HuggingFaceTestTask(self.job_config, self.scrubber).run()
self.assertEqual(1, route.call_count)

# Confirm that if we remove the cache file, we call the endpoint again
os.remove(cache_file)
await hftest.HuggingFaceTestTask(self.job_config, self.scrubber).run()
self.assertEqual(2, route.call_count)

@respx.mock(assert_all_called=True)
async def test_init_check_unreachable(self, respx_mock):
"""Verify we bail if the server isn't reachable"""
respx_mock.get("http://localhost:8086/info").respond(status_code=500)
with self.assertRaises(SystemExit) as cm:
await hftest.HuggingFaceTestTask.init_check()
self.assertEqual(errors.SERVICE_MISSING, cm.exception.code)

@respx.mock(assert_all_called=True)
async def test_init_check_config(self, respx_mock):
"""Verify we check the server properties"""
# Happy path
mock_info(respx_mock)
await hftest.HuggingFaceTestTask.init_check()

# Bad model ID
mock_info(respx_mock, override={"model_id": "bogus/Llama-2-13b-chat-hf"})
with self.assertRaises(SystemExit) as cm:
await hftest.HuggingFaceTestTask.init_check()
self.assertEqual(errors.SERVICE_MISSING, cm.exception.code)

# Bad model SHA
mock_info(respx_mock, override={"model_sha": "bogus"})
with self.assertRaises(SystemExit) as cm:
await hftest.HuggingFaceTestTask.init_check()
self.assertEqual(errors.SERVICE_MISSING, cm.exception.code)

# Bad SHA
mock_info(respx_mock, override={"sha": "bogus"})
with self.assertRaises(SystemExit) as cm:
await hftest.HuggingFaceTestTask.init_check()
self.assertEqual(errors.SERVICE_MISSING, cm.exception.code)


class TestHuggingFaceETL(BaseEtlSimple):
"""Tests the end-to-end ETL of the hftest tasks."""

DATA_ROOT = "hftest"

@respx.mock(assert_all_called=True)
async def test_basic_etl(self, respx_mock):
mock_prompt(respx_mock, text="Test note 1")
mock_prompt(respx_mock, text="Test note 2")
await self.run_etl(tasks=["hftest__summary"])
self.assert_output_equal()
4 changes: 3 additions & 1 deletion tests/i2b2_mock_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from cumulus_etl.loaders.i2b2 import transform

DOCREF_TEXT = "Chief complaint: fever and chills. Denies cough."


def patient_dim() -> transform.PatientDimension:
return transform.PatientDimension(
Expand Down Expand Up @@ -63,7 +65,7 @@ def documentreference_dim() -> transform.ObservationFact:
"ENCOUNTER_NUM": 67890,
"CONCEPT_CD": "NOTE:149798455", # emergency room type
"START_DATE": "2016-01-01",
"OBSERVATION_BLOB": "Chief complaint: fever and chills. Denies cough.",
"OBSERVATION_BLOB": DOCREF_TEXT,
"TVAL_CHAR": "Emergency note",
}
)
Expand Down
66 changes: 66 additions & 0 deletions tests/nlp/test_watcher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
"""Tests for nlp/watcher.py"""

import os
import tempfile
from unittest import mock

import ddt
import respx

from cumulus_etl import common, errors, nlp

from tests.ctakesmock import CtakesMixin
from tests.utils import AsyncTestCase


class TestNLPWatcher(AsyncTestCase):
"""Generic test case for service watching code"""

@mock.patch("cumulus_etl.cli_utils.is_url_available", new=lambda x: False)
def test_ctakes_down(self):
"""Verify we report cTAKES being down correctly"""
with self.assertRaises(SystemExit) as cm:
nlp.check_ctakes()
self.assertEqual(errors.CTAKES_MISSING, cm.exception.code)

@mock.patch("cumulus_etl.cli_utils.is_url_available", new=lambda x: False)
def test_negation_cnlpt_down(self):
"""Verify we report negation being down correctly"""
with self.assertRaises(SystemExit) as cm:
nlp.check_negation_cnlpt()
self.assertEqual(errors.CNLPT_MISSING, cm.exception.code)

@mock.patch("cumulus_etl.cli_utils.is_url_available", new=lambda x: False)
def test_term_exists_cnlpt_down(self):
"""Verify we report term exists being down correctly"""
with self.assertRaises(SystemExit) as cm:
nlp.check_term_exists_cnlpt()
self.assertEqual(errors.CNLPT_MISSING, cm.exception.code)

def test_restart_ctakes_no_folder(self):
self.assertFalse(nlp.restart_ctakes_with_bsv("", ""))

def test_restart_ctakes_nonexistent_folder(self):
with tempfile.TemporaryDirectory() as tmpdir:
self.assertFalse(nlp.restart_ctakes_with_bsv(f"{tmpdir}/nope", ""))

def test_restart_ctakes_file_not_folder(self):
with tempfile.NamedTemporaryFile() as file:
self.assertFalse(nlp.restart_ctakes_with_bsv(file.name, ""))


class TestCTakesWatcher(CtakesMixin, AsyncTestCase):
"""Test case for cTAKES watching code that needs a real server"""

@mock.patch("select.poll")
@mock.patch("time.sleep", new=lambda x: None) # don't sleep during restart
def test_restart_timeout(self, mock_poll):
mock_poller = mock.MagicMock()
mock_poller.poll.return_value = False
mock_poll.return_value = mock_poller

with tempfile.NamedTemporaryFile() as file:
common.write_text(file.name, "C0028081|T184|night sweats|Sweats")
with self.assertRaises(SystemExit) as cm:
nlp.restart_ctakes_with_bsv(self.ctakes_overrides.name, file.name)
self.assertEqual(errors.CTAKES_RESTART_FAILED, cm.exception.code)

0 comments on commit d398ca3

Please sign in to comment.