Skip to content

Commit

Permalink
Merge pull request #64 from Lambdua/bug_uage_mapStream_iof
Browse files Browse the repository at this point in the history
bug: StreamOption.includeUsage = true causes OpenAiService.mapStreamT…
  • Loading branch information
Lambdua authored Sep 19, 2024
2 parents 8c4affc + a8a6f45 commit a7db34f
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 12 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package com.theokanning.openai.service;

import com.theokanning.openai.Usage;
import com.theokanning.openai.completion.chat.AssistantMessage;
import com.theokanning.openai.completion.chat.ChatFunctionCall;

Expand All @@ -15,15 +16,18 @@ public class ChatMessageAccumulator {
private final AssistantMessage messageChunk;
private final AssistantMessage accumulatedMessage;

private final Usage usage;

/**
* Constructor that initializes the message chunk and accumulated message.
*
* @param messageChunk The message chunk.
* @param accumulatedMessage The accumulated message.
*/
public ChatMessageAccumulator(AssistantMessage messageChunk, AssistantMessage accumulatedMessage) {
public ChatMessageAccumulator(AssistantMessage messageChunk, AssistantMessage accumulatedMessage,Usage usage) {
this.messageChunk = messageChunk;
this.accumulatedMessage = accumulatedMessage;
this.usage=usage;
}

/**
Expand Down Expand Up @@ -64,6 +68,14 @@ public AssistantMessage getAccumulatedMessage() {
return accumulatedMessage;
}


/**
* 只有{@link com.theokanning.openai.completion.chat.StreamOption#INCLUDE} 时,usage才不为null
*/
public Usage getUsage() {
return usage;
}

/**
* Retrieves the function call from the message chunk.
* This is equivalent to getMessageChunk().getFunctionCall().
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -742,15 +742,19 @@ public Flowable<ChatMessageAccumulator> mapStreamToAccumulator(Flowable<ChatComp
ChatFunctionCall functionCall = new ChatFunctionCall(null, null);
AssistantMessage accumulatedMessage = new AssistantMessage();
return flowable.map(chunk -> {
ChatCompletionChoice firstChoice = chunk.getChoices().get(0);
AssistantMessage messageChunk = firstChoice.getMessage();
appendContent(messageChunk, accumulatedMessage);
processFunctionCall(messageChunk, functionCall, accumulatedMessage);
processToolCalls(messageChunk, accumulatedMessage);
if (firstChoice.getFinishReason() != null) {
handleFinishReason(firstChoice.getFinishReason(), functionCall, accumulatedMessage);
List<ChatCompletionChoice> choices = chunk.getChoices();
AssistantMessage messageChunk=new AssistantMessage();
if (choices!=null && !choices.isEmpty()){
ChatCompletionChoice firstChoice = choices.get(0);
messageChunk = firstChoice.getMessage();
appendContent(messageChunk, accumulatedMessage);
processFunctionCall(messageChunk, functionCall, accumulatedMessage);
processToolCalls(messageChunk, accumulatedMessage);
if (firstChoice.getFinishReason() != null) {
handleFinishReason(firstChoice.getFinishReason(), functionCall, accumulatedMessage);
}
}
return new ChatMessageAccumulator(messageChunk, accumulatedMessage);
return new ChatMessageAccumulator(messageChunk, accumulatedMessage,chunk.getUsage());
});
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -316,12 +316,15 @@ void zeroArgStreamToolTest() {
.n(1)
.maxTokens(100)
.logitBias(new HashMap<>())
.streamOptions(StreamOption.INCLUDE)
.build();
AssistantMessage accumulatedMessage = service.mapStreamToAccumulator(service.streamChatCompletion(chatCompletionRequest))
.blockingLast().getAccumulatedMessage();
ChatMessageAccumulator chatMessageAccumulator = service.mapStreamToAccumulator(service.streamChatCompletion(chatCompletionRequest))
.blockingLast();
AssistantMessage accumulatedMessage = chatMessageAccumulator.getAccumulatedMessage();
List<ChatToolCall> toolCalls = accumulatedMessage.getToolCalls();
assertNotNull(toolCalls);
assertEquals(1, toolCalls.size());
assertNotNull(chatMessageAccumulator.getUsage());
ChatToolCall chatToolCall = toolCalls.get(0);
ChatFunctionCall functionCall = chatToolCall.getFunction();
assertEquals("get_today", functionCall.getName());
Expand Down Expand Up @@ -954,5 +957,4 @@ void toolCallingStrictTest(){
assertEquals("asc",arguments.get("order_by").asText());
}


}

0 comments on commit a7db34f

Please sign in to comment.