Skip to content

Commit a805ff1

Browse files
authored
Merge pull request #1049 from xtexChooser/matrix/issue1046
#1046 for matrix
2 parents d8e6733 + 5842f2a commit a805ff1

File tree

15 files changed

+101
-51
lines changed

15 files changed

+101
-51
lines changed

bots/matrix/bot.py

+22-12
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from bots.matrix import client
1010
from bots.matrix.client import bot
1111
from bots.matrix.info import client_name
12-
from bots.matrix.message import MessageSession, FetchTarget
12+
from bots.matrix.message import MessageSession, FetchTarget, ReactionMessageSession
1313
from core.builtins import PrivateAssets, Url
1414
from core.logger import Logger
1515
from core.parser.message import parser
@@ -50,36 +50,45 @@ async def on_room_member(room: nio.MatrixRoom, event: nio.RoomMemberEvent):
5050
Logger.info(f"Left empty room {room.room_id}")
5151

5252

53-
async def on_message(room: nio.MatrixRoom, event: nio.RoomMessageFormatted):
53+
async def on_message(room: nio.MatrixRoom, event: nio.Event):
5454
if event.sender != bot.user_id and bot.olm:
5555
for device_id, olm_device in bot.device_store[event.sender].items():
5656
if bot.olm.is_device_verified(olm_device):
5757
continue
5858
bot.verify_device(olm_device)
5959
Logger.info(f"trust olm device for device id {event.sender} -> {device_id}")
60-
if event.source['content']['msgtype'] == 'm.notice':
60+
if isinstance(event, nio.RoomMessageFormatted) and event.source['content']['msgtype'] == 'm.notice':
6161
# https://spec.matrix.org/v1.7/client-server-api/#mnotice
6262
return
6363
is_room = room.member_count != 2 or room.join_rule != 'invite'
6464
target_id = room.room_id if is_room else event.sender
6565
reply_id = None
6666
if 'm.relates_to' in event.source['content'] and 'm.in_reply_to' in event.source['content']['m.relates_to']:
6767
reply_id = event.source['content']['m.relates_to']['m.in_reply_to']['event_id']
68+
6869
resp = await bot.get_displayname(event.sender)
6970
if isinstance(resp, nio.ErrorResponse):
7071
Logger.error(f"Failed to get display name for {event.sender}")
7172
return
7273
sender_name = resp.displayname
7374

74-
msg = MessageSession(MsgInfo(target_id=f'Matrix|{target_id}',
75-
sender_id=f'Matrix|{event.sender}',
76-
target_from=f'Matrix',
77-
sender_from='Matrix',
78-
sender_name=sender_name,
79-
client_name=client_name,
80-
message_id=event.event_id,
81-
reply_id=reply_id),
82-
Session(message=event.source, target=room.room_id, sender=event.sender))
75+
target = MsgInfo(target_id=f'Matrix|{target_id}',
76+
sender_id=f'Matrix|{event.sender}',
77+
target_from=f'Matrix',
78+
sender_from='Matrix',
79+
sender_name=sender_name,
80+
client_name=client_name,
81+
message_id=event.event_id,
82+
reply_id=reply_id)
83+
session = Session(message=event.source, target=room.room_id, sender=event.sender)
84+
85+
msg = None
86+
if isinstance(event, nio.RoomMessageFormatted):
87+
msg = MessageSession(target, session)
88+
elif isinstance(event, nio.ReactionEvent):
89+
msg = ReactionMessageSession(target, session)
90+
else:
91+
raise NotImplemented
8392
asyncio.create_task(parser(msg))
8493

8594

@@ -141,6 +150,7 @@ async def start():
141150
bot.add_event_callback(on_invite, nio.InviteEvent)
142151
bot.add_event_callback(on_room_member, nio.RoomMemberEvent)
143152
bot.add_event_callback(on_message, nio.RoomMessageFormatted)
153+
bot.add_event_callback(on_message, nio.ReactionEvent)
144154
bot.add_to_device_callback(on_verify, nio.KeyVerificationEvent)
145155
bot.add_event_callback(on_in_room_verify, nio.RoomMessageUnknown)
146156

bots/matrix/message.py

+32-1
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ async def to_message_chain(self):
201201

202202
async def delete(self):
203203
try:
204-
await bot.room_redact(self.session.target, self.session.message['event_id'])
204+
await bot.room_redact(self.session.target, self.target.message_id)
205205
except Exception:
206206
Logger.error(traceback.format_exc())
207207

