diff --git a/pyproject.toml b/pyproject.toml index e71f02f7..77ddecc5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -170,9 +170,12 @@ ignore = [ # ANN2 - missing-return-type # ANN102 - missing-type-cls # S101 - assert +# B011 - assert-false +# D104 - undocumented-public-package +# D100 - undocumented-public-module # INP001 - implicit-namespace-package # ARG001 - unused-function-argument # SLF001 - private-member-acces # N815 - mixed-case-variable-in-class-scope -"tests/*" = ["ANN001", "ANN2", "ANN102", "S101", "INP001", "SLF001", "ARG001"] +"tests/*" = ["ANN001", "ANN2", "ANN102", "S101", "B011", "D100", "D104", "INP001", "SLF001", "ARG001"] "src/metakb/schemas/*" = ["ANN102", "N815"] diff --git a/src/metakb/main.py b/src/metakb/main.py index 4fe43db0..964ae923 100644 --- a/src/metakb/main.py +++ b/src/metakb/main.py @@ -9,7 +9,14 @@ from metakb import __version__ from metakb.log_handle import configure_logs -from metakb.query import QueryHandler +from metakb.query import PaginationParamError, QueryHandler +from metakb.schemas.api import ( + BatchSearchStudiesQuery, + BatchSearchStudiesService, + SearchStudiesQuery, + SearchStudiesService, + ServiceMeta, +) query = QueryHandler() @@ -70,13 +77,13 @@ def custom_openapi() -> dict: t_description = "Therapy (object) to search" g_description = "Gene to search" s_description = "Study ID to search." -search_study_response_descr = "A response to a validly-formed query." +start_description = "The index of the first result to return. Use for pagination." +limit_description = "The maximum number of results to return. Use for pagination." @app.get( "/api/v2/search/studies", summary=search_studies_summary, - response_description=search_study_response_descr, description=search_studies_descr, ) async def get_studies( @@ -85,6 +92,8 @@ async def get_studies( therapy: Annotated[str | None, Query(description=t_description)] = None, gene: Annotated[str | None, Query(description=g_description)] = None, study_id: Annotated[str | None, Query(description=s_description)] = None, + start: Annotated[int, Query(description=start_description)] = 0, + limit: Annotated[int | None, Query(description=limit_description)] = None, ) -> dict: """Get nested studies from queried concepts that match all conditions provided. For example, if `variation` and `therapy` are provided, will return all studies @@ -95,35 +104,66 @@ async def get_studies( :param therapy: Therapy query :param gene: Gene query :param study_id: Study ID query. + :param start: The index of the first result to return. Use for pagination. + :param limit: The maximum number of results to return. Use for pagination. :return: SearchStudiesService response containing nested studies and service metadata """ - resp = await query.search_studies(variation, disease, therapy, gene, study_id) + try: + resp = await query.search_studies( + variation, disease, therapy, gene, study_id, start, limit + ) + except PaginationParamError: + resp = SearchStudiesService( + query=SearchStudiesQuery( + variation=variation, + disease=disease, + therapy=therapy, + gene=gene, + study_id=study_id, + ), + service_meta_=ServiceMeta(), + warnings=["`start` and `limit` params must both be nonnegative"], + ) return resp.model_dump(exclude_none=True) -_batch_search_studies_descr = { +_batch_descr = { "summary": "Get nested studies for all provided variations.", "description": "Return nested studies associated with any of the provided variations.", "arg_variations": "Variations (subject) to search. Can be free text or VRS variation ID.", + "arg_start": "The index of the first result to return. Use for pagination.", + "arg_limit": "The maximum number of results to return. Use for pagination.", } @app.get( "/api/v2/batch_search/studies", - summary=_batch_search_studies_descr["summary"], - description=_batch_search_studies_descr["description"], + summary=_batch_descr["summary"], + description=_batch_descr["description"], ) async def batch_get_studies( variations: Annotated[ list[str] | None, - Query(description=_batch_search_studies_descr["arg_variations"]), + Query(description=_batch_descr["arg_variations"]), ] = None, + start: Annotated[int, Query(description=_batch_descr["arg_start"])] = 0, + limit: Annotated[int | None, Query(description=_batch_descr["arg_limit"])] = None, ) -> dict: """Fetch all studies associated with `any` of the provided variations. :param variations: variations to match against + :param start: The index of the first result to return. Use for pagination. + :param limit: The maximum number of results to return. Use for pagination. :return: batch response object """ - response = await query.batch_search_studies(variations) + try: + response = await query.batch_search_studies(variations, start, limit) + except PaginationParamError: + response = BatchSearchStudiesService( + query=BatchSearchStudiesQuery(variations=[]), + service_meta_=ServiceMeta(), + warnings=["`start` and `limit` params must both be nonnegative"], + ) + return response.model_dump(exclude_none=True) diff --git a/src/metakb/query.py b/src/metakb/query.py index 34e59692..6d1fe3b3 100644 --- a/src/metakb/query.py +++ b/src/metakb/query.py @@ -38,6 +38,10 @@ logger = logging.getLogger(__name__) +class PaginationParamError(Exception): + """Raise for invalid pagination parameters.""" + + class VariationRelation(str, Enum): """Constrain possible values for the relationship between variations and categorical variations. @@ -83,6 +87,7 @@ def __init__( self, driver: Driver | None = None, normalizers: ViccNormalizers | None = None, + default_page_limit: int | None = None, ) -> None: """Initialize neo4j driver and the VICC normalizers. @@ -101,8 +106,27 @@ def __init__( ... ViccNormalizers("http://localhost:8000") ... ) + ``default_page_limit`` sets the default max number of studies to include in + query responses: + + >>> limited_qh = QueryHandler(default_page_limit=10) + >>> response = await limited_qh.batch_search_studies(["BRAF V600E"]) + >>> print(len(response.study_ids)) + 10 + + This value is overruled by an explicit ``limit`` parameter: + + >>> response = await limited_qh.batch_search_studies( + ... ["BRAF V600E"], + ... limit=2 + ... ) + >>> print(len(response.study_ids)) + 2 + :param driver: driver instance for graph connection :param normalizers: normalizer collection instance + :param default_page_limit: default number of results per response page (leave + as ``None`` for no default limit) """ if driver is None: driver = get_driver() @@ -110,6 +134,7 @@ def __init__( normalizers = ViccNormalizers() self.driver = driver self.vicc_normalizers = normalizers + self._default_page_limit = default_page_limit async def search_studies( self, @@ -118,6 +143,8 @@ async def search_studies( therapy: str | None = None, gene: str | None = None, study_id: str | None = None, + start: int = 0, + limit: int | None = None, ) -> SearchStudiesService: """Get nested studies from queried concepts that match all conditions provided. For example, if ``variation`` and ``therapy`` are provided, will return all studies @@ -131,7 +158,6 @@ async def search_studies( >>> result.studies[0].isReportedIn[0].url 'https://www.accessdata.fda.gov/drugsatfda_docs/label/2020/202429s019lbl.pdf' - Variation, disease, therapy, and gene terms are resolved via their respective :ref:`concept normalization services`. @@ -146,8 +172,18 @@ async def search_studies( :param gene: Gene query. Common shorthand name, e.g. ``"NTRK1"``, or compact URI, e.g. ``"ensembl:ENSG00000198400"``. :param study_id: Study ID query provided by source, e.g. ``"civic.eid:3017"``. + :param start: Index of first result to fetch. Must be nonnegative. + :param limit: Max number of results to fetch. Must be nonnegative. Revert to + default defined at class initialization if not given. :return: Service response object containing nested studies and service metadata. """ + if start < 0: + msg = "Can't start from an index of less than 0." + raise ValueError(msg) + if isinstance(limit, int) and limit < 0: + msg = "Can't limit results to less than a negative number." + raise ValueError(msg) + response: dict = { "query": { "variation": None, @@ -187,6 +223,8 @@ async def search_studies( normalized_therapy=normalized_therapy, normalized_disease=normalized_disease, normalized_gene=normalized_gene, + start=start, + limit=limit, ) response["study_ids"] = [s["id"] for s in study_nodes] @@ -358,6 +396,8 @@ def _get_study_by_id(self, study_id: str) -> Node | None: def _get_studies( self, + start: int, + limit: int | None, normalized_variation: str | None = None, normalized_therapy: str | None = None, normalized_disease: str | None = None, @@ -365,6 +405,10 @@ def _get_studies( ) -> list[Node]: """Get studies that match the intersection of provided concepts. + :param start: Index of first result to fetch. Calling context should've already + checked that it's nonnegative. + :param limit: Max number of results to fetch. Calling context should've already + checked that it's nonnegative. :param normalized_variation: VRS Variation ID :param normalized_therapy: normalized therapy concept ID :param normalized_disease: normalized disease concept ID @@ -372,7 +416,7 @@ def _get_studies( :return: List of Study nodes that match the intersection of the given parameters """ query = "MATCH (s:Study)" - params: dict[str, str] = {} + params: dict[str, str | int] = {} if normalized_variation: query += """ @@ -402,7 +446,18 @@ def _get_studies( """ params["t_id"] = normalized_therapy - query += "RETURN DISTINCT s" + query += """ + RETURN DISTINCT s + ORDER BY s.id + """ + + if start: + query += "\nSKIP $start" + params["start"] = start + limit_candidate = limit if limit is not None else self._default_page_limit + if limit_candidate is not None: + query += "\nLIMIT $limit" + params["limit"] = limit_candidate return [s[0] for s in self.driver.execute_query(query, params).records] @@ -748,6 +803,8 @@ def _get_therapeutic_agent(in_ta_params: dict) -> TherapeuticAgent: async def batch_search_studies( self, variations: list[str] | None = None, + start: int = 0, + limit: int | None = None, ) -> BatchSearchStudiesService: """Fetch all studies associated with any of the provided variation description strings. @@ -771,8 +828,19 @@ async def batch_search_studies( :param variations: a list of variation description strings, e.g. ``["BRAF V600E"]`` + :param start: Index of first result to fetch. Must be nonnegative. + :param limit: Max number of results to fetch. Must be nonnegative. Revert to + default defined at class initialization if not given. :return: response object including all matching studies + :raise ValueError: if ``start`` or ``limit`` are nonnegative """ + if start < 0: + msg = "Can't start from an index of less than 0." + raise ValueError(msg) + if isinstance(limit, int) and limit < 0: + msg = "Can't limit results to less than a negative number." + raise ValueError(msg) + response = BatchSearchStudiesService( query=BatchSearchStudiesQuery(variations=[]), service_meta_=ServiceMeta(), @@ -794,16 +862,30 @@ async def batch_search_studies( if not variation_ids: return response - query = """ - MATCH (s) -[:HAS_VARIANT] -> (cv:CategoricalVariation) - MATCH (cv) -[:HAS_DEFINING_CONTEXT|HAS_MEMBERS] -> (v:Variation) - WHERE v.id IN $v_ids - RETURN s - """ + if limit is not None or self._default_page_limit is not None: + query = """ + MATCH (s) -[:HAS_VARIANT] -> (cv:CategoricalVariation) + MATCH (cv) -[:HAS_DEFINING_CONTEXT|HAS_MEMBERS] -> (v:Variation) + WHERE v.id IN $v_ids + RETURN DISTINCT s + ORDER BY s.id + SKIP $skip + LIMIT $limit + """ + limit = limit if limit is not None else self._default_page_limit + else: + query = """ + MATCH (s) -[:HAS_VARIANT] -> (cv:CategoricalVariation) + MATCH (cv) -[:HAS_DEFINING_CONTEXT|HAS_MEMBERS] -> (v:Variation) + WHERE v.id IN $v_ids + RETURN DISTINCT s + ORDER BY s.id + SKIP $skip + """ with self.driver.session() as session: - result = session.run(query, v_ids=variation_ids) + result = session.run(query, v_ids=variation_ids, skip=start, limit=limit) study_nodes = [r[0] for r in result] - response.study_ids = list({n["id"] for n in study_nodes}) + response.study_ids = [n["id"] for n in study_nodes] studies = self._get_nested_studies(study_nodes) response.studies = [VariantTherapeuticResponseStudy(**s) for s in studies] return response diff --git a/tests/conftest.py b/tests/conftest.py index b7fa52c0..977e869b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -9,6 +9,7 @@ from metakb.harvesters.base import Harvester from metakb.normalizers import ViccNormalizers +from metakb.query import QueryHandler TEST_DATA_DIR = Path(__file__).resolve().parents[0] / "data" TEST_HARVESTERS_DIR = TEST_DATA_DIR / "harvesters" @@ -2123,3 +2124,11 @@ def check_transformed_cdm(data, studies, transformed_file): def normalizers(): """Provide normalizers to querying/transformation tests.""" return ViccNormalizers() + + +@pytest.fixture(scope="module") +def query_handler(normalizers): + """Create query handler test fixture""" + qh = QueryHandler(normalizers=normalizers) + yield qh + qh.driver.close() diff --git a/tests/unit/search/__init__.py b/tests/unit/search/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit/search/test_batch_search_studies.py b/tests/unit/search/test_batch_search_studies.py new file mode 100644 index 00000000..7af71357 --- /dev/null +++ b/tests/unit/search/test_batch_search_studies.py @@ -0,0 +1,133 @@ +"""Test batch search function.""" + +import pytest + +from metakb.query import QueryHandler +from metakb.schemas.api import NormalizedQuery + +from .utils import assert_no_match, find_and_check_study + + +@pytest.mark.asyncio(scope="module") +async def test_batch_search( + query_handler: QueryHandler, + assertion_checks, + civic_eid2997_study, + civic_eid816_study, +): + """Test batch search studies method.""" + resp = await query_handler.batch_search_studies([]) + assert resp.studies == resp.study_ids == [] + assert resp.warnings == [] + + assert_no_match(await query_handler.batch_search_studies(["gibberish variant"])) + + braf_va_id = "ga4gh:VA.Otc5ovrw906Ack087o1fhegB4jDRqCAe" + braf_response = await query_handler.batch_search_studies([braf_va_id]) + assert braf_response.query.variations == [ + NormalizedQuery( + term=braf_va_id, + normalized_id=braf_va_id, + ) + ] + find_and_check_study(braf_response, civic_eid816_study, assertion_checks) + + redundant_braf_response = await query_handler.batch_search_studies( + [braf_va_id, "NC_000007.13:g.140453136A>T"] + ) + assert len(redundant_braf_response.query.variations) == 2 + assert ( + NormalizedQuery( + term=braf_va_id, + normalized_id=braf_va_id, + ) + in redundant_braf_response.query.variations + ) + assert ( + NormalizedQuery( + term="NC_000007.13:g.140453136A>T", + normalized_id=braf_va_id, + ) + in redundant_braf_response.query.variations + ) + + find_and_check_study(redundant_braf_response, civic_eid816_study, assertion_checks) + assert len(braf_response.study_ids) == len(redundant_braf_response.study_ids) + + braf_egfr_response = await query_handler.batch_search_studies( + [braf_va_id, "EGFR L858R"] + ) + find_and_check_study(braf_egfr_response, civic_eid816_study, assertion_checks) + find_and_check_study(braf_egfr_response, civic_eid2997_study, assertion_checks) + assert len(braf_egfr_response.study_ids) > len(braf_response.study_ids) + + +@pytest.mark.asyncio(scope="module") +async def test_paginate(query_handler: QueryHandler, normalizers): + """Test pagination parameters.""" + braf_va_id = "ga4gh:VA.Otc5ovrw906Ack087o1fhegB4jDRqCAe" + full_response = await query_handler.batch_search_studies([braf_va_id]) + paged_response = await query_handler.batch_search_studies([braf_va_id], start=1) + # should be almost the same, just off by 1 + assert len(paged_response.study_ids) == len(full_response.study_ids) - 1 + assert paged_response.study_ids == full_response.study_ids[1:] + + # check that page limit > response doesn't affect response + huge_page_response = await query_handler.batch_search_studies( + [braf_va_id], limit=1000 + ) + assert len(huge_page_response.study_ids) == len(full_response.study_ids) + assert huge_page_response.study_ids == full_response.study_ids + + # get last item + last_response = await query_handler.batch_search_studies( + [braf_va_id], start=len(full_response.study_ids) - 1 + ) + assert len(last_response.study_ids) == 1 + assert last_response.study_ids[0] == full_response.study_ids[-1] + + # test limit + min_response = await query_handler.batch_search_studies([braf_va_id], limit=1) + assert min_response.study_ids[0] == full_response.study_ids[0] + + # test limit and start + other_min_response = await query_handler.batch_search_studies( + [braf_va_id], start=1, limit=1 + ) + assert other_min_response.study_ids[0] == full_response.study_ids[1] + + # test limit of 0 + empty_response = await query_handler.batch_search_studies([braf_va_id], limit=0) + assert len(empty_response.study_ids) == 0 + + # test raises exceptions + with pytest.raises(ValueError, match="Can't start from an index of less than 0."): + await query_handler.batch_search_studies([braf_va_id], start=-1) + with pytest.raises( + ValueError, match="Can't limit results to less than a negative number." + ): + await query_handler.batch_search_studies([braf_va_id], limit=-1) + + # test default limit + limited_query_handler = QueryHandler(normalizers=normalizers, default_page_limit=1) + default_limited_response = await limited_query_handler.batch_search_studies( + [braf_va_id] + ) + assert len(default_limited_response.study_ids) == 1 + assert default_limited_response.study_ids[0] == full_response.study_ids[0] + + # test overrideable + less_limited_response = await limited_query_handler.batch_search_studies( + [braf_va_id], limit=2 + ) + assert len(less_limited_response.study_ids) == 2 + assert less_limited_response.study_ids == full_response.study_ids[:2] + + # test default limit and skip + skipped_limited_response = await limited_query_handler.batch_search_studies( + [braf_va_id], start=1 + ) + assert len(skipped_limited_response.study_ids) == 1 + assert skipped_limited_response.study_ids[0] == full_response.study_ids[1] + + limited_query_handler.driver.close() diff --git a/tests/unit/test_search_studies.py b/tests/unit/search/test_search_studies.py similarity index 75% rename from tests/unit/test_search_studies.py rename to tests/unit/search/test_search_studies.py index 1a4a4b71..4751cbef 100644 --- a/tests/unit/test_search_studies.py +++ b/tests/unit/search/test_search_studies.py @@ -4,19 +4,8 @@ from ga4gh.core._internal.models import Extension from metakb.query import QueryHandler -from metakb.schemas.api import ( - BatchSearchStudiesService, - NormalizedQuery, - SearchStudiesService, -) - -@pytest.fixture(scope="module") -def query_handler(normalizers): - """Create query handler test fixture""" - qh = QueryHandler(normalizers=normalizers) - yield qh - qh.driver.close() +from .utils import assert_no_match, find_and_check_study def _get_normalizer_id(extensions: list[Extension]) -> str | None: @@ -42,38 +31,6 @@ def assert_general_search_studies(response): assert len_study_id_matches == len_studies -def assert_no_match(response): - """No match assertions for queried concepts in search_studies.""" - assert response.studies == response.study_ids == [] - assert len(response.warnings) > 0 - - -def find_and_check_study( - resp: SearchStudiesService | BatchSearchStudiesService, - expected_study: dict, - assertion_checks: callable, - should_find_match: bool = True, -): - """Check that expected study is or is not in response""" - if should_find_match: - assert expected_study["id"] in resp.study_ids - else: - assert expected_study["id"] not in resp.study_ids - - actual_study = None - for study in resp.studies: - if study.id == expected_study["id"]: - actual_study = study - break - - if should_find_match: - assert actual_study, f"Did not find study ID {expected_study['id']} in studies" - resp_studies = [actual_study.model_dump(exclude_none=True)] - assertion_checks(resp_studies, [expected_study]) - else: - assert actual_study is None - - @pytest.mark.asyncio(scope="module") async def test_civic_eid2997(query_handler, civic_eid2997_study, assertion_checks): """Test that search_studies method works correctly for CIViC EID2997""" @@ -279,54 +236,71 @@ async def test_no_matches(query_handler): @pytest.mark.asyncio(scope="module") -async def test_batch_search( - query_handler: QueryHandler, - assertion_checks, - civic_eid2997_study, - civic_eid816_study, -): - """Test batch search studies method.""" - resp = await query_handler.batch_search_studies([]) - assert resp.studies == resp.study_ids == [] - assert resp.warnings == [] - - assert_no_match(await query_handler.batch_search_studies(["gibberish variant"])) - +async def test_paginate(query_handler: QueryHandler, normalizers): + """Test pagination parameters.""" braf_va_id = "ga4gh:VA.Otc5ovrw906Ack087o1fhegB4jDRqCAe" - braf_response = await query_handler.batch_search_studies([braf_va_id]) - assert braf_response.query.variations == [ - NormalizedQuery( - term=braf_va_id, - normalized_id=braf_va_id, - ) - ] - find_and_check_study(braf_response, civic_eid816_study, assertion_checks) + full_response = await query_handler.search_studies(variation=braf_va_id) + paged_response = await query_handler.search_studies(variation=braf_va_id, start=1) + # should be almost the same, just off by 1 + assert len(paged_response.study_ids) == len(full_response.study_ids) - 1 + assert paged_response.study_ids == full_response.study_ids[1:] + + # check that page limit > response doesn't affect response + huge_page_response = await query_handler.search_studies( + variation=braf_va_id, limit=1000 + ) + assert len(huge_page_response.study_ids) == len(full_response.study_ids) + assert huge_page_response.study_ids == full_response.study_ids - redundant_braf_response = await query_handler.batch_search_studies( - [braf_va_id, "NC_000007.13:g.140453136A>T"] + # get last item + last_response = await query_handler.search_studies( + variation=braf_va_id, start=len(full_response.study_ids) - 1 ) - assert len(redundant_braf_response.query.variations) == 2 - assert ( - NormalizedQuery( - term=braf_va_id, - normalized_id=braf_va_id, - ) - in redundant_braf_response.query.variations + assert len(last_response.study_ids) == 1 + assert last_response.study_ids[0] == full_response.study_ids[-1] + + # test limit + min_response = await query_handler.search_studies(variation=braf_va_id, limit=1) + assert min_response.study_ids[0] == full_response.study_ids[0] + + # test limit and start + other_min_response = await query_handler.search_studies( + variation=braf_va_id, start=1, limit=1 ) - assert ( - NormalizedQuery( - term="NC_000007.13:g.140453136A>T", - normalized_id=braf_va_id, - ) - in redundant_braf_response.query.variations + assert other_min_response.study_ids[0] == full_response.study_ids[1] + + # test limit of 0 + empty_response = await query_handler.search_studies(variation=braf_va_id, limit=0) + assert len(empty_response.study_ids) == 0 + + # test raises exceptions + with pytest.raises(ValueError, match="Can't start from an index of less than 0."): + await query_handler.search_studies(variation=braf_va_id, start=-1) + with pytest.raises( + ValueError, match="Can't limit results to less than a negative number." + ): + await query_handler.search_studies(variation=braf_va_id, limit=-1) + + # test default limit + limited_query_handler = QueryHandler(normalizers=normalizers, default_page_limit=1) + default_limited_response = await limited_query_handler.search_studies( + variation=braf_va_id ) + assert len(default_limited_response.study_ids) == 1 + assert default_limited_response.study_ids[0] == full_response.study_ids[0] - find_and_check_study(redundant_braf_response, civic_eid816_study, assertion_checks) - assert len(braf_response.study_ids) == len(redundant_braf_response.study_ids) + # test overrideable + less_limited_response = await limited_query_handler.search_studies( + variation=braf_va_id, limit=2 + ) + assert len(less_limited_response.study_ids) == 2 + assert less_limited_response.study_ids == full_response.study_ids[:2] - braf_egfr_response = await query_handler.batch_search_studies( - [braf_va_id, "EGFR L858R"] + # test default limit and skip + skipped_limited_response = await limited_query_handler.search_studies( + variation=braf_va_id, start=1 ) - find_and_check_study(braf_egfr_response, civic_eid816_study, assertion_checks) - find_and_check_study(braf_egfr_response, civic_eid2997_study, assertion_checks) - assert len(braf_egfr_response.study_ids) > len(braf_response.study_ids) + assert len(skipped_limited_response.study_ids) == 1 + assert skipped_limited_response.study_ids[0] == full_response.study_ids[1] + + limited_query_handler.driver.close() diff --git a/tests/unit/search/utils.py b/tests/unit/search/utils.py new file mode 100644 index 00000000..8eaf78c2 --- /dev/null +++ b/tests/unit/search/utils.py @@ -0,0 +1,33 @@ +from metakb.schemas.api import BatchSearchStudiesService, SearchStudiesService + + +def assert_no_match(response): + """No match assertions for queried concepts in search_studies.""" + assert response.studies == response.study_ids == [] + assert len(response.warnings) > 0 + + +def find_and_check_study( + resp: SearchStudiesService | BatchSearchStudiesService, + expected_study: dict, + assertion_checks: callable, + should_find_match: bool = True, +): + """Check that expected study is or is not in response""" + if should_find_match: + assert expected_study["id"] in resp.study_ids + else: + assert expected_study["id"] not in resp.study_ids + + actual_study = None + for study in resp.studies: + if study.id == expected_study["id"]: + actual_study = study + break + + if should_find_match: + assert actual_study, f"Did not find study ID {expected_study['id']} in studies" + resp_studies = [actual_study.model_dump(exclude_none=True)] + assertion_checks(resp_studies, [expected_study]) + else: + assert actual_study is None