Skip to content

Commit

Permalink
implementing the UNIMPLEMENTED_PARSERS (#97)
Browse files Browse the repository at this point in the history
* to start the pr to add comments

* removed spaces

* Merge remote-tracking branch 'origin/main' into implement_more_parsers
meaningless merge

* create logger_config and reorder the imports

* main's logger

* ignore the log files

* postprocess notes

* fix test with new download order for pdfarticles

* Handle special docs

* Fetch new items from indices

* fixed domain getter from network location

* logger and minor fixes

* comment: add www2. and www6. handling

* removed logger_config

* merge with main and minor changes

* rm logger_config.py

* minor fixes

* minor fixes 2

* parsers type signature

* test_arxiv_process_entry_retracted fixed

* Refactor of special_indices

* 1239283019481293043902

* alignmentdataset class removed some init fields

* removed the wrong arxivpapers file

* minor changes

* pdf date_published is a datetime

* revert some useless changes

* revert type annotation change

* nits

* nits 2

* nits 2

---------

Co-authored-by: Daniel O'Connell <[email protected]>
Co-authored-by: Henri Lemoine <[email protected]>
  • Loading branch information
3 people authored Sep 13, 2023
1 parent 65e2a42 commit 16e4c84
Show file tree
Hide file tree
Showing 30 changed files with 496 additions and 355 deletions.
3 changes: 2 additions & 1 deletion align_data/analysis/analyse_jsonl_data.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from datetime import datetime
from pathlib import Path
from collections import defaultdict

import jsonlines

from collections import defaultdict


def is_valid_date_format(data_dict, format="%Y-%m-%dT%H:%M:%SZ"):
Expand Down
5 changes: 3 additions & 2 deletions align_data/analysis/count_tokens.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from typing import Tuple
import logging

from transformers import AutoTokenizer
import jsonlines
import logging
from typing import Tuple

logger = logging.getLogger(__name__)

Expand Down
99 changes: 53 additions & 46 deletions align_data/common/alignment_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,23 @@
from itertools import islice
import logging
import time
from dataclasses import dataclass, KW_ONLY
from dataclasses import dataclass, field, KW_ONLY
from pathlib import Path
from typing import Iterable, List, Optional, Set
from sqlalchemy import select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import joinedload
from typing import List, Optional, Set, Iterable, Tuple, Generator

import jsonlines
import pytz
from sqlalchemy import select, Select, JSON
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import joinedload, Session
import jsonlines
from dateutil.parser import parse, ParserError
from tqdm import tqdm

from align_data.db.models import Article, Summary
from align_data.db.session import make_session
from align_data.settings import ARTICLE_MAIN_KEYS
from align_data.sources.utils import merge_dicts


logger = logging.getLogger(__name__)


Expand All @@ -28,40 +28,42 @@ class AlignmentDataset:
"""The base dataset class."""

name: str
"""The name of the dataset"""
"""The name of the dataset."""

_: KW_ONLY

files_path = Path("")
"""The path where data can be found. Usually a folder"""
data_path: Path = Path(__file__).parent / "../../data/"
"""The path where data can be found. Usually a folder."""

# Derived paths
raw_data_path: Path = field(init=False)
files_path: Path = field(init=False)

# Internal housekeeping variables
_outputted_items: Set[str] = field(default_factory=set, init=False)
"""A set of the ids of all previously processed items."""

done_key = "id"
"""The key of the entry to use as the id when checking if already processed."""

COOLDOWN = 0
"""An optional cool down between processing entries"""
"""An optional cool down between processing entries."""

lazy_eval = False
"""Whether to lazy fetch items. This is nice in that it will start processing, but messes up the progress bar."""

batch_size = 20
"""The number of items to collect before flushing to the database."""

# Internal housekeeping variables
_entry_idx = 0
"""Used internally for writing debugging info - each file write will increment it"""
_outputted_items = set()
"""A set of the ids of all previously processed items"""
def __post_init__(self):
self.data_path = self.data_path.resolve()

def __str__(self) -> str:
return self.name

def __post_init__(self, data_path=Path(__file__).parent / "../../data/"):
self.data_path = data_path
self.raw_data_path = self.data_path / "raw"

# set the default place to look for data
self.files_path = self.raw_data_path / self.name

def __str__(self) -> str:
return self.name

def _add_authors(self, article: Article, authors: List[str]) -> Article:
# TODO: Don't keep adding the same authors - come up with some way to reuse them
article.authors = ",".join(authors)
Expand All @@ -87,42 +89,42 @@ def make_data_entry(self, data, **kwargs) -> Article:
article.summaries += [Summary(text=summary, source=self.name) for summary in summaries]
return article

def to_jsonl(self, out_path=None, filename=None) -> Path:
if not out_path:
out_path = Path(__file__).parent / "../../data/"

if not filename:
filename = f"{self.name}.jsonl"
filename = Path(out_path) / filename
def to_jsonl(self, out_path: Path | None = None, filename: str | None = None) -> Path:
out_path = out_path or self.data_path
filename = filename or f"{self.name}.jsonl"
filepath = out_path / filename

with jsonlines.open(filename, "w") as jsonl_writer:
with jsonlines.open(filepath, "w") as jsonl_writer:
for article in self.read_entries():
jsonl_writer.write(article.to_dict())
return filename.resolve()
return filepath.resolve()

@property
def _query_items(self):
def _query_items(self) -> Select[Tuple[Article]]:
return select(Article).where(Article.source == self.name)

def read_entries(self, sort_by=None):
def read_entries(self, sort_by=None) -> Iterable[Article]:
"""Iterate through all the saved entries."""
with make_session() as session:
query = self._query_items.options(joinedload(Article.summaries))
if sort_by is not None:
query = query.order_by(sort_by)
for item in session.scalars(query).unique():
yield item

result = session.scalars(query)
for article in result.unique(): # removes duplicates
yield article

def _add_batch(self, session, batch):
def _add_batch(self, session: Session, batch: tuple):
session.add_all(batch)

def add_entries(self, entries):
def commit():
def commit() -> bool:
try:
session.commit()
return True
except IntegrityError:
session.rollback()
return False

with make_session() as session:
items = iter(entries)
Expand Down Expand Up @@ -183,7 +185,11 @@ def _normalize_urls(self, urls: Iterable[str]) -> Set[str]:


def _load_outputted_items(self) -> Set[str]:
"""Load the output file (if it exists) in order to know which items have already been output."""
"""
Loads the outputted items from the database and returns them as a set.
if the done_key is not an attribute of Article, it will try to load it from the meta field.
"""
with make_session() as session:
items = set()
if hasattr(Article, self.done_key):
Expand All @@ -203,23 +209,24 @@ def not_processed(self, item) -> bool:
# If it get's to that level, consider batching it somehow
return self._normalize_url(self.get_item_key(item)) not in self._outputted_items

def unprocessed_items(self, items=None) -> Iterable:
def unprocessed_items(self, items=None) -> list | filter:
"""Return a list of all items to be processed.
This will automatically remove any items that have already been processed,
based on the contents of the output file.
"""
self.setup()
items = items or self.items_list

filtered = filter(self.not_processed, items or self.items_list)
items_to_process = filter(self.not_processed, items)

# greedily fetch all items if not lazy eval. This makes the progress bar look nice
if not self.lazy_eval:
filtered = list(filtered)
return list(items_to_process)

return filtered
return items_to_process

def fetch_entries(self):
def fetch_entries(self) -> Generator[Article, None, None]:
"""Get all entries to be written to the file."""
for item in tqdm(self.unprocessed_items(), desc=f"Processing {self.name}"):
entry = self.process_entry(item)
Expand All @@ -242,7 +249,7 @@ def process_entry(self, entry) -> Article | None:
raise NotImplementedError

@staticmethod
def _format_datetime(date) -> str:
def _format_datetime(date: datetime) -> str:
return date.strftime("%Y-%m-%dT%H:%M:%SZ")

@staticmethod
Expand Down Expand Up @@ -280,7 +287,7 @@ def _load_outputted_items(self) -> Set[str]:
)
)

