diff --git a/.github/workflows/main.yaml b/.github/workflows/main.yaml index 34e8e88..6025094 100644 --- a/.github/workflows/main.yaml +++ b/.github/workflows/main.yaml @@ -12,8 +12,9 @@ jobs: strategy: matrix: python_version: ["3.10"] + poc_id: ["poc1", "poc2"] concurrency: - group: ci-tests-${{ github.ref }} + group: ci-tests-${{ matrix.poc_id }}-${{ github.ref }} cancel-in-progress: true defaults: @@ -45,13 +46,10 @@ jobs: run: docker run --name redis -d redis redis-server --save 60 1 --loglevel warning - name: Start Celery worker - run: celery -A poc_celery.celery_app worker --loglevel=DEBUG & + run: celery -A poc_celery.${{ matrix.poc_id }}.celery_app worker --loglevel=DEBUG & - - name: Run pytest for Collectors - run: pytest -vvv tests/test_tasks_collectors.py - - - name: Run pytest for Async tasks - run: pytest -vvv tests/test_tasks_async.py + - name: Run pytest + run: pytest -vvv tests/${{ matrix.poc_id }} - name: Run linter run: | diff --git a/.gitignore b/.gitignore index 9754518..6f3e3ad 100644 --- a/.gitignore +++ b/.gitignore @@ -160,3 +160,4 @@ cython_debug/ #.idea/ /data +*.db diff --git a/README.md b/README.md index b949aee..b098d69 100644 --- a/README.md +++ b/README.md @@ -165,8 +165,8 @@ bash scripts/setup.sh This command executes the script that: 1. **Starts a Celery Worker**: Launches a Celery worker instance using - `poc_celery.celery_app` as the application module. This worker listens for - tasks dispatched to the queues and executes them as they arrive. + `poc_celery.poc1.celery_app` as the application module. This worker listens + for tasks dispatched to the queues and executes them as they arrive. 2. **Launches Flower**: Initiates Flower on the default port (5555), allowing you to access a web-based user interface to monitor and manage the Celery diff --git a/conda/dev.yaml b/conda/dev.yaml index 045e1d0..5fcf08b 100644 --- a/conda/dev.yaml +++ b/conda/dev.yaml @@ -19,3 +19,4 @@ dependencies: - vulture - bandit - mccabe + - sqlalchemy diff --git a/pyproject.toml b/pyproject.toml index a40d80e..fc20e47 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,7 +34,7 @@ exclude = [ fix = true [tool.ruff.lint] -ignore = ["PLR0913"] +ignore = ["PLR0913", "RUF012"] select = [ "E", # pycodestyle "F", # pyflakes diff --git a/scripts/start_celery_and_flower.sh b/scripts/start_celery_and_flower.sh index 816d7ba..52252fd 100755 --- a/scripts/start_celery_and_flower.sh +++ b/scripts/start_celery_and_flower.sh @@ -1,5 +1,7 @@ #!/bin/bash +POC_ID=${1:-poc1} + # Fetch the Rabbitmq IP address by directly invoking the get_amqp_ip function AMQP_IP=$(python -c 'from poc_celery.get_container_ip import get_amqp_ip; print(get_amqp_ip())') @@ -13,10 +15,10 @@ echo "Rabbitmq IP: $AMQP_IP" # Start the Celery worker echo "Starting Celery worker..." -celery -A poc_celery.celery_app worker --loglevel=INFO & +celery -A poc_celery.${POC_ID}.celery_app worker --loglevel=INFO & # Start Flower echo "Starting Flower with Rabbitmq at $AMQP_IP..." -celery -A poc_celery.celery_app flower --broker=amqp://guest:guest@{AMQP_IP}:5672 & +celery -A poc_celery.${POC_ID}.celery_app flower --broker=amqp://guest:guest@{AMQP_IP}:5672 & echo "Celery and Flower have been started." diff --git a/src/poc_celery/__init__.py b/src/poc_celery/__init__.py index f2d66a9..e69de29 100644 --- a/src/poc_celery/__init__.py +++ b/src/poc_celery/__init__.py @@ -1,3 +0,0 @@ -from poc_celery.celery_app import app as celery_app - -__all__ = ("celery_app",) diff --git a/src/poc_celery/db.py b/src/poc_celery/db.py new file mode 100644 index 0000000..db42b8e --- /dev/null +++ b/src/poc_celery/db.py @@ -0,0 +1,100 @@ +from __future__ import annotations + +from sqlalchemy import Column, Integer, String, create_engine +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import sessionmaker + +# Base class for declarative class definitions +Base = declarative_base() + + +class SearchModel(Base): + __tablename__ = "search" + id = Column(Integer, primary_key=True) + query = Column(String, nullable=False) + + +class ArticleModel(Base): + __tablename__ = "article" + id = Column(Integer, primary_key=True) + search_id = Column(Integer, nullable=False) + meta = Column(String, nullable=True) + + +class SimpleORM: + context = {"session": None} + + @classmethod + def create(cls, **kwargs) -> SimpleORM: + """ + Inserts sample data into the database. + + Parameters: + session (Session): The SQLAlchemy session object. + """ + # Creating a new search record + new_object = cls.model(**kwargs) + cls.context["session"].add(new_object) + cls.context["session"].commit() + return new_object + + @classmethod + def filter(cls, **kwargs) -> SimpleORM: + # Filtering data based on a condition + query = cls.context["session"].query(cls.model) + + # Apply filters based on kwargs + for key, value in kwargs.items(): + if not hasattr(cls.model, key): + print(f"Warning: '{key}' is not a valid attribute of Article") + continue + + # Construct a filter using the 'like' operator if the value + # contains a wildcard character + if "%" in value: + query = query.filter(getattr(cls.model, key).like(value)) + else: + query = query.filter(getattr(cls.model, key) == value) + + return query.all() + + @classmethod + def setup(cls, url: str = "sqlite:///example.db"): + """ + Setup the database by creating tables and initializing the session. + + Parameters: + url (str): The database URL. + + Returns: + session (Session): A SQLAlchemy Session object. + """ + engine = create_engine( + url, echo=False + ) # Set echo=False to turn off verbose logging + Base.metadata.create_all(engine) # Create all tables + Session = sessionmaker(bind=engine) + cls.context["session"] = Session() + cls.reset() + return cls.context["session"] + + @classmethod + def reset(cls): + """ + Resets the database by dropping all tables and recreating them. + """ + # Get the engine from the current session + engine = cls.context["session"].get_bind() + # Drop all tables + Base.metadata.drop_all(engine) + # Create all tables + Base.metadata.create_all(engine) + print("Database has been reset.") + + +class Search(SimpleORM): + model = SearchModel + + +class Article(SimpleORM): + model = ArticleModel diff --git a/src/poc_celery/poc1/__init__.py b/src/poc_celery/poc1/__init__.py new file mode 100644 index 0000000..6bc74b2 --- /dev/null +++ b/src/poc_celery/poc1/__init__.py @@ -0,0 +1,3 @@ +from poc_celery.poc1.celery_app import app as celery_app + +__all__ = ("celery_app",) diff --git a/src/poc_celery/poc1/celery_app.py b/src/poc_celery/poc1/celery_app.py new file mode 100644 index 0000000..ddbc700 --- /dev/null +++ b/src/poc_celery/poc1/celery_app.py @@ -0,0 +1,23 @@ +from celery import Celery + +from poc_celery.get_container_ip import get_amqp_ip, get_redis_ip + +# Get the Rabbitmq container IP address +AMQP_IP = get_amqp_ip() +REDIS_IP = get_redis_ip() + +# Create a Celery instance with Rabbitmq as the broker and result backend +app = Celery( + "poc-celery", + broker=f"amqp://guest:guest@{AMQP_IP}:5672", + backend=f"redis://{REDIS_IP}:6379/0", + include=[ + "poc_celery.poc1.tasks_async", + "poc_celery.poc1.tasks_collectors", + ], +) + +# Set broker_connection_retry_on_startup to True to suppress the warning +app.conf.broker_connection_retry_on_startup = True + +app.autodiscover_tasks() diff --git a/src/poc_celery/tasks_async.py b/src/poc_celery/poc1/tasks_async.py similarity index 93% rename from src/poc_celery/tasks_async.py rename to src/poc_celery/poc1/tasks_async.py index 7526bdf..f2af836 100644 --- a/src/poc_celery/tasks_async.py +++ b/src/poc_celery/poc1/tasks_async.py @@ -1,6 +1,6 @@ from pathlib import Path -from poc_celery.celery_app import app +from poc_celery.poc1.celery_app import app # app = Celery('tasks', broker='your_broker_url', backend='your_backend_url') DATA_DIR = Path(__file__).parent.parent / "data" diff --git a/src/poc_celery/tasks_collectors.py b/src/poc_celery/poc1/tasks_collectors.py similarity index 99% rename from src/poc_celery/tasks_collectors.py rename to src/poc_celery/poc1/tasks_collectors.py index 78ad3c9..ea766f9 100644 --- a/src/poc_celery/tasks_collectors.py +++ b/src/poc_celery/poc1/tasks_collectors.py @@ -3,7 +3,7 @@ from celery import chord, group -from poc_celery.celery_app import app +from poc_celery.poc1.celery_app import app def generate_collector_request(topic: str) -> str: diff --git a/src/poc_celery/poc2/__init__.py b/src/poc_celery/poc2/__init__.py new file mode 100644 index 0000000..1203398 --- /dev/null +++ b/src/poc_celery/poc2/__init__.py @@ -0,0 +1,3 @@ +from poc_celery.poc2.celery_app import app as celery_app + +__all__ = ("celery_app",) diff --git a/src/poc_celery/celery_app.py b/src/poc_celery/poc2/celery_app.py similarity index 88% rename from src/poc_celery/celery_app.py rename to src/poc_celery/poc2/celery_app.py index a5b0067..13938e3 100644 --- a/src/poc_celery/celery_app.py +++ b/src/poc_celery/poc2/celery_app.py @@ -12,8 +12,7 @@ broker=f"amqp://guest:guest@{AMQP_IP}:5672", backend=f"redis://{REDIS_IP}:6379/0", include=[ - "poc_celery.tasks_async", - "poc_celery.tasks_collectors", + "poc_celery.poc2.tasks", ], ) diff --git a/src/poc_celery/poc2/tasks.py b/src/poc_celery/poc2/tasks.py new file mode 100644 index 0000000..3b38a36 --- /dev/null +++ b/src/poc_celery/poc2/tasks.py @@ -0,0 +1,88 @@ +from __future__ import annotations + +from celery import group, shared_task + +from poc_celery.db import Article, Search, SimpleORM + +SimpleORM.setup() + + +@shared_task +def search_task(query: str): + """ + Start the pipeline. + + Initial task that receives a user's request and triggers collector tasks. + """ + # with transaction.atomic(): + search_obj = Search.create(query=query) + search_id = search_obj.id + + collectors = [ + collector_1.s(search_id), + collector_2.s(search_id), + collector_3.s(search_id), + ] + callback = clean_up.s(search_id=search_id).set( + link_error=clean_up.s(search_id=search_id) + ) + group(collectors) | callback.delay() + + +@shared_task(bind=True, max_retries=0) +def collector_1(self, search_id: int): + """Collect data for collector 1.""" + return execute_collector_tasks(search_id, "collector_1") + + +@shared_task(bind=True, max_retries=0) +def collector_2(self, search_id: int): + """Collect data for collector 2.""" + return execute_collector_tasks(search_id, "collector_2") + + +@shared_task(bind=True, max_retries=0) +def collector_3(self, search_id: int): + """Collect data for collector 3.""" + return execute_collector_tasks(search_id, "collector_3") + + +def execute_collector_tasks(search_id: int, collector_name: str): + """ + Execute collector tasks. + + Helper function to execute get_list and get_article tasks for a collector. + """ + # Assuming `get_list` generates a list of article IDs for simplicity + article_ids = get_list(search_id, collector_name) + for article_id in article_ids: + get_article.delay(search_id, article_id, collector_name) + return {"status": "Completed", "collector": collector_name} + + +@shared_task +def get_list(search_id: int, collector_name: str): + """Simulated task to get a list of articles.""" + # Simulate getting a list of article IDs + return [1, 2, 3] # Example article IDs + + +@shared_task +def get_article(search_id: int, article_id: int, collector_name: str): + """Task to fetch and save article metadata.""" + # Simulate fetching article metadata + metadata = f"Metadata for article {article_id} from {collector_name}" + # with transaction.atomic(): + Article.objects.create(search_id=search_id, meta=metadata) + + +@shared_task +def clean_up(search_id: int): + """ + Clean up temporary storage. + + Cleanup task to be triggered when all articles from all collectors + for a specific search are done. + """ + # Implement cleanup logic here, e.g., removing duplicate articles + pass diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..8c9c486 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,39 @@ +import pytest + + +@pytest.fixture(scope="module") +def celery_config(): + """ + Provide Celery app configuration for testing. + + This fixture is responsible for setting up the Celery app with a specific + configuration suitable for test runs. It defines the broker and result backend + to use Rabbitmq and sets the task execution mode to always eager, which means + tasks will be executed locally and synchronously. + + Yields + ------ + dict + A dictionary containing configuration settings for the Celery application. + """ + return { + "broker_url": "amqp://guest:guest@rabbitmq3:5672", + "result_backend": "redis://localhost:6379/0", + "task_always_eager": True, + } + + +@pytest.fixture(scope="module") +def celery_enable_logging(): + """ + Activate logging for Celery tasks during testing. + + This fixture ensures that Celery task logs are visible during test execution, + aiding in debugging and verifying task behavior. + + Returns + ------- + bool + True to enable Celery task logging, False otherwise. + """ + return True diff --git a/tests/poc1/__init__.py b/tests/poc1/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_tasks_async.py b/tests/poc1/test_tasks_async.py similarity index 93% rename from tests/test_tasks_async.py rename to tests/poc1/test_tasks_async.py index 91144f4..7972967 100644 --- a/tests/test_tasks_async.py +++ b/tests/poc1/test_tasks_async.py @@ -4,7 +4,7 @@ import pytest -from poc_celery.tasks_async import DATA_DIR, clean_data, create_project +from poc_celery.poc1.tasks_async import DATA_DIR, clean_data, create_project @pytest.fixture @@ -56,7 +56,7 @@ def test_create_project(mock_file_io): async def test_create_project_stress(mock_file_io): file_path = str(DATA_DIR / "collectors.txt") - num_calls = 100000 + num_calls = 10 calls = [ [1, 1, 3], diff --git a/tests/test_tasks_collectors.py b/tests/poc1/test_tasks_collectors.py similarity index 60% rename from tests/test_tasks_collectors.py rename to tests/poc1/test_tasks_collectors.py index a877161..313ceda 100644 --- a/tests/test_tasks_collectors.py +++ b/tests/poc1/test_tasks_collectors.py @@ -4,7 +4,7 @@ import pytest -from poc_celery.tasks_collectors import ( +from poc_celery.poc1.tasks_collectors import ( collector_request, generate_collector_request, ) @@ -13,44 +13,6 @@ logging.basicConfig(level=logging.INFO) -@pytest.fixture(scope="module") -def celery_config(): - """ - Provide Celery app configuration for testing. - - This fixture is responsible for setting up the Celery app with a specific - configuration suitable for test runs. It defines the broker and result backend - to use Rabbitmq and sets the task execution mode to always eager, which means - tasks will be executed locally and synchronously. - - Yields - ------ - dict - A dictionary containing configuration settings for the Celery application. - """ - return { - "broker_url": "amqp://guest:guest@rabbitmq3:5672", - "result_backend": "redis://localhost:6379/0", - "task_always_eager": True, - } - - -@pytest.fixture(scope="module") -def celery_enable_logging(): - """ - Activate logging for Celery tasks during testing. - - This fixture ensures that Celery task logs are visible during test execution, - aiding in debugging and verifying task behavior. - - Returns - ------- - bool - True to enable Celery task logging, False otherwise. - """ - return True - - def test_generate_collector_request(): """ Validate that `generate_collector_request` produces a valid UUID string. @@ -67,8 +29,8 @@ def test_generate_collector_request(): assert isinstance(request_id, str), "The request_id should be a string." -@patch("poc_celery.tasks_collectors.collector_gathering.s") -@patch("poc_celery.tasks_collectors.group") +@patch("poc_celery.poc1.tasks_collectors.collector_gathering.s") +@patch("poc_celery.poc1.tasks_collectors.group") def test_collector_request_triggers_sub_collectors( mock_group, mock_collector_gathering_s ): diff --git a/tests/poc2/__init__.py b/tests/poc2/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/poc2/test_tasks_collectors.py b/tests/poc2/test_tasks_collectors.py new file mode 100644 index 0000000..39e38b1 --- /dev/null +++ b/tests/poc2/test_tasks_collectors.py @@ -0,0 +1,56 @@ +import unittest +from unittest.mock import patch, MagicMock +from poc_celery.poc2.tasks import ( + search_task, + get_article, + collector_1, +) # Adjust the import path + + +class TestCeleryTasks(unittest.TestCase): + @patch("poc_celery.poc2.tasks.Search.create") + @patch("poc_celery.poc2.tasks.group") + @patch("poc_celery.poc2.tasks.collector_1.s") + @patch("poc_celery.poc2.tasks.collector_2.s") + @patch("poc_celery.poc2.tasks.collector_3.s") + @patch("poc_celery.poc2.tasks.clean_up.s") + def test_search_task( + self, + mock_cleanup, + mock_collector_3, + mock_collector_2, + mock_collector_1, + mock_group, + mock_create, + ): + # Mocking the Search.create() to return an object with an id attribute + mock_search_obj = MagicMock() + mock_search_obj.id = 1 + mock_create.return_value = mock_search_obj + + # Testing search_task + search_task("test query") + + # Asserting Search.create was called with the right query + mock_create.assert_called_once_with(query="test query") + + # Assert that collector tasks and cleanup tasks are setup correctly + mock_collector_1.assert_called_once_with(1) + mock_collector_2.assert_called_once_with(1) + mock_collector_3.assert_called_once_with(1) + mock_cleanup.assert_called_once_with(search_id=1) + mock_group.assert_called_once() + + @patch("poc_celery.poc2.tasks.Article.objects.create") + def test_get_article(self, mock_create): + # Test get_article task + get_article(1, 101, "collector_1") + + # Asserting Article.create was called with correct parameters + mock_create.assert_called_once_with( + search_id=1, meta="Metadata for article 101 from collector_1" + ) + + +if __name__ == "__main__": + unittest.main()