diff --git a/README.md b/README.md index 89da01d..64e7690 100644 --- a/README.md +++ b/README.md @@ -61,6 +61,7 @@ You'll need at least these: - `DISCORD_GUILD`: your server ID - `DATABASE_PATH`: the path to the Q&A database (normally in `./database/stampy.db`). - `STAMPY_MODULES`: list of your desired modules, or leave unset to load all modules in the `./modules/` directory. You probably don't want all, as some of them aren't applicable to servers other than Rob's. +- `BOT_PRIVATE_CHANNEL_ID`: single channel where private Stampy status updates and info are sent Not required: @@ -68,7 +69,6 @@ Not required: - `BOT_DEV_ROLES`: list of roles representing bot devs. - `BOT_DEV_IDS`: list of user ids of bot devs. You may want to include `BOT_VIP_IDS` here. - `BOT_CONTROL_CHANNEL_IDS`: list of channels where control commands are accepted. -- `BOT_PRIVATE_CHANNEL_ID`: single channel where private Stampy status updates are sent - `BOT_ERROR_CHANNEL_ID`: (defaults to private channel) low level error tracebacks from Python. with this variable they can be shunted to a seperate channel. - `CODA_API_TOKEN`: token to access Coda. Without it, modules `Questions` and `QuestionSetter` will not be available and `StampyControls` will have limited functionality. - `BOT_REBOOT`: how Stampy reboots himself. Unset, he only quits, expecting an external `while true` loop (like in `runstampy`/Dockerfile). Set to `exec` he will try to relaunch himself from his own CLI arguments. diff --git a/config.py b/config.py index 48d379e..c3c8620 100644 --- a/config.py +++ b/config.py @@ -1,5 +1,6 @@ import os from typing import Literal, TypeVar, Optional, Union, cast, get_args, overload, Any, Tuple +from pathlib import Path import dotenv from structlog import get_logger @@ -10,16 +11,15 @@ dotenv.load_dotenv() NOT_PROVIDED = "__NOT_PROVIDED__" -module_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "modules") +module_dir = Path(__file__).parent / 'modules' def get_all_modules() -> frozenset[str]: - modules = set() - for file_name in os.listdir(module_dir): - if file_name.endswith(".py") and file_name not in ("__init__.py", "module.py"): - modules.add(file_name[:-3]) - - return frozenset(modules) + return frozenset({ + filename.stem + for filename in module_dir.glob('*.py') + if filename.suffix == '.py' and filename.name not in ('__init__.py', 'module.py') + }) ALL_STAMPY_MODULES = get_all_modules() @@ -47,8 +47,7 @@ def getenv(env_var: str, default = NOT_PROVIDED) -> str: def getenv_bool(env_var: str) -> bool: - e = getenv(env_var, default="UNDEFINED") - return e != "UNDEFINED" + return getenv(env_var, default="UNDEFINED") != "UNDEFINED" # fmt:off @@ -64,12 +63,12 @@ def getenv_unique_set(var_name: str, default: T) -> Union[frozenset[str], T]:... def getenv_unique_set(var_name: str, default: T = frozenset()) -> Union[frozenset, T]: - l = getenv(var_name, default="EMPTY_SET").split(" ") - if l == ["EMPTY_SET"]: + var = getenv(var_name, default='') + if not var.strip(): return default - s = frozenset(l) - assert len(l) == len(s), f"{var_name} has duplicate members! {l}" - return s + items = var.split() + assert len(items) == len(set(items)), f"{var_name} has duplicate members! {sorted(items)}" + return frozenset(items) maximum_recursion_depth = 30 @@ -150,6 +149,11 @@ def getenv_unique_set(var_name: str, default: T = frozenset()) -> Union[frozense channel_whitelist: Optional[frozenset[str]] disable_prompt_moderation: bool +## Flask settings +if flask_port := getenv('FLASK_PORT', '2300'): + flask_port = int(flask_port) +flask_address = getenv('FLASK_ADDRESS', "0.0.0.0") + is_rob_server = getenv_bool("IS_ROB_SERVER") if is_rob_server: # use robmiles server defaults @@ -233,7 +237,7 @@ def getenv_unique_set(var_name: str, default: T = frozenset()) -> Union[frozense bot_dev_roles = getenv_unique_set("BOT_DEV_ROLES", frozenset()) bot_dev_ids = getenv_unique_set("BOT_DEV_IDS", frozenset()) bot_control_channel_ids = getenv_unique_set("BOT_CONTROL_CHANNEL_IDS", frozenset()) - bot_private_channel_id = getenv("BOT_PRIVATE_CHANNEL_ID", '') + bot_private_channel_id = getenv("BOT_PRIVATE_CHANNEL_ID") bot_error_channel_id = getenv("BOT_ERROR_CHANNEL_ID", bot_private_channel_id) # NOTE: Rob's invite/member management functions, not ported yet member_role_id = getenv("MEMBER_ROLE_ID", default=None) diff --git a/modules/module.py b/modules/module.py index aa33747..e5c3602 100644 --- a/modules/module.py +++ b/modules/module.py @@ -243,6 +243,8 @@ def is_at_me(self, message: ServiceMessage) -> Union[str, Literal[False]]: r",? @?[sS](tampy)?(?P[.!?]*)$", r"\g", text ) at_me = True + elif re.search(r'^[sS]tamp[ys]?\?', text): + at_me = True if message.is_dm: # DMs are always at you @@ -255,10 +257,7 @@ def is_at_me(self, message: ServiceMessage) -> Union[str, Literal[False]]: ) at_me = True - if at_me: - return text - else: - return False + return at_me and text def get_guild_and_invite_role(self): return get_guild_and_invite_role() diff --git a/modules/stampy_chat.py b/modules/stampy_chat.py new file mode 100644 index 0000000..e472ee6 --- /dev/null +++ b/modules/stampy_chat.py @@ -0,0 +1,172 @@ +""" +Queries chat.stampy.ai with the user's question. + +""" + +import json +import re +from collections import deque, defaultdict +from typing import Iterable, List, Dict, Any +from uuid import uuid4 + +import requests +from structlog import get_logger + +from modules.module import Module, Response +from servicemodules.serviceConstants import italicise +from utilities.serviceutils import ServiceChannel, ServiceMessage +from utilities.utilities import Utilities + +log = get_logger() +utils = Utilities.get_instance() + + +LOG_MAX_MESSAGES = 15 # don't store more than X messages back +DATA_HEADER = 'data: ' + +STAMPY_CHAT_ENDPOINT = "https://chat.stampy.ai:8443/chat" +NLP_SEARCH_ENDPOINT = "https://nlp.stampy.ai" + +STAMPY_ANSWER_MIN_SCORE = 0.75 +STAMPY_CHAT_MIN_SCORE = 0.4 + + +def stream_lines(stream: Iterable): + line = '' + for item in stream: + item = item.decode('utf8') + line += item + if '\n' in line: + lines = line.split('\n') + line = lines[-1] + for l in lines[:-1]: + yield l + yield line + + +def parse_data_items(stream: Iterable): + for item in stream: + if item.strip().startswith(DATA_HEADER): + yield json.loads(item.split(DATA_HEADER)[1]) + + +def top_nlp_search(query: str) -> Dict[str, Any]: + resp = requests.get(NLP_SEARCH_ENDPOINT + '/api/search', params={'query': query, 'status': 'all'}) + if not resp: + return {} + + items = resp.json() + if not items: + return {} + return items[0] + + +def chunk_text(text: str, chunk_limit=2000, delimiter='.'): + chunk = '' + for sentence in text.split(delimiter): + if len(chunk + sentence) + 1 >= chunk_limit and chunk and sentence: + yield chunk + chunk = sentence + delimiter + elif sentence: + chunk += sentence + delimiter + yield chunk + + +def filter_citations(text, citations): + used_citations = re.findall(r'\[([a-z],? ?)*?\]', text) + return [c for c in citations if c.get('reference') in used_citations] + + +class StampyChat(Module): + + def __init__(self): + self.utils = Utilities.get_instance() + self._messages: dict[ServiceChannel, deque[ServiceMessage]] = defaultdict(lambda: deque(maxlen=LOG_MAX_MESSAGES)) + self.session_id = str(uuid4()) + super().__init__() + + @property + def class_name(self): + return 'stampy_chat' + + def format_message(self, message: ServiceMessage): + return { + 'content': message.content, + 'role': 'assistant' if self.utils.stampy_is_author(message) else 'user', + } + + def stream_chat_response(self, query: str, history: List[ServiceMessage]): + return parse_data_items(stream_lines(requests.post(STAMPY_CHAT_ENDPOINT, stream=True, json={ + 'query': query, + 'history': [self.format_message(m) for m in history], + 'sessionId': self.session_id, + 'settings': {'mode': 'discord'}, + }))) + + def get_chat_response(self, query: str, history: List[ServiceMessage]): + response = {'citations': [], 'content': '', 'followups': []} + for item in self.stream_chat_response(query, history): + if item.get('state') == 'citations': + response['citations'] += item.get('citations', []) + elif item.get('state') == 'streaming': + response['content'] += item.get('content', '') + elif item.get('state') == 'followups': + response['followups'] += item.get('followups', []) + response['citations'] = filter_citations(response['content'], response['citations']) + return response + + async def query(self, query: str, history: List[ServiceMessage], message: ServiceMessage): + log.info('calling %s', query) + chat_response = self.get_chat_response(query, history) + content_chunks = list(chunk_text(chat_response['content'])) + citations = [f'[{c["reference"]}] - {c["title"]} ({c["url"]})' for c in chat_response['citations'] if c.get('reference')] + if citations: + citations = ['Citations: \n' + '\n'.join(citations)] + followups = [] + if follows := chat_response['followups']: + followups = [ + 'Checkout these articles for more info: \n' + '\n'.join( + f'{f["text"]} - https://aisafety.info?state={f["pageid"]}' for f in follows + ) + ] + + log.info('response: %s', content_chunks + citations + followups) + return Response( + confidence=10, + text=[italicise(text, message) for text in content_chunks + citations + followups], + why='This is what the chat bot returned' + ) + + def _add_message(self, message: ServiceMessage) -> deque[ServiceMessage]: + self._messages[message.channel].append(message) + return self._messages[message.channel] + + def make_query(self, messages): + if not messages: + return '', messages + + current = messages[-1] + query, history = '', list(messages) + while message := (history and history.pop()): + if message.author != current.author: + break + query = message.content + ' ' + query + current = message + return query, history + + def process_message(self, message: ServiceMessage) -> Response: + history = self._add_message(message) + + if not self.is_at_me(message): + return Response() + + query, history = self.make_query(history) + nlp = top_nlp_search(query) + if nlp.get('score', 0) > STAMPY_ANSWER_MIN_SCORE and nlp.get('status') == 'Live on site': + return Response(confidence=5, text=f'Check out {nlp.get("url")} ({nlp.get("title")})') + if nlp.get('score', 0) > STAMPY_CHAT_MIN_SCORE: + return Response(confidence=6, callback=self.query, args=[query, history, message]) + return Response() + + def process_message_from_stampy(self, message: ServiceMessage): + self._add_message(message) diff --git a/servicemodules/discord.py b/servicemodules/discord.py index 337410b..7afb706 100644 --- a/servicemodules/discord.py +++ b/servicemodules/discord.py @@ -47,13 +47,15 @@ # TODO: store long responses temporarily for viewing outside of discord -def limit_text_and_notify(response: Response, why_traceback: list[str]) -> str: +def limit_text_and_notify(response: Response, why_traceback: list[str]) -> Union[str, Iterable]: if isinstance(response.text, str): wastrimmed = False wastrimmed, text_to_return = limit_text(response.text, discordLimit) if wastrimmed: why_traceback.append(f"I had to trim the output from {response.module}") return text_to_return + elif isinstance(response.text, (list, tuple)): + return response.text return "" diff --git a/servicemodules/flask.py b/servicemodules/flask.py index df2561f..fff658c 100644 --- a/servicemodules/flask.py +++ b/servicemodules/flask.py @@ -1,6 +1,6 @@ from flask import Response as FlaskResponse from collections.abc import Iterable -from config import TEST_RESPONSE_PREFIX, maximum_recursion_depth +from config import TEST_RESPONSE_PREFIX, maximum_recursion_depth, flask_port, flask_address from flask import Flask, request from modules.module import Response from structlog import get_logger @@ -46,29 +46,51 @@ def process_event(self) -> FlaskResponse: Keys are currently defined in utilities.flaskutils """ if request.is_json: - message = request.get_json() - message[ - "content" - ] += " s" # This plus s should make it always trigger the is_at_me functions. + message = FlaskMessage.from_dict(request.get_json()) + elif request.form: + message = FlaskMessage.from_dict(request.form) else: - content = ( - request.form.get("content") + " s" - ) # This plus s should make it always trigger the is_at_me functions. - key = request.form.get("key") - modules = json.loads( - request.form.get("modules", json.dumps(list(self.modules.keys()))) - ) - message = {"content": content, "key": key, "modules": modules} - response = self.on_message(FlaskMessage(message)) + return FlaskResponse("No data provided - aborting", 400) + + try: + response = self.on_message(message) + except Exception as e: + response = FlaskResponse(str(e), 400) + log.debug(class_name, response=response, type=type(response)) return response def process_list_modules(self) -> FlaskResponse: return FlaskResponse(json.dumps(list(self.modules.keys()))) + def _module_responses(self, message): + if message.modules is None: + message.modules = list(self.modules.keys()) + elif not message.modules: + raise LookupError('No modules specified') + + responses = [Response()] + for key, module in self.modules.items(): + if key not in message.modules: + log.info(class_name, msg=f"# Skipping module: {key}") + continue # Skip this module if it's not requested. + + log.info(class_name, msg=f"# Asking module: {module}") + response = module.process_message(message) + if response: + response.module = module + if response.callback: + response.confidence -= 0.001 + responses.append(response) + return responses + def on_message(self, message: FlaskMessage) -> FlaskResponse: if is_test_message(message.content) and self.utils.test_mode: log.info(class_name, type="TEST MESSAGE", message_content=message.content) + elif self.utils.stampy_is_author(message): + for module in self.modules.values(): + module.process_message_from_stampy(message) + return FlaskResponse("ok - if that's what I said", 200) log.info( class_name, @@ -80,19 +102,7 @@ def on_message(self, message: FlaskMessage) -> FlaskResponse: message_content=message.content, ) - responses = [Response()] - for key in self.modules: - if message.modules and key not in message.modules: - log.info(class_name, msg=f"# Skipping module: {key}") - continue # Skip this module if it's not requested. - module = self.modules[key] - log.info(class_name, msg=f"# Asking module: {module}") - response = module.process_message(message) - if response: - response.module = module - if response.callback: - response.confidence -= 0.001 - responses.append(response) + responses = self._module_responses(message) for i in range(maximum_recursion_depth): responses = sorted(responses, key=(lambda x: x.confidence), reverse=True) @@ -100,10 +110,10 @@ def on_message(self, message: FlaskMessage) -> FlaskResponse: for response in responses: args_string = "" if response.callback: - args_string = ", ".join([a.__repr__() for a in response.args]) + args_string = ", ".join([repr(a) for a in response.args]) if response.kwargs: args_string += ", " + ", ".join( - [f"{k}={v.__repr__()}" for k, v in response.kwargs.items()] + [f"{k}={repr(v)}" for k, v in response.kwargs.items()] ) log.info( class_name, @@ -159,7 +169,7 @@ def run(self): app.add_url_rule( "/list_modules", view_func=self.process_list_modules, methods=["GET"] ) - app.run(host="0.0.0.0", port=2300, debug=False) + app.run(host=flask_address, port=flask_port) def stop(self): exit() diff --git a/servicemodules/serviceConstants.py b/servicemodules/serviceConstants.py index 7bbc1dd..b000fb1 100644 --- a/servicemodules/serviceConstants.py +++ b/servicemodules/serviceConstants.py @@ -26,3 +26,10 @@ def __hash__(self): default_italics_mark = "*" + + +def italicise(text: str, message) -> str: + if not text.strip(): + return text + im = service_italics_marks.get(message.service, default_italics_mark) + return f'{im}{text}{im}' diff --git a/stam.py b/stam.py index d00d2ec..584fb5a 100644 --- a/stam.py +++ b/stam.py @@ -44,12 +44,11 @@ def get_stampy_modules() -> dict[str, Module]: loaded_module_filenames = set() # filenames of modules that were skipped because not enabled - skipped_module_filenames = set(ALL_STAMPY_MODULES - enabled_modules) + skipped_module_filenames = ALL_STAMPY_MODULES - enabled_modules + if invalid_modules := enabled_modules - ALL_STAMPY_MODULES: + raise AssertionError(f"Non existent modules enabled!: {', '.join(invalid_modules)}") for filename in enabled_modules: - if filename not in ALL_STAMPY_MODULES: - raise AssertionError(f"Module {filename} enabled but doesn't exist!") - log.info("import", filename=filename) mod = __import__(f"modules.{filename}", fromlist=[filename]) log.info("import", module_name=mod) @@ -60,7 +59,7 @@ def get_stampy_modules() -> dict[str, Module]: # try instantiating it if it is a `Module`... if isinstance(cls, type) and issubclass(cls, Module) and cls is not Module: log.info("import Module Found", module_name=attr_name) - # unless it has a classmethod is_available, which in this particular situation returns False + # unless it has a staticmethod is_available, which in this particular situation returns False if ( (is_available := getattr(cls, "is_available", None)) and callable(is_available) diff --git a/utilities/flaskutils.py b/utilities/flaskutils.py index 782a45d..feeb3f4 100644 --- a/utilities/flaskutils.py +++ b/utilities/flaskutils.py @@ -1,9 +1,10 @@ -from servicemodules.serviceConstants import Services -from utilities.serviceutils import ServiceUser, ServiceServer, ServiceChannel, ServiceMessage -from typing import TYPE_CHECKING +import json import threading import time - +from typing import TYPE_CHECKING +from utilities.serviceutils import ServiceUser, ServiceServer, ServiceChannel, ServiceMessage +from servicemodules.serviceConstants import Services +from servicemodules.discordConstants import wiki_feed_channel_id if TYPE_CHECKING: from servicemodules.flask import FlaskHandler @@ -47,8 +48,7 @@ def kill_thread(event: threading.Event, thread: "FlaskHandler"): class FlaskUser(ServiceUser): def __init__(self, key: str): - id = str(key) - super().__init__("User", "User", id) + super().__init__("User", "User", str(key)) class FlaskServer(ServiceServer): @@ -59,15 +59,38 @@ def __init__(self, key: str): class FlaskChannel(ServiceChannel): - def __init__(self, server: FlaskServer): - super().__init__("Web Interface", "flask_api", server) + def __init__(self, server: FlaskServer, channel=None): + super().__init__("Web Interface", channel or "flask_api", server) class FlaskMessage(ServiceMessage): - def __init__(self, msg): - self._message = msg - server = FlaskServer(msg["key"]) - id = str(time.time()) - service = Services.FLASK - super().__init__(id, msg["content"], FlaskUser(msg["key"]), FlaskChannel(server), service) - self.modules = msg["modules"] + + @staticmethod + def from_dict(data): + key = data.get('key') + if not key: + raise ValueError('No key provided') + + # FIXME: A very hacky way of allowing HTTP requests to claim to come from stampy + author = data.get('author') + if author == 'stampy': + author = FlaskUser(wiki_feed_channel_id) + else: + author = FlaskUser(key) + + modules = data.get('modules') + if not modules: + raise ValueError('No modules provided') + if isinstance(modules, str): + modules = json.loads(modules) + + msg = FlaskMessage( + content=data['content'], + service=Services.FLASK, + author=author, + channel=FlaskChannel(FlaskServer(key), data.get('channel')), + id=str(time.time()), + ) + msg.modules = modules + msg.clean_content = msg.content + return msg diff --git a/utilities/utilities.py b/utilities/utilities.py index ad1880e..f0223ee 100644 --- a/utilities/utilities.py +++ b/utilities/utilities.py @@ -135,13 +135,14 @@ def stampy_is_author(self, message: ServiceMessage) -> bool: return self.is_stampy(message.author) def is_stampy(self, user: ServiceUser) -> bool: - if ( - user.id == wiki_feed_channel_id - ): # consider wiki-feed ID as stampy to ignore -- is it better to set a wiki user? + if not user: + return False + # consider wiki-feed ID as stampy to ignore -- is it better to set a wiki user? + if user.id == wiki_feed_channel_id: return True if self.discord_user: return user == self.discord_user - if user.id == str(cast(discord.ClientUser, self.client.user).id): + if self.client.user and user.id == str(cast(discord.ClientUser, self.client.user).id): self.discord_user = user return True return False