def _add_batch(self, session, batch):
def _add_batch(self, session: Session, batch: tuple):
def merge(item):
if prev := self.articles.get(item.url):
return session.merge(item.update(prev))
Expand Down
35 changes: 18 additions & 17 deletions align_data/common/html_dataset.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
import pytz
import regex as re
import logging
from datetime import datetime
from dataclasses import dataclass, field, KW_ONLY
from dataclasses import dataclass, field
from urllib.parse import urljoin
from typing import List
from typing import List, Dict, Any
import re

import pytz
import requests
import feedparser
from bs4 import BeautifulSoup
from bs4.element import ResultSet, Tag
from markdownify import markdownify

from align_data.db.models import Article
from align_data.common.alignment_dataset import AlignmentDataset

logger = logging.getLogger(__name__)
Expand All @@ -26,9 +28,6 @@ class HTMLDataset(AlignmentDataset):
done_key = "url"

authors: List[str] = field(default_factory=list)
_: KW_ONLY
source_key: str = None
summary_key: str = None

item_selector = "article"
title_selector = "article h1"
Expand All @@ -39,23 +38,25 @@ class HTMLDataset(AlignmentDataset):
def extract_authors(self, article):
return self.authors

def get_item_key(self, item) -> str:
article_url = item.find_all("a")[0]["href"].split("?")[0]
return urljoin(self.url, article_url)

