diff --git a/pom.xml b/pom.xml index a0fb7af..da331ea 100644 --- a/pom.xml +++ b/pom.xml @@ -49,9 +49,33 @@ 1.2.7 5.8.18 4.10.0 + 2.4.4 + 5.2.1 0.4.0 + + + + + + org.apache.httpcomponents.client5 + httpclient5 + 5.2.1 + + + org.apache.httpcomponents.core5 + httpcore5 + 5.2.1 + + + org.apache.httpcomponents.core5 + httpcore5-h2 + 5.2.1 + + + + org.springframework.boot @@ -110,6 +134,12 @@ ${okhttp.version} + + io.github.lunasaw + luna-common + ${luna-common.version} + + com.knuddels jtokkit diff --git a/src/main/java/com/lzhpo/chatgpt/DefaultOpenAiClient.java b/src/main/java/com/lzhpo/chatgpt/DefaultOpenAiClient.java index 99a9a2d..44b15e2 100644 --- a/src/main/java/com/lzhpo/chatgpt/DefaultOpenAiClient.java +++ b/src/main/java/com/lzhpo/chatgpt/DefaultOpenAiClient.java @@ -49,6 +49,7 @@ import com.lzhpo.chatgpt.entity.moderations.ModerationRequest; import com.lzhpo.chatgpt.entity.moderations.ModerationResponse; import com.lzhpo.chatgpt.entity.users.UserResponse; +import com.lzhpo.chatgpt.sse.Listener; import com.lzhpo.chatgpt.utils.JsonUtils; import java.net.URI; import java.util.Map; @@ -92,10 +93,10 @@ public CompletionResponse completions(CompletionRequest request) { } @Override - public void streamCompletions(CompletionRequest request, EventSourceListener listener) { + public void streamCompletions(CompletionRequest request, Listener listener) { request.setStream(true); Request clientRequest = createRequest(COMPLETIONS, createRequestBody(request)); - RealEventSource realEventSource = new RealEventSource(clientRequest, listener); + RealEventSource realEventSource = new RealEventSource(clientRequest, (EventSourceListener) listener); realEventSource.connect(okHttpClient); } @@ -110,10 +111,10 @@ public ChatCompletionResponse chatCompletions(ChatCompletionRequest request) { } @Override - public void streamChatCompletions(ChatCompletionRequest request, EventSourceListener listener) { + public void streamChatCompletions(ChatCompletionRequest request, Listener listener) { request.setStream(true); Request clientRequest = createRequest(CHAT_COMPLETIONS, createRequestBody(request)); - RealEventSource realEventSource = new RealEventSource(clientRequest, listener); + RealEventSource realEventSource = new RealEventSource(clientRequest, (EventSourceListener) listener); realEventSource.connect(okHttpClient); } diff --git a/src/main/java/com/lzhpo/chatgpt/HttpOpenAiClient.java b/src/main/java/com/lzhpo/chatgpt/HttpOpenAiClient.java new file mode 100644 index 0000000..c99b4e9 --- /dev/null +++ b/src/main/java/com/lzhpo/chatgpt/HttpOpenAiClient.java @@ -0,0 +1,376 @@ +package com.lzhpo.chatgpt; + +import cn.hutool.core.io.file.FileNameUtil; +import cn.hutool.core.lang.WeightRandom; +import com.google.common.collect.Maps; +import com.luna.common.file.FileTools; +import com.luna.common.net.HttpUtils; +import com.luna.common.net.HttpUtilsConstant; +import com.luna.common.net.async.CustomSseAsyncConsumer; +import com.luna.common.net.hander.AbstactEventFutureCallback; +import com.luna.common.net.high.AsyncHttpUtils; +import com.luna.common.net.sse.Event; +import com.luna.common.net.sse.SseResponse; +import com.luna.common.text.CharsetUtil; +import com.luna.common.thread.AsyncEngineUtils; +import com.lzhpo.chatgpt.entity.audio.CreateAudioRequest; +import com.lzhpo.chatgpt.entity.audio.CreateAudioResponse; +import com.lzhpo.chatgpt.entity.billing.CreditGrantsResponse; +import com.lzhpo.chatgpt.entity.billing.SubscriptionResponse; +import com.lzhpo.chatgpt.entity.billing.UsageResponse; +import com.lzhpo.chatgpt.entity.chat.ChatCompletionRequest; +import com.lzhpo.chatgpt.entity.chat.ChatCompletionResponse; +import com.lzhpo.chatgpt.entity.completions.CompletionRequest; +import com.lzhpo.chatgpt.entity.completions.CompletionResponse; +import com.lzhpo.chatgpt.entity.edit.EditRequest; +import com.lzhpo.chatgpt.entity.edit.EditResponse; +import com.lzhpo.chatgpt.entity.embeddings.EmbeddingRequest; +import com.lzhpo.chatgpt.entity.embeddings.EmbeddingResponse; +import com.lzhpo.chatgpt.entity.files.DeleteFileResponse; +import com.lzhpo.chatgpt.entity.files.ListFileResponse; +import com.lzhpo.chatgpt.entity.files.RetrieveFileResponse; +import com.lzhpo.chatgpt.entity.files.UploadFileResponse; +import com.lzhpo.chatgpt.entity.finetunes.*; +import com.lzhpo.chatgpt.entity.image.CreateImageRequest; +import com.lzhpo.chatgpt.entity.image.CreateImageResponse; +import com.lzhpo.chatgpt.entity.image.CreateImageVariationRequest; +import com.lzhpo.chatgpt.entity.model.ListModelsResponse; +import com.lzhpo.chatgpt.entity.model.RetrieveModelResponse; +import com.lzhpo.chatgpt.entity.moderations.ModerationRequest; +import com.lzhpo.chatgpt.entity.moderations.ModerationResponse; +import com.lzhpo.chatgpt.entity.users.UserResponse; +import com.lzhpo.chatgpt.sse.Listener; +import com.lzhpo.chatgpt.utils.JsonUtils; +import lombok.RequiredArgsConstructor; +import lombok.SneakyThrows; +import lombok.extern.slf4j.Slf4j; +import org.apache.commons.collections4.MapUtils; +import org.apache.commons.io.IOUtils; +import org.apache.hc.client5.http.entity.mime.HttpMultipartMode; +import org.apache.hc.client5.http.entity.mime.MultipartEntityBuilder; +import org.apache.hc.client5.http.impl.classic.BasicHttpClientResponseHandler; +import org.apache.hc.core5.http.ContentType; +import org.apache.hc.core5.http.HttpEntity; +import org.apache.hc.core5.http.HttpHeaders; +import org.apache.hc.core5.http.io.entity.StringEntity; +import org.apache.hc.core5.http.nio.AsyncRequestProducer; +import org.apache.hc.core5.http.nio.entity.StringAsyncEntityProducer; +import org.jetbrains.annotations.NotNull; +import org.springframework.boot.context.properties.PropertyMapper; +import org.springframework.core.io.Resource; +import org.springframework.http.HttpMethod; +import org.springframework.util.Assert; +import org.springframework.util.StringUtils; +import org.springframework.validation.annotation.Validated; +import org.springframework.web.util.UriTemplateHandler; + +import java.io.IOException; +import java.net.URI; +import java.nio.charset.Charset; +import java.util.HashMap; +import java.util.Map; + +import static com.lzhpo.chatgpt.OpenAiConstant.*; +import static com.lzhpo.chatgpt.OpenAiUrl.*; + +/** + * @author luna + * @description + * @date 2023/4/22 + */ +@Slf4j +@Validated +@RequiredArgsConstructor +public class HttpOpenAiClient implements OpenAiClient { + + private final OpenAiProperties openAiProperties; + private final UriTemplateHandler uriTemplateHandler; + private final OpenAiKeyWrapper openAiKeyWrapper; + + + @Override + public ModerationResponse moderations(ModerationRequest request) { + return execute(MODERATIONS, createRequestBody(request), ModerationResponse.class); + } + + @Override + public CompletionResponse completions(CompletionRequest request) { + return execute(COMPLETIONS, createRequestBody(request), CompletionResponse.class); + } + + @Override + public void streamCompletions(CompletionRequest request, Listener eventListener) { + request.setStream(true); + doRequestAsync(CHAT_COMPLETIONS, JsonUtils.toJsonString(request), eventListener); + } + + @Override + public EditResponse edits(EditRequest request) { + return execute(EDITS, createRequestBody(request), EditResponse.class); + } + + @Override + public ChatCompletionResponse chatCompletions(ChatCompletionRequest request) { + return execute(CHAT_COMPLETIONS, createRequestBody(request), ChatCompletionResponse.class); + } + + @Override + public void streamChatCompletions(ChatCompletionRequest request, Listener eventListener) { + request.setStream(true); + doRequestAsync(CHAT_COMPLETIONS, JsonUtils.toJsonString(request), eventListener); + } + + @Override + public ListModelsResponse models() { + return execute(LIST_MODELS, null, ListModelsResponse.class); + } + + @Override + public RetrieveModelResponse retrieveModel(String modelId) { + return execute(RETRIEVE_MODEL, null, RetrieveModelResponse.class, modelId); + } + + @Override + public EmbeddingResponse embeddings(EmbeddingRequest request) { + return execute(EMBEDDINGS, createRequestBody(request), EmbeddingResponse.class); + } + + @Override + public ListFileResponse listFiles() { + return execute(LIST_FILES, null, ListFileResponse.class); + } + + @Override + @SneakyThrows + public UploadFileResponse uploadFile(Resource fileResource, String purpose) { + HashMap map = Maps.newHashMap(); + map.put("purpose", purpose); + map.put("file", fileResource.getFile().getAbsolutePath()); + HttpEntity requestBody = createRequestBody(map); + return execute(UPLOAD_FILE, requestBody, UploadFileResponse.class); + } + + @Override + public DeleteFileResponse deleteFile(String fileId) { + return execute(DELETE_FILE, null, DeleteFileResponse.class, fileId); + } + + @Override + public RetrieveFileResponse retrieveFile(String fileId) { + return execute(RETRIEVE_FILE, null, RetrieveFileResponse.class, fileId); + } + + @Override + public CreateFineTuneResponse createFineTune(CreateFineTuneRequest request) { + return execute(CREATE_FINE_TUNE, createRequestBody(request), CreateFineTuneResponse.class); + } + + @Override + public ListFineTuneResponse listFineTunes() { + return execute(LIST_FINE_TUNE, null, ListFineTuneResponse.class); + } + + @Override + public RetrieveFineTuneResponse retrieveFineTunes(String fineTuneId) { + return execute(RETRIEVE_FINE_TUNE, null, RetrieveFineTuneResponse.class, fineTuneId); + } + + @Override + public CancelFineTuneResponse cancelFineTune(String fineTuneId) { + return execute(CANCEL_FINE_TUNE, createRequestBody(null), CancelFineTuneResponse.class, fineTuneId); + } + + @Override + public ListFineTuneEventResponse listFineTuneEvents(String fineTuneId) { + return execute(LIST_FINE_TUNE_EVENTS, null, ListFineTuneEventResponse.class, fineTuneId); + } + + @Override + public DeleteFineTuneModelResponse deleteFineTuneModel(String model) { + return execute(DELETE_FINE_TUNE_EVENTS, null, DeleteFineTuneModelResponse.class, model); + } + + @Override + public CreateAudioResponse createTranscription(Resource fileResource, CreateAudioRequest request) { + HttpEntity audioBody = createAudioBody(fileResource, request); + return execute(CREATE_TRANSCRIPTION, audioBody, CreateAudioResponse.class); + } + + @Override + public CreateAudioResponse createTranslation(Resource fileResource, CreateAudioRequest request) { + HttpEntity multipartBody = createAudioBody(fileResource, request); + return execute(CREATE_TRANSLATION, multipartBody, CreateAudioResponse.class); + } + + @Override + public CreateImageResponse createImage(CreateImageRequest request) { + return execute(CREATE_IMAGE, createRequestBody(request), CreateImageResponse.class); + } + + @Override + @SneakyThrows + public CreateImageResponse createImageEdit(Resource image, Resource mask, CreateImageRequest request) { + HttpEntity imageBody = createImageBody(image, mask, request); + return execute(CREATE_TRANSCRIPTION, imageBody, CreateImageResponse.class); + } + + @Override + @SneakyThrows + public CreateImageResponse createImageVariation(Resource image, CreateImageVariationRequest request) { + + HttpEntity requestBody = buildImageFormBody(image, request); + return execute(CREATE_IMAGE_VARIATION, requestBody, CreateImageResponse.class); + } + + + @Override + public CreditGrantsResponse billingCreditGrants() { + return execute(BILLING_CREDIT_GRANTS, null, CreditGrantsResponse.class); + } + + @Override + public UserResponse users(String organizationId) { + return execute(USERS, null, UserResponse.class, organizationId); + } + + @Override + public SubscriptionResponse billingSubscription() { + return execute(BILLING_SUBSCRIPTION, null, SubscriptionResponse.class); + } + + @Override + public UsageResponse billingUsage(String startDate, String endDate) { + return execute(BILLING_USAGE, null, UsageResponse.class, startDate, endDate); + } + + @SneakyThrows + private S execute(OpenAiUrl openAiUrl, HttpEntity requestBody, Class responseType, Object... uriVariables) { + return execute(openAiUrl, new HashMap<>(), requestBody, responseType, uriVariables); + } + + @SneakyThrows + private S execute(OpenAiUrl openAiUrl, Map body, HttpEntity requestBody, Class responseType, Object... uriVariables) { + String request = doRequest(openAiUrl, body, requestBody, uriVariables); + Assert.notNull(request, "Resolve response body failed."); + return JsonUtils.parse(request, responseType); + } + + private String doRequest(OpenAiUrl openAiUrl, Map body, HttpEntity httpEntity, Object... uriVariables) { + WeightRandom weightRandom = openAiKeyWrapper.wrap(); + Map configUrls = openAiProperties.getUrls(); + + URI requestUri = getUri(configUrls, openAiUrl, uriVariables); + + Map header = Maps.newHashMap(); + header.put(HttpHeaders.AUTHORIZATION, BEARER.concat(weightRandom.next())); + + String result; + if (HttpMethod.POST.toString().equals(openAiUrl.getMethod())) { + header.put(HttpHeaders.CONTENT_TYPE, HttpUtilsConstant.JSON); + result = HttpUtils.doPost(openAiProperties.getDomain(), requestUri.getPath(), header, body, httpEntity, new BasicHttpClientResponseHandler()); + } else if (HttpMethod.GET.toString().equals(openAiUrl.getMethod())) { + result = HttpUtils.doGet(openAiProperties.getDomain(), requestUri.getPath(), header, body, new BasicHttpClientResponseHandler()); + } else { + throw new OpenAiException("不支持的请求方式"); + } + + return result; + } + + public void doRequestAsync(OpenAiUrl openAiUrl, String body, Listener callback, Object... uriVariable) { + WeightRandom weightRandom = openAiKeyWrapper.wrap(); + Map configUrls = openAiProperties.getUrls(); + URI requestUri = getUri(configUrls, openAiUrl, uriVariable); + Map header = Maps.newHashMap(); + header.put(HttpHeaders.AUTHORIZATION, BEARER.concat(weightRandom.next())); + + StringAsyncEntityProducer bodyProducer = new StringAsyncEntityProducer(body); + AsyncRequestProducer producer = AsyncHttpUtils.getProducer(openAiProperties.getDomain(), requestUri.getPath(), header, new HashMap<>(), bodyProducer, HttpMethod.GET.toString()); + + // 事件处理器 + CustomSseAsyncConsumer customSseAsyncConsumer = new CustomSseAsyncConsumer((AbstactEventFutureCallback) callback); + AsyncEngineUtils.execute(() -> AsyncHttpUtils.doAsyncRequest(producer, customSseAsyncConsumer, null)); + } + + @NotNull + private URI getUri(Map configUrls, OpenAiUrl openAiUrl, Object... uriVariables) { + String url = configUrls.get(openAiUrl); + if (!StringUtils.hasText(url)) { + url = openAiUrl.getSuffix(); + } + URI requestUri = uriTemplateHandler.expand(url, uriVariables); + return requestUri; + } + + + private HttpEntity createRequestBody(Object request) { + String jsonString = JsonUtils.toJsonString(request); + return new StringEntity(jsonString, Charset.defaultCharset()); + } + + private HttpEntity createRequestBody(Map bodies) { + MultipartEntityBuilder builder = MultipartEntityBuilder.create(); + builder.setMode(HttpMultipartMode.LEGACY); + builder.setCharset(CharsetUtil.defaultCharset()); + builder.setContentType(ContentType.MULTIPART_FORM_DATA); + if (MapUtils.isNotEmpty(bodies)) { + bodies.forEach((k, v) -> { + if (FileTools.isExists(v)) { + builder.addBinaryBody(k, IOUtils.toInputStream(v, CharsetUtil.defaultCharset())); + } else { + builder.addTextBody(k, v); + } + }); + } + return builder.build(); + } + + @SneakyThrows + private HttpEntity createImageBody(Resource image, Resource mask, CreateImageRequest request) { + boolean imageIsPng = FileNameUtil.isType(image.getFilename(), EXPECTED_IMAGE_TYPE); + boolean maskIsPng = FileNameUtil.isType(mask.getFilename(), EXPECTED_IMAGE_TYPE); + Assert.isTrue(imageIsPng, "The image must png type."); + Assert.isTrue(maskIsPng, "The mask must png type."); + + Assert.isTrue(image.contentLength() < MAX_IMAGE_SIZE, "The image must less than 4MB."); + Assert.isTrue(mask.contentLength() < MAX_IMAGE_SIZE, "The mask must less than 4MB."); + PropertyMapper mapper = PropertyMapper.get().alwaysApplyingWhenNonNull(); + HashMap map = Maps.newHashMap(); + + map.put("image", image.getFile().getAbsolutePath()); + map.put("mask", mask.getFile().getAbsolutePath()); + mapper.from(request.getPrompt()).to(prompt -> map.put("prompt", prompt)); + return createRequestBody(map); + } + + @NotNull + private HttpEntity buildImageFormBody(Resource image, CreateImageVariationRequest request) throws IOException { + boolean imageIsPng = FileNameUtil.isType(image.getFilename(), EXPECTED_IMAGE_TYPE); + Assert.isTrue(imageIsPng, "The image must png type."); + Assert.isTrue(image.contentLength() < MAX_IMAGE_SIZE, "The image must less than 4MB."); + + HashMap hashMap = Maps.newHashMap(); + PropertyMapper mapper = PropertyMapper.get().alwaysApplyingWhenNonNull(); + mapper.from(request.getN()).to(n -> hashMap.put("n", n.toString())); + mapper.from(request.getSize()).to(size -> hashMap.put("size", size.getValue())); + mapper.from(request.getResponseFormat()).to(obj -> hashMap.put("response_format", obj.getValue())); + mapper.from(request.getUser()).to(user -> hashMap.put("user", user)); + hashMap.put("images", image.getFile().getAbsolutePath()); + return createRequestBody(hashMap); + } + + @SneakyThrows + private HttpEntity createAudioBody(Resource fileResource, CreateAudioRequest request) { + + HashMap map = Maps.newHashMap(); + map.put("file", fileResource.getFile().getAbsolutePath()); + PropertyMapper mapper = PropertyMapper.get().alwaysApplyingWhenNonNull(); + mapper.from(request.getModel()).to(model -> map.put("model", model)); + mapper.from(request.getPrompt()).to(prompt -> map.put("prompt", prompt)); + mapper.from(request.getResponseFormat()).to(format -> map.put("response_format", format)); + mapper.from(request.getTemperature()).to(obj -> map.put("temperature", obj.toString())); + mapper.from(request.getLanguage()).to(language -> map.put("language", language)); + + return createRequestBody(map); + } +} diff --git a/src/main/java/com/lzhpo/chatgpt/OpenAiAutoConfiguration.java b/src/main/java/com/lzhpo/chatgpt/OpenAiAutoConfiguration.java index 70fb6b9..13259ff 100644 --- a/src/main/java/com/lzhpo/chatgpt/OpenAiAutoConfiguration.java +++ b/src/main/java/com/lzhpo/chatgpt/OpenAiAutoConfiguration.java @@ -5,16 +5,23 @@ import java.net.Proxy; import java.util.List; import java.util.Optional; +import java.util.Set; +import java.util.stream.Collectors; + +import com.luna.common.net.HttpUtils; +import com.luna.common.net.high.AsyncHttpUtils; import lombok.RequiredArgsConstructor; import okhttp3.Credentials; import okhttp3.Interceptor; import okhttp3.OkHttpClient; import org.springframework.beans.factory.ObjectProvider; +import org.springframework.beans.factory.annotation.Qualifier; import org.springframework.boot.autoconfigure.AutoConfiguration; import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean; import org.springframework.boot.context.properties.EnableConfigurationProperties; import org.springframework.boot.context.properties.PropertyMapper; import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Scope; import org.springframework.web.util.DefaultUriBuilderFactory; import org.springframework.web.util.UriTemplateHandler; @@ -52,7 +59,7 @@ public OkHttpClient okHttpClient(List interceptors) { @Bean @ConditionalOnMissingBean - public OpenAiClient openAiService( + public DefaultOpenAiClient openAiService( OkHttpClient okHttpClient, OpenAiKeyWrapper openAiKeyWrapper, ObjectProvider uriTemplateHandlerObjectProvider) { @@ -70,6 +77,24 @@ public OpenAiKeyWrapper openAiKeyWrapper(OpenAiKeyProvider openAiKeyProvider) { return new OpenAiKeyWrapper(openAiKeyProvider); } + @Bean(name = "httpOpenAiService") + @ConditionalOnMissingBean + public HttpOpenAiClient openAiService( + OpenAiKeyWrapper openAiKeyWrapper, + ObjectProvider uriTemplateHandlerObjectProvider) { + UriTemplateHandler uriTemplateHandler = uriTemplateHandlerObjectProvider.getIfAvailable(() -> { + DefaultUriBuilderFactory uriBuilderFactory = new DefaultUriBuilderFactory(); + uriBuilderFactory.setEncodingMode(DefaultUriBuilderFactory.EncodingMode.URI_COMPONENT); + return uriBuilderFactory; + }); + PropertyMapper mapper = PropertyMapper.get().alwaysApplyingWhenNonNull(); + mapper.from(openAiProperties::getReadTimeout).to(e->HttpUtils.setResponseTimeout((int) e.getSeconds())); + mapper.from(openAiProperties::getWriteTimeout).to(e->HttpUtils.setSocketTimeOut((int) e.getSeconds())); + mapper.from(openAiProperties::getConnectTimeout).to(e->HttpUtils.setConnectTimeout((int) e.getSeconds())); + Optional.ofNullable(openAiProperties.getProxy()).ifPresent(e-> AsyncHttpUtils.setProxy(e.getHost(), e.getPort())); + return new HttpOpenAiClient(openAiProperties, uriTemplateHandler, openAiKeyWrapper); + } + @Bean @ConditionalOnMissingBean public OpenAiKeyProvider openAiKeyProvider() { diff --git a/src/main/java/com/lzhpo/chatgpt/OpenAiClient.java b/src/main/java/com/lzhpo/chatgpt/OpenAiClient.java index 0fd7d18..656943d 100644 --- a/src/main/java/com/lzhpo/chatgpt/OpenAiClient.java +++ b/src/main/java/com/lzhpo/chatgpt/OpenAiClient.java @@ -45,6 +45,8 @@ import javax.validation.Valid; import javax.validation.constraints.NotBlank; import javax.validation.constraints.NotNull; + +import com.lzhpo.chatgpt.sse.Listener; import okhttp3.sse.EventSourceListener; import org.springframework.core.io.Resource; @@ -75,7 +77,7 @@ public interface OpenAiClient { * @param request {@link CompletionRequest} * @param listener {@link EventSourceListener} */ - void streamCompletions(@Valid CompletionRequest request, @NotNull EventSourceListener listener); + void streamCompletions(@Valid CompletionRequest request, @NotNull Listener listener); /** * Create edit. @@ -99,7 +101,7 @@ public interface OpenAiClient { * @param request {@link ChatCompletionRequest} * @param listener {@link EventSourceListener} */ - void streamChatCompletions(@Valid ChatCompletionRequest request, @NotNull EventSourceListener listener); + void streamChatCompletions(@Valid ChatCompletionRequest request, @NotNull Listener listener); /** * List models. diff --git a/src/main/java/com/lzhpo/chatgpt/sse/AbstractFutureCallback.java b/src/main/java/com/lzhpo/chatgpt/sse/AbstractFutureCallback.java new file mode 100644 index 0000000..95c4586 --- /dev/null +++ b/src/main/java/com/lzhpo/chatgpt/sse/AbstractFutureCallback.java @@ -0,0 +1,11 @@ +package com.lzhpo.chatgpt.sse; + +import com.luna.common.net.hander.AbstactEventFutureCallback; + +/** + * @author luna + * @description + * @date 2023/4/28 + */ +public class AbstractFutureCallback extends AbstactEventFutureCallback implements Listener { +} diff --git a/src/main/java/com/lzhpo/chatgpt/sse/CountDownLatchEventSourceListener.java b/src/main/java/com/lzhpo/chatgpt/sse/CountDownLatchEventSourceListener.java index 44f18d2..9329355 100644 --- a/src/main/java/com/lzhpo/chatgpt/sse/CountDownLatchEventSourceListener.java +++ b/src/main/java/com/lzhpo/chatgpt/sse/CountDownLatchEventSourceListener.java @@ -30,7 +30,7 @@ * @author lzhpo */ @Slf4j -public class CountDownLatchEventSourceListener extends AbstractEventSourceListener { +public class CountDownLatchEventSourceListener extends AbstractEventSourceListener implements Listener{ private final CountDownLatch countDownLatch; diff --git a/src/main/java/com/lzhpo/chatgpt/sse/CustomDownLatchEventFutureCallback.java b/src/main/java/com/lzhpo/chatgpt/sse/CustomDownLatchEventFutureCallback.java new file mode 100644 index 0000000..558d347 --- /dev/null +++ b/src/main/java/com/lzhpo/chatgpt/sse/CustomDownLatchEventFutureCallback.java @@ -0,0 +1,42 @@ +package com.lzhpo.chatgpt.sse; + +import com.alibaba.fastjson2.JSON; +import lombok.extern.slf4j.Slf4j; +import org.springframework.util.Assert; + +import java.util.concurrent.CountDownLatch; + +/** + * @author luna + * @description + * @date 2023/4/23 + */ +@Slf4j +public class CustomDownLatchEventFutureCallback extends AbstractFutureCallback implements Listener { + + private final CountDownLatch countDownLatch; + + public CustomDownLatchEventFutureCallback(CountDownLatch countDownLatch) { + Assert.notNull(countDownLatch, "countDownLatch cannot null."); + this.countDownLatch = countDownLatch; + } + + @Override + public void completed(T result) { + log.info("completed::result = {}", JSON.toJSONString(result)); + countDownLatch.countDown(); + } + + @Override + public void failed(Exception ex) { + super.failed(ex); + cancelled(); + countDownLatch.countDown(); + } + + @Override + public void cancelled() { + super.cancelled(); + countDownLatch.countDown(); + } +} diff --git a/src/main/java/com/lzhpo/chatgpt/sse/HttpSseEventSourceListener.java b/src/main/java/com/lzhpo/chatgpt/sse/HttpSseEventSourceListener.java new file mode 100644 index 0000000..6383223 --- /dev/null +++ b/src/main/java/com/lzhpo/chatgpt/sse/HttpSseEventSourceListener.java @@ -0,0 +1,53 @@ +package com.lzhpo.chatgpt.sse; +import lombok.extern.slf4j.Slf4j; +import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; + +import java.io.IOException; + +/** + * @author luna + * @description + * @date 2023/4/28 + */ +@Slf4j +public class HttpSseEventSourceListener extends AbstractFutureCallback { + + private SseEmitter sseEmitter; + + public HttpSseEventSourceListener(SseEmitter sseEmitter) { + this.sseEmitter = sseEmitter; + } + + @Override + public void onEvent(R result) { + try { + sseEmitter.send(result); + } catch (IOException e) { + log.error("onEvent::result = {} ", result, e); + } + } + + @Override + public void completed(T result) { + sseEmitter.complete(); + } + + @Override + public void failed(Exception ex) { + } + + @Override + public void cancelled() { + super.cancelled(); + } + + public SseEmitter getSseEmitter() { + return sseEmitter; + } + + public void setSseEmitter(SseEmitter sseEmitter) { + this.sseEmitter = sseEmitter; + } + + +} diff --git a/src/main/java/com/lzhpo/chatgpt/sse/Listener.java b/src/main/java/com/lzhpo/chatgpt/sse/Listener.java new file mode 100644 index 0000000..971cbd4 --- /dev/null +++ b/src/main/java/com/lzhpo/chatgpt/sse/Listener.java @@ -0,0 +1,9 @@ +package com.lzhpo.chatgpt.sse; + +/** + * @author luna + * @description + * @date 2023/4/23 + */ +public interface Listener { +} diff --git a/src/main/java/com/lzhpo/chatgpt/sse/SseEventSourceListener.java b/src/main/java/com/lzhpo/chatgpt/sse/SseEventSourceListener.java index 4f9d069..5e2d52d 100644 --- a/src/main/java/com/lzhpo/chatgpt/sse/SseEventSourceListener.java +++ b/src/main/java/com/lzhpo/chatgpt/sse/SseEventSourceListener.java @@ -30,7 +30,7 @@ * @author lzhpo */ @Slf4j -public class SseEventSourceListener extends AbstractEventSourceListener { +public class SseEventSourceListener extends AbstractEventSourceListener implements Listener{ private final SseEmitter sseEmitter; diff --git a/src/main/java/com/lzhpo/chatgpt/sse/WebSocketEventSourceListener.java b/src/main/java/com/lzhpo/chatgpt/sse/WebSocketEventSourceListener.java index eb97e05..46bfc56 100644 --- a/src/main/java/com/lzhpo/chatgpt/sse/WebSocketEventSourceListener.java +++ b/src/main/java/com/lzhpo/chatgpt/sse/WebSocketEventSourceListener.java @@ -30,7 +30,7 @@ * @author lzhpo */ @Slf4j -public class WebSocketEventSourceListener extends AbstractEventSourceListener { +public class WebSocketEventSourceListener extends AbstractEventSourceListener implements Listener{ private final Session session; diff --git a/src/test/java/com/lzhpo/chatgpt/HttpOpenAiClientTest.java b/src/test/java/com/lzhpo/chatgpt/HttpOpenAiClientTest.java new file mode 100644 index 0000000..401b66f --- /dev/null +++ b/src/test/java/com/lzhpo/chatgpt/HttpOpenAiClientTest.java @@ -0,0 +1,420 @@ +/* + * Copyright 2023 luna + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.lzhpo.chatgpt; + +import cn.hutool.core.collection.ListUtil; +import cn.hutool.core.date.DatePattern; +import cn.hutool.core.date.DateUtil; +import cn.hutool.core.lang.Console; +import com.luna.common.net.sse.Event; +import com.luna.common.net.sse.SseResponse; +import com.lzhpo.chatgpt.entity.audio.CreateAudioRequest; +import com.lzhpo.chatgpt.entity.audio.CreateAudioResponse; +import com.lzhpo.chatgpt.entity.billing.CreditGrantsResponse; +import com.lzhpo.chatgpt.entity.billing.SubscriptionResponse; +import com.lzhpo.chatgpt.entity.billing.UsageResponse; +import com.lzhpo.chatgpt.entity.chat.ChatCompletionMessage; +import com.lzhpo.chatgpt.entity.chat.ChatCompletionRequest; +import com.lzhpo.chatgpt.entity.chat.ChatCompletionResponse; +import com.lzhpo.chatgpt.entity.completions.CompletionRequest; +import com.lzhpo.chatgpt.entity.completions.CompletionResponse; +import com.lzhpo.chatgpt.entity.edit.EditRequest; +import com.lzhpo.chatgpt.entity.edit.EditResponse; +import com.lzhpo.chatgpt.entity.embeddings.EmbeddingRequest; +import com.lzhpo.chatgpt.entity.embeddings.EmbeddingResponse; +import com.lzhpo.chatgpt.entity.files.DeleteFileResponse; +import com.lzhpo.chatgpt.entity.files.ListFileResponse; +import com.lzhpo.chatgpt.entity.files.RetrieveFileResponse; +import com.lzhpo.chatgpt.entity.files.UploadFileResponse; +import com.lzhpo.chatgpt.entity.finetunes.*; +import com.lzhpo.chatgpt.entity.image.*; +import com.lzhpo.chatgpt.entity.model.ListModelsResponse; +import com.lzhpo.chatgpt.entity.model.RetrieveModelResponse; +import com.lzhpo.chatgpt.entity.moderations.ModerationRequest; +import com.lzhpo.chatgpt.entity.moderations.ModerationResponse; +import com.lzhpo.chatgpt.entity.users.UserResponse; +import com.lzhpo.chatgpt.sse.CountDownLatchEventSourceListener; +import com.lzhpo.chatgpt.sse.CustomDownLatchEventFutureCallback; +import com.lzhpo.chatgpt.utils.JsonUtils; +import org.junit.jupiter.api.MethodOrderer; +import org.junit.jupiter.api.Order; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestMethodOrder; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.boot.test.mock.mockito.MockBean; +import org.springframework.core.io.FileSystemResource; +import org.springframework.web.socket.server.standard.ServerEndpointExporter; + +import java.util.ArrayList; +import java.util.Date; +import java.util.List; +import java.util.concurrent.CountDownLatch; + +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; +import static org.junit.jupiter.api.Assertions.assertNotNull; + +/** + * @author lzhpo + */ +@SpringBootTest +@TestMethodOrder(MethodOrderer.OrderAnnotation.class) +class HttpOpenAiClientTest { + + @Autowired + private HttpOpenAiClient openAiService; + + @MockBean + private ServerEndpointExporter serverEndpointExporter; + + @Test + @Order(1) + void moderations() { + ModerationRequest request = new ModerationRequest(); + request.setInput(ListUtil.of("I want to kill them.")); + + ModerationResponse response = openAiService.moderations(request); + assertNotNull(response); + Console.log(JsonUtils.toJsonPrettyString(response)); + } + + @Test + @Order(2) + void completions() { + CompletionRequest request = new CompletionRequest(); + request.setModel("text-davinci-003"); + request.setPrompt("Say this is a test"); + request.setMaxTokens(7); + request.setTemperature(0); + + CompletionResponse response = openAiService.completions(request); + assertNotNull(response); + Console.log(JsonUtils.toJsonPrettyString(response)); + } + + @Test + @Order(3) + void streamCompletions() throws InterruptedException { + CompletionRequest request = new CompletionRequest(); + request.setStream(true); + request.setModel("text-davinci-003"); + request.setPrompt("Say this is a test"); + request.setMaxTokens(7); + request.setTemperature(0); + + final CountDownLatch countDownLatch = new CountDownLatch(1); + CustomDownLatchEventFutureCallback futureCallback = new CustomDownLatchEventFutureCallback<>(countDownLatch); + assertDoesNotThrow(() -> openAiService.streamCompletions(request, futureCallback)); + countDownLatch.await(); + } + + @Test + @Order(4) + void edits() { + EditRequest request = new EditRequest(); + request.setModel("text-davinci-edit-001"); + request.setInput("What day of the wek is it?"); + request.setInstruction("Fix the spelling mistakes"); + + EditResponse response = openAiService.edits(request); + assertNotNull(response); + Console.log(JsonUtils.toJsonPrettyString(response)); + } + + @Test + @Order(5) + void chatCompletions() { + List messages = new ArrayList<>(); + ChatCompletionMessage message = new ChatCompletionMessage(); + message.setRole("user"); + message.setContent("Hello"); + messages.add(message); + + ChatCompletionRequest request = new ChatCompletionRequest(); + request.setModel("gpt-3.5-turbo"); + request.setMessages(messages); + + ChatCompletionResponse response = openAiService.chatCompletions(request); + assertNotNull(response); + Console.log(JsonUtils.toJsonPrettyString(response)); + } + + @Test + @Order(6) + void streamChatCompletions() throws InterruptedException { + List messages = new ArrayList<>(); + ChatCompletionMessage message = new ChatCompletionMessage(); + message.setRole("user"); + message.setContent("Hello"); + messages.add(message); + + ChatCompletionRequest request = new ChatCompletionRequest(); + request.setStream(true); + request.setModel("gpt-3.5-turbo"); + request.setMessages(messages); + + final CountDownLatch countDownLatch = new CountDownLatch(1); + CustomDownLatchEventFutureCallback eventSourceListener = new CustomDownLatchEventFutureCallback<>(countDownLatch); + assertDoesNotThrow(() -> openAiService.streamChatCompletions(request, eventSourceListener)); + countDownLatch.await(); + } + + @Test + @Order(7) + void models() { + ListModelsResponse response = openAiService.models(); + + assertNotNull(response); + Console.log(JsonUtils.toJsonPrettyString(response)); + } + + @Test + @Order(8) + void retrieveModel() { + RetrieveModelResponse response = openAiService.retrieveModel("babbage"); + + assertNotNull(response); + Console.log(JsonUtils.toJsonPrettyString(response)); + } + + @Test + @Order(9) + void embeddings() { + EmbeddingRequest request = new EmbeddingRequest(); + request.setModel("text-embedding-ada-002"); + request.setInput(ListUtil.of("The food was delicious and the waiter...")); + EmbeddingResponse response = openAiService.embeddings(request); + + assertNotNull(response); + Console.log(JsonUtils.toJsonPrettyString(response)); + } + + @Test + @Order(10) + void uploadFile() { + final String path = "C:\\Users\\lzhpo\\Desktop\\xxx.txt"; + FileSystemResource fileResource = new FileSystemResource(path); + + UploadFileResponse response = openAiService.uploadFile(fileResource, "fine-tune"); + assertNotNull(response); + Console.log(JsonUtils.toJsonPrettyString(response)); + } + + @Test + @Order(11) + void listFiles() { + ListFileResponse response = openAiService.listFiles(); + + assertNotNull(response); + Console.log(JsonUtils.toJsonPrettyString(response)); + } + + @Test + @Order(12) + void retrieveFile() { + final String fileId = "file-xxx"; + RetrieveFileResponse response = openAiService.retrieveFile(fileId); + + assertNotNull(response); + Console.log(JsonUtils.toJsonPrettyString(response)); + } + + @Test + @Order(13) + void deleteFile() { + final String fileId = "file-xxx"; + DeleteFileResponse response = openAiService.deleteFile(fileId); + + assertNotNull(response); + Console.log(JsonUtils.toJsonPrettyString(response)); + } + + @Test + @Order(14) + void createFineTune() { + CreateFineTuneRequest request = new CreateFineTuneRequest(); + request.setTrainingFile("file-xxx"); + + CreateFineTuneResponse response = openAiService.createFineTune(request); + assertNotNull(response); + Console.log(JsonUtils.toJsonPrettyString(response)); + } + + @Test + @Order(15) + void listFineTunes() { + ListFineTuneResponse response = openAiService.listFineTunes(); + + assertNotNull(response); + Console.log(JsonUtils.toJsonPrettyString(response)); + } + + @Test + @Order(16) + void retrieveFineTunes() { + final String fineTuneId = "ft-xxx"; + RetrieveFineTuneResponse response = openAiService.retrieveFineTunes(fineTuneId); + + assertNotNull(response); + Console.log(JsonUtils.toJsonPrettyString(response)); + } + + @Test + @Order(17) + void cancelFineTune() { + final String fineTuneId = "ft-xxx"; + CancelFineTuneResponse response = openAiService.cancelFineTune(fineTuneId); + + assertNotNull(response); + Console.log(JsonUtils.toJsonPrettyString(response)); + } + + @Test + @Order(18) + void listFineTuneEvents() { + final String fineTuneId = "ft-xxx"; + ListFineTuneEventResponse response = openAiService.listFineTuneEvents(fineTuneId); + + assertNotNull(response); + Console.log(JsonUtils.toJsonPrettyString(response)); + } + + @Test + @Order(19) + void deleteFineTuneModel() { + final String modelId = "curie:ft-xxx"; + DeleteFineTuneModelResponse response = openAiService.deleteFineTuneModel(modelId); + + assertNotNull(response); + Console.log(JsonUtils.toJsonPrettyString(response)); + } + + @Test + @Order(20) + void createTranscription() { + final String path = "C:\\Users\\lzhpo\\Downloads\\xxx.mp3"; + FileSystemResource fileResource = new FileSystemResource(path); + CreateAudioRequest request = new CreateAudioRequest(); + request.setModel("whisper-1"); + + CreateAudioResponse response = openAiService.createTranscription(fileResource, request); + assertNotNull(response); + Console.log(JsonUtils.toJsonPrettyString(response)); + } + + @Test + @Order(21) + void createTranslation() { + final String path = "C:\\Users\\lzhpo\\Downloads\\xxx.mp3"; + FileSystemResource fileResource = new FileSystemResource(path); + CreateAudioRequest request = new CreateAudioRequest(); + request.setModel("whisper-1"); + + CreateAudioResponse response = openAiService.createTranslation(fileResource, request); + assertNotNull(response); + Console.log(JsonUtils.toJsonPrettyString(response)); + } + + @Test + @Order(22) + void createImage() { + CreateImageRequest request = new CreateImageRequest(); + request.setPrompt("A cute baby sea otter."); + request.setN(2); + request.setSize(CreateImageSize.X_512_512); + request.setResponseFormat(CreateImageResponseFormat.URL); + CreateImageResponse response = openAiService.createImage(request); + + assertNotNull(response); + Console.log(JsonUtils.toJsonPrettyString(response)); + } + + @Test + @Order(23) + void createImageEdit() { + final String imagePath = "C:\\Users\\lzhpo\\Downloads\\img-xxx.png"; + FileSystemResource imageResource = new FileSystemResource(imagePath); + + final String markPath = "C:\\Users\\lzhpo\\Downloads\\img-xxx.png"; + FileSystemResource markResource = new FileSystemResource(markPath); + + CreateImageRequest request = new CreateImageRequest(); + request.setPrompt("A cute baby sea otter."); + request.setN(2); + request.setSize(CreateImageSize.X_512_512); + request.setResponseFormat(CreateImageResponseFormat.URL); + CreateImageResponse response = openAiService.createImageEdit(imageResource, markResource, request); + + assertNotNull(response); + Console.log(JsonUtils.toJsonPrettyString(response)); + } + + @Test + @Order(24) + void createImageVariation() { + final String imagePath = "C:\\Users\\lzhpo\\Downloads\\img-xxx.png"; + FileSystemResource imageResource = new FileSystemResource(imagePath); + + CreateImageVariationRequest request = new CreateImageVariationRequest(); + request.setN(2); + request.setSize(CreateImageSize.X_512_512); + request.setResponseFormat(CreateImageResponseFormat.URL); + CreateImageResponse response = openAiService.createImageVariation(imageResource, request); + + assertNotNull(response); + Console.log(JsonUtils.toJsonPrettyString(response)); + } + + @Test + @Order(25) + void billingCreditGrants() { + CreditGrantsResponse response = openAiService.billingCreditGrants(); + + assertNotNull(response); + Console.log(JsonUtils.toJsonPrettyString(response)); + } + + @Test + @Order(26) + void users() { + UserResponse response = openAiService.users("org-xxx"); + + assertNotNull(response); + Console.log(JsonUtils.toJsonPrettyString(response)); + } + + @Test + @Order(27) + void billingSubscription() { + SubscriptionResponse response = openAiService.billingSubscription(); + + assertNotNull(response); + Console.log(JsonUtils.toJsonPrettyString(response)); + } + + @Test + @Order(28) + void billingUsage() { + Date nowDate = new Date(); + String startDate = DateUtil.format(DateUtil.offsetDay(nowDate, -100), DatePattern.NORM_DATE_PATTERN); + String endDate = DateUtil.format(nowDate, DatePattern.NORM_DATE_PATTERN); + UsageResponse response = openAiService.billingUsage(startDate, endDate); + + assertNotNull(response); + Console.log(JsonUtils.toJsonPrettyString(response)); + } +} diff --git a/src/test/java/com/lzhpo/chatgpt/OpenAiTestApplication.java b/src/test/java/com/lzhpo/chatgpt/OpenAiTestApplication.java index 8190637..1065b73 100644 --- a/src/test/java/com/lzhpo/chatgpt/OpenAiTestApplication.java +++ b/src/test/java/com/lzhpo/chatgpt/OpenAiTestApplication.java @@ -35,4 +35,5 @@ public static void main(String[] args) { public ServerEndpointExporter serverEndpointExporter() { return new ServerEndpointExporter(); } + } diff --git a/src/test/java/com/lzhpo/chatgpt/OpenAiTestController.java b/src/test/java/com/lzhpo/chatgpt/OpenAiTestController.java index 460b3bd..4b103b4 100644 --- a/src/test/java/com/lzhpo/chatgpt/OpenAiTestController.java +++ b/src/test/java/com/lzhpo/chatgpt/OpenAiTestController.java @@ -16,8 +16,10 @@ package com.lzhpo.chatgpt; +import com.luna.common.thread.AsyncEngineUtils; import com.lzhpo.chatgpt.entity.chat.ChatCompletionRequest; import com.lzhpo.chatgpt.entity.chat.ChatCompletionResponse; +import com.lzhpo.chatgpt.sse.HttpSseEventSourceListener; import com.lzhpo.chatgpt.sse.SseEventSourceListener; import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; @@ -26,6 +28,10 @@ import org.springframework.web.servlet.ModelAndView; import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; +import java.time.LocalTime; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; + /** * @author lzhpo */ @@ -35,8 +41,9 @@ @RequiredArgsConstructor public class OpenAiTestController { - private final OpenAiClient openAiClient; private final OpenAiKeyWrapper openAiKeyWrapper; + private final DefaultOpenAiClient openAiClient; + private final HttpOpenAiClient httpOpenAiClient; @GetMapping("/page/chat") public ModelAndView chatView() { @@ -69,4 +76,35 @@ public SseEmitter sseStreamChat(@RequestParam String message) { openAiClient.streamChatCompletions(request, new SseEventSourceListener(sseEmitter)); return sseEmitter; } + + @ResponseBody + @GetMapping("/chat/http5/sse") + public SseEmitter sseStreamHttpChat(@RequestParam String message) { + SseEmitter sseEmitter = new SseEmitter(); + ChatCompletionRequest request = ChatCompletionRequest.create(message); + httpOpenAiClient.streamChatCompletions(request, new HttpSseEventSourceListener<>(sseEmitter)); + return sseEmitter; + } + + @GetMapping("/stream-sse-mvc") + public SseEmitter streamSseMvc() { + SseEmitter emitter = new SseEmitter(); + ExecutorService sseMvcExecutor = Executors.newSingleThreadExecutor(); + sseMvcExecutor.execute(() -> { + try { + for (int i = 0; i < 5; i++) { + SseEmitter.SseEventBuilder event = SseEmitter.event() + .data("SSE MVC - " + LocalTime.now().toString()) + .id(String.valueOf(i)) + .name("sse event - mvc"); + emitter.send(event); + Thread.sleep(1000); + } + emitter.complete(); + } catch (Exception ex) { + emitter.completeWithError(ex); + } + }); + return emitter; + } }