Skip to content

Commit

Permalink
feat: add option to cli to create neo4j db constraints (#412)
Browse files Browse the repository at this point in the history
close #410

* Added this option to `load-cdm` and `update`
  • Loading branch information
korikuzma authored Dec 18, 2024
1 parent ff00567 commit 9a655c3
Showing 1 changed file with 30 additions and 6 deletions.
36 changes: 30 additions & 6 deletions src/metakb/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,12 +330,15 @@ async def transform_file(
)


def _get_driver(db_url: str, db_creds: str | None) -> Generator[Driver, None, None]:
def _get_driver(
db_url: str, db_creds: str | None, add_constraints: bool
) -> Generator[Driver, None, None]:
"""Acquire Neo4j graph driver.
:param db_url: URL endpoint for the application Neo4j database.
:param db_creds: DB username and password, separated by a colon, e.g.
``"username:password"``.
:param add_constraints: Whether or not to create Neo4j database constraints.
:return: Graph driver instance
"""
if not db_creds:
Expand All @@ -348,7 +351,9 @@ def _get_driver(db_url: str, db_creds: str | None) -> Generator[Driver, None, No
_help_msg(
f"Argument to --db_credentials appears invalid. Got '{db_creds}'. Should follow pattern 'username:password'."
)
driver = get_driver(uri=db_url, credentials=credentials)
driver = get_driver(
uri=db_url, credentials=credentials, add_constraints=add_constraints
)
yield driver
driver.close()

Expand Down Expand Up @@ -382,13 +387,19 @@ def clear_graph(
``"username:password"``.
:param keep_constraints: if True, don't clear graph constraints
""" # noqa: D301
driver = next(_get_driver(db_url, db_credentials))
driver = next(_get_driver(db_url, db_credentials, add_constraints=False))
clear_metakb_graph(driver, keep_constraints)


@cli.command()
@click.option("--db_url", "-u", default="", help=_neo4j_db_url_description)
@click.option("--db_credentials", "-c", help=_neo4j_creds_description)
@click.option(
"--add_constraints",
is_flag=True,
default=False,
help="if true, create neo4j database constraints",
)
@click.option(
"--from_s3",
"-s",
Expand All @@ -402,7 +413,11 @@ def clear_graph(
nargs=-1,
)
def load_cdm(
db_url: str, db_credentials: str | None, from_s3: bool, cdm_files: tuple[Path, ...]
db_url: str,
db_credentials: str | None,
add_constraints: bool,
from_s3: bool,
cdm_files: tuple[Path, ...],
) -> None:
"""Load one or more CDM_FILEs into Neo4j graph.
Expand Down Expand Up @@ -430,6 +445,7 @@ def load_cdm(
:param db_url: URL endpoint for the application Neo4j database.
:param db_credentials: DB username and password, separated by a colon, e.g.
``"username:password"``.
:param add_constraints: Whether or not to create Neo4j database constraints.
:param from_s3: Skip data harvest/transform and load latest existing CDM files from
VICC S3 bucket. Exclusive with ``cdm_file`` arguments.
:param cdm_files: tuple of specific file(s) to load from. If empty, just get latest
Expand All @@ -441,7 +457,7 @@ def load_cdm(
start = timer()
_echo_info("Loading Neo4j database...")

driver = next(_get_driver(db_url, db_credentials))
driver = next(_get_driver(db_url, db_credentials, add_constraints))

if cdm_files:
for file in cdm_files:
Expand All @@ -468,6 +484,12 @@ def load_cdm(
@cli.command()
@click.option("--db_url", "-u", default="", help=_neo4j_db_url_description)
@click.option("--db_credentials", "-c", help=_neo4j_creds_description)
@click.option(
"--add_constraints",
is_flag=True,
default=False,
help="if true, create neo4j database constraints",
)
@click.option("--normalizer_db_url", "-n", help=_normalizer_db_url_description)
@click.option(
"--refresh_source_caches",
Expand All @@ -487,6 +509,7 @@ def load_cdm(
async def update(
db_url: str,
db_credentials: str | None,
add_constraints: bool,
normalizer_db_url: str | None,
refresh_source_caches: bool,
sources: tuple[SourceName, ...],
Expand Down Expand Up @@ -515,6 +538,7 @@ async def update(
:param db_url: URL endpoint for the application Neo4j database.
:param db_credentials: DB username and password, separated by a colon, e.g.
``"username:password"``.
:param add_constraints: Whether or not to create Neo4j database constraints.
:param normalizer_db_url: URL endpoint of normalizers DynamoDB database. If not
given, defaults to the configuration rules of the individual normalizers.
:param refresh_source_caches: ``True`` if source caches, i.e. CIViCPy, should be
Expand All @@ -528,7 +552,7 @@ async def update(
start = timer()
_echo_info("Loading Neo4j database...")

driver = next(_get_driver(db_url, db_credentials))
driver = next(_get_driver(db_url, db_credentials, add_constraints))

if not sources:
sources = tuple(SourceName)
Expand Down

0 comments on commit 9a655c3

Please sign in to comment.