def get_item_key(self, item: Tag) -> str:
first_href = item.find("a")["href"]
href_base, *_ = first_href.split("?")
return urljoin(self.url, href_base)

@property
def items_list(self):
def items_list(self) -> ResultSet[Tag]:
logger.info(f"Fetching entries from {self.url}")
response = requests.get(self.url, allow_redirects=True)
soup = BeautifulSoup(response.content, "html.parser")
articles = soup.select(self.item_selector)
logger.info(f"Found {len(articles)} articles")
return articles

def _extra_values(self, contents):
def _extra_values(self, contents: BeautifulSoup):
return {}

def get_contents(self, article_url: str):
def get_contents(self, article_url: str) -> Dict[str, Any]:
contents = self.fetch_contents(article_url)

title = self._get_title(contents)
Expand All @@ -72,16 +73,16 @@ def get_contents(self, article_url: str):
**self._extra_values(contents),
}

def process_entry(self, article):
def process_entry(self, article: Tag) -> Article:
article_url = self.get_item_key(article)
contents = self.get_contents(article_url)
if not contents.get("text"):
return None

return self.make_data_entry(contents)

def fetch_contents(self, url):
logger.info("Fetching {}".format(url))
def fetch_contents(self, url: str):
logger.info(f"Fetching {url}")
resp = requests.get(url, allow_redirects=True)
return BeautifulSoup(resp.content, "html.parser")

Expand Down Expand Up @@ -136,7 +137,7 @@ def _get_text(self, item):
text = item.get("content") and item["content"][0].get("value")
return self._extract_markdown(text)

def fetch_contents(self, url):
def fetch_contents(self, url: str):
item = self.items[url]
if "content" in item:
return item
Expand Down
1 change: 1 addition & 0 deletions align_data/db/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,7 @@ def to_dict(self) -> Dict[str, Any]:
}



event.listen(Article, "before_insert", Article.before_write)
event.listen(Article, "before_update", Article.before_write)
event.listen(Article, "before_insert", Article.check_for_changes)
Expand Down
1 change: 0 additions & 1 deletion align_data/db/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from align_data.settings import DB_CONNECTION_URI, MIN_CONFIDENCE
from align_data.db.models import Article, PineconeStatus


logger = logging.getLogger(__name__)

# We create a single engine for the entire application
Expand Down
1 change: 0 additions & 1 deletion align_data/embeddings/pinecone/pinecone_db_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
PINECONE_NAMESPACE,
)


logger = logging.getLogger(__name__)


Expand Down
2 changes: 0 additions & 2 deletions align_data/embeddings/pinecone/update_pinecone.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,8 @@
)
from align_data.embeddings.text_splitter import ParagraphSentenceUnitTextSplitter


logger = logging.getLogger(__name__)


# Define type aliases for the Callables
LengthFunctionType = Callable[[str], int]
TruncateFunctionType = Callable[[str, int], str]
Expand Down
Loading

0 comments on commit 16e4c84

Please sign in to comment.