Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: bi-directional streaming map #197

Merged
merged 8 commits into from
Oct 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions pynumaflow/mapper/_dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,3 +204,7 @@ def handler(self, keys: list[str], datum: Datum) -> Messages:
# MapAsyncCallable is a callable which can be used as a handler for the Asynchronous Map UDF
MapAsyncHandlerCallable = Callable[[list[str], Datum], Awaitable[Messages]]
MapAsyncCallable = Union[Mapper, MapAsyncHandlerCallable]


class MapError(Exception):
"""To Raise an error while executing a Map call"""
129 changes: 129 additions & 0 deletions pynumaflow/mapper/_servicer/_async_servicer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
import asyncio
from collections.abc import AsyncIterable

from google.protobuf import empty_pb2 as _empty_pb2
from pynumaflow.shared.asynciter import NonBlockingIterator

from pynumaflow._constants import _LOGGER, STREAM_EOF
from pynumaflow.mapper._dtypes import MapAsyncCallable, Datum, MapError
from pynumaflow.proto.mapper import map_pb2, map_pb2_grpc
from pynumaflow.shared.server import exit_on_error, handle_async_error
from pynumaflow.types import NumaflowServicerContext


class AsyncMapServicer(map_pb2_grpc.MapServicer):
"""
This class is used to create a new grpc Async Map Servicer instance.
It implements the SyncMapServicer interface from the proto map.proto file.
Provides the functionality for the required rpc methods.
"""

def __init__(
self,
handler: MapAsyncCallable,
):
self.background_tasks = set()
self.__map_handler: MapAsyncCallable = handler

async def MapFn(
self,
request_iterator: AsyncIterable[map_pb2.MapRequest],
context: NumaflowServicerContext,
) -> AsyncIterable[map_pb2.MapResponse]:
"""
Applies a function to each datum element.
The pascal case function name comes from the proto map_pb2_grpc.py file.
"""
# proto repeated field(keys) is of type google._upb._message.RepeatedScalarContainer
# we need to explicitly convert it to list
try:
# The first message to be received should be a valid handshake
req = await request_iterator.__anext__()
# check if it is a valid handshake req
if not (req.handshake and req.handshake.sot):
raise MapError("MapFn: expected handshake as the first message")
yield map_pb2.MapResponse(handshake=map_pb2.Handshake(sot=True))

global_result_queue = NonBlockingIterator()

# reader task to process the input task and invoke the required tasks
producer = asyncio.create_task(
self._process_inputs(request_iterator, global_result_queue)
)

# keep reading on result queue and send messages back
consumer = global_result_queue.read_iterator()
async for msg in consumer:
# If the message is an exception, we raise the exception
if isinstance(msg, BaseException):
await handle_async_error(context, msg)
return
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are not really returning anything right? But it should ideally be AsyncIterable[map_pb2.MapResponse] correct?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct, but the error handler would have sent a context error back, hence empty return for exit here.

# Send window response back to the client
else:
yield msg
# wait for the producer task to complete
await producer
except BaseException as e:
_LOGGER.critical("UDFError, re-raising the error", exc_info=True)
exit_on_error(context, repr(e))
return

async def _process_inputs(
self,
request_iterator: AsyncIterable[map_pb2.MapRequest],
result_queue: NonBlockingIterator,
):
"""
Utility function for processing incoming MapRequests
"""
try:
# for each incoming request, create a background task to execute the
# UDF code
async for req in request_iterator:
msg_task = asyncio.create_task(self._invoke_map(req, result_queue))
# save a reference to a set to store active tasks
self.background_tasks.add(msg_task)
msg_task.add_done_callback(self.background_tasks.discard)

# wait for all tasks to complete
for task in self.background_tasks:
await task

# send an EOF to result queue to indicate that all tasks have completed
await result_queue.put(STREAM_EOF)

except BaseException as e:
await result_queue.put(e)
return

async def _invoke_map(self, req: map_pb2.MapRequest, result_queue: NonBlockingIterator):
"""
Invokes the user defined function.
"""
try:
datum = Datum(
keys=list(req.request.keys),
value=req.request.value,
event_time=req.request.event_time.ToDatetime(),
watermark=req.request.watermark.ToDatetime(),
headers=dict(req.request.headers),
)
msgs = await self.__map_handler(list(req.request.keys), datum)
datums = []
for msg in msgs:
datums.append(
map_pb2.MapResponse.Result(keys=msg.keys, value=msg.value, tags=msg.tags)
)
await result_queue.put(map_pb2.MapResponse(results=datums, id=req.id))
except BaseException as err:
_LOGGER.critical("UDFError, re-raising the error", exc_info=True)
await result_queue.put(err)

async def IsReady(
self, request: _empty_pb2.Empty, context: NumaflowServicerContext
) -> map_pb2.ReadyResponse:
"""
IsReady is the heartbeat endpoint for gRPC.
The pascal case function name comes from the proto map_pb2_grpc.py file.
"""
return map_pb2.ReadyResponse(ready=True)
133 changes: 133 additions & 0 deletions pynumaflow/mapper/_servicer/_sync_servicer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
import threading
from concurrent.futures import ThreadPoolExecutor
from collections.abc import Iterable

from google.protobuf import empty_pb2 as _empty_pb2
from pynumaflow.shared.server import exit_on_error

from pynumaflow._constants import NUM_THREADS_DEFAULT, STREAM_EOF, _LOGGER
from pynumaflow.mapper._dtypes import MapSyncCallable, Datum, MapError
from pynumaflow.proto.mapper import map_pb2, map_pb2_grpc
from pynumaflow.shared.synciter import SyncIterator
from pynumaflow.types import NumaflowServicerContext


class SyncMapServicer(map_pb2_grpc.MapServicer):
"""
This class is used to create a new grpc Map Servicer instance.
It implements the SyncMapServicer interface from the proto map.proto file.
Provides the functionality for the required rpc methods.
"""

def __init__(self, handler: MapSyncCallable, multiproc: bool = False):
self.__map_handler: MapSyncCallable = handler
# This indicates whether the grpc server attached is multiproc or not
self.multiproc = multiproc
# create a thread pool for executing UDF code
self.executor = ThreadPoolExecutor(max_workers=NUM_THREADS_DEFAULT)

def MapFn(
self,
request_iterator: Iterable[map_pb2.MapRequest],
context: NumaflowServicerContext,
) -> Iterable[map_pb2.MapResponse]:
"""
Applies a function to each datum element.
The pascal case function name comes from the proto map_pb2_grpc.py file.
"""
try:
# The first message to be received should be a valid handshake
req = next(request_iterator)
# check if it is a valid handshake req
if not (req.handshake and req.handshake.sot):
raise MapError("MapFn: expected handshake as the first message")
yield map_pb2.MapResponse(handshake=map_pb2.Handshake(sot=True))

# result queue to stream messages from the user code back to the client
result_queue = SyncIterator()

# Reader thread to keep reading from the request iterator and schedule
# execution for each of them
reader_thread = threading.Thread(
target=self._process_requests, args=(context, request_iterator, result_queue)
)
reader_thread.start()
# Read the result queue and keep forwarding them upstream
for res in result_queue.read_iterator():
# if error handler accordingly
if isinstance(res, BaseException):
# Terminate the current server process due to exception
exit_on_error(context, repr(res), parent=self.multiproc)
return
# return the result
yield res

# wait for the threads to clean-up
reader_thread.join()
self.executor.shutdown(cancel_futures=True)

except BaseException as err:
_LOGGER.critical("UDFError, re-raising the error", exc_info=True)
# Terminate the current server process due to exception
exit_on_error(context, repr(err), parent=self.multiproc)
return

def _process_requests(
self,
context: NumaflowServicerContext,
request_iterator: Iterable[map_pb2.MapRequest],
result_queue: SyncIterator,
):
try:
# read through all incoming requests and submit to the
# threadpool for invocation
for request in request_iterator:
_ = self.executor.submit(self._invoke_map, context, request, result_queue)
# wait for all tasks to finish after all requests exhausted
self.executor.shutdown(wait=True)
# Indicate to the result queue that no more messages left to process
result_queue.put(STREAM_EOF)
except BaseException as e:
_LOGGER.critical("MapFn Error, re-raising the error", exc_info=True)
result_queue.put(e)

Check warning on line 92 in pynumaflow/mapper/_servicer/_sync_servicer.py

View check run for this annotation

Codecov / codecov/patch

pynumaflow/mapper/_servicer/_sync_servicer.py#L90-L92

Added lines #L90 - L92 were not covered by tests

def _invoke_map(
self,
context: NumaflowServicerContext,
request: map_pb2.MapRequest,
result_queue: SyncIterator,
):
try:
d = Datum(
keys=list(request.request.keys),
value=request.request.value,
event_time=request.request.event_time.ToDatetime(),
watermark=request.request.watermark.ToDatetime(),
headers=dict(request.request.headers),
)
Comment on lines +101 to +107
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to have this part inside the try except?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If there is a type issue, good to catch that mostly


responses = self.__map_handler(list(request.request.keys), d)
results = []
for resp in responses:
results.append(
map_pb2.MapResponse.Result(
keys=list(resp.keys),
value=resp.value,
tags=resp.tags,
)
)
result_queue.put(map_pb2.MapResponse(results=results, id=request.id))

except BaseException as e:
_LOGGER.critical("MapFn handler error", exc_info=True)
result_queue.put(e)
return

def IsReady(
self, request: _empty_pb2.Empty, context: NumaflowServicerContext
) -> map_pb2.ReadyResponse:
"""
IsReady is the heartbeat endpoint for gRPC.
The pascal case function name comes from the proto map_pb2_grpc.py file.
"""
return map_pb2.ReadyResponse(ready=True)
2 changes: 1 addition & 1 deletion pynumaflow/mapper/async_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
ContainerType,
)
from pynumaflow.mapper._dtypes import MapAsyncCallable
from pynumaflow.mapper.servicer.async_servicer import AsyncMapServicer
from pynumaflow.mapper._servicer._async_servicer import AsyncMapServicer
from pynumaflow.proto.mapper import map_pb2_grpc
from pynumaflow.shared.server import (
NumaflowServer,
Expand Down
2 changes: 1 addition & 1 deletion pynumaflow/mapper/multiproc_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
ContainerType,
)
from pynumaflow.mapper._dtypes import MapSyncCallable
from pynumaflow.mapper.servicer.sync_servicer import SyncMapServicer
from pynumaflow.mapper._servicer._sync_servicer import SyncMapServicer
from pynumaflow.shared.server import (
NumaflowServer,
start_multiproc_server,
Expand Down
75 changes: 0 additions & 75 deletions pynumaflow/mapper/servicer/async_servicer.py

This file was deleted.

Loading
Loading