diff --git a/python-sdk/indexify/cli.py b/python-sdk/indexify/cli.py index 36db65be2..3f9b70b41 100644 --- a/python-sdk/indexify/cli.py +++ b/python-sdk/indexify/cli.py @@ -8,13 +8,11 @@ import signal import subprocess import sys -import tempfile import threading import time from importlib.metadata import version from typing import Annotated, List, Optional -import httpx import nanoid import structlog import typer @@ -24,12 +22,14 @@ from rich.theme import Theme from indexify.executor.executor import Executor +from indexify.executor.function_executor.server.subprocess_function_executor_server_factory import ( + SubprocessFunctionExecutorServerFactory, +) from indexify.function_executor.function_executor_service import ( FunctionExecutorService, ) from indexify.function_executor.server import Server as FunctionExecutorServer from indexify.functions_sdk.image import Build, GetDefaultPythonImage, Image -from indexify.http_client import IndexifyClient logger = structlog.get_logger(module=__name__) @@ -250,7 +250,9 @@ def executor( code_path=executor_cache, name_alias=name_alias, image_hash=image_hash, - development_mode=dev, + function_executor_server_factory=SubprocessFunctionExecutorServerFactory( + development_mode=dev + ), ) try: asyncio.get_event_loop().run_until_complete(executor.run()) diff --git a/python-sdk/indexify/executor/downloader.py b/python-sdk/indexify/executor/downloader.py index 7feb8a0ff..43ff94e96 100644 --- a/python-sdk/indexify/executor/downloader.py +++ b/python-sdk/indexify/executor/downloader.py @@ -13,12 +13,6 @@ from .api_objects import Task -class DownloadedInputs: - def __init__(self, input: SerializedObject, init_value: Optional[SerializedObject]): - self.input = input - self.init_value = init_value - - class Downloader: def __init__( self, code_path: str, base_url: str, config_path: Optional[str] = None @@ -78,22 +72,22 @@ def _write_cached_graph( # are atomic operations at filesystem level. os.replace(tmp_path, path) - async def download_inputs(self, task: Task) -> DownloadedInputs: + async def download_input(self, task: Task) -> SerializedObject: logger = self._task_logger(task) - input: SerializedObject first_function_in_graph = task.invocation_id == task.input_key.split("|")[-1] if first_function_in_graph: # The first function in Graph gets its input from graph invocation payload. - input = await self._fetch_graph_invocation_payload(task, logger) + return await self._fetch_graph_invocation_payload(task, logger) else: - input = await self._fetch_function_input(task, logger) + return await self._fetch_function_input(task, logger) - init_value: Optional[SerializedObject] = None - if task.reducer_output_id is not None: - init_value = await self._fetch_function_init_value(task, logger) + async def download_init_value(self, task: Task) -> Optional[SerializedObject]: + if task.reducer_output_id is None: + return None - return DownloadedInputs(input=input, init_value=init_value) + logger = self._task_logger(task) + return await self._fetch_function_init_value(task, logger) def _task_logger(self, task: Task) -> Any: return structlog.get_logger( diff --git a/python-sdk/indexify/executor/executor.py b/python-sdk/indexify/executor/executor.py index 7e4a935a5..053136843 100644 --- a/python-sdk/indexify/executor/executor.py +++ b/python-sdk/indexify/executor/executor.py @@ -10,17 +10,13 @@ ) from .api_objects import Task -from .downloader import DownloadedInputs, Downloader -from .function_executor.process_function_executor_factory import ( - ProcessFunctionExecutorFactory, -) -from .function_worker import ( - FunctionWorker, - FunctionWorkerInput, - FunctionWorkerOutput, +from .downloader import Downloader +from .function_executor.server.function_executor_server_factory import ( + FunctionExecutorServerFactory, ) from .task_fetcher import TaskFetcher from .task_reporter import TaskReporter +from .task_runner import TaskInput, TaskOutput, TaskRunner class Executor: @@ -28,8 +24,8 @@ def __init__( self, executor_id: str, code_path: Path, + function_executor_server_factory: FunctionExecutorServerFactory, server_addr: str = "localhost:8900", - development_mode: bool = False, config_path: Optional[str] = None, name_alias: Optional[str] = None, image_hash: Optional[str] = None, @@ -45,10 +41,8 @@ def __init__( self._server_addr = server_addr self._base_url = f"{protocol}://{self._server_addr}" self._code_path = code_path - self._function_worker = FunctionWorker( - function_executor_factory=ProcessFunctionExecutorFactory( - development_mode=development_mode, - ), + self._task_runnner = TaskRunner( + function_executor_server_factory=function_executor_server_factory, base_url=self._base_url, config_path=config_path, ) @@ -92,39 +86,39 @@ async def _run_task(self, task: Task) -> None: Doesn't raise any Exceptions. All errors are reported to the server.""" logger = self._task_logger(task) - output: Optional[FunctionWorkerOutput] = None + output: Optional[TaskOutput] = None try: graph: SerializedObject = await self._downloader.download_graph(task) - input: DownloadedInputs = await self._downloader.download_inputs(task) - output = await self._function_worker.run( - input=FunctionWorkerInput( + input: SerializedObject = await self._downloader.download_input(task) + init_value: Optional[SerializedObject] = ( + await self._downloader.download_init_value(task) + ) + logger.info("task_execution_started") + output: TaskOutput = await self._task_runnner.run( + TaskInput( task=task, graph=graph, - function_input=input, - ) + input=input, + init_value=init_value, + ), + logger=logger, ) logger.info("task_execution_finished", success=output.success) except Exception as e: - logger.error("failed running the task", exc_info=e) + output = TaskOutput.internal_error(task) + logger.error("task_execution_failed", exc_info=e) - await self._report_task_outcome(task=task, output=output, logger=logger) + await self._report_task_outcome(output=output, logger=logger) - async def _report_task_outcome( - self, task: Task, output: Optional[FunctionWorkerOutput], logger: Any - ) -> None: - """Reports the task with the given output to the server. - - None output means that the task execution didn't finish due to an internal error. - Doesn't raise any exceptions.""" + async def _report_task_outcome(self, output: TaskOutput, logger: Any) -> None: + """Reports the task with the given output to the server.""" reporting_retries: int = 0 while True: logger = logger.bind(retries=reporting_retries) try: - await self._task_reporter.report( - task=task, output=output, logger=logger - ) + await self._task_reporter.report(output=output, logger=logger) break except Exception as e: logger.error( @@ -137,7 +131,7 @@ async def _report_task_outcome( async def _shutdown(self, loop): self._logger.info("shutting_down") self._should_run = False - await self._function_worker.shutdown() + await self._task_runnner.shutdown() for task in asyncio.all_tasks(loop): task.cancel() diff --git a/python-sdk/indexify/executor/function_executor/function_executor.py b/python-sdk/indexify/executor/function_executor/function_executor.py index ea9c764ee..c692bf449 100644 --- a/python-sdk/indexify/executor/function_executor/function_executor.py +++ b/python-sdk/indexify/executor/function_executor/function_executor.py @@ -1,32 +1,133 @@ +import asyncio from typing import Any, Optional import grpc -# Timeout for Function Executor startup in seconds. -# The timeout is counted from the moment when the Function Executor environment -# is fully prepared and the Function Executor gets started. -FUNCTION_EXECUTOR_READY_TIMEOUT_SEC = 5 +from indexify.common_util import get_httpx_client +from indexify.function_executor.proto.function_executor_pb2 import ( + InitializeRequest, + InitializeResponse, +) +from indexify.function_executor.proto.function_executor_pb2_grpc import ( + FunctionExecutorStub, +) + +from .invocation_state_client import InvocationStateClient +from .server.function_executor_server import ( + FUNCTION_EXECUTOR_SERVER_READY_TIMEOUT_SEC, + FunctionExecutorServer, +) +from .server.function_executor_server_factory import ( + FunctionExecutorServerConfiguration, + FunctionExecutorServerFactory, +) class FunctionExecutor: - """Abstract interface for a FunctionExecutor. + """Executor side class supporting a running FunctionExecutorServer. + + FunctionExecutor primary responsibility is creation and initialization + of all resources associated with a particular Function Executor Server + including the Server itself. FunctionExecutor owns all these resources + and provides other Executor components with access to them. - FunctionExecutor is a class that executes tasks for a particular function. - FunctionExecutor implements the gRPC server that listens for incoming tasks. + Addition of any business logic besides resource management is discouraged. + Please add such logic to other classes managed by this class. """ - async def channel(self) -> grpc.aio.Channel: - """Returns a async gRPC channel to the Function Executor. + def __init__(self, server_factory: FunctionExecutorServerFactory, logger: Any): + self._server_factory: FunctionExecutorServerFactory = server_factory + self._logger = logger.bind(module=__name__) + self._server: Optional[FunctionExecutorServer] = None + self._channel: Optional[grpc.aio.Channel] = None + self._invocation_state_client: Optional[InvocationStateClient] = None + self._initialized = False + + async def initialize( + self, + config: FunctionExecutorServerConfiguration, + initialize_request: InitializeRequest, + base_url: str, + config_path: Optional[str], + ): + """Creates and initializes a FunctionExecutorServer and all resources associated with it.""" + try: + self._server = await self._server_factory.create( + config=config, logger=self._logger + ) + self._channel = await self._server.create_channel(self._logger) + await _channel_ready(self._channel) + + stub: FunctionExecutorStub = FunctionExecutorStub(self._channel) + await _initialize_server(stub, initialize_request) + + self._invocation_state_client = InvocationStateClient( + stub=stub, + base_url=base_url, + http_client=get_httpx_client(config_path=config_path, make_async=True), + graph=initialize_request.graph_name, + namespace=initialize_request.namespace, + logger=self._logger, + ) + await self._invocation_state_client.start() + + self._initialized = True + except Exception: + await self.destroy() + raise + + def channel(self) -> grpc.aio.Channel: + self._check_initialized() + return self._channel + + def invocation_state_client(self) -> InvocationStateClient: + self._check_initialized() + return self._invocation_state_client + + async def destroy(self): + """Destroys all resources owned by this FunctionExecutor. + + Never raises any exceptions but logs them.""" + try: + if self._invocation_state_client is not None: + await self._invocation_state_client.destroy() + self._invocation_state_client = None + except Exception as e: + self._logger.error( + "failed to destroy FunctionExecutor invocation state client", exc_info=e + ) + + try: + if self._channel is not None: + await self._channel.close() + self._channel = None + except Exception as e: + self._logger.error( + "failed to close FunctionExecutorServer channel", exc_info=e + ) + + try: + if self._server is not None: + await self._server_factory.destroy(self._server, self._logger) + self._server = None + except Exception as e: + self._logger.error("failed to destroy FunctionExecutorServer", exc_info=e) + + def _check_initialized(self): + if not self._initialized: + raise RuntimeError("FunctionExecutor is not initialized") + - The channel is in ready state and can be used for all gRPC communication with the Function Executor - and can be shared among coroutines running in the same event loop in the same thread. Users should - not close the channel as it's reused for all requests. - Raises Exception if an error occurred.""" - raise NotImplementedError +async def _channel_ready(channel: grpc.aio.Channel): + await asyncio.wait_for( + channel.channel_ready(), + timeout=FUNCTION_EXECUTOR_SERVER_READY_TIMEOUT_SEC, + ) - def state(self) -> Optional[Any]: - """Returns optional state object. - The state object can be used to associate any data with the Function Executor. - """ - raise NotImplementedError +async def _initialize_server( + stub: FunctionExecutorStub, initialize_request: InitializeRequest +): + initialize_response: InitializeResponse = await stub.initialize(initialize_request) + if not initialize_response.success: + raise Exception("initialize RPC failed at function executor server") diff --git a/python-sdk/indexify/executor/function_executor/function_executor_factory.py b/python-sdk/indexify/executor/function_executor/function_executor_factory.py deleted file mode 100644 index 9d54500be..000000000 --- a/python-sdk/indexify/executor/function_executor/function_executor_factory.py +++ /dev/null @@ -1,26 +0,0 @@ -from typing import Any, Optional - -from .function_executor import FunctionExecutor - - -class FunctionExecutorFactory: - """Abstract class for creating function executors.""" - - async def create( - self, logger: Any, state: Optional[Any] = None - ) -> FunctionExecutor: - """Creates a new FunctionExecutor. - - Args: - logger: logger to be used during the function. - state: state to be stored in the FunctionExecutor.""" - raise NotImplementedError() - - async def destroy(self, executor: FunctionExecutor, logger: Any) -> None: - """Destroys the FunctionExecutor and release all its resources. - - Args: - logger: logger to be used during the function. - FunctionExecutor and customer code running inside of it are not notified about the destruction. - Never raises any Exceptions.""" - raise NotImplementedError diff --git a/python-sdk/indexify/executor/function_executor/function_executor_map.py b/python-sdk/indexify/executor/function_executor/function_executor_map.py deleted file mode 100644 index 5d7eed77a..000000000 --- a/python-sdk/indexify/executor/function_executor/function_executor_map.py +++ /dev/null @@ -1,120 +0,0 @@ -import asyncio -from typing import Any, Dict, Optional - -import grpc - -from indexify.common_util import get_httpx_client -from indexify.function_executor.proto.function_executor_pb2 import ( - InitializeRequest, - InitializeResponse, -) -from indexify.function_executor.proto.function_executor_pb2_grpc import ( - FunctionExecutorStub, -) - -from .function_executor import FunctionExecutor -from .function_executor_factory import FunctionExecutorFactory -from .invocation_state_client import InvocationStateClient - - -class FunctionExecutorMap: - """A map of ID => FunctionExecutor. - - The map is safe to use by multiple couroutines running in event loop on the same thread - but it's not thread safe (can't be used from different threads concurrently).""" - - def __init__( - self, - factory: FunctionExecutorFactory, - base_url: str, - config_path: Optional[str], - ): - self._factory = factory - self._base_url = base_url - self._config_path = config_path - # Map of initialized Function executors ready to run tasks. - # Function ID -> FunctionExecutor. - self._executors: Dict[str, FunctionExecutor] = {} - # We have to do all operations under this lock because we need to ensure - # that we don't create more Function Executors than required. This is important - # e.g. when a Function Executor is using the only available GPU on the machine. - # We can get rid of this locking in the future once we assing GPUs explicitly to Function Executors. - # Running the full test suite with all this locking removed doesn't make it run faster, - # so it looks like this full locking doesn't really result in any performance penalty so far. - self._executors_lock = asyncio.Lock() - - async def get_or_create( - self, - id: str, - initialize_request: InitializeRequest, - initial_state: Any, - logger: Any, - ) -> FunctionExecutor: - """Returns a FunctionExecutor for the given ID. - - If the FunctionExecutor for the given ID doesn't exist then it will be created and initialized. - Raises an exception if the FunctionExecutor creation or initialization failed. - """ - async with self._executors_lock: - # Use existing Function Executor if it's already initialized. - if id in self._executors: - return self._executors[id] - - executor: Optional[FunctionExecutor] = None - invocation_state_client: Optional[InvocationStateClient] = None - try: - executor = await self._factory.create(logger, state=initial_state) - channel: grpc.aio.Channel = await executor.channel() - stub: FunctionExecutorStub = FunctionExecutorStub(channel) - initialize_response: InitializeResponse = await stub.initialize( - initialize_request - ) - if not initialize_response.success: - raise Exception("initialize RPC failed at function executor") - invocation_state_client = InvocationStateClient( - stub=stub, - base_url=self._base_url, - http_client=get_httpx_client( - config_path=self._config_path, make_async=True - ), - graph=initialize_request.graph_name, - namespace=initialize_request.namespace, - logger=logger, - ) - await invocation_state_client.start() - # This is dirty but requires refactoring to implement properly. - initial_state.invocation_state_client = invocation_state_client - except Exception: - if invocation_state_client is not None: - await invocation_state_client.destroy() - if executor is not None: - await self._factory.destroy(executor=executor, logger=logger) - # Function Executor creation or initialization failed. - raise - - self._executors[id] = executor - return executor - - async def delete( - self, id: str, function_executor: FunctionExecutor, logger: Any - ) -> None: - """Deletes the FunctionExecutor for the given ID. - - Does nothing if the FunctionExecutor for the given ID doesn't exist or was already deleted. - """ - async with self._executors_lock: - if self._executors[id] != function_executor: - # Function Executor was already deleted or replaced and the caller is not aware of this. - return - del self._executors[id] - if function_executor.state().invocation_state_client is not None: - await function_executor.state().invocation_state_client.destroy() - await self._factory.destroy(executor=function_executor, logger=logger) - - async def clear(self, logger): - async with self._executors_lock: - while self._executors: - id, function_executor = self._executors.popitem() - if function_executor.state().invocation_state_client is not None: - await function_executor.state().invocation_state_client.destroy() - await self._factory.destroy(function_executor, logger) diff --git a/python-sdk/indexify/executor/function_executor/function_executor_state.py b/python-sdk/indexify/executor/function_executor/function_executor_state.py new file mode 100644 index 000000000..c2ba29b7c --- /dev/null +++ b/python-sdk/indexify/executor/function_executor/function_executor_state.py @@ -0,0 +1,75 @@ +import asyncio +from typing import Optional + +from .function_executor import FunctionExecutor + + +class FunctionExecutorState: + """State of a Function Executor with a particular ID. + + The Function Executor might not exist, i.e. not yet created or destroyed. + This object represents all such states. Any state modification must be done + under the lock. + """ + + def __init__(self, function_id_with_version: str, function_id_without_version: str): + self.function_id_with_version: str = function_id_with_version + self.function_id_without_version: str = function_id_without_version + self.function_executor: Optional[FunctionExecutor] = None + self.running_tasks: int = 0 + self.lock: asyncio.Lock = asyncio.Lock() + self.running_tasks_change_notifier: asyncio.Condition = asyncio.Condition( + lock=self.lock + ) + + def increment_running_tasks(self) -> None: + """Increments the number of running tasks. + + The caller must hold the lock. + """ + self.check_locked() + self.running_tasks += 1 + self.running_tasks_change_notifier.notify_all() + + def decrement_running_tasks(self) -> None: + """Decrements the number of running tasks. + + The caller must hold the lock. + """ + self.check_locked() + self.running_tasks -= 1 + self.running_tasks_change_notifier.notify_all() + + async def wait_running_tasks_less(self, value: int) -> None: + """Waits until the number of running tasks is less than the supplied value. + + The caller must hold the lock. + """ + self.check_locked() + while self.running_tasks >= value: + await self.running_tasks_change_notifier.wait() + + async def destroy_function_executor(self) -> None: + """Destroys the Function Executor if it exists. + + The caller must hold the lock.""" + self.check_locked() + if self.function_executor is not None: + await self.function_executor.destroy() + self.function_executor = None + + async def destroy_function_executor_not_locked(self) -> None: + """Destroys the Function Executor if it exists. + + The caller doesn't need to hold the lock but this call + might make the state inconsistent.""" + if self.function_executor is not None: + # Atomically hide the destroyed Function Executor from other asyncio tasks. + ref = self.function_executor + self.function_executor = None + await ref.destroy() + + def check_locked(self) -> None: + """Raises an exception if the lock is not held.""" + if not self.lock.locked(): + raise RuntimeError("The FunctionExecutorState lock must be held.") diff --git a/python-sdk/indexify/executor/function_executor/process_function_executor.py b/python-sdk/indexify/executor/function_executor/process_function_executor.py deleted file mode 100644 index 6a394f1e0..000000000 --- a/python-sdk/indexify/executor/function_executor/process_function_executor.py +++ /dev/null @@ -1,64 +0,0 @@ -import asyncio -from typing import Any, Optional - -import grpc - -from indexify.function_executor.proto.configuration import GRPC_CHANNEL_OPTIONS - -from .function_executor import ( - FUNCTION_EXECUTOR_READY_TIMEOUT_SEC, - FunctionExecutor, -) - - -class ProcessFunctionExecutor(FunctionExecutor): - """A FunctionExecutor that runs in a separate host process.""" - - def __init__( - self, - process: asyncio.subprocess.Process, - port: int, - address: str, - logger: Any, - state: Optional[Any] = None, - ): - self._proc = process - self._port = port - self._address = address - self._logger = logger.bind(module=__name__) - self._channel: Optional[grpc.aio.Channel] = None - self._state: Optional[Any] = state - - async def channel(self) -> grpc.aio.Channel: - # Not thread safe but async safe because we don't await. - if self._channel is not None: - return self._channel - - channel: Optional[grpc.aio.Channel] = None - try: - channel = grpc.aio.insecure_channel( - self._address, options=GRPC_CHANNEL_OPTIONS - ) - await asyncio.wait_for( - channel.channel_ready(), - timeout=FUNCTION_EXECUTOR_READY_TIMEOUT_SEC, - ) - # Check if another channel was created by a concurrent coroutine. - # Not thread safe but async safe because we never overwrite non-None self._channel. - if self._channel is not None: - # Don't close and overwrite existing channel because it might be used for RPCs already. - await channel.close() - return self._channel - else: - self._channel = channel - return channel - except Exception: - if channel is not None: - await channel.close() - self._logger.error( - f"failed to connect to the gRPC server at {self._address} within {FUNCTION_EXECUTOR_READY_TIMEOUT_SEC} seconds" - ) - raise - - def state(self) -> Optional[Any]: - return self._state diff --git a/python-sdk/indexify/executor/function_executor/server/function_executor_server.py b/python-sdk/indexify/executor/function_executor/server/function_executor_server.py new file mode 100644 index 000000000..1f1aa8b9a --- /dev/null +++ b/python-sdk/indexify/executor/function_executor/server/function_executor_server.py @@ -0,0 +1,24 @@ +from typing import Any + +import grpc + +# Timeout for Function Executor Server startup in seconds. The timeout is counted from +# the moment when a server just started. +FUNCTION_EXECUTOR_SERVER_READY_TIMEOUT_SEC = 5 + + +class FunctionExecutorServer: + """Abstract interface for a Function Executor Server. + + FunctionExecutorServer is a class that executes tasks for a particular function. + The communication with FunctionExecutorServer is typicall done via gRPC. + """ + + async def create_channel(self, logger: Any) -> grpc.aio.Channel: + """Creates a new async gRPC channel to the Function Executor Server. + + The channel is in ready state. It can only be used in the same thread where the + function was called. Caller should close the channel when it's no longer needed. + + Raises Exception if an error occurred.""" + raise NotImplementedError diff --git a/python-sdk/indexify/executor/function_executor/server/function_executor_server_factory.py b/python-sdk/indexify/executor/function_executor/server/function_executor_server_factory.py new file mode 100644 index 000000000..31c4759ea --- /dev/null +++ b/python-sdk/indexify/executor/function_executor/server/function_executor_server_factory.py @@ -0,0 +1,43 @@ +from typing import Any, Optional + +from .function_executor_server import FunctionExecutorServer + + +class FunctionExecutorServerConfiguration: + """Configuration for creating a FunctionExecutorServer. + + This configuration only includes data that must be known + during creation of the FunctionExecutorServer. If some data + is not required during the creation then it shouldn't be here. + + A particular factory implementation might ignore certain + configuration parameters or raise an exception if it can't implement + them.""" + + def __init__(self, image_uri: Optional[str]): + # Container image URI of the Function Executor Server. + self.image_uri: Optional[str] = image_uri + + +class FunctionExecutorServerFactory: + """Abstract class for creating FunctionExecutorServers.""" + + async def create( + self, config: FunctionExecutorServerConfiguration, logger: Any + ) -> FunctionExecutorServer: + """Creates a new FunctionExecutorServer. + + Raises an exception if the creation failed or the configuration is not supported. + Args: + config: configuration of the FunctionExecutorServer. + logger: logger to be used during the function call.""" + raise NotImplementedError() + + async def destroy(self, server: FunctionExecutorServer, logger: Any) -> None: + """Destroys the FunctionExecutorServer and release all its resources. + + Args: + logger: logger to be used during the function call. + FunctionExecutorServer and customer code that it's running are not notified about the destruction. + Never raises any Exceptions.""" + raise NotImplementedError diff --git a/python-sdk/indexify/executor/function_executor/server/subprocess_function_executor_server.py b/python-sdk/indexify/executor/function_executor/server/subprocess_function_executor_server.py new file mode 100644 index 000000000..91cf8b74d --- /dev/null +++ b/python-sdk/indexify/executor/function_executor/server/subprocess_function_executor_server.py @@ -0,0 +1,25 @@ +import asyncio +from typing import Any + +import grpc + +from indexify.function_executor.proto.configuration import GRPC_CHANNEL_OPTIONS + +from .function_executor_server import FunctionExecutorServer + + +class SubprocessFunctionExecutorServer(FunctionExecutorServer): + """A FunctionExecutorServer that runs in a child process.""" + + def __init__( + self, + process: asyncio.subprocess.Process, + port: int, + address: str, + ): + self._proc = process + self._port = port + self._address = address + + async def create_channel(self, logger: Any) -> grpc.aio.Channel: + return grpc.aio.insecure_channel(self._address, options=GRPC_CHANNEL_OPTIONS) diff --git a/python-sdk/indexify/executor/function_executor/process_function_executor_factory.py b/python-sdk/indexify/executor/function_executor/server/subprocess_function_executor_server_factory.py similarity index 77% rename from python-sdk/indexify/executor/function_executor/process_function_executor_factory.py rename to python-sdk/indexify/executor/function_executor/server/subprocess_function_executor_server_factory.py index c753ea4e1..fc7fdfa34 100644 --- a/python-sdk/indexify/executor/function_executor/process_function_executor_factory.py +++ b/python-sdk/indexify/executor/function_executor/server/subprocess_function_executor_server_factory.py @@ -1,11 +1,16 @@ import asyncio from typing import Any, Optional -from .function_executor_factory import FunctionExecutorFactory -from .process_function_executor import ProcessFunctionExecutor +from .function_executor_server_factory import ( + FunctionExecutorServerConfiguration, + FunctionExecutorServerFactory, +) +from .subprocess_function_executor_server import ( + SubprocessFunctionExecutorServer, +) -class ProcessFunctionExecutorFactory(FunctionExecutorFactory): +class SubprocessFunctionExecutorServerFactory(FunctionExecutorServerFactory): def __init__( self, development_mode: bool, @@ -15,8 +20,13 @@ def __init__( self._free_ports = set(range(50000, 51000)) async def create( - self, logger: Any, state: Optional[Any] = None - ) -> ProcessFunctionExecutor: + self, config: FunctionExecutorServerConfiguration, logger: Any + ) -> SubprocessFunctionExecutorServer: + if config.image_uri is not None: + raise ValueError( + "SubprocessFunctionExecutorServerFactory doesn't support container images" + ) + logger = logger.bind(module=__name__) port: Optional[int] = None @@ -37,12 +47,10 @@ async def create( "indexify-cli", *args, ) - return ProcessFunctionExecutor( + return SubprocessFunctionExecutorServer( process=proc, port=port, address=_server_address(port), - logger=logger, - state=state, ) except Exception as e: if port is not None: @@ -53,9 +61,11 @@ async def create( ) raise - async def destroy(self, executor: ProcessFunctionExecutor, logger: Any) -> None: - proc: asyncio.subprocess.Process = executor._proc - port: int = executor._port + async def destroy( + self, server: SubprocessFunctionExecutorServer, logger: Any + ) -> None: + proc: asyncio.subprocess.Process = server._proc + port: int = server._port logger = logger.bind( module=__name__, pid=proc.pid, @@ -76,8 +86,6 @@ async def destroy(self, executor: ProcessFunctionExecutor, logger: Any) -> None: ) finally: self._release_port(port) - if executor._channel is not None: - await executor._channel.close() def _allocate_port(self) -> int: # No asyncio.Lock is required here because this operation never awaits diff --git a/python-sdk/indexify/executor/function_executor/single_task_runner.py b/python-sdk/indexify/executor/function_executor/single_task_runner.py new file mode 100644 index 000000000..3ce859aea --- /dev/null +++ b/python-sdk/indexify/executor/function_executor/single_task_runner.py @@ -0,0 +1,160 @@ +from typing import Any, Optional + +import grpc + +from indexify.function_executor.proto.function_executor_pb2 import ( + InitializeRequest, + RunTaskRequest, + RunTaskResponse, +) +from indexify.function_executor.proto.function_executor_pb2_grpc import ( + FunctionExecutorStub, +) + +from ..api_objects import Task +from .function_executor import FunctionExecutor +from .function_executor_state import FunctionExecutorState +from .server.function_executor_server_factory import ( + FunctionExecutorServerConfiguration, + FunctionExecutorServerFactory, +) +from .task_input import TaskInput +from .task_output import TaskOutput + + +class SingleTaskRunner: + def __init__( + self, + function_executor_state: FunctionExecutorState, + task_input: TaskInput, + function_executor_server_factory: FunctionExecutorServerFactory, + base_url: str, + config_path: Optional[str], + logger: Any, + ): + self._state: FunctionExecutorState = function_executor_state + self._task_input: TaskInput = task_input + self._factory: FunctionExecutorServerFactory = function_executor_server_factory + self._base_url: str = base_url + self._config_path: Optional[str] = config_path + self._logger = logger.bind(module=__name__) + + async def run(self) -> TaskOutput: + """Runs the task in the Function Executor. + + The FunctionExecutorState must be locked by the caller. + The lock is released during actual task run in the server. + The lock is relocked on return. + + Raises an exception if an error occured.""" + self._state.check_locked() + + if self._state.function_executor is None: + self._state.function_executor = await self._create_function_executor() + + return await self._run() + + async def _create_function_executor(self) -> FunctionExecutor: + function_executor: FunctionExecutor = FunctionExecutor( + server_factory=self._factory, logger=self._logger + ) + try: + config: FunctionExecutorServerConfiguration = ( + FunctionExecutorServerConfiguration( + image_uri=self._task_input.task.image_uri, + ) + ) + initialize_request: InitializeRequest = InitializeRequest( + namespace=self._task_input.task.namespace, + graph_name=self._task_input.task.compute_graph, + graph_version=self._task_input.task.graph_version, + function_name=self._task_input.task.compute_fn, + graph=self._task_input.graph, + ) + await function_executor.initialize( + config=config, + initialize_request=initialize_request, + base_url=self._base_url, + config_path=self._config_path, + ) + return function_executor + except Exception as e: + self._logger.error( + "failed to initialize function executor", + exc_info=e, + ) + await function_executor.destroy() + raise + + async def _run(self) -> TaskOutput: + request: RunTaskRequest = RunTaskRequest( + graph_invocation_id=self._task_input.task.invocation_id, + task_id=self._task_input.task.id, + function_input=self._task_input.input, + ) + if self._task_input.init_value is not None: + request.function_init_value.CopyFrom(self._task_input.init_value) + channel: grpc.aio.Channel = self._state.function_executor.channel() + + async with _RunningTaskContextManager( + task_input=self._task_input, function_executor_state=self._state + ): + response: RunTaskResponse = await FunctionExecutorStub(channel).run_task( + request + ) + return _task_output(task=self._task_input.task, response=response) + + +class _RunningTaskContextManager: + """Performs all the actions required before and after running a task.""" + + def __init__( + self, task_input: TaskInput, function_executor_state: FunctionExecutorState + ): + self._task_input: TaskInput = task_input + self._state: FunctionExecutorState = function_executor_state + + async def __aenter__(self): + self._state.increment_running_tasks() + self._state.function_executor.invocation_state_client().add_task_to_invocation_id_entry( + task_id=self._task_input.task.id, + invocation_id=self._task_input.task.invocation_id, + ) + # Unlock the state so other tasks can act depending on it. + self._state.lock.release() + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + await self._state.lock.acquire() + self._state.decrement_running_tasks() + self._state.function_executor.invocation_state_client().remove_task_to_invocation_id_entry( + task_id=self._task_input.task.id + ) + + +def _task_output(task: Task, response: RunTaskResponse) -> TaskOutput: + required_fields = [ + "stdout", + "stderr", + "is_reducer", + "success", + ] + + for field in required_fields: + if not response.HasField(field): + raise ValueError(f"Response is missing required field: {field}") + + output = TaskOutput( + task=task, + stdout=response.stdout, + stderr=response.stderr, + reducer=response.is_reducer, + success=response.success, + ) + + if response.HasField("function_output"): + output.function_output = response.function_output + if response.HasField("router_output"): + output.router_output = response.router_output + + return output diff --git a/python-sdk/indexify/executor/function_executor/task_input.py b/python-sdk/indexify/executor/function_executor/task_input.py new file mode 100644 index 000000000..2980ec282 --- /dev/null +++ b/python-sdk/indexify/executor/function_executor/task_input.py @@ -0,0 +1,23 @@ +from typing import Optional + +from indexify.function_executor.proto.function_executor_pb2 import ( + SerializedObject, +) + +from ..api_objects import Task + + +class TaskInput: + """Task with all the resources required to run it.""" + + def __init__( + self, + task: Task, + graph: SerializedObject, + input: SerializedObject, + init_value: Optional[SerializedObject], + ): + self.task: Task = task + self.graph: SerializedObject = graph + self.input: SerializedObject = input + self.init_value: Optional[SerializedObject] = init_value diff --git a/python-sdk/indexify/executor/function_executor/task_output.py b/python-sdk/indexify/executor/function_executor/task_output.py new file mode 100644 index 000000000..fdcad5c43 --- /dev/null +++ b/python-sdk/indexify/executor/function_executor/task_output.py @@ -0,0 +1,36 @@ +from typing import Optional + +from indexify.function_executor.proto.function_executor_pb2 import ( + FunctionOutput, + RouterOutput, +) + +from ..api_objects import Task + + +class TaskOutput: + """Result of running a task.""" + + def __init__( + self, + task: Task, + function_output: Optional[FunctionOutput] = None, + router_output: Optional[RouterOutput] = None, + stdout: Optional[str] = None, + stderr: Optional[str] = None, + reducer: bool = False, + success: bool = False, + ): + self.task = task + self.function_output = function_output + self.router_output = router_output + self.stdout = stdout + self.stderr = stderr + self.reducer = reducer + self.success = success + + @classmethod + def internal_error(cls, task: Task) -> "TaskOutput": + """Creates a TaskOutput for an internal error.""" + # We are not sharing internal error messages with the customer. + return TaskOutput(task=task, stderr="Platform failed to execute the function.") diff --git a/python-sdk/indexify/executor/function_worker.py b/python-sdk/indexify/executor/function_worker.py deleted file mode 100644 index 8cb79f9ff..000000000 --- a/python-sdk/indexify/executor/function_worker.py +++ /dev/null @@ -1,273 +0,0 @@ -import asyncio -from typing import Any, Dict, Optional - -import grpc -import structlog - -from indexify.function_executor.proto.function_executor_pb2 import ( - FunctionOutput, - InitializeRequest, - RouterOutput, - RunTaskRequest, - RunTaskResponse, - SerializedObject, -) -from indexify.function_executor.proto.function_executor_pb2_grpc import ( - FunctionExecutorStub, -) - -from .api_objects import Task -from .downloader import DownloadedInputs -from .function_executor.function_executor import FunctionExecutor -from .function_executor.function_executor_factory import ( - FunctionExecutorFactory, -) -from .function_executor.function_executor_map import FunctionExecutorMap -from .function_executor.invocation_state_client import InvocationStateClient - - -class FunctionWorkerInput: - """Task with all the resources required to run it.""" - - def __init__( - self, - task: Task, - graph: SerializedObject, - function_input: DownloadedInputs, - ): - self.task = task - self.graph = graph - self.function_input = function_input - - -class FunctionWorkerOutput: - def __init__( - self, - function_output: Optional[FunctionOutput] = None, - router_output: Optional[RouterOutput] = None, - stdout: Optional[str] = None, - stderr: Optional[str] = None, - reducer: bool = False, - success: bool = False, - ): - self.function_output = function_output - self.router_output = router_output - self.stdout = stdout - self.stderr = stderr - self.reducer = reducer - self.success = success - - -class FunctionExecutorState: - def __init__( - self, - function_id_with_version: str, - function_id_without_version: str, - ongoing_tasks_count: int, - invocation_state_client: Optional[InvocationStateClient] = None, - ): - self.function_id_with_version: str = function_id_with_version - self.function_id_without_version: str = function_id_without_version - self.ongoing_tasks_count: int = ongoing_tasks_count - self.invocation_state_client: Optional[InvocationStateClient] = ( - invocation_state_client - ) - - -class FunctionWorker: - def __init__( - self, - function_executor_factory: FunctionExecutorFactory, - base_url: str, - config_path: Optional[str], - ): - self._function_executors = FunctionExecutorMap( - factory=function_executor_factory, - base_url=base_url, - config_path=config_path, - ) - - async def run(self, input: FunctionWorkerInput) -> FunctionWorkerOutput: - logger = _logger(input.task) - function_executor: Optional[FunctionExecutor] = None - try: - function_executor = await self._obtain_function_executor(input, logger) - return await self._run_in_executor( - function_executor=function_executor, input=input - ) - except Exception as e: - logger.error( - "failed running the task", - exc_info=e, - ) - if function_executor is not None: - # This will fail all the tasks concurrently running in this Function Executor. Not great. - await self._function_executors.delete( - id=_function_id_without_version(input.task), - function_executor=function_executor, - logger=logger, - ) - return _internal_error_output() - - async def _obtain_function_executor( - self, input: FunctionWorkerInput, logger: Any - ) -> FunctionExecutor: - # Temporary policy for Function Executors lifecycle: - # There can only be a single Function Executor per function. - # If a Function Executor already exists for a different function version then wait until - # all the tasks finish in the existing Function Executor and then destroy it first. - initialize_request: InitializeRequest = InitializeRequest( - namespace=input.task.namespace, - graph_name=input.task.compute_graph, - graph_version=input.task.graph_version, - function_name=input.task.compute_fn, - graph=input.graph, - ) - initial_function_executor_state: FunctionExecutorState = FunctionExecutorState( - function_id_with_version=_function_id_with_version(input.task), - function_id_without_version=_function_id_without_version(input.task), - ongoing_tasks_count=0, - ) - - while True: - function_executor = await self._function_executors.get_or_create( - id=_function_id_without_version(input.task), - initialize_request=initialize_request, - initial_state=initial_function_executor_state, - logger=logger, - ) - - # No need to lock Function Executor state as we are not awaiting. - function_executor_state: FunctionExecutorState = function_executor.state() - if ( - function_executor_state.function_id_with_version - == _function_id_with_version(input.task) - ): - # The existing Function Executor is for the same function version so we can run the task in it. - # Increment the ongoing tasks count before awaiting to prevent the Function Executor from being destroyed - # by another coroutine. - function_executor_state.ongoing_tasks_count += 1 - return function_executor - - # This loop implements the temporary policy so it's implemented using polling instead of a lower - # latency event based mechanism with a higher complexity. - if function_executor_state.ongoing_tasks_count == 0: - logger.info( - "destroying existing Function Executor for different function version", - function_id=_function_id_with_version(input.task), - executor_function_id=function_executor_state.function_id_with_version, - ) - await self._function_executors.delete( - id=_function_id_without_version(input.task), - function_executor=function_executor, - logger=logger, - ) - else: - logger.info( - "waiting for existing Function Executor to finish", - function_id=_function_id_with_version(input.task), - executor_function_id=function_executor_state.function_id_with_version, - ) - await asyncio.sleep( - 5 - ) # Wait for 5 secs before checking if all tasks for the existing Function Executor finished. - - async def _run_in_executor( - self, function_executor: FunctionExecutor, input: FunctionWorkerInput - ) -> FunctionWorkerOutput: - """Runs the task in the Function Executor. - - The Function Executor's ongoing_tasks_count must be incremented before calling this function. - """ - try: - run_task_request: RunTaskRequest = RunTaskRequest( - graph_invocation_id=input.task.invocation_id, - task_id=input.task.id, - function_input=input.function_input.input, - ) - if input.function_input.init_value is not None: - run_task_request.function_init_value.CopyFrom( - input.function_input.init_value - ) - channel: grpc.aio.Channel = await function_executor.channel() - function_executor.state().invocation_state_client.add_task_to_invocation_id_entry( - task_id=input.task.id, invocation_id=input.task.invocation_id - ) - run_task_response: RunTaskResponse = await FunctionExecutorStub( - channel - ).run_task(run_task_request) - return _to_output(run_task_response) - finally: - # If this Function Executor was destroyed then it's not - # visible in the map but we still have a reference to it. - function_executor.state().ongoing_tasks_count -= 1 - function_executor.state().invocation_state_client.remove_task_to_invocation_id_entry( - input.task.id - ) - - async def shutdown(self) -> None: - await self._function_executors.clear( - logger=structlog.get_logger(module=__name__, event="shutdown") - ) - - -def _to_output(response: RunTaskResponse) -> FunctionWorkerOutput: - required_fields = [ - "stdout", - "stderr", - "is_reducer", - "success", - ] - - for field in required_fields: - if not response.HasField(field): - raise ValueError(f"Response is missing required field: {field}") - - output = FunctionWorkerOutput( - stdout=response.stdout, - stderr=response.stderr, - reducer=response.is_reducer, - success=response.success, - ) - - if response.HasField("function_output"): - output.function_output = response.function_output - if response.HasField("router_output"): - output.router_output = response.router_output - - return output - - -def _internal_error_output() -> FunctionWorkerOutput: - return FunctionWorkerOutput( - function_output=None, - router_output=None, - stdout=None, - # We are not sharing internal error messages with the customer. - # If customer code failed then we won't get any exceptions here. - # All customer code errors are returned in the gRPC response. - stderr="Platform failed to execute the function.", - reducer=False, - success=False, - ) - - -def _logger(task: Task) -> Any: - return structlog.get_logger( - module=__name__, - namespace=task.namespace, - graph_name=task.compute_graph, - graph_version=task.graph_version, - function_name=task.compute_fn, - graph_invocation_id=task.invocation_id, - task_id=task.id, - function_id=_function_id_with_version(task), - ) - - -def _function_id_with_version(task: Task) -> str: - return f"versioned/{task.namespace}/{task.compute_graph}/{task.graph_version}/{task.compute_fn}" - - -def _function_id_without_version(task: Task) -> str: - return f"not_versioned/{task.namespace}/{task.compute_graph}/{task.compute_fn}" diff --git a/python-sdk/indexify/executor/task_reporter.py b/python-sdk/indexify/executor/task_reporter.py index a851a1c13..b75e2737b 100644 --- a/python-sdk/indexify/executor/task_reporter.py +++ b/python-sdk/indexify/executor/task_reporter.py @@ -10,7 +10,7 @@ FunctionOutput, ) -from .function_worker import FunctionWorkerOutput +from .task_runner import TaskOutput # https://github.com/psf/requests/issues/1081#issuecomment-428504128 @@ -48,17 +48,10 @@ def __init__( # results in not reusing established TCP connections to server. self._client = get_httpx_client(config_path, make_async=False) - async def report( - self, task: Task, output: Optional[FunctionWorkerOutput], logger: Any - ): - """Reports result of the supplied task. - - If FunctionWorkerOutput is None this means that the task didn't finish and failed with internal error. - """ + async def report(self, output: TaskOutput, logger: Any): + """Reports result of the supplied task.""" logger = logger.bind(module=__name__) - task_result, output_files, output_summary = self._process_task_output( - task, output - ) + task_result, output_files, output_summary = self._process_task_output(output) task_result_data = task_result.model_dump_json(exclude_none=True) logger.info( @@ -100,16 +93,16 @@ async def report( ) from e def _process_task_output( - self, task: Task, output: Optional[FunctionWorkerOutput] + self, output: TaskOutput ) -> Tuple[TaskResult, List[Any], TaskOutputSummary]: task_result = TaskResult( outcome="failure", - namespace=task.namespace, - compute_graph=task.compute_graph, - compute_fn=task.compute_fn, - invocation_id=task.invocation_id, + namespace=output.task.namespace, + compute_graph=output.task.compute_graph, + compute_fn=output.task.compute_fn, + invocation_id=output.task.invocation_id, executor_id=self._executor_id, - task_id=task.id, + task_id=output.task.id, ) output_files: List[Any] = [] summary: TaskOutputSummary = TaskOutputSummary() diff --git a/python-sdk/indexify/executor/task_runner.py b/python-sdk/indexify/executor/task_runner.py new file mode 100644 index 000000000..d783138db --- /dev/null +++ b/python-sdk/indexify/executor/task_runner.py @@ -0,0 +1,104 @@ +from typing import Any, Dict, Optional + +from .api_objects import Task +from .function_executor.function_executor_state import FunctionExecutorState +from .function_executor.server.function_executor_server_factory import ( + FunctionExecutorServerFactory, +) +from .function_executor.single_task_runner import SingleTaskRunner +from .function_executor.task_input import TaskInput +from .function_executor.task_output import TaskOutput + + +class TaskRunner: + """Routes a task to its container following a scheduling policy. + + Due to the scheduling policy a task might be blocked for a while.""" + + def __init__( + self, + function_executor_server_factory: FunctionExecutorServerFactory, + base_url: str, + config_path: Optional[str], + ): + self._factory: FunctionExecutorServerFactory = function_executor_server_factory + self._base_url: str = base_url + self._config_path: Optional[str] = config_path + # We don't lock this map cause we never await while reading and modifying it. + self._function_executor_states: Dict[str, FunctionExecutorState] = {} + + async def run(self, task_input: TaskInput, logger: Any) -> TaskOutput: + logger = logger.bind(module=__name__) + try: + return await self._run(task_input, logger) + except Exception as e: + logger.error( + "failed running the task", + exc_info=e, + ) + return TaskOutput.internal_error(task_input.task) + + async def _run(self, task_input: TaskInput, logger: Any) -> TaskOutput: + state = self._get_or_create_state(task_input.task) + async with state.lock: + await self._run_task_policy(state, task_input.task) + return await self._run_task(state, task_input, logger) + + async def _run_task_policy(self, state: FunctionExecutorState, task: Task) -> None: + # Current policy for running tasks: + # - There can only be a single Function Executor per function regardless of function versions. + # -- If a Function Executor already exists for a different function version then wait until + # all the tasks finish in the existing Function Executor and then destroy it. + # -- This prevents failed tasks for different versions of the same function continiously + # destroying each other's Function Executors. + # - Each Function Executor rans at most 1 task concurrently. + await state.wait_running_tasks_less(1) + + if state.function_id_with_version != _function_id_with_version(task): + await state.destroy_function_executor() + state.function_id_with_version = _function_id_with_version(task) + # At this point the state belongs to the version of the function from the task + # and there are no running tasks in the Function Executor. + + def _get_or_create_state(self, task: Task) -> FunctionExecutorState: + id = _function_id_without_version(task) + if id not in self._function_executor_states: + state = FunctionExecutorState( + function_id_with_version=_function_id_with_version(task), + function_id_without_version=id, + ) + self._function_executor_states[id] = state + return self._function_executor_states[id] + + async def _run_task( + self, state: FunctionExecutorState, task_input: TaskInput, logger: Any + ) -> TaskOutput: + runner: SingleTaskRunner = SingleTaskRunner( + function_executor_state=state, + task_input=task_input, + function_executor_server_factory=self._factory, + base_url=self._base_url, + config_path=self._config_path, + logger=logger, + ) + return await runner.run() + + async def shutdown(self) -> None: + # When shutting down there's no need to wait for completion of the running + # FunctionExecutor tasks. + while self._function_executor_states: + id, state = self._function_executor_states.popitem() + # At this point the state is not visible to new tasks. + # Only ongoing tasks who read it already have a reference to it. + await state.destroy_function_executor_not_locked() + # The task running inside the Function Executor will fail because it's destroyed. + # asyncio tasks waiting to run inside the Function Executor will get cancelled by + # the caller's shutdown code. + + +def _function_id_with_version(task: Task) -> str: + return f"versioned/{task.namespace}/{task.compute_graph}/{task.graph_version}/{task.compute_fn}" + + +def _function_id_without_version(task: Task) -> str: + return f"not_versioned/{task.namespace}/{task.compute_graph}/{task.compute_fn}" diff --git a/python-sdk/tests/test_executor_behaviour.py b/python-sdk/tests/test_executor_behaviour.py index b83ad4104..f35db4aea 100644 --- a/python-sdk/tests/test_executor_behaviour.py +++ b/python-sdk/tests/test_executor_behaviour.py @@ -13,7 +13,7 @@ from indexify.executor.executor import Executor -class TestExtractorAgent(unittest.TestCase): +class TestExecutor(unittest.TestCase): @patch( "builtins.open", new_callable=mock_open, @@ -32,6 +32,7 @@ def test_tls_configuration(self, mock_async_client, mock_sync_client, mock_file) executor = Executor( executor_id="unit-test", code_path=Path("test"), + function_executor_server_factory=None, server_addr=service_url, config_path=config_path, ) @@ -62,6 +63,7 @@ def test_no_tls_configuration(self): executor = Executor( executor_id="unit-test", code_path=Path("test"), + function_executor_server_factory=None, server_addr="localhost:8900", ) diff --git a/python-sdk/tests/test_function_concurrency.py b/python-sdk/tests/test_function_concurrency.py new file mode 100644 index 000000000..ad2ae461a --- /dev/null +++ b/python-sdk/tests/test_function_concurrency.py @@ -0,0 +1,121 @@ +import threading +import time +import unittest +from typing import Optional + +from parameterized import parameterized + +from indexify import Graph, indexify_function +from tests.testing import remote_or_local_graph, test_graph_name + + +@indexify_function() +def sleep_a(secs: int) -> str: + time.sleep(secs) + return "success" + + +@indexify_function() +def sleep_b(secs: int) -> str: + time.sleep(secs) + return "success" + + +class TestRemoteGraphFunctionConcurrency(unittest.TestCase): + def test_two_same_functions_run_with_concurrency_of_one(self): + is_remote = True + graph = Graph( + name=test_graph_name(self), + description="test", + start_node=sleep_a, + ) + graph = remote_or_local_graph(graph, is_remote) + + def invoke_sleep_a(secs: int): + invocation_id = graph.run( + block_until_done=True, + secs=secs, + ) + output = graph.output(invocation_id, "sleep_a") + self.assertEqual(output, ["success"]) + + # Pre-warm Executor so Executor delays in the next invokes are very low. + invoke_sleep_a(0.01) + + threads = [ + threading.Thread(target=invoke_sleep_a, args=(0.51,)), + threading.Thread(target=invoke_sleep_a, args=(0.51,)), + ] + + start_time = time.time() + for thread in threads: + thread.start() + for thread in threads: + thread.join() + + end_time = time.time() + duration = end_time - start_time + self.assertGreaterEqual( + duration, + 1.0, + "The two invocations of the same function should run sequentially", + ) + + def test_two_different_functions_run_with_concurrency_of_two(self): + is_remote = True + graph_a = Graph( + name=test_graph_name(self) + "_a", + description="test", + start_node=sleep_a, + ) + graph_a = remote_or_local_graph(graph_a, is_remote) + + def invoke_sleep_a(secs: int): + invocation_id = graph_a.run( + block_until_done=True, + secs=secs, + ) + output = graph_a.output(invocation_id, "sleep_a") + self.assertEqual(output, ["success"]) + + graph_b = Graph( + name=test_graph_name(self) + "_b", + description="test", + start_node=sleep_b, + ) + graph_b = remote_or_local_graph(graph_b, is_remote) + + def invoke_sleep_b(secs: int): + invocation_id = graph_b.run( + block_until_done=True, + secs=secs, + ) + output = graph_b.output(invocation_id, "sleep_b") + self.assertEqual(output, ["success"]) + + # Pre-warm Executor so Executor delays in the next invokes are very low. + invoke_sleep_a(0.01) + invoke_sleep_b(0.01) + + threads = [ + threading.Thread(target=invoke_sleep_a, args=(0.51,)), + threading.Thread(target=invoke_sleep_b, args=(0.51,)), + ] + + start_time = time.time() + for thread in threads: + thread.start() + for thread in threads: + thread.join() + + end_time = time.time() + duration = end_time - start_time + self.assertLessEqual( + duration, + 1.0, + "The two invocations of different functions should run concurrently", + ) + + +if __name__ == "__main__": + unittest.main()