@@ -224,6 +224,37 @@ async def __aexit__(self, exc_type, exc_val, exc_tb):
224224
pass
225225

226226

227+
class ReactionMessageSession(MessageSession):
228+
class Feature(MessageSession.Feature):
229+
pass
230+
231+
class Typing(MessageSession.Typing):
232+
pass
233+
234+
def as_display(self, text_only=False):
235+
if text_only:
236+
return ''
237+
return self.session.message['content']['m.relates_to']['key']
238+
239+
async def to_message_chain(self):
240+
return MessageChain([])
241+
242+
def is_quick_confirm(self, target: Union[MessageSession, FinishedSession]) -> bool:
243+
content = self.session.message['content']['m.relates_to']
244+
if content['rel_type'] == 'm.annotation':
245+
if content['key'] in ['👍️', '✔️', '🎉']: # todo: move to config
246+
if target is None:
247+
return True
248+
else:
249+
msg = [target.target.message_id] if isinstance(target, MessageSession) else target.message_id
250+
if content['event_id'] in msg:
251+
return True
252+
return False
253+
254+
asDisplay = as_display
255+
toMessageChain = to_message_chain
256+
257+
227258
class FetchedSession(Bot.FetchedSession):
228259

229260
async def _resolve_matrix_room_(self):

config/config.toml.example

+1
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ debug = false
3131
cache_path = "./cache/"
3232
command_prefix = ["~", "~",]
3333
confirm_command = ["是", "对", "對", "yes", "Yes", "YES", "y", "Y",]
34+
quick_confirm = true
3435
disabled_bots =
3536
locale = "zh_cn"
3637
timezone_offset = "+8"

