Skip to content

Commit

Permalink
Article checker
Browse files Browse the repository at this point in the history
  • Loading branch information
mruwnik committed Oct 16, 2023
1 parent 57db939 commit 36ad8cb
Show file tree
Hide file tree
Showing 8 changed files with 390 additions and 59 deletions.
65 changes: 65 additions & 0 deletions .github/workflows/check-articles.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
name: Check articles are valid

on:
workflow_call:
inputs:
datasource:
type: string
required: true
workflow_dispatch: # allow manual triggering
inputs:
datasource:
description: 'The datasource to process'
type: choice
options:
- all
- agentmodels
- agisf
- aisafety.info
- alignment_newsletter
- alignmentforum
- arbital
- arxiv
- blogs
- distill
- eaforum
- indices
- lesswrong
- special_docs
- youtube
schedule:
- cron: "0 */4 * * *" # Every 4 hours

jobs:
build-dataset:
runs-on: ubuntu-latest

steps:
- name: Checkout repository
uses: actions/checkout@v2

- name: Setup Python environment
uses: actions/setup-python@v2
with:
python-version: '3.x'

- name: Install Pandoc
run: |
if [ "${{ inputs.datasource }}" = "gdocs" ]; then
sudo apt-get update
sudo apt-get -y install pandoc
fi
- name: Install dependencies
run: pip install -r requirements.txt

- name: Process dataset
env:
CODA_TOKEN: ${{ secrets.CODA_TOKEN }}
AIRTABLE_API_KEY: ${{ secrets.AIRTABLE_API_KEY }}
YOUTUBE_API_KEY: ${{ secrets.YOUTUBE_API_KEY }}
ARD_DB_USER: ${{ secrets.ARD_DB_USER }}
ARD_DB_PASSWORD: ${{ secrets.ARD_DB_PASSWORD }}
ARD_DB_HOST: ${{ secrets.ARD_DB_HOST }}
ARD_DB_NAME: alignment_research_dataset
run: python main.py fetch ${{ inputs.datasource }}
122 changes: 68 additions & 54 deletions align_data/common/alignment_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,13 @@
import time
from dataclasses import dataclass, field, KW_ONLY
from pathlib import Path
from typing import List, Optional, Set, Iterable, Tuple, Generator
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Generator
from sqlalchemy import select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import joinedload

import pytz
from sqlalchemy import select, Select, JSON
from sqlalchemy import select, Select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import joinedload, Session
import jsonlines
Expand All @@ -23,6 +26,62 @@
logger = logging.getLogger(__name__)



def normalize_url(url: str | None) -> str | None:
if not url:
return url

# ending '/'
url = url.rstrip("/")

# Remove http and use https consistently
url = url.replace("http://", "https://")

# Remove www
url = url.replace("https://www.", "https://")

# Remove index.html or index.htm
url = re.sub(r'/index\.html?$', '', url)

# Convert youtu.be links to youtube.com
url = url.replace("https://youtu.be/", "https://youtube.com/watch?v=")

# Additional rules for mirror domains can be added here

# agisafetyfundamentals.com -> aisafetyfundamentals.com
url = url.replace("https://agisafetyfundamentals.com", "https://aisafetyfundamentals.com")

return url


def normalize_text(text: str | None) -> str | None:
return (text or '').replace('\n', ' ').replace('\r', '').strip() or None


def format_authors(authors: List[str]) -> str:
# TODO: Don't keep adding the same authors - come up with some way to reuse them
authors_str = ",".join(authors)
if len(authors_str) > 1024:
authors_str = ",".join(authors_str[:1024].split(",")[:-1])
return authors_str


def article_dict(data, **kwargs) -> Dict[str, Any]:
data = merge_dicts(data, kwargs)

summaries = data.pop("summaries", [])
summary = data.pop("summary", None)

data['summaries'] = summaries + [summary] if summary else []
data['authors'] = format_authors(data.pop("authors", []))
data['title'] = normalize_text(data.get('title'))

return dict(
meta={k: v for k, v in data.items() if k not in ARTICLE_MAIN_KEYS and v is not None},
**{k: v for k, v in data.items() if k in ARTICLE_MAIN_KEYS},
)


