Skip to content

Commit

Permalink
Merge pull request #92 from hynky1999/ci/cd_update
Browse files Browse the repository at this point in the history
Ruff Linting and easier development cycle with Makefile
  • Loading branch information
hynky1999 authored Nov 18, 2023
2 parents d8bc441 + f433243 commit cbbc3ee
Show file tree
Hide file tree
Showing 44 changed files with 585 additions and 680 deletions.
4 changes: 0 additions & 4 deletions .env.sample
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
# ENV VARIABLES FOR LOCAL TESTING
# AWS
AWS_PROFILE="dev"

# MYSQL
MYSQL_HOST=127.0.0.1
MYSQL_PORT=3306
Expand Down
16 changes: 11 additions & 5 deletions .github/workflows/test_and_types.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,12 @@ jobs:
with:
python-version: ${{ env.PYTHON_VERSION }}
cache: "pip"
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@v4
with:
aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }}
aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
aws-region: us-east-1

- name: Install dependencies
run: pip install -r requirements.test.txt # Replace with your dependencies installation command
Expand All @@ -38,7 +44,7 @@ jobs:
run: cp .env.sample .env

- name: Run tests
run: python -m unittest discover -s tests -p "*_tests.py" # Replace with your test command
run: make test

lint_and_types:
runs-on: ubuntu-latest
Expand All @@ -56,8 +62,8 @@ jobs:
- name: Install dependencies
run: pip install -r requirements.dev.txt

- name: Lint with pyright
run: pyright
- name: Format
run: make format

- name: Lint with black
run: black -t py310 --check .
- name: Lint
run: make lint
27 changes: 10 additions & 17 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,14 @@ repos:
- id: check-merge-conflict
- id: mixed-line-ending

- repo: https://github.com/PyCQA/autoflake
rev: v1.4
hooks:
- id: autoflake
args:
- "--in-place"
- "--expand-star-imports"
- "--remove-duplicate-keys"
- "--recursive"
- "--remove-unused-variables"
language_version: python3.11
- repo: https://github.com/myint/autoflake
rev: v1.4
hooks:
- id: autoflake

