diff --git a/pom.xml b/pom.xml
index a0fb7af..da331ea 100644
--- a/pom.xml
+++ b/pom.xml
@@ -49,9 +49,33 @@
1.2.7
5.8.18
4.10.0
+ 2.4.4
+ 5.2.1
0.4.0
+
+
+
+
+
+ org.apache.httpcomponents.client5
+ httpclient5
+ 5.2.1
+
+
+ org.apache.httpcomponents.core5
+ httpcore5
+ 5.2.1
+
+
+ org.apache.httpcomponents.core5
+ httpcore5-h2
+ 5.2.1
+
+
+
+
org.springframework.boot
@@ -110,6 +134,12 @@
${okhttp.version}
+
+ io.github.lunasaw
+ luna-common
+ ${luna-common.version}
+
+
com.knuddels
jtokkit
diff --git a/src/main/java/com/lzhpo/chatgpt/DefaultOpenAiClient.java b/src/main/java/com/lzhpo/chatgpt/DefaultOpenAiClient.java
index 99a9a2d..44b15e2 100644
--- a/src/main/java/com/lzhpo/chatgpt/DefaultOpenAiClient.java
+++ b/src/main/java/com/lzhpo/chatgpt/DefaultOpenAiClient.java
@@ -49,6 +49,7 @@
import com.lzhpo.chatgpt.entity.moderations.ModerationRequest;
import com.lzhpo.chatgpt.entity.moderations.ModerationResponse;
import com.lzhpo.chatgpt.entity.users.UserResponse;
+import com.lzhpo.chatgpt.sse.Listener;
import com.lzhpo.chatgpt.utils.JsonUtils;
import java.net.URI;
import java.util.Map;
@@ -92,10 +93,10 @@ public CompletionResponse completions(CompletionRequest request) {
}
@Override
- public void streamCompletions(CompletionRequest request, EventSourceListener listener) {
+ public void streamCompletions(CompletionRequest request, Listener listener) {
request.setStream(true);
Request clientRequest = createRequest(COMPLETIONS, createRequestBody(request));
- RealEventSource realEventSource = new RealEventSource(clientRequest, listener);
+ RealEventSource realEventSource = new RealEventSource(clientRequest, (EventSourceListener) listener);
realEventSource.connect(okHttpClient);
}
@@ -110,10 +111,10 @@ public ChatCompletionResponse chatCompletions(ChatCompletionRequest request) {
}
@Override
- public void streamChatCompletions(ChatCompletionRequest request, EventSourceListener listener) {
+ public void streamChatCompletions(ChatCompletionRequest request, Listener listener) {
request.setStream(true);
Request clientRequest = createRequest(CHAT_COMPLETIONS, createRequestBody(request));
- RealEventSource realEventSource = new RealEventSource(clientRequest, listener);
+ RealEventSource realEventSource = new RealEventSource(clientRequest, (EventSourceListener) listener);
realEventSource.connect(okHttpClient);
}
diff --git a/src/main/java/com/lzhpo/chatgpt/HttpOpenAiClient.java b/src/main/java/com/lzhpo/chatgpt/HttpOpenAiClient.java
new file mode 100644
index 0000000..c99b4e9
--- /dev/null
+++ b/src/main/java/com/lzhpo/chatgpt/HttpOpenAiClient.java
@@ -0,0 +1,376 @@
+package com.lzhpo.chatgpt;
+
+import cn.hutool.core.io.file.FileNameUtil;
+import cn.hutool.core.lang.WeightRandom;
+import com.google.common.collect.Maps;
+import com.luna.common.file.FileTools;
+import com.luna.common.net.HttpUtils;
+import com.luna.common.net.HttpUtilsConstant;
+import com.luna.common.net.async.CustomSseAsyncConsumer;
+import com.luna.common.net.hander.AbstactEventFutureCallback;
+import com.luna.common.net.high.AsyncHttpUtils;
+import com.luna.common.net.sse.Event;
+import com.luna.common.net.sse.SseResponse;
+import com.luna.common.text.CharsetUtil;
+import com.luna.common.thread.AsyncEngineUtils;
+import com.lzhpo.chatgpt.entity.audio.CreateAudioRequest;
+import com.lzhpo.chatgpt.entity.audio.CreateAudioResponse;
+import com.lzhpo.chatgpt.entity.billing.CreditGrantsResponse;
+import com.lzhpo.chatgpt.entity.billing.SubscriptionResponse;
+import com.lzhpo.chatgpt.entity.billing.UsageResponse;
+import com.lzhpo.chatgpt.entity.chat.ChatCompletionRequest;
+import com.lzhpo.chatgpt.entity.chat.ChatCompletionResponse;
+import com.lzhpo.chatgpt.entity.completions.CompletionRequest;
+import com.lzhpo.chatgpt.entity.completions.CompletionResponse;
+import com.lzhpo.chatgpt.entity.edit.EditRequest;
+import com.lzhpo.chatgpt.entity.edit.EditResponse;
+import com.lzhpo.chatgpt.entity.embeddings.EmbeddingRequest;
+import com.lzhpo.chatgpt.entity.embeddings.EmbeddingResponse;
+import com.lzhpo.chatgpt.entity.files.DeleteFileResponse;
+import com.lzhpo.chatgpt.entity.files.ListFileResponse;
+import com.lzhpo.chatgpt.entity.files.RetrieveFileResponse;
+import com.lzhpo.chatgpt.entity.files.UploadFileResponse;
+import com.lzhpo.chatgpt.entity.finetunes.*;
+import com.lzhpo.chatgpt.entity.image.CreateImageRequest;
+import com.lzhpo.chatgpt.entity.image.CreateImageResponse;
+import com.lzhpo.chatgpt.entity.image.CreateImageVariationRequest;
+import com.lzhpo.chatgpt.entity.model.ListModelsResponse;
+import com.lzhpo.chatgpt.entity.model.RetrieveModelResponse;
+import com.lzhpo.chatgpt.entity.moderations.ModerationRequest;
+import com.lzhpo.chatgpt.entity.moderations.ModerationResponse;
+import com.lzhpo.chatgpt.entity.users.UserResponse;
+import com.lzhpo.chatgpt.sse.Listener;
+import com.lzhpo.chatgpt.utils.JsonUtils;
+import lombok.RequiredArgsConstructor;
+import lombok.SneakyThrows;
+import lombok.extern.slf4j.Slf4j;
+import org.apache.commons.collections4.MapUtils;
+import org.apache.commons.io.IOUtils;
+import org.apache.hc.client5.http.entity.mime.HttpMultipartMode;
+import org.apache.hc.client5.http.entity.mime.MultipartEntityBuilder;
+import org.apache.hc.client5.http.impl.classic.BasicHttpClientResponseHandler;
+import org.apache.hc.core5.http.ContentType;
+import org.apache.hc.core5.http.HttpEntity;
+import org.apache.hc.core5.http.HttpHeaders;
+import org.apache.hc.core5.http.io.entity.StringEntity;
+import org.apache.hc.core5.http.nio.AsyncRequestProducer;
+import org.apache.hc.core5.http.nio.entity.StringAsyncEntityProducer;
+import org.jetbrains.annotations.NotNull;
+import org.springframework.boot.context.properties.PropertyMapper;
+import org.springframework.core.io.Resource;
+import org.springframework.http.HttpMethod;
+import org.springframework.util.Assert;
+import org.springframework.util.StringUtils;
+import org.springframework.validation.annotation.Validated;
+import org.springframework.web.util.UriTemplateHandler;
+
+import java.io.IOException;
+import java.net.URI;
+import java.nio.charset.Charset;
+import java.util.HashMap;
+import java.util.Map;
+
+import static com.lzhpo.chatgpt.OpenAiConstant.*;
+import static com.lzhpo.chatgpt.OpenAiUrl.*;
+
+/**
+ * @author luna
+ * @description
+ * @date 2023/4/22
+ */
+@Slf4j
+@Validated
+@RequiredArgsConstructor
+public class HttpOpenAiClient implements OpenAiClient {
+
+ private final OpenAiProperties openAiProperties;
+ private final UriTemplateHandler uriTemplateHandler;
+ private final OpenAiKeyWrapper openAiKeyWrapper;
+
+
+ @Override
+ public ModerationResponse moderations(ModerationRequest request) {
+ return execute(MODERATIONS, createRequestBody(request), ModerationResponse.class);
+ }
+
+ @Override
+ public CompletionResponse completions(CompletionRequest request) {
+ return execute(COMPLETIONS, createRequestBody(request), CompletionResponse.class);
+ }
+
+ @Override
+ public void streamCompletions(CompletionRequest request, Listener eventListener) {
+ request.setStream(true);
+ doRequestAsync(CHAT_COMPLETIONS, JsonUtils.toJsonString(request), eventListener);
+ }
+
+ @Override
+ public EditResponse edits(EditRequest request) {
+ return execute(EDITS, createRequestBody(request), EditResponse.class);
+ }
+
+ @Override
+ public ChatCompletionResponse chatCompletions(ChatCompletionRequest request) {
+ return execute(CHAT_COMPLETIONS, createRequestBody(request), ChatCompletionResponse.class);
+ }
+
+ @Override
+ public void streamChatCompletions(ChatCompletionRequest request, Listener eventListener) {
+ request.setStream(true);
+ doRequestAsync(CHAT_COMPLETIONS, JsonUtils.toJsonString(request), eventListener);
+ }
+
+ @Override
+ public ListModelsResponse models() {
+ return execute(LIST_MODELS, null, ListModelsResponse.class);
+ }
+
+ @Override
+ public RetrieveModelResponse retrieveModel(String modelId) {
+ return execute(RETRIEVE_MODEL, null, RetrieveModelResponse.class, modelId);
+ }
+
+ @Override
+ public EmbeddingResponse embeddings(EmbeddingRequest request) {
+ return execute(EMBEDDINGS, createRequestBody(request), EmbeddingResponse.class);
+ }
+
+ @Override
+ public ListFileResponse listFiles() {
+ return execute(LIST_FILES, null, ListFileResponse.class);
+ }
+
+ @Override
+ @SneakyThrows
+ public UploadFileResponse uploadFile(Resource fileResource, String purpose) {
+ HashMap map = Maps.newHashMap();
+ map.put("purpose", purpose);
+ map.put("file", fileResource.getFile().getAbsolutePath());
+ HttpEntity requestBody = createRequestBody(map);
+ return execute(UPLOAD_FILE, requestBody, UploadFileResponse.class);
+ }
+
+ @Override
+ public DeleteFileResponse deleteFile(String fileId) {
+ return execute(DELETE_FILE, null, DeleteFileResponse.class, fileId);
+ }
+
+ @Override
+ public RetrieveFileResponse retrieveFile(String fileId) {
+ return execute(RETRIEVE_FILE, null, RetrieveFileResponse.class, fileId);
+ }
+
+ @Override
+ public CreateFineTuneResponse createFineTune(CreateFineTuneRequest request) {
+ return execute(CREATE_FINE_TUNE, createRequestBody(request), CreateFineTuneResponse.class);
+ }
+
+ @Override
+ public ListFineTuneResponse listFineTunes() {
+ return execute(LIST_FINE_TUNE, null, ListFineTuneResponse.class);
+ }
+
+ @Override
+ public RetrieveFineTuneResponse retrieveFineTunes(String fineTuneId) {
+ return execute(RETRIEVE_FINE_TUNE, null, RetrieveFineTuneResponse.class, fineTuneId);
+ }
+
+ @Override
+ public CancelFineTuneResponse cancelFineTune(String fineTuneId) {
+ return execute(CANCEL_FINE_TUNE, createRequestBody(null), CancelFineTuneResponse.class, fineTuneId);
+ }
+
+ @Override
+ public ListFineTuneEventResponse listFineTuneEvents(String fineTuneId) {
+ return execute(LIST_FINE_TUNE_EVENTS, null, ListFineTuneEventResponse.class, fineTuneId);
+ }
+
+ @Override
+ public DeleteFineTuneModelResponse deleteFineTuneModel(String model) {
+ return execute(DELETE_FINE_TUNE_EVENTS, null, DeleteFineTuneModelResponse.class, model);
+ }
+
+ @Override
+ public CreateAudioResponse createTranscription(Resource fileResource, CreateAudioRequest request) {
+ HttpEntity audioBody = createAudioBody(fileResource, request);
+ return execute(CREATE_TRANSCRIPTION, audioBody, CreateAudioResponse.class);
+ }
+
+ @Override
+ public CreateAudioResponse createTranslation(Resource fileResource, CreateAudioRequest request) {
+ HttpEntity multipartBody = createAudioBody(fileResource, request);
+ return execute(CREATE_TRANSLATION, multipartBody, CreateAudioResponse.class);
+ }
+
+ @Override
+ public CreateImageResponse createImage(CreateImageRequest request) {
+ return execute(CREATE_IMAGE, createRequestBody(request), CreateImageResponse.class);
+ }
+
+ @Override
+ @SneakyThrows
+ public CreateImageResponse createImageEdit(Resource image, Resource mask, CreateImageRequest request) {
+ HttpEntity imageBody = createImageBody(image, mask, request);
+ return execute(CREATE_TRANSCRIPTION, imageBody, CreateImageResponse.class);
+ }
+
+ @Override
+ @SneakyThrows
+ public CreateImageResponse createImageVariation(Resource image, CreateImageVariationRequest request) {
+
+ HttpEntity requestBody = buildImageFormBody(image, request);
+ return execute(CREATE_IMAGE_VARIATION, requestBody, CreateImageResponse.class);
+ }
+
+
+ @Override
+ public CreditGrantsResponse billingCreditGrants() {
+ return execute(BILLING_CREDIT_GRANTS, null, CreditGrantsResponse.class);
+ }
+
+ @Override
+ public UserResponse users(String organizationId) {
+ return execute(USERS, null, UserResponse.class, organizationId);
+ }
+
+ @Override
+ public SubscriptionResponse billingSubscription() {
+ return execute(BILLING_SUBSCRIPTION, null, SubscriptionResponse.class);
+ }
+
+ @Override
+ public UsageResponse billingUsage(String startDate, String endDate) {
+ return execute(BILLING_USAGE, null, UsageResponse.class, startDate, endDate);
+ }
+
+ @SneakyThrows
+ private S execute(OpenAiUrl openAiUrl, HttpEntity requestBody, Class responseType, Object... uriVariables) {
+ return execute(openAiUrl, new HashMap<>(), requestBody, responseType, uriVariables);
+ }
+
+ @SneakyThrows
+ private S execute(OpenAiUrl openAiUrl, Map body, HttpEntity requestBody, Class responseType, Object... uriVariables) {
+ String request = doRequest(openAiUrl, body, requestBody, uriVariables);
+ Assert.notNull(request, "Resolve response body failed.");
+ return JsonUtils.parse(request, responseType);
+ }
+
+ private String doRequest(OpenAiUrl openAiUrl, Map body, HttpEntity httpEntity, Object... uriVariables) {
+ WeightRandom weightRandom = openAiKeyWrapper.wrap();
+ Map configUrls = openAiProperties.getUrls();
+
+ URI requestUri = getUri(configUrls, openAiUrl, uriVariables);
+
+ Map header = Maps.newHashMap();
+ header.put(HttpHeaders.AUTHORIZATION, BEARER.concat(weightRandom.next()));
+
+ String result;
+ if (HttpMethod.POST.toString().equals(openAiUrl.getMethod())) {
+ header.put(HttpHeaders.CONTENT_TYPE, HttpUtilsConstant.JSON);
+ result = HttpUtils.doPost(openAiProperties.getDomain(), requestUri.getPath(), header, body, httpEntity, new BasicHttpClientResponseHandler());
+ } else if (HttpMethod.GET.toString().equals(openAiUrl.getMethod())) {
+ result = HttpUtils.doGet(openAiProperties.getDomain(), requestUri.getPath(), header, body, new BasicHttpClientResponseHandler());
+ } else {
+ throw new OpenAiException("不支持的请求方式");
+ }
+
+ return result;
+ }
+
+ public void doRequestAsync(OpenAiUrl openAiUrl, String body, Listener callback, Object... uriVariable) {
+ WeightRandom weightRandom = openAiKeyWrapper.wrap();
+ Map configUrls = openAiProperties.getUrls();
+ URI requestUri = getUri(configUrls, openAiUrl, uriVariable);
+ Map header = Maps.newHashMap();
+ header.put(HttpHeaders.AUTHORIZATION, BEARER.concat(weightRandom.next()));
+
+ StringAsyncEntityProducer bodyProducer = new StringAsyncEntityProducer(body);
+ AsyncRequestProducer producer = AsyncHttpUtils.getProducer(openAiProperties.getDomain(), requestUri.getPath(), header, new HashMap<>(), bodyProducer, HttpMethod.GET.toString());
+
+ // 事件处理器
+ CustomSseAsyncConsumer customSseAsyncConsumer = new CustomSseAsyncConsumer((AbstactEventFutureCallback) callback);
+ AsyncEngineUtils.execute(() -> AsyncHttpUtils.doAsyncRequest(producer, customSseAsyncConsumer, null));
+ }
+
+ @NotNull
+ private URI getUri(Map configUrls, OpenAiUrl openAiUrl, Object... uriVariables) {
+ String url = configUrls.get(openAiUrl);
+ if (!StringUtils.hasText(url)) {
+ url = openAiUrl.getSuffix();
+ }
+ URI requestUri = uriTemplateHandler.expand(url, uriVariables);
+ return requestUri;
+ }
+
+
+ private HttpEntity createRequestBody(Object request) {
+ String jsonString = JsonUtils.toJsonString(request);
+ return new StringEntity(jsonString, Charset.defaultCharset());
+ }
+
+ private HttpEntity createRequestBody(Map bodies) {
+ MultipartEntityBuilder builder = MultipartEntityBuilder.create();
+ builder.setMode(HttpMultipartMode.LEGACY);
+ builder.setCharset(CharsetUtil.defaultCharset());
+ builder.setContentType(ContentType.MULTIPART_FORM_DATA);
+ if (MapUtils.isNotEmpty(bodies)) {
+ bodies.forEach((k, v) -> {
+ if (FileTools.isExists(v)) {
+ builder.addBinaryBody(k, IOUtils.toInputStream(v, CharsetUtil.defaultCharset()));
+ } else {
+ builder.addTextBody(k, v);
+ }
+ });
+ }
+ return builder.build();
+ }
+
+ @SneakyThrows
+ private HttpEntity createImageBody(Resource image, Resource mask, CreateImageRequest request) {
+ boolean imageIsPng = FileNameUtil.isType(image.getFilename(), EXPECTED_IMAGE_TYPE);
+ boolean maskIsPng = FileNameUtil.isType(mask.getFilename(), EXPECTED_IMAGE_TYPE);
+ Assert.isTrue(imageIsPng, "The image must png type.");
+ Assert.isTrue(maskIsPng, "The mask must png type.");
+
+ Assert.isTrue(image.contentLength() < MAX_IMAGE_SIZE, "The image must less than 4MB.");
+ Assert.isTrue(mask.contentLength() < MAX_IMAGE_SIZE, "The mask must less than 4MB.");
+ PropertyMapper mapper = PropertyMapper.get().alwaysApplyingWhenNonNull();
+ HashMap map = Maps.newHashMap();
+
+ map.put("image", image.getFile().getAbsolutePath());
+ map.put("mask", mask.getFile().getAbsolutePath());
+ mapper.from(request.getPrompt()).to(prompt -> map.put("prompt", prompt));
+ return createRequestBody(map);
+ }
+
+ @NotNull
+ private HttpEntity buildImageFormBody(Resource image, CreateImageVariationRequest request) throws IOException {
+ boolean imageIsPng = FileNameUtil.isType(image.getFilename(), EXPECTED_IMAGE_TYPE);
+ Assert.isTrue(imageIsPng, "The image must png type.");
+ Assert.isTrue(image.contentLength() < MAX_IMAGE_SIZE, "The image must less than 4MB.");
+
+ HashMap hashMap = Maps.newHashMap();
+ PropertyMapper mapper = PropertyMapper.get().alwaysApplyingWhenNonNull();
+ mapper.from(request.getN()).to(n -> hashMap.put("n", n.toString()));
+ mapper.from(request.getSize()).to(size -> hashMap.put("size", size.getValue()));
+ mapper.from(request.getResponseFormat()).to(obj -> hashMap.put("response_format", obj.getValue()));
+ mapper.from(request.getUser()).to(user -> hashMap.put("user", user));
+ hashMap.put("images", image.getFile().getAbsolutePath());
+ return createRequestBody(hashMap);
+ }
+
+ @SneakyThrows
+ private HttpEntity createAudioBody(Resource fileResource, CreateAudioRequest request) {
+
+ HashMap map = Maps.newHashMap();
+ map.put("file", fileResource.getFile().getAbsolutePath());
+ PropertyMapper mapper = PropertyMapper.get().alwaysApplyingWhenNonNull();
+ mapper.from(request.getModel()).to(model -> map.put("model", model));
+ mapper.from(request.getPrompt()).to(prompt -> map.put("prompt", prompt));
+ mapper.from(request.getResponseFormat()).to(format -> map.put("response_format", format));
+ mapper.from(request.getTemperature()).to(obj -> map.put("temperature", obj.toString()));
+ mapper.from(request.getLanguage()).to(language -> map.put("language", language));
+
+ return createRequestBody(map);
+ }
+}
diff --git a/src/main/java/com/lzhpo/chatgpt/OpenAiAutoConfiguration.java b/src/main/java/com/lzhpo/chatgpt/OpenAiAutoConfiguration.java
index 70fb6b9..13259ff 100644
--- a/src/main/java/com/lzhpo/chatgpt/OpenAiAutoConfiguration.java
+++ b/src/main/java/com/lzhpo/chatgpt/OpenAiAutoConfiguration.java
@@ -5,16 +5,23 @@
import java.net.Proxy;
import java.util.List;
import java.util.Optional;
+import java.util.Set;
+import java.util.stream.Collectors;
+
+import com.luna.common.net.HttpUtils;
+import com.luna.common.net.high.AsyncHttpUtils;
import lombok.RequiredArgsConstructor;
import okhttp3.Credentials;
import okhttp3.Interceptor;
import okhttp3.OkHttpClient;
import org.springframework.beans.factory.ObjectProvider;
+import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.boot.autoconfigure.AutoConfiguration;
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.boot.context.properties.PropertyMapper;
import org.springframework.context.annotation.Bean;
+import org.springframework.context.annotation.Scope;
import org.springframework.web.util.DefaultUriBuilderFactory;
import org.springframework.web.util.UriTemplateHandler;
@@ -52,7 +59,7 @@ public OkHttpClient okHttpClient(List interceptors) {
@Bean
@ConditionalOnMissingBean
- public OpenAiClient openAiService(
+ public DefaultOpenAiClient openAiService(
OkHttpClient okHttpClient,
OpenAiKeyWrapper openAiKeyWrapper,
ObjectProvider uriTemplateHandlerObjectProvider) {
@@ -70,6 +77,24 @@ public OpenAiKeyWrapper openAiKeyWrapper(OpenAiKeyProvider openAiKeyProvider) {
return new OpenAiKeyWrapper(openAiKeyProvider);
}
+ @Bean(name = "httpOpenAiService")
+ @ConditionalOnMissingBean
+ public HttpOpenAiClient openAiService(
+ OpenAiKeyWrapper openAiKeyWrapper,
+ ObjectProvider uriTemplateHandlerObjectProvider) {
+ UriTemplateHandler uriTemplateHandler = uriTemplateHandlerObjectProvider.getIfAvailable(() -> {
+ DefaultUriBuilderFactory uriBuilderFactory = new DefaultUriBuilderFactory();
+ uriBuilderFactory.setEncodingMode(DefaultUriBuilderFactory.EncodingMode.URI_COMPONENT);
+ return uriBuilderFactory;
+ });
+ PropertyMapper mapper = PropertyMapper.get().alwaysApplyingWhenNonNull();
+ mapper.from(openAiProperties::getReadTimeout).to(e->HttpUtils.setResponseTimeout((int) e.getSeconds()));
+ mapper.from(openAiProperties::getWriteTimeout).to(e->HttpUtils.setSocketTimeOut((int) e.getSeconds()));
+ mapper.from(openAiProperties::getConnectTimeout).to(e->HttpUtils.setConnectTimeout((int) e.getSeconds()));
+ Optional.ofNullable(openAiProperties.getProxy()).ifPresent(e-> AsyncHttpUtils.setProxy(e.getHost(), e.getPort()));
+ return new HttpOpenAiClient(openAiProperties, uriTemplateHandler, openAiKeyWrapper);
+ }
+
@Bean
@ConditionalOnMissingBean
public OpenAiKeyProvider openAiKeyProvider() {
diff --git a/src/main/java/com/lzhpo/chatgpt/OpenAiClient.java b/src/main/java/com/lzhpo/chatgpt/OpenAiClient.java
index 0fd7d18..656943d 100644
--- a/src/main/java/com/lzhpo/chatgpt/OpenAiClient.java
+++ b/src/main/java/com/lzhpo/chatgpt/OpenAiClient.java
@@ -45,6 +45,8 @@
import javax.validation.Valid;
import javax.validation.constraints.NotBlank;
import javax.validation.constraints.NotNull;
+
+import com.lzhpo.chatgpt.sse.Listener;
import okhttp3.sse.EventSourceListener;
import org.springframework.core.io.Resource;
@@ -75,7 +77,7 @@ public interface OpenAiClient {
* @param request {@link CompletionRequest}
* @param listener {@link EventSourceListener}
*/
- void streamCompletions(@Valid CompletionRequest request, @NotNull EventSourceListener listener);
+ void streamCompletions(@Valid CompletionRequest request, @NotNull Listener listener);
/**
* Create edit.
@@ -99,7 +101,7 @@ public interface OpenAiClient {
* @param request {@link ChatCompletionRequest}
* @param listener {@link EventSourceListener}
*/
- void streamChatCompletions(@Valid ChatCompletionRequest request, @NotNull EventSourceListener listener);
+ void streamChatCompletions(@Valid ChatCompletionRequest request, @NotNull Listener listener);
/**
* List models.
diff --git a/src/main/java/com/lzhpo/chatgpt/sse/AbstractFutureCallback.java b/src/main/java/com/lzhpo/chatgpt/sse/AbstractFutureCallback.java
new file mode 100644
index 0000000..95c4586
--- /dev/null
+++ b/src/main/java/com/lzhpo/chatgpt/sse/AbstractFutureCallback.java
@@ -0,0 +1,11 @@
+package com.lzhpo.chatgpt.sse;
+
+import com.luna.common.net.hander.AbstactEventFutureCallback;
+
+/**
+ * @author luna
+ * @description
+ * @date 2023/4/28
+ */
+public class AbstractFutureCallback extends AbstactEventFutureCallback implements Listener {
+}
diff --git a/src/main/java/com/lzhpo/chatgpt/sse/CountDownLatchEventSourceListener.java b/src/main/java/com/lzhpo/chatgpt/sse/CountDownLatchEventSourceListener.java
index 44f18d2..9329355 100644
--- a/src/main/java/com/lzhpo/chatgpt/sse/CountDownLatchEventSourceListener.java
+++ b/src/main/java/com/lzhpo/chatgpt/sse/CountDownLatchEventSourceListener.java
@@ -30,7 +30,7 @@
* @author lzhpo
*/
@Slf4j
-public class CountDownLatchEventSourceListener extends AbstractEventSourceListener {
+public class CountDownLatchEventSourceListener extends AbstractEventSourceListener implements Listener{
private final CountDownLatch countDownLatch;
diff --git a/src/main/java/com/lzhpo/chatgpt/sse/CustomDownLatchEventFutureCallback.java b/src/main/java/com/lzhpo/chatgpt/sse/CustomDownLatchEventFutureCallback.java
new file mode 100644
index 0000000..558d347
--- /dev/null
+++ b/src/main/java/com/lzhpo/chatgpt/sse/CustomDownLatchEventFutureCallback.java
@@ -0,0 +1,42 @@
+package com.lzhpo.chatgpt.sse;
+
+import com.alibaba.fastjson2.JSON;
+import lombok.extern.slf4j.Slf4j;
+import org.springframework.util.Assert;
+
+import java.util.concurrent.CountDownLatch;
+
+/**
+ * @author luna
+ * @description
+ * @date 2023/4/23
+ */
+@Slf4j
+public class CustomDownLatchEventFutureCallback extends AbstractFutureCallback implements Listener {
+
+ private final CountDownLatch countDownLatch;
+
+ public CustomDownLatchEventFutureCallback(CountDownLatch countDownLatch) {
+ Assert.notNull(countDownLatch, "countDownLatch cannot null.");
+ this.countDownLatch = countDownLatch;
+ }
+
+ @Override
+ public void completed(T result) {
+ log.info("completed::result = {}", JSON.toJSONString(result));
+ countDownLatch.countDown();
+ }
+
+ @Override
+ public void failed(Exception ex) {
+ super.failed(ex);
+ cancelled();
+ countDownLatch.countDown();
+ }
+
+ @Override
+ public void cancelled() {
+ super.cancelled();
+ countDownLatch.countDown();
+ }
+}
diff --git a/src/main/java/com/lzhpo/chatgpt/sse/HttpSseEventSourceListener.java b/src/main/java/com/lzhpo/chatgpt/sse/HttpSseEventSourceListener.java
new file mode 100644
index 0000000..6383223
--- /dev/null
+++ b/src/main/java/com/lzhpo/chatgpt/sse/HttpSseEventSourceListener.java
@@ -0,0 +1,53 @@
+package com.lzhpo.chatgpt.sse;
+import lombok.extern.slf4j.Slf4j;
+import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
+
+import java.io.IOException;
+
+/**
+ * @author luna
+ * @description
+ * @date 2023/4/28
+ */
+@Slf4j
+public class HttpSseEventSourceListener extends AbstractFutureCallback {
+
+ private SseEmitter sseEmitter;
+
+ public HttpSseEventSourceListener(SseEmitter sseEmitter) {
+ this.sseEmitter = sseEmitter;
+ }
+
+ @Override
+ public void onEvent(R result) {
+ try {
+ sseEmitter.send(result);
+ } catch (IOException e) {
+ log.error("onEvent::result = {} ", result, e);
+ }
+ }
+
+ @Override
+ public void completed(T result) {
+ sseEmitter.complete();
+ }
+
+ @Override
+ public void failed(Exception ex) {
+ }
+
+ @Override
+ public void cancelled() {
+ super.cancelled();
+ }
+
+ public SseEmitter getSseEmitter() {
+ return sseEmitter;
+ }
+
+ public void setSseEmitter(SseEmitter sseEmitter) {
+ this.sseEmitter = sseEmitter;
+ }
+
+
+}
diff --git a/src/main/java/com/lzhpo/chatgpt/sse/Listener.java b/src/main/java/com/lzhpo/chatgpt/sse/Listener.java
new file mode 100644
index 0000000..971cbd4
--- /dev/null
+++ b/src/main/java/com/lzhpo/chatgpt/sse/Listener.java
@@ -0,0 +1,9 @@
+package com.lzhpo.chatgpt.sse;
+
+/**
+ * @author luna
+ * @description
+ * @date 2023/4/23
+ */
+public interface Listener {
+}
diff --git a/src/main/java/com/lzhpo/chatgpt/sse/SseEventSourceListener.java b/src/main/java/com/lzhpo/chatgpt/sse/SseEventSourceListener.java
index 4f9d069..5e2d52d 100644
--- a/src/main/java/com/lzhpo/chatgpt/sse/SseEventSourceListener.java
+++ b/src/main/java/com/lzhpo/chatgpt/sse/SseEventSourceListener.java
@@ -30,7 +30,7 @@
* @author lzhpo
*/
@Slf4j
-public class SseEventSourceListener extends AbstractEventSourceListener {
+public class SseEventSourceListener extends AbstractEventSourceListener implements Listener{
private final SseEmitter sseEmitter;
diff --git a/src/main/java/com/lzhpo/chatgpt/sse/WebSocketEventSourceListener.java b/src/main/java/com/lzhpo/chatgpt/sse/WebSocketEventSourceListener.java
index eb97e05..46bfc56 100644
--- a/src/main/java/com/lzhpo/chatgpt/sse/WebSocketEventSourceListener.java
+++ b/src/main/java/com/lzhpo/chatgpt/sse/WebSocketEventSourceListener.java
@@ -30,7 +30,7 @@
* @author lzhpo
*/
@Slf4j
-public class WebSocketEventSourceListener extends AbstractEventSourceListener {
+public class WebSocketEventSourceListener extends AbstractEventSourceListener implements Listener{
private final Session session;
diff --git a/src/test/java/com/lzhpo/chatgpt/HttpOpenAiClientTest.java b/src/test/java/com/lzhpo/chatgpt/HttpOpenAiClientTest.java
new file mode 100644
index 0000000..401b66f
--- /dev/null
+++ b/src/test/java/com/lzhpo/chatgpt/HttpOpenAiClientTest.java
@@ -0,0 +1,420 @@
+/*
+ * Copyright 2023 luna
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.lzhpo.chatgpt;
+
+import cn.hutool.core.collection.ListUtil;
+import cn.hutool.core.date.DatePattern;
+import cn.hutool.core.date.DateUtil;
+import cn.hutool.core.lang.Console;
+import com.luna.common.net.sse.Event;
+import com.luna.common.net.sse.SseResponse;
+import com.lzhpo.chatgpt.entity.audio.CreateAudioRequest;
+import com.lzhpo.chatgpt.entity.audio.CreateAudioResponse;
+import com.lzhpo.chatgpt.entity.billing.CreditGrantsResponse;
+import com.lzhpo.chatgpt.entity.billing.SubscriptionResponse;
+import com.lzhpo.chatgpt.entity.billing.UsageResponse;
+import com.lzhpo.chatgpt.entity.chat.ChatCompletionMessage;
+import com.lzhpo.chatgpt.entity.chat.ChatCompletionRequest;
+import com.lzhpo.chatgpt.entity.chat.ChatCompletionResponse;
+import com.lzhpo.chatgpt.entity.completions.CompletionRequest;
+import com.lzhpo.chatgpt.entity.completions.CompletionResponse;
+import com.lzhpo.chatgpt.entity.edit.EditRequest;
+import com.lzhpo.chatgpt.entity.edit.EditResponse;
+import com.lzhpo.chatgpt.entity.embeddings.EmbeddingRequest;
+import com.lzhpo.chatgpt.entity.embeddings.EmbeddingResponse;
+import com.lzhpo.chatgpt.entity.files.DeleteFileResponse;
+import com.lzhpo.chatgpt.entity.files.ListFileResponse;
+import com.lzhpo.chatgpt.entity.files.RetrieveFileResponse;
+import com.lzhpo.chatgpt.entity.files.UploadFileResponse;
+import com.lzhpo.chatgpt.entity.finetunes.*;
+import com.lzhpo.chatgpt.entity.image.*;
+import com.lzhpo.chatgpt.entity.model.ListModelsResponse;
+import com.lzhpo.chatgpt.entity.model.RetrieveModelResponse;
+import com.lzhpo.chatgpt.entity.moderations.ModerationRequest;
+import com.lzhpo.chatgpt.entity.moderations.ModerationResponse;
+import com.lzhpo.chatgpt.entity.users.UserResponse;
+import com.lzhpo.chatgpt.sse.CountDownLatchEventSourceListener;
+import com.lzhpo.chatgpt.sse.CustomDownLatchEventFutureCallback;
+import com.lzhpo.chatgpt.utils.JsonUtils;
+import org.junit.jupiter.api.MethodOrderer;
+import org.junit.jupiter.api.Order;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.TestMethodOrder;
+import org.springframework.beans.factory.annotation.Autowired;
+import org.springframework.boot.test.context.SpringBootTest;
+import org.springframework.boot.test.mock.mockito.MockBean;
+import org.springframework.core.io.FileSystemResource;
+import org.springframework.web.socket.server.standard.ServerEndpointExporter;
+
+import java.util.ArrayList;
+import java.util.Date;
+import java.util.List;
+import java.util.concurrent.CountDownLatch;
+
+import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
+import static org.junit.jupiter.api.Assertions.assertNotNull;
+
+/**
+ * @author lzhpo
+ */
+@SpringBootTest
+@TestMethodOrder(MethodOrderer.OrderAnnotation.class)
+class HttpOpenAiClientTest {
+
+ @Autowired
+ private HttpOpenAiClient openAiService;
+
+ @MockBean
+ private ServerEndpointExporter serverEndpointExporter;
+
+ @Test
+ @Order(1)
+ void moderations() {
+ ModerationRequest request = new ModerationRequest();
+ request.setInput(ListUtil.of("I want to kill them."));
+
+ ModerationResponse response = openAiService.moderations(request);
+ assertNotNull(response);
+ Console.log(JsonUtils.toJsonPrettyString(response));
+ }
+
+ @Test
+ @Order(2)
+ void completions() {
+ CompletionRequest request = new CompletionRequest();
+ request.setModel("text-davinci-003");
+ request.setPrompt("Say this is a test");
+ request.setMaxTokens(7);
+ request.setTemperature(0);
+
+ CompletionResponse response = openAiService.completions(request);
+ assertNotNull(response);
+ Console.log(JsonUtils.toJsonPrettyString(response));
+ }
+
+ @Test
+ @Order(3)
+ void streamCompletions() throws InterruptedException {
+ CompletionRequest request = new CompletionRequest();
+ request.setStream(true);
+ request.setModel("text-davinci-003");
+ request.setPrompt("Say this is a test");
+ request.setMaxTokens(7);
+ request.setTemperature(0);
+
+ final CountDownLatch countDownLatch = new CountDownLatch(1);
+ CustomDownLatchEventFutureCallback futureCallback = new CustomDownLatchEventFutureCallback<>(countDownLatch);
+ assertDoesNotThrow(() -> openAiService.streamCompletions(request, futureCallback));
+ countDownLatch.await();
+ }
+
+ @Test
+ @Order(4)
+ void edits() {
+ EditRequest request = new EditRequest();
+ request.setModel("text-davinci-edit-001");
+ request.setInput("What day of the wek is it?");
+ request.setInstruction("Fix the spelling mistakes");
+
+ EditResponse response = openAiService.edits(request);
+ assertNotNull(response);
+ Console.log(JsonUtils.toJsonPrettyString(response));
+ }
+
+ @Test
+ @Order(5)
+ void chatCompletions() {
+ List messages = new ArrayList<>();
+ ChatCompletionMessage message = new ChatCompletionMessage();
+ message.setRole("user");
+ message.setContent("Hello");
+ messages.add(message);
+
+ ChatCompletionRequest request = new ChatCompletionRequest();
+ request.setModel("gpt-3.5-turbo");
+ request.setMessages(messages);
+
+ ChatCompletionResponse response = openAiService.chatCompletions(request);
+ assertNotNull(response);
+ Console.log(JsonUtils.toJsonPrettyString(response));
+ }
+
+ @Test
+ @Order(6)
+ void streamChatCompletions() throws InterruptedException {
+ List messages = new ArrayList<>();
+ ChatCompletionMessage message = new ChatCompletionMessage();
+ message.setRole("user");
+ message.setContent("Hello");
+ messages.add(message);
+
+ ChatCompletionRequest request = new ChatCompletionRequest();
+ request.setStream(true);
+ request.setModel("gpt-3.5-turbo");
+ request.setMessages(messages);
+
+ final CountDownLatch countDownLatch = new CountDownLatch(1);
+ CustomDownLatchEventFutureCallback eventSourceListener = new CustomDownLatchEventFutureCallback<>(countDownLatch);
+ assertDoesNotThrow(() -> openAiService.streamChatCompletions(request, eventSourceListener));
+ countDownLatch.await();
+ }
+
+ @Test
+ @Order(7)
+ void models() {
+ ListModelsResponse response = openAiService.models();
+
+ assertNotNull(response);
+ Console.log(JsonUtils.toJsonPrettyString(response));
+ }
+
+ @Test
+ @Order(8)
+ void retrieveModel() {
+ RetrieveModelResponse response = openAiService.retrieveModel("babbage");
+
+ assertNotNull(response);
+ Console.log(JsonUtils.toJsonPrettyString(response));
+ }
+
+ @Test
+ @Order(9)
+ void embeddings() {
+ EmbeddingRequest request = new EmbeddingRequest();
+ request.setModel("text-embedding-ada-002");
+ request.setInput(ListUtil.of("The food was delicious and the waiter..."));
+ EmbeddingResponse response = openAiService.embeddings(request);
+
+ assertNotNull(response);
+ Console.log(JsonUtils.toJsonPrettyString(response));
+ }
+
+ @Test
+ @Order(10)
+ void uploadFile() {
+ final String path = "C:\\Users\\lzhpo\\Desktop\\xxx.txt";
+ FileSystemResource fileResource = new FileSystemResource(path);
+
+ UploadFileResponse response = openAiService.uploadFile(fileResource, "fine-tune");
+ assertNotNull(response);
+ Console.log(JsonUtils.toJsonPrettyString(response));
+ }
+
+ @Test
+ @Order(11)
+ void listFiles() {
+ ListFileResponse response = openAiService.listFiles();
+
+ assertNotNull(response);
+ Console.log(JsonUtils.toJsonPrettyString(response));
+ }
+
+ @Test
+ @Order(12)
+ void retrieveFile() {
+ final String fileId = "file-xxx";
+ RetrieveFileResponse response = openAiService.retrieveFile(fileId);
+
+ assertNotNull(response);
+ Console.log(JsonUtils.toJsonPrettyString(response));
+ }
+
+ @Test
+ @Order(13)
+ void deleteFile() {
+ final String fileId = "file-xxx";
+ DeleteFileResponse response = openAiService.deleteFile(fileId);
+
+ assertNotNull(response);
+ Console.log(JsonUtils.toJsonPrettyString(response));
+ }
+
+ @Test
+ @Order(14)
+ void createFineTune() {
+ CreateFineTuneRequest request = new CreateFineTuneRequest();
+ request.setTrainingFile("file-xxx");
+
+ CreateFineTuneResponse response = openAiService.createFineTune(request);
+ assertNotNull(response);
+ Console.log(JsonUtils.toJsonPrettyString(response));
+ }
+
+ @Test
+ @Order(15)
+ void listFineTunes() {
+ ListFineTuneResponse response = openAiService.listFineTunes();
+
+ assertNotNull(response);
+ Console.log(JsonUtils.toJsonPrettyString(response));
+ }
+
+ @Test
+ @Order(16)
+ void retrieveFineTunes() {
+ final String fineTuneId = "ft-xxx";
+ RetrieveFineTuneResponse response = openAiService.retrieveFineTunes(fineTuneId);
+
+ assertNotNull(response);
+ Console.log(JsonUtils.toJsonPrettyString(response));
+ }
+
+ @Test
+ @Order(17)
+ void cancelFineTune() {
+ final String fineTuneId = "ft-xxx";
+ CancelFineTuneResponse response = openAiService.cancelFineTune(fineTuneId);
+
+ assertNotNull(response);
+ Console.log(JsonUtils.toJsonPrettyString(response));
+ }
+
+ @Test
+ @Order(18)
+ void listFineTuneEvents() {
+ final String fineTuneId = "ft-xxx";
+ ListFineTuneEventResponse response = openAiService.listFineTuneEvents(fineTuneId);
+
+ assertNotNull(response);
+ Console.log(JsonUtils.toJsonPrettyString(response));
+ }
+
+ @Test
+ @Order(19)
+ void deleteFineTuneModel() {
+ final String modelId = "curie:ft-xxx";
+ DeleteFineTuneModelResponse response = openAiService.deleteFineTuneModel(modelId);
+
+ assertNotNull(response);
+ Console.log(JsonUtils.toJsonPrettyString(response));
+ }
+
+ @Test
+ @Order(20)
+ void createTranscription() {
+ final String path = "C:\\Users\\lzhpo\\Downloads\\xxx.mp3";
+ FileSystemResource fileResource = new FileSystemResource(path);
+ CreateAudioRequest request = new CreateAudioRequest();
+ request.setModel("whisper-1");
+
+ CreateAudioResponse response = openAiService.createTranscription(fileResource, request);
+ assertNotNull(response);
+ Console.log(JsonUtils.toJsonPrettyString(response));
+ }
+
+ @Test
+ @Order(21)
+ void createTranslation() {
+ final String path = "C:\\Users\\lzhpo\\Downloads\\xxx.mp3";
+ FileSystemResource fileResource = new FileSystemResource(path);
+ CreateAudioRequest request = new CreateAudioRequest();
+ request.setModel("whisper-1");
+
+ CreateAudioResponse response = openAiService.createTranslation(fileResource, request);
+ assertNotNull(response);
+ Console.log(JsonUtils.toJsonPrettyString(response));
+ }
+
+ @Test
+ @Order(22)
+ void createImage() {
+ CreateImageRequest request = new CreateImageRequest();
+ request.setPrompt("A cute baby sea otter.");
+ request.setN(2);
+ request.setSize(CreateImageSize.X_512_512);
+ request.setResponseFormat(CreateImageResponseFormat.URL);
+ CreateImageResponse response = openAiService.createImage(request);
+
+ assertNotNull(response);
+ Console.log(JsonUtils.toJsonPrettyString(response));
+ }
+
+ @Test
+ @Order(23)
+ void createImageEdit() {
+ final String imagePath = "C:\\Users\\lzhpo\\Downloads\\img-xxx.png";
+ FileSystemResource imageResource = new FileSystemResource(imagePath);
+
+ final String markPath = "C:\\Users\\lzhpo\\Downloads\\img-xxx.png";
+ FileSystemResource markResource = new FileSystemResource(markPath);
+
+ CreateImageRequest request = new CreateImageRequest();
+ request.setPrompt("A cute baby sea otter.");
+ request.setN(2);
+ request.setSize(CreateImageSize.X_512_512);
+ request.setResponseFormat(CreateImageResponseFormat.URL);
+ CreateImageResponse response = openAiService.createImageEdit(imageResource, markResource, request);
+
+ assertNotNull(response);
+ Console.log(JsonUtils.toJsonPrettyString(response));
+ }
+
+ @Test
+ @Order(24)
+ void createImageVariation() {
+ final String imagePath = "C:\\Users\\lzhpo\\Downloads\\img-xxx.png";
+ FileSystemResource imageResource = new FileSystemResource(imagePath);
+
+ CreateImageVariationRequest request = new CreateImageVariationRequest();
+ request.setN(2);
+ request.setSize(CreateImageSize.X_512_512);
+ request.setResponseFormat(CreateImageResponseFormat.URL);
+ CreateImageResponse response = openAiService.createImageVariation(imageResource, request);
+
+ assertNotNull(response);
+ Console.log(JsonUtils.toJsonPrettyString(response));
+ }
+
+ @Test
+ @Order(25)
+ void billingCreditGrants() {
+ CreditGrantsResponse response = openAiService.billingCreditGrants();
+
+ assertNotNull(response);
+ Console.log(JsonUtils.toJsonPrettyString(response));
+ }
+
+ @Test
+ @Order(26)
+ void users() {
+ UserResponse response = openAiService.users("org-xxx");
+
+ assertNotNull(response);
+ Console.log(JsonUtils.toJsonPrettyString(response));
+ }
+
+ @Test
+ @Order(27)
+ void billingSubscription() {
+ SubscriptionResponse response = openAiService.billingSubscription();
+
+ assertNotNull(response);
+ Console.log(JsonUtils.toJsonPrettyString(response));
+ }
+
+ @Test
+ @Order(28)
+ void billingUsage() {
+ Date nowDate = new Date();
+ String startDate = DateUtil.format(DateUtil.offsetDay(nowDate, -100), DatePattern.NORM_DATE_PATTERN);
+ String endDate = DateUtil.format(nowDate, DatePattern.NORM_DATE_PATTERN);
+ UsageResponse response = openAiService.billingUsage(startDate, endDate);
+
+ assertNotNull(response);
+ Console.log(JsonUtils.toJsonPrettyString(response));
+ }
+}
diff --git a/src/test/java/com/lzhpo/chatgpt/OpenAiTestApplication.java b/src/test/java/com/lzhpo/chatgpt/OpenAiTestApplication.java
index 8190637..1065b73 100644
--- a/src/test/java/com/lzhpo/chatgpt/OpenAiTestApplication.java
+++ b/src/test/java/com/lzhpo/chatgpt/OpenAiTestApplication.java
@@ -35,4 +35,5 @@ public static void main(String[] args) {
public ServerEndpointExporter serverEndpointExporter() {
return new ServerEndpointExporter();
}
+
}
diff --git a/src/test/java/com/lzhpo/chatgpt/OpenAiTestController.java b/src/test/java/com/lzhpo/chatgpt/OpenAiTestController.java
index 460b3bd..4b103b4 100644
--- a/src/test/java/com/lzhpo/chatgpt/OpenAiTestController.java
+++ b/src/test/java/com/lzhpo/chatgpt/OpenAiTestController.java
@@ -16,8 +16,10 @@
package com.lzhpo.chatgpt;
+import com.luna.common.thread.AsyncEngineUtils;
import com.lzhpo.chatgpt.entity.chat.ChatCompletionRequest;
import com.lzhpo.chatgpt.entity.chat.ChatCompletionResponse;
+import com.lzhpo.chatgpt.sse.HttpSseEventSourceListener;
import com.lzhpo.chatgpt.sse.SseEventSourceListener;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
@@ -26,6 +28,10 @@
import org.springframework.web.servlet.ModelAndView;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
+import java.time.LocalTime;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+
/**
* @author lzhpo
*/
@@ -35,8 +41,9 @@
@RequiredArgsConstructor
public class OpenAiTestController {
- private final OpenAiClient openAiClient;
private final OpenAiKeyWrapper openAiKeyWrapper;
+ private final DefaultOpenAiClient openAiClient;
+ private final HttpOpenAiClient httpOpenAiClient;
@GetMapping("/page/chat")
public ModelAndView chatView() {
@@ -69,4 +76,35 @@ public SseEmitter sseStreamChat(@RequestParam String message) {
openAiClient.streamChatCompletions(request, new SseEventSourceListener(sseEmitter));
return sseEmitter;
}
+
+ @ResponseBody
+ @GetMapping("/chat/http5/sse")
+ public SseEmitter sseStreamHttpChat(@RequestParam String message) {
+ SseEmitter sseEmitter = new SseEmitter();
+ ChatCompletionRequest request = ChatCompletionRequest.create(message);
+ httpOpenAiClient.streamChatCompletions(request, new HttpSseEventSourceListener<>(sseEmitter));
+ return sseEmitter;
+ }
+
+ @GetMapping("/stream-sse-mvc")
+ public SseEmitter streamSseMvc() {
+ SseEmitter emitter = new SseEmitter();
+ ExecutorService sseMvcExecutor = Executors.newSingleThreadExecutor();
+ sseMvcExecutor.execute(() -> {
+ try {
+ for (int i = 0; i < 5; i++) {
+ SseEmitter.SseEventBuilder event = SseEmitter.event()
+ .data("SSE MVC - " + LocalTime.now().toString())
+ .id(String.valueOf(i))
+ .name("sse event - mvc");
+ emitter.send(event);
+ Thread.sleep(1000);
+ }
+ emitter.complete();
+ } catch (Exception ex) {
+ emitter.completeWithError(ex);
+ }
+ });
+ return emitter;
+ }
}