Skip to content

Commit

Permalink
实现current_rate_limit_token (#23)
Browse files Browse the repository at this point in the history
* 实现current_rate_limit_token

* 升级版本号
  • Loading branch information
ssttkkl authored Nov 26, 2023
1 parent 9043c5b commit c407ce7
Show file tree
Hide file tree
Showing 6 changed files with 75 additions and 50 deletions.
57 changes: 12 additions & 45 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "nonebot-plugin-access-control"
version = "1.1.3"
version = "1.1.4"
description = ""
authors = ["ssttkkl <[email protected]>"]
license = "MIT"
Expand All @@ -13,7 +13,7 @@ packages = [
[tool.poetry.dependencies]
python = "^3.9"
nonebot2 = "^2.1.0"
nonebot-plugin-access-control-api = "^1.1.1"
nonebot-plugin-access-control-api = "^1.1.2"
nonebot-plugin-apscheduler = ">=0.3.0"
nonebot-plugin-session = "^0.2.0"
nonebot-plugin-orm = ">=0.5.0, <1.0.0"
Expand Down
21 changes: 21 additions & 0 deletions src/nonebot_plugin_ac_demo/matcher_demo.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from nonebot import on_command
from nonebot.internal.matcher import Matcher
from nonebot_plugin_access_control_api.service.contextvars import (
current_rate_limit_token,
)

from .plugin_service import plugin_service

Expand Down Expand Up @@ -33,3 +36,21 @@ async def _(matcher: Matcher):
@c_service.patch_handler()
async def _(matcher: Matcher):
await matcher.send("c")


d_counter = 0

d_matcher = on_command("d", priority=99)
d_service = plugin_service.create_subservice("d")


@d_matcher.handle()
@d_service.patch_handler()
async def _(matcher: Matcher):
global d_counter
d_counter += 1
if d_counter % 2 == 0:
await current_rate_limit_token.get().retire()
await matcher.send("retired")
else:
await matcher.send("d")
10 changes: 7 additions & 3 deletions src/nonebot_plugin_access_control/service/_impl/patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
from nonebot_plugin_access_control_api.service import get_nonebot_service
from nonebot_plugin_access_control_api.models.rate_limit import AcquireTokenResult
from nonebot_plugin_access_control_api.service.interface.patcher import IServicePatcher
from nonebot_plugin_access_control_api.service.contextvars import (
current_rate_limit_token,
)
from nonebot_plugin_access_control_api.errors import (
RateLimitedError,
PermissionDeniedError,
Expand Down Expand Up @@ -76,22 +79,23 @@ async def wrapped_func(*args, **kwargs):
await self.handle_rate_limited(matcher, result)
return

matcher.state["ac_token"] = result.token

t = current_rate_limit_token.set(result.token)
try:
return await func(*args, **kwargs)
except BaseException as e:
if retire_on_throw:
await result.token.retire()
raise e
finally:
current_rate_limit_token.reset(t)

return wrapped_func

return decorator


@run_preprocessor
async def check(matcher: Matcher, bot: Bot, event: Event):
async def check(bot: Bot, event: Event, matcher: Matcher):
service = ServicePatcherImpl._matcher_service_mapping.get(type(matcher), None)
if service is None:
return
Expand Down
32 changes: 32 additions & 0 deletions src/tests/test_rate_limit.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,3 +111,35 @@ async def test_rate_limit_overwrite(app: App):
event = fake_ob11_group_message_event("/c")
ctx.receive_event(bot, event)
ctx.should_call_send(event, "c")


@pytest.mark.asyncio
async def test_rate_limit_retire(app: App):
from nonebot.adapters.onebot.v11 import Bot
from nonebot_plugin_access_control_api.service import get_nonebot_service

from nonebot_plugin_ac_demo.matcher_demo import d_matcher

# service: nonebot
# subject: all
# span: 1s
# limit: 2
await get_nonebot_service().add_rate_limit_rule("all", timedelta(seconds=1), 2)

async with app.test_matcher(d_matcher) as ctx:
bot = ctx.create_bot(base=Bot, self_id=str(SELF_ID))
event = fake_ob11_group_message_event("/d")
ctx.receive_event(bot, event)
ctx.should_call_send(event, "d")

async with app.test_matcher(d_matcher) as ctx:
bot = ctx.create_bot(base=Bot, self_id=str(SELF_ID))
event = fake_ob11_group_message_event("/d")
ctx.receive_event(bot, event)
ctx.should_call_send(event, "retired")

async with app.test_matcher(d_matcher) as ctx:
bot = ctx.create_bot(base=Bot, self_id=str(SELF_ID))
event = fake_ob11_group_message_event("/d")
ctx.receive_event(bot, event)
ctx.should_call_send(event, "d")
1 change: 1 addition & 0 deletions src/tests/test_service_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ async def test_service_ls_handler(app: App):
"│ ├── a\n"
"│ └── b\n"
"├── c\n"
"├── d\n"
"└── tick"
)

Expand Down

0 comments on commit c407ce7

Please sign in to comment.