Skip to content

Execute text2cypher queries against Memgraph #611

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions smoke-release-testing/text2cypher/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
*.log
*.cypher
191 changes: 191 additions & 0 deletions smoke-release-testing/text2cypher/generate_random_graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
"""
Random graph generator for the Country / Filing / Entity schema.

Requires:
pip install faker python-dateutil
Outputs:
Cypher statements on stdout – redirect to a file if you like:
python make_graph.py > load.cypher
"""

import random
from datetime import datetime, timedelta
from dateutil.tz import tzutc
from faker import Faker

# ---------------------------------------------------------------------------
# tunables – change these to make a bigger or smaller graph
NUM_COUNTRIES = 20
NUM_ENTITIES = 150
NUM_FILINGS = 400
SEED = 42 # remove or change for different results
# ---------------------------------------------------------------------------

fake = Faker()
random.seed(SEED)

ISO_CODES = set() # keep ISO-3 codes unique


def iso3() -> str:
"""Return a unique random three-letter ISO code."""
while True:
code = "".join(random.choices("ABCDEFGHIJKLMNOPQRSTUVWXYZ", k=3))
if code not in ISO_CODES:
ISO_CODES.add(code)
return code


def random_point() -> str:
"""Return a Cypher point literal string."""
lat = random.uniform(-90, 90)
lng = random.uniform(-180, 180)
return f"point({{latitude:{lat:.6f}, longitude:{lng:.6f}, srid:4326}})"


def dt_between(start: str, end: str) -> datetime:
"""Random UTC datetime between two ISO-8601 strings."""
start_dt = datetime.fromisoformat(start.replace("Z", "+00:00"))
end_dt = datetime.fromisoformat(end.replace("Z", "+00:00"))
delta = end_dt - start_dt
return start_dt + timedelta(seconds=random.randint(0, int(delta.total_seconds())))


def iso_z(dt: datetime) -> str:
"""Return datetime in ISO-8601 + 'Z'."""
return dt.astimezone(tzutc()).isoformat(timespec="seconds").replace("+00:00", "Z")


# ---------------------------------------------------------------------------
# generate Countries
countries = []
for i in range(NUM_COUNTRIES):
code = iso3()
name = fake.country()
tld = code[:2].upper()
countries.append(
{
"id": f"c{i}",
"code": code,
"name": name,
"tld": tld,
"location": random_point(),
}
)

# generate Entities
entities = []
for i in range(NUM_ENTITIES):
c = random.choice(countries)
eid = fake.slug()
entities.append(
{
"id": eid,
"name": fake.company().rstrip(","),
"country": c["code"],
"location": random_point(),
}
)

# generate Filings
filings = []
for i in range(NUM_FILINGS):
begin_dt = dt_between("2000-02-08T00:00:00Z", "2017-09-05T00:00:00Z")
end_dt = dt_between(iso_z(begin_dt), "2017-11-03T00:00:00Z")

origin_bank = fake.company().rstrip(",")
bene_bank = fake.company().rstrip(",")
origin = random.choice(countries)
bene = random.choice(countries)

filings.append(
{
"id": str(fake.unique.random_int(min=1, max=9_999_999)),
"begin": iso_z(begin_dt),
"end": iso_z(end_dt),
"begin_date_format": begin_dt.strftime("%Y-%m-%dT%H:%M:%SZ"),
"end_date_format": end_dt.strftime("%Y-%m-%dT%H:%M:%SZ"),
"begin_date": begin_dt.strftime("%b %d, %Y"),
"end_date": end_dt.strftime("%b %d, %Y"),
"originator_bank": origin_bank,
"originator_bank_id": fake.slug(),
"originator_bank_country": origin["name"],
"originator_iso": origin["code"],
"origin_lat": f"{random.uniform(-90, 90):.4f}",
"origin_lng": f"{random.uniform(-180, 180):.4f}",
"beneficiary_bank": bene_bank,
"beneficiary_bank_id": fake.slug(),
"beneficiary_bank_country": bene["name"],
"beneficiary_iso": bene["code"],
"beneficiary_lat": f"{random.uniform(-90, 90):.4f}",
"beneficiary_lng": f"{random.uniform(-180, 180):.4f}",
"amount": round(random.uniform(1.18, 2_721_000_000), 2),
"number": random.randint(1, 174),
# choose the entity relationships later
}
)

