diff --git a/python-sdk/indexify/executor/agent.py b/python-sdk/indexify/executor/agent.py index 753a02d8f..6f16308a0 100644 --- a/python-sdk/indexify/executor/agent.py +++ b/python-sdk/indexify/executor/agent.py @@ -1,16 +1,10 @@ import asyncio -import json -from importlib.metadata import version from pathlib import Path from typing import Dict, List, Optional import structlog -from httpx_sse import aconnect_sse -from indexify.common_util import get_httpx_client - -from .api_objects import ExecutorMetadata, Task -from .downloader import DownloadedInputs, Downloader +from .downloader import Downloader from .executor_tasks import DownloadGraphTask, DownloadInputsTask, RunTask from .function_executor.process_function_executor_factory import ( ProcessFunctionExecutorFactory, @@ -20,7 +14,7 @@ FunctionWorkerInput, FunctionWorkerOutput, ) -from .runtime_probes import ProbeInfo, RuntimeProbes +from .task_fetcher import TaskFetcher from .task_reporter import TaskReporter from .task_store import CompletedTask, TaskStore @@ -38,19 +32,13 @@ def __init__( name_alias: Optional[str] = None, image_hash: Optional[str] = None, ): - self.name_alias = name_alias - self.image_hash = image_hash self._config_path = config_path - self._probe = RuntimeProbes() - + protocol: str = "http" if config_path: logger.info("running the extractor with TLS enabled") - self._protocol = "https" - else: - self._protocol = "http" + protocol = "https" self._task_store: TaskStore = TaskStore() - self._executor_id = executor_id self._function_worker = FunctionWorker( function_executor_factory=ProcessFunctionExecutorFactory( indexify_server_address=server_addr, @@ -60,14 +48,22 @@ def __init__( ) self._has_registered = False self._server_addr = server_addr - self._base_url = f"{self._protocol}://{self._server_addr}" + self._base_url = f"{protocol}://{self._server_addr}" self._code_path = code_path self._downloader = Downloader( code_path=code_path, base_url=self._base_url, config_path=config_path ) + self._task_fetcher = TaskFetcher( + protocol=protocol, + indexify_server_addr=self._server_addr, + executor_id=executor_id, + name_alias=name_alias, + image_hash=image_hash, + config_path=config_path, + ) self._task_reporter = TaskReporter( base_url=self._base_url, - executor_id=self._executor_id, + executor_id=executor_id, config_path=self._config_path, ) @@ -230,6 +226,18 @@ async def task_launcher(self): self._task_store.complete(outcome=completed_task) continue + async def _main_loop(self): + """Fetches incoming tasks from the server and starts their processing.""" + self._should_run = True + while self._should_run: + try: + async for task in self._task_fetcher.run(): + self._task_store.add_tasks([task]) + except Exception as e: + logger.error("failed fetching tasks, retrying in 5 seconds", exc_info=e) + await asyncio.sleep(5) + continue + async def run(self): import signal @@ -241,75 +249,7 @@ async def run(self): ) asyncio.create_task(self.task_launcher()) asyncio.create_task(self.task_completion_reporter()) - self._should_run = True - while self._should_run: - url = f"{self._protocol}://{self._server_addr}/internal/executors/{self._executor_id}/tasks" - runtime_probe: ProbeInfo = self._probe.probe() - - executor_version = version("indexify") - - image_name = ( - self.name_alias - if self.name_alias is not None - else runtime_probe.image_name - ) - - image_hash: str = ( - self.image_hash - if self.image_hash is not None - else runtime_probe.image_hash - ) - - data = ExecutorMetadata( - id=self._executor_id, - executor_version=executor_version, - addr="", - image_name=image_name, - image_hash=image_hash, - labels=runtime_probe.labels, - ).model_dump() - logger.info( - "registering_executor", - executor_id=self._executor_id, - url=url, - executor_version=executor_version, - image=image_name, - image_hash=image_hash, - labels=runtime_probe.labels, - ) - try: - async with get_httpx_client(self._config_path, True) as client: - async with aconnect_sse( - client, - "POST", - url, - json=data, - headers={"Content-Type": "application/json"}, - ) as event_source: - if not event_source.response.is_success: - resp = await event_source.response.aread() - logger.error( - f"failed to register", - resp=str(resp), - status_code=event_source.response.status_code, - ) - await asyncio.sleep(5) - continue - logger.info( - "executor_registered", executor_id=self._executor_id - ) - async for sse in event_source.aiter_sse(): - data = json.loads(sse.data) - tasks = [] - for task_dict in data: - tasks.append( - Task.model_validate(task_dict, strict=False) - ) - self._task_store.add_tasks(tasks) - except Exception as e: - logger.error("failed to register", exc_info=e) - await asyncio.sleep(5) - continue + await self._main_loop() async def _shutdown(self, loop): logger.info("shutting_down") diff --git a/python-sdk/indexify/executor/task_fetcher.py b/python-sdk/indexify/executor/task_fetcher.py new file mode 100644 index 000000000..64162852d --- /dev/null +++ b/python-sdk/indexify/executor/task_fetcher.py @@ -0,0 +1,80 @@ +import json +from importlib.metadata import version +from typing import AsyncGenerator, Optional + +import httpx +import structlog +from httpx_sse import aconnect_sse + +from indexify.common_util import get_httpx_client + +from .api_objects import ExecutorMetadata, Task +from .runtime_probes import ProbeInfo, RuntimeProbes + + +class TaskFetcher: + """Registers with Indexify server and fetches tasks from it.""" + + def __init__( + self, + protocol: str, + indexify_server_addr: str, + executor_id: str, + name_alias: Optional[str] = None, + image_hash: Optional[int] = None, + config_path: Optional[str] = None, + ): + self._protocol: str = protocol + self._indexify_server_addr: str = indexify_server_addr + self.config_path = config_path + self._logger = structlog.get_logger(module=__name__) + + probe_info: ProbeInfo = RuntimeProbes().probe() + self._executor_metadata: ExecutorMetadata = ExecutorMetadata( + id=executor_id, + executor_version=version("indexify"), + addr="", + image_name=probe_info.image_name if name_alias is None else name_alias, + image_hash=(probe_info.image_hash if image_hash is None else image_hash), + labels=probe_info.labels, + ) + + async def run(self) -> AsyncGenerator[Task, None]: + """Fetches tasks that Indexify server assigned to the Executor. + + Raises an exception if error occurred.""" + url = f"{self._protocol}://{self._indexify_server_addr}/internal/executors/{self._executor_metadata.id}/tasks" + + self._logger.info( + "registering_executor", + executor_id=self._executor_metadata.id, + url=url, + executor_version=self._executor_metadata.executor_version, + ) + async with get_httpx_client( + config_path=self.config_path, make_async=True + ) as client: + async with aconnect_sse( + client, + "POST", + url, + json=self._executor_metadata.model_dump(), + headers={"Content-Type": "application/json"}, + ) as event_source: + try: + event_source.response.raise_for_status() + except Exception as e: + await event_source.response.aread() + raise Exception( + "Failed to register at server. " + f"Response code: {event_source.response.status_code}. " + f"Response text: '{event_source.response.text}'." + ) from e + + self._logger.info( + "executor_registered", executor_id=self._executor_metadata.id + ) + async for sse in event_source.aiter_sse(): + task_dicts = json.loads(sse.data) + for task_dict in task_dicts: + yield Task.model_validate(task_dict, strict=False) diff --git a/python-sdk/tests/test_extractor_agent_behaviour.py b/python-sdk/tests/test_extractor_agent_behaviour.py index 6bdf51b0c..d081b41fd 100644 --- a/python-sdk/tests/test_extractor_agent_behaviour.py +++ b/python-sdk/tests/test_extractor_agent_behaviour.py @@ -54,7 +54,7 @@ def test_tls_configuration(self, mock_async_client, mock_sync_client, mock_file) # Verify TLS config in Agent self.assertEqual(agent._server_addr, service_url) - self.assertEqual(agent._protocol, "https") + self.assertTrue(agent._base_url.startswith("https://")) def test_no_tls_configuration(self): # Create an instance of ExtractorAgent without TLS @@ -65,7 +65,7 @@ def test_no_tls_configuration(self): ) # Verify the protocol is set to "http" - self.assertEqual(agent._protocol, "http") + self.assertTrue(agent._base_url.startswith("http://")) if __name__ == "__main__":