Skip to content

Commit

Permalink
Merge pull request #39 from StampyAI/update-sql-and-pinecone
Browse files Browse the repository at this point in the history
Update sql and pinecone
  • Loading branch information
henri123lemoine authored Jul 5, 2023
2 parents bae0fcb + 511f231 commit d9abc7c
Show file tree
Hide file tree
Showing 10 changed files with 434 additions and 574 deletions.
7 changes: 7 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,8 @@ dmypy.json
.pyre/

# Other
*test.py
*test.ipynb
*alignment_texts.jsonl
*config.py
*.DS_Store
Expand All @@ -145,3 +147,8 @@ api/dataset_big.pkl
api/dataset_300.pkl

api/.env.backup

src/dataset_tests.ipynb
src/ARD_LangChain_QA_Chat.ipynb

src/dataset/data/*
3 changes: 3 additions & 0 deletions src/.env.example
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
OPENAI_API_KEY="sk-XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX"
PINECONE_API_KEY="xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx"
PINECONE_ENVIRONMENT="xx-xxxxx-gcp"
401 changes: 0 additions & 401 deletions src/dataset/create_dataset.py

This file was deleted.

100 changes: 100 additions & 0 deletions src/dataset/pinecone_db_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# dataset/pinecone_db_handler.py

import os
import json
import pinecone

from .settings import PINECONE_INDEX_NAME, PINECONE_VALUES_DIMS, PINECONE_METRIC, PINECONE_METADATA_ENTRIES, PINECONE_API_KEY, PINECONE_ENVIRONMENT

import logging
logger = logging.getLogger(__name__)


class PineconeDB:
def __init__(
self,
create_index: bool = False,
):
self.index_name = PINECONE_INDEX_NAME

pinecone.init(
api_key = PINECONE_API_KEY,
environment = PINECONE_ENVIRONMENT,
)

if create_index:
self.create_index()

self.index = pinecone.Index(index_name=self.index_name)

def __str__(self) -> str:
index_stats_response = self.index.describe_index_stats()
return f"{self.index_name}:\n{json.dumps(index_stats_response, indent=4)}"

def upsert_entry(self, entry, chunks, embeddings, upsert_size=100):
self.index.upsert(
vectors=list(
zip(
[f"{entry['id']}_{str(i).zfill(6)}" for i in range(len(chunks))],
embeddings.tolist(),
[
{
'entry_id': entry['id'],
'source': entry['source'],
'title': entry['title'],
'authors': entry['authors'],
'text': chunk,
} for chunk in chunks
]
)
),
batch_size=upsert_size
)

def upsert_entries(self, entries_batch, chunks_batch, chunks_ids_batch, embeddings, upsert_size=100):
self.index.upsert(
vectors=list(
zip(
chunks_ids_batch,
embeddings.tolist(),
[
{
'entry_id': entry['id'],
'source': entry['source'],
'title': entry['title'],
'authors': entry['authors'],
'text': chunk,
}
for entry in entries_batch
for chunk in chunks_batch
]
)
),
batch_size=upsert_size
)

def delete_entry(self, id):
self.index.delete(
filter={"entry_id": {"$eq": id}}
)

def delete_entries(self, ids):
self.index.delete(
filter={"entry_id": {"$in": ids}}
)

def create_index(self, replace_current_index: bool = True):
if replace_current_index:
self.delete_index()

pinecone.create_index(
name=self.index_name,
dimension=PINECONE_VALUES_DIMS,
metric=PINECONE_METRIC,
metadata_config = {"indexed": PINECONE_METADATA_ENTRIES}
)

def delete_index(self):
if self.index_name in pinecone.list_indexes():
logger.info(f"Deleting index '{self.index_name}'.")
pinecone.delete_index(self.index_name)
35 changes: 27 additions & 8 deletions src/dataset/settings.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,31 @@
# dataset/settings.py

import os
import torch
from pathlib import Path

EMBEDDING_MODEL = "text-embedding-ada-002"
COMPLETIONS_MODEL = "gpt-3.5-turbo"
### FILE PATHS ###
current_file_path = Path(__file__).resolve()
SQL_DB_PATH = str(current_file_path.parent / 'data' / 'ARD.db')

LEN_EMBEDDINGS = 1536
MAX_LEN_PROMPT = 4095 # This may be 8191, unsure.
### DATASET ###
ARD_DATASET_NAME = "StampyAI/alignment-research-dataset"

current_file_path = Path(__file__).resolve()
PATH_TO_RAW_DATA = str(current_file_path.parent / 'data' / 'alignment_texts.jsonl')
PATH_TO_DATASET_PKL = str(current_file_path.parent / 'data' / 'dataset.pkl')
PATH_TO_DATASET_DICT_PKL = str(current_file_path.parent / 'data' / 'dataset_dict.pkl')
### EMBEDDINGS ###
USE_OPENAI_EMBEDDINGS = False
OPENAI_EMBEDDINGS_MODEL = "text-embedding-ada-002"
EMBEDDINGS_DIMS = 1536
OPENAI_EMBEDDINGS_RATE_LIMIT = 3500
SENTENCE_TRANSFORMER_EMBEDDINGS_MODEL = "sentence-transformers/multi-qa-mpnet-base-cos-v1"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

### PINECONE ###
PINECONE_INDEX_NAME = "stampy-chat-embeddings-test"
PINECONE_VALUES_DIMS = EMBEDDINGS_DIMS
PINECONE_METRIC = "cosine"
PINECONE_METADATA_ENTRIES = ["entry_id", "source", "title", "authors", "text"]
PINECONE_API_KEY = os.environ["PINECONE_API_KEY"]
PINECONE_ENVIRONMENT = os.environ["PINECONE_ENVIRONMENT"]

### MISCELLANEOUS ###
MAX_NUM_AUTHORS_IN_SIGNATURE = 3
104 changes: 104 additions & 0 deletions src/dataset/sql_db_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
# dataset/sql_db_handler.py

from typing import List, Dict, Union
import sqlite3

from .settings import SQL_DB_PATH

import logging
logger = logging.getLogger(__name__)


class SQLDB:
def __init__(self):
self.db_name = SQL_DB_PATH

self.create_tables()

def create_tables(self, reset: bool = False):
with sqlite3.connect(self.db_name) as conn:
cursor = conn.cursor()
try:
if reset:
# Drop the tables if reset is True
cursor.execute("DROP TABLE IF EXISTS entry_database")
cursor.execute("DROP TABLE IF EXISTS chunk_database")

# Create entry table
query = """
CREATE TABLE IF NOT EXISTS entry_database (
id TEXT PRIMARY KEY,
source TEXT,
title TEXT,
text TEXT,
url TEXT,
date_published TEXT,
authors TEXT
)
"""
cursor.execute(query)

# Create chunk table
query = """
CREATE TABLE IF NOT EXISTS chunk_database (
id TEXT PRIMARY KEY,
text TEXT,
entry_id TEXT,
FOREIGN KEY (entry_id) REFERENCES entry_database(id)
)
"""
cursor.execute(query)

except sqlite3.Error as e:
logger.error(f"The error '{e}' occurred.")

def upsert_entry(self, entry: Dict[str, Union[str, list]]) -> bool:
with sqlite3.connect(self.db_name) as conn:
cursor = conn.cursor()
try:
# Fetch existing data
cursor.execute("SELECT * FROM entry_database WHERE id=?", (entry['id'],))
existing_entry = cursor.fetchone()

new_entry = (
entry['id'],
entry['source'],
entry['title'],
entry['text'],
entry['url'],
entry['date_published'],
', '.join(entry['authors'])
)

if existing_entry != new_entry:
query = """
INSERT OR REPLACE INTO entry_database
(id, source, title, text, url, date_published, authors)
VALUES (?, ?, ?, ?, ?, ?, ?)
"""
cursor.execute(query, new_entry)
return True
else:
return False

except sqlite3.Error as e:
logger.error(f"The error '{e}' occurred.")
return False

finally:
conn.commit()

def upsert_chunks(self, chunks_ids_batch: List[str], chunks_batch: List[str]) -> bool:
with sqlite3.connect(self.db_name) as conn:
cursor = conn.cursor()
try:
for chunk_id, chunk in zip(chunks_ids_batch, chunks_batch):
cursor.execute("""
INSERT OR REPLACE INTO chunk_database
(id, text)
VALUES (?, ?)
""", (chunk_id, chunk))
except sqlite3.Error as e:
logger.error(f"The error '{e}' occurred.")
finally:
conn.commit()
2 changes: 2 additions & 0 deletions src/dataset/text_splitter.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# dataset/text_splitter.py

import re
from typing import List
import tiktoken
Expand Down
Loading

0 comments on commit d9abc7c

Please sign in to comment.