Skip to content

Commit

Permalink
feat: add from_variable_selector for stream chunk / message event (la…
Browse files Browse the repository at this point in the history
  • Loading branch information
takatost authored Sep 10, 2024
1 parent fdbbdb7 commit cee0c51
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 6 deletions.
4 changes: 3 additions & 1 deletion api/core/app/apps/advanced_chat/generate_task_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,7 +451,9 @@ def _process_stream_response(
tts_publisher.publish(message=queue_message)

self._task_state.answer += delta_text
yield self._message_to_stream_response(delta_text, self._message.id)
yield self._message_to_stream_response(
answer=delta_text, message_id=self._message.id, from_variable_selector=event.from_variable_selector
)
elif isinstance(event, QueueMessageReplaceEvent):
# published by moderation
yield self._message_replace_to_stream_response(answer=event.text)
Expand Down
11 changes: 8 additions & 3 deletions api/core/app/apps/workflow/generate_task_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,9 @@ def _process_stream_response(
tts_publisher.publish(message=queue_message)

self._task_state.answer += delta_text
yield self._text_chunk_to_stream_response(delta_text)
yield self._text_chunk_to_stream_response(
delta_text, from_variable_selector=event.from_variable_selector
)
else:
continue

Expand Down Expand Up @@ -412,14 +414,17 @@ def _save_workflow_app_log(self, workflow_run: WorkflowRun) -> None:
db.session.commit()
db.session.close()

def _text_chunk_to_stream_response(self, text: str) -> TextChunkStreamResponse:
def _text_chunk_to_stream_response(
self, text: str, from_variable_selector: Optional[list[str]] = None
) -> TextChunkStreamResponse:
"""
Handle completed event.
:param text: text
:return:
"""
response = TextChunkStreamResponse(
task_id=self._application_generate_entity.task_id, data=TextChunkStreamResponse.Data(text=text)
task_id=self._application_generate_entity.task_id,
data=TextChunkStreamResponse.Data(text=text, from_variable_selector=from_variable_selector),
)

return response
2 changes: 2 additions & 0 deletions api/core/app/entities/task_entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ class MessageStreamResponse(StreamResponse):
event: StreamEvent = StreamEvent.MESSAGE
id: str
answer: str
from_variable_selector: Optional[list[str]] = None


class MessageAudioStreamResponse(StreamResponse):
Expand Down Expand Up @@ -479,6 +480,7 @@ class Data(BaseModel):
"""

text: str
from_variable_selector: Optional[list[str]] = None

event: StreamEvent = StreamEvent.TEXT_CHUNK
data: Data
Expand Down
11 changes: 9 additions & 2 deletions api/core/app/task_pipeline/message_cycle_manage.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,14 +153,21 @@ def _message_file_to_stream_response(self, event: QueueMessageFileEvent) -> Opti

return None

def _message_to_stream_response(self, answer: str, message_id: str) -> MessageStreamResponse:
def _message_to_stream_response(
self, answer: str, message_id: str, from_variable_selector: Optional[list[str]] = None
) -> MessageStreamResponse:
"""
Message to stream response.
:param answer: answer
:param message_id: message id
:return:
"""
return MessageStreamResponse(task_id=self._application_generate_entity.task_id, id=message_id, answer=answer)
return MessageStreamResponse(
task_id=self._application_generate_entity.task_id,
id=message_id,
answer=answer,
from_variable_selector=from_variable_selector,
)

def _message_replace_to_stream_response(self, answer: str) -> MessageReplaceStreamResponse:
"""
Expand Down
1 change: 1 addition & 0 deletions api/core/workflow/nodes/answer/answer_stream_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def _generate_stream_outputs_when_node_finished(
route_node_state=event.route_node_state,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
from_variable_selector=[answer_node_id, "answer"],
)
else:
route_chunk = cast(VarGenerateRouteChunk, route_chunk)
Expand Down

0 comments on commit cee0c51

Please sign in to comment.