diff --git a/smoke-release-testing/text2cypher/.gitignore b/smoke-release-testing/text2cypher/.gitignore new file mode 100644 index 000000000..522729095 --- /dev/null +++ b/smoke-release-testing/text2cypher/.gitignore @@ -0,0 +1,2 @@ +*.log +*.cypher diff --git a/smoke-release-testing/text2cypher/generate_random_graph.py b/smoke-release-testing/text2cypher/generate_random_graph.py new file mode 100644 index 000000000..46b1eef45 --- /dev/null +++ b/smoke-release-testing/text2cypher/generate_random_graph.py @@ -0,0 +1,191 @@ +""" +Random graph generator for the Country / Filing / Entity schema. + +Requires: + pip install faker python-dateutil +Outputs: + Cypher statements on stdout – redirect to a file if you like: + python make_graph.py > load.cypher +""" + +import random +from datetime import datetime, timedelta +from dateutil.tz import tzutc +from faker import Faker + +# --------------------------------------------------------------------------- +# tunables – change these to make a bigger or smaller graph +NUM_COUNTRIES = 20 +NUM_ENTITIES = 150 +NUM_FILINGS = 400 +SEED = 42 # remove or change for different results +# --------------------------------------------------------------------------- + +fake = Faker() +random.seed(SEED) + +ISO_CODES = set() # keep ISO-3 codes unique + + +def iso3() -> str: + """Return a unique random three-letter ISO code.""" + while True: + code = "".join(random.choices("ABCDEFGHIJKLMNOPQRSTUVWXYZ", k=3)) + if code not in ISO_CODES: + ISO_CODES.add(code) + return code + + +def random_point() -> str: + """Return a Cypher point literal string.""" + lat = random.uniform(-90, 90) + lng = random.uniform(-180, 180) + return f"point({{latitude:{lat:.6f}, longitude:{lng:.6f}, srid:4326}})" + + +def dt_between(start: str, end: str) -> datetime: + """Random UTC datetime between two ISO-8601 strings.""" + start_dt = datetime.fromisoformat(start.replace("Z", "+00:00")) + end_dt = datetime.fromisoformat(end.replace("Z", "+00:00")) + delta = end_dt - start_dt + return start_dt + timedelta(seconds=random.randint(0, int(delta.total_seconds()))) + + +def iso_z(dt: datetime) -> str: + """Return datetime in ISO-8601 + 'Z'.""" + return dt.astimezone(tzutc()).isoformat(timespec="seconds").replace("+00:00", "Z") + + +# --------------------------------------------------------------------------- +# generate Countries +countries = [] +for i in range(NUM_COUNTRIES): + code = iso3() + name = fake.country() + tld = code[:2].upper() + countries.append( + { + "id": f"c{i}", + "code": code, + "name": name, + "tld": tld, + "location": random_point(), + } + ) + +# generate Entities +entities = [] +for i in range(NUM_ENTITIES): + c = random.choice(countries) + eid = fake.slug() + entities.append( + { + "id": eid, + "name": fake.company().rstrip(","), + "country": c["code"], + "location": random_point(), + } + ) + +# generate Filings +filings = [] +for i in range(NUM_FILINGS): + begin_dt = dt_between("2000-02-08T00:00:00Z", "2017-09-05T00:00:00Z") + end_dt = dt_between(iso_z(begin_dt), "2017-11-03T00:00:00Z") + + origin_bank = fake.company().rstrip(",") + bene_bank = fake.company().rstrip(",") + origin = random.choice(countries) + bene = random.choice(countries) + + filings.append( + { + "id": str(fake.unique.random_int(min=1, max=9_999_999)), + "begin": iso_z(begin_dt), + "end": iso_z(end_dt), + "begin_date_format": begin_dt.strftime("%Y-%m-%dT%H:%M:%SZ"), + "end_date_format": end_dt.strftime("%Y-%m-%dT%H:%M:%SZ"), + "begin_date": begin_dt.strftime("%b %d, %Y"), + "end_date": end_dt.strftime("%b %d, %Y"), + "originator_bank": origin_bank, + "originator_bank_id": fake.slug(), + "originator_bank_country": origin["name"], + "originator_iso": origin["code"], + "origin_lat": f"{random.uniform(-90, 90):.4f}", + "origin_lng": f"{random.uniform(-180, 180):.4f}", + "beneficiary_bank": bene_bank, + "beneficiary_bank_id": fake.slug(), + "beneficiary_bank_country": bene["name"], + "beneficiary_iso": bene["code"], + "beneficiary_lat": f"{random.uniform(-90, 90):.4f}", + "beneficiary_lng": f"{random.uniform(-180, 180):.4f}", + "amount": round(random.uniform(1.18, 2_721_000_000), 2), + "number": random.randint(1, 174), + # choose the entity relationships later + } + ) + +# --------------------------------------------------------------------------- +# utility – Cypher value rendering +def cypher_str(s: str) -> str: + return '"' + s.replace('"', '\\"') + '"' + + +def props(data: dict, exclude=()) -> str: + fields = [] + for k, v in data.items(): + if k in exclude: + continue + if isinstance(v, str) and v.startswith("point("): + fields.append(f"{k}:{v}") + elif isinstance(v, (int, float)): + fields.append(f"{k}:{v}") + else: + fields.append(f"{k}:{cypher_str(v)}") + return "{" + ", ".join(fields) + "}" + + +# --------------------------------------------------------------------------- +# emit Cypher +print("// ------------------ Countries ------------------") +for c in countries: + print(f"CREATE (:Country {props(c, exclude=['id'])});") + +print("\n// ------------------ Entities -------------------") +for e in entities: + print(f"CREATE (:Entity {props(e)});") + +print("\n// ------------------ Filings --------------------") +for f in filings: + print(f"CREATE (:Filing {props(f)});") + +print("\n// ------------------ Relationships --------------") +for f in filings: + # pick entities for the three roles + filing_id = f["id"] + originator = random.choice(entities)["id"] + beneficiary = random.choice(entities)["id"] + concern = random.choice(entities)["id"] + + # FILED + print(f"MATCH (e:Entity {{id:{cypher_str(originator)}}}), (f:Filing {{id:{cypher_str(filing_id)}}}) " + f"CREATE (e)-[:FILED]->(f);") + # ORIGINATOR + print(f"MATCH (f:Filing {{id:{cypher_str(filing_id)}}}), (e:Entity {{id:{cypher_str(originator)}}}) " + f"CREATE (f)-[:ORIGINATOR]->(e);") + # BENEFITS + print(f"MATCH (f:Filing {{id:{cypher_str(filing_id)}}}), (e:Entity {{id:{cypher_str(beneficiary)}}}) " + f"CREATE (f)-[:BENEFITS]->(e);") + # CONCERNS + print(f"MATCH (f:Filing {{id:{cypher_str(filing_id)}}}), (e:Entity {{id:{cypher_str(concern)}}}) " + f"CREATE (f)-[:CONCERNS]->(e);") + +# Entity-COUNTRY relations +print("\n// --------------- Entity → Country --------------") +for e in entities: + print( + f"MATCH (e:Entity {{id:{cypher_str(e['id'])}}}), (c:Country {{code:{cypher_str(e['country'])}}}) " + f"CREATE (e)-[:COUNTRY]->(c);" + ) + +print("// ------------------ Done -----------------------") diff --git a/smoke-release-testing/text2cypher/main.py b/smoke-release-testing/text2cypher/main.py new file mode 100644 index 000000000..3a2f3befd --- /dev/null +++ b/smoke-release-testing/text2cypher/main.py @@ -0,0 +1,170 @@ +import itertools +from os import wait +import time +import logging +from datetime import datetime + +from datasets import load_dataset +import docker +import mgclient + +# Configure logging +timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") +log_file = f"text2cypher_test_{timestamp}.log" + +# Create a formatter +formatter = logging.Formatter( + "%(asctime)s - %(levelname)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S" +) + +# Configure file handler +file_handler = logging.FileHandler(log_file) +file_handler.setFormatter(formatter) +file_handler.setLevel(logging.DEBUG) + +# Configure console handler +console_handler = logging.StreamHandler() +console_handler.setFormatter(formatter) +console_handler.setLevel(logging.INFO) + +# Configure root logger +logger = logging.getLogger() +logger.setLevel(logging.DEBUG) +logger.addHandler(file_handler) +logger.addHandler(console_handler) + +# Under Mac -> Settings > Advanced > Allow the default Docker socket to be used +# (requires password). +DOCKER_CLIENT = docker.from_env() + + +def container_exists(name): + try: + DOCKER_CLIENT.containers.get(name) + logger.debug(f"Container {name} already exists") + return True + except docker.errors.NotFound: + logger.debug(f"Container {name} does not exist") + return False + except Exception as e: + logger.debug(f"An error occurred while checking container {name}: {e}") + return False + + +def start_memgraph(): + if container_exists("test_text2cypher_queries"): + return + logger.info("Starting Memgraph container...") + try: + DOCKER_CLIENT.containers.run( + "memgraph/memgraph:3.3.0", + detach=True, + auto_remove=True, + name="test_text2cypher_queries", + ports={"7687": 7687}, + command=["--telemetry-enabled=False"], + ) + logger.info("Memgraph container started successfully") + except Exception as e: + logger.error(f"Failed to start Memgraph container: {e}") + raise + + +# TODO(gitbuda): Add a timeout to the query execution. +def execute_query(query): + try: + conn = mgclient.connect(host="127.0.0.1", port=7687) + conn.autocommit = True + cursor = conn.cursor() + cursor.execute(query) + cursor.fetchall() + cursor.close() + conn.close() + return True + except Exception as e: + logger.error(f"Failed to execute query: {e}") + return False + + +def is_memgraph_alive(reason=None): + # NOTE: This is tricky because sometimes the client error is "Can't assign + # requested address" while memgraph is still alive. + if reason: + logger.debug(f"Checking if Memgraph is alive... ({reason})") + else: + logger.debug("Checking if Memgraph is alive...") + if not container_exists("test_text2cypher_queries"): + return False + probe_query = "RETURN 1;" + count = 20 + while True: + if execute_query(probe_query): + return True + time.sleep(0.1) + count -= 1 + if count == 0: + logger.error( + "I couldn't get back from memgraph for a long time, exiting..." + ) + return False + + +def wait_memgraph(): + logger.info("Waiting for Memgraph to become available...") + while True: + if is_memgraph_alive(): + logger.info("Memgraph is now available") + break + time.sleep(0.1) + + +def start_and_wait_memgraph(): + start_memgraph() + # NOTE: There is some weird edgecase in deleting/starting... + time.sleep(1) + start_memgraph() + wait_memgraph() + + +logger.info("Loading dataset...") +dataset_path = "neo4j/text2cypher-2025v1" +dataset = load_dataset(dataset_path) +all_items_iter = itertools.chain(dataset["train"], dataset["test"]) +logger.info("Dataset loaded") + +start_and_wait_memgraph() +tried_queres = 0 +passed_queries = 0 +failed_queries = 0 +number_of_restarts = 0 +queries_crasing_memgraph_file = "queries_crashing_memgraph_file.cypher" + +with open(queries_crasing_memgraph_file, "w") as f: + for item in all_items_iter: + if not is_memgraph_alive("did previous query crash memgraph?"): + logger.debug("The previous query crashed memgraph, restarting...") + number_of_restarts += 1 + start_and_wait_memgraph() + else: + logger.debug("All good, Memgraph is still alive") + query = item["cypher"].replace("\\n", " ") + tried_queres += 1 + logger.info(f"Executing: {query}") + # TODO(gitbuda): Some queries have params -> pass relevant params somehow. + if execute_query(query): + passed_queries += 1 + else: + failed_queries += 1 + if not is_memgraph_alive(): + f.write(f"{query};\n") + +logger.info("Test Results Summary:") +logger.info(f"The number of tried queries: {tried_queres}") +logger.info(f"The number of passed queries: {passed_queries}") +logger.info(f"The number of failed queries: {failed_queries}") +logger.info(f"The number of memgraph restarts: {number_of_restarts}") +logger.info(f"Failed queries have been written to: {queries_crasing_memgraph_file}") +logger.info(f"Full log has been written to: {log_file}") + +# TODO(gitbuda): Implement a generator based on the schema and execute queries +# on top of the real data. diff --git a/smoke-release-testing/text2cypher/requirements.txt b/smoke-release-testing/text2cypher/requirements.txt new file mode 100644 index 000000000..ff0f8145b --- /dev/null +++ b/smoke-release-testing/text2cypher/requirements.txt @@ -0,0 +1,3 @@ +datasets +pymgclient==1.4.0 +docker