@dataclass
class AlignmentDataset:
"""The base dataset class."""
Expand Down Expand Up @@ -64,28 +123,10 @@ def __post_init__(self):
def __str__(self) -> str:
return self.name

def _add_authors(self, article: Article, authors: List[str]) -> Article:
# TODO: Don't keep adding the same authors - come up with some way to reuse them
article.authors = ",".join(authors)
if len(article.authors) > 1024:
article.authors = ",".join(article.authors[:1024].split(",")[:-1])
return article

def make_data_entry(self, data, **kwargs) -> Article:
data = merge_dicts(data, kwargs)

summaries = data.pop("summaries", [])
summary = data.pop("summary", None)
summaries += [summary] if summary else []

authors = data.pop("authors", [])
data['title'] = (data.get('title') or '').replace('\n', ' ').replace('\r', '') or None

article = Article(
meta={k: v for k, v in data.items() if k not in ARTICLE_MAIN_KEYS and v is not None},
**{k: v for k, v in data.items() if k in ARTICLE_MAIN_KEYS},
)
self._add_authors(article, authors)
data = article_dict(data, **kwargs)
summaries = data.pop('summaries', [])
article = Article(**data)
article.summaries += [Summary(text=summary, source=self.name) for summary in summaries]
return article

Expand All @@ -109,7 +150,7 @@ def read_entries(self, sort_by=None) -> Iterable[Article]:
query = self._query_items.options(joinedload(Article.summaries))
if sort_by is not None:
query = query.order_by(sort_by)

result = session.scalars(query)
for article in result.unique(): # removes duplicates
yield article
Expand Down Expand Up @@ -153,41 +194,14 @@ def get_item_key(self, item) -> str:
"""
return item.name

@staticmethod
def _normalize_url(url: str | None) -> str | None:
if not url:
return url

# ending '/'
url = url.rstrip("/")

# Remove http and use https consistently
url = url.replace("http://", "https://")

# Remove www
url = url.replace("https://www.", "https://")

# Remove index.html or index.htm
url = re.sub(r'/index\.html?$', '', url)

# Convert youtu.be links to youtube.com
url = url.replace("https://youtu.be/", "https://youtube.com/watch?v=")

# Additional rules for mirror domains can be added here

# agisafetyfundamentals.com -> aisafetyfundamentals.com
url = url.replace("https://agisafetyfundamentals.com", "https://aisafetyfundamentals.com")

return url

def _normalize_urls(self, urls: Iterable[str]) -> Set[str]:
return {self._normalize_url(url) for url in urls}
return {normalize_url(url) for url in urls}


def _load_outputted_items(self) -> Set[str]:
"""
Loads the outputted items from the database and returns them as a set.
if the done_key is not an attribute of Article, it will try to load it from the meta field.
"""
with make_session() as session:
Expand All @@ -207,7 +221,7 @@ def not_processed(self, item) -> bool:
# cause problems (e.g. massive RAM usage, big slow downs) then it will have to be switched around, so that
# this function runs a query to check if the item is in the database rather than first getting all done_keys.
# If it get's to that level, consider batching it somehow
return self._normalize_url(self.get_item_key(item)) not in self._outputted_items
return normalize_url(self.get_item_key(item)) not in self._outputted_items

def unprocessed_items(self, items=None) -> list | filter:
"""Return a list of all items to be processed.
Expand Down
1 change: 1 addition & 0 deletions align_data/db/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ class Article(Base):
date_updated: Mapped[Optional[datetime]] = mapped_column(
DateTime, onupdate=func.current_timestamp()
)
date_checked: Mapped[datetime] = mapped_column(DateTime, default=func.now()) # The timestamp when this article was last checked if still valid
status: Mapped[Optional[str]] = mapped_column(String(256))
comments: Mapped[Optional[str]] = mapped_column(LONGTEXT) # Editor comments. Can be anything

Expand Down
10 changes: 5 additions & 5 deletions align_data/sources/articles/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from pypandoc import convert_file
from sqlalchemy import select, Select

