Skip to content

Commit

Permalink
新增 ChatMessageAccumulatorWrapper,以便于能够在使用chatCompletion时使用原始 chunk 数据…
Browse files Browse the repository at this point in the history
…。这一需求主要适用于chat API的转发场景。
  • Loading branch information
big-mouth-cn committed Oct 18, 2024
1 parent 05f4673 commit 323d430
Show file tree
Hide file tree
Showing 10 changed files with 109 additions and 13 deletions.
6 changes: 3 additions & 3 deletions README-zh.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,14 @@ OpenAi4J是一个非官方的Java库,旨在帮助java开发者与OpenAI的GPT
## 导入依赖
### Gradle

`implementation 'io.github.lambdua:<api|client|service>:0.22.3'`
`implementation 'io.github.lambdua:<api|client|service>:0.22.4'`
### Maven
```xml

<dependency>
<groupId>io.github.lambdua</groupId>
<artifactId>service</artifactId>
<version>0.22.3</version>
<version>0.22.4</version>
</dependency>
```

Expand Down Expand Up @@ -61,7 +61,7 @@ static void simpleChat() {
<dependency>
<groupId>io.github.lambdua</groupId>
<artifactId>api</artifactId>
<version>0.22.3</version>
<version>0.22.4</version>
</dependency>
```

Expand Down
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,14 @@ applications effortlessly.
## Import
### Gradle

`implementation 'io.github.lambdua:<api|client|service>:0.22.3'`
`implementation 'io.github.lambdua:<api|client|service>:0.22.4'`
### Maven
```xml

<dependency>
<groupId>io.github.lambdua</groupId>
<artifactId>service</artifactId>
<version>0.22.3</version>
<version>0.22.4</version>
</dependency>
```

Expand Down Expand Up @@ -67,7 +67,7 @@ To utilize pojos, import the api module:
<dependency>
<groupId>io.github.lambdua</groupId>
<artifactId>api</artifactId>
<version>0.22.3</version>
<version>0.22.4</version>
</dependency>
```

Expand Down
2 changes: 1 addition & 1 deletion api/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
<parent>
<groupId>io.github.lambdua</groupId>
<artifactId>openai-java</artifactId>
<version>0.22.3</version>
<version>0.22.4</version>
</parent>
<packaging>jar</packaging>
<artifactId>api</artifactId>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package com.theokanning.openai.completion.chat;

import com.fasterxml.jackson.annotation.JsonIgnore;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.theokanning.openai.Usage;
import lombok.Data;
Expand Down Expand Up @@ -51,4 +52,13 @@ public class ChatCompletionChunk {
*/
Usage usage;

/**
* The original data packet returned by chat completion.
* the value like this:
* <pre>
* data:{"id":"chatcmpl-A0QiHfuacgBSbvd8Ld1Por1HojY31","object":"chat.completion.chunk","created":1724666049,"model":"gpt-3.5-turbo-0125","system_fingerprint":null,"choices":[{"index":0,"delta":{"role":"assistant","content":"","refusal":null},"logprobs":null,"finish_reason":null}]}
* </pre>
*/
@JsonIgnore
String source;
}
2 changes: 1 addition & 1 deletion client/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
<parent>
<groupId>io.github.lambdua</groupId>
<artifactId>openai-java</artifactId>
<version>0.22.3</version>
<version>0.22.4</version>
</parent>
<packaging>jar</packaging>

Expand Down
4 changes: 2 additions & 2 deletions example/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

<groupId>io.github.lambdua</groupId>
<artifactId>example</artifactId>
<version>0.22.3</version>
<version>0.22.4</version>
<name>example</name>

<properties>
Expand All @@ -17,7 +17,7 @@
<dependency>
<groupId>io.github.lambdua</groupId>
<artifactId>service</artifactId>
<version>0.22.3</version>
<version>0.22.4</version>
</dependency>

</dependencies>
Expand Down
2 changes: 1 addition & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

<groupId>io.github.lambdua</groupId>
<artifactId>openai-java</artifactId>
<version>0.22.3</version>
<version>0.22.4</version>
<packaging>pom</packaging>
<description>openai java 版本</description>
<url>https://github.com/Lambdua/openai-java</url>
Expand Down
2 changes: 1 addition & 1 deletion service/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
<parent>
<groupId>io.github.lambdua</groupId>
<artifactId>openai-java</artifactId>
<version>0.22.3</version>
<version>0.22.4</version>
</parent>
<packaging>jar</packaging>

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package com.theokanning.openai.service;

import com.theokanning.openai.completion.chat.ChatCompletionChunk;

/**
* Wrapper class of ChatMessageAccumulator
*
* @author Allen Hu
* @date 2024/10/18
*/
public class ChatMessageAccumulatorWrapper {

private final ChatMessageAccumulator chatMessageAccumulator;
private final ChatCompletionChunk chatCompletionChunk;

public ChatMessageAccumulatorWrapper(ChatMessageAccumulator chatMessageAccumulator, ChatCompletionChunk chatCompletionChunk) {
this.chatMessageAccumulator = chatMessageAccumulator;
this.chatCompletionChunk = chatCompletionChunk;
}

public ChatMessageAccumulator getChatMessageAccumulator() {
return chatMessageAccumulator;
}

public ChatCompletionChunk getChatCompletionChunk() {
return chatCompletionChunk;
}
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package com.theokanning.openai.service;

import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.DeserializationFeature;
import com.fasterxml.jackson.databind.ObjectMapper;
Expand Down Expand Up @@ -72,6 +73,8 @@
import java.util.*;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.function.BiConsumer;
import java.util.function.Supplier;

public class OpenAiService {

Expand Down Expand Up @@ -190,7 +193,17 @@ public ChatCompletionResult createChatCompletion(ChatCompletionRequest request)

public Flowable<ChatCompletionChunk> streamChatCompletion(ChatCompletionRequest request) {
request.setStream(true);
return stream(api.createChatCompletionStream(request), ChatCompletionChunk.class);
return stream(api.createChatCompletionStream(request), ChatCompletionChunk.class, new BiConsumer<ChatCompletionChunk, SSE>() {
@Override
public void accept(ChatCompletionChunk chatCompletionChunk, SSE sse) {
chatCompletionChunk.setSource(sse.getData());
}
}, new Supplier<ChatCompletionChunk>() {
@Override
public ChatCompletionChunk get() {
return new ChatCompletionChunk();
}
});
}


Expand Down Expand Up @@ -692,6 +705,31 @@ public static <T> Flowable<T> stream(Call<ResponseBody> apiCall, Class<T> cl) {
return stream(apiCall).map(sse -> mapper.readValue(sse.getData(), cl));
}

/**
* Calls the Open AI api and returns a Flowable of type T for streaming
* omitting the last message.
* @param apiCall The api call
* @param cl Class of type T to return
* @param consumer After the instance creation is complete
* @param newInstance If the serialization fails, call this interface to get an instance
*/
public static <T> Flowable<T> stream(Call<ResponseBody> apiCall, Class<T> cl, BiConsumer<T, SSE> consumer,
Supplier<T> newInstance) {
return stream(apiCall, true).map(sse -> {
try {
T t = mapper.readValue(sse.getData(), cl);
if (Objects.nonNull(consumer)) {
consumer.accept(t, sse);
}
return t;
} catch (JsonProcessingException e) {
T t = newInstance.get();
consumer.accept(t, sse);
return t;
}
});
}

/**
* Shuts down the OkHttp ExecutorService.
* The default behaviour of OkHttp's ExecutorService (ConnectionPool)
Expand Down Expand Up @@ -758,6 +796,26 @@ public Flowable<ChatMessageAccumulator> mapStreamToAccumulator(Flowable<ChatComp
});
}

public Flowable<ChatMessageAccumulatorWrapper> mapStreamToAccumulatorWrapper(Flowable<ChatCompletionChunk> flowable) {
ChatFunctionCall functionCall = new ChatFunctionCall(null, null);
AssistantMessage accumulatedMessage = new AssistantMessage();
return flowable.map(chunk -> {
List<ChatCompletionChoice> choices = chunk.getChoices();
AssistantMessage messageChunk = null;
if (null != choices && !choices.isEmpty()) {
ChatCompletionChoice firstChoice = choices.get(0);
messageChunk = firstChoice.getMessage();
appendContent(messageChunk, accumulatedMessage);
processFunctionCall(messageChunk, functionCall, accumulatedMessage);
processToolCalls(messageChunk, accumulatedMessage);
if (firstChoice.getFinishReason() != null) {
handleFinishReason(firstChoice.getFinishReason(), functionCall, accumulatedMessage);
}
}
ChatMessageAccumulator chatMessageAccumulator = new ChatMessageAccumulator(messageChunk, accumulatedMessage, chunk.getUsage());
return new ChatMessageAccumulatorWrapper(chatMessageAccumulator, chunk);
});
}

/**
* 处理消息块中的函数调用。
Expand Down

0 comments on commit 323d430

Please sign in to comment.