Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stampy chat module #325

Merged
merged 6 commits into from
Nov 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,14 +61,14 @@ You'll need at least these:
- `DISCORD_GUILD`: your server ID
- `DATABASE_PATH`: the path to the Q&A database (normally in `./database/stampy.db`).
- `STAMPY_MODULES`: list of your desired modules, or leave unset to load all modules in the `./modules/` directory. You probably don't want all, as some of them aren't applicable to servers other than Rob's.
- `BOT_PRIVATE_CHANNEL_ID`: single channel where private Stampy status updates and info are sent

Not required:

- `BOT_VIP_IDS`: list of user IDs. VIPs have full access and some special permissions.
- `BOT_DEV_ROLES`: list of roles representing bot devs.
- `BOT_DEV_IDS`: list of user ids of bot devs. You may want to include `BOT_VIP_IDS` here.
- `BOT_CONTROL_CHANNEL_IDS`: list of channels where control commands are accepted.
- `BOT_PRIVATE_CHANNEL_ID`: single channel where private Stampy status updates are sent
- `BOT_ERROR_CHANNEL_ID`: (defaults to private channel) low level error tracebacks from Python. with this variable they can be shunted to a seperate channel.
- `CODA_API_TOKEN`: token to access Coda. Without it, modules `Questions` and `QuestionSetter` will not be available and `StampyControls` will have limited functionality.
- `BOT_REBOOT`: how Stampy reboots himself. Unset, he only quits, expecting an external `while true` loop (like in `runstampy`/Dockerfile). Set to `exec` he will try to relaunch himself from his own CLI arguments.
Expand Down
34 changes: 19 additions & 15 deletions config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
from typing import Literal, TypeVar, Optional, Union, cast, get_args, overload, Any, Tuple
from pathlib import Path

import dotenv
from structlog import get_logger
Expand All @@ -10,16 +11,15 @@
dotenv.load_dotenv()
NOT_PROVIDED = "__NOT_PROVIDED__"

module_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "modules")
module_dir = Path(__file__).parent / 'modules'


def get_all_modules() -> frozenset[str]:
modules = set()
for file_name in os.listdir(module_dir):
if file_name.endswith(".py") and file_name not in ("__init__.py", "module.py"):
modules.add(file_name[:-3])

return frozenset(modules)
return frozenset({
filename.stem
for filename in module_dir.glob('*.py')
if filename.suffix == '.py' and filename.name not in ('__init__.py', 'module.py')
})


ALL_STAMPY_MODULES = get_all_modules()
Expand Down Expand Up @@ -47,8 +47,7 @@ def getenv(env_var: str, default = NOT_PROVIDED) -> str:


def getenv_bool(env_var: str) -> bool:
e = getenv(env_var, default="UNDEFINED")
return e != "UNDEFINED"
return getenv(env_var, default="UNDEFINED") != "UNDEFINED"


# fmt:off
Expand All @@ -64,12 +63,12 @@ def getenv_unique_set(var_name: str, default: T) -> Union[frozenset[str], T]:...


def getenv_unique_set(var_name: str, default: T = frozenset()) -> Union[frozenset, T]:
l = getenv(var_name, default="EMPTY_SET").split(" ")
if l == ["EMPTY_SET"]:
var = getenv(var_name, default='')
if not var.strip():
return default
s = frozenset(l)
assert len(l) == len(s), f"{var_name} has duplicate members! {l}"
return s
items = var.split()
assert len(items) == len(set(items)), f"{var_name} has duplicate members! {sorted(items)}"
return frozenset(items)