from align_data.common.alignment_dataset import AlignmentDataset
from align_data.common.alignment_dataset import AlignmentDataset, normalize_url
from align_data.db.models import Article
from align_data.sources.articles.google_cloud import fetch_file, fetch_markdown
from align_data.sources.articles.parsers import (
Expand Down Expand Up @@ -127,16 +127,16 @@ def not_processed(self, item: tuple) -> bool:
url = self.maybe(item, "url")
source_url = self.maybe(item, "source_url")

if item_key and self._normalize_url(item_key) in self._outputted_items:
if item_key and normalize_url(item_key) in self._outputted_items:
return False

for given_url in [url, source_url]:
if given_url:
norm_url = self._normalize_url(given_url)
norm_url = normalize_url(given_url)
if norm_url in self._outputted_items:
return False

norm_canonical_url = self._normalize_url(arxiv_canonical_url(given_url))
norm_canonical_url = normalize_url(arxiv_canonical_url(given_url))
if norm_canonical_url in self._outputted_items:
return False

Expand Down
84 changes: 84 additions & 0 deletions align_data/sources/validate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import logging
from datetime import datetime, timedelta
from typing import Any, List

from tqdm import tqdm
from sqlalchemy.exc import IntegrityError
from align_data.db.session import make_session
from align_data.db.models import Article
from align_data.common.alignment_dataset import normalize_url, normalize_text, article_dict
from align_data.sources.articles.parsers import item_metadata
from align_data.sources.articles.html import fetch


logger = logging.getLogger(__name__)


def update_article_field(article: Article, field: str, value: Any):
if not value:
return

if field == 'url' and normalize_url(article.url) == normalize_url(value):
# This is pretty much the same url, so don't modify it
return
if field == 'title' and normalize_text(article.title) == normalize_text(value):
# If there are slight differences in the titles (e.g. punctuation), assume the
# database version is more correct
return
if field == 'meta':
article.meta = article.meta or {}
for k, v in value.items():
meta_val = article.meta.get(k)
if not meta_val or v > meta_val:
article.meta[k] = v
return

article_val = getattr(article, field, None)
# Assume that if the provided value is larger (or later, in the case of dates), then it's
# better. This might very well not hold, but it seems like a decent heuristic?
if not article_val:
setattr(article, field, value)
elif isinstance(value, datetime) and value > article_val:
setattr(article, field, value)
elif isinstance(value, str) and len(normalize_text(value) or '') > len(normalize_text(article_val) or ''):
setattr(article, field, normalize_text(value))


def check_article(article: Article) -> Article:
source_url = article.meta.get('source_url') or article.url
contents = {}
if source_url:
contents = item_metadata(source_url)

if 'error' not in contents:
for field, value in article_dict(contents).items():
update_article_field(article, field, value)
else:
logger.info('Error getting contents for %s: %s', article, contents.get('error'))

if 400 <= fetch(article.url).status_code < 500:
logger.info('Could not get url for %s', article)
article.status = 'Unreachable url'

article.date_checked = datetime.utcnow()

return article


def check_articles(sources: List[str], batch_size=100):
logger.info('Checking %s articles for %s', batch_size, ', '.join(sources))
with make_session() as session:
for article in tqdm(
session.query(Article)
.filter(Article.date_checked < datetime.now() - timedelta(weeks=4))
.filter(Article.source.in_(sources))
.limit(batch_size)
.all()
):
check_article(article)
session.add(article)
logger.debug('commiting')
try:
session.commit()
except IntegrityError as e:
logger.error(e)
9 changes: 9 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
)
from align_data.embeddings.pinecone.update_pinecone import PineconeUpdater
from align_data.embeddings.finetuning.training import finetune_embeddings
from align_data.sources.validate import check_articles
from align_data.settings import (
METADATA_OUTPUT_SPREADSHEET,
METADATA_SOURCE_SHEET,
Expand Down Expand Up @@ -150,6 +151,14 @@ def train_finetuning_layer(self) -> None:
"""
finetune_embeddings()

def validate_articles(self, *names, n=100) -> None:
"""Check n articles to see whether their data is correct and that their urls point to valid addresses."""
if names == ("all",):
names = ALL_DATASETS
missing = {name for name in names if name not in ALL_DATASETS}
assert not missing, f"{missing} are not valid dataset names"
check_articles(names, n)


if __name__ == "__main__":
fire.Fire(AlignmentDataset)
Loading

0 comments on commit 36ad8cb

Please sign in to comment.