Skip to content

Commit

Permalink
Merge pull request #114 from StampyAI/chat-options
Browse files Browse the repository at this point in the history
Chat options
  • Loading branch information
mruwnik authored Oct 7, 2023
2 parents 1b43c28 + e4f60fb commit 3a502ca
Show file tree
Hide file tree
Showing 7 changed files with 255 additions and 157 deletions.
22 changes: 15 additions & 7 deletions api/src/stampy_chat/settings.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
from collections import namedtuple
import tiktoken

from stampy_chat.env import COMPLETIONS_MODEL


Model = namedtuple('Model', ['maxTokens', 'topKBlocks'])


SOURCE_PROMPT = (
"You are a helpful assistant knowledgeable about AI Alignment and Safety. "
"Please give a clear and coherent answer to the user's questions.(written after \"Q:\") "
Expand Down Expand Up @@ -44,6 +48,12 @@
'question': QUESTION_PROMPT,
'modes': PROMPT_MODES,
}
MODELS = {
'gpt-3.5-turbo': Model(4097, 10),
'gpt-3.5-turbo-16k': Model(16385, 30),
'gpt-4': Model(8192, 20),
# 'gpt-4-32k': Model(32768, 30),
}


class Settings:
Expand Down Expand Up @@ -99,23 +109,21 @@ def encoder(self, value):
self.encoders[value] = tiktoken.get_encoding(value)

def set_completions(self, completions, numTokens=None, topKBlocks=None):
if completions not in MODELS:
raise ValueError(f'Unknown model: {completions}')
self.completions = completions

# Set the max number of tokens sent in the prompt
# Set the max number of tokens sent in the prompt - see https://platform.openai.com/docs/models/gpt-4
if numTokens is not None:
self.numTokens = numTokens
elif completions == 'gtp-4':
self.numTokens = 8191
else:
self.numTokens = 4095
self.numTokens = MODELS[completions].maxTokens

# Set the max number of blocks used as citations
if topKBlocks is not None:
self.topKBlocks = topKBlocks
elif completions == 'gtp-4':
self.topKBlocks = 20
else:
self.topKBlocks = 10
self.topKBlocks = MODELS[completions].topKBlocks

@property
def prompt_modes(self):
Expand Down
4 changes: 2 additions & 2 deletions api/tests/stampy_chat/test_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,14 +255,14 @@ def test_check_openai_moderation_not_flagged():


