Skip to content

Commit

Permalink
Move everything related to fetching tasks from server into its own class
Browse files Browse the repository at this point in the history
This is the first PR where I simplify agent.py by splitting it into
components with dedicated responsibilities and interfaces.
There's going to be more.

Testing:

make check
make test
  • Loading branch information
eabatalov committed Dec 16, 2024
1 parent d593058 commit 5eae3fa
Show file tree
Hide file tree
Showing 3 changed files with 109 additions and 89 deletions.
114 changes: 27 additions & 87 deletions python-sdk/indexify/executor/agent.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,10 @@
import asyncio
import json
from importlib.metadata import version
from pathlib import Path
from typing import Dict, List, Optional

import structlog
from httpx_sse import aconnect_sse

from indexify.common_util import get_httpx_client

from .api_objects import ExecutorMetadata, Task
from .downloader import DownloadedInputs, Downloader
from .downloader import Downloader
from .executor_tasks import DownloadGraphTask, DownloadInputsTask, RunTask
from .function_executor.process_function_executor_factory import (
ProcessFunctionExecutorFactory,
Expand All @@ -20,7 +14,7 @@
FunctionWorkerInput,
FunctionWorkerOutput,
)
from .runtime_probes import ProbeInfo, RuntimeProbes
from .task_fetcher import TaskFetcher
from .task_reporter import TaskReporter
from .task_store import CompletedTask, TaskStore

Expand All @@ -38,19 +32,13 @@ def __init__(
name_alias: Optional[str] = None,
image_hash: Optional[str] = None,
):
self.name_alias = name_alias
self.image_hash = image_hash
self._config_path = config_path
self._probe = RuntimeProbes()

protocol: str = "http"
if config_path:
logger.info("running the extractor with TLS enabled")
self._protocol = "https"
else:
self._protocol = "http"
protocol = "https"

