-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathweaviate_indexer.py
134 lines (113 loc) · 4.39 KB
/
weaviate_indexer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import sys
import logging
import warnings
from typing import Any, Dict, List
from llama_index.embeddings.openai import OpenAIEmbedding
from llama_index.core import PromptHelper, ServiceContext, GPTVectorStoreIndex
from llama_index.vector_stores.weaviate import WeaviateVectorStore
from llama_index.core import StorageContext
import weaviate
warnings.simplefilter("ignore", ResourceWarning)
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
logger = logging.getLogger(__name__)
# Constants
EMBED_MODEL_NAME="text-embedding-3-large"
CONTEXT_WINDOW = 4096
NUM_OUTPUT = 256
CHUNK_OVERLAP_RATIO = 0.1
# Copied from weaviate_indexer to:
# 1) upgrade string->text for proper tokenization
# 2) set tokenization which defaults to whitespace for some reason
# 3) disable indexes on metadata json
NODE_SCHEMA: List[Dict] = [
{
"name": "ref_doc_id",
"dataType": ["text"],
"description": "The ref_doc_id of the Node"
},
{
"name": "_node_content",
"dataType": ["text"],
"description": "Node content (in serialized JSON)",
"indexFilterable": False,
"indexSearchable": False,
"tokenization": 'word'
},
{
"name": "text",
"dataType": ["text"],
"description": "Full text of the node",
"tokenization": 'word'
},
{
"name": "title",
"dataType": ["text"],
"description": "The title of the document",
"tokenization": 'word'
},
{
"name": "link",
"dataType": ["text"],
"description": "HTTP link to the source document",
"tokenization": 'field'
},
{
"name": "source",
"dataType": ["text"],
"description": "Data source for the source document",
"tokenization": 'field'
}
]
def create_schema(client: Any, class_prefix: str) -> None:
"""Create schema."""
# first check if schema exists
schema = client.schema.get()
classes = schema["classes"]
existing_class_names = {c["class"] for c in classes}
# if schema already exists, don't create
class_name = _class_name(class_prefix)
if class_name in existing_class_names:
return
properties = NODE_SCHEMA
class_obj = {
"class": _class_name(class_prefix), # <= note the capital "A".
"description": f"Class for {class_name}",
"properties": properties,
}
client.schema.create_class(class_obj)
def _class_name(class_prefix: str) -> str:
"""Return class name."""
return f"{class_prefix}_Node"
class Indexer():
def __init__(self, weaviate_url, class_prefix, delete_database):
self.weaviate_url = weaviate_url
self.class_prefix = class_prefix
self.delete_database = delete_database
def index(self, documents):
# Connect to Weaviate database
client = weaviate.Client(self.weaviate_url)
if not client.is_live():
logger.error(f"Weaviate is not live at {self.weaviate_url}")
sys.exit(1)
if not client.is_live():
logger.error(f"Weaviate is not ready at {self.weaviate_url}")
sys.exit(1)
logger.info(f"Connected to Weaviate at {self.weaviate_url} (Version {client.get_meta()['version']})")
# Delete existing data in Weaviate
class_prefix = self.class_prefix
if self.delete_database:
class_name = _class_name(class_prefix)
logger.warning(f"Deleting {class_name} class in Weaviate")
client.schema.delete_class(class_name)
logger.info(f"Creating {class_name} class in Weaviate")
create_schema(client, class_prefix)
# Create LLM embedding model
embed_model = OpenAIEmbedding(embed_batch_size=20, model=EMBED_MODEL_NAME)
prompt_helper = PromptHelper(CONTEXT_WINDOW, NUM_OUTPUT, CHUNK_OVERLAP_RATIO)
service_context = ServiceContext.from_defaults(embed_model=embed_model, prompt_helper=prompt_helper)
# Embed the documents and persist the embeddings into Weaviate
logger.info("Creating GPT vector store index")
vector_store = WeaviateVectorStore(weaviate_client=client, class_prefix=class_prefix)
storage_context = StorageContext.from_defaults(vector_store=vector_store)
GPTVectorStoreIndex.from_documents(documents, storage_context=storage_context, service_context=service_context)
logger.info(f"Completed indexing into '{class_prefix}_Node'")