- repo: https://github.com/psf/black
rev: 23.3.0
hooks:
- id: black
args: ["--line-length", "79"]
language_version: python3.11
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.1.5
hooks:
# Run the formatter.
- id: ruff-format
2 changes: 1 addition & 1 deletion .vscode/launch.json
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
"request": "launch",
"module": "cmoncrawl.integrations.commands",
"console": "integratedTerminal",
"args": ["download", "idnes.cz", "out", "html"]
"args": ["download", "--limit=4" ,"idnes.cz", "out", "record"]
},
{
"name": "Extract",
Expand Down
12 changes: 12 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
.PHONY: test lint check format

test:
python -m unittest discover -s tests -p '*_test.py'

lint:
@ruff --fix cmoncrawl tests || ( echo ">>> ruff failed"; exit 1; )

format:
@pre-commit run --all-files

check: format lint
69 changes: 16 additions & 53 deletions cmoncrawl/aggregator/athena_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@

import aioboto3
import aiofiles
import tenacity
from aiocsv.readers import AsyncDictReader
from aiohttp import ClientSession

Expand All @@ -31,19 +30,15 @@
from cmoncrawl.aggregator.utils.helpers import get_all_CC_indexes
from cmoncrawl.common.loggers import all_purpose_logger
from cmoncrawl.common.types import (
DomainCrawl,
DomainRecord,
MatchType,
RetrieveResponse,
)

QUERIES_SUBFOLDER = "queries"
QUERIES_TMP_SUBFOLDER = "queries_tmp"


async def remove_bucket_prefix(
session: aioboto3.Session, prefix: str, folder: str
):
async def remove_bucket_prefix(session: aioboto3.Session, prefix: str, folder: str):
# remove all query results
async with session.client("s3") as s3:
paginator = s3.get_paginator("list_objects_v2")
Expand Down Expand Up @@ -175,9 +170,7 @@ def validate_args(self):
async def __aenter__(self) -> AthenaAggregator:
return await self.aopen()

async def __aexit__(
self, exc_type, exc_value, traceback
) -> AthenaAggregator:
async def __aexit__(self, exc_type, exc_value, traceback) -> AthenaAggregator:
return await self.aclose()

async def aopen(self) -> AthenaAggregator:
Expand Down Expand Up @@ -264,9 +257,7 @@ def __init__(
self.__database_name = database_name
self.__table_name = table_name
self.__extra_sql_where_clause = extra_sql_where_clause
self.__prefetch_queue: Set[
asyncio.Task[List[DomainRecord]]
] = set()
self.__prefetch_queue: Set[asyncio.Task[List[DomainRecord]]] = set()
self.__opt_prefetch_size = 5

def init_crawls_queue(
Expand All @@ -275,13 +266,8 @@ def init_crawls_queue(
allowed_crawls = [
crawl_url_to_name(crawl)
for crawl in CC_files
if (
self.__since is None
or crawl_to_year(crawl) >= self.__since.year
)
and (
self.__to is None or crawl_to_year(crawl) <= self.__to.year
)
if (self.__since is None or crawl_to_year(crawl) >= self.__since.year)
and (self.__to is None or crawl_to_year(crawl) <= self.__to.year)
]
if batch_size <= 0:
return [allowed_crawls]
Expand All @@ -291,34 +277,24 @@ def init_crawls_queue(
]

async def download_results(self, key: str) -> str:
with tempfile.NamedTemporaryFile(
suffix=".csv", delete=False
) as temp_file:
with tempfile.NamedTemporaryFile(suffix=".csv", delete=False) as temp_file:
async with self.__aws_client.client("s3") as s3:
await s3.download_file(
self.__bucket_name, key, temp_file.name
)
await s3.download_file(self.__bucket_name, key, temp_file.name)

return temp_file.name

async def await_athena_query(
self, query: str, result_name: str
) -> str:
async def await_athena_query(self, query: str, result_name: str) -> str:
s3_location = f"s3://{self.__bucket_name}/{QUERIES_TMP_SUBFOLDER}"
query_execution_id = await run_athena_query(
self.__aws_client,
{
"QueryString": query,
"QueryExecutionContext": {
"Database": self.__database_name
},
"QueryExecutionContext": {"Database": self.__database_name},
"ResultConfiguration": {"OutputLocation": s3_location},
},
)
# Move file to bucket/result_name
query_result_key = (
f"{QUERIES_TMP_SUBFOLDER}/{query_execution_id}.csv"
)
query_result_key = f"{QUERIES_TMP_SUBFOLDER}/{query_execution_id}.csv"
expected_result_key = f"{QUERIES_SUBFOLDER}/{result_name}.csv"
async with self.__aws_client.client("s3") as s3:
await s3.copy_object(
Expand Down Expand Up @@ -383,19 +359,13 @@ async def __fetch_next_crawl_batch(self, crawl_batch: List[str]):
query_id = f"{crawl_batch_id}-{query_hash}"
crawl_s3_key = await self.is_crawl_cached(query_id)
if crawl_s3_key is not None:
all_purpose_logger.info(
f"Using cached crawl batch {crawl_batch}"
)
all_purpose_logger.info(f"Using cached crawl batch {crawl_batch}")
else:
all_purpose_logger.info(
f"Querying for crawl batch {crawl_batch}"
)
all_purpose_logger.info(f"Querying for crawl batch {crawl_batch}")
crawl_s3_key = await self.await_athena_query(query, query_id)

domain_records: List[DomainRecord] = []
async for domain_record in self.domain_records_from_s3(
crawl_s3_key
):
async for domain_record in self.domain_records_from_s3(crawl_s3_key):
domain_records.append(domain_record)
return domain_records

Expand All @@ -414,15 +384,10 @@ async def __await_next_prefetch(self):
):
next_crawl = self.__crawls_remaining.pop(0)
self.__prefetch_queue.add(
asyncio.create_task(
self.__fetch_next_crawl_batch(next_crawl)
)
asyncio.create_task(self.__fetch_next_crawl_batch(next_crawl))
)

while (
len(self.__prefetch_queue) > 0
and len(self.__domain_records) == 0
):
while len(self.__prefetch_queue) > 0 and len(self.__domain_records) == 0:
done, self.__prefetch_queue = await asyncio.wait(
self.__prefetch_queue, return_when="FIRST_COMPLETED"
)
Expand All @@ -431,9 +396,7 @@ async def __await_next_prefetch(self):
domain_records = task.result()
self.__domain_records.extend(domain_records)
except Exception as e:
all_purpose_logger.error(
f"Error during a crawl query", e
)
all_purpose_logger.error("Error during a crawl query", e)

async def __anext__(self) -> DomainRecord:
# Stop if we fetched everything or reached limit
Expand Down
50 changes: 14 additions & 36 deletions cmoncrawl/aggregator/index_query.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,32 @@
from __future__ import annotations

import asyncio
import re
from collections import deque
from datetime import datetime
import re
from cmoncrawl.aggregator.utils.helpers import get_all_CC_indexes, retrieve

from types import TracebackType
from typing import (
AsyncIterable,
AsyncIterator,
Deque,
List,
Dict,
List,
Set,
Type,
)

from aiohttp import (
ClientSession,
)

from cmoncrawl.aggregator.utils.helpers import get_all_CC_indexes, retrieve
from cmoncrawl.common.loggers import all_purpose_logger
from cmoncrawl.common.types import (
DomainRecord,
DomainCrawl,
DomainRecord,
MatchType,
)

from aiohttp import (
ClientSession,
)
import asyncio


class IndexAggregator(AsyncIterable[DomainRecord]):
"""
Expand Down Expand Up @@ -126,9 +127,7 @@ async def __aexit__(
exc_val: BaseException | None,
exc_tb: TracebackType | None = None,
) -> IndexAggregator:
return await self.aclose(
exc_type=exc_type, exc_val=exc_val, exc_tb=exc_tb
)
return await self.aclose(exc_type=exc_type, exc_val=exc_val, exc_tb=exc_tb)

@staticmethod
async def get_number_of_pages(
Expand Down Expand Up @@ -297,9 +296,7 @@ async def __prefetch_next_crawl(self) -> int:
)

for i in range(num_pages):
dc = DomainCrawl(
next_crawl.domain, next_crawl.cdx_server, i
)
dc = DomainCrawl(next_crawl.domain, next_crawl.cdx_server, i)
self.prefetch_queue.add(
asyncio.create_task(
IndexAggregator.get_captured_responses(
Expand Down Expand Up @@ -333,10 +330,7 @@ async def __await_next_prefetch(self):
):
await self.__prefetch_next_crawl()

while (
len(self.prefetch_queue) > 0
and len(self.__domain_records) == 0
):
while len(self.prefetch_queue) > 0 and len(self.__domain_records) == 0:
done, self.prefetch_queue = await asyncio.wait(
self.prefetch_queue, return_when="FIRST_COMPLETED"
)
Expand Down Expand Up @@ -371,22 +365,6 @@ def clean(self):
for task in self.prefetch_queue:
task.cancel()

async def __fetch_next_dc(self, dc: DomainCrawl):
return (
await IndexAggregator.get_captured_responses(
self.__client,
dc.cdx_server,
dc.domain,
match_type=self.__match_type,
page=dc.page,
since=self.__since,
to=self.__to,
max_retry=self.__max_retry,
sleep_step=self.__sleep_step,
),
dc,
)


def to_timestamp_format(date: datetime):
return date.strftime("%Y%m%d%H%M%S")
Expand Down
9 changes: 5 additions & 4 deletions cmoncrawl/aggregator/utils/athena_query_maker.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from datetime import datetime
import textwrap
from datetime import datetime
from typing import List, Optional
from urllib.parse import urlparse

from cmoncrawl.aggregator.index_query import crawl_to_year
from cmoncrawl.common.types import MatchType

Expand Down Expand Up @@ -80,14 +81,14 @@ def prepare_athena_where_conditions(
urls_with_type_query = [
f"({url_query_based_on_match_type(match_type, url)})" for url in urls
]
url_query = f" OR ".join(urls_with_type_query)
url_query = " OR ".join(urls_with_type_query)
allowed_crawls_query = crawl_query(crawl_urls, since, to)
date_query = url_query_date_range(since, to)
where_conditions = [
date_query,
allowed_crawls_query,
f"cc.fetch_status = 200",
f"cc.subset = 'warc'",
"cc.fetch_status = 200",
"cc.subset = 'warc'",
url_query,
]
where_conditions = [condition for condition in where_conditions if condition]
Expand Down
Loading

0 comments on commit cbbc3ee

Please sign in to comment.