diff --git a/api/main.py b/api/main.py index e35ac43..a94fb62 100644 --- a/api/main.py +++ b/api/main.py @@ -54,8 +54,10 @@ def chat(): session_id = request.json.get('sessionId') history = request.json.get('history', []) settings = request.json.get('settings', {}) + followups = request.json.get('followups', True) + as_stream = request.json.get('stream', True) - if query is None: + if query is None and history: query = history[-1].get('content') history = history[:-1] @@ -65,7 +67,10 @@ def formatter(item): return json.dumps(item) def run(callback): - return run_query(session_id, query, history, Settings(**settings), callback) + return run_query(session_id, query, history, Settings(**settings), callback, followups) + + if not as_stream: + return jsonify(run(None)['text']) return Response(stream_with_context(stream(stream_callback(run, formatter))), mimetype='text/event-stream') diff --git a/api/src/stampy_chat/chat.py b/api/src/stampy_chat/chat.py index c14bd33..bc4537b 100644 --- a/api/src/stampy_chat/chat.py +++ b/api/src/stampy_chat/chat.py @@ -226,7 +226,7 @@ def merge_history(history): return messages -def run_query(session_id: str, query: str, history: List[Dict], settings: Settings, callback: Callable[[Any], None] = None) -> Dict[str, str]: +def run_query(session_id: str, query: str, history: List[Dict], settings: Settings, callback: Callable[[Any], None] = None, followups=True) -> Dict[str, str]: """Execute the query. :param str query: the phrase that was input by the user @@ -252,7 +252,9 @@ def run_query(session_id: str, query: str, history: List[Dict], settings: Settin verbose=False, prompt=make_prompt(settings, chat_model, callbacks), memory=make_memory(settings, history, callbacks) - ) | StampyChain(callbacks=callbacks) + ) + if followups: + chain = chain | StampyChain(callbacks=callbacks) result = chain.invoke({"query": query, 'history': history}, {'callbacks': []}) if callback: callback({'state': 'done'})