self._task_store: TaskStore = TaskStore()
self._executor_id = executor_id
self._function_worker = FunctionWorker(
function_executor_factory=ProcessFunctionExecutorFactory(
indexify_server_address=server_addr,
Expand All @@ -60,14 +48,22 @@ def __init__(
)
self._has_registered = False
self._server_addr = server_addr
self._base_url = f"{self._protocol}://{self._server_addr}"
self._base_url = f"{protocol}://{self._server_addr}"
self._code_path = code_path
self._downloader = Downloader(
code_path=code_path, base_url=self._base_url, config_path=config_path
)
self._task_fetcher = TaskFetcher(
protocol=protocol,
indexify_server_addr=self._server_addr,
executor_id=executor_id,
name_alias=name_alias,
image_hash=image_hash,
config_path=config_path,
)
self._task_reporter = TaskReporter(
base_url=self._base_url,
executor_id=self._executor_id,
executor_id=executor_id,
config_path=self._config_path,
)

Expand Down Expand Up @@ -230,6 +226,18 @@ async def task_launcher(self):
self._task_store.complete(outcome=completed_task)
continue

async def _main_loop(self):
"""Fetches incoming tasks from the server and starts their processing."""
self._should_run = True
while self._should_run:
try:
async for task in self._task_fetcher.run():
self._task_store.add_tasks([task])
except Exception as e:
logger.error("failed fetching tasks, retrying in 5 seconds", exc_info=e)
await asyncio.sleep(5)
continue

async def run(self):
import signal

Expand All @@ -241,75 +249,7 @@ async def run(self):
)
asyncio.create_task(self.task_launcher())
asyncio.create_task(self.task_completion_reporter())
self._should_run = True
while self._should_run:
url = f"{self._protocol}://{self._server_addr}/internal/executors/{self._executor_id}/tasks"
runtime_probe: ProbeInfo = self._probe.probe()

executor_version = version("indexify")

image_name = (
self.name_alias
if self.name_alias is not None
else runtime_probe.image_name
)

image_hash: str = (
self.image_hash
if self.image_hash is not None
else runtime_probe.image_hash
)

data = ExecutorMetadata(
id=self._executor_id,
executor_version=executor_version,
addr="",
image_name=image_name,
image_hash=image_hash,
labels=runtime_probe.labels,
).model_dump()
logger.info(
"registering_executor",
executor_id=self._executor_id,
url=url,
executor_version=executor_version,
image=image_name,
image_hash=image_hash,
labels=runtime_probe.labels,
)
try:
async with get_httpx_client(self._config_path, True) as client:
async with aconnect_sse(
client,
"POST",
url,
json=data,
headers={"Content-Type": "application/json"},
) as event_source:
if not event_source.response.is_success:
resp = await event_source.response.aread()
logger.error(
f"failed to register",
resp=str(resp),
status_code=event_source.response.status_code,
)
await asyncio.sleep(5)
continue
logger.info(
"executor_registered", executor_id=self._executor_id
)
async for sse in event_source.aiter_sse():
data = json.loads(sse.data)
tasks = []
for task_dict in data:
tasks.append(
Task.model_validate(task_dict, strict=False)
)
self._task_store.add_tasks(tasks)
except Exception as e:
logger.error("failed to register", exc_info=e)
await asyncio.sleep(5)
continue
await self._main_loop()

async def _shutdown(self, loop):
logger.info("shutting_down")
Expand Down
80 changes: 80 additions & 0 deletions python-sdk/indexify/executor/task_fetcher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import json
from importlib.metadata import version
from typing import AsyncGenerator, Optional

import httpx
import structlog
from httpx_sse import aconnect_sse

from indexify.common_util import get_httpx_client

from .api_objects import ExecutorMetadata, Task
from .runtime_probes import ProbeInfo, RuntimeProbes


class TaskFetcher:
"""Registers with Indexify server and fetches tasks from it."""

def __init__(
self,
protocol: str,
indexify_server_addr: str,
executor_id: str,
name_alias: Optional[str] = None,
image_hash: Optional[int] = None,
config_path: Optional[str] = None,
):
self._protocol: str = protocol
self._indexify_server_addr: str = indexify_server_addr
self.config_path = config_path
self._logger = structlog.get_logger(module=__name__)

probe_info: ProbeInfo = RuntimeProbes().probe()
self._executor_metadata: ExecutorMetadata = ExecutorMetadata(
id=executor_id,
executor_version=version("indexify"),
addr="",
image_name=probe_info.image_name if name_alias is None else name_alias,
image_hash=(probe_info.image_hash if image_hash is None else image_hash),
labels=probe_info.labels,
)

async def run(self) -> AsyncGenerator[Task, None]:
"""Fetches tasks that Indexify server assigned to the Executor.
Raises an exception if error occurred."""
url = f"{self._protocol}://{self._indexify_server_addr}/internal/executors/{self._executor_metadata.id}/tasks"

self._logger.info(
"registering_executor",
executor_id=self._executor_metadata.id,
url=url,
executor_version=self._executor_metadata.executor_version,
)
async with get_httpx_client(
config_path=self.config_path, make_async=True
) as client:
async with aconnect_sse(
client,
"POST",
url,
json=self._executor_metadata.model_dump(),
headers={"Content-Type": "application/json"},
) as event_source:
try:
event_source.response.raise_for_status()
except Exception as e:
await event_source.response.aread()
raise Exception(
"Failed to register at server. "
f"Response code: {event_source.response.status_code}. "
f"Response text: '{event_source.response.text}'."
) from e

self._logger.info(
"executor_registered", executor_id=self._executor_metadata.id
)
async for sse in event_source.aiter_sse():
task_dicts = json.loads(sse.data)
for task_dict in task_dicts:
yield Task.model_validate(task_dict, strict=False)
4 changes: 2 additions & 2 deletions python-sdk/tests/test_extractor_agent_behaviour.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def test_tls_configuration(self, mock_async_client, mock_sync_client, mock_file)

# Verify TLS config in Agent
self.assertEqual(agent._server_addr, service_url)
self.assertEqual(agent._protocol, "https")
self.assertTrue(agent._base_url.startswith("https://"))

def test_no_tls_configuration(self):
# Create an instance of ExtractorAgent without TLS
Expand All @@ -65,7 +65,7 @@ def test_no_tls_configuration(self):
)

# Verify the protocol is set to "http"
self.assertEqual(agent._protocol, "http")
self.assertTrue(agent._base_url.startswith("http://"))


if __name__ == "__main__":
Expand Down

0 comments on commit 5eae3fa

Please sign in to comment.