Skip to content

Commit

Permalink
Session ids
Browse files Browse the repository at this point in the history
  • Loading branch information
mruwnik committed Sep 29, 2023
1 parent a5fdee2 commit 382430d
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 18 deletions.
16 changes: 9 additions & 7 deletions api/main.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from flask import Flask, jsonify, request, Response
from flask_cors import CORS, cross_origin
from urllib.parse import unquote
import dataclasses
import json
import re
from urllib.parse import unquote

from flask import Flask, jsonify, request, Response
from flask_cors import CORS, cross_origin

from stampy_chat import logging
from stampy_chat.env import PINECONE_INDEX, FLASK_PORT
Expand Down Expand Up @@ -42,11 +43,12 @@ def semantic():
@cross_origin()
def chat():

query = request.json['query']
mode = request.json['mode']
history = request.json['history']
query = request.json.get('query')
mode = request.json.get('mode', 'default')
session_id = request.json.get('sessionId')
history = request.json.get('history', [])

return Response(stream(talk_to_robot(PINECONE_INDEX, query, mode, history)), mimetype='text/event-stream')
return Response(stream(talk_to_robot(PINECONE_INDEX, query, mode, history, session_id)), mimetype='text/event-stream')


# ------------- simplified non-streaming chat for internal testing -------------
Expand Down
12 changes: 9 additions & 3 deletions api/src/stampy_chat/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,13 +156,19 @@ def remaining_tokens(prompt: Prompt):
return NUM_TOKENS - used_tokens - TOKENS_BUFFER


def talk_to_robot_internal(index, query: str, mode: str, history: Prompt, k: int = STANDARD_K):
def talk_to_robot_internal(index, query: str, mode: str, history: Prompt, session_id: str, k: int = STANDARD_K):
try:
# 1. Find the most relevant blocks from the Alignment Research Dataset
yield {"state": "loading", "phase": "semantic"}
top_k_blocks = get_top_k_blocks(index, query, k)

yield {"state": "loading", "phase": "semantic", 'citations': [{'title': block.title, 'author': block.authors, 'date': block.date, 'url': block.url} for block in top_k_blocks]}
yield {
"state": "loading", "phase": "semantic",
"citations": [
{'title': block.title, 'author': block.authors, 'date': block.date, 'url': block.url}
for block in top_k_blocks
]
}

# 2. Generate a prompt
yield {"state": "loading", "phase": "prompt"}
Expand Down Expand Up @@ -205,7 +211,7 @@ def talk_to_robot_internal(index, query: str, mode: str, history: Prompt, k: int
logger.debug(' ------------------------------ response: -----------------------------')
logger.debug(response)

logger.interaction(query, response, history, prompt, top_k_blocks)
logger.interaction(session_id, query, response, history, prompt, top_k_blocks)

# yield done state, possibly with followup questions
fin_json = {'state': 'done'}
Expand Down
4 changes: 2 additions & 2 deletions api/src/stampy_chat/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,13 @@ def __init__(self, *args, **kwargs):
def is_debug(self):
return self.isEnabledFor(DEBUG)

def interaction(self, query, response, history, prompt, blocks):
def interaction(self, session_id, query, response, history, prompt, blocks):
prompt = [i for i in prompt if i.get('role') == 'system']
prompt = prompt[0].get('content') if prompt else None

self.item_adder.add(
Interaction(
# session_id=session_id,
session_id=session_id,
interaction_no=len([i for i in history if i.get('role') == 'user']),
query=query,
prompt=prompt,
Expand Down
14 changes: 9 additions & 5 deletions web/src/hooks/useSearch.ts
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ export const extractAnswer = async (
};

const fetchLLM = async (
sessionId: string,
query: string,
mode: string,
history: HistoryEntry[]
Expand All @@ -110,18 +111,19 @@ const fetchLLM = async (
Accept: "text/event-stream",
},

body: JSON.stringify({ query, mode, history }),
body: JSON.stringify({ sessionId, query, mode, history }),
});

export const queryLLM = async (
query: string,
mode: string,
history: HistoryEntry[],
baseReferencesIndex: number,
setCurrent: (e?: CurrentSearch) => void
setCurrent: (e?: CurrentSearch) => void,
sessionId: string
): Promise<SearchResult> => {
// do SSE on a POST request.
const res = await fetchLLM(query, mode, history);
const res = await fetchLLM(sessionId, query, mode, history);

if (!res.ok) {
return { result: { role: "error", content: "POST Error: " + res.status } };
Expand Down Expand Up @@ -191,7 +193,8 @@ export const runSearch = async (
mode: string,
baseReferencesIndex: number,
entries: Entry[],
setCurrent: (c: CurrentSearch) => void
setCurrent: (c: CurrentSearch) => void,
sessionId: string
): SearchResult => {
if (query_source === "search") {
const history = entries
Expand All @@ -206,7 +209,8 @@ export const runSearch = async (
mode,
history,
baseReferencesIndex,
setCurrent
setCurrent,
sessionId
);
} else {
// ----------------- HUMAN AUTHORED CONTENT RETRIEVAL ------------------
Expand Down
6 changes: 5 additions & 1 deletion web/src/pages/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -45,19 +45,22 @@ const Home: NextPage = () => {
const [entries, setEntries] = useState<Entry[]>([]);
const [runningIndex, setRunningIndex] = useState(0);
const [current, setCurrent] = useState<CurrentSearch>();
const [sessionId, setSessionId] = useState()

// [state, ready to save to localstorage]
const [mode, setMode] = useState<[Mode, boolean]>(["default", false]);

// store mode in localstorage
useEffect(() => {
if (mode[1]) localStorage.setItem("chat_mode", mode[0]);

}, [mode]);

// initial load
useEffect(() => {
const mode = localStorage.getItem("chat_mode") as Mode || "default";
setMode([mode, true]);
setSessionId(crypto.randomUUID());
}, []);

const updateCurrent = (current: CurrentSearch) => {
Expand Down Expand Up @@ -88,7 +91,8 @@ const Home: NextPage = () => {
mode[0],
runningIndex,
entries,
updateCurrent
updateCurrent,
sessionId,
);
setCurrent(undefined);

Expand Down

0 comments on commit 382430d

Please sign in to comment.