diff --git a/.env.example b/.env.example index 057b5ba4..1c29520b 100644 --- a/.env.example +++ b/.env.example @@ -9,3 +9,4 @@ OPENAI_API_KEY="sk-XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX" PINECONE_INDEX_NAME="stampy-chat-ard" PINECONE_API_KEY="xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx" PINECONE_ENVIRONMENT="xx-xxxxx-gcp" +YOUTUBE_API_KEY="" \ No newline at end of file diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 00000000..61a81a7b --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,11 @@ +repos: +- repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.4.0 + hooks: + - id: trailing-whitespace + - id: end-of-file-fixer +- repo: https://github.com/psf/black + rev: 23.7.0 + hooks: + - id: black + language_version: python3.11 diff --git a/align_data/analysis/analyse_jsonl_data.py b/align_data/analysis/analyse_jsonl_data.py index b8c5103a..0aed124d 100644 --- a/align_data/analysis/analyse_jsonl_data.py +++ b/align_data/analysis/analyse_jsonl_data.py @@ -68,9 +68,7 @@ def process_jsonl_files(data_dir): for id, duplicates in seen_urls.items(): if len(duplicates) > 1: - list_of_duplicates = "\n".join( - get_data_dict_str(duplicate) for duplicate in duplicates - ) + list_of_duplicates = "\n".join(get_data_dict_str(duplicate) for duplicate in duplicates) print( f"{len(duplicates)} duplicate ids found. \nId: {id}\n{list_of_duplicates}\n\n\n\n" ) diff --git a/align_data/common/alignment_dataset.py b/align_data/common/alignment_dataset.py index 0bd8f62d..78753169 100644 --- a/align_data/common/alignment_dataset.py +++ b/align_data/common/alignment_dataset.py @@ -164,14 +164,9 @@ def _load_outputted_items(self) -> Set[str]: # This doesn't filter by self.name. The good thing about that is that it should handle a lot more # duplicates. The bad thing is that this could potentially return a massive amount of data if there # are lots of items. - return set( - session.scalars(select(getattr(Article, self.done_key))).all() - ) + return set(session.scalars(select(getattr(Article, self.done_key))).all()) # TODO: Properly handle this - it should create a proper SQL JSON select - return { - item.get(self.done_key) - for item in session.scalars(select(Article.meta)).all() - } + return {item.get(self.done_key) for item in session.scalars(select(Article.meta)).all()} def not_processed(self, item): # NOTE: `self._outputted_items` reads in all items. Which could potentially be a lot. If this starts to @@ -214,7 +209,7 @@ def fetch_entries(self): if self.COOLDOWN: time.sleep(self.COOLDOWN) - def process_entry(self, entry) -> Optional[Article]: + def process_entry(self, entry) -> Article | None: """Process a single entry.""" raise NotImplementedError @@ -223,7 +218,7 @@ def _format_datetime(date) -> str: return date.strftime("%Y-%m-%dT%H:%M:%SZ") @staticmethod - def _get_published_date(date) -> Optional[datetime]: + def _get_published_date(date) -> datetime | None: try: # Totally ignore any timezone info, forcing everything to UTC return parse(str(date)).replace(tzinfo=pytz.UTC) @@ -239,7 +234,11 @@ def unprocessed_items(self, items=None) -> Iterable: urls = map(self.get_item_key, items) with make_session() as session: - articles = session.query(Article).options(joinedload(Article.summaries)).filter(Article.url.in_(urls)) + articles = ( + session.query(Article) + .options(joinedload(Article.summaries)) + .filter(Article.url.in_(urls)) + ) self.articles = {a.url: a for a in articles if a.url} return items @@ -249,9 +248,7 @@ def _load_outputted_items(self) -> Set[str]: with make_session() as session: return set( session.scalars( - select(Article.url) - .join(Article.summaries) - .filter(Summary.source == self.name) + select(Article.url).join(Article.summaries).filter(Summary.source == self.name) ) ) diff --git a/align_data/common/html_dataset.py b/align_data/common/html_dataset.py index c90172a8..c6fbf797 100644 --- a/align_data/common/html_dataset.py +++ b/align_data/common/html_dataset.py @@ -75,7 +75,7 @@ def get_contents(self, article_url): def process_entry(self, article): article_url = self.get_item_key(article) contents = self.get_contents(article_url) - if not contents.get('text'): + if not contents.get("text"): return None return self.make_data_entry(contents) diff --git a/align_data/db/models.py b/align_data/db/models.py index c922c5b7..74817648 100644 --- a/align_data/db/models.py +++ b/align_data/db/models.py @@ -17,7 +17,6 @@ from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship from sqlalchemy.dialects.mysql import LONGTEXT from sqlalchemy.ext.hybrid import hybrid_property -from align_data.settings import PINECONE_METADATA_KEYS logger = logging.getLogger(__name__) @@ -71,33 +70,26 @@ class Article(Base): def __repr__(self) -> str: return f"Article(id={self.id!r}, title={self.title!r}, url={self.url!r}, source={self.source!r}, authors={self.authors!r}, date_published={self.date_published!r})" - def is_metadata_keys_equal(self, other): - if not isinstance(other, Article): - raise TypeError( - f"Expected an instance of Article, got {type(other).__name__}" - ) - return not any( - getattr(self, key, None) - != getattr(other, key, None) # entry_id is implicitly ignored - for key in PINECONE_METADATA_KEYS - ) - def generate_id_string(self) -> bytes: - return "".join(str(getattr(self, field)) for field in self.__id_fields).encode( - "utf-8" - ) + return "".join(str(getattr(self, field)) for field in self.__id_fields).encode("utf-8") @property def __id_fields(self): - if self.source == 'aisafety.info': - return ['url'] - if self.source in ['importai', 'ml_safety_newsletter', 'alignment_newsletter']: - return ['url', 'title', 'source'] + if self.source == "aisafety.info": + return ["url"] + if self.source in ["importai", "ml_safety_newsletter", "alignment_newsletter"]: + return ["url", "title", "source"] return ["url", "title"] @property def missing_fields(self): - fields = set(self.__id_fields) | {'text', 'title', 'url', 'source', 'date_published'} + fields = set(self.__id_fields) | { + "text", + "title", + "url", + "source", + "date_published", + } return sorted([field for field in fields if not getattr(self, field, None)]) def verify_id(self): @@ -133,13 +125,21 @@ def add_meta(self, key, val): self.meta = {} self.meta[key] = val + def append_comment(self, comment: str): + """Appends a comment to the article.comments field. You must run session.commit() to save the comment to the database.""" + if self.comments is None: + self.comments = "" + self.comments = f"{self.comments}\n\n{comment}".strip() + @hybrid_property def is_valid(self): return ( - self.text and self.text.strip() and - self.url and self.title and - self.authors is not None and - self.status == OK_STATUS + self.text + and self.text.strip() + and self.url + and self.title + and self.authors is not None + and self.status == OK_STATUS ) @is_valid.expression @@ -157,7 +157,7 @@ def before_write(cls, _mapper, _connection, target): target.verify_id_fields() if not target.status and target.missing_fields: - target.status = 'Missing fields' + target.status = "Missing fields" target.comments = f'missing fields: {", ".join(target.missing_fields)}' if target.id: diff --git a/align_data/db/session.py b/align_data/db/session.py index 66c571ce..4aa23a87 100644 --- a/align_data/db/session.py +++ b/align_data/db/session.py @@ -10,24 +10,38 @@ logger = logging.getLogger(__name__) +# We create a single engine for the entire application +engine = create_engine(DB_CONNECTION_URI, echo=False) + @contextmanager def make_session(auto_commit=False): - engine = create_engine(DB_CONNECTION_URI, echo=False) - with Session(engine).no_autoflush as session: + with Session(engine, autoflush=False) as session: yield session if auto_commit: session.commit() -def stream_pinecone_updates(session, custom_sources: List[str]): +def stream_pinecone_updates( + session: Session, custom_sources: List[str], force_update: bool = False +): """Yield Pinecone entries that require an update.""" yield from ( - session - .query(Article) - .filter(Article.pinecone_update_required.is_(True)) + session.query(Article) + .filter(or_(Article.pinecone_update_required.is_(True), force_update)) .filter(Article.is_valid) .filter(Article.source.in_(custom_sources)) .filter(or_(Article.confidence == None, Article.confidence > MIN_CONFIDENCE)) .yield_per(1000) ) + + +def get_all_valid_article_ids(session: Session) -> List[str]: + """Return all valid article IDs.""" + query_result = ( + session.query(Article.id) + .filter(Article.is_valid) + .filter(or_(Article.confidence == None, Article.confidence > MIN_CONFIDENCE)) + .all() + ) + return [item[0] for item in query_result] diff --git a/align_data/embeddings/embedding_utils.py b/align_data/embeddings/embedding_utils.py new file mode 100644 index 00000000..6ef57b88 --- /dev/null +++ b/align_data/embeddings/embedding_utils.py @@ -0,0 +1,199 @@ +import logging +from typing import List, Tuple, Dict, Any, Optional +from functools import wraps + +import openai +from langchain.embeddings import HuggingFaceEmbeddings +from openai.error import ( + OpenAIError, + RateLimitError, + APIError, +) +from tenacity import ( + retry, + stop_after_attempt, + wait_random_exponential, + retry_if_exception_type, + retry_if_exception, +) + +from align_data.embeddings.pinecone.pinecone_models import MissingEmbeddingModelError +from align_data.settings import ( + USE_OPENAI_EMBEDDINGS, + OPENAI_EMBEDDINGS_MODEL, + EMBEDDING_LENGTH_BIAS, + SENTENCE_TRANSFORMER_EMBEDDINGS_MODEL, + DEVICE, +) + + +# -------------------- +# CONSTANTS & CONFIGURATION +# -------------------- + +logger = logging.getLogger(__name__) + +hf_embedding_model = None +if not USE_OPENAI_EMBEDDINGS: + hf_embedding_model = HuggingFaceEmbeddings( + model_name=SENTENCE_TRANSFORMER_EMBEDDINGS_MODEL, + model_kwargs={"device": DEVICE}, + encode_kwargs={"show_progress_bar": False}, + ) + +ModerationInfoType = Dict[str, Any] + + +# -------------------- +# DECORATORS +# -------------------- + + +def handle_openai_errors(func): + """Decorator to handle OpenAI-specific exceptions with retries.""" + + @wraps(func) + @retry( + wait=wait_random_exponential(multiplier=1, min=2, max=30), + stop=stop_after_attempt(6), + retry=retry_if_exception_type(RateLimitError) + | retry_if_exception_type(APIError) + | retry_if_exception(lambda e: "502" in str(e)), + ) + def wrapper(*args, **kwargs): + try: + return func(*args, **kwargs) + except RateLimitError as e: + logger.warning(f"OpenAI Rate limit error. Trying again. Error: {e}") + raise + except APIError as e: + if "502" in str(e): + logger.warning(f"OpenAI 502 Bad Gateway error. Trying again. Error: {e}") + else: + logger.error(f"OpenAI API Error encountered: {e}") + raise + except OpenAIError as e: + logger.error(f"OpenAI Error encountered: {e}") + raise + except Exception as e: + logger.error(f"Unexpected error encountered: {e}") + raise + + return wrapper + + +# -------------------- +# MAIN FUNCTIONS +# -------------------- + + +@handle_openai_errors +def moderation_check(texts: List[str]) -> List[ModerationInfoType]: + return openai.Moderation.create(input=texts)["results"] + + +@handle_openai_errors +def _compute_openai_embeddings(non_flagged_texts: List[str], **kwargs) -> List[List[float]]: + data = openai.Embedding.create(input=non_flagged_texts, engine=OPENAI_EMBEDDINGS_MODEL, **kwargs).data + return [d["embedding"] for d in data] + + +def get_embeddings_without_moderation( + texts: List[str], + source: Optional[str] = None, + **kwargs, +) -> List[List[float]]: + """ + Obtain embeddings without moderation checks. + + Parameters: + - texts (List[str]): List of texts to be embedded. + - source (Optional[str], optional): Source identifier to potentially adjust embedding bias. Defaults to None. + - **kwargs: Additional keyword arguments passed to the embedding function. + + Returns: + - List[List[float]]: List of embeddings for the provided texts. + """ + if not texts: + return [] + + texts = [text.replace("\n", " ") for text in texts] + if USE_OPENAI_EMBEDDINGS: + embeddings = _compute_openai_embeddings(texts, **kwargs) + elif hf_embedding_model: + embeddings = hf_embedding_model.embed_documents(texts) + else: + raise MissingEmbeddingModelError("No embedding model available.") + + # Bias adjustment + if source and (bias := EMBEDDING_LENGTH_BIAS.get(source, 1.0)): + embeddings = [[bias * e for e in embedding] for embedding in embeddings] + + return embeddings + + +def get_embeddings_or_none_if_flagged( + texts: List[str], + source: Optional[str] = None, + **kwargs, +) -> Tuple[List[List[float]] | None, List[ModerationInfoType]]: + """ + Obtain embeddings for the provided texts. If any text is flagged during moderation, + the function returns None for the embeddings while still providing the moderation results. + + Parameters: + - texts (List[str]): List of texts to be embedded. + - source (Optional[str], optional): Source identifier to potentially adjust embedding bias. Defaults to None. + - **kwargs: Additional keyword arguments passed to the embedding function. + + Returns: + - Tuple[Optional[List[List[float]]], ModerationInfoListType]: Tuple containing the list of embeddings (or None if any text is flagged) and the moderation results. + """ + moderation_results = moderation_check(texts) + if any(result["flagged"] for result in moderation_results): + return None, moderation_results + + embeddings = get_embeddings_without_moderation(texts, source, **kwargs) + return embeddings, moderation_results + + +def get_embeddings( + texts: List[str], + source: Optional[str] = None, + **kwargs, +) -> Tuple[List[List[float] | None], List[ModerationInfoType]]: + """ + Obtain embeddings for the provided texts, replacing the embeddings of flagged texts with `None`. + + Parameters: + - texts (List[str]): List of texts to be embedded. + - source (Optional[str], optional): Source identifier to potentially adjust embedding bias. Defaults to None. + - **kwargs: Additional keyword arguments passed to the embedding function. + + Returns: + - Tuple[List[Optional[List[float]]], ModerationInfoListType]: Tuple containing the list of embeddings (with None for flagged texts) and the moderation results. + """ + assert len(texts) <= 2048, "The batch size should not be larger than 2048." + assert all(texts), "No empty strings allowed in the input list." + + # replace newlines, which can negatively affect performance + texts = [text.replace("\n", " ") for text in texts] + + # Check all texts for moderation flags + moderation_results = moderation_check(texts) + flags = [result["flagged"] for result in moderation_results] + + non_flagged_texts = [text for text, flag in zip(texts, flags) if not flag] + non_flagged_embeddings = get_embeddings_without_moderation( + non_flagged_texts, source, **kwargs + ) + embeddings = [None if flag else non_flagged_embeddings.pop(0) for flag in flags] + return embeddings, moderation_results + + +def get_embedding( + text: str, source: Optional[str] = None, **kwargs +) -> Tuple[List[float] | None, ModerationInfoType]: + """Obtain an embedding for a single text.""" + embedding, moderation_result = get_embeddings([text], source, **kwargs) + return embedding[0], moderation_result[0] diff --git a/align_data/embeddings/finetuning/data/best_finetuned_model.pth b/align_data/embeddings/finetuning/data/best_finetuned_model.pth new file mode 100644 index 00000000..e05a5d52 Binary files /dev/null and b/align_data/embeddings/finetuning/data/best_finetuned_model.pth differ diff --git a/align_data/embeddings/finetuning/data/finetuned_model.pth b/align_data/embeddings/finetuning/data/finetuned_model.pth new file mode 100644 index 00000000..7ba44ac6 Binary files /dev/null and b/align_data/embeddings/finetuning/data/finetuned_model.pth differ diff --git a/align_data/embeddings/finetuning/finetuning_dataset.py b/align_data/embeddings/finetuning/finetuning_dataset.py new file mode 100644 index 00000000..8c5eec04 --- /dev/null +++ b/align_data/embeddings/finetuning/finetuning_dataset.py @@ -0,0 +1,105 @@ +import math +import random +from typing import List, Tuple, Generator +from collections import deque + +import torch +from torch.utils.data import IterableDataset, get_worker_info +from sqlalchemy.exc import OperationalError +from sqlalchemy.sql import func +from sqlalchemy.orm import Session + +from align_data.db.session import make_session, get_all_valid_article_ids +from align_data.embeddings.embedding_utils import get_embedding +from align_data.embeddings.pinecone.pinecone_db_handler import PineconeDB +from align_data.embeddings.text_splitter import ParagraphSentenceUnitTextSplitter +from align_data.embeddings.pinecone.update_pinecone import get_text_chunks +from align_data.db.models import Article + + +class FinetuningDataset(IterableDataset): + def __init__(self, num_batches_per_epoch: int, cache_size: int = 1280): + self.num_batches_per_epoch = num_batches_per_epoch + self.article_cache = deque(maxlen=cache_size) + + self.text_splitter = ParagraphSentenceUnitTextSplitter() + self.pinecone_db = PineconeDB() + + with make_session() as session: + self.all_article_ids = get_all_valid_article_ids(session) + self.total_articles = len(self.all_article_ids) + + def __len__(self): + return self.num_batches_per_epoch + + def __iter__(self): + start, end = 0, None + worker_info = get_worker_info() + if worker_info is not None: # Multi-process loading + per_worker = math.ceil(self.total_articles / worker_info.num_workers) + start = worker_info.id * per_worker + end = min(start + per_worker, self.total_articles) + + with make_session() as session: + return self._generate_pairs(session, start, end) + + def _fetch_random_articles(self, session: Session, batch_size: int = 1) -> List[Article]: + """Fetch a batch of random articles.""" + # If the list has fewer IDs than needed, raise an exception + random_selected_ids = random.sample(self.all_article_ids, batch_size) + return session.query(Article).filter(Article.id.in_(random_selected_ids)).all() + + def _get_random_chunks(self, article: Article, num_chunks: int = 2) -> List[Tuple[int, str]]: + chunked_text = get_text_chunks(article, self.text_splitter) + + chunks = list(enumerate(chunked_text)) + if len(chunks) < num_chunks: + return [] + + return random.sample(chunks, num_chunks) + + def _get_embeddings(self, article: Article, chunks: List[Tuple[int, str]]) -> List[List[float]]: + full_ids = [f"{article.id}_{str(idx).zfill(6)}" for idx, _ in chunks] + _embeddings = self.pinecone_db.get_embeddings_by_ids(full_ids) + + embeddings = [] + for (_, chunk), (_, embedding) in zip(chunks, _embeddings): + if embedding is None: + embedding, _ = get_embedding(chunk, article.source) + embeddings.append(torch.tensor(embedding)) + + return embeddings + + def _generate_pairs( + self, session, start=0, end=None, neg_pos_proportion=0.5 + ) -> Generator[Tuple[List[float], List[float], int], None, None]: + end = end or self.total_articles + + batches_yielded = 0 + while start < end: + start += 1 + if random.random() < neg_pos_proportion: + # Positive pairs + article = self._fetch_random_articles(session)[0] + chunks = self._get_random_chunks(article, 2) + if not chunks: + continue + embedding_1, embedding_2 = self._get_embeddings(article, chunks) + label = 1 + else: + # Negative pairs + article1, article2 = self._fetch_random_articles(session, batch_size=2) + chunk1 = self._get_random_chunks(article1, 1) + chunk2 = self._get_random_chunks(article2, 1) + embedding_1, embedding_2 = ( + self._get_embeddings(article1, chunk1)[0], + self._get_embeddings(article2, chunk2)[0], + ) + label = 0 + yield torch.tensor(embedding_1, dtype=torch.int64), torch.tensor( + embedding_2, dtype=torch.int64 + ), torch.tensor(label, dtype=torch.int64) + batches_yielded += 1 + + if self.num_batches_per_epoch and batches_yielded >= self.num_batches_per_epoch: + break diff --git a/align_data/embeddings/finetuning/training.py b/align_data/embeddings/finetuning/training.py new file mode 100644 index 00000000..cd1d7845 --- /dev/null +++ b/align_data/embeddings/finetuning/training.py @@ -0,0 +1,177 @@ +import os + +import torch +import torch.nn as nn +import torch.optim as optim +from torch.optim.lr_scheduler import ReduceLROnPlateau +from torch.utils.data import DataLoader + +from align_data.embeddings.finetuning.finetuning_dataset import FinetuningDataset +from align_data.settings import ( + PINECONE_VALUES_DIMS, + DEVICE, + OPENAI_FINETUNED_LAYER_PATH, + OPENAI_CURRENT_BEST_FINETUNED_LAYER_PATH, +) + + +class ContrastiveLoss(nn.Module): + def __init__(self, margin=2.0): + super(ContrastiveLoss, self).__init__() + self.margin = margin + + def forward(self, output1, output2, label): + euclidean_distance = nn.functional.pairwise_distance(output1, output2) + loss_contrastive = torch.mean( + (1 - label) * torch.pow(euclidean_distance, 2) + + (label) * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2) + ) + + return loss_contrastive + + +class NonLinearFineTuneModel(nn.Module): + def __init__(self, embedding_dim=PINECONE_VALUES_DIMS, hidden_dim=2000, dropout=0.5): + super(FineTuneModel, self).__init__() + + self.fc1 = nn.Linear(embedding_dim, hidden_dim) + self.fc2 = nn.Linear(hidden_dim, embedding_dim) + self.dropout = nn.Dropout(dropout) + + def forward(self, x): + x = nn.functional.relu(self.fc1(x)) + x = self.dropout(x) + x = self.fc2(x) + return x + + +class FineTuneModel(nn.Module): + def __init__(self, embedding_dim=PINECONE_VALUES_DIMS): + super(FineTuneModel, self).__init__() + + self.fc = nn.Linear(embedding_dim, embedding_dim) + + def forward(self, x): + x = self.fc(x) + return x + + +def train(model, dataloader, optimizer, criterion): + model.train() + total_loss = 0.0 + + for batch_idx, (text1_embedding, text2_embedding, target) in enumerate(dataloader): + text1_embedding = text1_embedding.to(DEVICE) + text2_embedding = text2_embedding.to(DEVICE) + target = target.float().to(DEVICE) + + optimizer.zero_grad() + + output1 = model(text1_embedding) + output2 = model(text2_embedding) + + loss = criterion(output1, output2, target) + loss.backward() + + # Gradient clipping + torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) + + optimizer.step() + + total_loss += loss.item() + + return total_loss / len(dataloader) + + +def validate(model, dataloader, criterion): + model.eval() + total_loss = 0.0 + + with torch.no_grad(): + for batch_idx, (text1_embedding, text2_embedding, target) in enumerate(dataloader): + text1_embedding = text1_embedding.to(DEVICE) + text2_embedding = text2_embedding.to(DEVICE) + target = target.float().to(DEVICE) + + output1 = model(text1_embedding) + output2 = model(text2_embedding) + + loss = criterion(output1, output2, target) + total_loss += loss.item() + + return total_loss / len(dataloader) + + +def finetune_embeddings(): + # Hyperparameters & Configuration + EPOCHS = 100 + BATCH_PER_EPOCH = 20 + BATCH_SIZE = 64 + LEARNING_RATE = 5.0000e-02 + MARGIN = 2.0 + + dataset = FinetuningDataset(num_batches_per_epoch=BATCH_PER_EPOCH) + dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, num_workers=5) + + model = FineTuneModel().to(DEVICE) + model = load_best_model_if_exists(model) + optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE) + scheduler = ReduceLROnPlateau(optimizer, "min", patience=2, factor=0.5, verbose=True) + criterion = ContrastiveLoss(MARGIN) + + # Assuming you've split your data and have a separate validation set + validation_dataset = FinetuningDataset(num_batches_per_epoch=BATCH_PER_EPOCH) + validation_dataloader = DataLoader(validation_dataset, batch_size=BATCH_SIZE, num_workers=5) + best_val_loss = validate(model, validation_dataloader, criterion) + print(f"Initial validation loss (from loaded model or new model): {best_val_loss:.4f}") + + epochs_without_improvement = 0 + max_epochs_without_improvement = 15 # stop after 5 epochs without improvement + + for epoch in range(EPOCHS): + train_loss = train(model, dataloader, optimizer, criterion) + validate_loss = validate(model, validation_dataloader, criterion) + + scheduler.step(validate_loss) + if validate_loss < best_val_loss: + best_val_loss = validate_loss + torch.save(model.state_dict(), OPENAI_CURRENT_BEST_FINETUNED_LAYER_PATH) + epochs_without_improvement = 0 + else: + epochs_without_improvement += 1 + + print( + f"Epoch: {epoch+1}/{EPOCHS} | Train Loss: {train_loss:.4f} | Validation Loss: {validate_loss:.4f}" + ) + + if epochs_without_improvement >= max_epochs_without_improvement: + print("Early stopping due to no improvement in validation loss.") + break + + torch.save(model.state_dict(), OPENAI_FINETUNED_LAYER_PATH) + + +### HELPER FUNCTIONS ### + + +def load_best_model_if_exists(model): + """ + Load the best saved model if it exists. + + Parameters: + - model (torch.nn.Module): The model architecture. + + Returns: + - model (torch.nn.Module): The loaded model. + """ + if os.path.exists(OPENAI_CURRENT_BEST_FINETUNED_LAYER_PATH): + model.load_state_dict( + torch.load(OPENAI_CURRENT_BEST_FINETUNED_LAYER_PATH, map_location=DEVICE) + ) + print(f"Loaded model from {OPENAI_CURRENT_BEST_FINETUNED_LAYER_PATH}.") + else: + print( + f"No saved model found at {OPENAI_CURRENT_BEST_FINETUNED_LAYER_PATH}. Starting from scratch." + ) + + return model diff --git a/align_data/pinecone/__init__.py b/align_data/embeddings/pinecone/__init__.py similarity index 100% rename from align_data/pinecone/__init__.py rename to align_data/embeddings/pinecone/__init__.py diff --git a/align_data/embeddings/pinecone/pinecone_db_handler.py b/align_data/embeddings/pinecone/pinecone_db_handler.py new file mode 100644 index 00000000..dd2990d3 --- /dev/null +++ b/align_data/embeddings/pinecone/pinecone_db_handler.py @@ -0,0 +1,142 @@ +# dataset/pinecone_db_handler.py +import logging +from typing import List, Tuple + +import pinecone +from pinecone.core.client.models import ScoredVector + +from align_data.embeddings.embedding_utils import get_embedding +from align_data.embeddings.pinecone.pinecone_models import ( + PineconeEntry, + PineconeMetadata, +) +from align_data.settings import ( + PINECONE_INDEX_NAME, + PINECONE_VALUES_DIMS, + PINECONE_METRIC, + PINECONE_API_KEY, + PINECONE_ENVIRONMENT, + PINECONE_NAMESPACE, +) + + +logger = logging.getLogger(__name__) + + +class PineconeDB: + def __init__( + self, + index_name: str = PINECONE_INDEX_NAME, + values_dims: int = PINECONE_VALUES_DIMS, + metric: str = PINECONE_METRIC, + create_index: bool = False, + log_index_stats: bool = False, + ): + self.index_name = index_name + self.values_dims = values_dims + self.metric = metric + + pinecone.init( + api_key=PINECONE_API_KEY, + environment=PINECONE_ENVIRONMENT, + ) + + if create_index: + self.create_index() + + self.index = pinecone.Index(index_name=self.index_name) + + if log_index_stats: + index_stats_response = self.index.describe_index_stats() + logger.info(f"{self.index_name}:\n{index_stats_response}") + + def upsert_entry(self, pinecone_entry: PineconeEntry, upsert_size: int = 100): + vectors = pinecone_entry.create_pinecone_vectors() + self.index.upsert(vectors=vectors, batch_size=upsert_size, namespace=PINECONE_NAMESPACE) + + def query_vector( + self, + query: List[float], + top_k: int = 10, + include_values: bool = False, + include_metadata: bool = True, + **kwargs, + ) -> List[ScoredVector]: + assert not isinstance( + query, str + ), "query must be a list of floats. Use query_PineconeDB_text for text queries" + + query_response = self.index.query( + vector=query, + top_k=top_k, + include_values=include_values, + include_metadata=include_metadata, + **kwargs, + namespace=PINECONE_NAMESPACE, + ) + + return [ + ScoredVector( + id=match["id"], + score=match["score"], + metadata=PineconeMetadata(**match["metadata"]), + ) + for match in query_response["matches"] + ] + + def query_text( + self, + query: str, + top_k: int = 10, + include_values: bool = False, + include_metadata: bool = True, + **kwargs, + ) -> List[ScoredVector]: + query_vector = get_embedding(query)[0] + return self.query_vector( + query=query_vector, + top_k=top_k, + include_values=include_values, + include_metadata=include_metadata, + **kwargs, + ) + + def delete_entries(self, ids): + self.index.delete(filter={"hash_id": {"$in": ids}}) + + def create_index(self, replace_current_index: bool = True): + if replace_current_index: + self.delete_index() + + pinecone.create_index( + name=self.index_name, + dimension=self.values_dims, + metric=self.metric, + metadata_config={"indexed": list(PineconeMetadata.__annotations__.keys())}, + ) + + def delete_index(self): + if self.index_name in pinecone.list_indexes(): + logger.info(f"Deleting index '{self.index_name}'.") + pinecone.delete_index(self.index_name) + + def get_embeddings_by_ids(self, ids: List[str]) -> List[Tuple[str, List[float] | None]]: + """ + Fetch embeddings for given entry IDs from Pinecone. + + Args: + - ids (List[str]): List of entry IDs for which embeddings are to be fetched. + + Returns: + - List[Tuple[str, List[float] | None]]: List of tuples containing ID and its corresponding embedding. + """ + # TODO: check that this still works + vectors = self.index.fetch( + ids=ids, + namespace=PINECONE_NAMESPACE, + )["vectors"] + return [(id, vectors.get(id, {}).get("values", None)) for id in ids] + + +def strip_block(text: str) -> str: + return "\n".join(text.split("\n")[1:]) diff --git a/align_data/embeddings/pinecone/pinecone_models.py b/align_data/embeddings/pinecone/pinecone_models.py new file mode 100644 index 00000000..fd7b67eb --- /dev/null +++ b/align_data/embeddings/pinecone/pinecone_models.py @@ -0,0 +1,77 @@ +from typing import List, TypedDict + +from pydantic import BaseModel, validator +from pinecone.core.client.models import Vector + + +class MissingFieldsError(Exception): + pass + + +class MissingEmbeddingModelError(Exception): + pass + + +class PineconeMetadata(TypedDict): + hash_id: str + source: str + title: str + url: str + date_published: float + authors: List[str] + text: str + + +class PineconeEntry(BaseModel): + hash_id: str + source: str + title: str + url: str + date_published: float + authors: List[str] + text_chunks: List[str] + embeddings: List[List[float]] + + def __init__(self, **data): + """Check for missing (falsy) fields before initializing.""" + missing_fields = [field for field, value in data.items() if not str(value).strip()] + + if missing_fields: + raise MissingFieldsError(f"Missing fields: {missing_fields}") + + super().__init__(**data) + + def __repr__(self): + def make_small(chunk: str) -> str: + return (chunk[:45] + " [...] " + chunk[-45:]) if len(chunk) > 100 else chunk + + def display_chunks(chunks_lst: List[str]) -> str: + chunks = ", ".join(f'"{make_small(chunk)}"' for chunk in chunks_lst) + return ( + f"[{chunks[:450]} [...] {chunks[-450:]} ]" if len(chunks) > 1000 else f"[{chunks}]" + ) + + return f"PineconeEntry(hash_id={self.hash_id!r}, source={self.source!r}, title={self.title!r}, url={self.url!r}, date_published={self.date_published!r}, authors={self.authors!r}, text_chunks={display_chunks(self.text_chunks)})" + + @property + def chunk_num(self) -> int: + return len(self.text_chunks) + + def create_pinecone_vectors(self) -> List[Vector]: + return [ + Vector( + id=f"{self.hash_id}_{str(i).zfill(6)}", + values=self.embeddings[i], + metadata=PineconeMetadata( + hash_id=self.hash_id, + source=self.source, + title=self.title, + authors=self.authors, + url=self.url, + date_published=self.date_published, + text=self.text_chunks[i], + ), + ) + for i in range(self.chunk_num) + if self.embeddings[i] # Skips flagged chunks + ] diff --git a/align_data/embeddings/pinecone/update_pinecone.py b/align_data/embeddings/pinecone/update_pinecone.py new file mode 100644 index 00000000..b425ee9d --- /dev/null +++ b/align_data/embeddings/pinecone/update_pinecone.py @@ -0,0 +1,129 @@ +from datetime import datetime +import logging +from itertools import islice +from typing import Callable, List, Tuple, Generator, Iterator, Optional + +from sqlalchemy.orm import Session +from pydantic import ValidationError + +from align_data.embeddings.embedding_utils import get_embeddings +from align_data.db.models import Article +from align_data.db.session import make_session, stream_pinecone_updates +from align_data.embeddings.pinecone.pinecone_db_handler import PineconeDB +from align_data.embeddings.pinecone.pinecone_models import ( + PineconeEntry, MissingFieldsError, MissingEmbeddingModelError +) +from align_data.embeddings.text_splitter import ParagraphSentenceUnitTextSplitter + + +logger = logging.getLogger(__name__) + + +# Define type aliases for the Callables +LengthFunctionType = Callable[[str], int] +TruncateFunctionType = Callable[[str, int], str] + + +class PineconeUpdater: + def __init__(self): + self.text_splitter = ParagraphSentenceUnitTextSplitter() + self.pinecone_db = PineconeDB() + + def update(self, custom_sources: List[str], force_update: bool = False): + """ + Update the given sources. If no sources are provided, updates all sources. + + :param custom_sources: List of sources to update. + """ + with make_session() as session: + articles_to_update_stream = stream_pinecone_updates( + session, custom_sources, force_update + ) + for batch in self.batch_entries(articles_to_update_stream): + self.save_batch(session, batch) + + def save_batch(self, session: Session, batch: List[Tuple[Article, PineconeEntry]]): + try: + for article, pinecone_entry in batch: + self.pinecone_db.upsert_entry(pinecone_entry) + + article.pinecone_update_required = False + session.add(article) + + session.commit() + + except Exception as e: + # Rollback on any kind of error. The next run will redo this batch, but in the meantime keep trucking + logger.error(e) + session.rollback() + + def batch_entries( + self, article_stream: Generator[Article, None, None] + ) -> Iterator[List[Tuple[Article, PineconeEntry]]]: + while batch := tuple(islice(article_stream, 10)): + yield [ + (article, pinecone_entry) + for article in batch + if (pinecone_entry := self._make_pinecone_entry(article)) is not None + ] + + def _make_pinecone_entry(self, article: Article) -> PineconeEntry | None: + try: + text_chunks = get_text_chunks(article, self.text_splitter) + embeddings, moderation_results = get_embeddings(text_chunks, article.source) + + if any(result['flagged'] for result in moderation_results): + flagged_text_chunks = [f"Chunk {i}: \"{text}\"" for i, (text, result) in enumerate(zip(text_chunks, moderation_results)) if result["flagged"]] + logger.warning(f"OpenAI moderation flagged text chunks for the following article: {article.id}") + article.append_comment(f"OpenAI moderation flagged the following text chunks: {flagged_text_chunks}") + + return PineconeEntry( + hash_id=article.id, # the hash_id of the article + source=article.source, + title=article.title, + url=article.url, + date_published=article.date_published.timestamp(), + authors=[author.strip() for author in article.authors.split(",") if author.strip()], + text_chunks=text_chunks, + embeddings=embeddings, + ) + except (ValueError, TypeError, AttributeError, ValidationError, MissingFieldsError, MissingEmbeddingModelError) as e: + logger.warning(e) + article.append_comment(f"Error encountered while processing this article: {e}") + return None + + except Exception as e: + logger.error(e) + raise + + +def get_text_chunks( + article: Article, text_splitter: ParagraphSentenceUnitTextSplitter +) -> List[str]: + title = article.title.replace("\n", " ") + + authors_lst = [author.strip() for author in article.authors.split(",")] + authors = get_authors_str(authors_lst) + + signature = f"Title: {title}; Author(s): {authors}." + text_chunks = text_splitter.split_text(article.text) + return [f'###{signature}###\n"""{text_chunk}"""' for text_chunk in text_chunks] + + +def get_authors_str(authors_lst: List[str]) -> str: + if not authors_lst: + return "n/a" + + if len(authors_lst) == 1: + authors_str = authors_lst[0] + else: + authors_lst = authors_lst[:4] + authors_str = f"{', '.join(authors_lst[:-1])} and {authors_lst[-1]}" + + authors_str = authors_str.replace("\n", " ") + + # Truncate if necessary + if len(authors_str) > 500: + authors_str = authors_str[:497] + "..." + + return authors_str diff --git a/align_data/pinecone/text_splitter.py b/align_data/embeddings/text_splitter.py similarity index 89% rename from align_data/pinecone/text_splitter.py rename to align_data/embeddings/text_splitter.py index b8af09a3..a364415e 100644 --- a/align_data/pinecone/text_splitter.py +++ b/align_data/embeddings/text_splitter.py @@ -1,5 +1,3 @@ -# dataset/text_splitter.py - from typing import List, Callable, Any from langchain.text_splitter import TextSplitter from nltk.tokenize import sent_tokenize @@ -11,8 +9,6 @@ StrToIntFunction = Callable[[str], int] StrIntBoolToStrFunction = Callable[[str, int, bool], str] -def default_truncate_function(string: str, length: int, from_end: bool = False) -> str: - return string[-length:] if from_end else string[:length] def default_truncate_function(string: str, length: int, from_end: bool = False) -> str: return string[-length:] if from_end else string[:length] @@ -26,7 +22,7 @@ class ParagraphSentenceUnitTextSplitter(TextSplitter): @param length_function: A function that returns the length of a string in units. Defaults to len(). @param truncate_function: A function that truncates a string to a given unit length. """ - + DEFAULT_MIN_CHUNK_SIZE: int = 900 DEFAULT_MAX_CHUNK_SIZE: int = 1100 DEFAULT_LENGTH_FUNCTION: StrToIntFunction = len @@ -38,7 +34,7 @@ def __init__( max_chunk_size: int = DEFAULT_MAX_CHUNK_SIZE, length_function: StrToIntFunction = DEFAULT_LENGTH_FUNCTION, truncate_function: StrIntBoolToStrFunction = DEFAULT_TRUNCATE_FUNCTION, - **kwargs: Any + **kwargs: Any, ): super().__init__(**kwargs) self.min_chunk_size = min_chunk_size @@ -49,6 +45,9 @@ def __init__( def split_text(self, text: str) -> List[str]: """Split text into chunks of length between min_chunk_size and max_chunk_size.""" + if not text: + return [] + blocks: List[str] = [] current_block: str = "" @@ -90,13 +89,11 @@ def _handle_large_paragraph(self, current_block: str, blocks: List[str], paragra def _truncate_large_block(self, current_block: str, blocks: List[str]) -> str: while self._length_function(current_block) > self.max_chunk_size: # Truncate current_block to max size, set remaining text as current_block - truncated_block = self._truncate_function( - current_block, self.max_chunk_size - ) + truncated_block = self._truncate_function(current_block, self.max_chunk_size, False) blocks.append(truncated_block) - current_block = current_block[len(truncated_block):].lstrip() - + current_block = current_block[len(truncated_block) :].lstrip() + return current_block def _handle_remaining_text(self, last_block: str, blocks: List[str]) -> List[str]: @@ -107,9 +104,7 @@ def _handle_remaining_text(self, last_block: str, blocks: List[str]) -> List[str if self.min_chunk_size - len_last_block > 0: # Add text from previous block to last block if last_block is too short part_prev_block = self._truncate_function( - string=blocks[-1], - length=self.min_chunk_size - len_last_block, - from_end=True + blocks[-1], self.min_chunk_size - len_last_block, True ) last_block = part_prev_block + last_block diff --git a/align_data/pinecone/pinecone_db_handler.py b/align_data/pinecone/pinecone_db_handler.py deleted file mode 100644 index d8f565df..00000000 --- a/align_data/pinecone/pinecone_db_handler.py +++ /dev/null @@ -1,91 +0,0 @@ -# dataset/pinecone_db_handler.py - -import logging -from typing import Dict - -import pinecone - -from align_data.settings import ( - PINECONE_INDEX_NAME, - PINECONE_VALUES_DIMS, - PINECONE_METRIC, - PINECONE_METADATA_KEYS, - PINECONE_API_KEY, - PINECONE_ENVIRONMENT, -) - - -logger = logging.getLogger(__name__) - - -class PineconeDB: - def __init__( - self, - index_name: str = PINECONE_INDEX_NAME, - values_dims: int = PINECONE_VALUES_DIMS, - metric: str = PINECONE_METRIC, - metadata_keys: list = PINECONE_METADATA_KEYS, - create_index: bool = False, - log_index_stats: bool = True, - ): - self.index_name = index_name - self.values_dims = values_dims - self.metric = metric - self.metadata_keys = metadata_keys - - pinecone.init( - api_key=PINECONE_API_KEY, - environment=PINECONE_ENVIRONMENT, - ) - - if create_index: - self.create_index() - - self.index = pinecone.Index(index_name=self.index_name) - - if log_index_stats: - index_stats_response = self.index.describe_index_stats() - logger.info(f"{self.index_name}:\n{index_stats_response}") - - def upsert_entry(self, entry: Dict, upsert_size=100): - self.index.upsert( - vectors=list( - zip( - [ - f"{entry['id']}_{str(i).zfill(6)}" - for i in range(len(entry["text_chunks"])) - ], - entry["embeddings"].tolist(), - [ - { - "entry_id": entry["id"], - "source": entry["source"], - "title": entry["title"], - "authors": entry["authors"], - "text": text_chunk, - } - for text_chunk in entry["text_chunks"] - ], - ) - ), - batch_size=upsert_size, - ) - - def delete_entries(self, ids): - self.index.delete(filter={"entry_id": {"$in": ids}}) - - def create_index(self, replace_current_index: bool = True): - if replace_current_index: - self.delete_index() - - pinecone.create_index( - name=self.index_name, - dimension=self.values_dims, - metric=self.metric, - metadata_config={"indexed": self.metadata_keys}, - ) - - def delete_index(self): - if self.index_name in pinecone.list_indexes(): - logger.info(f"Deleting index '{self.index_name}'.") - pinecone.delete_index(self.index_name) diff --git a/align_data/pinecone/update_pinecone.py b/align_data/pinecone/update_pinecone.py deleted file mode 100644 index 7e64e560..00000000 --- a/align_data/pinecone/update_pinecone.py +++ /dev/null @@ -1,193 +0,0 @@ -from datetime import datetime -import logging -import numpy as np -import os -from itertools import islice -from typing import Callable, List, Tuple, Generator - -import openai -from pydantic import BaseModel, ValidationError, validator - -from align_data.db.models import Article -from align_data.db.session import make_session, stream_pinecone_updates -from align_data.pinecone.pinecone_db_handler import PineconeDB -from align_data.pinecone.text_splitter import ParagraphSentenceUnitTextSplitter -from align_data.settings import ( - USE_OPENAI_EMBEDDINGS, - OPENAI_EMBEDDINGS_MODEL, - OPENAI_EMBEDDINGS_DIMS, - OPENAI_EMBEDDINGS_RATE_LIMIT, - SENTENCE_TRANSFORMER_EMBEDDINGS_MODEL, - SENTENCE_TRANSFORMER_EMBEDDINGS_DIMS, - CHUNK_SIZE, - MAX_NUM_AUTHORS_IN_SIGNATURE, - EMBEDDING_LENGTH_BIAS, -) - - -logger = logging.getLogger(__name__) - - -# Define type aliases for the Callables -LengthFunctionType = Callable[[str], int] -TruncateFunctionType = Callable[[str, int], str] - - -class PineconeEntry(BaseModel): - id: str - source: str - title: str - url: str - date_published: datetime - authors: List[str] - text_chunks: List[str] - embeddings: np.ndarray - - class Config: - arbitrary_types_allowed = True - - def __repr__(self): - return f"PineconeEntry(id={self.id!r}, source={self.source!r}, title={self.title!r}, url={self.url!r}, date_published={self.date_published!r}, authors={self.authors!r}, text_chunks={self.text_chunks[:5]!r})" - - @validator( - "id", - "source", - "title", - "url", - "date_published", - "authors", - "text_chunks", - pre=True, - always=True, - ) - def empty_strings_not_allowed(cls, value): - if not str(value).strip(): - raise ValueError("Attribute should not be empty.") - return value - - -class PineconeUpdater: - def __init__( - self, - min_chunk_size: int = ParagraphSentenceUnitTextSplitter.DEFAULT_MIN_CHUNK_SIZE, - max_chunk_size: int = ParagraphSentenceUnitTextSplitter.DEFAULT_MAX_CHUNK_SIZE, - length_function: LengthFunctionType = ParagraphSentenceUnitTextSplitter.DEFAULT_LENGTH_FUNCTION, - truncate_function: TruncateFunctionType = ParagraphSentenceUnitTextSplitter.DEFAULT_TRUNCATE_FUNCTION, - ): - self.min_chunk_size = min_chunk_size - self.max_chunk_size = max_chunk_size - self.length_function = length_function - self.truncate_function = truncate_function - - self.text_splitter = ParagraphSentenceUnitTextSplitter( - min_chunk_size=self.min_chunk_size, - max_chunk_size=self.max_chunk_size, - length_function=self.length_function, - truncate_function=self.truncate_function, - ) - self.pinecone_db = PineconeDB() - - if USE_OPENAI_EMBEDDINGS: - import openai - - openai.api_key = os.environ["OPENAI_API_KEY"] - else: - import torch - from langchain.embeddings import HuggingFaceEmbeddings - - self.hf_embeddings = HuggingFaceEmbeddings( - model_name=SENTENCE_TRANSFORMER_EMBEDDINGS_MODEL, - model_kwargs={"device": "cuda" if torch.cuda.is_available() else "cpu"}, - encode_kwargs={"show_progress_bar": False}, - ) - - def save_batch(self, session, batch): - try: - for article, pinecone_entry in batch: - self.pinecone_db.upsert_entry(pinecone_entry.dict()) - article.pinecone_update_required = False - session.add(article) - session.commit() - except Exception as e: - # Rollback on any kind of error. The next run will redo this batch, but in the meantime keep trucking - logger.error(e) - session.rollback() - - def update(self, custom_sources: List[str]): - """ - Update the given sources. If no sources are provided, updates all sources. - - :param custom_sources: List of sources to update. - """ - with make_session() as session: - entries_stream = stream_pinecone_updates(session, custom_sources) - for batch in self.batch_entries(entries_stream): - self.save_batch(session, batch) - - def _make_pinecone_update(self, article): - try: - text_chunks = self.get_text_chunks(article) - return article, PineconeEntry( - id=article.id, - source=article.source, - title=article.title, - url=article.url, - date_published=article.date_published, - authors=[ - author.strip() - for author in article.authors.split(",") - if author.strip() - ], - text_chunks=text_chunks, - embeddings=self.extract_embeddings( - text_chunks, [article.source] * len(text_chunks) - ), - ) - except (ValueError, ValidationError) as e: - logger.exception(e) - - def batch_entries( - self, article_stream: Generator[Article, None, None] - ) -> Generator[List[Tuple[Article, PineconeEntry]], None, None]: - items = iter(article_stream) - while batch := tuple(islice(items, 10)): - yield list(filter(None, map(self._make_pinecone_update, batch))) - - def get_text_chunks(self, article: Article) -> List[str]: - signature = f"Title: {article.title}, Author(s): {self.get_authors_str(article.authors)}" - text_chunks = self.text_splitter.split_text(article.text) - text_chunks = [f"- {signature}\n\n{text_chunk}" for text_chunk in text_chunks] - return text_chunks - - def extract_embeddings(self, chunks_batch, sources_batch): - if USE_OPENAI_EMBEDDINGS: - return self.get_openai_embeddings(chunks_batch, sources_batch) - else: - return np.array( - self.hf_embeddings.embed_documents(chunks_batch, sources_batch) - ) - - @staticmethod - def get_openai_embeddings(chunks, sources=""): - embeddings = np.zeros((len(chunks), OPENAI_EMBEDDINGS_DIMS)) - - openai_output = openai.Embedding.create( - model=OPENAI_EMBEDDINGS_MODEL, input=chunks - )["data"] - - for i, (embedding, source) in enumerate(zip(openai_output, sources)): - bias = EMBEDDING_LENGTH_BIAS.get(source, 1.0) - embeddings[i] = bias * np.array(embedding["embedding"]) - - return embeddings - - @staticmethod - def get_authors_str(authors_lst: List[str]) -> str: - if authors_lst == []: - return "n/a" - if len(authors_lst) == 1: - return authors_lst[0] - else: - authors_lst = authors_lst[:MAX_NUM_AUTHORS_IN_SIGNATURE] - authors_str = f"{', '.join(authors_lst[:-1])} and {authors_lst[-1]}" - return authors_str diff --git a/align_data/postprocess/postprocess.py b/align_data/postprocess/postprocess.py index cb16b7a5..05e9dbde 100644 --- a/align_data/postprocess/postprocess.py +++ b/align_data/postprocess/postprocess.py @@ -30,13 +30,15 @@ def compute_statistics(self) -> None: for source_name, path in tqdm(zip(self.source_list, self.jsonl_list)): with jsonlines.open(path) as reader: for obj in reader: - text = obj['text'] + text = obj["text"] source_stats = self.all_stats[source_name] source_stats["num_entries"] += 1 source_stats["num_tokens"] += len(text.split()) # TODO: Use tokenizer source_stats["num_chars"] += len(text) source_stats["num_words"] += len(text.split()) - source_stats["num_sentences"] += len(text.split(".")) # TODO: Use NLTK/Spacy or similar + source_stats["num_sentences"] += len( + text.split(".") + ) # TODO: Use NLTK/Spacy or similar source_stats["num_paragraphs"] += len(text.splitlines()) def plot_statistics(self) -> None: diff --git a/align_data/settings.py b/align_data/settings.py index fb8c8f45..2ae060e5 100644 --- a/align_data/settings.py +++ b/align_data/settings.py @@ -1,10 +1,13 @@ import os import logging +from typing import Dict +import openai +import torch from dotenv import load_dotenv load_dotenv() -LOG_LEVEL = os.environ.get('LOG_LEVEL', 'WARNING').upper() +LOG_LEVEL = os.environ.get("LOG_LEVEL", "WARNING").upper() logging.basicConfig(level=LOG_LEVEL) ### CODA ### @@ -24,7 +27,7 @@ "METADATA_OUTPUT_SPREADSHEET", "1l3azVJVukGAvZPgg0GyeqiaQe8bEMZvycBJaA8cRXf4" ) -### YouTube ### +### YOUTUBE ### YOUTUBE_API_KEY = os.environ.get("YOUTUBE_API_KEY") ### MYSQL ### @@ -37,13 +40,16 @@ ### EMBEDDINGS ### USE_OPENAI_EMBEDDINGS = True # If false, SentenceTransformer embeddings will be used. -EMBEDDING_LENGTH_BIAS = { - "aisafety.info": 1.05, # In search, favor AISafety.info entries. +EMBEDDING_LENGTH_BIAS: Dict[str, float] = { + # TODO: Experiement with these values. For now, let's remove the bias. + # "aisafety.info": 1.05, # In search, favor AISafety.info entries. } OPENAI_EMBEDDINGS_MODEL = "text-embedding-ada-002" OPENAI_EMBEDDINGS_DIMS = 1536 OPENAI_EMBEDDINGS_RATE_LIMIT = 3500 +openai.api_key = os.environ.get("OPENAI_API_KEY", None) +openai.organization = os.environ.get("OPENAI_ORGANIZATION", None) SENTENCE_TRANSFORMER_EMBEDDINGS_MODEL = "sentence-transformers/multi-qa-mpnet-base-cos-v1" SENTENCE_TRANSFORMER_EMBEDDINGS_DIMS = 768 @@ -53,15 +59,20 @@ PINECONE_API_KEY = os.environ.get("PINECONE_API_KEY", None) PINECONE_ENVIRONMENT = os.environ.get("PINECONE_ENVIRONMENT", None) PINECONE_VALUES_DIMS = ( - OPENAI_EMBEDDINGS_DIMS - if USE_OPENAI_EMBEDDINGS - else SENTENCE_TRANSFORMER_EMBEDDINGS_DIMS + OPENAI_EMBEDDINGS_DIMS if USE_OPENAI_EMBEDDINGS else SENTENCE_TRANSFORMER_EMBEDDINGS_DIMS ) PINECONE_METRIC = "dotproduct" -PINECONE_METADATA_KEYS = ["entry_id", "source", "title", "authors", "text", "url"] +PINECONE_NAMESPACE = os.environ.get("PINECONE_NAMESPACE", "normal") # "normal" or "finetuned" -### MISCELLANEOUS ### -CHUNK_SIZE = 1750 -MAX_NUM_AUTHORS_IN_SIGNATURE = 3 +### FINE-TUNING ### +OPENAI_FINETUNED_LAYER_PATH = os.environ.get( + "OPENAI_FINETUNED_LAYER_PATH", "align_data/finetuning/data/finetuned_model.pth" +) +OPENAI_CURRENT_BEST_FINETUNED_LAYER_PATH = os.environ.get( + "OPENAI_CURRENT_BEST_FINETUNED_LAYER_PATH", + "align_data/finetuning/data/best_finetuned_model.pth", +) +### MISCELLANEOUS ### MIN_CONFIDENCE = 50 +DEVICE = "cuda" if torch.cuda.is_available() else "cpu" diff --git a/align_data/sources/articles/__init__.py b/align_data/sources/articles/__init__.py index 4f080409..6fd45fbc 100644 --- a/align_data/sources/articles/__init__.py +++ b/align_data/sources/articles/__init__.py @@ -1,6 +1,12 @@ from align_data.sources.articles.datasets import ( - ArxivPapers, EbookArticles, DocArticles, HTMLArticles, - MarkdownArticles, PDFArticles, SpecialDocs, XMLArticles + ArxivPapers, + EbookArticles, + DocArticles, + HTMLArticles, + MarkdownArticles, + PDFArticles, + SpecialDocs, + XMLArticles, ) from align_data.sources.articles.indices import IndicesDataset from align_data.common.alignment_dataset import MultiDataset @@ -38,9 +44,9 @@ sheet_id="1293295703", ), SpecialDocs( - 'special_docs', - spreadsheet_id='1pgG3HzercOhf4gniaqp3tBc3uvZnHpPhXErwHcthmbI', - sheet_id='980957638', + "special_docs", + spreadsheet_id="1pgG3HzercOhf4gniaqp3tBc3uvZnHpPhXErwHcthmbI", + sheet_id="980957638", ), ] @@ -52,5 +58,5 @@ spreadsheet_id="1pgG3HzercOhf4gniaqp3tBc3uvZnHpPhXErwHcthmbI", sheet_id="655836697", ), - IndicesDataset('indices'), + IndicesDataset("indices"), ] diff --git a/align_data/sources/articles/articles.py b/align_data/sources/articles/articles.py index 39cc31b7..7db94a7b 100644 --- a/align_data/sources/articles/articles.py +++ b/align_data/sources/articles/articles.py @@ -105,8 +105,10 @@ def process_spreadsheets(source_sheet, output_sheets): row["source_url"] = row["url"] if row.get("source_url") in seen: logger.info(f'skipping "{title}", as it has already been seen') - elif row.get('status'): - logger.info(f'skipping "{title}", as it has a status set - remove it for this row to be processed') + elif row.get("status"): + logger.info( + f'skipping "{title}", as it has a status set - remove it for this row to be processed' + ) else: process_row(row, output_sheets) @@ -114,9 +116,7 @@ def process_spreadsheets(source_sheet, output_sheets): def update_new_items(source_spreadsheet, source_sheet, output_spreadsheet): """Go through all unprocessed items from the source worksheet, updating the appropriate metadata in the output one.""" source_sheet = get_sheet(source_spreadsheet, source_sheet) - sheets = { - sheet.title: sheet for sheet in get_spreadsheet(output_spreadsheet).worksheets() - } + sheets = {sheet.title: sheet for sheet in get_spreadsheet(output_spreadsheet).worksheets()} return process_spreadsheets(source_sheet, sheets) @@ -136,8 +136,7 @@ def check_new_articles(source_spreadsheet, source_sheet): missing = [ item for title, item in indices_items.items() - if title not in current - and not {item.get("url"), item.get("source_url")} & seen_urls + if title not in current and not {item.get("url"), item.get("source_url")} & seen_urls ] if not missing: logger.info("No new articles found") @@ -153,14 +152,12 @@ def check_new_articles(source_spreadsheet, source_sheet): "publication_title", "source_type", ] - res = source_sheet.append_rows( - [[item.get(col) for col in columns] for item in missing] - ) + res = source_sheet.append_rows([[item.get(col) for col in columns] for item in missing]) updated = res["updates"]["updatedRows"] logger.info("Added %s rows", updated) return updated def update_articles(csv_file, delimiter): - dataset = ReplacerDataset(name='updater', csv_path=csv_file, delimiter=delimiter) + dataset = ReplacerDataset(name="updater", csv_path=csv_file, delimiter=delimiter) dataset.add_entries(dataset.fetch_entries()) diff --git a/align_data/sources/articles/datasets.py b/align_data/sources/articles/datasets.py index 5c031fcf..37ab0940 100644 --- a/align_data/sources/articles/datasets.py +++ b/align_data/sources/articles/datasets.py @@ -16,10 +16,16 @@ from align_data.db.models import Article from align_data.sources.articles.google_cloud import fetch_file, fetch_markdown from align_data.sources.articles.parsers import ( - HTML_PARSERS, extract_gdrive_contents, item_metadata, parse_domain + HTML_PARSERS, + extract_gdrive_contents, + item_metadata, + parse_domain, ) from align_data.sources.articles.pdf import read_pdf -from align_data.sources.arxiv_papers import fetch_arxiv, canonical_url as arxiv_cannonical_url +from align_data.sources.arxiv_papers import ( + fetch_arxiv, + canonical_url as arxiv_cannonical_url, +) logger = logging.getLogger(__name__) @@ -44,8 +50,8 @@ def get_item_key(self, item): @property def items_list(self): - url = f'https://docs.google.com/spreadsheets/d/{self.spreadsheet_id}/export?format=csv&gid={self.sheet_id}' - logger.info(f'Fetching {url}') + url = f"https://docs.google.com/spreadsheets/d/{self.spreadsheet_id}/export?format=csv&gid={self.sheet_id}" + logger.info(f"Fetching {url}") df = pd.read_csv(url) return (item for item in df.itertuples() if self.get_item_key(item)) @@ -85,7 +91,6 @@ def process_entry(self, item): class SpecialDocs(SpreadsheetDataset): - @property def _query_items(self): special_docs_types = ["pdf", "html", "xml", "markdown", "docx"] @@ -96,35 +101,39 @@ def get_contents(self, item) -> Dict: if url := self.maybe(item, "source_url") or self.maybe(item, "url"): contents = item_metadata(url) - return dict(contents, **{ - 'url': self.maybe(item, "url"), - 'title': self.maybe(item, "title") or contents.get('title'), - 'source': contents.get('source_type') or self.name, - 'source_url': self.maybe(item, "source_url"), - 'source_type': contents.get('source_type') or self.maybe(item, "source_type"), - 'date_published': self._get_published_date(self.maybe(item, 'date_published')) or contents.get('date_published'), - 'authors': self.extract_authors(item) or contents.get('authors', []), - 'text': contents.get('text'), - 'status': 'Invalid' if contents.get('error') else None, - 'comments': contents.get('error'), - }) + return dict( + contents, + **{ + "url": self.maybe(item, "url"), + "title": self.maybe(item, "title") or contents.get("title"), + "source": contents.get("source_type") or self.name, + "source_url": self.maybe(item, "source_url"), + "source_type": contents.get("source_type") or self.maybe(item, "source_type"), + "date_published": self._get_published_date(self.maybe(item, "date_published")) + or contents.get("date_published"), + "authors": self.extract_authors(item) or contents.get("authors", []), + "text": contents.get("text"), + "status": "Invalid" if contents.get("error") else None, + "comments": contents.get("error"), + }, + ) def not_processed(self, item): - url = self.maybe(item, 'url') - source_url = self.maybe(item, 'source_url') + url = self.maybe(item, "url") + source_url = self.maybe(item, "source_url") return ( - self.get_item_key(item) not in self._outputted_items and - url not in self._outputted_items and - source_url not in self._outputted_items and - (not url or arxiv_cannonical_url(url) not in self._outputted_items) and - (not source_url or arxiv_cannonical_url(source_url) not in self._outputted_items) + self.get_item_key(item) not in self._outputted_items + and url not in self._outputted_items + and source_url not in self._outputted_items + and (not url or arxiv_cannonical_url(url) not in self._outputted_items) + and (not source_url or arxiv_cannonical_url(source_url) not in self._outputted_items) ) def process_entry(self, item): if ArxivPapers.is_arxiv(item.url): contents = ArxivPapers.get_contents(item) - contents['source'] = 'arxiv' + contents["source"] = "arxiv" else: contents = self.get_contents(item) @@ -154,7 +163,7 @@ def _get_text(item): domain = parse_domain(item.source_url) if parser := HTML_PARSERS.get(domain): res = parser(item.source_url) - return res and res.get('text') + return res and res.get("text") class EbookArticles(SpreadsheetDataset): @@ -168,9 +177,7 @@ def setup(self): def _get_text(self, item): file_id = item.source_url.split("/")[-2] - filename = download( - output=str(self.files_path / f"{item.title}.epub"), id=file_id - ) + filename = download(output=str(self.files_path / f"{item.title}.epub"), id=file_id) return convert_file(filename, "plain", "epub", extra_args=["--wrap=none"]) @@ -221,12 +228,12 @@ def get_contents(cls, item) -> Dict: contents = fetch_arxiv(item.url or item.source_url) if cls.maybe(item, "authors") and item.authors.strip(): - contents['authors'] = [i.strip() for i in item.authors.split(',')] + contents["authors"] = [i.strip() for i in item.authors.split(",")] if cls.maybe(item, "title"): - contents['title'] = cls.maybe(item, "title") + contents["title"] = cls.maybe(item, "title") - contents['date_published'] = cls._get_published_date( - cls.maybe(item, "date_published") or contents.get('date_published') + contents["date_published"] = cls._get_published_date( + cls.maybe(item, "date_published") or contents.get("date_published") ) return contents diff --git a/align_data/sources/articles/google_cloud.py b/align_data/sources/articles/google_cloud.py index adb6fa61..ca310235 100644 --- a/align_data/sources/articles/google_cloud.py +++ b/align_data/sources/articles/google_cloud.py @@ -146,17 +146,17 @@ def fetch_markdown(file_id): "source_type": "markdown", } except Exception as e: - return {'error': str(e)} + return {"error": str(e)} def parse_grobid(contents): doc_dict = grobid_tei_xml.parse_document_xml(contents).to_dict() - authors = [xx["full_name"].strip(' !') for xx in doc_dict.get("header", {}).get("authors", [])] + authors = [xx["full_name"].strip(" !") for xx in doc_dict.get("header", {}).get("authors", [])] - if not doc_dict.get('body'): + if not doc_dict.get("body"): return { - 'error': 'No contents in XML file', - 'source_type': 'xml', + "error": "No contents in XML file", + "source_type": "xml", } return { @@ -169,67 +169,72 @@ def parse_grobid(contents): def get_content_type(res): - header = res.headers.get('Content-Type') or '' - parts = [c_type.strip().lower() for c_type in header.split(';')] + header = res.headers.get("Content-Type") or "" + parts = [c_type.strip().lower() for c_type in header.split(";")] return set(filter(None, parts)) def extract_gdrive_contents(link): - file_id = link.split('/')[-2] - url = f'https://drive.google.com/uc?id={file_id}' - res = fetch(url, 'head') + file_id = link.split("/")[-2] + url = f"https://drive.google.com/uc?id={file_id}" + res = fetch(url, "head") if res.status_code == 403: - logger.error('Could not fetch the file at %s - 403 returned', link) - return {'error': 'Could not read file from google drive - forbidden'} + logger.error("Could not fetch the file at %s - 403 returned", link) + return {"error": "Could not read file from google drive - forbidden"} if res.status_code >= 400: - logger.error('Could not fetch the file at %s - are you sure that link is correct?', link) - return {'error': 'Could not read file from google drive'} + logger.error("Could not fetch the file at %s - are you sure that link is correct?", link) + return {"error": "Could not read file from google drive"} result = { - 'source_url': link, - 'downloaded_from': 'google drive', + "source_url": link, + "downloaded_from": "google drive", } content_type = get_content_type(res) if not content_type: - result['error'] = 'no content type' - elif content_type & {'application/octet-stream', 'application/pdf'}: + result["error"] = "no content type" + elif content_type & {"application/octet-stream", "application/pdf"}: result.update(fetch_pdf(url)) - elif content_type & {'text/markdown'}: + elif content_type & {"text/markdown"}: result.update(fetch_markdown(file_id)) - elif content_type & {'application/epub+zip', 'application/epub'}: - result['source_type'] = 'ebook' - elif content_type & {'text/html'}: + elif content_type & {"application/epub+zip", "application/epub"}: + result["source_type"] = "ebook" + elif content_type & {"text/html"}: res = fetch(url) - if 'Google Drive - Virus scan warning' in res.text: + if "Google Drive - Virus scan warning" in res.text: soup = BeautifulSoup(res.content, "html.parser") - res = fetch(soup.select_one('form').get('action')) + res = fetch(soup.select_one("form").get("action")) content_type = get_content_type(res) - if content_type & {'text/xml'}: + if content_type & {"text/xml"}: result.update(parse_grobid(res.content)) - elif content_type & {'text/html'}: + elif content_type & {"text/html"}: soup = BeautifulSoup(res.content, "html.parser") - result.update({ - 'text': MarkdownConverter().convert_soup(soup.select_one('body')).strip(), - 'source_type': 'html', - }) + result.update( + { + "text": MarkdownConverter().convert_soup(soup.select_one("body")).strip(), + "source_type": "html", + } + ) else: - result['error'] = f'unknown content type: {content_type}' + result["error"] = f"unknown content type: {content_type}" else: - result['error'] = f'unknown content type: {content_type}' + result["error"] = f"unknown content type: {content_type}" return result def google_doc(url: str) -> Dict: """Fetch the contents of the given gdoc url as markdown.""" - res = re.search(r'https://docs.google.com/document/(?:u/)?(?:0/)?d/(.*?)/', url) + res = re.search(r"https://docs.google.com/document/(?:u/)?(?:0/)?d/(.*?)/", url) if not res: return {} doc_id = res.group(1) - body = fetch_element(f'https://docs.google.com/document/d/{doc_id}/export?format=html', 'body') + body = fetch_element(f"https://docs.google.com/document/d/{doc_id}/export?format=html", "body") if body: - return {'text': MarkdownConverter().convert_soup(body).strip(), 'source_url': url} + return { + "text": MarkdownConverter().convert_soup(body).strip(), + "source_url": url, + } return {} diff --git a/align_data/sources/articles/html.py b/align_data/sources/articles/html.py index 7df71860..d3c2490c 100644 --- a/align_data/sources/articles/html.py +++ b/align_data/sources/articles/html.py @@ -70,9 +70,9 @@ def getter(url): for e in elem.select(sel): e.extract() return { - 'text': MarkdownConverter().convert_soup(elem).strip(), - 'source_url': url, - 'source_type': 'html', + "text": MarkdownConverter().convert_soup(elem).strip(), + "source_url": url, + "source_type": "html", } return getter diff --git a/align_data/sources/articles/indices.py b/align_data/sources/articles/indices.py index 60075ef2..5e7d5a05 100644 --- a/align_data/sources/articles/indices.py +++ b/align_data/sources/articles/indices.py @@ -239,49 +239,58 @@ def fetch_all(): class IndicesDataset(AlignmentDataset): - - done_key = 'url' + done_key = "url" @property def items_list(self): return fetch_all().values() def get_item_key(self, item): - return item.get('url') + return item.get("url") @staticmethod def extract_authors(item): - if authors := (item.get('authors') or '').strip(): - return [author.strip() for author in authors.split(',') if author.strip()] + if authors := (item.get("authors") or "").strip(): + return [author.strip() for author in authors.split(",") if author.strip()] return [] def process_entry(self, item): contents = {} - url = item.get('source_url') or item.get('url') + url = item.get("source_url") or item.get("url") if url: - contents= item_metadata(url) - - if not contents.get('text'): - logger.error('Could not get text for %s (%s) - %s - skipping for now', item.get('title'), url, contents.get('error')) + contents = item_metadata(url) + + if not contents.get("text"): + logger.error( + "Could not get text for %s (%s) - %s - skipping for now", + item.get("title"), + url, + contents.get("error"), + ) return None # If the article is not an arxiv paper, just mark it as ignored - if in the future editors # decide it's worth adding, it can be fetched then - if parse_domain(url or '') != 'arxiv.org': - return self.make_data_entry({ - 'source': self.name, - 'url': self.get_item_key(item), - 'title': item.get('title'), - 'date_published': self._get_published_date(item.get('date_published')), - 'authors': self.extract_authors(item), - 'status': 'Ignored', - 'comments': 'Added from indices', - }) - - return self.make_data_entry({ - 'source': 'arxiv', - 'url': contents.get('url') or self.get_item_key(item), - 'title': item.get('title'), - 'date_published': self._get_published_date(item.get('date_published')), - 'authors': self.extract_authors(item), - }, **contents) + if parse_domain(url or "") != "arxiv.org": + return self.make_data_entry( + { + "source": self.name, + "url": self.get_item_key(item), + "title": item.get("title"), + "date_published": self._get_published_date(item.get("date_published")), + "authors": self.extract_authors(item), + "status": "Ignored", + "comments": "Added from indices", + } + ) + + return self.make_data_entry( + { + "source": "arxiv", + "url": contents.get("url") or self.get_item_key(item), + "title": item.get("title"), + "date_published": self._get_published_date(item.get("date_published")), + "authors": self.extract_authors(item), + }, + **contents, + ) diff --git a/align_data/sources/articles/parsers.py b/align_data/sources/articles/parsers.py index 36736c95..c210e565 100644 --- a/align_data/sources/articles/parsers.py +++ b/align_data/sources/articles/parsers.py @@ -24,27 +24,29 @@ def get_pdf_from_page(*link_selectors: str): :param List[str] link_selectors: CSS selector used to find the final download link :returns: the contents of the pdf file as a string """ + def getter(url: str): link: str = url for selector in link_selectors: elem = fetch_element(link, selector) if not elem: - return {'error': f'Could not find pdf download link for {link} using \'{selector}\''} + return {"error": f"Could not find pdf download link for {link} using '{selector}'"} - link = elem.get('href') - if not link.startswith('http') or not link.startswith('//'): + link = elem.get("href") + if not link.startswith("http") or not link.startswith("//"): link = urljoin(url, link) # Some pages keep link to google drive previews of pdf files, which need to be # mangled to get the URL of the actual pdf file - if 'drive.google.com' in link and '/view' in link: + if "drive.google.com" in link and "/view" in link: return extract_gdrive_contents(link) if parse_domain(link) == "arxiv.org": return fetch_arxiv(link) if pdf := fetch_pdf(link): return pdf - return {'error': f'Could not fetch pdf from {link}'} + return {"error": f"Could not fetch pdf from {link}"} + return getter @@ -77,10 +79,11 @@ def __call__(self, url): def error(error_msg): """Returns a url handler function that just logs the provided `error` string.""" + def func(url): if error_msg: logger.error(error_msg) - return {'error': error_msg, 'source_url': url} + return {"error": error_msg, "source_url": url} return func @@ -129,17 +132,17 @@ def getter(url): "mediangroup.org": element_extractor("div.entry-content"), "www.alexirpan.com": element_extractor("article"), "www.incompleteideas.net": element_extractor("body"), - "ai-alignment.com": MediumParser(name='html', url='ai-alignment.com'), + "ai-alignment.com": MediumParser(name="html", url="ai-alignment.com"), "aisrp.org": element_extractor("article"), "bounded-regret.ghost.io": element_extractor("div.post-content"), "carnegieendowment.org": element_extractor( "div.article-body", remove=[".no-print", ".related-pubs"] ), - "casparoesterheld.com": element_extractor( - ".entry-content", remove=["div.sharedaddy"] - ), + "casparoesterheld.com": element_extractor(".entry-content", remove=["div.sharedaddy"]), "cullenokeefe.com": element_extractor("div.sqs-block-content"), - "deepmindsafetyresearch.medium.com": MediumParser(name='html', url='deepmindsafetyresearch.medium.com'), + "deepmindsafetyresearch.medium.com": MediumParser( + name="html", url="deepmindsafetyresearch.medium.com" + ), "docs.google.com": google_doc, "docs.microsoft.com": element_extractor("div.content"), "digichina.stanford.edu": element_extractor("div.h_editor-content"), @@ -154,7 +157,7 @@ def getter(url): "link.springer.com": element_extractor("article.c-article-body"), "longtermrisk.org": element_extractor("div.entry-content"), "lukemuehlhauser.com": element_extractor("div.entry-content"), - "medium.com": MediumParser(name='html', url='medium.com'), + "medium.com": MediumParser(name="html", url="medium.com"), "openai.com": element_extractor("#content"), "ought.org": element_extractor("div.BlogPostBodyContainer"), "sideways-view.com": element_extractor("article", remove=["header"]), @@ -169,10 +172,8 @@ def getter(url): ), "theconversation.com": element_extractor("div.content-body"), "thegradient.pub": element_extractor("div.c-content"), - "towardsdatascience.com": MediumParser(name='html', url='towardsdatascience.com'), - "unstableontology.com": element_extractor( - ".entry-content", remove=["div.sharedaddy"] - ), + "towardsdatascience.com": MediumParser(name="html", url="towardsdatascience.com"), + "unstableontology.com": element_extractor(".entry-content", remove=["div.sharedaddy"]), "waitbutwhy.com": element_extractor("article", remove=[".entry-header"]), "weightagnostic.github.io": element_extractor( "dt-article", remove=["#authors_section", "dt-byline"] @@ -180,9 +181,7 @@ def getter(url): "cnas.org": element_extractor("#mainbar-toc"), "econlib.org": element_extractor("div.post-content"), "humanityplus.org": element_extractor("div.content"), - "gleech.org": element_extractor( - "article.post-content", remove=["center", "div.accordion"] - ), + "gleech.org": element_extractor("article.post-content", remove=["center", "div.accordion"]), "ibm.com": element_extractor("div:has(> p)"), # IBM's HTML is really ugly... "microsoft.com": element_extractor("div.content-container"), "mdpi.com": element_extractor( @@ -259,32 +258,30 @@ def getter(url): "jstor.org": doi_getter, "ri.cmu.edu": get_pdf_from_page("a.pub-link"), "risksciences.ucla.edu": get_pdf_from_page('a:-soup-contains("Download")'), - "ssrn.com": get_pdf_from_page( - '.abstract-buttons a.button-link:-soup-contains("Download")' - ), + "ssrn.com": get_pdf_from_page('.abstract-buttons a.button-link:-soup-contains("Download")'), "yjolt.org": get_pdf_from_page("span.file a"), } def parse_domain(url: str) -> str: - return url and urlparse(url).netloc.lstrip('www.') + return url and urlparse(url).netloc.lstrip("www.") def item_metadata(url) -> Dict[str, any]: domain = parse_domain(url) try: - res = fetch(url, 'head') + res = fetch(url, "head") except (MissingSchema, InvalidSchema, ConnectionError) as e: - return {'error': str(e)} + return {"error": str(e)} - content_type = {item.strip() for item in res.headers.get('Content-Type', '').split(';')} + content_type = {item.strip() for item in res.headers.get("Content-Type", "").split(";")} if content_type & {"text/html", "text/xml"}: # If the url points to a html webpage, then it either contains the text as html, or # there is a link to a pdf on it if parser := HTML_PARSERS.get(domain): res = parser(url) - if res and 'error' not in res: + if res and "error" not in res: # Proper contents were found on the page, so use them return res @@ -296,13 +293,11 @@ def item_metadata(url) -> Dict[str, any]: if parser := UNIMPLEMENTED_PARSERS.get(domain): return parser(url) - if domain not in ( - HTML_PARSERS.keys() | PDF_PARSERS.keys() | UNIMPLEMENTED_PARSERS.keys() - ): + if domain not in (HTML_PARSERS.keys() | PDF_PARSERS.keys() | UNIMPLEMENTED_PARSERS.keys()): return {"error": "No domain handler defined"} return {"error": "could not parse url"} elif content_type & {"application/octet-stream", "application/pdf"}: - if domain == 'arxiv.org': + if domain == "arxiv.org": return fetch_arxiv(url) # just download it as a pdf return fetch_pdf(url) diff --git a/align_data/sources/articles/pdf.py b/align_data/sources/articles/pdf.py index b0e31951..43495142 100644 --- a/align_data/sources/articles/pdf.py +++ b/align_data/sources/articles/pdf.py @@ -54,9 +54,7 @@ def fetch_pdf(link): link, ) - content_type = { - c_type.strip().lower() for c_type in res.headers.get("Content-Type").split(";") - } + content_type = {c_type.strip().lower() for c_type in res.headers.get("Content-Type").split(";")} if not content_type & {"application/octet-stream", "application/pdf"}: return { "error": f"Wrong content type retrieved: {content_type} - {link}", @@ -71,8 +69,8 @@ def fetch_pdf(link): "source_type": "pdf", } except (TypeError, PdfReadError) as e: - logger.error('Could not read PDF file: %s', e) - return {'error': str(e)} + logger.error("Could not read PDF file: %s", e) + return {"error": str(e)} filenames = [ i.strip().split("=")[1] @@ -96,11 +94,7 @@ def get_arxiv_link(doi): if res.status_code != 200: return None - vals = [ - val - for val in response.json().get("values") - if val.get("type", "").upper() == "URL" - ] + vals = [val for val in response.json().get("values") if val.get("type", "").upper() == "URL"] if not vals: return None @@ -135,7 +129,7 @@ def doi_getter(url): def parse_vanity(url) -> Dict[str, Any]: contents = fetch_element(url, "article") if not contents: - return {'error': 'Could not fetch from arxiv vanity'} + return {"error": "Could not fetch from arxiv vanity"} if title := contents.select_one("h1.ltx_title"): title = title.text diff --git a/align_data/sources/articles/updater.py b/align_data/sources/articles/updater.py index 861d0af6..f453b5ae 100644 --- a/align_data/sources/articles/updater.py +++ b/align_data/sources/articles/updater.py @@ -10,7 +10,7 @@ logger = logging.getLogger(__name__) -Item = namedtuple('Item', ['updates', 'article']) +Item = namedtuple("Item", ["updates", "article"]) @dataclass @@ -33,52 +33,50 @@ def maybe(item, key): def items_list(self): df = pd.read_csv(self.csv_path, delimiter=self.delimiter) self.csv_items = [ - item for item in df.itertuples() - if self.maybe(item, 'id') or self.maybe(item, 'hash_id') + item + for item in df.itertuples() + if self.maybe(item, "id") or self.maybe(item, "hash_id") ] - by_id = {i.id: i for i in self.csv_items if self.maybe(i, 'id')} - by_hash_id = {i.hash_id: i for i in self.csv_items if self.maybe(i, 'hash_id')} + by_id = {i.id: i for i in self.csv_items if self.maybe(i, "id")} + by_hash_id = {i.hash_id: i for i in self.csv_items if self.maybe(i, "hash_id")} - return [ - Item(by_id.get(a._id) or by_hash_id.get(a.id), a) - for a in self.read_entries() - ] + return [Item(by_id.get(a._id) or by_hash_id.get(a.id), a) for a in self.read_entries()] @property def _query_items(self): - ids = [i.id for i in self.csv_items if self.maybe(i, 'id')] - hash_ids = [i.hash_id for i in self.csv_items if self.maybe(i, 'hash_id')] + ids = [i.id for i in self.csv_items if self.maybe(i, "id")] + hash_ids = [i.hash_id for i in self.csv_items if self.maybe(i, "hash_id")] return select(Article).where(or_(Article.id.in_(hash_ids), Article._id.in_(ids))) def update_text(self, updates, article): # If the url is the same as it was before, and there isn't a source url provided, assume that the # previous text is still valid - if article.url == self.maybe(updates, 'url') and not self.maybe(updates, 'source_url'): + if article.url == self.maybe(updates, "url") and not self.maybe(updates, "source_url"): return # If no url found, then don't bother fetching the text - assume it was successfully fetched previously - url = self.maybe(updates, 'source_url') or self.maybe(updates, 'url') + url = self.maybe(updates, "source_url") or self.maybe(updates, "url") if not url: return if article.url != url: - article.add_meta('source_url', url) + article.add_meta("source_url", url) metadata = item_metadata(url) # Only change the text if it could be fetched - better to have outdated values than none - if metadata.get('text'): - article.text = metadata.get('text') - article.status = metadata.get('error') + if metadata.get("text"): + article.text = metadata.get("text") + article.status = metadata.get("error") def process_entry(self, item): updates, article = item - for key in ['url', 'title', 'source', 'authors', 'comment', 'confidence']: + for key in ["url", "title", "source", "authors", "comment", "confidence"]: value = self.maybe(updates, key) if value and getattr(article, key, None) != value: setattr(article, key, value) - if date := getattr(updates, 'date_published', None): + if date := getattr(updates, "date_published", None): article.date_published = self._get_published_date(date) self.update_text(updates, article) diff --git a/align_data/sources/arxiv_papers.py b/align_data/sources/arxiv_papers.py index 2b98223f..2fb64377 100644 --- a/align_data/sources/arxiv_papers.py +++ b/align_data/sources/arxiv_papers.py @@ -28,37 +28,37 @@ def get_arxiv_metadata(paper_id) -> arxiv.Result: return None -def get_id(url: str) -> Optional[str]: +def get_id(url: str) -> str | None: if res := re.search(r"https?://arxiv.org/(?:abs|pdf)/(.*?)(?:v\d+)?(?:/|\.pdf)?$", url): return res.group(1) def canonical_url(url: str) -> str: if paper_id := get_id(url): - return f'https://arxiv.org/abs/{paper_id}' + return f"https://arxiv.org/abs/{paper_id}" return url def get_contents(paper_id: str) -> Dict[str, Any]: arxiv_vanity = parse_vanity(f"https://www.arxiv-vanity.com/papers/{paper_id}") - if 'error' not in arxiv_vanity: + if "error" not in arxiv_vanity: return arxiv_vanity ar5iv = parse_vanity(f"https://ar5iv.org/abs/{paper_id}") - if 'error' not in ar5iv: + if "error" not in ar5iv: return ar5iv return fetch_pdf(f"https://arxiv.org/pdf/{paper_id}.pdf") -def get_version(id: str) -> Optional[str]: - if res := re.search(r'.*v(\d+)$', id): +def get_version(id: str) -> str | None: + if res := re.search(r".*v(\d+)$", id): return res.group(1) def is_withdrawn(url: str): - if elem := fetch_element(canonical_url(url), '.extra-services .full-text ul'): - return elem.text.strip().lower() == 'withdrawn' + if elem := fetch_element(canonical_url(url), ".extra-services .full-text ul"): + return elem.text.strip().lower() == "withdrawn" return None @@ -84,18 +84,21 @@ def add_metadata(data, paper_id): def fetch_arxiv(url) -> Dict: paper_id = get_id(url) if not paper_id: - return {'error': 'Could not extract arxiv id'} + return {"error": "Could not extract arxiv id"} if is_withdrawn(url): - paper = {'status': 'Withdrawn'} + paper = {"status": "Withdrawn"} else: paper = get_contents(paper_id) - data = add_metadata({ - "url": canonical_url(url), - "source_type": paper.get('data_source'), - }, paper_id) - authors = data.get('authors') or paper.get("authors") or [] - data['authors'] = [str(a).strip() for a in authors] + data = add_metadata( + { + "url": canonical_url(url), + "source_type": paper.get("data_source"), + }, + paper_id, + ) + authors = data.get("authors") or paper.get("authors") or [] + data["authors"] = [str(a).strip() for a in authors] return merge_dicts(data, paper) diff --git a/align_data/sources/blogs/__init__.py b/align_data/sources/blogs/__init__.py index 3fee9ca2..05831f3e 100644 --- a/align_data/sources/blogs/__init__.py +++ b/align_data/sources/blogs/__init__.py @@ -26,9 +26,7 @@ url="https://deepmindsafetyresearch.medium.com/", authors=["DeepMind Safety Research"], ), - GwernBlog( - name="gwern_blog", url="https://www.gwern.net/", authors=["Gwern Branwen"] - ), + GwernBlog(name="gwern_blog", url="https://www.gwern.net/", authors=["Gwern Branwen"]), ColdTakes( name="cold_takes", url="https://www.cold-takes.com/", diff --git a/align_data/sources/blogs/blogs.py b/align_data/sources/blogs/blogs.py index fcf97891..8c2a2b94 100644 --- a/align_data/sources/blogs/blogs.py +++ b/align_data/sources/blogs/blogs.py @@ -51,12 +51,7 @@ def _get_published_date(self, contents): return "" def extract_authors(self, article): - return ( - article.select_one("header .post-meta") - .text.split("·")[1] - .strip() - .split(", ") - ) + return article.select_one("header .post-meta").text.split("·")[1].strip().split(", ") class OpenAIResearch(HTMLDataset): @@ -82,9 +77,7 @@ def extract_authors(self, article): authors = [] if authors_div: authors = [ - i.split("(")[0].strip() - for i in authors_div.select_one("p").children - if not i.name + i.split("(")[0].strip() for i in authors_div.select_one("p").children if not i.name ] return authors or ["OpenAI Research"] @@ -114,9 +107,7 @@ def items_list(self): page += 1 # update the tqdm progress bar - pbar.set_postfix_str( - f"page {page}", refresh=True - ) # Set postfix to "page X" + pbar.set_postfix_str(f"page {page}", refresh=True) # Set postfix to "page X" pbar.update() # Here we increment the progress bar by 1 logger.info("Got %s pages", len(articles)) diff --git a/align_data/sources/blogs/wp_blog.py b/align_data/sources/blogs/wp_blog.py index c0132301..cd409d98 100644 --- a/align_data/sources/blogs/wp_blog.py +++ b/align_data/sources/blogs/wp_blog.py @@ -42,9 +42,7 @@ def items_list(self): self.items[item["link"]] = item # update the tqdm progress bar - pbar.set_postfix_str( - f"page {page_number}", refresh=True - ) # Set postfix to "page X" + pbar.set_postfix_str(f"page {page_number}", refresh=True) # Set postfix to "page X" pbar.update() # Here we increment the progress bar by 1 logger.info(f"Got {len(self.items)} pages") diff --git a/align_data/sources/distill/distill.py b/align_data/sources/distill/distill.py index 3a514a01..8709b154 100644 --- a/align_data/sources/distill/distill.py +++ b/align_data/sources/distill/distill.py @@ -6,7 +6,9 @@ class Distill(RSSDataset): done_key = "url" def extract_authors(self, item): - return [a.text for a in item["soup"].select(".authors-affiliations p.author a")] or ["Distill"] + return [a.text for a in item["soup"].select(".authors-affiliations p.author a")] or [ + "Distill" + ] def _get_text(self, item): article = item["soup"].find("d-article") or item["soup"].find("dt-article") diff --git a/align_data/sources/ebooks/__init__.py b/align_data/sources/ebooks/__init__.py index 0055f5e0..7fdcd729 100644 --- a/align_data/sources/ebooks/__init__.py +++ b/align_data/sources/ebooks/__init__.py @@ -1,7 +1,5 @@ from .agentmodels import AgentModels EBOOK_REGISTRY = [ - AgentModels( - name="agentmodels", repo="https://github.com/agentmodels/agentmodels.org.git" - ), + AgentModels(name="agentmodels", repo="https://github.com/agentmodels/agentmodels.org.git"), ] diff --git a/align_data/sources/ebooks/agentmodels.py b/align_data/sources/ebooks/agentmodels.py index cfd68a79..65b52502 100644 --- a/align_data/sources/ebooks/agentmodels.py +++ b/align_data/sources/ebooks/agentmodels.py @@ -27,9 +27,7 @@ def setup(self): self.files_path = self.base_dir / "chapters" def _get_published_date(self, filename): - last_commit = next( - self.repository.iter_commits(paths=f"chapters/{filename.name}") - ) + last_commit = next(self.repository.iter_commits(paths=f"chapters/{filename.name}")) return last_commit.committed_datetime.astimezone(timezone.utc) def process_entry(self, filename): diff --git a/align_data/sources/greaterwrong/greaterwrong.py b/align_data/sources/greaterwrong/greaterwrong.py index deb6d4d1..5f32b68f 100644 --- a/align_data/sources/greaterwrong/greaterwrong.py +++ b/align_data/sources/greaterwrong/greaterwrong.py @@ -77,10 +77,7 @@ def setup(self): self.ai_tags = get_allowed_tags(self.base_url, self.name) def tags_ok(self, post): - return ( - not self.ai_tags - or {t["name"] for t in post["tags"] if t.get("name")} & self.ai_tags - ) + return not self.ai_tags or {t["name"] for t in post["tags"] if t.get("name")} & self.ai_tags def get_item_key(self, item): return item["pageUrl"] @@ -174,7 +171,7 @@ def process_entry(self, item): authors = item["coauthors"] if item["user"]: authors = [item["user"]] + authors - authors = [a["displayName"] for a in authors] or ['anonymous'] + authors = [a["displayName"] for a in authors] or ["anonymous"] return self.make_data_entry( { "title": item["title"], diff --git a/align_data/sources/stampy/stampy.py b/align_data/sources/stampy/stampy.py index 88a7149e..49f57620 100644 --- a/align_data/sources/stampy/stampy.py +++ b/align_data/sources/stampy/stampy.py @@ -45,13 +45,9 @@ def _get_published_date(self, entry): def process_entry(self, entry): def clean_text(text): text = html.unescape(text) - return re.sub( - r"\(/\?state=(\w+)\)", r"(http://aisafety.info?state=\1)", text - ) + return re.sub(r"\(/\?state=(\w+)\)", r"(http://aisafety.info?state=\1)", text) - question = clean_text( - entry["Question"] - ) # raise an error if the entry has no question + question = clean_text(entry["Question"]) # raise an error if the entry has no question answer = clean_text(entry["Rich Text"]) url = "https://aisafety.info?state=" + entry["UI ID"] diff --git a/main.py b/main.py index 67478a12..2c1d8da2 100644 --- a/main.py +++ b/main.py @@ -7,8 +7,13 @@ from align_data import ALL_DATASETS, get_dataset from align_data.analysis.count_tokens import count_token -from align_data.sources.articles.articles import update_new_items, check_new_articles, update_articles -from align_data.pinecone.update_pinecone import PineconeUpdater +from align_data.sources.articles.articles import ( + update_new_items, + check_new_articles, + update_articles, +) +from align_data.embeddings.pinecone.update_pinecone import PineconeUpdater +from align_data.embeddings.finetuning.training import finetune_embeddings from align_data.settings import ( METADATA_OUTPUT_SPREADSHEET, METADATA_SOURCE_SHEET, @@ -76,12 +81,10 @@ def count_tokens(self, merged_dataset_path: str) -> None: This function counts the number of tokens, words, and characters in the dataset :return: None """ - assert os.path.exists( - merged_dataset_path - ), "The path to the merged dataset does not exist" + assert os.path.exists(merged_dataset_path), "The path to the merged dataset does not exist" count_token(merged_dataset_path) - def update(self, csv_path, delimiter=','): + def update(self, csv_path, delimiter=","): """Update all articles in the provided csv files, overwriting the provided values and fetching new text if a different url provided. :param str csv_path: The path to the csv file to be processed @@ -115,7 +118,7 @@ def fetch_new_articles( """ return check_new_articles(source_spreadsheet, source_sheet) - def pinecone_update(self, *names) -> None: + def pinecone_update(self, *names, force_update=False) -> None: """ This function updates the Pinecone vector DB. @@ -125,14 +128,20 @@ def pinecone_update(self, *names) -> None: names = ALL_DATASETS missing = {name for name in names if name not in ALL_DATASETS} assert not missing, f"{missing} are not valid dataset names" - PineconeUpdater().update(names) + PineconeUpdater().update(names, force_update) - def pinecone_update_all(self, *skip) -> None: + def pinecone_update_all(self, *skip, force_update=False) -> None: """ This function updates the Pinecone vector DB. """ names = [name for name in ALL_DATASETS if name not in skip] - PineconeUpdater().update(names) + PineconeUpdater().update(names, force_update) + + def train_finetuning_layer(self) -> None: + """ + This function trains a finetuning layer on top of the OpenAI embeddings. + """ + finetune_embeddings() if __name__ == "__main__": diff --git a/migrations/versions/59ac3cb671e3_added_pinecone_update_required_to_.py b/migrations/versions/59ac3cb671e3_added_pinecone_update_required_to_.py index 7a8485fe..e5b9a303 100644 --- a/migrations/versions/59ac3cb671e3_added_pinecone_update_required_to_.py +++ b/migrations/versions/59ac3cb671e3_added_pinecone_update_required_to_.py @@ -18,9 +18,7 @@ def upgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### - op.add_column( - "articles", sa.Column("pinecone_update_required", sa.Boolean(), nullable=False) - ) + op.add_column("articles", sa.Column("pinecone_update_required", sa.Boolean(), nullable=False)) # ### end Alembic commands ### diff --git a/migrations/versions/f5a2bcfa6b2c_add_status_column.py b/migrations/versions/f5a2bcfa6b2c_add_status_column.py index 76c89ee0..d93a8c86 100644 --- a/migrations/versions/f5a2bcfa6b2c_add_status_column.py +++ b/migrations/versions/f5a2bcfa6b2c_add_status_column.py @@ -10,17 +10,17 @@ from sqlalchemy.dialects import mysql # revision identifiers, used by Alembic. -revision = 'f5a2bcfa6b2c' -down_revision = '59ac3cb671e3' +revision = "f5a2bcfa6b2c" +down_revision = "59ac3cb671e3" branch_labels = None depends_on = None def upgrade() -> None: - op.add_column('articles', sa.Column('status', sa.String(length=256), nullable=True)) - op.add_column('articles', sa.Column('comments', mysql.LONGTEXT(), nullable=True)) + op.add_column("articles", sa.Column("status", sa.String(length=256), nullable=True)) + op.add_column("articles", sa.Column("comments", mysql.LONGTEXT(), nullable=True)) def downgrade() -> None: - op.drop_column('articles', 'comments') - op.drop_column('articles', 'status') + op.drop_column("articles", "comments") + op.drop_column("articles", "status") diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 00000000..aa4949aa --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,2 @@ +[tool.black] +line-length = 100 diff --git a/requirements.txt b/requirements.txt index b7985f3f..88ce2517 100644 --- a/requirements.txt +++ b/requirements.txt @@ -36,3 +36,5 @@ openai langchain nltk pinecone-client + +torch diff --git a/tests/align_data/articles/test_datasets.py b/tests/align_data/articles/test_datasets.py index a0ea9fc1..5b73e05b 100644 --- a/tests/align_data/articles/test_datasets.py +++ b/tests/align_data/articles/test_datasets.py @@ -48,7 +48,7 @@ def mock_arxiv(): journal_ref="sdf", primary_category="cat", ) - metadata.get_short_id.return_value = '2001.11038' + metadata.get_short_id.return_value = "2001.11038" arxiv = Mock() arxiv.Search.return_value.results.return_value = iter([metadata]) @@ -124,30 +124,24 @@ def test_pdf_articles_process_item(articles): "text": "pdf contents [bla](asd.com)", "title": "article no 0", "url": "http://example.com/item/0", - 'source_url': 'http://example.com/source_url/0', + "source_url": "http://example.com/source_url/0", } def test_html_articles_get_text(): def parser(url): assert url == "http://example.org/bla.bla" - return {'text': "html contents"} + return {"text": "html contents"} - with patch( - "align_data.sources.articles.datasets.HTML_PARSERS", {"example.org": parser} - ): + with patch("align_data.sources.articles.datasets.HTML_PARSERS", {"example.org": parser}): assert ( - HTMLArticles._get_text(Mock(source_url="http://example.org/bla.bla")) - == "html contents" + HTMLArticles._get_text(Mock(source_url="http://example.org/bla.bla")) == "html contents" ) def test_html_articles_get_text_no_parser(): with patch("align_data.sources.articles.datasets.HTML_PARSERS", {}): - assert ( - HTMLArticles._get_text(Mock(source_url="http://example.org/bla.bla")) - is None - ) + assert HTMLArticles._get_text(Mock(source_url="http://example.org/bla.bla")) is None def test_html_articles_process_entry(articles): @@ -156,7 +150,9 @@ def test_html_articles_process_entry(articles): item = list(dataset.items_list)[0] parsers = { - "example.com": lambda _: {'text': ' html contents with proper elements ble ble '} + "example.com": lambda _: { + "text": ' html contents with proper elements ble ble ' + } } with patch("align_data.sources.articles.datasets.HTML_PARSERS", parsers): assert dataset.process_entry(item).to_dict() == { @@ -170,7 +166,7 @@ def test_html_articles_process_entry(articles): "text": "html contents with [proper elements](bla.com) ble ble", "title": "article no 0", "url": "http://example.com/item/0", - 'source_url': 'http://example.com/source_url/0', + "source_url": "http://example.com/source_url/0", } @@ -201,9 +197,7 @@ def test_ebook_articles_process_entry(articles): contents = ' html contents with proper elements ble ble ' with patch("align_data.sources.articles.datasets.download"): - with patch( - "align_data.sources.articles.datasets.convert_file", return_value=contents - ): + with patch("align_data.sources.articles.datasets.convert_file", return_value=contents): assert dataset.process_entry(item).to_dict() == { "authors": ["John Snow", "mr Blobby"], "date_published": "2023-01-01T12:32:11Z", @@ -215,7 +209,7 @@ def test_ebook_articles_process_entry(articles): "text": "html contents with [proper elements](bla.com) ble ble", "title": "article no 0", "url": "http://example.com/item/0", - 'source_url': 'http://example.com/source_url/0', + "source_url": "http://example.com/source_url/0", } @@ -248,7 +242,7 @@ def test_xml_articles_process_entry(articles): "text": "bla bla", "title": "article no 0", "url": "http://example.com/item/0", - 'source_url': 'http://example.com/source_url/0', + "source_url": "http://example.com/source_url/0", } @@ -281,19 +275,15 @@ def test_markdown_articles_process_entry(articles): "text": "bla bla", "title": "article no 0", "url": "http://example.com/item/0", - 'source_url': 'http://example.com/source_url/0', + "source_url": "http://example.com/source_url/0", } def test_doc_articles_get_text(): dataset = DocArticles(name="bla", spreadsheet_id="123", sheet_id="456") with patch("align_data.sources.articles.datasets.fetch_file"): - with patch( - "align_data.sources.articles.datasets.convert_file", return_value="bla bla" - ): - assert ( - dataset._get_text(Mock(source_url="bla.com/bla/123/bla")) == "bla bla" - ) + with patch("align_data.sources.articles.datasets.convert_file", return_value="bla bla"): + assert dataset._get_text(Mock(source_url="bla.com/bla/123/bla")) == "bla bla" def test_doc_articles_process_entry(articles): @@ -302,9 +292,7 @@ def test_doc_articles_process_entry(articles): item = list(dataset.items_list)[0] with patch("align_data.sources.articles.datasets.fetch_file"): - with patch( - "align_data.sources.articles.datasets.convert_file", return_value="bla bla" - ): + with patch("align_data.sources.articles.datasets.convert_file", return_value="bla bla"): assert dataset.process_entry(item).to_dict() == { "authors": ["John Snow", "mr Blobby"], "date_published": "2023-01-01T12:32:11Z", @@ -316,11 +304,11 @@ def test_doc_articles_process_entry(articles): "text": "bla bla", "title": "article no 0", "url": "http://example.com/item/0", - 'source_url': 'http://example.com/source_url/0', + "source_url": "http://example.com/source_url/0", } -@patch('requests.get', return_value=Mock(content='')) +@patch("requests.get", return_value=Mock(content="")) def test_arxiv_process_entry(_, mock_arxiv): dataset = ArxivPapers(name="asd", spreadsheet_id="ad", sheet_id="da") item = Mock( @@ -335,9 +323,7 @@ def test_arxiv_process_entry(_, mock_arxiv): "authors": ["mr blobby"], "source_type": "html", } - with patch( - "align_data.sources.arxiv_papers.parse_vanity", return_value=contents - ): + with patch("align_data.sources.arxiv_papers.parse_vanity", return_value=contents): assert dataset.process_entry(item).to_dict() == { "comment": "no comment", "authors": ["mr blobby"], @@ -377,9 +363,9 @@ def test_arxiv_process_entry_retracted(mock_arxiv): """ - with patch('requests.get', return_value=Mock(content=response)): + with patch("requests.get", return_value=Mock(content=response)): article = dataset.process_entry(item) - assert article.status == 'Withdrawn' + assert article.status == "Withdrawn" assert article.to_dict() == { "comment": "no comment", "authors": [], @@ -407,7 +393,7 @@ def test_special_docs_process_entry(): authors="mr. blobby", date_published="2023-10-02T01:23:45", source_type=None, - source_url="https://ble.ble.com" + source_url="https://ble.ble.com", ) contents = { "text": "this is the text", @@ -418,20 +404,20 @@ def test_special_docs_process_entry(): with patch("align_data.sources.articles.datasets.item_metadata", return_value=contents): assert dataset.process_entry(item).to_dict() == { - 'authors': ['mr. blobby'], - 'date_published': '2023-10-02T01:23:45Z', - 'id': None, - 'source': 'html', - 'source_url': "https://ble.ble.com", - 'source_type': 'html', - 'summaries': [], - 'text': 'this is the text', - 'title': 'this is the title', - 'url': 'https://bla.bla.bla', + "authors": ["mr. blobby"], + "date_published": "2023-10-02T01:23:45Z", + "id": None, + "source": "html", + "source_url": "https://ble.ble.com", + "source_type": "html", + "summaries": [], + "text": "this is the text", + "title": "this is the title", + "url": "https://bla.bla.bla", } -@patch('requests.get', return_value=Mock(content='')) +@patch("requests.get", return_value=Mock(content="")) def test_special_docs_process_entry_arxiv(_, mock_arxiv): dataset = SpecialDocs(name="asd", spreadsheet_id="ad", sheet_id="da") item = Mock( @@ -447,9 +433,7 @@ def test_special_docs_process_entry_arxiv(_, mock_arxiv): "source_type": "pdf", } - with patch( - "align_data.sources.arxiv_papers.parse_vanity", return_value=contents - ): + with patch("align_data.sources.arxiv_papers.parse_vanity", return_value=contents): assert dataset.process_entry(item).to_dict() == { "comment": "no comment", "authors": ["mr blobby"], @@ -469,16 +453,22 @@ def test_special_docs_process_entry_arxiv(_, mock_arxiv): } -@pytest.mark.parametrize('url, expected', ( - ("http://bla.bla", "http://bla.bla"), - ("http://arxiv.org/abs/2001.11038", "https://arxiv.org/abs/2001.11038"), - ("https://arxiv.org/abs/2001.11038", "https://arxiv.org/abs/2001.11038"), - ("https://arxiv.org/abs/2001.11038/", "https://arxiv.org/abs/2001.11038"), - ("https://arxiv.org/pdf/2001.11038", "https://arxiv.org/abs/2001.11038"), - ("https://arxiv.org/pdf/2001.11038.pdf", "https://arxiv.org/abs/2001.11038"), - ("https://arxiv.org/pdf/2001.11038v3.pdf", "https://arxiv.org/abs/2001.11038"), - ("https://arxiv.org/abs/math/2001.11038", "https://arxiv.org/abs/math/2001.11038"), -)) +@pytest.mark.parametrize( + "url, expected", + ( + ("http://bla.bla", "http://bla.bla"), + ("http://arxiv.org/abs/2001.11038", "https://arxiv.org/abs/2001.11038"), + ("https://arxiv.org/abs/2001.11038", "https://arxiv.org/abs/2001.11038"), + ("https://arxiv.org/abs/2001.11038/", "https://arxiv.org/abs/2001.11038"), + ("https://arxiv.org/pdf/2001.11038", "https://arxiv.org/abs/2001.11038"), + ("https://arxiv.org/pdf/2001.11038.pdf", "https://arxiv.org/abs/2001.11038"), + ("https://arxiv.org/pdf/2001.11038v3.pdf", "https://arxiv.org/abs/2001.11038"), + ( + "https://arxiv.org/abs/math/2001.11038", + "https://arxiv.org/abs/math/2001.11038", + ), + ), +) def test_special_docs_not_processed_true(url, expected): dataset = SpecialDocs(name="asd", spreadsheet_id="ad", sheet_id="da") dataset._outputted_items = [url, expected] @@ -486,13 +476,15 @@ def test_special_docs_not_processed_true(url, expected): assert not dataset.not_processed(Mock(url=None, source_url=url)) -@pytest.mark.parametrize('url', ( - "http://bla.bla" - "http://arxiv.org/abs/2001.11038", - "https://arxiv.org/abs/2001.11038", - "https://arxiv.org/abs/2001.11038/", - "https://arxiv.org/pdf/2001.11038", -)) +@pytest.mark.parametrize( + "url", + ( + "http://bla.bla" "http://arxiv.org/abs/2001.11038", + "https://arxiv.org/abs/2001.11038", + "https://arxiv.org/abs/2001.11038/", + "https://arxiv.org/pdf/2001.11038", + ), +) def test_special_docs_not_processed_false(url): dataset = SpecialDocs(name="asd", spreadsheet_id="ad", sheet_id="da") dataset._outputted_items = [] diff --git a/tests/align_data/articles/test_google_cloud.py b/tests/align_data/articles/test_google_cloud.py index 39cacce3..bc814fe1 100644 --- a/tests/align_data/articles/test_google_cloud.py +++ b/tests/align_data/articles/test_google_cloud.py @@ -1,7 +1,12 @@ from unittest.mock import Mock, patch import pytest -from align_data.sources.articles.google_cloud import extract_gdrive_contents, get_content_type, google_doc, parse_grobid +from align_data.sources.articles.google_cloud import ( + extract_gdrive_contents, + get_content_type, + google_doc, + parse_grobid, +) SAMPLE_XML = """ @@ -45,8 +50,12 @@ def test_google_doc(): def fetcher(url, *args, **kwargs): - assert url == 'https://docs.google.com/document/d/1fenKXrbvGeZ83hxYf_6mghsZMChxWXjGsZSqY3LZzms/export?format=html' - return Mock(content=""" + assert ( + url + == "https://docs.google.com/document/d/1fenKXrbvGeZ83hxYf_6mghsZMChxWXjGsZSqY3LZzms/export?format=html" + ) + return Mock( + content="""
bla bla bla
@@ -54,35 +63,45 @@ def fetcher(url, *args, **kwargs): - """) + """ + ) - with patch('requests.get', fetcher): - url = 'https://docs.google.com/document/d/1fenKXrbvGeZ83hxYf_6mghsZMChxWXjGsZSqY3LZzms/edit' + with patch("requests.get", fetcher): + url = "https://docs.google.com/document/d/1fenKXrbvGeZ83hxYf_6mghsZMChxWXjGsZSqY3LZzms/edit" assert google_doc(url) == { - 'text': 'ble ble [a link](bla.com)', - 'source_url': url + "text": "ble ble [a link](bla.com)", + "source_url": url, } def test_google_doc_no_body(): def fetcher(url, *args, **kwargs): - assert url == 'https://docs.google.com/document/d/1fenKXrbvGeZ83hxYf_6mghsZMChxWXjGsZSqY3LZzms/export?format=html' + assert ( + url + == "https://docs.google.com/document/d/1fenKXrbvGeZ83hxYf_6mghsZMChxWXjGsZSqY3LZzms/export?format=html" + ) return Mock(content="
bla bla bla
") - with patch('requests.get', fetcher): - assert google_doc('https://docs.google.com/document/d/1fenKXrbvGeZ83hxYf_6mghsZMChxWXjGsZSqY3LZzms/edit') == {} + with patch("requests.get", fetcher): + assert ( + google_doc( + "https://docs.google.com/document/d/1fenKXrbvGeZ83hxYf_6mghsZMChxWXjGsZSqY3LZzms/edit" + ) + == {} + ) def test_google_doc_bad_url(): - assert google_doc('https://docs.google.com/bla/bla') == {} + assert google_doc("https://docs.google.com/bla/bla") == {} + def test_parse_grobid(): assert parse_grobid(SAMPLE_XML) == { - 'abstract': 'this is the abstract', - 'authors': ['Cullen Oâ\x80\x99Keefe'], - 'text': 'This is the contents', - 'title': 'The title!!', - 'source_type': 'xml', + "abstract": "this is the abstract", + "authors": ["Cullen Oâ\x80\x99Keefe"], + "text": "This is the contents", + "title": "The title!!", + "source_type": "xml", } @@ -104,74 +123,94 @@ def test_parse_grobid_no_body(): """ - assert parse_grobid(xml) == {'error': 'No contents in XML file', 'source_type': 'xml'} - + assert parse_grobid(xml) == { + "error": "No contents in XML file", + "source_type": "xml", + } -@pytest.mark.parametrize('header, expected', ( - (None, set()), - ('', set()), - ('text/html', {'text/html'}), - ('text/html; bla=asdas; fewwe=fe', {'text/html', 'bla=asdas', 'fewwe=fe'}), -)) +@pytest.mark.parametrize( + "header, expected", + ( + (None, set()), + ("", set()), + ("text/html", {"text/html"}), + ("text/html; bla=asdas; fewwe=fe", {"text/html", "bla=asdas", "fewwe=fe"}), + ), +) def test_get_content_type(header, expected): - assert get_content_type(Mock(headers={'Content-Type': header})) == expected - - -@pytest.mark.parametrize('headers', ( - {}, - {'Content-Type': None}, - {'Content-Type': ''}, - {'Content-Type': ' '}, - {'Content-Type': ' ; ;; '}, -)) + assert get_content_type(Mock(headers={"Content-Type": header})) == expected + + +@pytest.mark.parametrize( + "headers", + ( + {}, + {"Content-Type": None}, + {"Content-Type": ""}, + {"Content-Type": " "}, + {"Content-Type": " ; ;; "}, + ), +) def test_extract_gdrive_contents_no_contents(headers): - url = 'https://drive.google.com/file/d/1OrKZlksba2a8gKa5bAQfP2qF717O_57I/view?usp=sharing' - with patch('requests.head', return_value=Mock(headers=headers, status_code=200)): + url = "https://drive.google.com/file/d/1OrKZlksba2a8gKa5bAQfP2qF717O_57I/view?usp=sharing" + with patch("requests.head", return_value=Mock(headers=headers, status_code=200)): assert extract_gdrive_contents(url) == { - 'downloaded_from': 'google drive', - 'source_url': 'https://drive.google.com/file/d/1OrKZlksba2a8gKa5bAQfP2qF717O_57I/view?usp=sharing', - 'error': 'no content type' + "downloaded_from": "google drive", + "source_url": "https://drive.google.com/file/d/1OrKZlksba2a8gKa5bAQfP2qF717O_57I/view?usp=sharing", + "error": "no content type", } -@pytest.mark.parametrize('header', ( - 'application/octet-stream', - 'application/pdf', - 'application/pdf; filename=bla.pdf' -)) +@pytest.mark.parametrize( + "header", + ( + "application/octet-stream", + "application/pdf", + "application/pdf; filename=bla.pdf", + ), +) def test_extract_gdrive_contents_pdf(header): - res = Mock(headers={'Content-Type': header}, status_code=200) - url = 'https://drive.google.com/file/d/1OrKZlksba2a8gKa5bAQfP2qF717O_57I/view?usp=sharing' - with patch('requests.head', return_value=res): - with patch('align_data.sources.articles.google_cloud.fetch_pdf', return_value={'text': 'bla'}): + res = Mock(headers={"Content-Type": header}, status_code=200) + url = "https://drive.google.com/file/d/1OrKZlksba2a8gKa5bAQfP2qF717O_57I/view?usp=sharing" + with patch("requests.head", return_value=res): + with patch( + "align_data.sources.articles.google_cloud.fetch_pdf", + return_value={"text": "bla"}, + ): assert extract_gdrive_contents(url) == { - 'downloaded_from': 'google drive', - 'source_url': 'https://drive.google.com/file/d/1OrKZlksba2a8gKa5bAQfP2qF717O_57I/view?usp=sharing', - 'text': 'bla', + "downloaded_from": "google drive", + "source_url": "https://drive.google.com/file/d/1OrKZlksba2a8gKa5bAQfP2qF717O_57I/view?usp=sharing", + "text": "bla", } -@pytest.mark.parametrize('header', ( - 'application/epub', - 'application/epub+zip', - 'application/epub; filename=bla.epub', -)) +@pytest.mark.parametrize( + "header", + ( + "application/epub", + "application/epub+zip", + "application/epub; filename=bla.epub", + ), +) def test_extract_gdrive_contents_ebook(header): - res = Mock(headers={'Content-Type': header}, status_code=200) - url = 'https://drive.google.com/file/d/1OrKZlksba2a8gKa5bAQfP2qF717O_57I/view?usp=sharing' - with patch('requests.head', return_value=res): + res = Mock(headers={"Content-Type": header}, status_code=200) + url = "https://drive.google.com/file/d/1OrKZlksba2a8gKa5bAQfP2qF717O_57I/view?usp=sharing" + with patch("requests.head", return_value=res): assert extract_gdrive_contents(url) == { - 'downloaded_from': 'google drive', - 'source_url': 'https://drive.google.com/file/d/1OrKZlksba2a8gKa5bAQfP2qF717O_57I/view?usp=sharing', - 'source_type': 'ebook', + "downloaded_from": "google drive", + "source_url": "https://drive.google.com/file/d/1OrKZlksba2a8gKa5bAQfP2qF717O_57I/view?usp=sharing", + "source_type": "ebook", } def test_extract_gdrive_contents_html(): - res = Mock(headers={'Content-Type': 'text/html'}, status_code=200) - url = 'https://drive.google.com/file/d/1OrKZlksba2a8gKa5bAQfP2qF717O_57I/view?usp=sharing' - with patch('requests.head', return_value=Mock(headers={'Content-Type': 'text/html'}, status_code=200)): + res = Mock(headers={"Content-Type": "text/html"}, status_code=200) + url = "https://drive.google.com/file/d/1OrKZlksba2a8gKa5bAQfP2qF717O_57I/view?usp=sharing" + with patch( + "requests.head", + return_value=Mock(headers={"Content-Type": "text/html"}, status_code=200), + ): html = """
bleee
@@ -179,45 +218,48 @@ def test_extract_gdrive_contents_html(): """ res = Mock( - headers={'Content-Type': 'text/html'}, + headers={"Content-Type": "text/html"}, status_code=200, content=html, text=html, ) - with patch('requests.get', return_value=res): + with patch("requests.get", return_value=res): assert extract_gdrive_contents(url) == { - 'downloaded_from': 'google drive', - 'source_url': 'https://drive.google.com/file/d/1OrKZlksba2a8gKa5bAQfP2qF717O_57I/view?usp=sharing', - 'text': 'bla bla', - 'source_type': 'html', + "downloaded_from": "google drive", + "source_url": "https://drive.google.com/file/d/1OrKZlksba2a8gKa5bAQfP2qF717O_57I/view?usp=sharing", + "text": "bla bla", + "source_type": "html", } def test_extract_gdrive_contents_xml(): - res = Mock(headers={'Content-Type': 'text/html'}, status_code=200) - url = 'https://drive.google.com/file/d/1OrKZlksba2a8gKa5bAQfP2qF717O_57I/view?usp=sharing' - with patch('requests.head', return_value=Mock(headers={'Content-Type': 'text/html'}, status_code=200)): + res = Mock(headers={"Content-Type": "text/html"}, status_code=200) + url = "https://drive.google.com/file/d/1OrKZlksba2a8gKa5bAQfP2qF717O_57I/view?usp=sharing" + with patch( + "requests.head", + return_value=Mock(headers={"Content-Type": "text/html"}, status_code=200), + ): res = Mock( - headers={'Content-Type': 'text/xml'}, + headers={"Content-Type": "text/xml"}, status_code=200, content=SAMPLE_XML, text=SAMPLE_XML, ) - with patch('requests.get', return_value=res): + with patch("requests.get", return_value=res): assert extract_gdrive_contents(url) == { - 'abstract': 'this is the abstract', - 'authors': ['Cullen Oâ\x80\x99Keefe'], - 'downloaded_from': 'google drive', - 'source_url': 'https://drive.google.com/file/d/1OrKZlksba2a8gKa5bAQfP2qF717O_57I/view?usp=sharing', - 'text': 'This is the contents', - 'title': 'The title!!', - 'source_type': 'xml', + "abstract": "this is the abstract", + "authors": ["Cullen Oâ\x80\x99Keefe"], + "downloaded_from": "google drive", + "source_url": "https://drive.google.com/file/d/1OrKZlksba2a8gKa5bAQfP2qF717O_57I/view?usp=sharing", + "text": "This is the contents", + "title": "The title!!", + "source_type": "xml", } def test_extract_gdrive_contents_xml_with_confirm(): - res = Mock(headers={'Content-Type': 'text/html'}, status_code=200) - url = 'https://drive.google.com/file/d/1OrKZlksba2a8gKa5bAQfP2qF717O_57I/view?usp=sharing' + res = Mock(headers={"Content-Type": "text/html"}, status_code=200) + url = "https://drive.google.com/file/d/1OrKZlksba2a8gKa5bAQfP2qF717O_57I/view?usp=sharing" def fetcher(link, *args, **kwargs): # The first request should get the google drive warning page @@ -228,27 +270,35 @@ def fetcher(link, *args, **kwargs):
""" - return Mock(headers={'Content-Type': 'text/html'}, status_code=200, text=html, content=html) + return Mock( + headers={"Content-Type": "text/html"}, + status_code=200, + text=html, + content=html, + ) # The second one returns the actual contents - return Mock(headers={'Content-Type': 'text/xml'}, status_code=200, content=SAMPLE_XML) + return Mock(headers={"Content-Type": "text/xml"}, status_code=200, content=SAMPLE_XML) - with patch('requests.head', return_value=Mock(headers={'Content-Type': 'text/html'}, status_code=200)): - with patch('requests.get', fetcher): + with patch( + "requests.head", + return_value=Mock(headers={"Content-Type": "text/html"}, status_code=200), + ): + with patch("requests.get", fetcher): assert extract_gdrive_contents(url) == { - 'abstract': 'this is the abstract', - 'authors': ['Cullen Oâ\x80\x99Keefe'], - 'downloaded_from': 'google drive', - 'source_url': 'https://drive.google.com/file/d/1OrKZlksba2a8gKa5bAQfP2qF717O_57I/view?usp=sharing', - 'text': 'This is the contents', - 'title': 'The title!!', - 'source_type': 'xml', + "abstract": "this is the abstract", + "authors": ["Cullen Oâ\x80\x99Keefe"], + "downloaded_from": "google drive", + "source_url": "https://drive.google.com/file/d/1OrKZlksba2a8gKa5bAQfP2qF717O_57I/view?usp=sharing", + "text": "This is the contents", + "title": "The title!!", + "source_type": "xml", } def test_extract_gdrive_contents_warning_with_unknown(): - res = Mock(headers={'Content-Type': 'text/html'}, status_code=200) - url = 'https://drive.google.com/file/d/1OrKZlksba2a8gKa5bAQfP2qF717O_57I/view?usp=sharing' + res = Mock(headers={"Content-Type": "text/html"}, status_code=200) + url = "https://drive.google.com/file/d/1OrKZlksba2a8gKa5bAQfP2qF717O_57I/view?usp=sharing" def fetcher(link, *args, **kwargs): # The first request should get the google drive warning page @@ -259,26 +309,34 @@ def fetcher(link, *args, **kwargs):
""" - return Mock(headers={'Content-Type': 'text/html'}, status_code=200, text=html, content=html) + return Mock( + headers={"Content-Type": "text/html"}, + status_code=200, + text=html, + content=html, + ) # The second one returns the actual contents, with an unhandled content type - return Mock(headers={'Content-Type': 'text/bla bla'}, status_code=200) + return Mock(headers={"Content-Type": "text/bla bla"}, status_code=200) - with patch('requests.head', return_value=Mock(headers={'Content-Type': 'text/html'}, status_code=200)): - with patch('requests.get', fetcher): + with patch( + "requests.head", + return_value=Mock(headers={"Content-Type": "text/html"}, status_code=200), + ): + with patch("requests.get", fetcher): assert extract_gdrive_contents(url) == { - 'downloaded_from': 'google drive', - 'error': "unknown content type: {'text/bla bla'}", - 'source_url': 'https://drive.google.com/file/d/1OrKZlksba2a8gKa5bAQfP2qF717O_57I/view?usp=sharing', + "downloaded_from": "google drive", + "error": "unknown content type: {'text/bla bla'}", + "source_url": "https://drive.google.com/file/d/1OrKZlksba2a8gKa5bAQfP2qF717O_57I/view?usp=sharing", } def test_extract_gdrive_contents_unknown_content_type(): - res = Mock(headers={'Content-Type': 'bla bla'}, status_code=200) - url = 'https://drive.google.com/file/d/1OrKZlksba2a8gKa5bAQfP2qF717O_57I/view?usp=sharing' - with patch('requests.head', return_value=res): + res = Mock(headers={"Content-Type": "bla bla"}, status_code=200) + url = "https://drive.google.com/file/d/1OrKZlksba2a8gKa5bAQfP2qF717O_57I/view?usp=sharing" + with patch("requests.head", return_value=res): assert extract_gdrive_contents(url) == { - 'downloaded_from': 'google drive', - 'source_url': 'https://drive.google.com/file/d/1OrKZlksba2a8gKa5bAQfP2qF717O_57I/view?usp=sharing', - 'error': "unknown content type: {'bla bla'}", + "downloaded_from": "google drive", + "source_url": "https://drive.google.com/file/d/1OrKZlksba2a8gKa5bAQfP2qF717O_57I/view?usp=sharing", + "error": "unknown content type: {'bla bla'}", } diff --git a/tests/align_data/articles/test_parsers.py b/tests/align_data/articles/test_parsers.py index 8bac313f..5a174e3f 100644 --- a/tests/align_data/articles/test_parsers.py +++ b/tests/align_data/articles/test_parsers.py @@ -42,6 +42,7 @@ """ + def test_medium_blog(): html = """
@@ -60,14 +61,14 @@ def test_medium_blog():
""" with patch("requests.get", return_value=Mock(content=html)): - assert MediumParser('html', 'ble')("bla.com") == { - 'authors': [], - 'date_published': parse('Oct 7, 2023').replace(tzinfo=pytz.UTC), - 'source': 'html', - 'source_type': 'blog', - 'text': 'bla bla bla [a link](http://ble.com) bla bla', - 'title': 'This is the title', - 'url': 'bla.com', + assert MediumParser("html", "ble")("bla.com") == { + "authors": [], + "date_published": parse("Oct 7, 2023").replace(tzinfo=pytz.UTC), + "source": "html", + "source_type": "blog", + "text": "bla bla bla [a link](http://ble.com) bla bla", + "title": "This is the title", + "url": "bla.com", } @@ -83,14 +84,14 @@ def test_medium_blog_no_title(): """ with patch("requests.get", return_value=Mock(content=html)): - assert MediumParser(name='html', url='')("bla.com") == { - 'authors': [], - 'date_published': None, - 'source': 'html', - 'source_type': 'blog', - 'text': 'bla bla bla [a link](http://ble.com) bla bla', - 'title': None, - 'url': 'bla.com', + assert MediumParser(name="html", url="")("bla.com") == { + "authors": [], + "date_published": None, + "source": "html", + "source_type": "blog", + "text": "bla bla bla [a link](http://ble.com) bla bla", + "title": None, + "url": "bla.com", } @@ -105,13 +106,13 @@ def test_medium_blog_no_contents(): """ - with patch('requests.get', return_value=Mock(content=html)): - assert MediumParser(name='html', url='')('bla.com') == { - 'authors': [], - 'date_published': None, - 'source': 'html', - 'source_type': 'blog', - 'text': None, - 'title': None, - 'url': 'bla.com', + with patch("requests.get", return_value=Mock(content=html)): + assert MediumParser(name="html", url="")("bla.com") == { + "authors": [], + "date_published": None, + "source": "html", + "source_type": "blog", + "text": None, + "title": None, + "url": "bla.com", } diff --git a/tests/align_data/articles/test_updater.py b/tests/align_data/articles/test_updater.py index 7d11fbb7..f9e2aea2 100644 --- a/tests/align_data/articles/test_updater.py +++ b/tests/align_data/articles/test_updater.py @@ -9,39 +9,43 @@ SAMPLE_UPDATES = [ {}, - {'title': 'no id - should be ignored'}, - - {'id': '122', 'hash_id': 'deadbeef000'}, + {"title": "no id - should be ignored"}, + {"id": "122", "hash_id": "deadbeef000"}, + { + "id": "123", + "hash_id": "deadbeef001", + "title": "bla bla", + "url": "http://bla.com", + "source_url": "http://bla.bla.com", + "authors": "mr. blobby, johnny", + }, + { + "id": "124", + "title": "no hash id", + "url": "http://bla.com", + "source_url": "http://bla.bla.com", + "authors": "mr. blobby", + }, { - 'id': '123', 'hash_id': 'deadbeef001', - 'title': 'bla bla', - 'url': 'http://bla.com', - 'source_url': 'http://bla.bla.com', - 'authors': 'mr. blobby, johnny', - }, { - 'id': '124', - 'title': 'no hash id', - 'url': 'http://bla.com', - 'source_url': 'http://bla.bla.com', - 'authors': 'mr. blobby', - }, { - 'hash_id': 'deadbeef002', - 'title': 'no id', - 'url': 'http://bla.com', - 'source_url': 'http://bla.bla.com', - 'authors': 'mr. blobby', - }, { - 'id': '125', - 'title': 'no hash id, url or title', - 'authors': 'mr. blobby', - } + "hash_id": "deadbeef002", + "title": "no id", + "url": "http://bla.com", + "source_url": "http://bla.bla.com", + "authors": "mr. blobby", + }, + { + "id": "125", + "title": "no hash id, url or title", + "authors": "mr. blobby", + }, ] + @pytest.fixture def csv_file(tmp_path): - filename = tmp_path / 'data.csv' - with open(filename, 'w', newline='') as csvfile: - fieldnames = ['id', 'hash_id', 'title', 'url', 'source_url', 'authors'] + filename = tmp_path / "data.csv" + with open(filename, "w", newline="") as csvfile: + fieldnames = ["id", "hash_id", "title", "url", "source_url", "authors"] writer = DictWriter(csvfile, fieldnames=fieldnames) writer.writeheader() @@ -51,152 +55,195 @@ def csv_file(tmp_path): def test_items_list(csv_file): - dataset = ReplacerDataset(name='bla', csv_path=csv_file, delimiter=',') + dataset = ReplacerDataset(name="bla", csv_path=csv_file, delimiter=",") def mock_entries(): return [ Mock( - _id=dataset.maybe(v, 'id'), - id=dataset.maybe(v, 'hash_id'), - title=dataset.maybe(v, 'title'), - url=dataset.maybe(v, 'url'), - authors=dataset.maybe(v, 'authors') + _id=dataset.maybe(v, "id"), + id=dataset.maybe(v, "hash_id"), + title=dataset.maybe(v, "title"), + url=dataset.maybe(v, "url"), + authors=dataset.maybe(v, "authors"), ) for v in dataset.csv_items ] - with patch.object(dataset, 'read_entries', mock_entries): + with patch.object(dataset, "read_entries", mock_entries): items = dataset.items_list - assert len(items) == 5, "items_list should only contain items with valid ids - something is wrong" + assert ( + len(items) == 5 + ), "items_list should only contain items with valid ids - something is wrong" for item in items: - assert dataset.maybe(item.updates, 'id') == item.article._id - assert dataset.maybe(item.updates, 'hash_id') == item.article.id - assert dataset.maybe(item.updates, 'title') == item.article.title - assert dataset.maybe(item.updates, 'url') == item.article.url - assert dataset.maybe(item.updates, 'authors') == item.article.authors - - -@pytest.mark.parametrize('updates', ( - Mock(url='http://some.other.url'), - Mock(source_url='http://some.other.url'), - Mock(url='http://some.other.url', source_url='http://another.url'), -)) + assert dataset.maybe(item.updates, "id") == item.article._id + assert dataset.maybe(item.updates, "hash_id") == item.article.id + assert dataset.maybe(item.updates, "title") == item.article.title + assert dataset.maybe(item.updates, "url") == item.article.url + assert dataset.maybe(item.updates, "authors") == item.article.authors + + +@pytest.mark.parametrize( + "updates", + ( + Mock(url="http://some.other.url"), + Mock(source_url="http://some.other.url"), + Mock(url="http://some.other.url", source_url="http://another.url"), + ), +) def test_update_text(csv_file, updates): - dataset = ReplacerDataset(name='bla', csv_path=csv_file, delimiter=',') + dataset = ReplacerDataset(name="bla", csv_path=csv_file, delimiter=",") - article = Mock(text='this should be changed', status='as should this', url='http:/bla.bla.com') + article = Mock(text="this should be changed", status="as should this", url="http:/bla.bla.com") - with patch('align_data.sources.articles.updater.item_metadata', return_value={'text': 'bla bla bla'}): + with patch( + "align_data.sources.articles.updater.item_metadata", + return_value={"text": "bla bla bla"}, + ): dataset.update_text(updates, article) - assert article.text == 'bla bla bla' + assert article.text == "bla bla bla" assert article.status == None -@pytest.mark.parametrize('updates', ( - Mock(url='http://some.other.url'), - Mock(source_url='http://some.other.url'), - Mock(url='http://some.other.url', source_url='http://another.url'), -)) +@pytest.mark.parametrize( + "updates", + ( + Mock(url="http://some.other.url"), + Mock(source_url="http://some.other.url"), + Mock(url="http://some.other.url", source_url="http://another.url"), + ), +) def test_update_text_error(csv_file, updates): - dataset = ReplacerDataset(name='bla', csv_path=csv_file, delimiter=',') + dataset = ReplacerDataset(name="bla", csv_path=csv_file, delimiter=",") - article = Mock(text='this should not be changed', status='but this should be', url='http:/bla.bla.com') + article = Mock( + text="this should not be changed", + status="but this should be", + url="http:/bla.bla.com", + ) - with patch('align_data.sources.articles.updater.item_metadata', return_value={'error': 'oh noes!'}): + with patch( + "align_data.sources.articles.updater.item_metadata", + return_value={"error": "oh noes!"}, + ): dataset.update_text(updates, article) - assert article.text == 'this should not be changed' - assert article.status == 'oh noes!' - - -@pytest.mark.parametrize('updates', ( - Mock(url='http://bla.bla.com', source_url=None, comment='Same url as article, no source_url'), - Mock(url='http://bla.bla.com', source_url='', comment='Same url as article, empty source_url'), - Mock(url=None, source_url=None, comment='no urls provided'), -)) + assert article.text == "this should not be changed" + assert article.status == "oh noes!" + + +@pytest.mark.parametrize( + "updates", + ( + Mock( + url="http://bla.bla.com", + source_url=None, + comment="Same url as article, no source_url", + ), + Mock( + url="http://bla.bla.com", + source_url="", + comment="Same url as article, empty source_url", + ), + Mock(url=None, source_url=None, comment="no urls provided"), + ), +) def test_update_text_no_update(csv_file, updates): - dataset = ReplacerDataset(name='bla', csv_path=csv_file, delimiter=',') + dataset = ReplacerDataset(name="bla", csv_path=csv_file, delimiter=",") - article = Mock(text='this should not be changed', status='as should not this', url='http://bla.bla.com') + article = Mock( + text="this should not be changed", + status="as should not this", + url="http://bla.bla.com", + ) - with patch('align_data.sources.articles.updater.item_metadata', return_value={'text': 'bla bla bla'}): + with patch( + "align_data.sources.articles.updater.item_metadata", + return_value={"text": "bla bla bla"}, + ): dataset.update_text(updates, article) - assert article.text == 'this should not be changed' - assert article.status == 'as should not this' + assert article.text == "this should not be changed" + assert article.status == "as should not this" def test_process_entry(csv_file): - dataset = ReplacerDataset(name='bla', csv_path=csv_file, delimiter=',') + dataset = ReplacerDataset(name="bla", csv_path=csv_file, delimiter=",") article = Article( - _id=123, id='deadbeef0123', - title='this should be changed', - url='this should be changed', - text='this should be changed', - authors='this should be changed', - date_published='this should be changed', + _id=123, + id="deadbeef0123", + title="this should be changed", + url="this should be changed", + text="this should be changed", + authors="this should be changed", + date_published="this should be changed", ) updates = Mock( - id='123', - hash_id='deadbeef001', - title='bla bla', - url='http://bla.com', - source_url='http://bla.bla.com', - source='tests', - authors='mr. blobby, johnny', - date_published='2000-12-23T10:32:43Z', + id="123", + hash_id="deadbeef001", + title="bla bla", + url="http://bla.com", + source_url="http://bla.bla.com", + source="tests", + authors="mr. blobby, johnny", + date_published="2000-12-23T10:32:43Z", ) - with patch('align_data.sources.articles.updater.item_metadata', return_value={'text': 'bla bla bla'}): + with patch( + "align_data.sources.articles.updater.item_metadata", + return_value={"text": "bla bla bla"}, + ): assert dataset.process_entry(Item(updates, article)).to_dict() == { - 'authors': ['mr. blobby', 'johnny'], - 'date_published': '2000-12-23T10:32:43Z', - 'id': 'd8d8cad8d28739a0862654a0e6e8ce6e', - 'source': 'tests', - 'source_type': None, - 'summaries': [], - 'text': 'bla bla bla', - 'title': 'bla bla', - 'url': 'http://bla.com', - 'source_url': 'http://bla.bla.com', + "authors": ["mr. blobby", "johnny"], + "date_published": "2000-12-23T10:32:43Z", + "id": "d8d8cad8d28739a0862654a0e6e8ce6e", + "source": "tests", + "source_type": None, + "summaries": [], + "text": "bla bla bla", + "title": "bla bla", + "url": "http://bla.com", + "source_url": "http://bla.bla.com", } def test_process_entry_empty(csv_file): - dataset = ReplacerDataset(name='bla', csv_path=csv_file, delimiter=',') + dataset = ReplacerDataset(name="bla", csv_path=csv_file, delimiter=",") article = Article( - _id=123, id='deadbeef0123', - title='this should not be changed', - url='this should not be changed', - source='this should not be changed', - authors='this should not be changed', - - text='this should be changed', - date_published='this should be changed', + _id=123, + id="deadbeef0123", + title="this should not be changed", + url="this should not be changed", + source="this should not be changed", + authors="this should not be changed", + text="this should be changed", + date_published="this should be changed", ) updates = Mock( - id='123', - hash_id='deadbeef001', + id="123", + hash_id="deadbeef001", title=None, - url='', - source_url='http://bla.bla.com', - source=' ', - authors=' \n \n \t \t ', - date_published='2000-12-23T10:32:43Z', + url="", + source_url="http://bla.bla.com", + source=" ", + authors=" \n \n \t \t ", + date_published="2000-12-23T10:32:43Z", ) - with patch('align_data.sources.articles.updater.item_metadata', return_value={'text': 'bla bla bla'}): + with patch( + "align_data.sources.articles.updater.item_metadata", + return_value={"text": "bla bla bla"}, + ): assert dataset.process_entry(Item(updates, article)).to_dict() == { - 'authors': ['this should not be changed'], - 'date_published': '2000-12-23T10:32:43Z', - 'id': '606e9224254f508d297bcb17bcc6d104', - 'source': 'this should not be changed', - 'source_type': None, - 'summaries': [], - 'text': 'bla bla bla', - 'title': 'this should not be changed', - 'url': 'this should not be changed', - 'source_url': 'http://bla.bla.com', + "authors": ["this should not be changed"], + "date_published": "2000-12-23T10:32:43Z", + "id": "606e9224254f508d297bcb17bcc6d104", + "source": "this should not be changed", + "source_type": None, + "summaries": [], + "text": "bla bla bla", + "title": "this should not be changed", + "url": "this should not be changed", + "source_url": "http://bla.bla.com", } diff --git a/tests/align_data/common/test_alignment_dataset.py b/tests/align_data/common/test_alignment_dataset.py index e45c62eb..d18aaf78 100644 --- a/tests/align_data/common/test_alignment_dataset.py +++ b/tests/align_data/common/test_alignment_dataset.py @@ -75,41 +75,68 @@ def test_data_entry_id_from_urls_and_title(): ) -@pytest.mark.parametrize('item, error', ( - ( - {"key1": 12, "key2": 312, "title": "wikipedia goes to war on porcupines", "url": "asd"}, - 'missing fields: date_published, source, text' - ), - ( - {"key1": 12, "key2": 312, "url": "www.wikipedia.org", "text": "asdasd", "title": "asdasd"}, - 'missing fields: date_published, source' - ), - ( - { - "key1": 12, "key2": 312, "url": "www.wikipedia.org", "title": "bla", - "source": "dwe", "date_published": "dwe" - }, - 'missing fields: text' - ), - ( - { - "key1": 12, "key2": 312, "url": "www.wikipedia.org", "title": "bla", - "text": "asdasd", "date_published": "dwe" - }, - 'missing fields: source' - ), +@pytest.mark.parametrize( + "item, error", ( - { - "key1": 12, "key2": 312, "url": "www.wikipedia.org", "title": "bla", "text": "asdasd", "source": "dwe" - }, - 'missing fields: date_published' + ( + { + "key1": 12, + "key2": 312, + "title": "wikipedia goes to war on porcupines", + "url": "asd", + }, + "missing fields: date_published, source, text", + ), + ( + { + "key1": 12, + "key2": 312, + "url": "www.wikipedia.org", + "text": "asdasd", + "title": "asdasd", + }, + "missing fields: date_published, source", + ), + ( + { + "key1": 12, + "key2": 312, + "url": "www.wikipedia.org", + "title": "bla", + "source": "dwe", + "date_published": "dwe", + }, + "missing fields: text", + ), + ( + { + "key1": 12, + "key2": 312, + "url": "www.wikipedia.org", + "title": "bla", + "text": "asdasd", + "date_published": "dwe", + }, + "missing fields: source", + ), + ( + { + "key1": 12, + "key2": 312, + "url": "www.wikipedia.org", + "title": "bla", + "text": "asdasd", + "source": "dwe", + }, + "missing fields: date_published", + ), ), -)) +) def test_data_entry_missing(item, error): dataset = AlignmentDataset(name="blaa") entry = dataset.make_data_entry(item) Article.before_write(None, None, entry) - assert entry.status == 'Missing fields' + assert entry.status == "Missing fields" assert entry.comments == error @@ -136,7 +163,7 @@ def test_data_entry_verify_id_fails(): "id": "f2b4e02fc1dd8ae43845e4f930f2d84f", } ) - expected = 'Entry id f2b4e02fc1dd8ae43845e4f930f2d84f does not match id from id_fields: 770fe57c8c2130eda08dc392b8696f97' + expected = "Entry id f2b4e02fc1dd8ae43845e4f930f2d84f does not match id from id_fields: 770fe57c8c2130eda08dc392b8696f97" with pytest.raises(AssertionError, match=expected): entry.verify_id() @@ -172,7 +199,9 @@ def test_data_entry_verify_fields_fails(data, error): def test_data_entry_id_fields(): dataset = AlignmentDataset(name="blaa") - entry = dataset.make_data_entry({"url": "https://www.google.ca/once_upon_a_time", 'title': 'bla'}) + entry = dataset.make_data_entry( + {"url": "https://www.google.ca/once_upon_a_time", "title": "bla"} + ) Article.before_write(None, None, entry) assert entry.id @@ -246,16 +275,11 @@ def test_unprocessed_items_some_done(numbers_dataset): def test_fetch_entries(numbers_dataset): - assert [i.meta["value"] for i in numbers_dataset.fetch_entries()] == [ - i**2 for i in range(10) - ] + assert [i.meta["value"] for i in numbers_dataset.fetch_entries()] == [i**2 for i in range(10)] def test_format_datatime(dataset): - assert ( - dataset._format_datetime(datetime(2022, 1, 1, 12, 23, 43)) - == "2022-01-01T12:23:43Z" - ) + assert dataset._format_datetime(datetime(2022, 1, 1, 12, 23, 43)) == "2022-01-01T12:23:43Z" def test_format_datatime_ignore_timezone(dataset): diff --git a/tests/align_data/common/test_html_dataset.py b/tests/align_data/common/test_html_dataset.py index 25e84b8b..3efbddb0 100644 --- a/tests/align_data/common/test_html_dataset.py +++ b/tests/align_data/common/test_html_dataset.py @@ -91,16 +91,12 @@ def test_html_dataset_items_list(html_dataset): def test_html_datasetfetch_contents(html_dataset): with patch("requests.get", return_value=Mock(content=SAMPLE_HTML)): - assert html_dataset.fetch_contents("url") == BeautifulSoup( - SAMPLE_HTML, "html.parser" - ) + assert html_dataset.fetch_contents("url") == BeautifulSoup(SAMPLE_HTML, "html.parser") def test_html_dataset_get_text(html_dataset): soup = BeautifulSoup(f"
{SAMPLE_CONTENTS}
", "html.parser") - assert ( - html_dataset._get_text(soup) == "bla bla bla [a link](http://ble.com) bla bla" - ) + assert html_dataset._get_text(soup) == "bla bla bla [a link](http://ble.com) bla bla" def test_html_dataset_find_date(html_dataset): @@ -125,10 +121,7 @@ def test_html_dataset_find_date(html_dataset): ), ) def test_html_dataset_extract_metadata(html_dataset, text): - assert ( - html_dataset._extract_markdown(text) - == "bla bla bla [a link](http://ble.com) bla bla" - ) + assert html_dataset._extract_markdown(text) == "bla bla bla [a link](http://ble.com) bla bla" def test_html_dataset_process_entry(html_dataset): @@ -176,9 +169,7 @@ def test_html_dataset_process_entry_no_text(html_dataset): ), ) def test_rss_dataset_extract_authors(item, authors): - dataset = RSSDataset( - name="bla", url="http://example.org", authors=["default author"] - ) + dataset = RSSDataset(name="bla", url="http://example.org", authors=["default author"]) assert dataset.extract_authors(item) == authors @@ -202,9 +193,7 @@ def test_rss_dataset_get_title(): ), ) def test_rss_dataset_get_published_date(item, date): - dataset = RSSDataset( - name="bla", url="http://example.org", authors=["default author"] - ) + dataset = RSSDataset(name="bla", url="http://example.org", authors=["default author"]) assert dataset._get_published_date(item) == date @@ -263,6 +252,4 @@ def test_rss_dataset_items_list(): } with patch("feedparser.parse", return_value=contents): - assert dataset.items_list == [ - f"http://example.org/article-{i}" for i in range(5) - ] + assert dataset.items_list == [f"http://example.org/article-{i}" for i in range(5)] diff --git a/align_data/common/utils.py b/tests/align_data/embeddings/test_embedding_utils.py similarity index 100% rename from align_data/common/utils.py rename to tests/align_data/embeddings/test_embedding_utils.py diff --git a/tests/align_data/embeddings/test_pinecone_db_handler.py b/tests/align_data/embeddings/test_pinecone_db_handler.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/align_data/test_alignment_newsletter.py b/tests/align_data/sources/test_alignment_newsletter.py similarity index 98% rename from tests/align_data/test_alignment_newsletter.py rename to tests/align_data/sources/test_alignment_newsletter.py index ffd5e31b..da8fa043 100644 --- a/tests/align_data/test_alignment_newsletter.py +++ b/tests/align_data/sources/test_alignment_newsletter.py @@ -43,9 +43,7 @@ def test_process_entry_no_summary(dataset): def test_format_datatime(dataset): - assert dataset._get_published_date(2022) == datetime( - 2022, 1, 1, tzinfo=timezone.utc - ) + assert dataset._get_published_date(2022) == datetime(2022, 1, 1, tzinfo=timezone.utc) def test_process_entry(dataset): diff --git a/tests/align_data/test_arbital.py b/tests/align_data/sources/test_arbital.py similarity index 96% rename from tests/align_data/test_arbital.py rename to tests/align_data/sources/test_arbital.py index 304c5398..af65ed05 100644 --- a/tests/align_data/test_arbital.py +++ b/tests/align_data/sources/test_arbital.py @@ -127,9 +127,7 @@ def post(url, *args, **kwargs): page = json.loads(kwargs.get("data", "{}")).get("pageAlias") if "json/explore" in url: - response.json.return_value = { - "pages": {f"{page}-{i}": i for i in range(10)} - } + response.json.return_value = {"pages": {f"{page}-{i}": i for i in range(10)}} elif "json/primaryPage" in url: response.json.return_value = { "pages": { @@ -201,9 +199,7 @@ def test_extract_authors_ignore_missing(dataset): page = {"changeLogs": [{"userId": author} for author in authors]} with patch.object(dataset, "get_title", lambda author: author): - assert sorted(dataset.extract_authors(page)) == sorted( - ["John Snow", "mr. blobby"] - ) + assert sorted(dataset.extract_authors(page)) == sorted(["John Snow", "mr. blobby"]) @pytest.mark.parametrize( diff --git a/tests/align_data/test_arxiv.py b/tests/align_data/sources/test_arxiv.py similarity index 100% rename from tests/align_data/test_arxiv.py rename to tests/align_data/sources/test_arxiv.py diff --git a/tests/align_data/test_blogs.py b/tests/align_data/sources/test_blogs.py similarity index 93% rename from tests/align_data/test_blogs.py rename to tests/align_data/sources/test_blogs.py index 23789b2e..27ffefa3 100644 --- a/tests/align_data/test_blogs.py +++ b/tests/align_data/sources/test_blogs.py @@ -162,10 +162,7 @@ def test_caradomoe_text(): """ soup = BeautifulSoup(contents, "html.parser") - assert ( - dataset._get_text({"soup": soup}) - == "bla bla bla [a link](http://ble.com) bla bla" - ) + assert dataset._get_text({"soup": soup}) == "bla bla bla [a link](http://ble.com) bla bla" def test_caradomoe_process_entry(): @@ -230,9 +227,7 @@ def test_caradomoe_process_entry(): def test_gwern_get_text(): - dataset = GwernBlog( - name="gwern_blog", url="https://www.gwern.net/", authors=["Gwern Branwen"] - ) + dataset = GwernBlog(name="gwern_blog", url="https://www.gwern.net/", authors=["Gwern Branwen"]) soup = BeautifulSoup(GWERN_CONTENTS, "html.parser") assert dataset._get_text(soup) == "bla bla bla [a link](http://ble.com) bla bla" @@ -252,17 +247,13 @@ def test_gwern_get_text(): ), ) def test_gwern_get_published_date(metadata, date): - dataset = GwernBlog( - name="gwern_blog", url="https://www.gwern.net/", authors=["Gwern Branwen"] - ) + dataset = GwernBlog(name="gwern_blog", url="https://www.gwern.net/", authors=["Gwern Branwen"]) assert dataset._get_published_date(metadata) == date def test_gwern_get_article(): - dataset = GwernBlog( - name="gwern_blog", url="https://www.gwern.net/", authors=["Gwern Branwen"] - ) + dataset = GwernBlog(name="gwern_blog", url="https://www.gwern.net/", authors=["Gwern Branwen"]) with patch("requests.get", return_value="article contents"): assert dataset._get_article("http://bla.com") == "article contents" @@ -303,13 +294,9 @@ def test_gwern_process_markdown(): ... {SAMPLE_HTML} """ - dataset = GwernBlog( - name="gwern_blog", url="https://www.gwern.net/", authors=["Gwern Branwen"] - ) + dataset = GwernBlog(name="gwern_blog", url="https://www.gwern.net/", authors=["Gwern Branwen"]) - assert dataset._process_markdown( - "http://article.url", Mock(text=text) - ).to_dict() == { + assert dataset._process_markdown("http://article.url", Mock(text=text)).to_dict() == { "authors": ["Gwern Branwen"], "date_published": "2020-05-28T00:00:00Z", "id": None, @@ -330,13 +317,9 @@ def test_gwern_process_entry_markdown(): ... {SAMPLE_HTML} """ - dataset = GwernBlog( - name="gwern_blog", url="https://www.gwern.net/", authors=["Gwern Branwen"] - ) + dataset = GwernBlog(name="gwern_blog", url="https://www.gwern.net/", authors=["Gwern Branwen"]) - with patch( - "requests.get", return_value=Mock(text=text, status_code=200, headers={}) - ): + with patch("requests.get", return_value=Mock(text=text, status_code=200, headers={})): assert dataset.process_entry("http://article.url").to_dict() == { "authors": ["Gwern Branwen"], "date_published": "2020-05-28T00:00:00Z", @@ -351,9 +334,7 @@ def test_gwern_process_entry_markdown(): def test_gwern_process_entry_html(): - dataset = GwernBlog( - name="gwern_blog", url="https://www.gwern.net/", authors=["Gwern Branwen"] - ) + dataset = GwernBlog(name="gwern_blog", url="https://www.gwern.net/", authors=["Gwern Branwen"]) with patch( "requests.get", @@ -377,9 +358,7 @@ def test_gwern_process_entry_html(): def test_gwern_process_entry_erro(): - dataset = GwernBlog( - name="gwern_blog", url="https://www.gwern.net/", authors=["Gwern Branwen"] - ) + dataset = GwernBlog(name="gwern_blog", url="https://www.gwern.net/", authors=["Gwern Branwen"]) with patch("requests.get", return_value=Mock(status_code=404)): assert dataset.process_entry("http://article.url") is None @@ -491,9 +470,7 @@ def test_substack_blog_process_entry(): "title": "Eliezer S. Yudkowsky", "link": "https://www.yudkowsky.net", }, - "headers": { - "link": '; rel="https://api.w.org/"' - }, + "headers": {"link": '; rel="https://api.w.org/"'}, } @@ -509,9 +486,7 @@ def test_wordpress_blog_setup(): @patch("feedparser.parse", return_value=WORDPRESS_FEED) def test_wordpress_blog_items_list(feedparser_parse): blog = WordpressBlog(name="blog", url="https://www.bla.yudkowsky.net") - assert blog.items_list == [ - "https://www.yudkowsky.net/other/fiction/prospiracy-theory" - ] + assert blog.items_list == ["https://www.yudkowsky.net/other/fiction/prospiracy-theory"] def test_wordpress_blog_get_item_key(): @@ -528,9 +503,7 @@ def test_wordpress_blog_get_published_date(): name="blog", url="https://www.bla.yudkowsky.net", ) - date_published = blog._get_published_date( - {"published": "Mon, 26 Jun 2023 13:40:01 +0000"} - ) + date_published = blog._get_published_date({"published": "Mon, 26 Jun 2023 13:40:01 +0000"}) assert date_published == parse("2023-06-26T13:40:01Z") @@ -541,9 +514,7 @@ def test_wordpress_blog_process_entry(feedparser_parse): url="https://www.bla.yudkowsky.net", ) blog.items = {i["link"]: i for i in WORDPRESS_FEED["entries"]} - entry = blog.process_entry( - "https://www.yudkowsky.net/other/fiction/prospiracy-theory" - ) + entry = blog.process_entry("https://www.yudkowsky.net/other/fiction/prospiracy-theory") assert entry.to_dict() == { "authors": ["Eliezer S. Yudkowsky"], "date_published": "2020-09-04T04:11:23Z", @@ -643,10 +614,8 @@ def test_openai_research_get_text(): dataset = OpenAIResearch(name="openai", url="bla.bla") soup = BeautifulSoup(OPENAI_HTML, "html.parser") - parsers = {"arxiv.org": lambda _: {'text': 'bla bla bla'}} - with patch( - "requests.head", return_value=Mock(headers={"Content-Type": "text/html"}) - ): + parsers = {"arxiv.org": lambda _: {"text": "bla bla bla"}} + with patch("requests.head", return_value=Mock(headers={"Content-Type": "text/html"})): with patch("align_data.sources.articles.parsers.PDF_PARSERS", parsers): assert dataset._get_text(soup) == "bla bla bla" @@ -697,10 +666,8 @@ def test_openai_research_process_entry(): dataset = OpenAIResearch(name="openai", url="bla.bla") soup = BeautifulSoup(OPENAI_HTML, "html.parser") - parsers = {"arxiv.org": lambda _: {'text': 'bla bla bla'}} - with patch( - "requests.head", return_value=Mock(headers={"Content-Type": "text/html"}) - ): + parsers = {"arxiv.org": lambda _: {"text": "bla bla bla"}} + with patch("requests.head", return_value=Mock(headers={"Content-Type": "text/html"})): with patch("requests.get", return_value=Mock(content=OPENAI_HTML)): with patch("align_data.sources.articles.parsers.PDF_PARSERS", parsers): assert dataset.process_entry(soup).to_dict() == { diff --git a/tests/align_data/test_distill.py b/tests/align_data/sources/test_distill.py similarity index 98% rename from tests/align_data/test_distill.py rename to tests/align_data/sources/test_distill.py index 6ced02df..ac1b7f34 100644 --- a/tests/align_data/test_distill.py +++ b/tests/align_data/sources/test_distill.py @@ -76,9 +76,7 @@ def test_extra_values(): """ soup = BeautifulSoup(contents, "html.parser") - assert dataset._extra_values( - {"soup": soup, "summary": "A wild summary has appeared!"} - ) == { + assert dataset._extra_values({"soup": soup, "summary": "A wild summary has appeared!"}) == { "bibliography": [ { "link": "https://doi.org/10.23915/distill.00033", diff --git a/tests/align_data/test_greater_wrong.py b/tests/align_data/sources/test_greater_wrong.py similarity index 97% rename from tests/align_data/test_greater_wrong.py rename to tests/align_data/sources/test_greater_wrong.py index 29140794..b8a9e73d 100644 --- a/tests/align_data/test_greater_wrong.py +++ b/tests/align_data/sources/test_greater_wrong.py @@ -84,9 +84,7 @@ def test_greaterwrong_get_item_key(dataset): def test_greaterwrong_get_published_date(dataset): - assert dataset._get_published_date({"postedAt": "2021/02/01"}) == parse( - "2021-02-01T00:00:00Z" - ) + assert dataset._get_published_date({"postedAt": "2021/02/01"}) == parse("2021-02-01T00:00:00Z") def test_greaterwrong_get_published_date_missing(dataset): @@ -152,9 +150,7 @@ def fetcher(next_date): ] return {"results": results} - mock_items = ( - i for i in [Mock(date_published=datetime.fromisoformat("2014-12-12T01:23:45"))] - ) + mock_items = (i for i in [Mock(date_published=datetime.fromisoformat("2014-12-12T01:23:45"))]) with patch.object(dataset, "fetch_posts", fetcher): with patch.object(dataset, "make_query", lambda next_date: next_date): with patch.object(dataset, "read_entries", return_value=mock_items): diff --git a/tests/align_data/test_stampy.py b/tests/align_data/sources/test_stampy.py similarity index 83% rename from tests/align_data/test_stampy.py rename to tests/align_data/sources/test_stampy.py index 5d4500b5..9b40a2c3 100644 --- a/tests/align_data/test_stampy.py +++ b/tests/align_data/sources/test_stampy.py @@ -14,17 +14,14 @@ def test_validate_coda_token(): def test_get_item_key(): dataset = Stampy(name="bla") - assert ( - dataset.get_item_key({"Question": "Why not just?"}) - == "Why\nnot just?" - ) + assert dataset.get_item_key({"Question": "Why not just?"}) == "Why\nnot just?" def test_get_published_date(): dataset = Stampy(name="bla") - assert dataset._get_published_date( - {"Doc Last Edited": "2012/01/03 12:23:32"} - ) == parse("2012-01-03T12:23:32Z") + assert dataset._get_published_date({"Doc Last Edited": "2012/01/03 12:23:32"}) == parse( + "2012-01-03T12:23:32Z" + ) def test_get_published_date_missing(): diff --git a/tests/align_data/test_youtube.py b/tests/align_data/sources/test_youtube.py similarity index 93% rename from tests/align_data/test_youtube.py rename to tests/align_data/sources/test_youtube.py index bcb720e8..70bbed94 100644 --- a/tests/align_data/test_youtube.py +++ b/tests/align_data/sources/test_youtube.py @@ -46,9 +46,7 @@ def test_next_page_empty_by_default(): }, { "kind": "youtube#playlistItem", - "snippet": { - "resourceId": {"kind": "youtube#video", "videoId": "your_video_id"} - }, + "snippet": {"resourceId": {"kind": "youtube#video", "videoId": "your_video_id"}}, }, ), ) @@ -72,9 +70,7 @@ def test_get_id_with_id(item): }, { "kind": "youtube#playlistItem", - "snippet": { - "resourceId": {"kind": "invalid_kind", "videoId": "your_video_id"} - }, + "snippet": {"resourceId": {"kind": "invalid_kind", "videoId": "your_video_id"}}, }, ), ) @@ -187,8 +183,7 @@ def test_items_list(): def fetcher(collection_id): return [ - {"id": {"kind": "youtube#video", "videoId": f"{collection_id}_{i}"}} - for i in range(3) + {"id": {"kind": "youtube#video", "videoId": f"{collection_id}_{i}"}} for i in range(3) ] with patch.object(dataset, "fetch_videos", fetcher): @@ -208,9 +203,7 @@ def test_get_item_key(): "id": {"kind": "youtube#video", "videoId": "your_video_id"}, "kind": "youtube#searchResult", } - assert ( - dataset.get_item_key(video) == "https://www.youtube.com/watch?v=your_video_id" - ) + assert dataset.get_item_key(video) == "https://www.youtube.com/watch?v=your_video_id" @pytest.mark.parametrize( @@ -229,9 +222,7 @@ def test_get_contents_with_no_transcript_found(error): } transcriber = Mock() - transcriber.list_transcripts.return_value.find_transcript.return_value.fetch.side_effect = ( - error - ) + transcriber.list_transcripts.return_value.find_transcript.return_value.fetch.side_effect = error with patch("align_data.sources.youtube.youtube.YouTubeTranscriptApi", transcriber): assert dataset._get_contents(video) is None @@ -336,16 +327,12 @@ def test_channel_process_item(transcriber): def test_playlist_collection_ids(): - dataset = YouTubePlaylistDataset( - name="bla", playlist_ids=["a list id", "another id"] - ) + dataset = YouTubePlaylistDataset(name="bla", playlist_ids=["a list id", "another id"]) assert dataset.collection_ids == ["a list id", "another id"] def test_playlist_published_date(): - dataset = YouTubePlaylistDataset( - name="bla", playlist_ids=["a list id", "another id"] - ) + dataset = YouTubePlaylistDataset(name="bla", playlist_ids=["a list id", "another id"]) video = { "kind": "youtube#playlistItem", "snippet": { @@ -359,9 +346,7 @@ def test_playlist_published_date(): def test_channel_process_item(transcriber): - dataset = YouTubePlaylistDataset( - name="bla", playlist_ids=["a list id", "another id"] - ) + dataset = YouTubePlaylistDataset(name="bla", playlist_ids=["a list id", "another id"]) video = { "kind": "youtube#playlistItem", "snippet": { diff --git a/tests/print_date_published.py b/tests/print_date_published.py index d3ba46fd..32bda616 100644 --- a/tests/print_date_published.py +++ b/tests/print_date_published.py @@ -20,9 +20,7 @@ def validate_date_format(file_path, keys_to_print): # Try to parse the date_published string into a datetime object parse(date_published) except ValueError: - print( - f"Row {i}: date_published is NOT in a valid format: {date_published}" - ) + print(f"Row {i}: date_published is NOT in a valid format: {date_published}") for key in keys_to_print: print(f" {key}: {entry.get(key)}") diff --git a/upload_to_huggingface.py b/upload_to_huggingface.py index 9ccc74ed..9e7481f8 100644 --- a/upload_to_huggingface.py +++ b/upload_to_huggingface.py @@ -10,41 +10,39 @@ from huggingface_hub import HfApi -GDOCS_FOLDER = ( - "https://drive.google.com/drive/folders/1n4i0J4CuSfNmrUkKPyTFKJU0XWYLtRF8" -) +GDOCS_FOLDER = "https://drive.google.com/drive/folders/1n4i0J4CuSfNmrUkKPyTFKJU0XWYLtRF8" DATASOURCES = [ - 'agentmodels', - 'aiimpacts', - 'aisafety.camp', - 'aisafety.info', - 'ai_alignment_playlist', - 'ai_explained', - 'ai_safety_talks', - 'ai_safety_reading_group', - 'ai_tech_tu_delft', - 'alignmentforum', - 'arbital', - 'arxiv', - 'carado.moe', - 'cold_takes', - 'deepmind_blog', - 'deepmind_technical_blog', - 'distill', - 'eaforum', - 'eleuther.ai', - 'generative.ink', - 'gwern_blog', - 'importai', - 'jsteinhardt_blog', - 'lesswrong', - 'miri', - 'ml_safety_newsletter', - 'openai.research', - 'rob_miles_ai_safety', - 'special_docs', - 'vkrakovna_blog', - 'yudkowsky_blog' + "agentmodels", + "aiimpacts", + "aisafety.camp", + "aisafety.info", + "ai_alignment_playlist", + "ai_explained", + "ai_safety_talks", + "ai_safety_reading_group", + "ai_tech_tu_delft", + "alignmentforum", + "arbital", + "arxiv", + "carado.moe", + "cold_takes", + "deepmind_blog", + "deepmind_technical_blog", + "distill", + "eaforum", + "eleuther.ai", + "generative.ink", + "gwern_blog", + "importai", + "jsteinhardt_blog", + "lesswrong", + "miri", + "ml_safety_newsletter", + "openai.research", + "rob_miles_ai_safety", + "special_docs", + "vkrakovna_blog", + "yudkowsky_blog", ] @@ -70,11 +68,7 @@ def get_gdoc_names(url): return None _, id_name_type_iter = _parse_google_drive_file(url=url, content=res.text) - return [ - (id, name) - for id, name, filetype in id_name_type_iter - if name.endswith(".jsonl") - ] + return [(id, name) for id, name, filetype in id_name_type_iter if name.endswith(".jsonl")] def upload_data_file(api, name, repo_name): @@ -84,7 +78,7 @@ def upload_data_file(api, name, repo_name): # Don't download it if it exists locally if not filename.exists(): - print(f'{filename} not found!') + print(f"{filename} not found!") return try: @@ -99,9 +93,7 @@ def upload_data_file(api, name, repo_name): def download_file(repo_name, filename, api): headers = {"Authorization": f"Bearer {api.token}"} - url = ( - f"https://huggingface.co/datasets/StampyAI/{repo_name}/raw/main/{filename.name}" - ) + url = f"https://huggingface.co/datasets/StampyAI/{repo_name}/raw/main/{filename.name}" response = requests.get(url, headers=headers) if response.status_code == 200: