Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: update batchmap and mapstream to use Map proto #200

Merged
merged 5 commits into from
Nov 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions pynumaflow/batchmapper/_dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from datetime import datetime
from typing import TypeVar, Callable, Union, Optional
from collections.abc import AsyncIterable
from collections.abc import Awaitable

from pynumaflow._constants import DROP

Expand Down Expand Up @@ -222,5 +221,9 @@ async def handler(self, datums: AsyncIterable[Datum]) -> BatchResponses:
pass


BatchMapAsyncCallable = Callable[[AsyncIterable[Datum]], Awaitable[BatchResponses]]
BatchMapAsyncCallable = Callable[[AsyncIterable[Datum]], BatchResponses]
BatchMapCallable = Union[BatchMapper, BatchMapAsyncCallable]


class BatchMapError(Exception):
"""To Raise an error while executing a BatchMap call"""
4 changes: 2 additions & 2 deletions pynumaflow/batchmapper/async_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
MINIMUM_NUMAFLOW_VERSION,
ContainerType,
)
from pynumaflow.proto.batchmapper import batchmap_pb2_grpc
from pynumaflow.proto.mapper import map_pb2_grpc
from pynumaflow.shared.server import NumaflowServer, start_async_server


Expand Down Expand Up @@ -103,7 +103,7 @@ async def aexec(self):
# Create a new async server instance and add the servicer to it
server = grpc.aio.server(options=self._server_options)
server.add_insecure_port(self.sock_path)
batchmap_pb2_grpc.add_BatchMapServicer_to_server(
map_pb2_grpc.add_MapServicer_to_server(
self.servicer,
server,
)
Expand Down
157 changes: 69 additions & 88 deletions pynumaflow/batchmapper/servicer/async_servicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,40 +5,19 @@
from google.protobuf import empty_pb2 as _empty_pb2

from pynumaflow.batchmapper import Datum
from pynumaflow.batchmapper._dtypes import BatchMapCallable
from pynumaflow.proto.batchmapper import batchmap_pb2, batchmap_pb2_grpc
from pynumaflow.batchmapper._dtypes import BatchMapCallable, BatchMapError
from pynumaflow.proto.mapper import map_pb2, map_pb2_grpc
from pynumaflow.shared.asynciter import NonBlockingIterator
from pynumaflow.shared.server import exit_on_error
from pynumaflow.types import NumaflowServicerContext
from pynumaflow._constants import _LOGGER, STREAM_EOF


async def datum_generator(
request_iterator: AsyncIterable[batchmap_pb2.BatchMapRequest],
) -> AsyncIterable[Datum]:
"""
This function is used to create an async generator
from the gRPC request iterator.
It yields a Datum instance for each request received which is then
forwarded to the UDF.
"""
async for d in request_iterator:
request = Datum(
keys=d.keys,
value=d.value,
event_time=d.event_time.ToDatetime(),
watermark=d.watermark.ToDatetime(),
headers=dict(d.headers),
id=d.id,
)
yield request


class AsyncBatchMapServicer(batchmap_pb2_grpc.BatchMapServicer):
class AsyncBatchMapServicer(map_pb2_grpc.MapServicer):
"""
This class is used to create a new grpc Batch Map Servicer instance.
It implements the BatchMapServicer interface from the proto
batchmap_pb2_grpc.py file.
It implements the MapServicer interface from the proto
map_pb2_grpc.py file.
Provides the functionality for the required rpc methods.
"""

Expand All @@ -49,41 +28,74 @@ def __init__(
self.background_tasks = set()
self.__batch_map_handler: BatchMapCallable = handler

async def BatchMapFn(
async def MapFn(
self,
request_iterator: AsyncIterable[batchmap_pb2.BatchMapRequest],
request_iterator: AsyncIterable[map_pb2.MapRequest],
context: NumaflowServicerContext,
) -> batchmap_pb2.BatchMapResponse:
) -> AsyncIterable[map_pb2.MapResponse]:
"""
Applies a batch map function to a BatchMapRequest stream in a batching mode.
The pascal case function name comes from the proto batchmap_pb2_grpc.py file.
Applies a batch map function to a MapRequest stream in a batching mode.
The pascal case function name comes from the proto map_pb2_grpc.py file.
"""
# Create an async iterator from the request iterator
datum_iterator = datum_generator(request_iterator=request_iterator)

try:
# invoke the UDF call for batch map
responses, request_counter = await self.invoke_batch_map(datum_iterator)

# If the number of responses received does not align with the request batch size,
# we will not be able to process the data correctly.
# This should be marked as an error and raised to the user.
if len(responses) != request_counter:
err_msg = "batchMapFn: mismatch between length of batch requests and responses"
raise Exception(err_msg)

# iterate over the responses received and covert to the required proto format
for batch_response in responses:
single_req_resp = []
for msg in batch_response.messages:
single_req_resp.append(
batchmap_pb2.BatchMapResponse.Result(
keys=msg.keys, value=msg.value, tags=msg.tags
)
# The first message to be received should be a valid handshake
req = await request_iterator.__anext__()
# check if it is a valid handshake req
if not (req.handshake and req.handshake.sot):
raise BatchMapError("BatchMapFn: expected handshake as the first message")
yield map_pb2.MapResponse(handshake=map_pb2.Handshake(sot=True))

# cur_task is used to track the task (coroutine) processing
# the current batch of messages.
cur_task = None
# iterate of the incoming messages ot the sink
async for d in request_iterator:
# if we do not have any active task currently processing the batch
# we need to create one and call the User function for processing the same.
if cur_task is None:
req_queue = NonBlockingIterator()
cur_task = asyncio.create_task(
self.__batch_map_handler(req_queue.read_iterator())
)

# send the response for a given ID back to the stream
yield batchmap_pb2.BatchMapResponse(id=batch_response.id, results=single_req_resp)
self.background_tasks.add(cur_task)
cur_task.add_done_callback(self.background_tasks.discard)
# when we have end of transmission message, we need to stop the processing the
# current batch and wait for the next batch of messages.
# We will also wait for the current task to finish processing the current batch.
# We mark the current task as None to indicate that we are
# ready to process the next batch.
if d.status and d.status.eot:
await req_queue.put(STREAM_EOF)
await cur_task
ret = cur_task.result()

# iterate over the responses received and covert to the required proto format
for batch_response in ret:
single_req_resp = []
for msg in batch_response.messages:
single_req_resp.append(
map_pb2.MapResponse.Result(
keys=msg.keys, value=msg.value, tags=msg.tags
)
)
# send the response for a given ID back to the stream
yield map_pb2.MapResponse(id=batch_response.id, results=single_req_resp)

# send EOT after each finishing Batch responses
yield map_pb2.MapResponse(status=map_pb2.TransmissionStatus(eot=True))
cur_task = None
continue

# if we have a valid message, we will add it to the request queue for processing.
datum = Datum(
keys=list(d.request.keys),
value=d.request.value,
event_time=d.request.event_time.ToDatetime(),
watermark=d.request.watermark.ToDatetime(),
headers=dict(d.request.headers),
id=d.id,
)
await req_queue.put(datum)

except BaseException as err:
_LOGGER.critical("UDFError, re-raising the error", exc_info=True)
Expand All @@ -93,42 +105,11 @@ async def BatchMapFn(
exit_on_error(context, repr(err))
return

async def invoke_batch_map(self, datum_iterator: AsyncIterable[Datum]):
"""
# iterate over the incoming requests, and keep sending to the user code
# once all messages have been sent, we wait for the responses
"""
# create a message queue to send to the user code
niter = NonBlockingIterator()
riter = niter.read_iterator()
# create a task for invoking the UDF handler
task = asyncio.create_task(self.__batch_map_handler(riter))
# Save a reference to the result of this function, to avoid a
# task disappearing mid-execution.
self.background_tasks.add(task)
task.add_done_callback(lambda t: self.background_tasks.remove(t))

req_count = 0
# start streaming the messages to the UDF code, and increment the request counter
async for datum in datum_iterator:
await niter.put(datum)
req_count += 1

# once all messages have been exhausted, send an EOF to indicate end of messages
# to the UDF
await niter.put(STREAM_EOF)

# wait for all the responses
await task

# return the result from the UDF, along with the request_counter
return task.result(), req_count

async def IsReady(
self, request: _empty_pb2.Empty, context: NumaflowServicerContext
) -> batchmap_pb2.ReadyResponse:
) -> map_pb2.ReadyResponse:
"""
IsReady is the heartbeat endpoint for gRPC.
The pascal case function name comes from the proto batchmap_pb2_grpc.py file.
"""
return batchmap_pb2.ReadyResponse(ready=True)
return map_pb2.ReadyResponse(ready=True)
4 changes: 4 additions & 0 deletions pynumaflow/mapstreamer/_dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,3 +201,7 @@ async def handler(self, keys: list[str], datum: Datum) -> AsyncIterable[Message]

MapStreamAsyncCallable = Callable[[list[str], Datum], AsyncIterable[Message]]
MapStreamCallable = Union[MapStreamer, MapStreamAsyncCallable]


class MapStreamError(Exception):
"""To Raise an error while executing a MapStream call"""
4 changes: 2 additions & 2 deletions pynumaflow/mapstreamer/async_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
ContainerType,
)
from pynumaflow.mapstreamer.servicer.async_servicer import AsyncMapStreamServicer
from pynumaflow.proto.mapstreamer import mapstream_pb2_grpc
from pynumaflow.proto.mapper import map_pb2_grpc

from pynumaflow._constants import (
MAP_STREAM_SOCK_PATH,
Expand Down Expand Up @@ -122,7 +122,7 @@ async def aexec(self):
# Create a new async server instance and add the servicer to it
server = grpc.aio.server(options=self._server_options)
server.add_insecure_port(self.sock_path)
mapstream_pb2_grpc.add_MapStreamServicer_to_server(
map_pb2_grpc.add_MapServicer_to_server(
self.servicer,
server,
)
Expand Down
68 changes: 37 additions & 31 deletions pynumaflow/mapstreamer/servicer/async_servicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,18 @@
from google.protobuf import empty_pb2 as _empty_pb2

from pynumaflow.mapstreamer import Datum
from pynumaflow.mapstreamer._dtypes import MapStreamCallable
from pynumaflow.proto.mapstreamer import mapstream_pb2_grpc, mapstream_pb2
from pynumaflow.mapstreamer._dtypes import MapStreamCallable, MapStreamError
from pynumaflow.proto.mapper import map_pb2_grpc, map_pb2
from pynumaflow.shared.server import exit_on_error
from pynumaflow.types import NumaflowServicerContext
from pynumaflow._constants import _LOGGER


class AsyncMapStreamServicer(mapstream_pb2_grpc.MapStreamServicer):
class AsyncMapStreamServicer(map_pb2_grpc.MapServicer):
"""
This class is used to create a new grpc Map Stream Servicer instance.
It implements the SyncMapServicer interface from the proto
mapstream_pb2_grpc.py file.
map_pb2_grpc.py file.
Provides the functionality for the required rpc methods.
"""

Expand All @@ -24,52 +24,58 @@ def __init__(
):
self.__map_stream_handler: MapStreamCallable = handler

async def MapStreamFn(
async def MapFn(
self,
request: mapstream_pb2.MapStreamRequest,
request_iterator: AsyncIterable[map_pb2.MapRequest],
context: NumaflowServicerContext,
) -> AsyncIterable[mapstream_pb2.MapStreamResponse]:
) -> AsyncIterable[map_pb2.MapResponse]:
"""
Applies a map function to a datum stream in streaming mode.
The pascal case function name comes from the proto mapstream_pb2_grpc.py file.
The pascal case function name comes from the proto map_pb2_grpc.py file.
"""

try:
async for res in self.__invoke_map_stream(
list(request.keys),
Datum(
keys=list(request.keys),
value=request.value,
event_time=request.event_time.ToDatetime(),
watermark=request.watermark.ToDatetime(),
headers=dict(request.headers),
),
context,
):
yield mapstream_pb2.MapStreamResponse(result=res)
# The first message to be received should be a valid handshake
req = await request_iterator.__anext__()
# check if it is a valid handshake req
if not (req.handshake and req.handshake.sot):
raise MapStreamError("MapStreamFn: expected handshake as the first message")
yield map_pb2.MapResponse(handshake=map_pb2.Handshake(sot=True))

# read for each input request
async for req in request_iterator:
# yield messages as received from the UDF
async for res in self.__invoke_map_stream(
list(req.request.keys),
Datum(
keys=list(req.request.keys),
value=req.request.value,
event_time=req.request.event_time.ToDatetime(),
watermark=req.request.watermark.ToDatetime(),
headers=dict(req.request.headers),
),
):
yield map_pb2.MapResponse(results=[res], id=req.id)
# send EOT to indicate end of transmission for a given message
yield map_pb2.MapResponse(status=map_pb2.TransmissionStatus(eot=True), id=req.id)
except BaseException as err:
_LOGGER.critical("UDFError, re-raising the error", exc_info=True)
exit_on_error(context, repr(err))
return

async def __invoke_map_stream(
self, keys: list[str], req: Datum, context: NumaflowServicerContext
):
async def __invoke_map_stream(self, keys: list[str], req: Datum):
try:
# Invoke the user handler for map stream
async for msg in self.__map_stream_handler(keys, req):
yield mapstream_pb2.MapStreamResponse.Result(
keys=msg.keys, value=msg.value, tags=msg.tags
)
yield map_pb2.MapResponse.Result(keys=msg.keys, value=msg.value, tags=msg.tags)
except BaseException as err:
_LOGGER.critical("UDFError, re-raising the error", exc_info=True)
exit_on_error(context, repr(err))
raise err

async def IsReady(
self, request: _empty_pb2.Empty, context: NumaflowServicerContext
) -> mapstream_pb2.ReadyResponse:
) -> map_pb2.ReadyResponse:
"""
IsReady is the heartbeat endpoint for gRPC.
The pascal case function name comes from the proto mapstream_pb2_grpc.py file.
The pascal case function name comes from the proto map_pb2_grpc.py file.
"""
return mapstream_pb2.ReadyResponse(ready=True)
return map_pb2.ReadyResponse(ready=True)
Empty file.
Loading
Loading