@pytest.mark.parametrize('prompt, remaining', (
([{'role': 'system', 'content': 'bla'}], 4043),
([{'role': 'system', 'content': 'bla'}], 4045),
(
[
{'role': 'system', 'content': 'bla'},
{'role': 'user', 'content': 'message 1'},
{'role': 'assistant', 'content': 'response 1'},
],
4035
4037
),
(
[
Expand Down
21 changes: 12 additions & 9 deletions web/src/components/chat.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -103,30 +103,33 @@ const Chat = ({ sessionId, settings, onQuery, onNewEntry }: ChatParams) => {
const search = async (
query: string,
query_source: "search" | "followups",
disable: () => void,
enable: (f_set: Followup[] | ((fs: Followup[]) => Followup[])) => void
enable: (f_set: Followup[] | ((fs: Followup[]) => Followup[])) => void,
controller: AbortController
) => {
// clear the query box, append to entries
const userEntry: Entry = {
role: "user",
content: query_source === "search" ? query : query.split("\n", 2)[1]!,
};
addEntry(userEntry);
disable();

const { result, followups } = await runSearch(
query,
query_source,
settings,
entries,
updateCurrent,
sessionId
sessionId,
controller
);
if (result.content !== "aborted") {
addEntry(userEntry);
addEntry(result);
enable(followups || []);
scroll30();
} else {
enable([]);
}
setCurrent(undefined);

addEntry(result);
enable(followups || []);
scroll30();
};

var last_entry = <></>;
Expand Down
65 changes: 35 additions & 30 deletions web/src/components/searchbox.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ const SearchBoxInternal: React.FC<{
search: (
query: string,
query_source: "search" | "followups",
disable: () => void,
enable: (f_set: Followup[] | ((fs: Followup[]) => Followup[])) => void
enable: (f_set: Followup[] | ((fs: Followup[]) => Followup[])) => void,
controller: AbortController
) => void;
onQuery?: (q: string) => any;
}> = ({ search, onQuery }) => {
Expand All @@ -43,20 +43,21 @@ const SearchBoxInternal: React.FC<{
const [query, setQuery] = useState(initial_query);
const [loading, setLoading] = useState(false);
const [followups, setFollowups] = useState<Followup[]>([]);
const [controller, setController] = useState(new AbortController());

const inputRef = React.useRef<HTMLTextAreaElement>(null);

// because everything is async, I can't just manually set state at the
// point we do a search. Instead it needs to be passed into the search
// method, for some reason.
const enable = (f_set: Followup[] | ((fs: Followup[]) => Followup[])) => {
setLoading(false);
setFollowups(f_set);
};
const disable = () => {
setLoading(true);
setQuery("");
};
const enable =
(controller: AbortController) =>
(f_set: Followup[] | ((fs: Followup[]) => Followup[])) => {
if (!controller.signal.aborted) setQuery("");

setLoading(false);
setFollowups(f_set);
};

useEffect(() => {
// set focus on the input box
Expand All @@ -71,7 +72,17 @@ const SearchBoxInternal: React.FC<{
inputRef.current.selectionEnd = inputRef.current.textLength;
}, []);

if (loading) return <></>;
const runSearch =
(query: string, searchtype: "search" | "followups") => () => {
if (loading || query.trim() === "") return;

setLoading(true);
const controller = new AbortController();
setController(controller);
search(query, searchtype, enable(controller), controller);
};
const cancelSearch = () => controller.abort();

return (
<>
<div className="mt-1 flex flex-col items-end">
Expand All @@ -81,14 +92,10 @@ const SearchBoxInternal: React.FC<{
<li key={i}>
<button
className="my-1 border border-gray-300 px-1"
onClick={() => {
search(
followup.pageid + "\n" + followup.text,
"followups",
disable,
enable
);
}}
onClick={runSearch(
followup.pageid + "\n" + followup.text,
"followups"
)}
>
<span> {followup.text} </span>
</button>
Expand All @@ -97,13 +104,7 @@ const SearchBoxInternal: React.FC<{
})}
</div>

<form
className="mt-1 mb-2 flex"
onSubmit={(e) => {
e.preventDefault();
search(query, "search", disable, enable);
}}
>
<div className="mt-1 mb-2 flex">
<TextareaAutosize
className="flex-1 resize-none border border-gray-300 px-1"
ref={inputRef}
Expand All @@ -118,14 +119,18 @@ const SearchBoxInternal: React.FC<{
// if <enter> without <shift>, submit the form (if it's not empty)
if (e.key === "Enter" && !e.shiftKey) {
e.preventDefault();
if (query.trim() !== "") search(query, "search", disable, enable);
runSearch(query, "search")();
}
}}
/>
<button className="ml-2" type="submit" disabled={loading}>
{loading ? "Loading..." : "Search"}
<button
className="ml-2"
type="button"
onClick={loading ? cancelSearch : runSearch(query, "search")}
>
{loading ? "Cancel" : "Search"}
</button>
</form>
</div>
</>
);
};
Expand Down
50 changes: 38 additions & 12 deletions web/src/hooks/useSearch.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,12 @@ type HistoryEntry = {
content: string;
};

const ignoreAbort = (error: Error) => {
if (error.name !== "AbortError") {
throw error;
}
};

export async function* iterateData(res: Response) {
const reader = res.body!.getReader();
var message = "";
Expand Down Expand Up @@ -104,9 +110,11 @@ const fetchLLM = async (
sessionId: string,
query: string,
settings: LLMSettings,
history: HistoryEntry[]
): Promise<Response> =>
history: HistoryEntry[],
controller: AbortController
): Promise<Response | void> =>
fetch(API_URL + "/chat", {
signal: controller.signal,
method: "POST",
cache: "no-cache",
keepalive: true,
Expand All @@ -116,25 +124,31 @@ const fetchLLM = async (
},

body: JSON.stringify({ sessionId, query, history, settings }),
});
}).catch(ignoreAbort);

export const queryLLM = async (
query: string,
settings: LLMSettings,
history: HistoryEntry[],
setCurrent: (e?: CurrentSearch) => void,
sessionId: string
sessionId: string,
controller: AbortController
): Promise<SearchResult> => {
// do SSE on a POST request.
const res = await fetchLLM(sessionId, query, settings, history);
const res = await fetchLLM(sessionId, query, settings, history, controller);

if (!res.ok) {
if (!res) {
return { result: { role: "error", content: "No response from server" } };
} else if (!res.ok) {
return { result: { role: "error", content: "POST Error: " + res.status } };
}

try {
return await extractAnswer(res, setCurrent);
} catch (e) {
if ((e as Error)?.name === "AbortError") {
return { result: { role: "error", content: "aborted" } };
}
return {
result: { role: "error", content: e ? e.toString() : "unknown error" },
};
Expand All @@ -149,17 +163,21 @@ const cleanStampyContent = (contents: string) =>
);

export const getStampyContent = async (
questionId: string
questionId: string,
controller: AbortController
): Promise<SearchResult> => {
const res = await fetch(`${STAMPY_CONTENT_URL}/${questionId}`, {
method: "GET",
signal: controller.signal,
headers: {
"Content-Type": "application/json",
Accept: "application/json",
},
});
}).catch(ignoreAbort);

if (!res.ok) {
if (!res) {
return { result: { role: "error", content: "No response from server" } };
} else if (!res.ok) {
return { result: { role: "error", content: "POST Error: " + res.status } };
}

Expand Down Expand Up @@ -198,7 +216,8 @@ export const runSearch = async (
settings: LLMSettings,
entries: Entry[],
setCurrent: (c: CurrentSearch) => void,
sessionId: string
sessionId: string,
controller: AbortController
): Promise<SearchResult> => {
if (query_source === "search") {
const history = entries
Expand All @@ -208,12 +227,19 @@ export const runSearch = async (
content: entry.content.trim(),
}));

return await queryLLM(query, settings, history, setCurrent, sessionId);
return await queryLLM(
query,
settings,
history,
setCurrent,
sessionId,
controller
);
} else {
// ----------------- HUMAN AUTHORED CONTENT RETRIEVAL ------------------
const [questionId] = query.split("\n", 2);
if (questionId) {
return await getStampyContent(questionId);
return await getStampyContent(questionId, controller);
}
const result = {
role: "error",
Expand Down
Loading

0 comments on commit 3a502ca

Please sign in to comment.