Skip to content

Commit

Permalink
fix: improve DI performance
Browse files Browse the repository at this point in the history
avoids allocating a dict each time the middleware
  • Loading branch information
woile committed Jan 6, 2025
1 parent 206d434 commit ef5a833
Showing 1 changed file with 21 additions and 11 deletions.
32 changes: 21 additions & 11 deletions kstreams/middleware/udf_middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,18 @@ async def anext(async_gen: typing.AsyncGenerator):


class UdfHandler(BaseMiddleware):
"""User Defined Function Handler Middleware
Manages dependency injection for user defined functions (UDFs) that are
defined as coroutines. The UDFs are defined by the user and are passed
to the stream engine to be executed when a consumer record is received.
The UDFs can have different signatures and the middleware is responsible
for managing the dependency injection for the UDFs.
UdfHandler tries to stay small and performant.
"""

def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
signature = inspect.signature(self.next_call)
Expand All @@ -23,20 +35,17 @@ def __init__(self, *args, **kwargs) -> None:
for param in signature.parameters.values()
]
self.type: UDFType = setup_type(self.params)

def get_type(self) -> UDFType:
return self.type

def bind_udf_params(self, cr: types.ConsumerRecord) -> typing.List:
# NOTE: When `no typing` support is deprecated then this can
# be more eficient as the CR will be always there.
ANNOTATIONS_TO_PARAMS = {
types.ConsumerRecord: cr,
self.annotations_to_params: dict[type, typing.Any] = {
types.ConsumerRecord: None,
Stream: self.stream,
types.Send: self.send,
}

return [ANNOTATIONS_TO_PARAMS[param_type] for param_type in self.params]
def get_type(self) -> UDFType:
return self.type

def bind_cr(self, cr: types.ConsumerRecord) -> None:
self.annotations_to_params[types.ConsumerRecord] = cr

async def __call__(self, cr: types.ConsumerRecord) -> typing.Any:
"""
Expand All @@ -58,7 +67,8 @@ async def consume(cr: ConsumerRecord):
async def consume(cr: ConsumerRecord, stream: Stream):
...
"""
params = self.bind_udf_params(cr)
self.bind_cr(cr)
params = [self.annotations_to_params[param_type] for param_type in self.params]

if inspect.isasyncgenfunction(self.next_call):
return await anext(self.next_call(*params))
Expand Down

0 comments on commit ef5a833

Please sign in to comment.