Skip to content

Commit

Permalink
Allow non streaming
Browse files Browse the repository at this point in the history
  • Loading branch information
mruwnik committed Dec 19, 2023
1 parent c663939 commit 9c808ce
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 4 deletions.
9 changes: 7 additions & 2 deletions api/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,10 @@ def chat():
session_id = request.json.get('sessionId')
history = request.json.get('history', [])
settings = request.json.get('settings', {})
followups = request.json.get('followups', True)
as_stream = request.json.get('stream', True)

if query is None:
if query is None and history:
query = history[-1].get('content')
history = history[:-1]

Expand All @@ -65,7 +67,10 @@ def formatter(item):
return json.dumps(item)

def run(callback):
return run_query(session_id, query, history, Settings(**settings), callback)
return run_query(session_id, query, history, Settings(**settings), callback, followups)

if not as_stream:
return jsonify(run(None)['text'])

return Response(stream_with_context(stream(stream_callback(run, formatter))), mimetype='text/event-stream')

Expand Down
6 changes: 4 additions & 2 deletions api/src/stampy_chat/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ def merge_history(history):
return messages


def run_query(session_id: str, query: str, history: List[Dict], settings: Settings, callback: Callable[[Any], None] = None) -> Dict[str, str]:
def run_query(session_id: str, query: str, history: List[Dict], settings: Settings, callback: Callable[[Any], None] = None, followups=True) -> Dict[str, str]:
"""Execute the query.
:param str query: the phrase that was input by the user
Expand All @@ -252,7 +252,9 @@ def run_query(session_id: str, query: str, history: List[Dict], settings: Settin
verbose=False,
prompt=make_prompt(settings, chat_model, callbacks),
memory=make_memory(settings, history, callbacks)
) | StampyChain(callbacks=callbacks)
)
if followups:
chain = chain | StampyChain(callbacks=callbacks)
result = chain.invoke({"query": query, 'history': history}, {'callbacks': []})
if callback:
callback({'state': 'done'})
Expand Down

0 comments on commit 9c808ce

Please sign in to comment.