From 09748a3a317dec30be8c49c93129067d08265f7c Mon Sep 17 00:00:00 2001 From: Sidhant Kohli Date: Fri, 1 Nov 2024 11:43:55 -0400 Subject: [PATCH 1/4] feat: unify batchmap Signed-off-by: Sidhant Kohli --- pynumaflow/batchmapper/_dtypes.py | 7 +- pynumaflow/batchmapper/async_server.py | 4 +- .../batchmapper/servicer/async_servicer.py | 195 ++++++++++-------- pynumaflow/proto/batchmapper/__init__.py | 0 pynumaflow/proto/batchmapper/batchmap.proto | 50 ----- pynumaflow/proto/batchmapper/batchmap_pb2.py | 43 ---- pynumaflow/proto/batchmapper/batchmap_pb2.pyi | 79 ------- .../proto/batchmapper/batchmap_pb2_grpc.py | 128 ------------ tests/batchmap/test_async_batch_map.py | 43 ++-- tests/batchmap/test_async_batch_map_err.py | 38 ++-- tests/batchmap/utils.py | 38 ++-- 11 files changed, 176 insertions(+), 449 deletions(-) delete mode 100644 pynumaflow/proto/batchmapper/__init__.py delete mode 100644 pynumaflow/proto/batchmapper/batchmap.proto delete mode 100644 pynumaflow/proto/batchmapper/batchmap_pb2.py delete mode 100644 pynumaflow/proto/batchmapper/batchmap_pb2.pyi delete mode 100644 pynumaflow/proto/batchmapper/batchmap_pb2_grpc.py diff --git a/pynumaflow/batchmapper/_dtypes.py b/pynumaflow/batchmapper/_dtypes.py index 91762a48..edeeb08a 100644 --- a/pynumaflow/batchmapper/_dtypes.py +++ b/pynumaflow/batchmapper/_dtypes.py @@ -4,7 +4,6 @@ from datetime import datetime from typing import TypeVar, Callable, Union, Optional from collections.abc import AsyncIterable -from collections.abc import Awaitable from pynumaflow._constants import DROP @@ -222,5 +221,9 @@ async def handler(self, datums: AsyncIterable[Datum]) -> BatchResponses: pass -BatchMapAsyncCallable = Callable[[AsyncIterable[Datum]], Awaitable[BatchResponses]] +BatchMapAsyncCallable = Callable[[AsyncIterable[Datum]], BatchResponses] BatchMapCallable = Union[BatchMapper, BatchMapAsyncCallable] + + +class BatchMapError(Exception): + """To Raise an error while executing a BatchMap call""" diff --git a/pynumaflow/batchmapper/async_server.py b/pynumaflow/batchmapper/async_server.py index 0db1e73b..73914bdd 100644 --- a/pynumaflow/batchmapper/async_server.py +++ b/pynumaflow/batchmapper/async_server.py @@ -18,7 +18,7 @@ MINIMUM_NUMAFLOW_VERSION, ContainerType, ) -from pynumaflow.proto.batchmapper import batchmap_pb2_grpc +from pynumaflow.proto.mapper import map_pb2_grpc from pynumaflow.shared.server import NumaflowServer, start_async_server @@ -103,7 +103,7 @@ async def aexec(self): # Create a new async server instance and add the servicer to it server = grpc.aio.server(options=self._server_options) server.add_insecure_port(self.sock_path) - batchmap_pb2_grpc.add_BatchMapServicer_to_server( + map_pb2_grpc.add_MapServicer_to_server( self.servicer, server, ) diff --git a/pynumaflow/batchmapper/servicer/async_servicer.py b/pynumaflow/batchmapper/servicer/async_servicer.py index 6daf7b5a..f16361ae 100644 --- a/pynumaflow/batchmapper/servicer/async_servicer.py +++ b/pynumaflow/batchmapper/servicer/async_servicer.py @@ -5,40 +5,19 @@ from google.protobuf import empty_pb2 as _empty_pb2 from pynumaflow.batchmapper import Datum -from pynumaflow.batchmapper._dtypes import BatchMapCallable -from pynumaflow.proto.batchmapper import batchmap_pb2, batchmap_pb2_grpc +from pynumaflow.batchmapper._dtypes import BatchMapCallable, BatchMapError +from pynumaflow.proto.mapper import map_pb2, map_pb2_grpc from pynumaflow.shared.asynciter import NonBlockingIterator from pynumaflow.shared.server import exit_on_error from pynumaflow.types import NumaflowServicerContext from pynumaflow._constants import _LOGGER, STREAM_EOF -async def datum_generator( - request_iterator: AsyncIterable[batchmap_pb2.BatchMapRequest], -) -> AsyncIterable[Datum]: - """ - This function is used to create an async generator - from the gRPC request iterator. - It yields a Datum instance for each request received which is then - forwarded to the UDF. - """ - async for d in request_iterator: - request = Datum( - keys=d.keys, - value=d.value, - event_time=d.event_time.ToDatetime(), - watermark=d.watermark.ToDatetime(), - headers=dict(d.headers), - id=d.id, - ) - yield request - - -class AsyncBatchMapServicer(batchmap_pb2_grpc.BatchMapServicer): +class AsyncBatchMapServicer(map_pb2_grpc.MapServicer): """ This class is used to create a new grpc Batch Map Servicer instance. - It implements the BatchMapServicer interface from the proto - batchmap_pb2_grpc.py file. + It implements the MapServicer interface from the proto + map_pb2_grpc.py file. Provides the functionality for the required rpc methods. """ @@ -49,41 +28,74 @@ def __init__( self.background_tasks = set() self.__batch_map_handler: BatchMapCallable = handler - async def BatchMapFn( + async def MapFn( self, - request_iterator: AsyncIterable[batchmap_pb2.BatchMapRequest], + request_iterator: AsyncIterable[map_pb2.MapRequest], context: NumaflowServicerContext, - ) -> batchmap_pb2.BatchMapResponse: + ) -> AsyncIterable[map_pb2.MapResponse]: """ - Applies a batch map function to a BatchMapRequest stream in a batching mode. - The pascal case function name comes from the proto batchmap_pb2_grpc.py file. + Applies a batch map function to a MapRequest stream in a batching mode. + The pascal case function name comes from the proto map_pb2_grpc.py file. """ - # Create an async iterator from the request iterator - datum_iterator = datum_generator(request_iterator=request_iterator) - try: - # invoke the UDF call for batch map - responses, request_counter = await self.invoke_batch_map(datum_iterator) - - # If the number of responses received does not align with the request batch size, - # we will not be able to process the data correctly. - # This should be marked as an error and raised to the user. - if len(responses) != request_counter: - err_msg = "batchMapFn: mismatch between length of batch requests and responses" - raise Exception(err_msg) - - # iterate over the responses received and covert to the required proto format - for batch_response in responses: - single_req_resp = [] - for msg in batch_response.messages: - single_req_resp.append( - batchmap_pb2.BatchMapResponse.Result( - keys=msg.keys, value=msg.value, tags=msg.tags - ) + # The first message to be received should be a valid handshake + req = await request_iterator.__anext__() + # check if it is a valid handshake req + if not (req.handshake and req.handshake.sot): + raise BatchMapError("BatchMapFn: expected handshake as the first message") + yield map_pb2.MapResponse(handshake=map_pb2.Handshake(sot=True)) + + # cur_task is used to track the task (coroutine) processing + # the current batch of messages. + cur_task = None + # iterate of the incoming messages ot the sink + async for d in request_iterator: + # if we do not have any active task currently processing the batch + # we need to create one and call the User function for processing the same. + if cur_task is None: + req_queue = NonBlockingIterator() + cur_task = asyncio.create_task( + self.__batch_map_handler(req_queue.read_iterator()) ) - - # send the response for a given ID back to the stream - yield batchmap_pb2.BatchMapResponse(id=batch_response.id, results=single_req_resp) + self.background_tasks.add(cur_task) + cur_task.add_done_callback(self.background_tasks.discard) + # when we have end of transmission message, we need to stop the processing the + # current batch and wait for the next batch of messages. + # We will also wait for the current task to finish processing the current batch. + # We mark the current task as None to indicate that we are + # ready to process the next batch. + if d.status and d.status.eot: + await req_queue.put(STREAM_EOF) + await cur_task + ret = cur_task.result() + + # iterate over the responses received and covert to the required proto format + for batch_response in ret: + single_req_resp = [] + for msg in batch_response.messages: + single_req_resp.append( + map_pb2.MapResponse.Result( + keys=msg.keys, value=msg.value, tags=msg.tags + ) + ) + # send the response for a given ID back to the stream + yield map_pb2.MapResponse(id=batch_response.id, results=single_req_resp) + + # send EOT after each finishing Batch responses + yield map_pb2.MapResponse(status=map_pb2.TransmissionStatus(eot=True)) + cur_task = None + continue + + # if we have a valid message, we will add it to the request queue for processing. + datum = Datum( + keys=list(d.request.keys), + value=d.request.value, + event_time=d.request.event_time.ToDatetime(), + watermark=d.request.watermark.ToDatetime(), + headers=dict(d.request.headers), + id=d.id, + ) + await req_queue.put(datum) except BaseException as err: _LOGGER.critical("UDFError, re-raising the error", exc_info=True) @@ -93,42 +105,59 @@ async def BatchMapFn( exit_on_error(context, repr(err)) return - async def invoke_batch_map(self, datum_iterator: AsyncIterable[Datum]): + # # Create an async iterator from the request iterator + # datum_iterator = datum_generator(request_iterator=request_iterator) + # + # try: + # # invoke the UDF call for batch map + # responses, request_counter = await self.invoke_batch_map(datum_iterator) + # + # # If the number of responses received does not align with the request batch size, + # # we will not be able to process the data correctly. + # # This should be marked as an error and raised to the user. + # if len(responses) != request_counter: + # err_msg = "batchMapFn: mismatch between length of batch requests and responses" + # raise Exception(err_msg) + # + # # iterate over the responses received and covert to the required proto format + # for batch_response in responses: + # single_req_resp = [] + # for msg in batch_response.messages: + # single_req_resp.append( + # batchmap_pb2.BatchMapResponse.Result( + # keys=msg.keys, value=msg.value, tags=msg.tags + # ) + # ) + # + # # send the response for a given ID back to the stream + # yield batchmap_pb2.BatchMapResponse(id=batch_response.id, results=single_req_resp) + # + # except BaseException as err: + # _LOGGER.critical("UDFError, re-raising the error", exc_info=True) + # await asyncio.gather( + # context.abort(grpc.StatusCode.UNKNOWN, details=repr(err)), return_exceptions=True + # ) + # exit_on_error(context, repr(err)) + # return + + async def __invoke_batch_map(self, datum_iterator: AsyncIterable[Datum]): """ # iterate over the incoming requests, and keep sending to the user code # once all messages have been sent, we wait for the responses """ - # create a message queue to send to the user code - niter = NonBlockingIterator() - riter = niter.read_iterator() - # create a task for invoking the UDF handler - task = asyncio.create_task(self.__batch_map_handler(riter)) - # Save a reference to the result of this function, to avoid a - # task disappearing mid-execution. - self.background_tasks.add(task) - task.add_done_callback(lambda t: self.background_tasks.remove(t)) - - req_count = 0 - # start streaming the messages to the UDF code, and increment the request counter - async for datum in datum_iterator: - await niter.put(datum) - req_count += 1 - - # once all messages have been exhausted, send an EOF to indicate end of messages - # to the UDF - await niter.put(STREAM_EOF) - - # wait for all the responses - await task - - # return the result from the UDF, along with the request_counter - return task.result(), req_count + try: + # invoke the user function with the request queue + return await self.__batch_map_handler(datum_iterator) + except BaseException as err: + err_msg = f"UDBatchMapError: {repr(err)}" + _LOGGER.critical(err_msg, exc_info=True) + raise err async def IsReady( self, request: _empty_pb2.Empty, context: NumaflowServicerContext - ) -> batchmap_pb2.ReadyResponse: + ) -> map_pb2.ReadyResponse: """ IsReady is the heartbeat endpoint for gRPC. The pascal case function name comes from the proto batchmap_pb2_grpc.py file. """ - return batchmap_pb2.ReadyResponse(ready=True) + return map_pb2.ReadyResponse(ready=True) diff --git a/pynumaflow/proto/batchmapper/__init__.py b/pynumaflow/proto/batchmapper/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/pynumaflow/proto/batchmapper/batchmap.proto b/pynumaflow/proto/batchmapper/batchmap.proto deleted file mode 100644 index 82d5672d..00000000 --- a/pynumaflow/proto/batchmapper/batchmap.proto +++ /dev/null @@ -1,50 +0,0 @@ -syntax = "proto3"; - -import "google/protobuf/empty.proto"; -import "google/protobuf/timestamp.proto"; - -package batchmap.v1; - -service BatchMap { - // IsReady is the heartbeat endpoint for gRPC. - rpc IsReady(google.protobuf.Empty) returns (ReadyResponse); - - // BatchMapFn is a bi-directional streaming rpc which applies a - // Map function on each BatchMapRequest element of the stream and then returns streams - // back BatchMapResponse elements. - rpc BatchMapFn(stream BatchMapRequest) returns (stream BatchMapResponse); -} - -/** - * BatchMapRequest represents a request element. - */ -message BatchMapRequest { - repeated string keys = 1; - bytes value = 2; - google.protobuf.Timestamp event_time = 3; - google.protobuf.Timestamp watermark = 4; - map headers = 5; - // This ID is used uniquely identify a map request - string id = 6; -} - -/** - * BatchMapResponse represents a response element. - */ -message BatchMapResponse { - message Result { - repeated string keys = 1; - bytes value = 2; - repeated string tags = 3; - } - repeated Result results = 1; - // This ID is used to refer the responses to the request it corresponds to. - string id = 2; -} - -/** - * ReadyResponse is the health check result. - */ -message ReadyResponse { - bool ready = 1; -} \ No newline at end of file diff --git a/pynumaflow/proto/batchmapper/batchmap_pb2.py b/pynumaflow/proto/batchmapper/batchmap_pb2.py deleted file mode 100644 index b25383d5..00000000 --- a/pynumaflow/proto/batchmapper/batchmap_pb2.py +++ /dev/null @@ -1,43 +0,0 @@ -# -*- coding: utf-8 -*- -# Generated by the protocol buffer compiler. DO NOT EDIT! -# source: batchmap.proto -# Protobuf Python Version: 4.25.1 -"""Generated protocol buffer code.""" -from google.protobuf import descriptor as _descriptor -from google.protobuf import descriptor_pool as _descriptor_pool -from google.protobuf import symbol_database as _symbol_database -from google.protobuf.internal import builder as _builder - -# @@protoc_insertion_point(imports) - -_sym_db = _symbol_database.Default() - - -from google.protobuf import empty_pb2 as google_dot_protobuf_dot_empty__pb2 -from google.protobuf import timestamp_pb2 as google_dot_protobuf_dot_timestamp__pb2 - - -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x0e\x62\x61tchmap.proto\x12\x0b\x62\x61tchmap.v1\x1a\x1bgoogle/protobuf/empty.proto\x1a\x1fgoogle/protobuf/timestamp.proto"\x85\x02\n\x0f\x42\x61tchMapRequest\x12\x0c\n\x04keys\x18\x01 \x03(\t\x12\r\n\x05value\x18\x02 \x01(\x0c\x12.\n\nevent_time\x18\x03 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12-\n\twatermark\x18\x04 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12:\n\x07headers\x18\x05 \x03(\x0b\x32).batchmap.v1.BatchMapRequest.HeadersEntry\x12\n\n\x02id\x18\x06 \x01(\t\x1a.\n\x0cHeadersEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01"\x8a\x01\n\x10\x42\x61tchMapResponse\x12\x35\n\x07results\x18\x01 \x03(\x0b\x32$.batchmap.v1.BatchMapResponse.Result\x12\n\n\x02id\x18\x02 \x01(\t\x1a\x33\n\x06Result\x12\x0c\n\x04keys\x18\x01 \x03(\t\x12\r\n\x05value\x18\x02 \x01(\x0c\x12\x0c\n\x04tags\x18\x03 \x03(\t"\x1e\n\rReadyResponse\x12\r\n\x05ready\x18\x01 \x01(\x08\x32\x98\x01\n\x08\x42\x61tchMap\x12=\n\x07IsReady\x12\x16.google.protobuf.Empty\x1a\x1a.batchmap.v1.ReadyResponse\x12M\n\nBatchMapFn\x12\x1c.batchmap.v1.BatchMapRequest\x1a\x1d.batchmap.v1.BatchMapResponse(\x01\x30\x01\x62\x06proto3' -) - -_globals = globals() -_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) -_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "batchmap_pb2", _globals) -if _descriptor._USE_C_DESCRIPTORS == False: - DESCRIPTOR._options = None - _globals["_BATCHMAPREQUEST_HEADERSENTRY"]._options = None - _globals["_BATCHMAPREQUEST_HEADERSENTRY"]._serialized_options = b"8\001" - _globals["_BATCHMAPREQUEST"]._serialized_start = 94 - _globals["_BATCHMAPREQUEST"]._serialized_end = 355 - _globals["_BATCHMAPREQUEST_HEADERSENTRY"]._serialized_start = 309 - _globals["_BATCHMAPREQUEST_HEADERSENTRY"]._serialized_end = 355 - _globals["_BATCHMAPRESPONSE"]._serialized_start = 358 - _globals["_BATCHMAPRESPONSE"]._serialized_end = 496 - _globals["_BATCHMAPRESPONSE_RESULT"]._serialized_start = 445 - _globals["_BATCHMAPRESPONSE_RESULT"]._serialized_end = 496 - _globals["_READYRESPONSE"]._serialized_start = 498 - _globals["_READYRESPONSE"]._serialized_end = 528 - _globals["_BATCHMAP"]._serialized_start = 531 - _globals["_BATCHMAP"]._serialized_end = 683 -# @@protoc_insertion_point(module_scope) diff --git a/pynumaflow/proto/batchmapper/batchmap_pb2.pyi b/pynumaflow/proto/batchmapper/batchmap_pb2.pyi deleted file mode 100644 index e51ccb85..00000000 --- a/pynumaflow/proto/batchmapper/batchmap_pb2.pyi +++ /dev/null @@ -1,79 +0,0 @@ -from google.protobuf import empty_pb2 as _empty_pb2 -from google.protobuf import timestamp_pb2 as _timestamp_pb2 -from google.protobuf.internal import containers as _containers -from google.protobuf import descriptor as _descriptor -from google.protobuf import message as _message -from typing import ( - ClassVar as _ClassVar, - Iterable as _Iterable, - Mapping as _Mapping, - Optional as _Optional, - Union as _Union, -) - -DESCRIPTOR: _descriptor.FileDescriptor - -class BatchMapRequest(_message.Message): - __slots__ = ("keys", "value", "event_time", "watermark", "headers", "id") - - class HeadersEntry(_message.Message): - __slots__ = ("key", "value") - KEY_FIELD_NUMBER: _ClassVar[int] - VALUE_FIELD_NUMBER: _ClassVar[int] - key: str - value: str - def __init__(self, key: _Optional[str] = ..., value: _Optional[str] = ...) -> None: ... - KEYS_FIELD_NUMBER: _ClassVar[int] - VALUE_FIELD_NUMBER: _ClassVar[int] - EVENT_TIME_FIELD_NUMBER: _ClassVar[int] - WATERMARK_FIELD_NUMBER: _ClassVar[int] - HEADERS_FIELD_NUMBER: _ClassVar[int] - ID_FIELD_NUMBER: _ClassVar[int] - keys: _containers.RepeatedScalarFieldContainer[str] - value: bytes - event_time: _timestamp_pb2.Timestamp - watermark: _timestamp_pb2.Timestamp - headers: _containers.ScalarMap[str, str] - id: str - def __init__( - self, - keys: _Optional[_Iterable[str]] = ..., - value: _Optional[bytes] = ..., - event_time: _Optional[_Union[_timestamp_pb2.Timestamp, _Mapping]] = ..., - watermark: _Optional[_Union[_timestamp_pb2.Timestamp, _Mapping]] = ..., - headers: _Optional[_Mapping[str, str]] = ..., - id: _Optional[str] = ..., - ) -> None: ... - -class BatchMapResponse(_message.Message): - __slots__ = ("results", "id") - - class Result(_message.Message): - __slots__ = ("keys", "value", "tags") - KEYS_FIELD_NUMBER: _ClassVar[int] - VALUE_FIELD_NUMBER: _ClassVar[int] - TAGS_FIELD_NUMBER: _ClassVar[int] - keys: _containers.RepeatedScalarFieldContainer[str] - value: bytes - tags: _containers.RepeatedScalarFieldContainer[str] - def __init__( - self, - keys: _Optional[_Iterable[str]] = ..., - value: _Optional[bytes] = ..., - tags: _Optional[_Iterable[str]] = ..., - ) -> None: ... - RESULTS_FIELD_NUMBER: _ClassVar[int] - ID_FIELD_NUMBER: _ClassVar[int] - results: _containers.RepeatedCompositeFieldContainer[BatchMapResponse.Result] - id: str - def __init__( - self, - results: _Optional[_Iterable[_Union[BatchMapResponse.Result, _Mapping]]] = ..., - id: _Optional[str] = ..., - ) -> None: ... - -class ReadyResponse(_message.Message): - __slots__ = ("ready",) - READY_FIELD_NUMBER: _ClassVar[int] - ready: bool - def __init__(self, ready: bool = ...) -> None: ... diff --git a/pynumaflow/proto/batchmapper/batchmap_pb2_grpc.py b/pynumaflow/proto/batchmapper/batchmap_pb2_grpc.py deleted file mode 100644 index d3614d5a..00000000 --- a/pynumaflow/proto/batchmapper/batchmap_pb2_grpc.py +++ /dev/null @@ -1,128 +0,0 @@ -# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! -"""Client and server classes corresponding to protobuf-defined services.""" -import grpc - -from . import batchmap_pb2 as batchmap__pb2 -from google.protobuf import empty_pb2 as google_dot_protobuf_dot_empty__pb2 - - -class BatchMapStub(object): - """Missing associated documentation comment in .proto file.""" - - def __init__(self, channel): - """Constructor. - - Args: - channel: A grpc.Channel. - """ - self.IsReady = channel.unary_unary( - "/batchmap.v1.BatchMap/IsReady", - request_serializer=google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString, - response_deserializer=batchmap__pb2.ReadyResponse.FromString, - ) - self.BatchMapFn = channel.stream_stream( - "/batchmap.v1.BatchMap/BatchMapFn", - request_serializer=batchmap__pb2.BatchMapRequest.SerializeToString, - response_deserializer=batchmap__pb2.BatchMapResponse.FromString, - ) - - -class BatchMapServicer(object): - """Missing associated documentation comment in .proto file.""" - - def IsReady(self, request, context): - """IsReady is the heartbeat endpoint for gRPC.""" - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details("Method not implemented!") - raise NotImplementedError("Method not implemented!") - - def BatchMapFn(self, request_iterator, context): - """BatchMapFn is a bi-directional streaming rpc which applies a - Map function on each BatchMapRequest element of the stream and then returns streams - back BatchMapResponse elements. - """ - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details("Method not implemented!") - raise NotImplementedError("Method not implemented!") - - -def add_BatchMapServicer_to_server(servicer, server): - rpc_method_handlers = { - "IsReady": grpc.unary_unary_rpc_method_handler( - servicer.IsReady, - request_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString, - response_serializer=batchmap__pb2.ReadyResponse.SerializeToString, - ), - "BatchMapFn": grpc.stream_stream_rpc_method_handler( - servicer.BatchMapFn, - request_deserializer=batchmap__pb2.BatchMapRequest.FromString, - response_serializer=batchmap__pb2.BatchMapResponse.SerializeToString, - ), - } - generic_handler = grpc.method_handlers_generic_handler( - "batchmap.v1.BatchMap", rpc_method_handlers - ) - server.add_generic_rpc_handlers((generic_handler,)) - - -# This class is part of an EXPERIMENTAL API. -class BatchMap(object): - """Missing associated documentation comment in .proto file.""" - - @staticmethod - def IsReady( - request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None, - ): - return grpc.experimental.unary_unary( - request, - target, - "/batchmap.v1.BatchMap/IsReady", - google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString, - batchmap__pb2.ReadyResponse.FromString, - options, - channel_credentials, - insecure, - call_credentials, - compression, - wait_for_ready, - timeout, - metadata, - ) - - @staticmethod - def BatchMapFn( - request_iterator, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None, - ): - return grpc.experimental.stream_stream( - request_iterator, - target, - "/batchmap.v1.BatchMap/BatchMapFn", - batchmap__pb2.BatchMapRequest.SerializeToString, - batchmap__pb2.BatchMapResponse.FromString, - options, - channel_credentials, - insecure, - call_credentials, - compression, - wait_for_ready, - timeout, - metadata, - ) diff --git a/tests/batchmap/test_async_batch_map.py b/tests/batchmap/test_async_batch_map.py index 4922ecb2..7ea32a07 100644 --- a/tests/batchmap/test_async_batch_map.py +++ b/tests/batchmap/test_async_batch_map.py @@ -17,8 +17,8 @@ BatchResponse, BatchMapAsyncServer, ) -from pynumaflow.proto.batchmapper import batchmap_pb2_grpc -from tests.batchmap.utils import start_request, request_generator +from pynumaflow.proto.mapper import map_pb2_grpc +from tests.batchmap.utils import request_generator LOGGER = setup_logging(__name__) @@ -85,7 +85,7 @@ def NewAsyncBatchMapper(): async def start_server(udfs): server = grpc.aio.server() - batchmap_pb2_grpc.add_BatchMapServicer_to_server(udfs, server) + map_pb2_grpc.add_MapServicer_to_server(udfs, server) server.add_insecure_port(listen_addr) logging.info("Starting server on %s", listen_addr) global _s @@ -125,37 +125,42 @@ def tearDownClass(cls) -> None: def test_batch_map(self) -> None: stub = self.__stub() - request = start_request() generator_response = None try: - generator_response = stub.BatchMapFn( - request_iterator=request_generator(count=10, request=request) - ) + generator_response = stub.MapFn(request_iterator=request_generator(count=10, session=1)) except grpc.RpcError as e: logging.error(e) - # capture the output from the BatchMapFn generator and assert. - count = 0 + handshake = next(generator_response) + # assert that handshake response is received. + self.assertTrue(handshake.handshake.sot) + data_resp = [] for r in generator_response: + data_resp.append(r) + + idx = 0 + while idx < len(data_resp) - 1: self.assertEqual( bytes( "test_mock_message", encoding="utf-8", ), - r.results[0].value, + data_resp[idx].results[0].value, ) - _id = r.id - self.assertEqual(_id, str(count)) - count += 1 - - # in our example we should be return 10 messages which is equal to the number - # of requests - self.assertEqual(10, count) + _id = data_resp[idx].id + self.assertEqual(_id, "test-id-" + str(idx)) + # capture the output from the SinkFn generator and assert. + # self.assertEqual(data_resp[idx].result.status, sink_pb2.Status.SUCCESS) + idx += 1 + # EOT Response + self.assertEqual(data_resp[len(data_resp) - 1].status.eot, True) + # 10 sink responses + 1 EOT response + self.assertEqual(11, len(data_resp)) def test_is_ready(self) -> None: with grpc.insecure_channel(listen_addr) as channel: - stub = batchmap_pb2_grpc.BatchMapStub(channel) + stub = map_pb2_grpc.MapStub(channel) request = _empty_pb2.Empty() response = None @@ -180,7 +185,7 @@ def test_max_threads(self): self.assertEqual(server.max_threads, 4) def __stub(self): - return batchmap_pb2_grpc.BatchMapStub(_channel) + return map_pb2_grpc.MapStub(_channel) if __name__ == "__main__": diff --git a/tests/batchmap/test_async_batch_map_err.py b/tests/batchmap/test_async_batch_map_err.py index 6f9c3bd0..5f8162f9 100644 --- a/tests/batchmap/test_async_batch_map_err.py +++ b/tests/batchmap/test_async_batch_map_err.py @@ -10,9 +10,9 @@ from pynumaflow import setup_logging from pynumaflow.batchmapper import BatchResponses -from pynumaflow.batchmapper import Datum, BatchMapAsyncServer -from pynumaflow.proto.batchmapper import batchmap_pb2_grpc -from tests.batchmap.utils import start_request +from pynumaflow.batchmapper import BatchMapAsyncServer +from pynumaflow.proto.mapper import map_pb2_grpc +from tests.batchmap.utils import request_generator from tests.testing_utils import mock_terminate_on_stop LOGGER = setup_logging(__name__) @@ -20,17 +20,8 @@ raise_error = False -def request_generator(count, request, resetkey: bool = False): - for i in range(count): - # add the id to the datum - request.id = str(i) - if resetkey: - request.payload.keys.extend([f"key-{i}"]) - yield request - - # This handler mimics the scenario where batch map UDF throws a runtime error. -async def err_handler(datums: list[Datum]) -> BatchResponses: +async def err_handler(datums) -> BatchResponses: if raise_error: raise RuntimeError("Got a runtime error from batch map handler.") batch_responses = BatchResponses() @@ -55,7 +46,7 @@ async def start_server(): server = grpc.aio.server() server_instance = BatchMapAsyncServer(err_handler) udfs = server_instance.servicer - batchmap_pb2_grpc.add_BatchMapServicer_to_server(udfs, server) + map_pb2_grpc.add_MapServicer_to_server(udfs, server) server.add_insecure_port(listen_addr) logging.info("Starting server on %s", listen_addr) global _s @@ -99,8 +90,8 @@ def test_batch_map_error(self) -> None: raise_error = True stub = self.__stub() try: - generator_response = stub.BatchMapFn( - request_iterator=request_generator(count=10, request=start_request()) + generator_response = stub.MapFn( + request_iterator=request_generator(count=10, handshake=True, session=1) ) counter = 0 for _ in generator_response: @@ -110,27 +101,24 @@ def test_batch_map_error(self) -> None: return self.fail("Expected an exception.") - def test_batch_map_length_error(self) -> None: + def test_batch_map_error_no_handshake(self) -> None: global raise_error - raise_error = False + raise_error = True stub = self.__stub() try: - generator_response = stub.BatchMapFn( - request_iterator=request_generator(count=10, request=start_request()) + generator_response = stub.MapFn( + request_iterator=request_generator(count=10, handshake=False, session=1) ) counter = 0 for _ in generator_response: counter += 1 except Exception as err: - self.assertTrue( - "batchMapFn: mismatch between length of batch requests and responses" - in err.__str__() - ) + self.assertTrue("BatchMapFn: expected handshake as the first message" in err.__str__()) return self.fail("Expected an exception.") def __stub(self): - return batchmap_pb2_grpc.BatchMapStub(_channel) + return map_pb2_grpc.MapStub(_channel) def test_invalid_input(self): with self.assertRaises(TypeError): diff --git a/tests/batchmap/utils.py b/tests/batchmap/utils.py index 31f87feb..cbcb1e59 100644 --- a/tests/batchmap/utils.py +++ b/tests/batchmap/utils.py @@ -1,22 +1,24 @@ -from pynumaflow.batchmapper import Datum -from pynumaflow.proto.batchmapper import batchmap_pb2 -from tests.testing_utils import get_time_args, mock_message +from pynumaflow.proto.mapper import map_pb2 +from tests.testing_utils import get_time_args, mock_message, mock_headers -def request_generator(count, request, resetkey: bool = False): - for i in range(count): - # add the id to the datum - request.id = str(i) - if resetkey: - request.payload.keys.extend([f"key-{i}"]) - yield request +def request_generator(count, session=1, handshake=True): + event_time_timestamp, watermark_timestamp = get_time_args() + if handshake: + yield map_pb2.MapRequest(handshake=map_pb2.Handshake(sot=True)) -def start_request() -> Datum: - event_time_timestamp, watermark_timestamp = get_time_args() - request = batchmap_pb2.BatchMapRequest( - value=mock_message(), - event_time=event_time_timestamp, - watermark=watermark_timestamp, - ) - return request + for j in range(session): + for i in range(count): + req = map_pb2.MapRequest( + request=map_pb2.MapRequest.Request( + value=mock_message(), + event_time=event_time_timestamp, + watermark=watermark_timestamp, + headers=mock_headers(), + ), + id="test-id-" + str(i), + ) + yield req + + yield map_pb2.MapRequest(status=map_pb2.TransmissionStatus(eot=True)) From 0df077f33112d6c50c894168f5debee1757b0643 Mon Sep 17 00:00:00 2001 From: Sidhant Kohli Date: Wed, 6 Nov 2024 00:22:06 -0800 Subject: [PATCH 2/4] chore: mapstream Signed-off-by: Sidhant Kohli --- pynumaflow/mapstreamer/_dtypes.py | 4 + pynumaflow/mapstreamer/async_server.py | 4 +- .../mapstreamer/servicer/async_servicer.py | 68 +++++----- pynumaflow/proto/mapstreamer/__init__.py | 0 pynumaflow/proto/mapstreamer/mapstream.proto | 45 ------- pynumaflow/proto/mapstreamer/mapstream_pb2.py | 43 ------ .../proto/mapstreamer/mapstream_pb2.pyi | 72 ---------- .../proto/mapstreamer/mapstream_pb2_grpc.py | 125 ------------------ tests/mapstream/test_async_map_stream.py | 40 ++++-- tests/mapstream/test_async_map_stream_err.py | 29 ++-- tests/mapstream/utils.py | 28 ++-- 11 files changed, 106 insertions(+), 352 deletions(-) delete mode 100644 pynumaflow/proto/mapstreamer/__init__.py delete mode 100644 pynumaflow/proto/mapstreamer/mapstream.proto delete mode 100644 pynumaflow/proto/mapstreamer/mapstream_pb2.py delete mode 100644 pynumaflow/proto/mapstreamer/mapstream_pb2.pyi delete mode 100644 pynumaflow/proto/mapstreamer/mapstream_pb2_grpc.py diff --git a/pynumaflow/mapstreamer/_dtypes.py b/pynumaflow/mapstreamer/_dtypes.py index d8669193..43089415 100644 --- a/pynumaflow/mapstreamer/_dtypes.py +++ b/pynumaflow/mapstreamer/_dtypes.py @@ -201,3 +201,7 @@ async def handler(self, keys: list[str], datum: Datum) -> AsyncIterable[Message] MapStreamAsyncCallable = Callable[[list[str], Datum], AsyncIterable[Message]] MapStreamCallable = Union[MapStreamer, MapStreamAsyncCallable] + + +class MapStreamError(Exception): + """To Raise an error while executing a MapStream call""" diff --git a/pynumaflow/mapstreamer/async_server.py b/pynumaflow/mapstreamer/async_server.py index e10bb323..5d4bb80a 100644 --- a/pynumaflow/mapstreamer/async_server.py +++ b/pynumaflow/mapstreamer/async_server.py @@ -9,7 +9,7 @@ ContainerType, ) from pynumaflow.mapstreamer.servicer.async_servicer import AsyncMapStreamServicer -from pynumaflow.proto.mapstreamer import mapstream_pb2_grpc +from pynumaflow.proto.mapper import map_pb2_grpc from pynumaflow._constants import ( MAP_STREAM_SOCK_PATH, @@ -122,7 +122,7 @@ async def aexec(self): # Create a new async server instance and add the servicer to it server = grpc.aio.server(options=self._server_options) server.add_insecure_port(self.sock_path) - mapstream_pb2_grpc.add_MapStreamServicer_to_server( + map_pb2_grpc.add_MapServicer_to_server( self.servicer, server, ) diff --git a/pynumaflow/mapstreamer/servicer/async_servicer.py b/pynumaflow/mapstreamer/servicer/async_servicer.py index 857ddc56..f2e029e3 100644 --- a/pynumaflow/mapstreamer/servicer/async_servicer.py +++ b/pynumaflow/mapstreamer/servicer/async_servicer.py @@ -3,18 +3,18 @@ from google.protobuf import empty_pb2 as _empty_pb2 from pynumaflow.mapstreamer import Datum -from pynumaflow.mapstreamer._dtypes import MapStreamCallable -from pynumaflow.proto.mapstreamer import mapstream_pb2_grpc, mapstream_pb2 +from pynumaflow.mapstreamer._dtypes import MapStreamCallable, MapStreamError +from pynumaflow.proto.mapper import map_pb2_grpc, map_pb2 from pynumaflow.shared.server import exit_on_error from pynumaflow.types import NumaflowServicerContext from pynumaflow._constants import _LOGGER -class AsyncMapStreamServicer(mapstream_pb2_grpc.MapStreamServicer): +class AsyncMapStreamServicer(map_pb2_grpc.MapServicer): """ This class is used to create a new grpc Map Stream Servicer instance. It implements the SyncMapServicer interface from the proto - mapstream_pb2_grpc.py file. + map_pb2_grpc.py file. Provides the functionality for the required rpc methods. """ @@ -24,52 +24,58 @@ def __init__( ): self.__map_stream_handler: MapStreamCallable = handler - async def MapStreamFn( + async def MapFn( self, - request: mapstream_pb2.MapStreamRequest, + request_iterator: AsyncIterable[map_pb2.MapRequest], context: NumaflowServicerContext, - ) -> AsyncIterable[mapstream_pb2.MapStreamResponse]: + ) -> AsyncIterable[map_pb2.MapResponse]: """ Applies a map function to a datum stream in streaming mode. - The pascal case function name comes from the proto mapstream_pb2_grpc.py file. + The pascal case function name comes from the proto map_pb2_grpc.py file. """ - try: - async for res in self.__invoke_map_stream( - list(request.keys), - Datum( - keys=list(request.keys), - value=request.value, - event_time=request.event_time.ToDatetime(), - watermark=request.watermark.ToDatetime(), - headers=dict(request.headers), - ), - context, - ): - yield mapstream_pb2.MapStreamResponse(result=res) + # The first message to be received should be a valid handshake + req = await request_iterator.__anext__() + # check if it is a valid handshake req + if not (req.handshake and req.handshake.sot): + raise MapStreamError("MapStreamFn: expected handshake as the first message") + yield map_pb2.MapResponse(handshake=map_pb2.Handshake(sot=True)) + + # read for each input request + async for req in request_iterator: + # yield messages as received from the UDF + async for res in self.__invoke_map_stream( + list(req.request.keys), + Datum( + keys=list(req.request.keys), + value=req.request.value, + event_time=req.request.event_time.ToDatetime(), + watermark=req.request.watermark.ToDatetime(), + headers=dict(req.request.headers), + ), + ): + yield map_pb2.MapResponse(results=[res], id=req.id) + # send EOT to indicate end of transmission for a given message + yield map_pb2.MapResponse(status=map_pb2.TransmissionStatus(eot=True), id=req.id) except BaseException as err: _LOGGER.critical("UDFError, re-raising the error", exc_info=True) exit_on_error(context, repr(err)) return - async def __invoke_map_stream( - self, keys: list[str], req: Datum, context: NumaflowServicerContext - ): + async def __invoke_map_stream(self, keys: list[str], req: Datum): try: + # Invoke the user handler for map stream async for msg in self.__map_stream_handler(keys, req): - yield mapstream_pb2.MapStreamResponse.Result( - keys=msg.keys, value=msg.value, tags=msg.tags - ) + yield map_pb2.MapResponse.Result(keys=msg.keys, value=msg.value, tags=msg.tags) except BaseException as err: _LOGGER.critical("UDFError, re-raising the error", exc_info=True) - exit_on_error(context, repr(err)) raise err async def IsReady( self, request: _empty_pb2.Empty, context: NumaflowServicerContext - ) -> mapstream_pb2.ReadyResponse: + ) -> map_pb2.ReadyResponse: """ IsReady is the heartbeat endpoint for gRPC. - The pascal case function name comes from the proto mapstream_pb2_grpc.py file. + The pascal case function name comes from the proto map_pb2_grpc.py file. """ - return mapstream_pb2.ReadyResponse(ready=True) + return map_pb2.ReadyResponse(ready=True) diff --git a/pynumaflow/proto/mapstreamer/__init__.py b/pynumaflow/proto/mapstreamer/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/pynumaflow/proto/mapstreamer/mapstream.proto b/pynumaflow/proto/mapstreamer/mapstream.proto deleted file mode 100644 index cbcbe809..00000000 --- a/pynumaflow/proto/mapstreamer/mapstream.proto +++ /dev/null @@ -1,45 +0,0 @@ -syntax = "proto3"; - -import "google/protobuf/empty.proto"; -import "google/protobuf/timestamp.proto"; - - -package mapstream.v1; - -service MapStream { - // MapStreamFn applies a function to each request element and returns a stream. - rpc MapStreamFn(MapStreamRequest) returns (stream MapStreamResponse); - - // IsReady is the heartbeat endpoint for gRPC. - rpc IsReady(google.protobuf.Empty) returns (ReadyResponse); -} - -/** - * MapStreamRequest represents a request element. - */ -message MapStreamRequest { - repeated string keys = 1; - bytes value = 2; - google.protobuf.Timestamp event_time = 3; - google.protobuf.Timestamp watermark = 4; - map headers = 5; -} - -/** - * MapStreamResponse represents a response element. - */ -message MapStreamResponse { - message Result { - repeated string keys = 1; - bytes value = 2; - repeated string tags = 3; - } - Result result = 1; -} - -/** - * ReadyResponse is the health check result. - */ -message ReadyResponse { - bool ready = 1; -} \ No newline at end of file diff --git a/pynumaflow/proto/mapstreamer/mapstream_pb2.py b/pynumaflow/proto/mapstreamer/mapstream_pb2.py deleted file mode 100644 index 8d22bbb3..00000000 --- a/pynumaflow/proto/mapstreamer/mapstream_pb2.py +++ /dev/null @@ -1,43 +0,0 @@ -# -*- coding: utf-8 -*- -# Generated by the protocol buffer compiler. DO NOT EDIT! -# source: mapstream.proto -# Protobuf Python Version: 4.25.1 -"""Generated protocol buffer code.""" -from google.protobuf import descriptor as _descriptor -from google.protobuf import descriptor_pool as _descriptor_pool -from google.protobuf import symbol_database as _symbol_database -from google.protobuf.internal import builder as _builder - -# @@protoc_insertion_point(imports) - -_sym_db = _symbol_database.Default() - - -from google.protobuf import empty_pb2 as google_dot_protobuf_dot_empty__pb2 -from google.protobuf import timestamp_pb2 as google_dot_protobuf_dot_timestamp__pb2 - - -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x0fmapstream.proto\x12\x0cmapstream.v1\x1a\x1bgoogle/protobuf/empty.proto\x1a\x1fgoogle/protobuf/timestamp.proto"\xfc\x01\n\x10MapStreamRequest\x12\x0c\n\x04keys\x18\x01 \x03(\t\x12\r\n\x05value\x18\x02 \x01(\x0c\x12.\n\nevent_time\x18\x03 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12-\n\twatermark\x18\x04 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12<\n\x07headers\x18\x05 \x03(\x0b\x32+.mapstream.v1.MapStreamRequest.HeadersEntry\x1a.\n\x0cHeadersEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01"\x80\x01\n\x11MapStreamResponse\x12\x36\n\x06result\x18\x01 \x01(\x0b\x32&.mapstream.v1.MapStreamResponse.Result\x1a\x33\n\x06Result\x12\x0c\n\x04keys\x18\x01 \x03(\t\x12\r\n\x05value\x18\x02 \x01(\x0c\x12\x0c\n\x04tags\x18\x03 \x03(\t"\x1e\n\rReadyResponse\x12\r\n\x05ready\x18\x01 \x01(\x08\x32\x9d\x01\n\tMapStream\x12P\n\x0bMapStreamFn\x12\x1e.mapstream.v1.MapStreamRequest\x1a\x1f.mapstream.v1.MapStreamResponse0\x01\x12>\n\x07IsReady\x12\x16.google.protobuf.Empty\x1a\x1b.mapstream.v1.ReadyResponseb\x06proto3' -) - -_globals = globals() -_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) -_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "mapstream_pb2", _globals) -if _descriptor._USE_C_DESCRIPTORS == False: - DESCRIPTOR._options = None - _globals["_MAPSTREAMREQUEST_HEADERSENTRY"]._options = None - _globals["_MAPSTREAMREQUEST_HEADERSENTRY"]._serialized_options = b"8\001" - _globals["_MAPSTREAMREQUEST"]._serialized_start = 96 - _globals["_MAPSTREAMREQUEST"]._serialized_end = 348 - _globals["_MAPSTREAMREQUEST_HEADERSENTRY"]._serialized_start = 302 - _globals["_MAPSTREAMREQUEST_HEADERSENTRY"]._serialized_end = 348 - _globals["_MAPSTREAMRESPONSE"]._serialized_start = 351 - _globals["_MAPSTREAMRESPONSE"]._serialized_end = 479 - _globals["_MAPSTREAMRESPONSE_RESULT"]._serialized_start = 428 - _globals["_MAPSTREAMRESPONSE_RESULT"]._serialized_end = 479 - _globals["_READYRESPONSE"]._serialized_start = 481 - _globals["_READYRESPONSE"]._serialized_end = 511 - _globals["_MAPSTREAM"]._serialized_start = 514 - _globals["_MAPSTREAM"]._serialized_end = 671 -# @@protoc_insertion_point(module_scope) diff --git a/pynumaflow/proto/mapstreamer/mapstream_pb2.pyi b/pynumaflow/proto/mapstreamer/mapstream_pb2.pyi deleted file mode 100644 index 834ae027..00000000 --- a/pynumaflow/proto/mapstreamer/mapstream_pb2.pyi +++ /dev/null @@ -1,72 +0,0 @@ -from google.protobuf import empty_pb2 as _empty_pb2 -from google.protobuf import timestamp_pb2 as _timestamp_pb2 -from google.protobuf.internal import containers as _containers -from google.protobuf import descriptor as _descriptor -from google.protobuf import message as _message -from typing import ( - ClassVar as _ClassVar, - Iterable as _Iterable, - Mapping as _Mapping, - Optional as _Optional, - Union as _Union, -) - -DESCRIPTOR: _descriptor.FileDescriptor - -class MapStreamRequest(_message.Message): - __slots__ = ("keys", "value", "event_time", "watermark", "headers") - - class HeadersEntry(_message.Message): - __slots__ = ("key", "value") - KEY_FIELD_NUMBER: _ClassVar[int] - VALUE_FIELD_NUMBER: _ClassVar[int] - key: str - value: str - def __init__(self, key: _Optional[str] = ..., value: _Optional[str] = ...) -> None: ... - KEYS_FIELD_NUMBER: _ClassVar[int] - VALUE_FIELD_NUMBER: _ClassVar[int] - EVENT_TIME_FIELD_NUMBER: _ClassVar[int] - WATERMARK_FIELD_NUMBER: _ClassVar[int] - HEADERS_FIELD_NUMBER: _ClassVar[int] - keys: _containers.RepeatedScalarFieldContainer[str] - value: bytes - event_time: _timestamp_pb2.Timestamp - watermark: _timestamp_pb2.Timestamp - headers: _containers.ScalarMap[str, str] - def __init__( - self, - keys: _Optional[_Iterable[str]] = ..., - value: _Optional[bytes] = ..., - event_time: _Optional[_Union[_timestamp_pb2.Timestamp, _Mapping]] = ..., - watermark: _Optional[_Union[_timestamp_pb2.Timestamp, _Mapping]] = ..., - headers: _Optional[_Mapping[str, str]] = ..., - ) -> None: ... - -class MapStreamResponse(_message.Message): - __slots__ = ("result",) - - class Result(_message.Message): - __slots__ = ("keys", "value", "tags") - KEYS_FIELD_NUMBER: _ClassVar[int] - VALUE_FIELD_NUMBER: _ClassVar[int] - TAGS_FIELD_NUMBER: _ClassVar[int] - keys: _containers.RepeatedScalarFieldContainer[str] - value: bytes - tags: _containers.RepeatedScalarFieldContainer[str] - def __init__( - self, - keys: _Optional[_Iterable[str]] = ..., - value: _Optional[bytes] = ..., - tags: _Optional[_Iterable[str]] = ..., - ) -> None: ... - RESULT_FIELD_NUMBER: _ClassVar[int] - result: MapStreamResponse.Result - def __init__( - self, result: _Optional[_Union[MapStreamResponse.Result, _Mapping]] = ... - ) -> None: ... - -class ReadyResponse(_message.Message): - __slots__ = ("ready",) - READY_FIELD_NUMBER: _ClassVar[int] - ready: bool - def __init__(self, ready: bool = ...) -> None: ... diff --git a/pynumaflow/proto/mapstreamer/mapstream_pb2_grpc.py b/pynumaflow/proto/mapstreamer/mapstream_pb2_grpc.py deleted file mode 100644 index 305c8e05..00000000 --- a/pynumaflow/proto/mapstreamer/mapstream_pb2_grpc.py +++ /dev/null @@ -1,125 +0,0 @@ -# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! -"""Client and server classes corresponding to protobuf-defined services.""" -import grpc - -from google.protobuf import empty_pb2 as google_dot_protobuf_dot_empty__pb2 -from . import mapstream_pb2 as mapstream__pb2 - - -class MapStreamStub(object): - """Missing associated documentation comment in .proto file.""" - - def __init__(self, channel): - """Constructor. - - Args: - channel: A grpc.Channel. - """ - self.MapStreamFn = channel.unary_stream( - "/mapstream.v1.MapStream/MapStreamFn", - request_serializer=mapstream__pb2.MapStreamRequest.SerializeToString, - response_deserializer=mapstream__pb2.MapStreamResponse.FromString, - ) - self.IsReady = channel.unary_unary( - "/mapstream.v1.MapStream/IsReady", - request_serializer=google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString, - response_deserializer=mapstream__pb2.ReadyResponse.FromString, - ) - - -class MapStreamServicer(object): - """Missing associated documentation comment in .proto file.""" - - def MapStreamFn(self, request, context): - """MapStreamFn applies a function to each request element and returns a stream.""" - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details("Method not implemented!") - raise NotImplementedError("Method not implemented!") - - def IsReady(self, request, context): - """IsReady is the heartbeat endpoint for gRPC.""" - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details("Method not implemented!") - raise NotImplementedError("Method not implemented!") - - -def add_MapStreamServicer_to_server(servicer, server): - rpc_method_handlers = { - "MapStreamFn": grpc.unary_stream_rpc_method_handler( - servicer.MapStreamFn, - request_deserializer=mapstream__pb2.MapStreamRequest.FromString, - response_serializer=mapstream__pb2.MapStreamResponse.SerializeToString, - ), - "IsReady": grpc.unary_unary_rpc_method_handler( - servicer.IsReady, - request_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString, - response_serializer=mapstream__pb2.ReadyResponse.SerializeToString, - ), - } - generic_handler = grpc.method_handlers_generic_handler( - "mapstream.v1.MapStream", rpc_method_handlers - ) - server.add_generic_rpc_handlers((generic_handler,)) - - -# This class is part of an EXPERIMENTAL API. -class MapStream(object): - """Missing associated documentation comment in .proto file.""" - - @staticmethod - def MapStreamFn( - request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None, - ): - return grpc.experimental.unary_stream( - request, - target, - "/mapstream.v1.MapStream/MapStreamFn", - mapstream__pb2.MapStreamRequest.SerializeToString, - mapstream__pb2.MapStreamResponse.FromString, - options, - channel_credentials, - insecure, - call_credentials, - compression, - wait_for_ready, - timeout, - metadata, - ) - - @staticmethod - def IsReady( - request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None, - ): - return grpc.experimental.unary_unary( - request, - target, - "/mapstream.v1.MapStream/IsReady", - google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString, - mapstream__pb2.ReadyResponse.FromString, - options, - channel_credentials, - insecure, - call_credentials, - compression, - wait_for_ready, - timeout, - metadata, - ) diff --git a/tests/mapstream/test_async_map_stream.py b/tests/mapstream/test_async_map_stream.py index 1558ccf7..a4b36941 100644 --- a/tests/mapstream/test_async_map_stream.py +++ b/tests/mapstream/test_async_map_stream.py @@ -14,8 +14,8 @@ Datum, MapStreamAsyncServer, ) -from pynumaflow.proto.mapstreamer import mapstream_pb2_grpc -from tests.mapstream.utils import start_request_map_stream +from pynumaflow.proto.mapper import map_pb2_grpc +from tests.mapstream.utils import request_generator LOGGER = setup_logging(__name__) @@ -47,14 +47,14 @@ def startup_callable(loop): def NewAsyncMapStreamer( map_stream_handler=async_map_stream_handler, ): - server = MapStreamAsyncServer(map_stream_instance=async_map_stream_handler) + server = MapStreamAsyncServer(map_stream_instance=map_stream_handler) udfs = server.servicer return udfs async def start_server(udfs): server = grpc.aio.server() - mapstream_pb2_grpc.add_MapStreamServicer_to_server(udfs, server) + map_pb2_grpc.add_MapServicer_to_server(udfs, server) listen_addr = "unix:///tmp/async_map_stream.sock" server.add_insecure_port(listen_addr) logging.info("Starting server on %s", listen_addr) @@ -95,31 +95,43 @@ def tearDownClass(cls) -> None: def test_map_stream(self) -> None: stub = self.__stub() - request = start_request_map_stream() generator_response = None try: - generator_response = stub.MapStreamFn(request=request) + generator_response = stub.MapFn(request_iterator=request_generator(count=1, session=1)) except grpc.RpcError as e: logging.error(e) - counter = 0 - # capture the output from the MapStreamFn generator and assert. + handshake = next(generator_response) + # assert that handshake response is received. + self.assertTrue(handshake.handshake.sot) + data_resp = [] for r in generator_response: - counter += 1 + data_resp.append(r) + + self.assertEqual(11, len(data_resp)) + + idx = 0 + while idx < len(data_resp) - 1: self.assertEqual( bytes( "payload:test_mock_message " "event_time:2022-09-12 16:00:00 watermark:2022-09-12 16:01:00", encoding="utf-8", ), - r.result.value, + data_resp[idx].results[0].value, ) - """Assert that the generator was called 10 times in the stream""" - self.assertEqual(10, counter) + _id = data_resp[idx].id + self.assertEqual(_id, "test-id-0") + # capture the output from the SinkFn generator and assert. + idx += 1 + # EOT Response + self.assertEqual(data_resp[len(data_resp) - 1].status.eot, True) + # 10 sink responses + 1 EOT response + self.assertEqual(11, len(data_resp)) def test_is_ready(self) -> None: with grpc.insecure_channel("unix:///tmp/async_map_stream.sock") as channel: - stub = mapstream_pb2_grpc.MapStreamStub(channel) + stub = map_pb2_grpc.MapStub(channel) request = _empty_pb2.Empty() response = None @@ -131,7 +143,7 @@ def test_is_ready(self) -> None: self.assertTrue(response.ready) def __stub(self): - return mapstream_pb2_grpc.MapStreamStub(_channel) + return map_pb2_grpc.MapStub(_channel) def test_max_threads(self): # max cap at 16 diff --git a/tests/mapstream/test_async_map_stream_err.py b/tests/mapstream/test_async_map_stream_err.py index 27f57273..a93bbee4 100644 --- a/tests/mapstream/test_async_map_stream_err.py +++ b/tests/mapstream/test_async_map_stream_err.py @@ -11,8 +11,8 @@ from pynumaflow import setup_logging from pynumaflow.mapstreamer import Message, Datum, MapStreamAsyncServer -from pynumaflow.proto.mapstreamer import mapstream_pb2_grpc -from tests.mapstream.utils import start_request_map_stream +from pynumaflow.proto.mapper import map_pb2_grpc +from tests.mapstream.utils import request_generator from tests.testing_utils import mock_terminate_on_stop LOGGER = setup_logging(__name__) @@ -47,7 +47,7 @@ async def start_server(): server = grpc.aio.server() server_instance = MapStreamAsyncServer(err_async_map_stream_handler) udfs = server_instance.servicer - mapstream_pb2_grpc.add_MapStreamServicer_to_server(udfs, server) + map_pb2_grpc.add_MapServicer_to_server(udfs, server) listen_addr = "unix:///tmp/async_map_stream_err.sock" server.add_insecure_port(listen_addr) logging.info("Starting server on %s", listen_addr) @@ -88,20 +88,29 @@ def tearDownClass(cls) -> None: LOGGER.error(e) def test_map_stream_error(self) -> None: - stub = self.__stub() - request = start_request_map_stream() try: - generator_response = stub.MapStreamFn(request=request) - counter = 0 - for _ in generator_response: - counter += 1 + stub = self.__stub() + generator_response = None + try: + generator_response = stub.MapFn( + request_iterator=request_generator(count=1, session=1) + ) + except grpc.RpcError as e: + logging.error(e) + + handshake = next(generator_response) + # assert that handshake response is received. + self.assertTrue(handshake.handshake.sot) + data_resp = [] + for r in generator_response: + data_resp.append(r) except Exception as err: self.assertTrue("Got a runtime error from map stream handler." in err.__str__()) return self.fail("Expected an exception.") def __stub(self): - return mapstream_pb2_grpc.MapStreamStub(_channel) + return map_pb2_grpc.MapStub(_channel) def test_invalid_input(self): with self.assertRaises(TypeError): diff --git a/tests/mapstream/utils.py b/tests/mapstream/utils.py index 4e9e4824..40fed81c 100644 --- a/tests/mapstream/utils.py +++ b/tests/mapstream/utils.py @@ -1,14 +1,22 @@ -from pynumaflow.mapstreamer import Datum -from pynumaflow.proto.mapstreamer import mapstream_pb2 -from tests.testing_utils import get_time_args, mock_message +from pynumaflow.proto.mapper import map_pb2 +from tests.testing_utils import get_time_args, mock_message, mock_headers -def start_request_map_stream() -> (Datum, tuple): +def request_generator(count, session=1, handshake=True): event_time_timestamp, watermark_timestamp = get_time_args() - request = mapstream_pb2.MapStreamRequest( - value=mock_message(), - event_time=event_time_timestamp, - watermark=watermark_timestamp, - ) - return request + if handshake: + yield map_pb2.MapRequest(handshake=map_pb2.Handshake(sot=True)) + + for j in range(session): + for i in range(count): + req = map_pb2.MapRequest( + request=map_pb2.MapRequest.Request( + value=mock_message(), + event_time=event_time_timestamp, + watermark=watermark_timestamp, + headers=mock_headers(), + ), + id="test-id-" + str(i), + ) + yield req From fc88d803e74ad9e5d7f4e40af7b94c626407a2d9 Mon Sep 17 00:00:00 2001 From: Sidhant Kohli Date: Wed, 6 Nov 2024 00:25:04 -0800 Subject: [PATCH 3/4] chore: clean Signed-off-by: Sidhant Kohli --- .../batchmapper/servicer/async_servicer.py | 35 ------------------- 1 file changed, 35 deletions(-) diff --git a/pynumaflow/batchmapper/servicer/async_servicer.py b/pynumaflow/batchmapper/servicer/async_servicer.py index f16361ae..1837288c 100644 --- a/pynumaflow/batchmapper/servicer/async_servicer.py +++ b/pynumaflow/batchmapper/servicer/async_servicer.py @@ -105,41 +105,6 @@ async def MapFn( exit_on_error(context, repr(err)) return - # # Create an async iterator from the request iterator - # datum_iterator = datum_generator(request_iterator=request_iterator) - # - # try: - # # invoke the UDF call for batch map - # responses, request_counter = await self.invoke_batch_map(datum_iterator) - # - # # If the number of responses received does not align with the request batch size, - # # we will not be able to process the data correctly. - # # This should be marked as an error and raised to the user. - # if len(responses) != request_counter: - # err_msg = "batchMapFn: mismatch between length of batch requests and responses" - # raise Exception(err_msg) - # - # # iterate over the responses received and covert to the required proto format - # for batch_response in responses: - # single_req_resp = [] - # for msg in batch_response.messages: - # single_req_resp.append( - # batchmap_pb2.BatchMapResponse.Result( - # keys=msg.keys, value=msg.value, tags=msg.tags - # ) - # ) - # - # # send the response for a given ID back to the stream - # yield batchmap_pb2.BatchMapResponse(id=batch_response.id, results=single_req_resp) - # - # except BaseException as err: - # _LOGGER.critical("UDFError, re-raising the error", exc_info=True) - # await asyncio.gather( - # context.abort(grpc.StatusCode.UNKNOWN, details=repr(err)), return_exceptions=True - # ) - # exit_on_error(context, repr(err)) - # return - async def __invoke_batch_map(self, datum_iterator: AsyncIterable[Datum]): """ # iterate over the incoming requests, and keep sending to the user code From 7ba064d24273a24f5d11421625d84347a64fcb75 Mon Sep 17 00:00:00 2001 From: Sidhant Kohli Date: Wed, 6 Nov 2024 10:54:46 -0800 Subject: [PATCH 4/4] chore: clean Signed-off-by: Sidhant Kohli --- .../batchmapper/servicer/async_servicer.py | 13 ------------- tests/mapstream/test_async_map_stream_err.py | 16 ++++++++++++++++ 2 files changed, 16 insertions(+), 13 deletions(-) diff --git a/pynumaflow/batchmapper/servicer/async_servicer.py b/pynumaflow/batchmapper/servicer/async_servicer.py index 1837288c..d9220f5b 100644 --- a/pynumaflow/batchmapper/servicer/async_servicer.py +++ b/pynumaflow/batchmapper/servicer/async_servicer.py @@ -105,19 +105,6 @@ async def MapFn( exit_on_error(context, repr(err)) return - async def __invoke_batch_map(self, datum_iterator: AsyncIterable[Datum]): - """ - # iterate over the incoming requests, and keep sending to the user code - # once all messages have been sent, we wait for the responses - """ - try: - # invoke the user function with the request queue - return await self.__batch_map_handler(datum_iterator) - except BaseException as err: - err_msg = f"UDBatchMapError: {repr(err)}" - _LOGGER.critical(err_msg, exc_info=True) - raise err - async def IsReady( self, request: _empty_pb2.Empty, context: NumaflowServicerContext ) -> map_pb2.ReadyResponse: diff --git a/tests/mapstream/test_async_map_stream_err.py b/tests/mapstream/test_async_map_stream_err.py index a93bbee4..cb7e0ef6 100644 --- a/tests/mapstream/test_async_map_stream_err.py +++ b/tests/mapstream/test_async_map_stream_err.py @@ -109,6 +109,22 @@ def test_map_stream_error(self) -> None: return self.fail("Expected an exception.") + def test_map_stream_error_no_handshake(self) -> None: + global raise_error + raise_error = True + stub = self.__stub() + try: + generator_response = stub.MapFn( + request_iterator=request_generator(count=10, handshake=False, session=1) + ) + counter = 0 + for _ in generator_response: + counter += 1 + except Exception as err: + self.assertTrue("MapStreamFn: expected handshake as the first message" in err.__str__()) + return + self.fail("Expected an exception.") + def __stub(self): return map_pb2_grpc.MapStub(_channel)