Skip to content

Commit

Permalink
feat: enable pagination (#384)
Browse files Browse the repository at this point in the history
* use `start` and `limit` params for pagination
* set default `limit` in `QueryHandler` init
* responses now have order (tbh they might've had it previously in
practical terms but this makes it explicit)
  • Loading branch information
jsstevenson authored Jul 16, 2024
1 parent 5e30c8d commit 5f5bada
Show file tree
Hide file tree
Showing 8 changed files with 382 additions and 108 deletions.
5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -170,9 +170,12 @@ ignore = [
# ANN2 - missing-return-type
# ANN102 - missing-type-cls
# S101 - assert
# B011 - assert-false
# D104 - undocumented-public-package
# D100 - undocumented-public-module
# INP001 - implicit-namespace-package
# ARG001 - unused-function-argument
# SLF001 - private-member-acces
# N815 - mixed-case-variable-in-class-scope
"tests/*" = ["ANN001", "ANN2", "ANN102", "S101", "INP001", "SLF001", "ARG001"]
"tests/*" = ["ANN001", "ANN2", "ANN102", "S101", "B011", "D100", "D104", "INP001", "SLF001", "ARG001"]
"src/metakb/schemas/*" = ["ANN102", "N815"]
58 changes: 49 additions & 9 deletions src/metakb/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,14 @@

from metakb import __version__
from metakb.log_handle import configure_logs
from metakb.query import QueryHandler
from metakb.query import PaginationParamError, QueryHandler
from metakb.schemas.api import (
BatchSearchStudiesQuery,
BatchSearchStudiesService,
SearchStudiesQuery,
SearchStudiesService,
ServiceMeta,
)

query = QueryHandler()

Expand Down Expand Up @@ -70,13 +77,13 @@ def custom_openapi() -> dict:
t_description = "Therapy (object) to search"
g_description = "Gene to search"
s_description = "Study ID to search."
search_study_response_descr = "A response to a validly-formed query."
start_description = "The index of the first result to return. Use for pagination."
limit_description = "The maximum number of results to return. Use for pagination."


@app.get(
"/api/v2/search/studies",
summary=search_studies_summary,
response_description=search_study_response_descr,
description=search_studies_descr,
)
async def get_studies(
Expand All @@ -85,6 +92,8 @@ async def get_studies(
therapy: Annotated[str | None, Query(description=t_description)] = None,
gene: Annotated[str | None, Query(description=g_description)] = None,
study_id: Annotated[str | None, Query(description=s_description)] = None,
start: Annotated[int, Query(description=start_description)] = 0,
limit: Annotated[int | None, Query(description=limit_description)] = None,
) -> dict:
"""Get nested studies from queried concepts that match all conditions provided.
For example, if `variation` and `therapy` are provided, will return all studies
Expand All @@ -95,35 +104,66 @@ async def get_studies(
:param therapy: Therapy query
:param gene: Gene query
:param study_id: Study ID query.
:param start: The index of the first result to return. Use for pagination.
:param limit: The maximum number of results to return. Use for pagination.
:return: SearchStudiesService response containing nested studies and service
metadata
"""
resp = await query.search_studies(variation, disease, therapy, gene, study_id)
try:
resp = await query.search_studies(
variation, disease, therapy, gene, study_id, start, limit
)
except PaginationParamError:
resp = SearchStudiesService(
query=SearchStudiesQuery(
variation=variation,
disease=disease,
therapy=therapy,
gene=gene,
study_id=study_id,
),
service_meta_=ServiceMeta(),
warnings=["`start` and `limit` params must both be nonnegative"],
)
return resp.model_dump(exclude_none=True)


_batch_search_studies_descr = {
_batch_descr = {
"summary": "Get nested studies for all provided variations.",
"description": "Return nested studies associated with any of the provided variations.",
"arg_variations": "Variations (subject) to search. Can be free text or VRS variation ID.",
"arg_start": "The index of the first result to return. Use for pagination.",
"arg_limit": "The maximum number of results to return. Use for pagination.",
}


@app.get(
"/api/v2/batch_search/studies",
summary=_batch_search_studies_descr["summary"],
description=_batch_search_studies_descr["description"],
summary=_batch_descr["summary"],
description=_batch_descr["description"],
)
async def batch_get_studies(
variations: Annotated[
list[str] | None,
Query(description=_batch_search_studies_descr["arg_variations"]),
Query(description=_batch_descr["arg_variations"]),
] = None,
start: Annotated[int, Query(description=_batch_descr["arg_start"])] = 0,
limit: Annotated[int | None, Query(description=_batch_descr["arg_limit"])] = None,
) -> dict:
"""Fetch all studies associated with `any` of the provided variations.
:param variations: variations to match against
:param start: The index of the first result to return. Use for pagination.
:param limit: The maximum number of results to return. Use for pagination.
:return: batch response object
"""
response = await query.batch_search_studies(variations)
try:
response = await query.batch_search_studies(variations, start, limit)
except PaginationParamError:
response = BatchSearchStudiesService(
query=BatchSearchStudiesQuery(variations=[]),
service_meta_=ServiceMeta(),
warnings=["`start` and `limit` params must both be nonnegative"],
)

return response.model_dump(exclude_none=True)
104 changes: 93 additions & 11 deletions src/metakb/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@
logger = logging.getLogger(__name__)


class PaginationParamError(Exception):
"""Raise for invalid pagination parameters."""


class VariationRelation(str, Enum):
"""Constrain possible values for the relationship between variations and
categorical variations.
Expand Down Expand Up @@ -83,6 +87,7 @@ def __init__(
self,
driver: Driver | None = None,
normalizers: ViccNormalizers | None = None,
default_page_limit: int | None = None,
) -> None:
"""Initialize neo4j driver and the VICC normalizers.
Expand All @@ -101,15 +106,35 @@ def __init__(
... ViccNormalizers("http://localhost:8000")
... )
``default_page_limit`` sets the default max number of studies to include in
query responses:
>>> limited_qh = QueryHandler(default_page_limit=10)
>>> response = await limited_qh.batch_search_studies(["BRAF V600E"])
>>> print(len(response.study_ids))
10
This value is overruled by an explicit ``limit`` parameter:
>>> response = await limited_qh.batch_search_studies(
... ["BRAF V600E"],
... limit=2
... )
>>> print(len(response.study_ids))
2
:param driver: driver instance for graph connection
:param normalizers: normalizer collection instance
:param default_page_limit: default number of results per response page (leave
as ``None`` for no default limit)
"""
if driver is None:
driver = get_driver()
if normalizers is None:
normalizers = ViccNormalizers()
self.driver = driver
self.vicc_normalizers = normalizers
self._default_page_limit = default_page_limit

async def search_studies(
self,
Expand All @@ -118,6 +143,8 @@ async def search_studies(
therapy: str | None = None,
gene: str | None = None,
study_id: str | None = None,
start: int = 0,
limit: int | None = None,
) -> SearchStudiesService:
"""Get nested studies from queried concepts that match all conditions provided.
For example, if ``variation`` and ``therapy`` are provided, will return all studies
Expand All @@ -131,7 +158,6 @@ async def search_studies(
>>> result.studies[0].isReportedIn[0].url
'https://www.accessdata.fda.gov/drugsatfda_docs/label/2020/202429s019lbl.pdf'
Variation, disease, therapy, and gene terms are resolved via their respective
:ref:`concept normalization services<normalization>`.
Expand All @@ -146,8 +172,18 @@ async def search_studies(
:param gene: Gene query. Common shorthand name, e.g. ``"NTRK1"``, or compact URI,
e.g. ``"ensembl:ENSG00000198400"``.
:param study_id: Study ID query provided by source, e.g. ``"civic.eid:3017"``.
:param start: Index of first result to fetch. Must be nonnegative.
:param limit: Max number of results to fetch. Must be nonnegative. Revert to
default defined at class initialization if not given.
:return: Service response object containing nested studies and service metadata.
"""
if start < 0:
msg = "Can't start from an index of less than 0."
raise ValueError(msg)
if isinstance(limit, int) and limit < 0:
msg = "Can't limit results to less than a negative number."
raise ValueError(msg)

response: dict = {
"query": {
"variation": None,
Expand Down Expand Up @@ -187,6 +223,8 @@ async def search_studies(
normalized_therapy=normalized_therapy,
normalized_disease=normalized_disease,
normalized_gene=normalized_gene,
start=start,
limit=limit,
)
response["study_ids"] = [s["id"] for s in study_nodes]

Expand Down Expand Up @@ -358,21 +396,27 @@ def _get_study_by_id(self, study_id: str) -> Node | None:

def _get_studies(
self,
start: int,
limit: int | None,
normalized_variation: str | None = None,
normalized_therapy: str | None = None,
normalized_disease: str | None = None,
normalized_gene: str | None = None,
) -> list[Node]:
"""Get studies that match the intersection of provided concepts.
:param start: Index of first result to fetch. Calling context should've already
checked that it's nonnegative.
:param limit: Max number of results to fetch. Calling context should've already
checked that it's nonnegative.
:param normalized_variation: VRS Variation ID
:param normalized_therapy: normalized therapy concept ID
:param normalized_disease: normalized disease concept ID
:param normalized_gene: normalized gene concept ID
:return: List of Study nodes that match the intersection of the given parameters
"""
query = "MATCH (s:Study)"
params: dict[str, str] = {}
params: dict[str, str | int] = {}

if normalized_variation:
query += """
Expand Down Expand Up @@ -402,7 +446,18 @@ def _get_studies(
"""
params["t_id"] = normalized_therapy

query += "RETURN DISTINCT s"
query += """
RETURN DISTINCT s
ORDER BY s.id
"""

if start:
query += "\nSKIP $start"
params["start"] = start
limit_candidate = limit if limit is not None else self._default_page_limit
if limit_candidate is not None:
query += "\nLIMIT $limit"
params["limit"] = limit_candidate

return [s[0] for s in self.driver.execute_query(query, params).records]

Expand Down Expand Up @@ -748,6 +803,8 @@ def _get_therapeutic_agent(in_ta_params: dict) -> TherapeuticAgent:
async def batch_search_studies(
self,
variations: list[str] | None = None,
start: int = 0,
limit: int | None = None,
) -> BatchSearchStudiesService:
"""Fetch all studies associated with any of the provided variation description
strings.
Expand All @@ -771,8 +828,19 @@ async def batch_search_studies(
:param variations: a list of variation description strings, e.g.
``["BRAF V600E"]``
:param start: Index of first result to fetch. Must be nonnegative.
:param limit: Max number of results to fetch. Must be nonnegative. Revert to
default defined at class initialization if not given.
:return: response object including all matching studies
:raise ValueError: if ``start`` or ``limit`` are nonnegative
"""
if start < 0:
msg = "Can't start from an index of less than 0."
raise ValueError(msg)
if isinstance(limit, int) and limit < 0:
msg = "Can't limit results to less than a negative number."
raise ValueError(msg)

response = BatchSearchStudiesService(
query=BatchSearchStudiesQuery(variations=[]),
service_meta_=ServiceMeta(),
Expand All @@ -794,16 +862,30 @@ async def batch_search_studies(
if not variation_ids:
return response

query = """
MATCH (s) -[:HAS_VARIANT] -> (cv:CategoricalVariation)
MATCH (cv) -[:HAS_DEFINING_CONTEXT|HAS_MEMBERS] -> (v:Variation)
WHERE v.id IN $v_ids
RETURN s
"""
if limit is not None or self._default_page_limit is not None:
query = """
MATCH (s) -[:HAS_VARIANT] -> (cv:CategoricalVariation)
MATCH (cv) -[:HAS_DEFINING_CONTEXT|HAS_MEMBERS] -> (v:Variation)
WHERE v.id IN $v_ids
RETURN DISTINCT s
ORDER BY s.id
SKIP $skip
LIMIT $limit
"""
limit = limit if limit is not None else self._default_page_limit
else:
query = """
MATCH (s) -[:HAS_VARIANT] -> (cv:CategoricalVariation)
MATCH (cv) -[:HAS_DEFINING_CONTEXT|HAS_MEMBERS] -> (v:Variation)
WHERE v.id IN $v_ids
RETURN DISTINCT s
ORDER BY s.id
SKIP $skip
"""
with self.driver.session() as session:
result = session.run(query, v_ids=variation_ids)
result = session.run(query, v_ids=variation_ids, skip=start, limit=limit)
study_nodes = [r[0] for r in result]
response.study_ids = list({n["id"] for n in study_nodes})
response.study_ids = [n["id"] for n in study_nodes]
studies = self._get_nested_studies(study_nodes)
response.studies = [VariantTherapeuticResponseStudy(**s) for s in studies]
return response
9 changes: 9 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from metakb.harvesters.base import Harvester
from metakb.normalizers import ViccNormalizers
from metakb.query import QueryHandler

TEST_DATA_DIR = Path(__file__).resolve().parents[0] / "data"
TEST_HARVESTERS_DIR = TEST_DATA_DIR / "harvesters"
Expand Down Expand Up @@ -2123,3 +2124,11 @@ def check_transformed_cdm(data, studies, transformed_file):
def normalizers():
"""Provide normalizers to querying/transformation tests."""
return ViccNormalizers()


@pytest.fixture(scope="module")
def query_handler(normalizers):
"""Create query handler test fixture"""
qh = QueryHandler(normalizers=normalizers)
yield qh
qh.driver.close()
Empty file added tests/unit/search/__init__.py
Empty file.
Loading

0 comments on commit 5f5bada

Please sign in to comment.