# ---------------------------------------------------------------------------
# utility – Cypher value rendering
def cypher_str(s: str) -> str:
return '"' + s.replace('"', '\\"') + '"'


def props(data: dict, exclude=()) -> str:
fields = []
for k, v in data.items():
if k in exclude:
continue
if isinstance(v, str) and v.startswith("point("):
fields.append(f"{k}:{v}")
elif isinstance(v, (int, float)):
fields.append(f"{k}:{v}")
else:
fields.append(f"{k}:{cypher_str(v)}")
return "{" + ", ".join(fields) + "}"


# ---------------------------------------------------------------------------
# emit Cypher
print("// ------------------ Countries ------------------")
for c in countries:
print(f"CREATE (:Country {props(c, exclude=['id'])});")

print("\n// ------------------ Entities -------------------")
for e in entities:
print(f"CREATE (:Entity {props(e)});")

print("\n// ------------------ Filings --------------------")
for f in filings:
print(f"CREATE (:Filing {props(f)});")

print("\n// ------------------ Relationships --------------")
for f in filings:
# pick entities for the three roles
filing_id = f["id"]
originator = random.choice(entities)["id"]
beneficiary = random.choice(entities)["id"]
concern = random.choice(entities)["id"]

# FILED
print(f"MATCH (e:Entity {{id:{cypher_str(originator)}}}), (f:Filing {{id:{cypher_str(filing_id)}}}) "
f"CREATE (e)-[:FILED]->(f);")
# ORIGINATOR
print(f"MATCH (f:Filing {{id:{cypher_str(filing_id)}}}), (e:Entity {{id:{cypher_str(originator)}}}) "
f"CREATE (f)-[:ORIGINATOR]->(e);")
# BENEFITS
print(f"MATCH (f:Filing {{id:{cypher_str(filing_id)}}}), (e:Entity {{id:{cypher_str(beneficiary)}}}) "
f"CREATE (f)-[:BENEFITS]->(e);")
# CONCERNS
print(f"MATCH (f:Filing {{id:{cypher_str(filing_id)}}}), (e:Entity {{id:{cypher_str(concern)}}}) "
f"CREATE (f)-[:CONCERNS]->(e);")

# Entity-COUNTRY relations
print("\n// --------------- Entity → Country --------------")
for e in entities:
print(
f"MATCH (e:Entity {{id:{cypher_str(e['id'])}}}), (c:Country {{code:{cypher_str(e['country'])}}}) "
f"CREATE (e)-[:COUNTRY]->(c);"
)

print("// ------------------ Done -----------------------")
170 changes: 170 additions & 0 deletions smoke-release-testing/text2cypher/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
import itertools
from os import wait
import time
import logging
from datetime import datetime

from datasets import load_dataset
import docker
import mgclient

# Configure logging
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
log_file = f"text2cypher_test_{timestamp}.log"

# Create a formatter
formatter = logging.Formatter(
"%(asctime)s - %(levelname)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S"
)

# Configure file handler
file_handler = logging.FileHandler(log_file)
file_handler.setFormatter(formatter)
file_handler.setLevel(logging.DEBUG)

# Configure console handler
console_handler = logging.StreamHandler()
console_handler.setFormatter(formatter)
console_handler.setLevel(logging.INFO)

# Configure root logger
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)
logger.addHandler(file_handler)
logger.addHandler(console_handler)

# Under Mac -> Settings > Advanced > Allow the default Docker socket to be used
# (requires password).
DOCKER_CLIENT = docker.from_env()


