-
Notifications
You must be signed in to change notification settings - Fork 9
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Stampy chat module #325
Stampy chat module #325
Changes from all commits
94155ee
983f2b3
170e261
3626f40
814c128
a26d2fa
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,172 @@ | ||
""" | ||
Queries chat.stampy.ai with the user's question. | ||
|
||
""" | ||
|
||
import json | ||
import re | ||
from collections import deque, defaultdict | ||
from typing import Iterable, List, Dict, Any | ||
from uuid import uuid4 | ||
|
||
import requests | ||
from structlog import get_logger | ||
|
||
from modules.module import Module, Response | ||
from servicemodules.serviceConstants import italicise | ||
from utilities.serviceutils import ServiceChannel, ServiceMessage | ||
from utilities.utilities import Utilities | ||
|
||
log = get_logger() | ||
utils = Utilities.get_instance() | ||
|
||
|
||
LOG_MAX_MESSAGES = 15 # don't store more than X messages back | ||
DATA_HEADER = 'data: ' | ||
|
||
STAMPY_CHAT_ENDPOINT = "https://chat.stampy.ai:8443/chat" | ||
NLP_SEARCH_ENDPOINT = "https://nlp.stampy.ai" | ||
|
||
STAMPY_ANSWER_MIN_SCORE = 0.75 | ||
STAMPY_CHAT_MIN_SCORE = 0.4 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. NLP must return something with at least this score for the module to do anything. Should filter out most messages, but might need to be fiddled with a bit. Or maybe it would be worth there also being an explicit way of triggering this module with a specific phrase or something? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I suspect the |
||
|
||
|
||
def stream_lines(stream: Iterable): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The chatbot server returns messages as server-sent events, so these 2 functions basically transform a |
||
line = '' | ||
for item in stream: | ||
item = item.decode('utf8') | ||
line += item | ||
if '\n' in line: | ||
lines = line.split('\n') | ||
line = lines[-1] | ||
for l in lines[:-1]: | ||
yield l | ||
yield line | ||
|
||
|
||
def parse_data_items(stream: Iterable): | ||
for item in stream: | ||
if item.strip().startswith(DATA_HEADER): | ||
yield json.loads(item.split(DATA_HEADER)[1]) | ||
|
||
|
||
def top_nlp_search(query: str) -> Dict[str, Any]: | ||
resp = requests.get(NLP_SEARCH_ENDPOINT + '/api/search', params={'query': query, 'status': 'all'}) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @ccstan99 this |
||
if not resp: | ||
return {} | ||
|
||
items = resp.json() | ||
if not items: | ||
return {} | ||
return items[0] | ||
|
||
|
||
def chunk_text(text: str, chunk_limit=2000, delimiter='.'): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Discord has a limit of 2000 characters per message (or at least other places in the code claim this), so this function will split the LLM's answer into smaller chunks, splitting them on full stops in order to make sure that sentences don't get chopped up. Though maybe newlines would be better, so it doesn't split on decimal points? |
||
chunk = '' | ||
for sentence in text.split(delimiter): | ||
if len(chunk + sentence) + 1 >= chunk_limit and chunk and sentence: | ||
yield chunk | ||
chunk = sentence + delimiter | ||
elif sentence: | ||
chunk += sentence + delimiter | ||
yield chunk | ||
|
||
|
||
def filter_citations(text, citations): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the chatbot returns a whole bunch of potential citations, but not all of them will be referenced in the text. This will remove the unused ones |
||
used_citations = re.findall(r'\[([a-z],? ?)*?\]', text) | ||
return [c for c in citations if c.get('reference') in used_citations] | ||
|
||
|
||
class StampyChat(Module): | ||
|
||
def __init__(self): | ||
self.utils = Utilities.get_instance() | ||
self._messages: dict[ServiceChannel, deque[ServiceMessage]] = defaultdict(lambda: deque(maxlen=LOG_MAX_MESSAGES)) | ||
self.session_id = str(uuid4()) | ||
super().__init__() | ||
|
||
@property | ||
def class_name(self): | ||
return 'stampy_chat' | ||
|
||
def format_message(self, message: ServiceMessage): | ||
return { | ||
'content': message.content, | ||
'role': 'assistant' if self.utils.stampy_is_author(message) else 'user', | ||
} | ||
|
||
def stream_chat_response(self, query: str, history: List[ServiceMessage]): | ||
return parse_data_items(stream_lines(requests.post(STAMPY_CHAT_ENDPOINT, stream=True, json={ | ||
'query': query, | ||
'history': [self.format_message(m) for m in history], | ||
'sessionId': self.session_id, | ||
'settings': {'mode': 'discord'}, | ||
}))) | ||
|
||
def get_chat_response(self, query: str, history: List[ServiceMessage]): | ||
response = {'citations': [], 'content': '', 'followups': []} | ||
for item in self.stream_chat_response(query, history): | ||
if item.get('state') == 'citations': | ||
response['citations'] += item.get('citations', []) | ||
elif item.get('state') == 'streaming': | ||
response['content'] += item.get('content', '') | ||
elif item.get('state') == 'followups': | ||
response['followups'] += item.get('followups', []) | ||
response['citations'] = filter_citations(response['content'], response['citations']) | ||
return response | ||
|
||
async def query(self, query: str, history: List[ServiceMessage], message: ServiceMessage): | ||
log.info('calling %s', query) | ||
chat_response = self.get_chat_response(query, history) | ||
content_chunks = list(chunk_text(chat_response['content'])) | ||
citations = [f'[{c["reference"]}] - {c["title"]} ({c["url"]})' for c in chat_response['citations'] if c.get('reference')] | ||
if citations: | ||
citations = ['Citations: \n' + '\n'.join(citations)] | ||
followups = [] | ||
if follows := chat_response['followups']: | ||
followups = [ | ||
'Checkout these articles for more info: \n' + '\n'.join( | ||
f'{f["text"]} - https://aisafety.info?state={f["pageid"]}' for f in follows | ||
) | ||
] | ||
|
||
log.info('response: %s', content_chunks + citations + followups) | ||
return Response( | ||
confidence=10, | ||
text=[italicise(text, message) for text in content_chunks + citations + followups], | ||
why='This is what the chat bot returned' | ||
) | ||
|
||
def _add_message(self, message: ServiceMessage) -> deque[ServiceMessage]: | ||
self._messages[message.channel].append(message) | ||
return self._messages[message.channel] | ||
|
||
def make_query(self, messages): | ||
if not messages: | ||
return '', messages | ||
|
||
current = messages[-1] | ||
query, history = '', list(messages) | ||
while message := (history and history.pop()): | ||
if message.author != current.author: | ||
break | ||
query = message.content + ' ' + query | ||
current = message | ||
return query, history | ||
|
||
def process_message(self, message: ServiceMessage) -> Response: | ||
history = self._add_message(message) | ||
|
||
if not self.is_at_me(message): | ||
return Response() | ||
|
||
query, history = self.make_query(history) | ||
nlp = top_nlp_search(query) | ||
if nlp.get('score', 0) > STAMPY_ANSWER_MIN_SCORE and nlp.get('status') == 'Live on site': | ||
return Response(confidence=5, text=f'Check out {nlp.get("url")} ({nlp.get("title")})') | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this will also get picked up by the semanticsearch module, which has a lower threshold (0.5) and higher confidence (8) |
||
if nlp.get('score', 0) > STAMPY_CHAT_MIN_SCORE: | ||
return Response(confidence=6, callback=self.query, args=[query, history, message]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is this the right confidence? |
||
return Response() | ||
|
||
def process_message_from_stampy(self, message: ServiceMessage): | ||
self._add_message(message) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
these are Discord messages, counted per channel