diff --git a/.env.example b/.env.example
index 057b5ba4..1c29520b 100644
--- a/.env.example
+++ b/.env.example
@@ -9,3 +9,4 @@ OPENAI_API_KEY="sk-XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX"
PINECONE_INDEX_NAME="stampy-chat-ard"
PINECONE_API_KEY="xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx"
PINECONE_ENVIRONMENT="xx-xxxxx-gcp"
+YOUTUBE_API_KEY=""
\ No newline at end of file
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
new file mode 100644
index 00000000..61a81a7b
--- /dev/null
+++ b/.pre-commit-config.yaml
@@ -0,0 +1,11 @@
+repos:
+- repo: https://github.com/pre-commit/pre-commit-hooks
+ rev: v4.4.0
+ hooks:
+ - id: trailing-whitespace
+ - id: end-of-file-fixer
+- repo: https://github.com/psf/black
+ rev: 23.7.0
+ hooks:
+ - id: black
+ language_version: python3.11
diff --git a/align_data/analysis/analyse_jsonl_data.py b/align_data/analysis/analyse_jsonl_data.py
index b8c5103a..0aed124d 100644
--- a/align_data/analysis/analyse_jsonl_data.py
+++ b/align_data/analysis/analyse_jsonl_data.py
@@ -68,9 +68,7 @@ def process_jsonl_files(data_dir):
for id, duplicates in seen_urls.items():
if len(duplicates) > 1:
- list_of_duplicates = "\n".join(
- get_data_dict_str(duplicate) for duplicate in duplicates
- )
+ list_of_duplicates = "\n".join(get_data_dict_str(duplicate) for duplicate in duplicates)
print(
f"{len(duplicates)} duplicate ids found. \nId: {id}\n{list_of_duplicates}\n\n\n\n"
)
diff --git a/align_data/common/alignment_dataset.py b/align_data/common/alignment_dataset.py
index 0bd8f62d..78753169 100644
--- a/align_data/common/alignment_dataset.py
+++ b/align_data/common/alignment_dataset.py
@@ -164,14 +164,9 @@ def _load_outputted_items(self) -> Set[str]:
# This doesn't filter by self.name. The good thing about that is that it should handle a lot more
# duplicates. The bad thing is that this could potentially return a massive amount of data if there
# are lots of items.
- return set(
- session.scalars(select(getattr(Article, self.done_key))).all()
- )
+ return set(session.scalars(select(getattr(Article, self.done_key))).all())
# TODO: Properly handle this - it should create a proper SQL JSON select
- return {
- item.get(self.done_key)
- for item in session.scalars(select(Article.meta)).all()
- }
+ return {item.get(self.done_key) for item in session.scalars(select(Article.meta)).all()}
def not_processed(self, item):
# NOTE: `self._outputted_items` reads in all items. Which could potentially be a lot. If this starts to
@@ -214,7 +209,7 @@ def fetch_entries(self):
if self.COOLDOWN:
time.sleep(self.COOLDOWN)
- def process_entry(self, entry) -> Optional[Article]:
+ def process_entry(self, entry) -> Article | None:
"""Process a single entry."""
raise NotImplementedError
@@ -223,7 +218,7 @@ def _format_datetime(date) -> str:
return date.strftime("%Y-%m-%dT%H:%M:%SZ")
@staticmethod
- def _get_published_date(date) -> Optional[datetime]:
+ def _get_published_date(date) -> datetime | None:
try:
# Totally ignore any timezone info, forcing everything to UTC
return parse(str(date)).replace(tzinfo=pytz.UTC)
@@ -239,7 +234,11 @@ def unprocessed_items(self, items=None) -> Iterable:
urls = map(self.get_item_key, items)
with make_session() as session:
- articles = session.query(Article).options(joinedload(Article.summaries)).filter(Article.url.in_(urls))
+ articles = (
+ session.query(Article)
+ .options(joinedload(Article.summaries))
+ .filter(Article.url.in_(urls))
+ )
self.articles = {a.url: a for a in articles if a.url}
return items
@@ -249,9 +248,7 @@ def _load_outputted_items(self) -> Set[str]:
with make_session() as session:
return set(
session.scalars(
- select(Article.url)
- .join(Article.summaries)
- .filter(Summary.source == self.name)
+ select(Article.url).join(Article.summaries).filter(Summary.source == self.name)
)
)
diff --git a/align_data/common/html_dataset.py b/align_data/common/html_dataset.py
index c90172a8..c6fbf797 100644
--- a/align_data/common/html_dataset.py
+++ b/align_data/common/html_dataset.py
@@ -75,7 +75,7 @@ def get_contents(self, article_url):
def process_entry(self, article):
article_url = self.get_item_key(article)
contents = self.get_contents(article_url)
- if not contents.get('text'):
+ if not contents.get("text"):
return None
return self.make_data_entry(contents)
diff --git a/align_data/db/models.py b/align_data/db/models.py
index c922c5b7..74817648 100644
--- a/align_data/db/models.py
+++ b/align_data/db/models.py
@@ -17,7 +17,6 @@
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
from sqlalchemy.dialects.mysql import LONGTEXT
from sqlalchemy.ext.hybrid import hybrid_property
-from align_data.settings import PINECONE_METADATA_KEYS
logger = logging.getLogger(__name__)
@@ -71,33 +70,26 @@ class Article(Base):
def __repr__(self) -> str:
return f"Article(id={self.id!r}, title={self.title!r}, url={self.url!r}, source={self.source!r}, authors={self.authors!r}, date_published={self.date_published!r})"
- def is_metadata_keys_equal(self, other):
- if not isinstance(other, Article):
- raise TypeError(
- f"Expected an instance of Article, got {type(other).__name__}"
- )
- return not any(
- getattr(self, key, None)
- != getattr(other, key, None) # entry_id is implicitly ignored
- for key in PINECONE_METADATA_KEYS
- )
-
def generate_id_string(self) -> bytes:
- return "".join(str(getattr(self, field)) for field in self.__id_fields).encode(
- "utf-8"
- )
+ return "".join(str(getattr(self, field)) for field in self.__id_fields).encode("utf-8")
@property
def __id_fields(self):
- if self.source == 'aisafety.info':
- return ['url']
- if self.source in ['importai', 'ml_safety_newsletter', 'alignment_newsletter']:
- return ['url', 'title', 'source']
+ if self.source == "aisafety.info":
+ return ["url"]
+ if self.source in ["importai", "ml_safety_newsletter", "alignment_newsletter"]:
+ return ["url", "title", "source"]
return ["url", "title"]
@property
def missing_fields(self):
- fields = set(self.__id_fields) | {'text', 'title', 'url', 'source', 'date_published'}
+ fields = set(self.__id_fields) | {
+ "text",
+ "title",
+ "url",
+ "source",
+ "date_published",
+ }
return sorted([field for field in fields if not getattr(self, field, None)])
def verify_id(self):
@@ -133,13 +125,21 @@ def add_meta(self, key, val):
self.meta = {}
self.meta[key] = val
+ def append_comment(self, comment: str):
+ """Appends a comment to the article.comments field. You must run session.commit() to save the comment to the database."""
+ if self.comments is None:
+ self.comments = ""
+ self.comments = f"{self.comments}\n\n{comment}".strip()
+
@hybrid_property
def is_valid(self):
return (
- self.text and self.text.strip() and
- self.url and self.title and
- self.authors is not None and
- self.status == OK_STATUS
+ self.text
+ and self.text.strip()
+ and self.url
+ and self.title
+ and self.authors is not None
+ and self.status == OK_STATUS
)
@is_valid.expression
@@ -157,7 +157,7 @@ def before_write(cls, _mapper, _connection, target):
target.verify_id_fields()
if not target.status and target.missing_fields:
- target.status = 'Missing fields'
+ target.status = "Missing fields"
target.comments = f'missing fields: {", ".join(target.missing_fields)}'
if target.id:
diff --git a/align_data/db/session.py b/align_data/db/session.py
index 66c571ce..4aa23a87 100644
--- a/align_data/db/session.py
+++ b/align_data/db/session.py
@@ -10,24 +10,38 @@
logger = logging.getLogger(__name__)
+# We create a single engine for the entire application
+engine = create_engine(DB_CONNECTION_URI, echo=False)
+
@contextmanager
def make_session(auto_commit=False):
- engine = create_engine(DB_CONNECTION_URI, echo=False)
- with Session(engine).no_autoflush as session:
+ with Session(engine, autoflush=False) as session:
yield session
if auto_commit:
session.commit()
-def stream_pinecone_updates(session, custom_sources: List[str]):
+def stream_pinecone_updates(
+ session: Session, custom_sources: List[str], force_update: bool = False
+):
"""Yield Pinecone entries that require an update."""
yield from (
- session
- .query(Article)
- .filter(Article.pinecone_update_required.is_(True))
+ session.query(Article)
+ .filter(or_(Article.pinecone_update_required.is_(True), force_update))
.filter(Article.is_valid)
.filter(Article.source.in_(custom_sources))
.filter(or_(Article.confidence == None, Article.confidence > MIN_CONFIDENCE))
.yield_per(1000)
)
+
+
+def get_all_valid_article_ids(session: Session) -> List[str]:
+ """Return all valid article IDs."""
+ query_result = (
+ session.query(Article.id)
+ .filter(Article.is_valid)
+ .filter(or_(Article.confidence == None, Article.confidence > MIN_CONFIDENCE))
+ .all()
+ )
+ return [item[0] for item in query_result]
diff --git a/align_data/embeddings/embedding_utils.py b/align_data/embeddings/embedding_utils.py
new file mode 100644
index 00000000..6ef57b88
--- /dev/null
+++ b/align_data/embeddings/embedding_utils.py
@@ -0,0 +1,199 @@
+import logging
+from typing import List, Tuple, Dict, Any, Optional
+from functools import wraps
+
+import openai
+from langchain.embeddings import HuggingFaceEmbeddings
+from openai.error import (
+ OpenAIError,
+ RateLimitError,
+ APIError,
+)
+from tenacity import (
+ retry,
+ stop_after_attempt,
+ wait_random_exponential,
+ retry_if_exception_type,
+ retry_if_exception,
+)
+
+from align_data.embeddings.pinecone.pinecone_models import MissingEmbeddingModelError
+from align_data.settings import (
+ USE_OPENAI_EMBEDDINGS,
+ OPENAI_EMBEDDINGS_MODEL,
+ EMBEDDING_LENGTH_BIAS,
+ SENTENCE_TRANSFORMER_EMBEDDINGS_MODEL,
+ DEVICE,
+)
+
+
+# --------------------
+# CONSTANTS & CONFIGURATION
+# --------------------
+
+logger = logging.getLogger(__name__)
+
+hf_embedding_model = None
+if not USE_OPENAI_EMBEDDINGS:
+ hf_embedding_model = HuggingFaceEmbeddings(
+ model_name=SENTENCE_TRANSFORMER_EMBEDDINGS_MODEL,
+ model_kwargs={"device": DEVICE},
+ encode_kwargs={"show_progress_bar": False},
+ )
+
+ModerationInfoType = Dict[str, Any]
+
+
+# --------------------
+# DECORATORS
+# --------------------
+
+
+def handle_openai_errors(func):
+ """Decorator to handle OpenAI-specific exceptions with retries."""
+
+ @wraps(func)
+ @retry(
+ wait=wait_random_exponential(multiplier=1, min=2, max=30),
+ stop=stop_after_attempt(6),
+ retry=retry_if_exception_type(RateLimitError)
+ | retry_if_exception_type(APIError)
+ | retry_if_exception(lambda e: "502" in str(e)),
+ )
+ def wrapper(*args, **kwargs):
+ try:
+ return func(*args, **kwargs)
+ except RateLimitError as e:
+ logger.warning(f"OpenAI Rate limit error. Trying again. Error: {e}")
+ raise
+ except APIError as e:
+ if "502" in str(e):
+ logger.warning(f"OpenAI 502 Bad Gateway error. Trying again. Error: {e}")
+ else:
+ logger.error(f"OpenAI API Error encountered: {e}")
+ raise
+ except OpenAIError as e:
+ logger.error(f"OpenAI Error encountered: {e}")
+ raise
+ except Exception as e:
+ logger.error(f"Unexpected error encountered: {e}")
+ raise
+
+ return wrapper
+
+
+# --------------------
+# MAIN FUNCTIONS
+# --------------------
+
+
+@handle_openai_errors
+def moderation_check(texts: List[str]) -> List[ModerationInfoType]:
+ return openai.Moderation.create(input=texts)["results"]
+
+
+@handle_openai_errors
+def _compute_openai_embeddings(non_flagged_texts: List[str], **kwargs) -> List[List[float]]:
+ data = openai.Embedding.create(input=non_flagged_texts, engine=OPENAI_EMBEDDINGS_MODEL, **kwargs).data
+ return [d["embedding"] for d in data]
+
+
+def get_embeddings_without_moderation(
+ texts: List[str],
+ source: Optional[str] = None,
+ **kwargs,
+) -> List[List[float]]:
+ """
+ Obtain embeddings without moderation checks.
+
+ Parameters:
+ - texts (List[str]): List of texts to be embedded.
+ - source (Optional[str], optional): Source identifier to potentially adjust embedding bias. Defaults to None.
+ - **kwargs: Additional keyword arguments passed to the embedding function.
+
+ Returns:
+ - List[List[float]]: List of embeddings for the provided texts.
+ """
+ if not texts:
+ return []
+
+ texts = [text.replace("\n", " ") for text in texts]
+ if USE_OPENAI_EMBEDDINGS:
+ embeddings = _compute_openai_embeddings(texts, **kwargs)
+ elif hf_embedding_model:
+ embeddings = hf_embedding_model.embed_documents(texts)
+ else:
+ raise MissingEmbeddingModelError("No embedding model available.")
+
+ # Bias adjustment
+ if source and (bias := EMBEDDING_LENGTH_BIAS.get(source, 1.0)):
+ embeddings = [[bias * e for e in embedding] for embedding in embeddings]
+
+ return embeddings
+
+
+def get_embeddings_or_none_if_flagged(
+ texts: List[str],
+ source: Optional[str] = None,
+ **kwargs,
+) -> Tuple[List[List[float]] | None, List[ModerationInfoType]]:
+ """
+ Obtain embeddings for the provided texts. If any text is flagged during moderation,
+ the function returns None for the embeddings while still providing the moderation results.
+
+ Parameters:
+ - texts (List[str]): List of texts to be embedded.
+ - source (Optional[str], optional): Source identifier to potentially adjust embedding bias. Defaults to None.
+ - **kwargs: Additional keyword arguments passed to the embedding function.
+
+ Returns:
+ - Tuple[Optional[List[List[float]]], ModerationInfoListType]: Tuple containing the list of embeddings (or None if any text is flagged) and the moderation results.
+ """
+ moderation_results = moderation_check(texts)
+ if any(result["flagged"] for result in moderation_results):
+ return None, moderation_results
+
+ embeddings = get_embeddings_without_moderation(texts, source, **kwargs)
+ return embeddings, moderation_results
+
+
+def get_embeddings(
+ texts: List[str],
+ source: Optional[str] = None,
+ **kwargs,
+) -> Tuple[List[List[float] | None], List[ModerationInfoType]]:
+ """
+ Obtain embeddings for the provided texts, replacing the embeddings of flagged texts with `None`.
+
+ Parameters:
+ - texts (List[str]): List of texts to be embedded.
+ - source (Optional[str], optional): Source identifier to potentially adjust embedding bias. Defaults to None.
+ - **kwargs: Additional keyword arguments passed to the embedding function.
+
+ Returns:
+ - Tuple[List[Optional[List[float]]], ModerationInfoListType]: Tuple containing the list of embeddings (with None for flagged texts) and the moderation results.
+ """
+ assert len(texts) <= 2048, "The batch size should not be larger than 2048."
+ assert all(texts), "No empty strings allowed in the input list."
+
+ # replace newlines, which can negatively affect performance
+ texts = [text.replace("\n", " ") for text in texts]
+
+ # Check all texts for moderation flags
+ moderation_results = moderation_check(texts)
+ flags = [result["flagged"] for result in moderation_results]
+
+ non_flagged_texts = [text for text, flag in zip(texts, flags) if not flag]
+ non_flagged_embeddings = get_embeddings_without_moderation(
+ non_flagged_texts, source, **kwargs
+ )
+ embeddings = [None if flag else non_flagged_embeddings.pop(0) for flag in flags]
+ return embeddings, moderation_results
+
+
+def get_embedding(
+ text: str, source: Optional[str] = None, **kwargs
+) -> Tuple[List[float] | None, ModerationInfoType]:
+ """Obtain an embedding for a single text."""
+ embedding, moderation_result = get_embeddings([text], source, **kwargs)
+ return embedding[0], moderation_result[0]
diff --git a/align_data/embeddings/finetuning/data/best_finetuned_model.pth b/align_data/embeddings/finetuning/data/best_finetuned_model.pth
new file mode 100644
index 00000000..e05a5d52
Binary files /dev/null and b/align_data/embeddings/finetuning/data/best_finetuned_model.pth differ
diff --git a/align_data/embeddings/finetuning/data/finetuned_model.pth b/align_data/embeddings/finetuning/data/finetuned_model.pth
new file mode 100644
index 00000000..7ba44ac6
Binary files /dev/null and b/align_data/embeddings/finetuning/data/finetuned_model.pth differ
diff --git a/align_data/embeddings/finetuning/finetuning_dataset.py b/align_data/embeddings/finetuning/finetuning_dataset.py
new file mode 100644
index 00000000..8c5eec04
--- /dev/null
+++ b/align_data/embeddings/finetuning/finetuning_dataset.py
@@ -0,0 +1,105 @@
+import math
+import random
+from typing import List, Tuple, Generator
+from collections import deque
+
+import torch
+from torch.utils.data import IterableDataset, get_worker_info
+from sqlalchemy.exc import OperationalError
+from sqlalchemy.sql import func
+from sqlalchemy.orm import Session
+
+from align_data.db.session import make_session, get_all_valid_article_ids
+from align_data.embeddings.embedding_utils import get_embedding
+from align_data.embeddings.pinecone.pinecone_db_handler import PineconeDB
+from align_data.embeddings.text_splitter import ParagraphSentenceUnitTextSplitter
+from align_data.embeddings.pinecone.update_pinecone import get_text_chunks
+from align_data.db.models import Article
+
+
+class FinetuningDataset(IterableDataset):
+ def __init__(self, num_batches_per_epoch: int, cache_size: int = 1280):
+ self.num_batches_per_epoch = num_batches_per_epoch
+ self.article_cache = deque(maxlen=cache_size)
+
+ self.text_splitter = ParagraphSentenceUnitTextSplitter()
+ self.pinecone_db = PineconeDB()
+
+ with make_session() as session:
+ self.all_article_ids = get_all_valid_article_ids(session)
+ self.total_articles = len(self.all_article_ids)
+
+ def __len__(self):
+ return self.num_batches_per_epoch
+
+ def __iter__(self):
+ start, end = 0, None
+ worker_info = get_worker_info()
+ if worker_info is not None: # Multi-process loading
+ per_worker = math.ceil(self.total_articles / worker_info.num_workers)
+ start = worker_info.id * per_worker
+ end = min(start + per_worker, self.total_articles)
+
+ with make_session() as session:
+ return self._generate_pairs(session, start, end)
+
+ def _fetch_random_articles(self, session: Session, batch_size: int = 1) -> List[Article]:
+ """Fetch a batch of random articles."""
+ # If the list has fewer IDs than needed, raise an exception
+ random_selected_ids = random.sample(self.all_article_ids, batch_size)
+ return session.query(Article).filter(Article.id.in_(random_selected_ids)).all()
+
+ def _get_random_chunks(self, article: Article, num_chunks: int = 2) -> List[Tuple[int, str]]:
+ chunked_text = get_text_chunks(article, self.text_splitter)
+
+ chunks = list(enumerate(chunked_text))
+ if len(chunks) < num_chunks:
+ return []
+
+ return random.sample(chunks, num_chunks)
+
+ def _get_embeddings(self, article: Article, chunks: List[Tuple[int, str]]) -> List[List[float]]:
+ full_ids = [f"{article.id}_{str(idx).zfill(6)}" for idx, _ in chunks]
+ _embeddings = self.pinecone_db.get_embeddings_by_ids(full_ids)
+
+ embeddings = []
+ for (_, chunk), (_, embedding) in zip(chunks, _embeddings):
+ if embedding is None:
+ embedding, _ = get_embedding(chunk, article.source)
+ embeddings.append(torch.tensor(embedding))
+
+ return embeddings
+
+ def _generate_pairs(
+ self, session, start=0, end=None, neg_pos_proportion=0.5
+ ) -> Generator[Tuple[List[float], List[float], int], None, None]:
+ end = end or self.total_articles
+
+ batches_yielded = 0
+ while start < end:
+ start += 1
+ if random.random() < neg_pos_proportion:
+ # Positive pairs
+ article = self._fetch_random_articles(session)[0]
+ chunks = self._get_random_chunks(article, 2)
+ if not chunks:
+ continue
+ embedding_1, embedding_2 = self._get_embeddings(article, chunks)
+ label = 1
+ else:
+ # Negative pairs
+ article1, article2 = self._fetch_random_articles(session, batch_size=2)
+ chunk1 = self._get_random_chunks(article1, 1)
+ chunk2 = self._get_random_chunks(article2, 1)
+ embedding_1, embedding_2 = (
+ self._get_embeddings(article1, chunk1)[0],
+ self._get_embeddings(article2, chunk2)[0],
+ )
+ label = 0
+ yield torch.tensor(embedding_1, dtype=torch.int64), torch.tensor(
+ embedding_2, dtype=torch.int64
+ ), torch.tensor(label, dtype=torch.int64)
+ batches_yielded += 1
+
+ if self.num_batches_per_epoch and batches_yielded >= self.num_batches_per_epoch:
+ break
diff --git a/align_data/embeddings/finetuning/training.py b/align_data/embeddings/finetuning/training.py
new file mode 100644
index 00000000..cd1d7845
--- /dev/null
+++ b/align_data/embeddings/finetuning/training.py
@@ -0,0 +1,177 @@
+import os
+
+import torch
+import torch.nn as nn
+import torch.optim as optim
+from torch.optim.lr_scheduler import ReduceLROnPlateau
+from torch.utils.data import DataLoader
+
+from align_data.embeddings.finetuning.finetuning_dataset import FinetuningDataset
+from align_data.settings import (
+ PINECONE_VALUES_DIMS,
+ DEVICE,
+ OPENAI_FINETUNED_LAYER_PATH,
+ OPENAI_CURRENT_BEST_FINETUNED_LAYER_PATH,
+)
+
+
+class ContrastiveLoss(nn.Module):
+ def __init__(self, margin=2.0):
+ super(ContrastiveLoss, self).__init__()
+ self.margin = margin
+
+ def forward(self, output1, output2, label):
+ euclidean_distance = nn.functional.pairwise_distance(output1, output2)
+ loss_contrastive = torch.mean(
+ (1 - label) * torch.pow(euclidean_distance, 2)
+ + (label) * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2)
+ )
+
+ return loss_contrastive
+
+
+class NonLinearFineTuneModel(nn.Module):
+ def __init__(self, embedding_dim=PINECONE_VALUES_DIMS, hidden_dim=2000, dropout=0.5):
+ super(FineTuneModel, self).__init__()
+
+ self.fc1 = nn.Linear(embedding_dim, hidden_dim)
+ self.fc2 = nn.Linear(hidden_dim, embedding_dim)
+ self.dropout = nn.Dropout(dropout)
+
+ def forward(self, x):
+ x = nn.functional.relu(self.fc1(x))
+ x = self.dropout(x)
+ x = self.fc2(x)
+ return x
+
+
+class FineTuneModel(nn.Module):
+ def __init__(self, embedding_dim=PINECONE_VALUES_DIMS):
+ super(FineTuneModel, self).__init__()
+
+ self.fc = nn.Linear(embedding_dim, embedding_dim)
+
+ def forward(self, x):
+ x = self.fc(x)
+ return x
+
+
+def train(model, dataloader, optimizer, criterion):
+ model.train()
+ total_loss = 0.0
+
+ for batch_idx, (text1_embedding, text2_embedding, target) in enumerate(dataloader):
+ text1_embedding = text1_embedding.to(DEVICE)
+ text2_embedding = text2_embedding.to(DEVICE)
+ target = target.float().to(DEVICE)
+
+ optimizer.zero_grad()
+
+ output1 = model(text1_embedding)
+ output2 = model(text2_embedding)
+
+ loss = criterion(output1, output2, target)
+ loss.backward()
+
+ # Gradient clipping
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
+
+ optimizer.step()
+
+ total_loss += loss.item()
+
+ return total_loss / len(dataloader)
+
+
+def validate(model, dataloader, criterion):
+ model.eval()
+ total_loss = 0.0
+
+ with torch.no_grad():
+ for batch_idx, (text1_embedding, text2_embedding, target) in enumerate(dataloader):
+ text1_embedding = text1_embedding.to(DEVICE)
+ text2_embedding = text2_embedding.to(DEVICE)
+ target = target.float().to(DEVICE)
+
+ output1 = model(text1_embedding)
+ output2 = model(text2_embedding)
+
+ loss = criterion(output1, output2, target)
+ total_loss += loss.item()
+
+ return total_loss / len(dataloader)
+
+
+def finetune_embeddings():
+ # Hyperparameters & Configuration
+ EPOCHS = 100
+ BATCH_PER_EPOCH = 20
+ BATCH_SIZE = 64
+ LEARNING_RATE = 5.0000e-02
+ MARGIN = 2.0
+
+ dataset = FinetuningDataset(num_batches_per_epoch=BATCH_PER_EPOCH)
+ dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, num_workers=5)
+
+ model = FineTuneModel().to(DEVICE)
+ model = load_best_model_if_exists(model)
+ optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
+ scheduler = ReduceLROnPlateau(optimizer, "min", patience=2, factor=0.5, verbose=True)
+ criterion = ContrastiveLoss(MARGIN)
+
+ # Assuming you've split your data and have a separate validation set
+ validation_dataset = FinetuningDataset(num_batches_per_epoch=BATCH_PER_EPOCH)
+ validation_dataloader = DataLoader(validation_dataset, batch_size=BATCH_SIZE, num_workers=5)
+ best_val_loss = validate(model, validation_dataloader, criterion)
+ print(f"Initial validation loss (from loaded model or new model): {best_val_loss:.4f}")
+
+ epochs_without_improvement = 0
+ max_epochs_without_improvement = 15 # stop after 5 epochs without improvement
+
+ for epoch in range(EPOCHS):
+ train_loss = train(model, dataloader, optimizer, criterion)
+ validate_loss = validate(model, validation_dataloader, criterion)
+
+ scheduler.step(validate_loss)
+ if validate_loss < best_val_loss:
+ best_val_loss = validate_loss
+ torch.save(model.state_dict(), OPENAI_CURRENT_BEST_FINETUNED_LAYER_PATH)
+ epochs_without_improvement = 0
+ else:
+ epochs_without_improvement += 1
+
+ print(
+ f"Epoch: {epoch+1}/{EPOCHS} | Train Loss: {train_loss:.4f} | Validation Loss: {validate_loss:.4f}"
+ )
+
+ if epochs_without_improvement >= max_epochs_without_improvement:
+ print("Early stopping due to no improvement in validation loss.")
+ break
+
+ torch.save(model.state_dict(), OPENAI_FINETUNED_LAYER_PATH)
+
+
+### HELPER FUNCTIONS ###
+
+
+def load_best_model_if_exists(model):
+ """
+ Load the best saved model if it exists.
+
+ Parameters:
+ - model (torch.nn.Module): The model architecture.
+
+ Returns:
+ - model (torch.nn.Module): The loaded model.
+ """
+ if os.path.exists(OPENAI_CURRENT_BEST_FINETUNED_LAYER_PATH):
+ model.load_state_dict(
+ torch.load(OPENAI_CURRENT_BEST_FINETUNED_LAYER_PATH, map_location=DEVICE)
+ )
+ print(f"Loaded model from {OPENAI_CURRENT_BEST_FINETUNED_LAYER_PATH}.")
+ else:
+ print(
+ f"No saved model found at {OPENAI_CURRENT_BEST_FINETUNED_LAYER_PATH}. Starting from scratch."
+ )
+
+ return model
diff --git a/align_data/pinecone/__init__.py b/align_data/embeddings/pinecone/__init__.py
similarity index 100%
rename from align_data/pinecone/__init__.py
rename to align_data/embeddings/pinecone/__init__.py
diff --git a/align_data/embeddings/pinecone/pinecone_db_handler.py b/align_data/embeddings/pinecone/pinecone_db_handler.py
new file mode 100644
index 00000000..dd2990d3
--- /dev/null
+++ b/align_data/embeddings/pinecone/pinecone_db_handler.py
@@ -0,0 +1,142 @@
+# dataset/pinecone_db_handler.py
+import logging
+from typing import List, Tuple
+
+import pinecone
+from pinecone.core.client.models import ScoredVector
+
+from align_data.embeddings.embedding_utils import get_embedding
+from align_data.embeddings.pinecone.pinecone_models import (
+ PineconeEntry,
+ PineconeMetadata,
+)
+from align_data.settings import (
+ PINECONE_INDEX_NAME,
+ PINECONE_VALUES_DIMS,
+ PINECONE_METRIC,
+ PINECONE_API_KEY,
+ PINECONE_ENVIRONMENT,
+ PINECONE_NAMESPACE,
+)
+
+
+logger = logging.getLogger(__name__)
+
+
+class PineconeDB:
+ def __init__(
+ self,
+ index_name: str = PINECONE_INDEX_NAME,
+ values_dims: int = PINECONE_VALUES_DIMS,
+ metric: str = PINECONE_METRIC,
+ create_index: bool = False,
+ log_index_stats: bool = False,
+ ):
+ self.index_name = index_name
+ self.values_dims = values_dims
+ self.metric = metric
+
+ pinecone.init(
+ api_key=PINECONE_API_KEY,
+ environment=PINECONE_ENVIRONMENT,
+ )
+
+ if create_index:
+ self.create_index()
+
+ self.index = pinecone.Index(index_name=self.index_name)
+
+ if log_index_stats:
+ index_stats_response = self.index.describe_index_stats()
+ logger.info(f"{self.index_name}:\n{index_stats_response}")
+
+ def upsert_entry(self, pinecone_entry: PineconeEntry, upsert_size: int = 100):
+ vectors = pinecone_entry.create_pinecone_vectors()
+ self.index.upsert(vectors=vectors, batch_size=upsert_size, namespace=PINECONE_NAMESPACE)
+
+ def query_vector(
+ self,
+ query: List[float],
+ top_k: int = 10,
+ include_values: bool = False,
+ include_metadata: bool = True,
+ **kwargs,
+ ) -> List[ScoredVector]:
+ assert not isinstance(
+ query, str
+ ), "query must be a list of floats. Use query_PineconeDB_text for text queries"
+
+ query_response = self.index.query(
+ vector=query,
+ top_k=top_k,
+ include_values=include_values,
+ include_metadata=include_metadata,
+ **kwargs,
+ namespace=PINECONE_NAMESPACE,
+ )
+
+ return [
+ ScoredVector(
+ id=match["id"],
+ score=match["score"],
+ metadata=PineconeMetadata(**match["metadata"]),
+ )
+ for match in query_response["matches"]
+ ]
+
+ def query_text(
+ self,
+ query: str,
+ top_k: int = 10,
+ include_values: bool = False,
+ include_metadata: bool = True,
+ **kwargs,
+ ) -> List[ScoredVector]:
+ query_vector = get_embedding(query)[0]
+ return self.query_vector(
+ query=query_vector,
+ top_k=top_k,
+ include_values=include_values,
+ include_metadata=include_metadata,
+ **kwargs,
+ )
+
+ def delete_entries(self, ids):
+ self.index.delete(filter={"hash_id": {"$in": ids}})
+
+ def create_index(self, replace_current_index: bool = True):
+ if replace_current_index:
+ self.delete_index()
+
+ pinecone.create_index(
+ name=self.index_name,
+ dimension=self.values_dims,
+ metric=self.metric,
+ metadata_config={"indexed": list(PineconeMetadata.__annotations__.keys())},
+ )
+
+ def delete_index(self):
+ if self.index_name in pinecone.list_indexes():
+ logger.info(f"Deleting index '{self.index_name}'.")
+ pinecone.delete_index(self.index_name)
+
+ def get_embeddings_by_ids(self, ids: List[str]) -> List[Tuple[str, List[float] | None]]:
+ """
+ Fetch embeddings for given entry IDs from Pinecone.
+
+ Args:
+ - ids (List[str]): List of entry IDs for which embeddings are to be fetched.
+
+ Returns:
+ - List[Tuple[str, List[float] | None]]: List of tuples containing ID and its corresponding embedding.
+ """
+ # TODO: check that this still works
+ vectors = self.index.fetch(
+ ids=ids,
+ namespace=PINECONE_NAMESPACE,
+ )["vectors"]
+ return [(id, vectors.get(id, {}).get("values", None)) for id in ids]
+
+
+def strip_block(text: str) -> str:
+ return "\n".join(text.split("\n")[1:])
diff --git a/align_data/embeddings/pinecone/pinecone_models.py b/align_data/embeddings/pinecone/pinecone_models.py
new file mode 100644
index 00000000..fd7b67eb
--- /dev/null
+++ b/align_data/embeddings/pinecone/pinecone_models.py
@@ -0,0 +1,77 @@
+from typing import List, TypedDict
+
+from pydantic import BaseModel, validator
+from pinecone.core.client.models import Vector
+
+
+class MissingFieldsError(Exception):
+ pass
+
+
+class MissingEmbeddingModelError(Exception):
+ pass
+
+
+class PineconeMetadata(TypedDict):
+ hash_id: str
+ source: str
+ title: str
+ url: str
+ date_published: float
+ authors: List[str]
+ text: str
+
+
+class PineconeEntry(BaseModel):
+ hash_id: str
+ source: str
+ title: str
+ url: str
+ date_published: float
+ authors: List[str]
+ text_chunks: List[str]
+ embeddings: List[List[float]]
+
+ def __init__(self, **data):
+ """Check for missing (falsy) fields before initializing."""
+ missing_fields = [field for field, value in data.items() if not str(value).strip()]
+
+ if missing_fields:
+ raise MissingFieldsError(f"Missing fields: {missing_fields}")
+
+ super().__init__(**data)
+
+ def __repr__(self):
+ def make_small(chunk: str) -> str:
+ return (chunk[:45] + " [...] " + chunk[-45:]) if len(chunk) > 100 else chunk
+
+ def display_chunks(chunks_lst: List[str]) -> str:
+ chunks = ", ".join(f'"{make_small(chunk)}"' for chunk in chunks_lst)
+ return (
+ f"[{chunks[:450]} [...] {chunks[-450:]} ]" if len(chunks) > 1000 else f"[{chunks}]"
+ )
+
+ return f"PineconeEntry(hash_id={self.hash_id!r}, source={self.source!r}, title={self.title!r}, url={self.url!r}, date_published={self.date_published!r}, authors={self.authors!r}, text_chunks={display_chunks(self.text_chunks)})"
+
+ @property
+ def chunk_num(self) -> int:
+ return len(self.text_chunks)
+
+ def create_pinecone_vectors(self) -> List[Vector]:
+ return [
+ Vector(
+ id=f"{self.hash_id}_{str(i).zfill(6)}",
+ values=self.embeddings[i],
+ metadata=PineconeMetadata(
+ hash_id=self.hash_id,
+ source=self.source,
+ title=self.title,
+ authors=self.authors,
+ url=self.url,
+ date_published=self.date_published,
+ text=self.text_chunks[i],
+ ),
+ )
+ for i in range(self.chunk_num)
+ if self.embeddings[i] # Skips flagged chunks
+ ]
diff --git a/align_data/embeddings/pinecone/update_pinecone.py b/align_data/embeddings/pinecone/update_pinecone.py
new file mode 100644
index 00000000..b425ee9d
--- /dev/null
+++ b/align_data/embeddings/pinecone/update_pinecone.py
@@ -0,0 +1,129 @@
+from datetime import datetime
+import logging
+from itertools import islice
+from typing import Callable, List, Tuple, Generator, Iterator, Optional
+
+from sqlalchemy.orm import Session
+from pydantic import ValidationError
+
+from align_data.embeddings.embedding_utils import get_embeddings
+from align_data.db.models import Article
+from align_data.db.session import make_session, stream_pinecone_updates
+from align_data.embeddings.pinecone.pinecone_db_handler import PineconeDB
+from align_data.embeddings.pinecone.pinecone_models import (
+ PineconeEntry, MissingFieldsError, MissingEmbeddingModelError
+)
+from align_data.embeddings.text_splitter import ParagraphSentenceUnitTextSplitter
+
+
+logger = logging.getLogger(__name__)
+
+
+# Define type aliases for the Callables
+LengthFunctionType = Callable[[str], int]
+TruncateFunctionType = Callable[[str, int], str]
+
+
+class PineconeUpdater:
+ def __init__(self):
+ self.text_splitter = ParagraphSentenceUnitTextSplitter()
+ self.pinecone_db = PineconeDB()
+
+ def update(self, custom_sources: List[str], force_update: bool = False):
+ """
+ Update the given sources. If no sources are provided, updates all sources.
+
+ :param custom_sources: List of sources to update.
+ """
+ with make_session() as session:
+ articles_to_update_stream = stream_pinecone_updates(
+ session, custom_sources, force_update
+ )
+ for batch in self.batch_entries(articles_to_update_stream):
+ self.save_batch(session, batch)
+
+ def save_batch(self, session: Session, batch: List[Tuple[Article, PineconeEntry]]):
+ try:
+ for article, pinecone_entry in batch:
+ self.pinecone_db.upsert_entry(pinecone_entry)
+
+ article.pinecone_update_required = False
+ session.add(article)
+
+ session.commit()
+
+ except Exception as e:
+ # Rollback on any kind of error. The next run will redo this batch, but in the meantime keep trucking
+ logger.error(e)
+ session.rollback()
+
+ def batch_entries(
+ self, article_stream: Generator[Article, None, None]
+ ) -> Iterator[List[Tuple[Article, PineconeEntry]]]:
+ while batch := tuple(islice(article_stream, 10)):
+ yield [
+ (article, pinecone_entry)
+ for article in batch
+ if (pinecone_entry := self._make_pinecone_entry(article)) is not None
+ ]
+
+ def _make_pinecone_entry(self, article: Article) -> PineconeEntry | None:
+ try:
+ text_chunks = get_text_chunks(article, self.text_splitter)
+ embeddings, moderation_results = get_embeddings(text_chunks, article.source)
+
+ if any(result['flagged'] for result in moderation_results):
+ flagged_text_chunks = [f"Chunk {i}: \"{text}\"" for i, (text, result) in enumerate(zip(text_chunks, moderation_results)) if result["flagged"]]
+ logger.warning(f"OpenAI moderation flagged text chunks for the following article: {article.id}")
+ article.append_comment(f"OpenAI moderation flagged the following text chunks: {flagged_text_chunks}")
+
+ return PineconeEntry(
+ hash_id=article.id, # the hash_id of the article
+ source=article.source,
+ title=article.title,
+ url=article.url,
+ date_published=article.date_published.timestamp(),
+ authors=[author.strip() for author in article.authors.split(",") if author.strip()],
+ text_chunks=text_chunks,
+ embeddings=embeddings,
+ )
+ except (ValueError, TypeError, AttributeError, ValidationError, MissingFieldsError, MissingEmbeddingModelError) as e:
+ logger.warning(e)
+ article.append_comment(f"Error encountered while processing this article: {e}")
+ return None
+
+ except Exception as e:
+ logger.error(e)
+ raise
+
+
+def get_text_chunks(
+ article: Article, text_splitter: ParagraphSentenceUnitTextSplitter
+) -> List[str]:
+ title = article.title.replace("\n", " ")
+
+ authors_lst = [author.strip() for author in article.authors.split(",")]
+ authors = get_authors_str(authors_lst)
+
+ signature = f"Title: {title}; Author(s): {authors}."
+ text_chunks = text_splitter.split_text(article.text)
+ return [f'###{signature}###\n"""{text_chunk}"""' for text_chunk in text_chunks]
+
+
+def get_authors_str(authors_lst: List[str]) -> str:
+ if not authors_lst:
+ return "n/a"
+
+ if len(authors_lst) == 1:
+ authors_str = authors_lst[0]
+ else:
+ authors_lst = authors_lst[:4]
+ authors_str = f"{', '.join(authors_lst[:-1])} and {authors_lst[-1]}"
+
+ authors_str = authors_str.replace("\n", " ")
+
+ # Truncate if necessary
+ if len(authors_str) > 500:
+ authors_str = authors_str[:497] + "..."
+
+ return authors_str
diff --git a/align_data/pinecone/text_splitter.py b/align_data/embeddings/text_splitter.py
similarity index 89%
rename from align_data/pinecone/text_splitter.py
rename to align_data/embeddings/text_splitter.py
index b8af09a3..a364415e 100644
--- a/align_data/pinecone/text_splitter.py
+++ b/align_data/embeddings/text_splitter.py
@@ -1,5 +1,3 @@
-# dataset/text_splitter.py
-
from typing import List, Callable, Any
from langchain.text_splitter import TextSplitter
from nltk.tokenize import sent_tokenize
@@ -11,8 +9,6 @@
StrToIntFunction = Callable[[str], int]
StrIntBoolToStrFunction = Callable[[str, int, bool], str]
-def default_truncate_function(string: str, length: int, from_end: bool = False) -> str:
- return string[-length:] if from_end else string[:length]
def default_truncate_function(string: str, length: int, from_end: bool = False) -> str:
return string[-length:] if from_end else string[:length]
@@ -26,7 +22,7 @@ class ParagraphSentenceUnitTextSplitter(TextSplitter):
@param length_function: A function that returns the length of a string in units. Defaults to len().
@param truncate_function: A function that truncates a string to a given unit length.
"""
-
+
DEFAULT_MIN_CHUNK_SIZE: int = 900
DEFAULT_MAX_CHUNK_SIZE: int = 1100
DEFAULT_LENGTH_FUNCTION: StrToIntFunction = len
@@ -38,7 +34,7 @@ def __init__(
max_chunk_size: int = DEFAULT_MAX_CHUNK_SIZE,
length_function: StrToIntFunction = DEFAULT_LENGTH_FUNCTION,
truncate_function: StrIntBoolToStrFunction = DEFAULT_TRUNCATE_FUNCTION,
- **kwargs: Any
+ **kwargs: Any,
):
super().__init__(**kwargs)
self.min_chunk_size = min_chunk_size
@@ -49,6 +45,9 @@ def __init__(
def split_text(self, text: str) -> List[str]:
"""Split text into chunks of length between min_chunk_size and max_chunk_size."""
+ if not text:
+ return []
+
blocks: List[str] = []
current_block: str = ""
@@ -90,13 +89,11 @@ def _handle_large_paragraph(self, current_block: str, blocks: List[str], paragra
def _truncate_large_block(self, current_block: str, blocks: List[str]) -> str:
while self._length_function(current_block) > self.max_chunk_size:
# Truncate current_block to max size, set remaining text as current_block
- truncated_block = self._truncate_function(
- current_block, self.max_chunk_size
- )
+ truncated_block = self._truncate_function(current_block, self.max_chunk_size, False)
blocks.append(truncated_block)
- current_block = current_block[len(truncated_block):].lstrip()
-
+ current_block = current_block[len(truncated_block) :].lstrip()
+
return current_block
def _handle_remaining_text(self, last_block: str, blocks: List[str]) -> List[str]:
@@ -107,9 +104,7 @@ def _handle_remaining_text(self, last_block: str, blocks: List[str]) -> List[str
if self.min_chunk_size - len_last_block > 0:
# Add text from previous block to last block if last_block is too short
part_prev_block = self._truncate_function(
- string=blocks[-1],
- length=self.min_chunk_size - len_last_block,
- from_end=True
+ blocks[-1], self.min_chunk_size - len_last_block, True
)
last_block = part_prev_block + last_block
diff --git a/align_data/pinecone/pinecone_db_handler.py b/align_data/pinecone/pinecone_db_handler.py
deleted file mode 100644
index d8f565df..00000000
--- a/align_data/pinecone/pinecone_db_handler.py
+++ /dev/null
@@ -1,91 +0,0 @@
-# dataset/pinecone_db_handler.py
-
-import logging
-from typing import Dict
-
-import pinecone
-
-from align_data.settings import (
- PINECONE_INDEX_NAME,
- PINECONE_VALUES_DIMS,
- PINECONE_METRIC,
- PINECONE_METADATA_KEYS,
- PINECONE_API_KEY,
- PINECONE_ENVIRONMENT,
-)
-
-
-logger = logging.getLogger(__name__)
-
-
-class PineconeDB:
- def __init__(
- self,
- index_name: str = PINECONE_INDEX_NAME,
- values_dims: int = PINECONE_VALUES_DIMS,
- metric: str = PINECONE_METRIC,
- metadata_keys: list = PINECONE_METADATA_KEYS,
- create_index: bool = False,
- log_index_stats: bool = True,
- ):
- self.index_name = index_name
- self.values_dims = values_dims
- self.metric = metric
- self.metadata_keys = metadata_keys
-
- pinecone.init(
- api_key=PINECONE_API_KEY,
- environment=PINECONE_ENVIRONMENT,
- )
-
- if create_index:
- self.create_index()
-
- self.index = pinecone.Index(index_name=self.index_name)
-
- if log_index_stats:
- index_stats_response = self.index.describe_index_stats()
- logger.info(f"{self.index_name}:\n{index_stats_response}")
-
- def upsert_entry(self, entry: Dict, upsert_size=100):
- self.index.upsert(
- vectors=list(
- zip(
- [
- f"{entry['id']}_{str(i).zfill(6)}"
- for i in range(len(entry["text_chunks"]))
- ],
- entry["embeddings"].tolist(),
- [
- {
- "entry_id": entry["id"],
- "source": entry["source"],
- "title": entry["title"],
- "authors": entry["authors"],
- "text": text_chunk,
- }
- for text_chunk in entry["text_chunks"]
- ],
- )
- ),
- batch_size=upsert_size,
- )
-
- def delete_entries(self, ids):
- self.index.delete(filter={"entry_id": {"$in": ids}})
-
- def create_index(self, replace_current_index: bool = True):
- if replace_current_index:
- self.delete_index()
-
- pinecone.create_index(
- name=self.index_name,
- dimension=self.values_dims,
- metric=self.metric,
- metadata_config={"indexed": self.metadata_keys},
- )
-
- def delete_index(self):
- if self.index_name in pinecone.list_indexes():
- logger.info(f"Deleting index '{self.index_name}'.")
- pinecone.delete_index(self.index_name)
diff --git a/align_data/pinecone/update_pinecone.py b/align_data/pinecone/update_pinecone.py
deleted file mode 100644
index 7e64e560..00000000
--- a/align_data/pinecone/update_pinecone.py
+++ /dev/null
@@ -1,193 +0,0 @@
-from datetime import datetime
-import logging
-import numpy as np
-import os
-from itertools import islice
-from typing import Callable, List, Tuple, Generator
-
-import openai
-from pydantic import BaseModel, ValidationError, validator
-
-from align_data.db.models import Article
-from align_data.db.session import make_session, stream_pinecone_updates
-from align_data.pinecone.pinecone_db_handler import PineconeDB
-from align_data.pinecone.text_splitter import ParagraphSentenceUnitTextSplitter
-from align_data.settings import (
- USE_OPENAI_EMBEDDINGS,
- OPENAI_EMBEDDINGS_MODEL,
- OPENAI_EMBEDDINGS_DIMS,
- OPENAI_EMBEDDINGS_RATE_LIMIT,
- SENTENCE_TRANSFORMER_EMBEDDINGS_MODEL,
- SENTENCE_TRANSFORMER_EMBEDDINGS_DIMS,
- CHUNK_SIZE,
- MAX_NUM_AUTHORS_IN_SIGNATURE,
- EMBEDDING_LENGTH_BIAS,
-)
-
-
-logger = logging.getLogger(__name__)
-
-
-# Define type aliases for the Callables
-LengthFunctionType = Callable[[str], int]
-TruncateFunctionType = Callable[[str, int], str]
-
-
-class PineconeEntry(BaseModel):
- id: str
- source: str
- title: str
- url: str
- date_published: datetime
- authors: List[str]
- text_chunks: List[str]
- embeddings: np.ndarray
-
- class Config:
- arbitrary_types_allowed = True
-
- def __repr__(self):
- return f"PineconeEntry(id={self.id!r}, source={self.source!r}, title={self.title!r}, url={self.url!r}, date_published={self.date_published!r}, authors={self.authors!r}, text_chunks={self.text_chunks[:5]!r})"
-
- @validator(
- "id",
- "source",
- "title",
- "url",
- "date_published",
- "authors",
- "text_chunks",
- pre=True,
- always=True,
- )
- def empty_strings_not_allowed(cls, value):
- if not str(value).strip():
- raise ValueError("Attribute should not be empty.")
- return value
-
-
-class PineconeUpdater:
- def __init__(
- self,
- min_chunk_size: int = ParagraphSentenceUnitTextSplitter.DEFAULT_MIN_CHUNK_SIZE,
- max_chunk_size: int = ParagraphSentenceUnitTextSplitter.DEFAULT_MAX_CHUNK_SIZE,
- length_function: LengthFunctionType = ParagraphSentenceUnitTextSplitter.DEFAULT_LENGTH_FUNCTION,
- truncate_function: TruncateFunctionType = ParagraphSentenceUnitTextSplitter.DEFAULT_TRUNCATE_FUNCTION,
- ):
- self.min_chunk_size = min_chunk_size
- self.max_chunk_size = max_chunk_size
- self.length_function = length_function
- self.truncate_function = truncate_function
-
- self.text_splitter = ParagraphSentenceUnitTextSplitter(
- min_chunk_size=self.min_chunk_size,
- max_chunk_size=self.max_chunk_size,
- length_function=self.length_function,
- truncate_function=self.truncate_function,
- )
- self.pinecone_db = PineconeDB()
-
- if USE_OPENAI_EMBEDDINGS:
- import openai
-
- openai.api_key = os.environ["OPENAI_API_KEY"]
- else:
- import torch
- from langchain.embeddings import HuggingFaceEmbeddings
-
- self.hf_embeddings = HuggingFaceEmbeddings(
- model_name=SENTENCE_TRANSFORMER_EMBEDDINGS_MODEL,
- model_kwargs={"device": "cuda" if torch.cuda.is_available() else "cpu"},
- encode_kwargs={"show_progress_bar": False},
- )
-
- def save_batch(self, session, batch):
- try:
- for article, pinecone_entry in batch:
- self.pinecone_db.upsert_entry(pinecone_entry.dict())
- article.pinecone_update_required = False
- session.add(article)
- session.commit()
- except Exception as e:
- # Rollback on any kind of error. The next run will redo this batch, but in the meantime keep trucking
- logger.error(e)
- session.rollback()
-
- def update(self, custom_sources: List[str]):
- """
- Update the given sources. If no sources are provided, updates all sources.
-
- :param custom_sources: List of sources to update.
- """
- with make_session() as session:
- entries_stream = stream_pinecone_updates(session, custom_sources)
- for batch in self.batch_entries(entries_stream):
- self.save_batch(session, batch)
-
- def _make_pinecone_update(self, article):
- try:
- text_chunks = self.get_text_chunks(article)
- return article, PineconeEntry(
- id=article.id,
- source=article.source,
- title=article.title,
- url=article.url,
- date_published=article.date_published,
- authors=[
- author.strip()
- for author in article.authors.split(",")
- if author.strip()
- ],
- text_chunks=text_chunks,
- embeddings=self.extract_embeddings(
- text_chunks, [article.source] * len(text_chunks)
- ),
- )
- except (ValueError, ValidationError) as e:
- logger.exception(e)
-
- def batch_entries(
- self, article_stream: Generator[Article, None, None]
- ) -> Generator[List[Tuple[Article, PineconeEntry]], None, None]:
- items = iter(article_stream)
- while batch := tuple(islice(items, 10)):
- yield list(filter(None, map(self._make_pinecone_update, batch)))
-
- def get_text_chunks(self, article: Article) -> List[str]:
- signature = f"Title: {article.title}, Author(s): {self.get_authors_str(article.authors)}"
- text_chunks = self.text_splitter.split_text(article.text)
- text_chunks = [f"- {signature}\n\n{text_chunk}" for text_chunk in text_chunks]
- return text_chunks
-
- def extract_embeddings(self, chunks_batch, sources_batch):
- if USE_OPENAI_EMBEDDINGS:
- return self.get_openai_embeddings(chunks_batch, sources_batch)
- else:
- return np.array(
- self.hf_embeddings.embed_documents(chunks_batch, sources_batch)
- )
-
- @staticmethod
- def get_openai_embeddings(chunks, sources=""):
- embeddings = np.zeros((len(chunks), OPENAI_EMBEDDINGS_DIMS))
-
- openai_output = openai.Embedding.create(
- model=OPENAI_EMBEDDINGS_MODEL, input=chunks
- )["data"]
-
- for i, (embedding, source) in enumerate(zip(openai_output, sources)):
- bias = EMBEDDING_LENGTH_BIAS.get(source, 1.0)
- embeddings[i] = bias * np.array(embedding["embedding"])
-
- return embeddings
-
- @staticmethod
- def get_authors_str(authors_lst: List[str]) -> str:
- if authors_lst == []:
- return "n/a"
- if len(authors_lst) == 1:
- return authors_lst[0]
- else:
- authors_lst = authors_lst[:MAX_NUM_AUTHORS_IN_SIGNATURE]
- authors_str = f"{', '.join(authors_lst[:-1])} and {authors_lst[-1]}"
- return authors_str
diff --git a/align_data/postprocess/postprocess.py b/align_data/postprocess/postprocess.py
index cb16b7a5..05e9dbde 100644
--- a/align_data/postprocess/postprocess.py
+++ b/align_data/postprocess/postprocess.py
@@ -30,13 +30,15 @@ def compute_statistics(self) -> None:
for source_name, path in tqdm(zip(self.source_list, self.jsonl_list)):
with jsonlines.open(path) as reader:
for obj in reader:
- text = obj['text']
+ text = obj["text"]
source_stats = self.all_stats[source_name]
source_stats["num_entries"] += 1
source_stats["num_tokens"] += len(text.split()) # TODO: Use tokenizer
source_stats["num_chars"] += len(text)
source_stats["num_words"] += len(text.split())
- source_stats["num_sentences"] += len(text.split(".")) # TODO: Use NLTK/Spacy or similar
+ source_stats["num_sentences"] += len(
+ text.split(".")
+ ) # TODO: Use NLTK/Spacy or similar
source_stats["num_paragraphs"] += len(text.splitlines())
def plot_statistics(self) -> None:
diff --git a/align_data/settings.py b/align_data/settings.py
index fb8c8f45..2ae060e5 100644
--- a/align_data/settings.py
+++ b/align_data/settings.py
@@ -1,10 +1,13 @@
import os
import logging
+from typing import Dict
+import openai
+import torch
from dotenv import load_dotenv
load_dotenv()
-LOG_LEVEL = os.environ.get('LOG_LEVEL', 'WARNING').upper()
+LOG_LEVEL = os.environ.get("LOG_LEVEL", "WARNING").upper()
logging.basicConfig(level=LOG_LEVEL)
### CODA ###
@@ -24,7 +27,7 @@
"METADATA_OUTPUT_SPREADSHEET", "1l3azVJVukGAvZPgg0GyeqiaQe8bEMZvycBJaA8cRXf4"
)
-### YouTube ###
+### YOUTUBE ###
YOUTUBE_API_KEY = os.environ.get("YOUTUBE_API_KEY")
### MYSQL ###
@@ -37,13 +40,16 @@
### EMBEDDINGS ###
USE_OPENAI_EMBEDDINGS = True # If false, SentenceTransformer embeddings will be used.
-EMBEDDING_LENGTH_BIAS = {
- "aisafety.info": 1.05, # In search, favor AISafety.info entries.
+EMBEDDING_LENGTH_BIAS: Dict[str, float] = {
+ # TODO: Experiement with these values. For now, let's remove the bias.
+ # "aisafety.info": 1.05, # In search, favor AISafety.info entries.
}
OPENAI_EMBEDDINGS_MODEL = "text-embedding-ada-002"
OPENAI_EMBEDDINGS_DIMS = 1536
OPENAI_EMBEDDINGS_RATE_LIMIT = 3500
+openai.api_key = os.environ.get("OPENAI_API_KEY", None)
+openai.organization = os.environ.get("OPENAI_ORGANIZATION", None)
SENTENCE_TRANSFORMER_EMBEDDINGS_MODEL = "sentence-transformers/multi-qa-mpnet-base-cos-v1"
SENTENCE_TRANSFORMER_EMBEDDINGS_DIMS = 768
@@ -53,15 +59,20 @@
PINECONE_API_KEY = os.environ.get("PINECONE_API_KEY", None)
PINECONE_ENVIRONMENT = os.environ.get("PINECONE_ENVIRONMENT", None)
PINECONE_VALUES_DIMS = (
- OPENAI_EMBEDDINGS_DIMS
- if USE_OPENAI_EMBEDDINGS
- else SENTENCE_TRANSFORMER_EMBEDDINGS_DIMS
+ OPENAI_EMBEDDINGS_DIMS if USE_OPENAI_EMBEDDINGS else SENTENCE_TRANSFORMER_EMBEDDINGS_DIMS
)
PINECONE_METRIC = "dotproduct"
-PINECONE_METADATA_KEYS = ["entry_id", "source", "title", "authors", "text", "url"]
+PINECONE_NAMESPACE = os.environ.get("PINECONE_NAMESPACE", "normal") # "normal" or "finetuned"
-### MISCELLANEOUS ###
-CHUNK_SIZE = 1750
-MAX_NUM_AUTHORS_IN_SIGNATURE = 3
+### FINE-TUNING ###
+OPENAI_FINETUNED_LAYER_PATH = os.environ.get(
+ "OPENAI_FINETUNED_LAYER_PATH", "align_data/finetuning/data/finetuned_model.pth"
+)
+OPENAI_CURRENT_BEST_FINETUNED_LAYER_PATH = os.environ.get(
+ "OPENAI_CURRENT_BEST_FINETUNED_LAYER_PATH",
+ "align_data/finetuning/data/best_finetuned_model.pth",
+)
+### MISCELLANEOUS ###
MIN_CONFIDENCE = 50
+DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
diff --git a/align_data/sources/articles/__init__.py b/align_data/sources/articles/__init__.py
index 4f080409..6fd45fbc 100644
--- a/align_data/sources/articles/__init__.py
+++ b/align_data/sources/articles/__init__.py
@@ -1,6 +1,12 @@
from align_data.sources.articles.datasets import (
- ArxivPapers, EbookArticles, DocArticles, HTMLArticles,
- MarkdownArticles, PDFArticles, SpecialDocs, XMLArticles
+ ArxivPapers,
+ EbookArticles,
+ DocArticles,
+ HTMLArticles,
+ MarkdownArticles,
+ PDFArticles,
+ SpecialDocs,
+ XMLArticles,
)
from align_data.sources.articles.indices import IndicesDataset
from align_data.common.alignment_dataset import MultiDataset
@@ -38,9 +44,9 @@
sheet_id="1293295703",
),
SpecialDocs(
- 'special_docs',
- spreadsheet_id='1pgG3HzercOhf4gniaqp3tBc3uvZnHpPhXErwHcthmbI',
- sheet_id='980957638',
+ "special_docs",
+ spreadsheet_id="1pgG3HzercOhf4gniaqp3tBc3uvZnHpPhXErwHcthmbI",
+ sheet_id="980957638",
),
]
@@ -52,5 +58,5 @@
spreadsheet_id="1pgG3HzercOhf4gniaqp3tBc3uvZnHpPhXErwHcthmbI",
sheet_id="655836697",
),
- IndicesDataset('indices'),
+ IndicesDataset("indices"),
]
diff --git a/align_data/sources/articles/articles.py b/align_data/sources/articles/articles.py
index 39cc31b7..7db94a7b 100644
--- a/align_data/sources/articles/articles.py
+++ b/align_data/sources/articles/articles.py
@@ -105,8 +105,10 @@ def process_spreadsheets(source_sheet, output_sheets):
row["source_url"] = row["url"]
if row.get("source_url") in seen:
logger.info(f'skipping "{title}", as it has already been seen')
- elif row.get('status'):
- logger.info(f'skipping "{title}", as it has a status set - remove it for this row to be processed')
+ elif row.get("status"):
+ logger.info(
+ f'skipping "{title}", as it has a status set - remove it for this row to be processed'
+ )
else:
process_row(row, output_sheets)
@@ -114,9 +116,7 @@ def process_spreadsheets(source_sheet, output_sheets):
def update_new_items(source_spreadsheet, source_sheet, output_spreadsheet):
"""Go through all unprocessed items from the source worksheet, updating the appropriate metadata in the output one."""
source_sheet = get_sheet(source_spreadsheet, source_sheet)
- sheets = {
- sheet.title: sheet for sheet in get_spreadsheet(output_spreadsheet).worksheets()
- }
+ sheets = {sheet.title: sheet for sheet in get_spreadsheet(output_spreadsheet).worksheets()}
return process_spreadsheets(source_sheet, sheets)
@@ -136,8 +136,7 @@ def check_new_articles(source_spreadsheet, source_sheet):
missing = [
item
for title, item in indices_items.items()
- if title not in current
- and not {item.get("url"), item.get("source_url")} & seen_urls
+ if title not in current and not {item.get("url"), item.get("source_url")} & seen_urls
]
if not missing:
logger.info("No new articles found")
@@ -153,14 +152,12 @@ def check_new_articles(source_spreadsheet, source_sheet):
"publication_title",
"source_type",
]
- res = source_sheet.append_rows(
- [[item.get(col) for col in columns] for item in missing]
- )
+ res = source_sheet.append_rows([[item.get(col) for col in columns] for item in missing])
updated = res["updates"]["updatedRows"]
logger.info("Added %s rows", updated)
return updated
def update_articles(csv_file, delimiter):
- dataset = ReplacerDataset(name='updater', csv_path=csv_file, delimiter=delimiter)
+ dataset = ReplacerDataset(name="updater", csv_path=csv_file, delimiter=delimiter)
dataset.add_entries(dataset.fetch_entries())
diff --git a/align_data/sources/articles/datasets.py b/align_data/sources/articles/datasets.py
index 5c031fcf..37ab0940 100644
--- a/align_data/sources/articles/datasets.py
+++ b/align_data/sources/articles/datasets.py
@@ -16,10 +16,16 @@
from align_data.db.models import Article
from align_data.sources.articles.google_cloud import fetch_file, fetch_markdown
from align_data.sources.articles.parsers import (
- HTML_PARSERS, extract_gdrive_contents, item_metadata, parse_domain
+ HTML_PARSERS,
+ extract_gdrive_contents,
+ item_metadata,
+ parse_domain,
)
from align_data.sources.articles.pdf import read_pdf
-from align_data.sources.arxiv_papers import fetch_arxiv, canonical_url as arxiv_cannonical_url
+from align_data.sources.arxiv_papers import (
+ fetch_arxiv,
+ canonical_url as arxiv_cannonical_url,
+)
logger = logging.getLogger(__name__)
@@ -44,8 +50,8 @@ def get_item_key(self, item):
@property
def items_list(self):
- url = f'https://docs.google.com/spreadsheets/d/{self.spreadsheet_id}/export?format=csv&gid={self.sheet_id}'
- logger.info(f'Fetching {url}')
+ url = f"https://docs.google.com/spreadsheets/d/{self.spreadsheet_id}/export?format=csv&gid={self.sheet_id}"
+ logger.info(f"Fetching {url}")
df = pd.read_csv(url)
return (item for item in df.itertuples() if self.get_item_key(item))
@@ -85,7 +91,6 @@ def process_entry(self, item):
class SpecialDocs(SpreadsheetDataset):
-
@property
def _query_items(self):
special_docs_types = ["pdf", "html", "xml", "markdown", "docx"]
@@ -96,35 +101,39 @@ def get_contents(self, item) -> Dict:
if url := self.maybe(item, "source_url") or self.maybe(item, "url"):
contents = item_metadata(url)
- return dict(contents, **{
- 'url': self.maybe(item, "url"),
- 'title': self.maybe(item, "title") or contents.get('title'),
- 'source': contents.get('source_type') or self.name,
- 'source_url': self.maybe(item, "source_url"),
- 'source_type': contents.get('source_type') or self.maybe(item, "source_type"),
- 'date_published': self._get_published_date(self.maybe(item, 'date_published')) or contents.get('date_published'),
- 'authors': self.extract_authors(item) or contents.get('authors', []),
- 'text': contents.get('text'),
- 'status': 'Invalid' if contents.get('error') else None,
- 'comments': contents.get('error'),
- })
+ return dict(
+ contents,
+ **{
+ "url": self.maybe(item, "url"),
+ "title": self.maybe(item, "title") or contents.get("title"),
+ "source": contents.get("source_type") or self.name,
+ "source_url": self.maybe(item, "source_url"),
+ "source_type": contents.get("source_type") or self.maybe(item, "source_type"),
+ "date_published": self._get_published_date(self.maybe(item, "date_published"))
+ or contents.get("date_published"),
+ "authors": self.extract_authors(item) or contents.get("authors", []),
+ "text": contents.get("text"),
+ "status": "Invalid" if contents.get("error") else None,
+ "comments": contents.get("error"),
+ },
+ )
def not_processed(self, item):
- url = self.maybe(item, 'url')
- source_url = self.maybe(item, 'source_url')
+ url = self.maybe(item, "url")
+ source_url = self.maybe(item, "source_url")
return (
- self.get_item_key(item) not in self._outputted_items and
- url not in self._outputted_items and
- source_url not in self._outputted_items and
- (not url or arxiv_cannonical_url(url) not in self._outputted_items) and
- (not source_url or arxiv_cannonical_url(source_url) not in self._outputted_items)
+ self.get_item_key(item) not in self._outputted_items
+ and url not in self._outputted_items
+ and source_url not in self._outputted_items
+ and (not url or arxiv_cannonical_url(url) not in self._outputted_items)
+ and (not source_url or arxiv_cannonical_url(source_url) not in self._outputted_items)
)
def process_entry(self, item):
if ArxivPapers.is_arxiv(item.url):
contents = ArxivPapers.get_contents(item)
- contents['source'] = 'arxiv'
+ contents["source"] = "arxiv"
else:
contents = self.get_contents(item)
@@ -154,7 +163,7 @@ def _get_text(item):
domain = parse_domain(item.source_url)
if parser := HTML_PARSERS.get(domain):
res = parser(item.source_url)
- return res and res.get('text')
+ return res and res.get("text")
class EbookArticles(SpreadsheetDataset):
@@ -168,9 +177,7 @@ def setup(self):
def _get_text(self, item):
file_id = item.source_url.split("/")[-2]
- filename = download(
- output=str(self.files_path / f"{item.title}.epub"), id=file_id
- )
+ filename = download(output=str(self.files_path / f"{item.title}.epub"), id=file_id)
return convert_file(filename, "plain", "epub", extra_args=["--wrap=none"])
@@ -221,12 +228,12 @@ def get_contents(cls, item) -> Dict:
contents = fetch_arxiv(item.url or item.source_url)
if cls.maybe(item, "authors") and item.authors.strip():
- contents['authors'] = [i.strip() for i in item.authors.split(',')]
+ contents["authors"] = [i.strip() for i in item.authors.split(",")]
if cls.maybe(item, "title"):
- contents['title'] = cls.maybe(item, "title")
+ contents["title"] = cls.maybe(item, "title")
- contents['date_published'] = cls._get_published_date(
- cls.maybe(item, "date_published") or contents.get('date_published')
+ contents["date_published"] = cls._get_published_date(
+ cls.maybe(item, "date_published") or contents.get("date_published")
)
return contents
diff --git a/align_data/sources/articles/google_cloud.py b/align_data/sources/articles/google_cloud.py
index adb6fa61..ca310235 100644
--- a/align_data/sources/articles/google_cloud.py
+++ b/align_data/sources/articles/google_cloud.py
@@ -146,17 +146,17 @@ def fetch_markdown(file_id):
"source_type": "markdown",
}
except Exception as e:
- return {'error': str(e)}
+ return {"error": str(e)}
def parse_grobid(contents):
doc_dict = grobid_tei_xml.parse_document_xml(contents).to_dict()
- authors = [xx["full_name"].strip(' !') for xx in doc_dict.get("header", {}).get("authors", [])]
+ authors = [xx["full_name"].strip(" !") for xx in doc_dict.get("header", {}).get("authors", [])]
- if not doc_dict.get('body'):
+ if not doc_dict.get("body"):
return {
- 'error': 'No contents in XML file',
- 'source_type': 'xml',
+ "error": "No contents in XML file",
+ "source_type": "xml",
}
return {
@@ -169,67 +169,72 @@ def parse_grobid(contents):
def get_content_type(res):
- header = res.headers.get('Content-Type') or ''
- parts = [c_type.strip().lower() for c_type in header.split(';')]
+ header = res.headers.get("Content-Type") or ""
+ parts = [c_type.strip().lower() for c_type in header.split(";")]
return set(filter(None, parts))
def extract_gdrive_contents(link):
- file_id = link.split('/')[-2]
- url = f'https://drive.google.com/uc?id={file_id}'
- res = fetch(url, 'head')
+ file_id = link.split("/")[-2]
+ url = f"https://drive.google.com/uc?id={file_id}"
+ res = fetch(url, "head")
if res.status_code == 403:
- logger.error('Could not fetch the file at %s - 403 returned', link)
- return {'error': 'Could not read file from google drive - forbidden'}
+ logger.error("Could not fetch the file at %s - 403 returned", link)
+ return {"error": "Could not read file from google drive - forbidden"}
if res.status_code >= 400:
- logger.error('Could not fetch the file at %s - are you sure that link is correct?', link)
- return {'error': 'Could not read file from google drive'}
+ logger.error("Could not fetch the file at %s - are you sure that link is correct?", link)
+ return {"error": "Could not read file from google drive"}
result = {
- 'source_url': link,
- 'downloaded_from': 'google drive',
+ "source_url": link,
+ "downloaded_from": "google drive",
}
content_type = get_content_type(res)
if not content_type:
- result['error'] = 'no content type'
- elif content_type & {'application/octet-stream', 'application/pdf'}:
+ result["error"] = "no content type"
+ elif content_type & {"application/octet-stream", "application/pdf"}:
result.update(fetch_pdf(url))
- elif content_type & {'text/markdown'}:
+ elif content_type & {"text/markdown"}:
result.update(fetch_markdown(file_id))
- elif content_type & {'application/epub+zip', 'application/epub'}:
- result['source_type'] = 'ebook'
- elif content_type & {'text/html'}:
+ elif content_type & {"application/epub+zip", "application/epub"}:
+ result["source_type"] = "ebook"
+ elif content_type & {"text/html"}:
res = fetch(url)
- if 'Google Drive - Virus scan warning' in res.text:
+ if "Google Drive - Virus scan warning" in res.text:
soup = BeautifulSoup(res.content, "html.parser")
- res = fetch(soup.select_one('form').get('action'))
+ res = fetch(soup.select_one("form").get("action"))
content_type = get_content_type(res)
- if content_type & {'text/xml'}:
+ if content_type & {"text/xml"}:
result.update(parse_grobid(res.content))
- elif content_type & {'text/html'}:
+ elif content_type & {"text/html"}:
soup = BeautifulSoup(res.content, "html.parser")
- result.update({
- 'text': MarkdownConverter().convert_soup(soup.select_one('body')).strip(),
- 'source_type': 'html',
- })
+ result.update(
+ {
+ "text": MarkdownConverter().convert_soup(soup.select_one("body")).strip(),
+ "source_type": "html",
+ }
+ )
else:
- result['error'] = f'unknown content type: {content_type}'
+ result["error"] = f"unknown content type: {content_type}"
else:
- result['error'] = f'unknown content type: {content_type}'
+ result["error"] = f"unknown content type: {content_type}"
return result
def google_doc(url: str) -> Dict:
"""Fetch the contents of the given gdoc url as markdown."""
- res = re.search(r'https://docs.google.com/document/(?:u/)?(?:0/)?d/(.*?)/', url)
+ res = re.search(r"https://docs.google.com/document/(?:u/)?(?:0/)?d/(.*?)/", url)
if not res:
return {}
doc_id = res.group(1)
- body = fetch_element(f'https://docs.google.com/document/d/{doc_id}/export?format=html', 'body')
+ body = fetch_element(f"https://docs.google.com/document/d/{doc_id}/export?format=html", "body")
if body:
- return {'text': MarkdownConverter().convert_soup(body).strip(), 'source_url': url}
+ return {
+ "text": MarkdownConverter().convert_soup(body).strip(),
+ "source_url": url,
+ }
return {}
diff --git a/align_data/sources/articles/html.py b/align_data/sources/articles/html.py
index 7df71860..d3c2490c 100644
--- a/align_data/sources/articles/html.py
+++ b/align_data/sources/articles/html.py
@@ -70,9 +70,9 @@ def getter(url):
for e in elem.select(sel):
e.extract()
return {
- 'text': MarkdownConverter().convert_soup(elem).strip(),
- 'source_url': url,
- 'source_type': 'html',
+ "text": MarkdownConverter().convert_soup(elem).strip(),
+ "source_url": url,
+ "source_type": "html",
}
return getter
diff --git a/align_data/sources/articles/indices.py b/align_data/sources/articles/indices.py
index 60075ef2..5e7d5a05 100644
--- a/align_data/sources/articles/indices.py
+++ b/align_data/sources/articles/indices.py
@@ -239,49 +239,58 @@ def fetch_all():
class IndicesDataset(AlignmentDataset):
-
- done_key = 'url'
+ done_key = "url"
@property
def items_list(self):
return fetch_all().values()
def get_item_key(self, item):
- return item.get('url')
+ return item.get("url")
@staticmethod
def extract_authors(item):
- if authors := (item.get('authors') or '').strip():
- return [author.strip() for author in authors.split(',') if author.strip()]
+ if authors := (item.get("authors") or "").strip():
+ return [author.strip() for author in authors.split(",") if author.strip()]
return []
def process_entry(self, item):
contents = {}
- url = item.get('source_url') or item.get('url')
+ url = item.get("source_url") or item.get("url")
if url:
- contents= item_metadata(url)
-
- if not contents.get('text'):
- logger.error('Could not get text for %s (%s) - %s - skipping for now', item.get('title'), url, contents.get('error'))
+ contents = item_metadata(url)
+
+ if not contents.get("text"):
+ logger.error(
+ "Could not get text for %s (%s) - %s - skipping for now",
+ item.get("title"),
+ url,
+ contents.get("error"),
+ )
return None
# If the article is not an arxiv paper, just mark it as ignored - if in the future editors
# decide it's worth adding, it can be fetched then
- if parse_domain(url or '') != 'arxiv.org':
- return self.make_data_entry({
- 'source': self.name,
- 'url': self.get_item_key(item),
- 'title': item.get('title'),
- 'date_published': self._get_published_date(item.get('date_published')),
- 'authors': self.extract_authors(item),
- 'status': 'Ignored',
- 'comments': 'Added from indices',
- })
-
- return self.make_data_entry({
- 'source': 'arxiv',
- 'url': contents.get('url') or self.get_item_key(item),
- 'title': item.get('title'),
- 'date_published': self._get_published_date(item.get('date_published')),
- 'authors': self.extract_authors(item),
- }, **contents)
+ if parse_domain(url or "") != "arxiv.org":
+ return self.make_data_entry(
+ {
+ "source": self.name,
+ "url": self.get_item_key(item),
+ "title": item.get("title"),
+ "date_published": self._get_published_date(item.get("date_published")),
+ "authors": self.extract_authors(item),
+ "status": "Ignored",
+ "comments": "Added from indices",
+ }
+ )
+
+ return self.make_data_entry(
+ {
+ "source": "arxiv",
+ "url": contents.get("url") or self.get_item_key(item),
+ "title": item.get("title"),
+ "date_published": self._get_published_date(item.get("date_published")),
+ "authors": self.extract_authors(item),
+ },
+ **contents,
+ )
diff --git a/align_data/sources/articles/parsers.py b/align_data/sources/articles/parsers.py
index 36736c95..c210e565 100644
--- a/align_data/sources/articles/parsers.py
+++ b/align_data/sources/articles/parsers.py
@@ -24,27 +24,29 @@ def get_pdf_from_page(*link_selectors: str):
:param List[str] link_selectors: CSS selector used to find the final download link
:returns: the contents of the pdf file as a string
"""
+
def getter(url: str):
link: str = url
for selector in link_selectors:
elem = fetch_element(link, selector)
if not elem:
- return {'error': f'Could not find pdf download link for {link} using \'{selector}\''}
+ return {"error": f"Could not find pdf download link for {link} using '{selector}'"}
- link = elem.get('href')
- if not link.startswith('http') or not link.startswith('//'):
+ link = elem.get("href")
+ if not link.startswith("http") or not link.startswith("//"):
link = urljoin(url, link)
# Some pages keep link to google drive previews of pdf files, which need to be
# mangled to get the URL of the actual pdf file
- if 'drive.google.com' in link and '/view' in link:
+ if "drive.google.com" in link and "/view" in link:
return extract_gdrive_contents(link)
if parse_domain(link) == "arxiv.org":
return fetch_arxiv(link)
if pdf := fetch_pdf(link):
return pdf
- return {'error': f'Could not fetch pdf from {link}'}
+ return {"error": f"Could not fetch pdf from {link}"}
+
return getter
@@ -77,10 +79,11 @@ def __call__(self, url):
def error(error_msg):
"""Returns a url handler function that just logs the provided `error` string."""
+
def func(url):
if error_msg:
logger.error(error_msg)
- return {'error': error_msg, 'source_url': url}
+ return {"error": error_msg, "source_url": url}
return func
@@ -129,17 +132,17 @@ def getter(url):
"mediangroup.org": element_extractor("div.entry-content"),
"www.alexirpan.com": element_extractor("article"),
"www.incompleteideas.net": element_extractor("body"),
- "ai-alignment.com": MediumParser(name='html', url='ai-alignment.com'),
+ "ai-alignment.com": MediumParser(name="html", url="ai-alignment.com"),
"aisrp.org": element_extractor("article"),
"bounded-regret.ghost.io": element_extractor("div.post-content"),
"carnegieendowment.org": element_extractor(
"div.article-body", remove=[".no-print", ".related-pubs"]
),
- "casparoesterheld.com": element_extractor(
- ".entry-content", remove=["div.sharedaddy"]
- ),
+ "casparoesterheld.com": element_extractor(".entry-content", remove=["div.sharedaddy"]),
"cullenokeefe.com": element_extractor("div.sqs-block-content"),
- "deepmindsafetyresearch.medium.com": MediumParser(name='html', url='deepmindsafetyresearch.medium.com'),
+ "deepmindsafetyresearch.medium.com": MediumParser(
+ name="html", url="deepmindsafetyresearch.medium.com"
+ ),
"docs.google.com": google_doc,
"docs.microsoft.com": element_extractor("div.content"),
"digichina.stanford.edu": element_extractor("div.h_editor-content"),
@@ -154,7 +157,7 @@ def getter(url):
"link.springer.com": element_extractor("article.c-article-body"),
"longtermrisk.org": element_extractor("div.entry-content"),
"lukemuehlhauser.com": element_extractor("div.entry-content"),
- "medium.com": MediumParser(name='html', url='medium.com'),
+ "medium.com": MediumParser(name="html", url="medium.com"),
"openai.com": element_extractor("#content"),
"ought.org": element_extractor("div.BlogPostBodyContainer"),
"sideways-view.com": element_extractor("article", remove=["header"]),
@@ -169,10 +172,8 @@ def getter(url):
),
"theconversation.com": element_extractor("div.content-body"),
"thegradient.pub": element_extractor("div.c-content"),
- "towardsdatascience.com": MediumParser(name='html', url='towardsdatascience.com'),
- "unstableontology.com": element_extractor(
- ".entry-content", remove=["div.sharedaddy"]
- ),
+ "towardsdatascience.com": MediumParser(name="html", url="towardsdatascience.com"),
+ "unstableontology.com": element_extractor(".entry-content", remove=["div.sharedaddy"]),
"waitbutwhy.com": element_extractor("article", remove=[".entry-header"]),
"weightagnostic.github.io": element_extractor(
"dt-article", remove=["#authors_section", "dt-byline"]
@@ -180,9 +181,7 @@ def getter(url):
"cnas.org": element_extractor("#mainbar-toc"),
"econlib.org": element_extractor("div.post-content"),
"humanityplus.org": element_extractor("div.content"),
- "gleech.org": element_extractor(
- "article.post-content", remove=["center", "div.accordion"]
- ),
+ "gleech.org": element_extractor("article.post-content", remove=["center", "div.accordion"]),
"ibm.com": element_extractor("div:has(> p)"), # IBM's HTML is really ugly...
"microsoft.com": element_extractor("div.content-container"),
"mdpi.com": element_extractor(
@@ -259,32 +258,30 @@ def getter(url):
"jstor.org": doi_getter,
"ri.cmu.edu": get_pdf_from_page("a.pub-link"),
"risksciences.ucla.edu": get_pdf_from_page('a:-soup-contains("Download")'),
- "ssrn.com": get_pdf_from_page(
- '.abstract-buttons a.button-link:-soup-contains("Download")'
- ),
+ "ssrn.com": get_pdf_from_page('.abstract-buttons a.button-link:-soup-contains("Download")'),
"yjolt.org": get_pdf_from_page("span.file a"),
}
def parse_domain(url: str) -> str:
- return url and urlparse(url).netloc.lstrip('www.')
+ return url and urlparse(url).netloc.lstrip("www.")
def item_metadata(url) -> Dict[str, any]:
domain = parse_domain(url)
try:
- res = fetch(url, 'head')
+ res = fetch(url, "head")
except (MissingSchema, InvalidSchema, ConnectionError) as e:
- return {'error': str(e)}
+ return {"error": str(e)}
- content_type = {item.strip() for item in res.headers.get('Content-Type', '').split(';')}
+ content_type = {item.strip() for item in res.headers.get("Content-Type", "").split(";")}
if content_type & {"text/html", "text/xml"}:
# If the url points to a html webpage, then it either contains the text as html, or
# there is a link to a pdf on it
if parser := HTML_PARSERS.get(domain):
res = parser(url)
- if res and 'error' not in res:
+ if res and "error" not in res:
# Proper contents were found on the page, so use them
return res
@@ -296,13 +293,11 @@ def item_metadata(url) -> Dict[str, any]:
if parser := UNIMPLEMENTED_PARSERS.get(domain):
return parser(url)
- if domain not in (
- HTML_PARSERS.keys() | PDF_PARSERS.keys() | UNIMPLEMENTED_PARSERS.keys()
- ):
+ if domain not in (HTML_PARSERS.keys() | PDF_PARSERS.keys() | UNIMPLEMENTED_PARSERS.keys()):
return {"error": "No domain handler defined"}
return {"error": "could not parse url"}
elif content_type & {"application/octet-stream", "application/pdf"}:
- if domain == 'arxiv.org':
+ if domain == "arxiv.org":
return fetch_arxiv(url)
# just download it as a pdf
return fetch_pdf(url)
diff --git a/align_data/sources/articles/pdf.py b/align_data/sources/articles/pdf.py
index b0e31951..43495142 100644
--- a/align_data/sources/articles/pdf.py
+++ b/align_data/sources/articles/pdf.py
@@ -54,9 +54,7 @@ def fetch_pdf(link):
link,
)
- content_type = {
- c_type.strip().lower() for c_type in res.headers.get("Content-Type").split(";")
- }
+ content_type = {c_type.strip().lower() for c_type in res.headers.get("Content-Type").split(";")}
if not content_type & {"application/octet-stream", "application/pdf"}:
return {
"error": f"Wrong content type retrieved: {content_type} - {link}",
@@ -71,8 +69,8 @@ def fetch_pdf(link):
"source_type": "pdf",
}
except (TypeError, PdfReadError) as e:
- logger.error('Could not read PDF file: %s', e)
- return {'error': str(e)}
+ logger.error("Could not read PDF file: %s", e)
+ return {"error": str(e)}
filenames = [
i.strip().split("=")[1]
@@ -96,11 +94,7 @@ def get_arxiv_link(doi):
if res.status_code != 200:
return None
- vals = [
- val
- for val in response.json().get("values")
- if val.get("type", "").upper() == "URL"
- ]
+ vals = [val for val in response.json().get("values") if val.get("type", "").upper() == "URL"]
if not vals:
return None
@@ -135,7 +129,7 @@ def doi_getter(url):
def parse_vanity(url) -> Dict[str, Any]:
contents = fetch_element(url, "article")
if not contents:
- return {'error': 'Could not fetch from arxiv vanity'}
+ return {"error": "Could not fetch from arxiv vanity"}
if title := contents.select_one("h1.ltx_title"):
title = title.text
diff --git a/align_data/sources/articles/updater.py b/align_data/sources/articles/updater.py
index 861d0af6..f453b5ae 100644
--- a/align_data/sources/articles/updater.py
+++ b/align_data/sources/articles/updater.py
@@ -10,7 +10,7 @@
logger = logging.getLogger(__name__)
-Item = namedtuple('Item', ['updates', 'article'])
+Item = namedtuple("Item", ["updates", "article"])
@dataclass
@@ -33,52 +33,50 @@ def maybe(item, key):
def items_list(self):
df = pd.read_csv(self.csv_path, delimiter=self.delimiter)
self.csv_items = [
- item for item in df.itertuples()
- if self.maybe(item, 'id') or self.maybe(item, 'hash_id')
+ item
+ for item in df.itertuples()
+ if self.maybe(item, "id") or self.maybe(item, "hash_id")
]
- by_id = {i.id: i for i in self.csv_items if self.maybe(i, 'id')}
- by_hash_id = {i.hash_id: i for i in self.csv_items if self.maybe(i, 'hash_id')}
+ by_id = {i.id: i for i in self.csv_items if self.maybe(i, "id")}
+ by_hash_id = {i.hash_id: i for i in self.csv_items if self.maybe(i, "hash_id")}
- return [
- Item(by_id.get(a._id) or by_hash_id.get(a.id), a)
- for a in self.read_entries()
- ]
+ return [Item(by_id.get(a._id) or by_hash_id.get(a.id), a) for a in self.read_entries()]
@property
def _query_items(self):
- ids = [i.id for i in self.csv_items if self.maybe(i, 'id')]
- hash_ids = [i.hash_id for i in self.csv_items if self.maybe(i, 'hash_id')]
+ ids = [i.id for i in self.csv_items if self.maybe(i, "id")]
+ hash_ids = [i.hash_id for i in self.csv_items if self.maybe(i, "hash_id")]
return select(Article).where(or_(Article.id.in_(hash_ids), Article._id.in_(ids)))
def update_text(self, updates, article):
# If the url is the same as it was before, and there isn't a source url provided, assume that the
# previous text is still valid
- if article.url == self.maybe(updates, 'url') and not self.maybe(updates, 'source_url'):
+ if article.url == self.maybe(updates, "url") and not self.maybe(updates, "source_url"):
return
# If no url found, then don't bother fetching the text - assume it was successfully fetched previously
- url = self.maybe(updates, 'source_url') or self.maybe(updates, 'url')
+ url = self.maybe(updates, "source_url") or self.maybe(updates, "url")
if not url:
return
if article.url != url:
- article.add_meta('source_url', url)
+ article.add_meta("source_url", url)
metadata = item_metadata(url)
# Only change the text if it could be fetched - better to have outdated values than none
- if metadata.get('text'):
- article.text = metadata.get('text')
- article.status = metadata.get('error')
+ if metadata.get("text"):
+ article.text = metadata.get("text")
+ article.status = metadata.get("error")
def process_entry(self, item):
updates, article = item
- for key in ['url', 'title', 'source', 'authors', 'comment', 'confidence']:
+ for key in ["url", "title", "source", "authors", "comment", "confidence"]:
value = self.maybe(updates, key)
if value and getattr(article, key, None) != value:
setattr(article, key, value)
- if date := getattr(updates, 'date_published', None):
+ if date := getattr(updates, "date_published", None):
article.date_published = self._get_published_date(date)
self.update_text(updates, article)
diff --git a/align_data/sources/arxiv_papers.py b/align_data/sources/arxiv_papers.py
index 2b98223f..2fb64377 100644
--- a/align_data/sources/arxiv_papers.py
+++ b/align_data/sources/arxiv_papers.py
@@ -28,37 +28,37 @@ def get_arxiv_metadata(paper_id) -> arxiv.Result:
return None
-def get_id(url: str) -> Optional[str]:
+def get_id(url: str) -> str | None:
if res := re.search(r"https?://arxiv.org/(?:abs|pdf)/(.*?)(?:v\d+)?(?:/|\.pdf)?$", url):
return res.group(1)
def canonical_url(url: str) -> str:
if paper_id := get_id(url):
- return f'https://arxiv.org/abs/{paper_id}'
+ return f"https://arxiv.org/abs/{paper_id}"
return url
def get_contents(paper_id: str) -> Dict[str, Any]:
arxiv_vanity = parse_vanity(f"https://www.arxiv-vanity.com/papers/{paper_id}")
- if 'error' not in arxiv_vanity:
+ if "error" not in arxiv_vanity:
return arxiv_vanity
ar5iv = parse_vanity(f"https://ar5iv.org/abs/{paper_id}")
- if 'error' not in ar5iv:
+ if "error" not in ar5iv:
return ar5iv
return fetch_pdf(f"https://arxiv.org/pdf/{paper_id}.pdf")
-def get_version(id: str) -> Optional[str]:
- if res := re.search(r'.*v(\d+)$', id):
+def get_version(id: str) -> str | None:
+ if res := re.search(r".*v(\d+)$", id):
return res.group(1)
def is_withdrawn(url: str):
- if elem := fetch_element(canonical_url(url), '.extra-services .full-text ul'):
- return elem.text.strip().lower() == 'withdrawn'
+ if elem := fetch_element(canonical_url(url), ".extra-services .full-text ul"):
+ return elem.text.strip().lower() == "withdrawn"
return None
@@ -84,18 +84,21 @@ def add_metadata(data, paper_id):
def fetch_arxiv(url) -> Dict:
paper_id = get_id(url)
if not paper_id:
- return {'error': 'Could not extract arxiv id'}
+ return {"error": "Could not extract arxiv id"}
if is_withdrawn(url):
- paper = {'status': 'Withdrawn'}
+ paper = {"status": "Withdrawn"}
else:
paper = get_contents(paper_id)
- data = add_metadata({
- "url": canonical_url(url),
- "source_type": paper.get('data_source'),
- }, paper_id)
- authors = data.get('authors') or paper.get("authors") or []
- data['authors'] = [str(a).strip() for a in authors]
+ data = add_metadata(
+ {
+ "url": canonical_url(url),
+ "source_type": paper.get("data_source"),
+ },
+ paper_id,
+ )
+ authors = data.get("authors") or paper.get("authors") or []
+ data["authors"] = [str(a).strip() for a in authors]
return merge_dicts(data, paper)
diff --git a/align_data/sources/blogs/__init__.py b/align_data/sources/blogs/__init__.py
index 3fee9ca2..05831f3e 100644
--- a/align_data/sources/blogs/__init__.py
+++ b/align_data/sources/blogs/__init__.py
@@ -26,9 +26,7 @@
url="https://deepmindsafetyresearch.medium.com/",
authors=["DeepMind Safety Research"],
),
- GwernBlog(
- name="gwern_blog", url="https://www.gwern.net/", authors=["Gwern Branwen"]
- ),
+ GwernBlog(name="gwern_blog", url="https://www.gwern.net/", authors=["Gwern Branwen"]),
ColdTakes(
name="cold_takes",
url="https://www.cold-takes.com/",
diff --git a/align_data/sources/blogs/blogs.py b/align_data/sources/blogs/blogs.py
index fcf97891..8c2a2b94 100644
--- a/align_data/sources/blogs/blogs.py
+++ b/align_data/sources/blogs/blogs.py
@@ -51,12 +51,7 @@ def _get_published_date(self, contents):
return ""
def extract_authors(self, article):
- return (
- article.select_one("header .post-meta")
- .text.split("·")[1]
- .strip()
- .split(", ")
- )
+ return article.select_one("header .post-meta").text.split("·")[1].strip().split(", ")
class OpenAIResearch(HTMLDataset):
@@ -82,9 +77,7 @@ def extract_authors(self, article):
authors = []
if authors_div:
authors = [
- i.split("(")[0].strip()
- for i in authors_div.select_one("p").children
- if not i.name
+ i.split("(")[0].strip() for i in authors_div.select_one("p").children if not i.name
]
return authors or ["OpenAI Research"]
@@ -114,9 +107,7 @@ def items_list(self):
page += 1
# update the tqdm progress bar
- pbar.set_postfix_str(
- f"page {page}", refresh=True
- ) # Set postfix to "page X"
+ pbar.set_postfix_str(f"page {page}", refresh=True) # Set postfix to "page X"
pbar.update() # Here we increment the progress bar by 1
logger.info("Got %s pages", len(articles))
diff --git a/align_data/sources/blogs/wp_blog.py b/align_data/sources/blogs/wp_blog.py
index c0132301..cd409d98 100644
--- a/align_data/sources/blogs/wp_blog.py
+++ b/align_data/sources/blogs/wp_blog.py
@@ -42,9 +42,7 @@ def items_list(self):
self.items[item["link"]] = item
# update the tqdm progress bar
- pbar.set_postfix_str(
- f"page {page_number}", refresh=True
- ) # Set postfix to "page X"
+ pbar.set_postfix_str(f"page {page_number}", refresh=True) # Set postfix to "page X"
pbar.update() # Here we increment the progress bar by 1
logger.info(f"Got {len(self.items)} pages")
diff --git a/align_data/sources/distill/distill.py b/align_data/sources/distill/distill.py
index 3a514a01..8709b154 100644
--- a/align_data/sources/distill/distill.py
+++ b/align_data/sources/distill/distill.py
@@ -6,7 +6,9 @@ class Distill(RSSDataset):
done_key = "url"
def extract_authors(self, item):
- return [a.text for a in item["soup"].select(".authors-affiliations p.author a")] or ["Distill"]
+ return [a.text for a in item["soup"].select(".authors-affiliations p.author a")] or [
+ "Distill"
+ ]
def _get_text(self, item):
article = item["soup"].find("d-article") or item["soup"].find("dt-article")
diff --git a/align_data/sources/ebooks/__init__.py b/align_data/sources/ebooks/__init__.py
index 0055f5e0..7fdcd729 100644
--- a/align_data/sources/ebooks/__init__.py
+++ b/align_data/sources/ebooks/__init__.py
@@ -1,7 +1,5 @@
from .agentmodels import AgentModels
EBOOK_REGISTRY = [
- AgentModels(
- name="agentmodels", repo="https://github.com/agentmodels/agentmodels.org.git"
- ),
+ AgentModels(name="agentmodels", repo="https://github.com/agentmodels/agentmodels.org.git"),
]
diff --git a/align_data/sources/ebooks/agentmodels.py b/align_data/sources/ebooks/agentmodels.py
index cfd68a79..65b52502 100644
--- a/align_data/sources/ebooks/agentmodels.py
+++ b/align_data/sources/ebooks/agentmodels.py
@@ -27,9 +27,7 @@ def setup(self):
self.files_path = self.base_dir / "chapters"
def _get_published_date(self, filename):
- last_commit = next(
- self.repository.iter_commits(paths=f"chapters/{filename.name}")
- )
+ last_commit = next(self.repository.iter_commits(paths=f"chapters/{filename.name}"))
return last_commit.committed_datetime.astimezone(timezone.utc)
def process_entry(self, filename):
diff --git a/align_data/sources/greaterwrong/greaterwrong.py b/align_data/sources/greaterwrong/greaterwrong.py
index deb6d4d1..5f32b68f 100644
--- a/align_data/sources/greaterwrong/greaterwrong.py
+++ b/align_data/sources/greaterwrong/greaterwrong.py
@@ -77,10 +77,7 @@ def setup(self):
self.ai_tags = get_allowed_tags(self.base_url, self.name)
def tags_ok(self, post):
- return (
- not self.ai_tags
- or {t["name"] for t in post["tags"] if t.get("name")} & self.ai_tags
- )
+ return not self.ai_tags or {t["name"] for t in post["tags"] if t.get("name")} & self.ai_tags
def get_item_key(self, item):
return item["pageUrl"]
@@ -174,7 +171,7 @@ def process_entry(self, item):
authors = item["coauthors"]
if item["user"]:
authors = [item["user"]] + authors
- authors = [a["displayName"] for a in authors] or ['anonymous']
+ authors = [a["displayName"] for a in authors] or ["anonymous"]
return self.make_data_entry(
{
"title": item["title"],
diff --git a/align_data/sources/stampy/stampy.py b/align_data/sources/stampy/stampy.py
index 88a7149e..49f57620 100644
--- a/align_data/sources/stampy/stampy.py
+++ b/align_data/sources/stampy/stampy.py
@@ -45,13 +45,9 @@ def _get_published_date(self, entry):
def process_entry(self, entry):
def clean_text(text):
text = html.unescape(text)
- return re.sub(
- r"\(/\?state=(\w+)\)", r"(http://aisafety.info?state=\1)", text
- )
+ return re.sub(r"\(/\?state=(\w+)\)", r"(http://aisafety.info?state=\1)", text)
- question = clean_text(
- entry["Question"]
- ) # raise an error if the entry has no question
+ question = clean_text(entry["Question"]) # raise an error if the entry has no question
answer = clean_text(entry["Rich Text"])
url = "https://aisafety.info?state=" + entry["UI ID"]
diff --git a/main.py b/main.py
index 67478a12..2c1d8da2 100644
--- a/main.py
+++ b/main.py
@@ -7,8 +7,13 @@
from align_data import ALL_DATASETS, get_dataset
from align_data.analysis.count_tokens import count_token
-from align_data.sources.articles.articles import update_new_items, check_new_articles, update_articles
-from align_data.pinecone.update_pinecone import PineconeUpdater
+from align_data.sources.articles.articles import (
+ update_new_items,
+ check_new_articles,
+ update_articles,
+)
+from align_data.embeddings.pinecone.update_pinecone import PineconeUpdater
+from align_data.embeddings.finetuning.training import finetune_embeddings
from align_data.settings import (
METADATA_OUTPUT_SPREADSHEET,
METADATA_SOURCE_SHEET,
@@ -76,12 +81,10 @@ def count_tokens(self, merged_dataset_path: str) -> None:
This function counts the number of tokens, words, and characters in the dataset
:return: None
"""
- assert os.path.exists(
- merged_dataset_path
- ), "The path to the merged dataset does not exist"
+ assert os.path.exists(merged_dataset_path), "The path to the merged dataset does not exist"
count_token(merged_dataset_path)
- def update(self, csv_path, delimiter=','):
+ def update(self, csv_path, delimiter=","):
"""Update all articles in the provided csv files, overwriting the provided values and fetching new text if a different url provided.
:param str csv_path: The path to the csv file to be processed
@@ -115,7 +118,7 @@ def fetch_new_articles(
"""
return check_new_articles(source_spreadsheet, source_sheet)
- def pinecone_update(self, *names) -> None:
+ def pinecone_update(self, *names, force_update=False) -> None:
"""
This function updates the Pinecone vector DB.
@@ -125,14 +128,20 @@ def pinecone_update(self, *names) -> None:
names = ALL_DATASETS
missing = {name for name in names if name not in ALL_DATASETS}
assert not missing, f"{missing} are not valid dataset names"
- PineconeUpdater().update(names)
+ PineconeUpdater().update(names, force_update)
- def pinecone_update_all(self, *skip) -> None:
+ def pinecone_update_all(self, *skip, force_update=False) -> None:
"""
This function updates the Pinecone vector DB.
"""
names = [name for name in ALL_DATASETS if name not in skip]
- PineconeUpdater().update(names)
+ PineconeUpdater().update(names, force_update)
+
+ def train_finetuning_layer(self) -> None:
+ """
+ This function trains a finetuning layer on top of the OpenAI embeddings.
+ """
+ finetune_embeddings()
if __name__ == "__main__":
diff --git a/migrations/versions/59ac3cb671e3_added_pinecone_update_required_to_.py b/migrations/versions/59ac3cb671e3_added_pinecone_update_required_to_.py
index 7a8485fe..e5b9a303 100644
--- a/migrations/versions/59ac3cb671e3_added_pinecone_update_required_to_.py
+++ b/migrations/versions/59ac3cb671e3_added_pinecone_update_required_to_.py
@@ -18,9 +18,7 @@
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
- op.add_column(
- "articles", sa.Column("pinecone_update_required", sa.Boolean(), nullable=False)
- )
+ op.add_column("articles", sa.Column("pinecone_update_required", sa.Boolean(), nullable=False))
# ### end Alembic commands ###
diff --git a/migrations/versions/f5a2bcfa6b2c_add_status_column.py b/migrations/versions/f5a2bcfa6b2c_add_status_column.py
index 76c89ee0..d93a8c86 100644
--- a/migrations/versions/f5a2bcfa6b2c_add_status_column.py
+++ b/migrations/versions/f5a2bcfa6b2c_add_status_column.py
@@ -10,17 +10,17 @@
from sqlalchemy.dialects import mysql
# revision identifiers, used by Alembic.
-revision = 'f5a2bcfa6b2c'
-down_revision = '59ac3cb671e3'
+revision = "f5a2bcfa6b2c"
+down_revision = "59ac3cb671e3"
branch_labels = None
depends_on = None
def upgrade() -> None:
- op.add_column('articles', sa.Column('status', sa.String(length=256), nullable=True))
- op.add_column('articles', sa.Column('comments', mysql.LONGTEXT(), nullable=True))
+ op.add_column("articles", sa.Column("status", sa.String(length=256), nullable=True))
+ op.add_column("articles", sa.Column("comments", mysql.LONGTEXT(), nullable=True))
def downgrade() -> None:
- op.drop_column('articles', 'comments')
- op.drop_column('articles', 'status')
+ op.drop_column("articles", "comments")
+ op.drop_column("articles", "status")
diff --git a/pyproject.toml b/pyproject.toml
new file mode 100644
index 00000000..aa4949aa
--- /dev/null
+++ b/pyproject.toml
@@ -0,0 +1,2 @@
+[tool.black]
+line-length = 100
diff --git a/requirements.txt b/requirements.txt
index b7985f3f..88ce2517 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -36,3 +36,5 @@ openai
langchain
nltk
pinecone-client
+
+torch
diff --git a/tests/align_data/articles/test_datasets.py b/tests/align_data/articles/test_datasets.py
index a0ea9fc1..5b73e05b 100644
--- a/tests/align_data/articles/test_datasets.py
+++ b/tests/align_data/articles/test_datasets.py
@@ -48,7 +48,7 @@ def mock_arxiv():
journal_ref="sdf",
primary_category="cat",
)
- metadata.get_short_id.return_value = '2001.11038'
+ metadata.get_short_id.return_value = "2001.11038"
arxiv = Mock()
arxiv.Search.return_value.results.return_value = iter([metadata])
@@ -124,30 +124,24 @@ def test_pdf_articles_process_item(articles):
"text": "pdf contents [bla](asd.com)",
"title": "article no 0",
"url": "http://example.com/item/0",
- 'source_url': 'http://example.com/source_url/0',
+ "source_url": "http://example.com/source_url/0",
}
def test_html_articles_get_text():
def parser(url):
assert url == "http://example.org/bla.bla"
- return {'text': "html contents"}
+ return {"text": "html contents"}
- with patch(
- "align_data.sources.articles.datasets.HTML_PARSERS", {"example.org": parser}
- ):
+ with patch("align_data.sources.articles.datasets.HTML_PARSERS", {"example.org": parser}):
assert (
- HTMLArticles._get_text(Mock(source_url="http://example.org/bla.bla"))
- == "html contents"
+ HTMLArticles._get_text(Mock(source_url="http://example.org/bla.bla")) == "html contents"
)
def test_html_articles_get_text_no_parser():
with patch("align_data.sources.articles.datasets.HTML_PARSERS", {}):
- assert (
- HTMLArticles._get_text(Mock(source_url="http://example.org/bla.bla"))
- is None
- )
+ assert HTMLArticles._get_text(Mock(source_url="http://example.org/bla.bla")) is None
def test_html_articles_process_entry(articles):
@@ -156,7 +150,9 @@ def test_html_articles_process_entry(articles):
item = list(dataset.items_list)[0]
parsers = {
- "example.com": lambda _: {'text': ' html contents with proper elements ble ble '}
+ "example.com": lambda _: {
+ "text": ' html contents with proper elements ble ble '
+ }
}
with patch("align_data.sources.articles.datasets.HTML_PARSERS", parsers):
assert dataset.process_entry(item).to_dict() == {
@@ -170,7 +166,7 @@ def test_html_articles_process_entry(articles):
"text": "html contents with [proper elements](bla.com) ble ble",
"title": "article no 0",
"url": "http://example.com/item/0",
- 'source_url': 'http://example.com/source_url/0',
+ "source_url": "http://example.com/source_url/0",
}
@@ -201,9 +197,7 @@ def test_ebook_articles_process_entry(articles):
contents = ' html contents with proper elements ble ble '
with patch("align_data.sources.articles.datasets.download"):
- with patch(
- "align_data.sources.articles.datasets.convert_file", return_value=contents
- ):
+ with patch("align_data.sources.articles.datasets.convert_file", return_value=contents):
assert dataset.process_entry(item).to_dict() == {
"authors": ["John Snow", "mr Blobby"],
"date_published": "2023-01-01T12:32:11Z",
@@ -215,7 +209,7 @@ def test_ebook_articles_process_entry(articles):
"text": "html contents with [proper elements](bla.com) ble ble",
"title": "article no 0",
"url": "http://example.com/item/0",
- 'source_url': 'http://example.com/source_url/0',
+ "source_url": "http://example.com/source_url/0",
}
@@ -248,7 +242,7 @@ def test_xml_articles_process_entry(articles):
"text": "bla bla",
"title": "article no 0",
"url": "http://example.com/item/0",
- 'source_url': 'http://example.com/source_url/0',
+ "source_url": "http://example.com/source_url/0",
}
@@ -281,19 +275,15 @@ def test_markdown_articles_process_entry(articles):
"text": "bla bla",
"title": "article no 0",
"url": "http://example.com/item/0",
- 'source_url': 'http://example.com/source_url/0',
+ "source_url": "http://example.com/source_url/0",
}
def test_doc_articles_get_text():
dataset = DocArticles(name="bla", spreadsheet_id="123", sheet_id="456")
with patch("align_data.sources.articles.datasets.fetch_file"):
- with patch(
- "align_data.sources.articles.datasets.convert_file", return_value="bla bla"
- ):
- assert (
- dataset._get_text(Mock(source_url="bla.com/bla/123/bla")) == "bla bla"
- )
+ with patch("align_data.sources.articles.datasets.convert_file", return_value="bla bla"):
+ assert dataset._get_text(Mock(source_url="bla.com/bla/123/bla")) == "bla bla"
def test_doc_articles_process_entry(articles):
@@ -302,9 +292,7 @@ def test_doc_articles_process_entry(articles):
item = list(dataset.items_list)[0]
with patch("align_data.sources.articles.datasets.fetch_file"):
- with patch(
- "align_data.sources.articles.datasets.convert_file", return_value="bla bla"
- ):
+ with patch("align_data.sources.articles.datasets.convert_file", return_value="bla bla"):
assert dataset.process_entry(item).to_dict() == {
"authors": ["John Snow", "mr Blobby"],
"date_published": "2023-01-01T12:32:11Z",
@@ -316,11 +304,11 @@ def test_doc_articles_process_entry(articles):
"text": "bla bla",
"title": "article no 0",
"url": "http://example.com/item/0",
- 'source_url': 'http://example.com/source_url/0',
+ "source_url": "http://example.com/source_url/0",
}
-@patch('requests.get', return_value=Mock(content=''))
+@patch("requests.get", return_value=Mock(content=""))
def test_arxiv_process_entry(_, mock_arxiv):
dataset = ArxivPapers(name="asd", spreadsheet_id="ad", sheet_id="da")
item = Mock(
@@ -335,9 +323,7 @@ def test_arxiv_process_entry(_, mock_arxiv):
"authors": ["mr blobby"],
"source_type": "html",
}
- with patch(
- "align_data.sources.arxiv_papers.parse_vanity", return_value=contents
- ):
+ with patch("align_data.sources.arxiv_papers.parse_vanity", return_value=contents):
assert dataset.process_entry(item).to_dict() == {
"comment": "no comment",
"authors": ["mr blobby"],
@@ -377,9 +363,9 @@ def test_arxiv_process_entry_retracted(mock_arxiv):
"""
- with patch('requests.get', return_value=Mock(content=response)):
+ with patch("requests.get", return_value=Mock(content=response)):
article = dataset.process_entry(item)
- assert article.status == 'Withdrawn'
+ assert article.status == "Withdrawn"
assert article.to_dict() == {
"comment": "no comment",
"authors": [],
@@ -407,7 +393,7 @@ def test_special_docs_process_entry():
authors="mr. blobby",
date_published="2023-10-02T01:23:45",
source_type=None,
- source_url="https://ble.ble.com"
+ source_url="https://ble.ble.com",
)
contents = {
"text": "this is the text",
@@ -418,20 +404,20 @@ def test_special_docs_process_entry():
with patch("align_data.sources.articles.datasets.item_metadata", return_value=contents):
assert dataset.process_entry(item).to_dict() == {
- 'authors': ['mr. blobby'],
- 'date_published': '2023-10-02T01:23:45Z',
- 'id': None,
- 'source': 'html',
- 'source_url': "https://ble.ble.com",
- 'source_type': 'html',
- 'summaries': [],
- 'text': 'this is the text',
- 'title': 'this is the title',
- 'url': 'https://bla.bla.bla',
+ "authors": ["mr. blobby"],
+ "date_published": "2023-10-02T01:23:45Z",
+ "id": None,
+ "source": "html",
+ "source_url": "https://ble.ble.com",
+ "source_type": "html",
+ "summaries": [],
+ "text": "this is the text",
+ "title": "this is the title",
+ "url": "https://bla.bla.bla",
}
-@patch('requests.get', return_value=Mock(content=''))
+@patch("requests.get", return_value=Mock(content=""))
def test_special_docs_process_entry_arxiv(_, mock_arxiv):
dataset = SpecialDocs(name="asd", spreadsheet_id="ad", sheet_id="da")
item = Mock(
@@ -447,9 +433,7 @@ def test_special_docs_process_entry_arxiv(_, mock_arxiv):
"source_type": "pdf",
}
- with patch(
- "align_data.sources.arxiv_papers.parse_vanity", return_value=contents
- ):
+ with patch("align_data.sources.arxiv_papers.parse_vanity", return_value=contents):
assert dataset.process_entry(item).to_dict() == {
"comment": "no comment",
"authors": ["mr blobby"],
@@ -469,16 +453,22 @@ def test_special_docs_process_entry_arxiv(_, mock_arxiv):
}
-@pytest.mark.parametrize('url, expected', (
- ("http://bla.bla", "http://bla.bla"),
- ("http://arxiv.org/abs/2001.11038", "https://arxiv.org/abs/2001.11038"),
- ("https://arxiv.org/abs/2001.11038", "https://arxiv.org/abs/2001.11038"),
- ("https://arxiv.org/abs/2001.11038/", "https://arxiv.org/abs/2001.11038"),
- ("https://arxiv.org/pdf/2001.11038", "https://arxiv.org/abs/2001.11038"),
- ("https://arxiv.org/pdf/2001.11038.pdf", "https://arxiv.org/abs/2001.11038"),
- ("https://arxiv.org/pdf/2001.11038v3.pdf", "https://arxiv.org/abs/2001.11038"),
- ("https://arxiv.org/abs/math/2001.11038", "https://arxiv.org/abs/math/2001.11038"),
-))
+@pytest.mark.parametrize(
+ "url, expected",
+ (
+ ("http://bla.bla", "http://bla.bla"),
+ ("http://arxiv.org/abs/2001.11038", "https://arxiv.org/abs/2001.11038"),
+ ("https://arxiv.org/abs/2001.11038", "https://arxiv.org/abs/2001.11038"),
+ ("https://arxiv.org/abs/2001.11038/", "https://arxiv.org/abs/2001.11038"),
+ ("https://arxiv.org/pdf/2001.11038", "https://arxiv.org/abs/2001.11038"),
+ ("https://arxiv.org/pdf/2001.11038.pdf", "https://arxiv.org/abs/2001.11038"),
+ ("https://arxiv.org/pdf/2001.11038v3.pdf", "https://arxiv.org/abs/2001.11038"),
+ (
+ "https://arxiv.org/abs/math/2001.11038",
+ "https://arxiv.org/abs/math/2001.11038",
+ ),
+ ),
+)
def test_special_docs_not_processed_true(url, expected):
dataset = SpecialDocs(name="asd", spreadsheet_id="ad", sheet_id="da")
dataset._outputted_items = [url, expected]
@@ -486,13 +476,15 @@ def test_special_docs_not_processed_true(url, expected):
assert not dataset.not_processed(Mock(url=None, source_url=url))
-@pytest.mark.parametrize('url', (
- "http://bla.bla"
- "http://arxiv.org/abs/2001.11038",
- "https://arxiv.org/abs/2001.11038",
- "https://arxiv.org/abs/2001.11038/",
- "https://arxiv.org/pdf/2001.11038",
-))
+@pytest.mark.parametrize(
+ "url",
+ (
+ "http://bla.bla" "http://arxiv.org/abs/2001.11038",
+ "https://arxiv.org/abs/2001.11038",
+ "https://arxiv.org/abs/2001.11038/",
+ "https://arxiv.org/pdf/2001.11038",
+ ),
+)
def test_special_docs_not_processed_false(url):
dataset = SpecialDocs(name="asd", spreadsheet_id="ad", sheet_id="da")
dataset._outputted_items = []
diff --git a/tests/align_data/articles/test_google_cloud.py b/tests/align_data/articles/test_google_cloud.py
index 39cacce3..bc814fe1 100644
--- a/tests/align_data/articles/test_google_cloud.py
+++ b/tests/align_data/articles/test_google_cloud.py
@@ -1,7 +1,12 @@
from unittest.mock import Mock, patch
import pytest
-from align_data.sources.articles.google_cloud import extract_gdrive_contents, get_content_type, google_doc, parse_grobid
+from align_data.sources.articles.google_cloud import (
+ extract_gdrive_contents,
+ get_content_type,
+ google_doc,
+ parse_grobid,
+)
SAMPLE_XML = """
@@ -45,8 +50,12 @@
def test_google_doc():
def fetcher(url, *args, **kwargs):
- assert url == 'https://docs.google.com/document/d/1fenKXrbvGeZ83hxYf_6mghsZMChxWXjGsZSqY3LZzms/export?format=html'
- return Mock(content="""
+ assert (
+ url
+ == "https://docs.google.com/document/d/1fenKXrbvGeZ83hxYf_6mghsZMChxWXjGsZSqY3LZzms/export?format=html"
+ )
+ return Mock(
+ content="""