core/builtins/message/__init__.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@
77
from core.builtins.message.internal import *
88
from core.builtins.tasks import MessageTaskManager
99
from core.builtins.temp import ExecutionLockList
10-
from core.builtins.utils import confirm_command
10+
from core.builtins.utils import confirm_command, quick_confirm
1111
from core.exceptions import WaitCancelException
12-
from core.types.message import MessageSession as MessageSessionT, MsgInfo, Session
12+
from core.types.message import MessageSession as MessageSessionT, FinishedSession, MsgInfo, Session
1313
from core.utils.i18n import Locale
1414
from core.utils.text import parse_time_string
1515
from database import BotDBUtil
@@ -56,12 +56,14 @@ async def wait_confirm(self, message_chain=None, quote=True, delete=True, timeou
5656
await send.delete()
5757
if result.as_display(text_only=True) in confirm_command:
5858
return True
59+
if quick_confirm and result.is_quick_confirm(send):
60+
return True
5961
return False
6062
else:
6163
raise WaitCancelException
6264

6365
async def wait_next_message(self, message_chain=None, quote=True, delete=False, timeout=120,
64-
append_instruction=True) -> MessageSessionT:
66+
append_instruction=True) -> (MessageSessionT, FinishedSession):
6567
sent = None
6668
ExecutionLockList.remove(self)
6769
if message_chain:
@@ -79,7 +81,7 @@ async def wait_next_message(self, message_chain=None, quote=True, delete=False,
7981
if delete and sent:
8082
await sent.delete()
8183
if result:
82-
return result
84+
return (result, sent)
8385
else:
8486
raise WaitCancelException
8587

core/builtins/tasks.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def add_callback(cls, message_id, callback):
2929
cls._callback_list[message_id] = {'callback': callback, 'ts': datetime.now().timestamp()}
3030

3131
@classmethod
32-
def get_result(cls, session: MessageSession):
32+
def get_result(cls, session: MessageSession) -> MessageSession:
3333
if 'result' in cls._list[session.target.target_id][session.target.sender_id][session]:
3434
return cls._list[session.target.target_id][session.target.sender_id][session]['result']
3535
else:

core/builtins/utils/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,12 @@
33

44

55
confirm_command = Config('confirm_command', default=["是", "对", "對", "yes", "Yes", "YES", "y", "Y"])
6+
quick_confirm = Config('quick_confirm', default=True)
67
command_prefix = Config('command_prefix', default=['~', '~']) # 消息前缀
78

89

910
class EnableDirtyWordCheck:
1011
status = False
1112

1213

13-
__all__ = ["confirm_command", "command_prefix", "EnableDirtyWordCheck", "PrivateAssets", "Secret"]
14+
__all__ = ["confirm_command", "quick_confirm", "command_prefix", "EnableDirtyWordCheck", "PrivateAssets", "Secret"]

core/types/message/__init__.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import asyncio
2-
from typing import List, Union, Dict, Coroutine
2+
from typing import List, Union, Dict, Coroutine, Self
33

44
from core.exceptions import FinishedException
55
from .chain import MessageChain
@@ -149,14 +149,14 @@ async def wait_confirm(self, message_chain=None, quote=True, delete=True, timeou
149149
raise NotImplementedError
150150

151151
async def wait_next_message(self, message_chain=None, quote=True, delete=False, timeout=120,
152-
append_instruction=True):
152+
append_instruction=True) -> (Self, FinishedSession):
153153
"""
154154
一次性模板,用于等待对象的下一条消息。
155155
:param message_chain: 需要发送的确认消息,可不填
156156
:param quote: 是否引用传入dict中的消息(默认为True)
157157
:param delete: 是否在触发后删除消息(默认为False)
158158
:param timeout: 超时时间
159-
:return: 下一条消息的MessageChain对象
159+
:return: 下一条消息的MessageChain对象和发出的提示消息
160160
"""
161161
raise NotImplementedError
162162

@@ -215,6 +215,13 @@ async def check_native_permission(self):
215215
"""
216216
raise NotImplementedError
217217

218+
def is_quick_confirm(self, target: Union[Self, FinishedSession] = None) -> bool:
219+
"""
220+
用于检查消息是否可用作快速确认事件。
221+
:param target: 确认的目标消息
222+
"""
223+
return False
224+
218225
async def fake_forward_msg(self, nodelist):
219226
"""
220227
用于发送假转发消息(QQ)。

modules/chemical_code/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,7 @@ async def timer(start):
243243

244244
await asyncio.gather(ans(msg, csr['name'], random_mode), timer(time_start))
245245
else:
246-
result = await msg.wait_next_message([Plain(msg.locale.t('chemical_code.message.showid', id=csr["id"])),
246+
result, _ = await msg.wait_next_message([Plain(msg.locale.t('chemical_code.message.showid', id=csr["id"])),
247247
Image(newpath), Plain(msg.locale.t('chemical_code.message.captcha',
248248
times=set_timeout))], timeout=3600, append_instruction=False)
249249
if play_state[msg.target.target_id]['active']:

modules/ncmusic/__init__.py

+6-7
Original file line numberDiff line numberDiff line change
@@ -32,16 +32,16 @@ async def search(msg: Bot.MessageSession, keyword: str):
3232
f"{' / '.join(artist['name'] for artist in song['artists'])}",
3333
f"{song['album']['name']}" + (f" ({' / '.join(song['album']['transNames'])})" if 'transNames' in song['album'] else ''),
3434
f"{song['id']}"
35-
] for i, song in enumerate(songs, start=1)
36-
]
35+
] for i, song in enumerate(songs, start=1)
36+
]
3737

3838
tables = ImageTable(data, [
3939
msg.locale.t('ncmusic.message.search.table.header.id'),
4040
msg.locale.t('ncmusic.message.search.table.header.name'),
4141
msg.locale.t('ncmusic.message.search.table.header.artists'),
4242
msg.locale.t('ncmusic.message.search.table.header.album'),
4343
'ID'
44-
])
44+
])
4545

4646
img = await image_table_render(tables)
4747
if img:
@@ -62,7 +62,7 @@ async def search(msg: Bot.MessageSession, keyword: str):
6262

6363
else:
6464
send_msg.append(Plain(msg.locale.t('ncmusic.message.search.prompt')))
65-
query = await msg.wait_reply(send_msg)
65+
query, _ = await msg.wait_next_message(send_msg)
6666
query = query.as_display(text_only=True)
6767

6868
if query.isdigit():
@@ -89,21 +89,20 @@ async def search(msg: Bot.MessageSession, keyword: str):
8989
if 'transNames' in song['album']:
9090
send_msg += f"({' / '.join(song['album']['transNames'])})"
9191
send_msg += f"({song['id']}\n"
92-
9392
if song_count > 10:
9493
song_count = 10
9594
send_msg += msg.locale.t("message.collapse", amount="10")
9695

