-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #39 from StampyAI/update-sql-and-pinecone
Update sql and pinecone
- Loading branch information
Showing
10 changed files
with
434 additions
and
574 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
OPENAI_API_KEY="sk-XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX" | ||
PINECONE_API_KEY="xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx" | ||
PINECONE_ENVIRONMENT="xx-xxxxx-gcp" |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
# dataset/pinecone_db_handler.py | ||
|
||
import os | ||
import json | ||
import pinecone | ||
|
||
from .settings import PINECONE_INDEX_NAME, PINECONE_VALUES_DIMS, PINECONE_METRIC, PINECONE_METADATA_ENTRIES, PINECONE_API_KEY, PINECONE_ENVIRONMENT | ||
|
||
import logging | ||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class PineconeDB: | ||
def __init__( | ||
self, | ||
create_index: bool = False, | ||
): | ||
self.index_name = PINECONE_INDEX_NAME | ||
|
||
pinecone.init( | ||
api_key = PINECONE_API_KEY, | ||
environment = PINECONE_ENVIRONMENT, | ||
) | ||
|
||
if create_index: | ||
self.create_index() | ||
|
||
self.index = pinecone.Index(index_name=self.index_name) | ||
|
||
def __str__(self) -> str: | ||
index_stats_response = self.index.describe_index_stats() | ||
return f"{self.index_name}:\n{json.dumps(index_stats_response, indent=4)}" | ||
|
||
def upsert_entry(self, entry, chunks, embeddings, upsert_size=100): | ||
self.index.upsert( | ||
vectors=list( | ||
zip( | ||
[f"{entry['id']}_{str(i).zfill(6)}" for i in range(len(chunks))], | ||
embeddings.tolist(), | ||
[ | ||
{ | ||
'entry_id': entry['id'], | ||
'source': entry['source'], | ||
'title': entry['title'], | ||
'authors': entry['authors'], | ||
'text': chunk, | ||
} for chunk in chunks | ||
] | ||
) | ||
), | ||
batch_size=upsert_size | ||
) | ||
|
||
def upsert_entries(self, entries_batch, chunks_batch, chunks_ids_batch, embeddings, upsert_size=100): | ||
self.index.upsert( | ||
vectors=list( | ||
zip( | ||
chunks_ids_batch, | ||
embeddings.tolist(), | ||
[ | ||
{ | ||
'entry_id': entry['id'], | ||
'source': entry['source'], | ||
'title': entry['title'], | ||
'authors': entry['authors'], | ||
'text': chunk, | ||
} | ||
for entry in entries_batch | ||
for chunk in chunks_batch | ||
] | ||
) | ||
), | ||
batch_size=upsert_size | ||
) | ||
|
||
def delete_entry(self, id): | ||
self.index.delete( | ||
filter={"entry_id": {"$eq": id}} | ||
) | ||
|
||
def delete_entries(self, ids): | ||
self.index.delete( | ||
filter={"entry_id": {"$in": ids}} | ||
) | ||
|
||
def create_index(self, replace_current_index: bool = True): | ||
if replace_current_index: | ||
self.delete_index() | ||
|
||
pinecone.create_index( | ||
name=self.index_name, | ||
dimension=PINECONE_VALUES_DIMS, | ||
metric=PINECONE_METRIC, | ||
metadata_config = {"indexed": PINECONE_METADATA_ENTRIES} | ||
) | ||
|
||
def delete_index(self): | ||
if self.index_name in pinecone.list_indexes(): | ||
logger.info(f"Deleting index '{self.index_name}'.") | ||
pinecone.delete_index(self.index_name) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,12 +1,31 @@ | ||
# dataset/settings.py | ||
|
||
import os | ||
import torch | ||
from pathlib import Path | ||
|
||
EMBEDDING_MODEL = "text-embedding-ada-002" | ||
COMPLETIONS_MODEL = "gpt-3.5-turbo" | ||
### FILE PATHS ### | ||
current_file_path = Path(__file__).resolve() | ||
SQL_DB_PATH = str(current_file_path.parent / 'data' / 'ARD.db') | ||
|
||
LEN_EMBEDDINGS = 1536 | ||
MAX_LEN_PROMPT = 4095 # This may be 8191, unsure. | ||
### DATASET ### | ||
ARD_DATASET_NAME = "StampyAI/alignment-research-dataset" | ||
|
||
current_file_path = Path(__file__).resolve() | ||
PATH_TO_RAW_DATA = str(current_file_path.parent / 'data' / 'alignment_texts.jsonl') | ||
PATH_TO_DATASET_PKL = str(current_file_path.parent / 'data' / 'dataset.pkl') | ||
PATH_TO_DATASET_DICT_PKL = str(current_file_path.parent / 'data' / 'dataset_dict.pkl') | ||
### EMBEDDINGS ### | ||
USE_OPENAI_EMBEDDINGS = False | ||
OPENAI_EMBEDDINGS_MODEL = "text-embedding-ada-002" | ||
EMBEDDINGS_DIMS = 1536 | ||
OPENAI_EMBEDDINGS_RATE_LIMIT = 3500 | ||
SENTENCE_TRANSFORMER_EMBEDDINGS_MODEL = "sentence-transformers/multi-qa-mpnet-base-cos-v1" | ||
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | ||
|
||
### PINECONE ### | ||
PINECONE_INDEX_NAME = "stampy-chat-embeddings-test" | ||
PINECONE_VALUES_DIMS = EMBEDDINGS_DIMS | ||
PINECONE_METRIC = "cosine" | ||
PINECONE_METADATA_ENTRIES = ["entry_id", "source", "title", "authors", "text"] | ||
PINECONE_API_KEY = os.environ["PINECONE_API_KEY"] | ||
PINECONE_ENVIRONMENT = os.environ["PINECONE_ENVIRONMENT"] | ||
|
||
### MISCELLANEOUS ### | ||
MAX_NUM_AUTHORS_IN_SIGNATURE = 3 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
# dataset/sql_db_handler.py | ||
|
||
from typing import List, Dict, Union | ||
import sqlite3 | ||
|
||
from .settings import SQL_DB_PATH | ||
|
||
import logging | ||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class SQLDB: | ||
def __init__(self): | ||
self.db_name = SQL_DB_PATH | ||
|
||
self.create_tables() | ||
|
||
def create_tables(self, reset: bool = False): | ||
with sqlite3.connect(self.db_name) as conn: | ||
cursor = conn.cursor() | ||
try: | ||
if reset: | ||
# Drop the tables if reset is True | ||
cursor.execute("DROP TABLE IF EXISTS entry_database") | ||
cursor.execute("DROP TABLE IF EXISTS chunk_database") | ||
|
||
# Create entry table | ||
query = """ | ||
CREATE TABLE IF NOT EXISTS entry_database ( | ||
id TEXT PRIMARY KEY, | ||
source TEXT, | ||
title TEXT, | ||
text TEXT, | ||
url TEXT, | ||
date_published TEXT, | ||
authors TEXT | ||
) | ||
""" | ||
cursor.execute(query) | ||
|
||
# Create chunk table | ||
query = """ | ||
CREATE TABLE IF NOT EXISTS chunk_database ( | ||
id TEXT PRIMARY KEY, | ||
text TEXT, | ||
entry_id TEXT, | ||
FOREIGN KEY (entry_id) REFERENCES entry_database(id) | ||
) | ||
""" | ||
cursor.execute(query) | ||
|
||
except sqlite3.Error as e: | ||
logger.error(f"The error '{e}' occurred.") | ||
|
||
def upsert_entry(self, entry: Dict[str, Union[str, list]]) -> bool: | ||
with sqlite3.connect(self.db_name) as conn: | ||
cursor = conn.cursor() | ||
try: | ||
# Fetch existing data | ||
cursor.execute("SELECT * FROM entry_database WHERE id=?", (entry['id'],)) | ||
existing_entry = cursor.fetchone() | ||
|
||
new_entry = ( | ||
entry['id'], | ||
entry['source'], | ||
entry['title'], | ||
entry['text'], | ||
entry['url'], | ||
entry['date_published'], | ||
', '.join(entry['authors']) | ||
) | ||
|
||
if existing_entry != new_entry: | ||
query = """ | ||
INSERT OR REPLACE INTO entry_database | ||
(id, source, title, text, url, date_published, authors) | ||
VALUES (?, ?, ?, ?, ?, ?, ?) | ||
""" | ||
cursor.execute(query, new_entry) | ||
return True | ||
else: | ||
return False | ||
|
||
except sqlite3.Error as e: | ||
logger.error(f"The error '{e}' occurred.") | ||
return False | ||
|
||
finally: | ||
conn.commit() | ||
|
||
def upsert_chunks(self, chunks_ids_batch: List[str], chunks_batch: List[str]) -> bool: | ||
with sqlite3.connect(self.db_name) as conn: | ||
cursor = conn.cursor() | ||
try: | ||
for chunk_id, chunk in zip(chunks_ids_batch, chunks_batch): | ||
cursor.execute(""" | ||
INSERT OR REPLACE INTO chunk_database | ||
(id, text) | ||
VALUES (?, ?) | ||
""", (chunk_id, chunk)) | ||
except sqlite3.Error as e: | ||
logger.error(f"The error '{e}' occurred.") | ||
finally: | ||
conn.commit() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,5 @@ | ||
# dataset/text_splitter.py | ||
|
||
import re | ||
from typing import List | ||
import tiktoken | ||
|
Oops, something went wrong.