-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathneo4jwriter.py
35 lines (27 loc) · 1.59 KB
/
neo4jwriter.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
from typing import List, Any, Dict, Iterator, Callable, Union
import os
from neo4j import GraphDatabase, Transaction
class Neo4jWriter:
def __init__(self, neo4j_url: str = os.environ.get("NEO4J_URI"),
neo4j_user: str = os.environ.get("NEO4J_USER"),
neo4j_password: str = os.environ.get("NEO4J_PASSWORD"),
database: str = os.environ.get("NEO4J_DATABASE")):
self.driver = GraphDatabase.driver(neo4j_url, auth=(neo4j_user, neo4j_password))
self.database = database
def batch_write(self, cypher_query: str, params: List[Dict[str, Any]], batch_size: int = 10000):
with self.driver.session(database=self.database) as session:
for batch in Neo4jWriter._batch_parameters(params, batch_size):
packaged_params = {'params': batch}
tx_function = lambda tx: self.neo4j_tx_function(tx=tx, cypher_query=cypher_query,
params=packaged_params)
session.execute_write(tx_function)
def neo4j_tx_function(self, tx: Transaction, params: List[Dict[str, Any]], cypher_query: str) -> None:
tx.run(cypher_query, parameters=params)
def build_indexes(self, index_list=List[str]):
for index in index_list:
tx_function = lambda tx: self.neo4j_tx_function(tx, [], index)
self.session.execute_write(tx_function)
@staticmethod
def _batch_parameters(lst: List[Any], batch_size: int) -> Iterator[List[Any]]:
for i in range(0, len(lst), batch_size):
yield lst[i:i + batch_size]