forked from openai/chatgpt-retrieval-plugin
-
Notifications
You must be signed in to change notification settings - Fork 0
/
process_json.py
147 lines (127 loc) · 5.24 KB
/
process_json.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import uuid
import json
import argparse
import asyncio
from loguru import logger
from models.models import Document, DocumentMetadata
from datastore.datastore import DataStore
from datastore.factory import get_datastore
from services.extract_metadata import extract_metadata_from_document
from services.pii_detection import screen_text_for_pii
DOCUMENT_UPSERT_BATCH_SIZE = 50
async def process_json_dump(
filepath: str,
datastore: DataStore,
custom_metadata: dict,
screen_for_pii: bool,
extract_metadata: bool,
):
# load the json file as a list of dictionaries
with open(filepath) as json_file:
data = json.load(json_file)
documents = []
skipped_items = []
# iterate over the data and create document objects
for item in data:
if len(documents) % 20 == 0:
logger.info(f"Processed {len(documents)} documents")
try:
# get the id, text, source, source_id, url, created_at and author from the item
# use default values if not specified
id = item.get("id", None)
text = item.get("text", None)
source = item.get("source", None)
source_id = item.get("source_id", None)
url = item.get("url", None)
created_at = item.get("created_at", None)
author = item.get("author", None)
if not text:
logger.info("No document text, skipping...")
continue
# create a metadata object with the source, source_id, url, created_at and author
metadata = DocumentMetadata(
source=source,
source_id=source_id,
url=url,
created_at=created_at,
author=author,
)
logger.info("metadata: ", str(metadata))
# update metadata with custom values
for key, value in custom_metadata.items():
if hasattr(metadata, key):
setattr(metadata, key, value)
# screen for pii if requested
if screen_for_pii:
pii_detected = screen_text_for_pii(text)
# if pii detected, print a warning and skip the document
if pii_detected:
logger.info("PII detected in document, skipping")
skipped_items.append(item) # add the skipped item to the list
continue
# extract metadata if requested
if extract_metadata:
# extract metadata from the document text
extracted_metadata = extract_metadata_from_document(
f"Text: {text}; Metadata: {str(metadata)}"
)
# get a Metadata object from the extracted metadata
metadata = DocumentMetadata(**extracted_metadata)
# create a document object with the id or a random id, text and metadata
document = Document(
id=id or str(uuid.uuid4()),
text=text,
metadata=metadata,
)
documents.append(document)
except Exception as e:
# log the error and continue with the next item
logger.error(f"Error processing {item}: {e}")
skipped_items.append(item) # add the skipped item to the list
# do this in batches, the upsert method already batches documents but this allows
# us to add more descriptive logging
for i in range(0, len(documents), DOCUMENT_UPSERT_BATCH_SIZE):
# Get the text of the chunks in the current batch
batch_documents = documents[i : i + DOCUMENT_UPSERT_BATCH_SIZE]
logger.info(f"Upserting batch of {len(batch_documents)} documents, batch {i}")
logger.info("documents: ", documents)
await datastore.upsert(batch_documents)
# print the skipped items
logger.info(f"Skipped {len(skipped_items)} items due to errors or PII detection")
for item in skipped_items:
logger.info(item)
async def main():
# parse the command-line arguments
parser = argparse.ArgumentParser()
parser.add_argument("--filepath", required=True, help="The path to the json dump")
parser.add_argument(
"--custom_metadata",
default="{}",
help="A JSON string of key-value pairs to update the metadata of the documents",
)
parser.add_argument(
"--screen_for_pii",
default=False,
type=bool,
help="A boolean flag to indicate whether to try the PII detection function (using a language model)",
)
parser.add_argument(
"--extract_metadata",
default=False,
type=bool,
help="A boolean flag to indicate whether to try to extract metadata from the document (using a language model)",
)
args = parser.parse_args()
# get the arguments
filepath = args.filepath
custom_metadata = json.loads(args.custom_metadata)
screen_for_pii = args.screen_for_pii
extract_metadata = args.extract_metadata
# initialize the db instance once as a global variable
datastore = await get_datastore()
# process the json dump
await process_json_dump(
filepath, datastore, custom_metadata, screen_for_pii, extract_metadata
)
if __name__ == "__main__":
asyncio.run(main())