9796
if song_count == 1:
9897
send_msg += '\n' + msg.locale.t('ncmusic.message.search.confirm')
99-
query = await msg.wait_confirm(send_msg, delete=False)
98+
query, _ = await msg.wait_next_message(send_msg)
10099
if query:
101100
sid = result['result']['songs'][0]['id']
102101
else:
103102
return
104103
else:
105104
send_msg += '\n' + msg.locale.t('ncmusic.message.search.prompt')
106-
query = await msg.wait_reply(send_msg)
105+
query, _ = await msg.wait_next_message(send_msg)
107106
query = query.as_display(text_only=True)
108107

109108
if query.isdigit():

modules/summary/__init__.py

+5-7
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,10 @@
1111

1212
openai.api_key = Config('openai_api_key')
1313

14-
s = module('summary',
15-
developers=['Dianliang233', 'OasisAkari'],
16-
desc='{summary.help.desc}',
17-
available_for=['QQ', 'QQ|Group'])
14+
s = module('summary',
15+
developers=['Dianliang233', 'OasisAkari'],
16+
desc='{summary.help.desc}',
17+
available_for=['QQ', 'QQ|Group'])
1818

1919

2020
@s.handle('{{summary.help}}')
@@ -28,7 +28,7 @@ async def _(msg: Bot.MessageSession):
2828
qc = CoolDown('call_openai', msg)
2929
c = qc.check(60)
3030
if c == 0 or msg.target.target_from == 'TEST|Console' or is_superuser:
31-
f_msg = await msg.wait_next_message(msg.locale.t('summary.message'), append_instruction=False)
31+
f_msg, _ = await msg.wait_next_message(msg.locale.t('summary.message'), append_instruction=False)
3232
try:
3333
f = re.search(r'\[Ke:forward,id=(.*?)\]', f_msg.as_display()).group(1)
3434
except AttributeError:
@@ -86,5 +86,3 @@ async def _(msg: Bot.MessageSession):
8686
await msg.finish(output, disable_secret_check=True)
8787
else:
8888
await msg.finish(msg.locale.t('message.cooldown', time=int(c), cd_time='60'))
89-
90-

modules/twenty_four/__init__.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -107,8 +107,8 @@ async def _(msg: Bot.MessageSession):
107107

108108
numbers = [random.randint(1, 13) for _ in range(4)]
109109
has_solution_flag = await has_solution(numbers)
110-
111-
answer = await msg.wait_next_message(msg.locale.t('twenty_four.message', numbers=numbers), timeout=3600, append_instruction=False)
110+
111+
answer, _ = await msg.wait_next_message(msg.locale.t('twenty_four.message', numbers=numbers), timeout=3600, append_instruction=False)
112112
expression = answer.as_display(text_only=True)
113113
if play_state[msg.target.target_id]['active']:
114114
if expression.lower() in no_solution:

modules/wiki/wiki.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import filetype
66

7-
from core.builtins import Bot, Plain, Image, Voice, Url, confirm_command
7+
from core.builtins import Bot, Plain, Image, Voice, Url, confirm_command, quick_confirm
88
from core.utils.image_table import image_table_render, ImageTable
99
from core.component import module
1010
from core.exceptions import AbuseWarning
@@ -377,11 +377,13 @@ async def image_and_voice():
377377

378378
async def wait_confirm():
379379
if wait_msg_list and session.Feature.wait:
380-
confirm = await session.wait_next_message(wait_msg_list, delete=True, append_instruction=False)
380+
confirm, sent = await session.wait_next_message(wait_msg_list, delete=True, append_instruction=False)
381381
auto_index = False
382382
index = 0
383383
if confirm.as_display(text_only=True) in confirm_command:
384384
auto_index = True
385+
elif quick_confirm and confirm.is_quick_confirm(sent):
386+
auto_index = True
385387
elif confirm.as_display(text_only=True).isdigit():
386388
index = int(confirm.as_display()) - 1
387389
else:

poetry.lock

+5-6
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ pycryptodome = "^3.18.0"
5454
langconv = "^0.2.0"
5555
toml = "^0.10.2"
5656
khl-py = "^0.3.16"
57-
matrix-nio = "^0.21.2"
57+
matrix-nio = "^0.22.0"
5858
attrs = "^23.1.0"
5959
uvicorn = {extras = ["standard"], version = "^0.23.2"}
6060
pyjwt = {extras = ["crypto"], version = "^2.8.0"}

0 commit comments

Comments
 (0)