Skip to content

Commit

Permalink
add test for middleware
Browse files Browse the repository at this point in the history
  • Loading branch information
zhqu1148980644 committed May 25, 2021
1 parent 4b7d28d commit 27eec05
Show file tree
Hide file tree
Showing 7 changed files with 100 additions and 4 deletions.
45 changes: 45 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ pip install aiosaber

```python
from aiosaber import *

@task
def add(self, num):
for i in range(100000):
Expand All @@ -67,3 +68,47 @@ num_ch = Channel.values(*list(range(100)))
f = my_flow(num_ch)
asyncio.run(f.start())
```

## Middleware example

```python
from aiosaber import *

class NameBuilder(BaseBuilder):
def __call__(self, com, *args, **kwargs):
super().__call__(com, *args, **kwargs)
com.context['name'] = type(com).__name__ + str(id(com))

class ClientProvider(BaseExecutor):
async def __call__(self, com, **kwargs):
if not context.context.get('client'):
context.context['client'] = 'client'
return await super().__call__(com, **kwargs)

class Filter(BaseHandler):
async def __call__(self, com, get, put, **kwargs):
async def filter_put(data):
if data is END or data > 3:
await put(data)

return await super().__call__(com, get, filter_put, **kwargs)

@task
async def add(self, num):
print(self.context['name'])
print(context.context['client'])
return num + 1

@flow
def myflow(num_ch):
return num_ch | add | view

context.context.update({
'builders': [NameBuilder],
'executors': [ClientProvider],
'handlers': [Filter]
})
f = myflow(Channel.values(1, 2, 3, 4, 5))
context.context.clear()
asyncio.run(f.start())
```
4 changes: 3 additions & 1 deletion aiosaber/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = '0.0.1'
__version__ = '0.0.1.1'

import warnings

Expand All @@ -9,6 +9,8 @@
# noinspection PyUnresolvedReferences
from .flow import *
# noinspection PyUnresolvedReferences
from .middleware import BaseHandler, BaseExecutor, BaseBuilder
# noinspection PyUnresolvedReferences
from .operators import *
# noinspection PyUnresolvedReferences
from .task import *
Expand Down
2 changes: 1 addition & 1 deletion aiosaber/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from aiosaber import context
from aiosaber.channel import Consumer
from aiosaber.plugins import ContextExecutor
from aiosaber.middleware import ContextExecutor
from aiosaber.utility.typings import ChannelOutput, Builder, Executor

if TYPE_CHECKING:
Expand Down
2 changes: 1 addition & 1 deletion aiosaber/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from aiosaber import context
from aiosaber.channel import Channel
from aiosaber.component import Component
from aiosaber.plugins import DaskExecutorProvider
from aiosaber.middleware import DaskExecutorProvider
from aiosaber.utility.typings import ChannelOutput
from aiosaber.utility.utils import class_deco

Expand Down
File renamed without changes.
6 changes: 5 additions & 1 deletion aiosaber/utility/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,13 @@ def first_exception(fut: asyncio.Future):
return_when = return_when or first_exception

def _done_callback(fut: asyncio.Future):
q.task_done()
sem.release()
futures.remove(fut)
# TODO may over called
try:
q.task_done()
except Exception:
pass
# set waiter done once
if not waiter.done() and return_when(fut):
# propagate all state to current coro
Expand Down
45 changes: 45 additions & 0 deletions tests/test_middleware.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from aiosaber import *


def test_middleware():
class NameBuilder(BaseBuilder):
def __call__(self, com, *args, **kwargs):
super().__call__(com, *args, **kwargs)
com.context['name'] = type(com).__name__ + str(id(com))

class ClientProvider(BaseExecutor):
async def __call__(self, com, **kwargs):
if not context.context.get('client'):
context.context['client'] = 'client'
return await super().__call__(com, **kwargs)

class Filter(BaseHandler):
async def __call__(self, com, get, put, **kwargs):
async def filter_put(data):
if data is END or data > 3:
await put(data)

return await super().__call__(com, get, filter_put, **kwargs)

@task
async def add(self, num):
print(self.context['name'])
print(context.context['client'])
return num + 1

@flow
def myflow(num_ch):
return num_ch | add | view

context.context.update({
'builders': [NameBuilder],
'executors': [ClientProvider],
'handlers': [Filter]
})
f = myflow(Channel.values(1, 2, 3, 4, 5))
context.context.clear()
asyncio.run(f.start())


if __name__ == "__main__":
test_middleware()

0 comments on commit 27eec05

Please sign in to comment.