maximum_recursion_depth = 30
Expand Down Expand Up @@ -150,6 +149,11 @@ def getenv_unique_set(var_name: str, default: T = frozenset()) -> Union[frozense
channel_whitelist: Optional[frozenset[str]]
disable_prompt_moderation: bool

## Flask settings
if flask_port := getenv('FLASK_PORT', '2300'):
flask_port = int(flask_port)
flask_address = getenv('FLASK_ADDRESS', "0.0.0.0")

is_rob_server = getenv_bool("IS_ROB_SERVER")
if is_rob_server:
# use robmiles server defaults
Expand Down Expand Up @@ -233,7 +237,7 @@ def getenv_unique_set(var_name: str, default: T = frozenset()) -> Union[frozense
bot_dev_roles = getenv_unique_set("BOT_DEV_ROLES", frozenset())
bot_dev_ids = getenv_unique_set("BOT_DEV_IDS", frozenset())
bot_control_channel_ids = getenv_unique_set("BOT_CONTROL_CHANNEL_IDS", frozenset())
bot_private_channel_id = getenv("BOT_PRIVATE_CHANNEL_ID", '')
bot_private_channel_id = getenv("BOT_PRIVATE_CHANNEL_ID")
bot_error_channel_id = getenv("BOT_ERROR_CHANNEL_ID", bot_private_channel_id)
# NOTE: Rob's invite/member management functions, not ported yet
member_role_id = getenv("MEMBER_ROLE_ID", default=None)
Expand Down
7 changes: 3 additions & 4 deletions modules/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,8 @@ def is_at_me(self, message: ServiceMessage) -> Union[str, Literal[False]]:
r",? @?[sS](tampy)?(?P<punctuation>[.!?]*)$", r"\g<punctuation>", text
)
at_me = True
elif re.search(r'^[sS]tamp[ys]?\?', text):
at_me = True

if message.is_dm:
# DMs are always at you
Expand All @@ -255,10 +257,7 @@ def is_at_me(self, message: ServiceMessage) -> Union[str, Literal[False]]:
)
at_me = True

if at_me:
return text
else:
return False
return at_me and text

def get_guild_and_invite_role(self):
return get_guild_and_invite_role()
Expand Down
172 changes: 172 additions & 0 deletions modules/stampy_chat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
"""
Queries chat.stampy.ai with the user's question.

"""

import json
import re
from collections import deque, defaultdict
from typing import Iterable, List, Dict, Any
from uuid import uuid4

import requests
from structlog import get_logger

from modules.module import Module, Response
from servicemodules.serviceConstants import italicise
from utilities.serviceutils import ServiceChannel, ServiceMessage
from utilities.utilities import Utilities

log = get_logger()
utils = Utilities.get_instance()


LOG_MAX_MESSAGES = 15 # don't store more than X messages back
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these are Discord messages, counted per channel

DATA_HEADER = 'data: '

STAMPY_CHAT_ENDPOINT = "https://chat.stampy.ai:8443/chat"
NLP_SEARCH_ENDPOINT = "https://nlp.stampy.ai"

STAMPY_ANSWER_MIN_SCORE = 0.75
STAMPY_CHAT_MIN_SCORE = 0.4
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NLP must return something with at least this score for the module to do anything. Should filter out most messages, but might need to be fiddled with a bit. Or maybe it would be worth there also being an explicit way of triggering this module with a specific phrase or something?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suspect the STAMPY_CHAT_MIN_SCORE = 0.4 can actually be an even lower value (0.2?) but you'll want to experiment.



def stream_lines(stream: Iterable):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The chatbot server returns messages as server-sent events, so these 2 functions basically transform a requests stream into a generator of js objects ready for using

line = ''
for item in stream:
item = item.decode('utf8')
line += item
if '\n' in line:
lines = line.split('\n')
line = lines[-1]
for l in lines[:-1]:
yield l
yield line


def parse_data_items(stream: Iterable):
for item in stream:
if item.strip().startswith(DATA_HEADER):
yield json.loads(item.split(DATA_HEADER)[1])


def top_nlp_search(query: str) -> Dict[str, Any]:
resp = requests.get(NLP_SEARCH_ENDPOINT + '/api/search', params={'query': query, 'status': 'all'})
Copy link
Collaborator Author

@mruwnik mruwnik Nov 1, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ccstan99 this 'status': 'all' does the same as showLive=0

if not resp:
return {}

items = resp.json()
if not items:
return {}
return items[0]


def chunk_text(text: str, chunk_limit=2000, delimiter='.'):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Discord has a limit of 2000 characters per message (or at least other places in the code claim this), so this function will split the LLM's answer into smaller chunks, splitting them on full stops in order to make sure that sentences don't get chopped up. Though maybe newlines would be better, so it doesn't split on decimal points?

chunk = ''
for sentence in text.split(delimiter):
if len(chunk + sentence) + 1 >= chunk_limit and chunk and sentence:
yield chunk
chunk = sentence + delimiter
elif sentence:
chunk += sentence + delimiter
yield chunk


