Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pinecone fix vector search #150

Merged
merged 57 commits into from
Aug 23, 2023
Merged
Show file tree
Hide file tree
Changes from 50 commits
Commits
Show all changes
57 commits
Select commit Hold shift + click to select a range
c3b3539
added finetune_model.py
henri123lemoine Aug 16, 2023
7bd986a
updated finetune_model.py
henri123lemoine Aug 16, 2023
877bcfe
added finetune_embeddings_tests to main for testing
henri123lemoine Aug 16, 2023
1e444cd
set up iterable dataset for finetuning
henri123lemoine Aug 16, 2023
03d0d54
Merge remote-tracking branch 'origin' into finetune-embeddings
henri123lemoine Aug 16, 2023
041d48f
minor refactor
henri123lemoine Aug 16, 2023
abee3f8
added model.py
henri123lemoine Aug 18, 2023
c41de70
added get_embeddings_by_ids method to pinecone handler
henri123lemoine Aug 18, 2023
5a629cb
simplified update_pinecone.py
henri123lemoine Aug 18, 2023
a0977fb
added (probably unnecessary) check on text splitter
henri123lemoine Aug 18, 2023
5472613
added dataset.py
henri123lemoine Aug 18, 2023
27629a8
added utils functions
henri123lemoine Aug 18, 2023
83b6f79
updated settings
henri123lemoine Aug 18, 2023
8ac789f
(incorrectly) trained finetuning layer. Best so far, but BAD
henri123lemoine Aug 18, 2023
b5c0457
changed the session maker to reuse engines
henri123lemoine Aug 18, 2023
b65db76
added train_finetuning_layer method to main
henri123lemoine Aug 18, 2023
7458b4d
added youtube api key to env.example
henri123lemoine Aug 18, 2023
891808d
simplified dataset.py
henri123lemoine Aug 21, 2023
d2b0ff7
update_pinecone refactor
henri123lemoine Aug 21, 2023
706aa9b
added force_update to pinecone_update methods
henri123lemoine Aug 21, 2023
57dc4e7
small refactor+removed bias+set namespace
henri123lemoine Aug 21, 2023
141ded5
added query method to pineconedb
henri123lemoine Aug 21, 2023
950b69f
Merge remote-tracking branch 'origin/main' into finetune-embeddings
henri123lemoine Aug 21, 2023
15e4bec
minor refactor
henri123lemoine Aug 21, 2023
f23d3b1
reformat chunk headers
henri123lemoine Aug 21, 2023
840ca8c
changes_to_pinecone
Thomas-Lemoine Aug 22, 2023
faff6cd
added moderation to get_embeddings util
henri123lemoine Aug 22, 2023
a749248
remove duplicate function
henri123lemoine Aug 22, 2023
ddb5f8e
renamed embed_query to get_embedding
henri123lemoine Aug 22, 2023
075d409
black reformatting
henri123lemoine Aug 22, 2023
e4cf11c
set black line-length to 100
henri123lemoine Aug 22, 2023
c206daf
renamed finetuning files
henri123lemoine Aug 22, 2023
34bac92
dealt with metadata keys
henri123lemoine Aug 22, 2023
db6a52d
updated utils
henri123lemoine Aug 22, 2023
6c07f7e
added comments when flagged
henri123lemoine Aug 22, 2023
cc86538
add comments and ~fix hash_id shenanigans
henri123lemoine Aug 22, 2023
05bf335
moved embedding utils out of common/utils.py
Thomas-Lemoine Aug 22, 2023
513b87a
hf_embeddings slight refactor
Thomas-Lemoine Aug 22, 2023
03f641b
engine rename and autoflush inside session init for better type signa…
Thomas-Lemoine Aug 22, 2023
b6a8c8c
simplified openai error types
Thomas-Lemoine Aug 22, 2023
c229cc2
restructured finetuning and pinecone dirs
henri123lemoine Aug 22, 2023
afe91af
added auto code formatting with black on push/PR
henri123lemoine Aug 22, 2023
ed034da
Testing pre-commit hook
henri123lemoine Aug 22, 2023
e6864b7
PR fixes
henri123lemoine Aug 23, 2023
d5b7657
moved text_splitter.py, fixed minor typing issues
henri123lemoine Aug 23, 2023
bee0106
minor bug fix
henri123lemoine Aug 23, 2023
78812f1
minor typing and imports bug-fix
henri123lemoine Aug 23, 2023
3d8a9d9
simplify get random chunks
Thomas-Lemoine Aug 23, 2023
18fe229
removed EmbeddingType
Thomas-Lemoine Aug 23, 2023
78192e3
moved sources tests to test/align_data/sources
henri123lemoine Aug 23, 2023
b44fe88
Tidy up (#144)
mruwnik Aug 21, 2023
ebcbeec
disable pinecone encoding in actions (#148)
mruwnik Aug 21, 2023
0b940d9
Transformer-circuits blog (#146)
mruwnik Aug 22, 2023
e7e19b9
write to correct database when updating metadata (#149)
mruwnik Aug 22, 2023
42e75f0
Bunch up blogs, special_docs and youtube (#147)
mruwnik Aug 22, 2023
cc9df48
Merge branch 'main' into pinecone-fix-vector-search
henri123lemoine Aug 23, 2023
176b052
skip entries with falsy field values (#154)
Thomas-Lemoine Aug 23, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@ OPENAI_API_KEY="sk-XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX"
PINECONE_INDEX_NAME="stampy-chat-ard"
PINECONE_API_KEY="xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx"
PINECONE_ENVIRONMENT="xx-xxxxx-gcp"
YOUTUBE_API_KEY=""
11 changes: 11 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer
- repo: https://github.com/psf/black
rev: 23.7.0
hooks:
- id: black
language_version: python3.11
4 changes: 1 addition & 3 deletions align_data/analysis/analyse_jsonl_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,7 @@ def process_jsonl_files(data_dir):

for id, duplicates in seen_urls.items():
if len(duplicates) > 1:
list_of_duplicates = "\n".join(
get_data_dict_str(duplicate) for duplicate in duplicates
)
list_of_duplicates = "\n".join(get_data_dict_str(duplicate) for duplicate in duplicates)
print(
f"{len(duplicates)} duplicate ids found. \nId: {id}\n{list_of_duplicates}\n\n\n\n"
)
Expand Down
23 changes: 10 additions & 13 deletions align_data/common/alignment_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,14 +164,9 @@ def _load_outputted_items(self) -> Set[str]:
# This doesn't filter by self.name. The good thing about that is that it should handle a lot more
# duplicates. The bad thing is that this could potentially return a massive amount of data if there
# are lots of items.
return set(
session.scalars(select(getattr(Article, self.done_key))).all()
)
return set(session.scalars(select(getattr(Article, self.done_key))).all())
# TODO: Properly handle this - it should create a proper SQL JSON select
return {
item.get(self.done_key)
for item in session.scalars(select(Article.meta)).all()
}
return {item.get(self.done_key) for item in session.scalars(select(Article.meta)).all()}

def not_processed(self, item):
# NOTE: `self._outputted_items` reads in all items. Which could potentially be a lot. If this starts to
Expand Down Expand Up @@ -214,7 +209,7 @@ def fetch_entries(self):
if self.COOLDOWN:
time.sleep(self.COOLDOWN)

def process_entry(self, entry) -> Optional[Article]:
def process_entry(self, entry) -> Article | None:
"""Process a single entry."""
raise NotImplementedError

Expand All @@ -223,7 +218,7 @@ def _format_datetime(date) -> str:
return date.strftime("%Y-%m-%dT%H:%M:%SZ")

@staticmethod
def _get_published_date(date) -> Optional[datetime]:
def _get_published_date(date) -> datetime | None:
try:
# Totally ignore any timezone info, forcing everything to UTC
return parse(str(date)).replace(tzinfo=pytz.UTC)
Expand All @@ -239,7 +234,11 @@ def unprocessed_items(self, items=None) -> Iterable:

urls = map(self.get_item_key, items)
with make_session() as session:
articles = session.query(Article).options(joinedload(Article.summaries)).filter(Article.url.in_(urls))
articles = (
session.query(Article)
.options(joinedload(Article.summaries))
.filter(Article.url.in_(urls))
)
self.articles = {a.url: a for a in articles if a.url}

return items
Expand All @@ -249,9 +248,7 @@ def _load_outputted_items(self) -> Set[str]:
with make_session() as session:
return set(
session.scalars(
select(Article.url)
.join(Article.summaries)
.filter(Summary.source == self.name)
select(Article.url).join(Article.summaries).filter(Summary.source == self.name)
henri123lemoine marked this conversation as resolved.
Show resolved Hide resolved
)
)

Expand Down
2 changes: 1 addition & 1 deletion align_data/common/html_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def get_contents(self, article_url):
def process_entry(self, article):
article_url = self.get_item_key(article)
contents = self.get_contents(article_url)
if not contents.get('text'):
if not contents.get("text"):
return None

return self.make_data_entry(contents)
Expand Down
44 changes: 19 additions & 25 deletions align_data/db/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
from sqlalchemy.dialects.mysql import LONGTEXT
from sqlalchemy.ext.hybrid import hybrid_property
from align_data.settings import PINECONE_METADATA_KEYS


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -71,33 +70,26 @@ class Article(Base):
def __repr__(self) -> str:
return f"Article(id={self.id!r}, title={self.title!r}, url={self.url!r}, source={self.source!r}, authors={self.authors!r}, date_published={self.date_published!r})"

def is_metadata_keys_equal(self, other):
if not isinstance(other, Article):
raise TypeError(
f"Expected an instance of Article, got {type(other).__name__}"
)
return not any(
getattr(self, key, None)
!= getattr(other, key, None) # entry_id is implicitly ignored
for key in PINECONE_METADATA_KEYS
)

def generate_id_string(self) -> bytes:
return "".join(str(getattr(self, field)) for field in self.__id_fields).encode(
"utf-8"
)
return "".join(str(getattr(self, field)) for field in self.__id_fields).encode("utf-8")

@property
def __id_fields(self):
if self.source == 'aisafety.info':
return ['url']
if self.source in ['importai', 'ml_safety_newsletter', 'alignment_newsletter']:
return ['url', 'title', 'source']
if self.source == "aisafety.info":
return ["url"]
if self.source in ["importai", "ml_safety_newsletter", "alignment_newsletter"]:
return ["url", "title", "source"]
return ["url", "title"]

@property
def missing_fields(self):
fields = set(self.__id_fields) | {'text', 'title', 'url', 'source', 'date_published'}
fields = set(self.__id_fields) | {
"text",
"title",
"url",
"source",
"date_published",
}
return sorted([field for field in fields if not getattr(self, field, None)])

def verify_id(self):
Expand Down Expand Up @@ -136,10 +128,12 @@ def add_meta(self, key, val):
@hybrid_property
def is_valid(self):
return (
self.text and self.text.strip() and
self.url and self.title and
self.authors is not None and
self.status == OK_STATUS
self.text
and self.text.strip()
and self.url
and self.title
and self.authors is not None
and self.status == OK_STATUS
)

@is_valid.expression
Expand All @@ -157,7 +151,7 @@ def before_write(cls, _mapper, _connection, target):
target.verify_id_fields()

if not target.status and target.missing_fields:
target.status = 'Missing fields'
target.status = "Missing fields"
target.comments = f'missing fields: {", ".join(target.missing_fields)}'

if target.id:
Expand Down
26 changes: 20 additions & 6 deletions align_data/db/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,24 +10,38 @@

logger = logging.getLogger(__name__)

# We create a single engine for the entire application
engine = create_engine(DB_CONNECTION_URI, echo=False)


@contextmanager
def make_session(auto_commit=False):
engine = create_engine(DB_CONNECTION_URI, echo=False)
with Session(engine).no_autoflush as session:
with Session(engine, autoflush=False) as session:
yield session
if auto_commit:
session.commit()


def stream_pinecone_updates(session, custom_sources: List[str]):
def stream_pinecone_updates(
session: Session, custom_sources: List[str], force_update: bool = False
):
"""Yield Pinecone entries that require an update."""
yield from (
session
.query(Article)
.filter(Article.pinecone_update_required.is_(True))
session.query(Article)
.filter(or_(Article.pinecone_update_required.is_(True), force_update))
.filter(Article.is_valid)
.filter(Article.source.in_(custom_sources))
.filter(or_(Article.confidence == None, Article.confidence > MIN_CONFIDENCE))
.yield_per(1000)
)


def get_all_valid_article_ids(session: Session) -> List[str]:
"""Return all valid article IDs."""
query_result = (
session.query(Article.id)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

session.scalars should do the trick, without having to manually extract them

.filter(Article.is_valid)
.filter(or_(Article.confidence == None, Article.confidence > MIN_CONFIDENCE))
.all()
)
return [item[0] for item in query_result]
Loading
Loading