def container_exists(name):
try:
DOCKER_CLIENT.containers.get(name)
logger.debug(f"Container {name} already exists")
return True
except docker.errors.NotFound:
logger.debug(f"Container {name} does not exist")
return False
except Exception as e:
logger.debug(f"An error occurred while checking container {name}: {e}")
return False


def start_memgraph():
if container_exists("test_text2cypher_queries"):
return
logger.info("Starting Memgraph container...")
try:
DOCKER_CLIENT.containers.run(
"memgraph/memgraph:3.3.0",
detach=True,
auto_remove=True,
name="test_text2cypher_queries",
ports={"7687": 7687},
command=["--telemetry-enabled=False"],
)
logger.info("Memgraph container started successfully")
except Exception as e:
logger.error(f"Failed to start Memgraph container: {e}")
raise


# TODO(gitbuda): Add a timeout to the query execution.
def execute_query(query):
try:
conn = mgclient.connect(host="127.0.0.1", port=7687)
conn.autocommit = True
cursor = conn.cursor()
cursor.execute(query)
cursor.fetchall()
cursor.close()
conn.close()
return True
except Exception as e:
logger.error(f"Failed to execute query: {e}")
return False


def is_memgraph_alive(reason=None):
# NOTE: This is tricky because sometimes the client error is "Can't assign
# requested address" while memgraph is still alive.
if reason:
logger.debug(f"Checking if Memgraph is alive... ({reason})")
else:
logger.debug("Checking if Memgraph is alive...")
if not container_exists("test_text2cypher_queries"):
return False
probe_query = "RETURN 1;"
count = 20
while True:
if execute_query(probe_query):
return True
time.sleep(0.1)
count -= 1
if count == 0:
logger.error(
"I couldn't get back from memgraph for a long time, exiting..."
)
return False


def wait_memgraph():
logger.info("Waiting for Memgraph to become available...")
while True:
if is_memgraph_alive():
logger.info("Memgraph is now available")
break
time.sleep(0.1)


def start_and_wait_memgraph():
start_memgraph()
# NOTE: There is some weird edgecase in deleting/starting...
time.sleep(1)
start_memgraph()
wait_memgraph()


logger.info("Loading dataset...")
dataset_path = "neo4j/text2cypher-2025v1"
dataset = load_dataset(dataset_path)
all_items_iter = itertools.chain(dataset["train"], dataset["test"])
logger.info("Dataset loaded")

start_and_wait_memgraph()
tried_queres = 0
passed_queries = 0
failed_queries = 0
number_of_restarts = 0
queries_crasing_memgraph_file = "queries_crashing_memgraph_file.cypher"

with open(queries_crasing_memgraph_file, "w") as f:
for item in all_items_iter:
if not is_memgraph_alive("did previous query crash memgraph?"):
logger.debug("The previous query crashed memgraph, restarting...")
number_of_restarts += 1
start_and_wait_memgraph()
else:
logger.debug("All good, Memgraph is still alive")
query = item["cypher"].replace("\\n", " ")
tried_queres += 1
logger.info(f"Executing: {query}")
# TODO(gitbuda): Some queries have params -> pass relevant params somehow.
if execute_query(query):
passed_queries += 1
else:
failed_queries += 1
if not is_memgraph_alive():
f.write(f"{query};\n")

logger.info("Test Results Summary:")
logger.info(f"The number of tried queries: {tried_queres}")
logger.info(f"The number of passed queries: {passed_queries}")
logger.info(f"The number of failed queries: {failed_queries}")
logger.info(f"The number of memgraph restarts: {number_of_restarts}")
logger.info(f"Failed queries have been written to: {queries_crasing_memgraph_file}")
logger.info(f"Full log has been written to: {log_file}")

# TODO(gitbuda): Implement a generator based on the schema and execute queries
# on top of the real data.
3 changes: 3 additions & 0 deletions smoke-release-testing/text2cypher/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
datasets
pymgclient==1.4.0
docker