def filter_citations(text, citations):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the chatbot returns a whole bunch of potential citations, but not all of them will be referenced in the text. This will remove the unused ones

used_citations = re.findall(r'\[([a-z],? ?)*?\]', text)
return [c for c in citations if c.get('reference') in used_citations]


class StampyChat(Module):

def __init__(self):
self.utils = Utilities.get_instance()
self._messages: dict[ServiceChannel, deque[ServiceMessage]] = defaultdict(lambda: deque(maxlen=LOG_MAX_MESSAGES))
self.session_id = str(uuid4())
super().__init__()

@property
def class_name(self):
return 'stampy_chat'

def format_message(self, message: ServiceMessage):
return {
'content': message.content,
'role': 'assistant' if self.utils.stampy_is_author(message) else 'user',
}

def stream_chat_response(self, query: str, history: List[ServiceMessage]):
return parse_data_items(stream_lines(requests.post(STAMPY_CHAT_ENDPOINT, stream=True, json={
'query': query,
'history': [self.format_message(m) for m in history],
'sessionId': self.session_id,
'settings': {'mode': 'discord'},
})))

def get_chat_response(self, query: str, history: List[ServiceMessage]):
response = {'citations': [], 'content': '', 'followups': []}
for item in self.stream_chat_response(query, history):
if item.get('state') == 'citations':
response['citations'] += item.get('citations', [])
elif item.get('state') == 'streaming':
response['content'] += item.get('content', '')
elif item.get('state') == 'followups':
response['followups'] += item.get('followups', [])
response['citations'] = filter_citations(response['content'], response['citations'])
return response

async def query(self, query: str, history: List[ServiceMessage], message: ServiceMessage):
log.info('calling %s', query)
chat_response = self.get_chat_response(query, history)
content_chunks = list(chunk_text(chat_response['content']))
citations = [f'[{c["reference"]}] - {c["title"]} ({c["url"]})' for c in chat_response['citations'] if c.get('reference')]
if citations:
citations = ['Citations: \n' + '\n'.join(citations)]
followups = []
if follows := chat_response['followups']:
followups = [
'Checkout these articles for more info: \n' + '\n'.join(
f'{f["text"]} - https://aisafety.info?state={f["pageid"]}' for f in follows
)
]

log.info('response: %s', content_chunks + citations + followups)
return Response(
confidence=10,
text=[italicise(text, message) for text in content_chunks + citations + followups],
why='This is what the chat bot returned'
)

def _add_message(self, message: ServiceMessage) -> deque[ServiceMessage]:
self._messages[message.channel].append(message)
return self._messages[message.channel]

def make_query(self, messages):
if not messages:
return '', messages

current = messages[-1]
query, history = '', list(messages)
while message := (history and history.pop()):
if message.author != current.author:
break
query = message.content + ' ' + query
current = message
return query, history

def process_message(self, message: ServiceMessage) -> Response:
history = self._add_message(message)

if not self.is_at_me(message):
return Response()

query, history = self.make_query(history)
nlp = top_nlp_search(query)
if nlp.get('score', 0) > STAMPY_ANSWER_MIN_SCORE and nlp.get('status') == 'Live on site':
return Response(confidence=5, text=f'Check out {nlp.get("url")} ({nlp.get("title")})')
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this will also get picked up by the semanticsearch module, which has a lower threshold (0.5) and higher confidence (8)

if nlp.get('score', 0) > STAMPY_CHAT_MIN_SCORE:
return Response(confidence=6, callback=self.query, args=[query, history, message])
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this the right confidence?

return Response()

def process_message_from_stampy(self, message: ServiceMessage):
self._add_message(message)
4 changes: 3 additions & 1 deletion servicemodules/discord.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,15 @@


# TODO: store long responses temporarily for viewing outside of discord
def limit_text_and_notify(response: Response, why_traceback: list[str]) -> str:
def limit_text_and_notify(response: Response, why_traceback: list[str]) -> Union[str, Iterable]:
if isinstance(response.text, str):
wastrimmed = False
wastrimmed, text_to_return = limit_text(response.text, discordLimit)
if wastrimmed:
why_traceback.append(f"I had to trim the output from {response.module}")
return text_to_return
elif isinstance(response.text, (list, tuple)):
return response.text
return ""


Expand Down
Loading
Loading