Skip to content

Commit

Permalink
Merge pull request #106 from mik3y/mikey/s3-cache
Browse files Browse the repository at this point in the history
feature: add optional cache layer to `S3Dao`
  • Loading branch information
hynky1999 authored Apr 7, 2024
2 parents 1fd7709 + fa5415d commit c00b3cd
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 2 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ repos:
- id: mixed-line-ending

- repo: https://github.com/myint/autoflake
rev: v1.4
rev: v2.3.1
hooks:
- id: autoflake

Expand Down
59 changes: 59 additions & 0 deletions cmoncrawl/common/caching.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import os
import logging
import hashlib
from pathlib import Path
from cmoncrawl.common.types import DomainRecord

logger = logging.getLogger(__name__)


def cache_key(record: DomainRecord):
"""Returns an opaque key / filename for caching a `DomainRecord`."""
h = hashlib.sha256()
h.update(record.filename.encode())
h.update("|".encode())
h.update(str(record.offset).encode())
h.update("|".encode())
h.update(str(record.length).encode())
return f"{h.hexdigest()}.bin"


class AbstractDomainRecordCache:
"""Cache interface for DomainRecords."""

def get(self, record: DomainRecord) -> bytes | None:
raise NotImplementedError

def set(self, record: DomainRecord, data: bytes) -> None:
raise NotImplementedError


class DomainRecordFilesystemCache(AbstractDomainRecordCache):
"""A local filesystem cache.
If `cache_dir` does not exist, the implementation will attempt
to create it upon first `set()` using `os.makedirs`.
Entries are never pruned (no TTL support currently).
"""

def __init__(self, cache_dir: Path):
super().__init__()
self.cache_dir = cache_dir

def get(self, record: DomainRecord) -> bytes | None:
cache_path = self.cache_dir / Path(cache_key(record))
if cache_path.exists():
with open(cache_path, "rb") as fp:
logger.debug(f"reading data for {record.url} from filesystem cache")
return fp.read()
return None

def set(self, record: DomainRecord, data: bytes) -> None:
if not self.cache_dir.exists():
logger.info(f"Creating cache dir {self.cache_dir}")
os.makedirs(str(self.cache_dir))
cache_path = self.cache_dir / Path(cache_key(record))
with open(cache_path, "wb") as fp:
logger.debug(f"writing data for {record.url} to filesystem cache")
fp.write(data)
16 changes: 15 additions & 1 deletion cmoncrawl/processor/dao/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from botocore.config import Config
from botocore.exceptions import ClientError

from cmoncrawl.common.caching import AbstractDomainRecordCache
from cmoncrawl.common.types import DomainRecord
from cmoncrawl.processor.dao.base import DownloadError, ICC_Dao

Expand All @@ -15,6 +16,7 @@ class S3Dao(ICC_Dao):
Args:
aws_profile (str, optional): The AWS profile to use for the download. Defaults to None.
bucket_name (str, optional): The name of the S3 bucket. Defaults to "commoncrawl".
cache (AbstractDomainRecordCache, optional): Cache to use for downloading from s3.
Attributes:
bucket_name (str): The name of the S3 bucket.
Expand All @@ -32,11 +34,15 @@ class S3Dao(ICC_Dao):
"""

def __init__(
self, aws_profile: str | None = None, bucket_name: str = "commoncrawl"
self,
aws_profile: str | None = None,
bucket_name: str = "commoncrawl",
cache: AbstractDomainRecordCache | None = None,
) -> None:
self.bucket_name = bucket_name
self.aws_profile = aws_profile
self.client = None
self.cache = cache

async def __aenter__(self) -> "S3Dao":
# We handle the retries ourselves, so we disable the botocore retries
Expand Down Expand Up @@ -73,6 +79,11 @@ async def fetch(self, domain_record: DomainRecord) -> bytes:
"S3Dao client is not initialized, did you forget to use async with?"
)

if self.cache:
cached_bytes = self.cache.get(domain_record)
if cached_bytes is not None:
return cached_bytes

file_name = domain_record.filename
byte_range = f"bytes={domain_record.offset}-{domain_record.offset+domain_record.length-1}"

Expand All @@ -84,4 +95,7 @@ async def fetch(self, domain_record: DomainRecord) -> bytes:
except ClientError as e:
raise DownloadError(f"AWS: {e.response['Error']['Message']}", 500)

if self.cache:
self.cache.set(domain_record, file_bytes)

return file_bytes

0 comments on commit c00b3cd

Please sign in to comment.