diff --git a/src/metakb/cli.py b/src/metakb/cli.py index f31fa6ba..da12eb3a 100644 --- a/src/metakb/cli.py +++ b/src/metakb/cli.py @@ -330,12 +330,15 @@ async def transform_file( ) -def _get_driver(db_url: str, db_creds: str | None) -> Generator[Driver, None, None]: +def _get_driver( + db_url: str, db_creds: str | None, add_constraints: bool +) -> Generator[Driver, None, None]: """Acquire Neo4j graph driver. :param db_url: URL endpoint for the application Neo4j database. :param db_creds: DB username and password, separated by a colon, e.g. ``"username:password"``. + :param add_constraints: Whether or not to create Neo4j database constraints. :return: Graph driver instance """ if not db_creds: @@ -348,7 +351,9 @@ def _get_driver(db_url: str, db_creds: str | None) -> Generator[Driver, None, No _help_msg( f"Argument to --db_credentials appears invalid. Got '{db_creds}'. Should follow pattern 'username:password'." ) - driver = get_driver(uri=db_url, credentials=credentials) + driver = get_driver( + uri=db_url, credentials=credentials, add_constraints=add_constraints + ) yield driver driver.close() @@ -382,13 +387,19 @@ def clear_graph( ``"username:password"``. :param keep_constraints: if True, don't clear graph constraints """ # noqa: D301 - driver = next(_get_driver(db_url, db_credentials)) + driver = next(_get_driver(db_url, db_credentials, add_constraints=False)) clear_metakb_graph(driver, keep_constraints) @cli.command() @click.option("--db_url", "-u", default="", help=_neo4j_db_url_description) @click.option("--db_credentials", "-c", help=_neo4j_creds_description) +@click.option( + "--add_constraints", + is_flag=True, + default=False, + help="if true, create neo4j database constraints", +) @click.option( "--from_s3", "-s", @@ -402,7 +413,11 @@ def clear_graph( nargs=-1, ) def load_cdm( - db_url: str, db_credentials: str | None, from_s3: bool, cdm_files: tuple[Path, ...] + db_url: str, + db_credentials: str | None, + add_constraints: bool, + from_s3: bool, + cdm_files: tuple[Path, ...], ) -> None: """Load one or more CDM_FILEs into Neo4j graph. @@ -430,6 +445,7 @@ def load_cdm( :param db_url: URL endpoint for the application Neo4j database. :param db_credentials: DB username and password, separated by a colon, e.g. ``"username:password"``. + :param add_constraints: Whether or not to create Neo4j database constraints. :param from_s3: Skip data harvest/transform and load latest existing CDM files from VICC S3 bucket. Exclusive with ``cdm_file`` arguments. :param cdm_files: tuple of specific file(s) to load from. If empty, just get latest @@ -441,7 +457,7 @@ def load_cdm( start = timer() _echo_info("Loading Neo4j database...") - driver = next(_get_driver(db_url, db_credentials)) + driver = next(_get_driver(db_url, db_credentials, add_constraints)) if cdm_files: for file in cdm_files: @@ -468,6 +484,12 @@ def load_cdm( @cli.command() @click.option("--db_url", "-u", default="", help=_neo4j_db_url_description) @click.option("--db_credentials", "-c", help=_neo4j_creds_description) +@click.option( + "--add_constraints", + is_flag=True, + default=False, + help="if true, create neo4j database constraints", +) @click.option("--normalizer_db_url", "-n", help=_normalizer_db_url_description) @click.option( "--refresh_source_caches", @@ -487,6 +509,7 @@ def load_cdm( async def update( db_url: str, db_credentials: str | None, + add_constraints: bool, normalizer_db_url: str | None, refresh_source_caches: bool, sources: tuple[SourceName, ...], @@ -515,6 +538,7 @@ async def update( :param db_url: URL endpoint for the application Neo4j database. :param db_credentials: DB username and password, separated by a colon, e.g. ``"username:password"``. + :param add_constraints: Whether or not to create Neo4j database constraints. :param normalizer_db_url: URL endpoint of normalizers DynamoDB database. If not given, defaults to the configuration rules of the individual normalizers. :param refresh_source_caches: ``True`` if source caches, i.e. CIViCPy, should be @@ -528,7 +552,7 @@ async def update( start = timer() _echo_info("Loading Neo4j database...") - driver = next(_get_driver(db_url, db_credentials)) + driver = next(_get_driver(db_url, db_credentials, add_constraints)) if not sources: sources = tuple(SourceName)