Skip to content

Commit

Permalink
Updated to DALL-E 3
Browse files Browse the repository at this point in the history
  • Loading branch information
cjmalloy committed Nov 13, 2023
1 parent ad79888 commit 427ba4d
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 198 deletions.
2 changes: 1 addition & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
<varv.version>0.10.4</varv.version>
<jsoup.version>1.16.1</jsoup.version>
<scim2-client.version>2.3.8</scim2-client.version>
<openai-gpt.version>0.16.1</openai-gpt.version>
<openai-gpt.version>0.17.0</openai-gpt.version>
</properties>

<dependencies>
Expand Down
204 changes: 20 additions & 184 deletions src/main/java/jasper/component/OpenAi.java
Original file line number Diff line number Diff line change
Expand Up @@ -60,89 +60,6 @@ public class OpenAi {
@Autowired
ObjectMapper objectMapper;

private AiConfig getConfig() {
var key = refRepository.findAll(selector("_openai/key" + props.getLocalOrigin()).refSpec());
if (key.isEmpty()) {
throw new NotFoundException("requires openai api key");
}
var aiPlugin = pluginRepository.findByTagAndOrigin("+plugin/openai", props.getLocalOrigin())
.orElseThrow(() -> new NotFoundException("+plugin/openai"));
var config = objectMapper.convertValue(aiPlugin.getConfig(), AiConfig.class);
if (isBlank(config.fineTuning)) return config;

ObjectMapper mapper = defaultObjectMapper();
OkHttpClient client = defaultClient(key.get(0).getComment(), Duration.ofSeconds(200));
Retrofit retrofit = defaultRetrofit(client, mapper);
var api = retrofit.create(OpenAiApi.class);
var service = new OpenAiService(api);
var fileId = config.fileId;
// if (fileId != null) {
// try {
// service.deleteFile(fileId);
// } catch (Exception e) {
// logger.warn("Deleting file {} failed.", fileId);
// }
// fileId = null;
// }
if (fileId == null) {
RequestBody purposeBody = RequestBody.create(okhttp3.MultipartBody.FORM, "fine-tune");
RequestBody fileBody = RequestBody.create(config.fineTuning, TEXT);
MultipartBody.Part body = MultipartBody.Part.createFormData("file", "fine-tune", fileBody);
fileId = execute(api.uploadFile(purposeBody, body)).getId();
config.fileId = fileId;
}
// for (var f : service.listFiles()) {
// if (!f.getId().equals(fileId)) {
// try {
// service.deleteFile(f.getId());
// } catch (Exception e) {
// logger.warn("Deleting file {} failed.", f.getId());
// }
// }
// }
var ftId = config.ftId;
// if (ftId != null) {
// try {
// service.deleteFineTune(ftId);
// } catch (Exception e) {
// logger.warn("Deleting fine tune {} failed.", ftId);
// }
// ftId = null;
// }
if (ftId == null) {
var fineTuneRequest = FineTuneRequest.builder()
.model("davinci")
.trainingFile(fileId)
.build();
ftId = service.createFineTune(fineTuneRequest).getId();
config.ftId = ftId;
}
// for (var ft : service.listFineTunes()) {
// if (!ft.getStatus().equals("cancelled") && !ft.getId().equals(ftId)) {
// try {
// service.deleteFineTune(ft.getId());
// } catch (Exception e) {
// logger.warn("Deleting fine tune {} failed.", ft.getId());
// }
// }
// }
var fineTunedModel = config.fineTunedModel;
if (fineTunedModel == null) {
var res = service.retrieveFineTune(ftId);
if (res.getStatus().equals("cancelled")) {
config.ftId = null;
} else {
if (res.getStatus().equals("succeeded")) {
fineTunedModel = res.getFineTunedModel();
config.fineTunedModel = fineTunedModel;
}
}
}
aiPlugin.setConfig(objectMapper.convertValue(config, JsonNode.class));
pluginRepository.save(aiPlugin);
return config;
}


public CompletionResult completion(String systemPrompt, String prompt) {
var key = refRepository.findAll(selector("_openai/key" + props.getLocalOrigin()).refSpec());
Expand Down Expand Up @@ -181,45 +98,24 @@ public CompletionResult completion(String systemPrompt, String prompt) {
}
}


public ChatCompletionResult chatCompletion(String systemPrompt, String prompt) {
public ChatCompletionResult chatCompletion(String prompt, AiConfig config) {
var key = refRepository.findAll(selector("_openai/key" + props.getLocalOrigin()).refSpec());
if (key.isEmpty()) {
throw new NotFoundException("requires openai api key");
}
var service = new OpenAiService(key.get(0).getComment(), Duration.ofSeconds(200));
var completionRequest = ChatCompletionRequest.builder()
.model("gpt-4-1106-preview")
.maxTokens(64_000)
var completionRequest = ChatCompletionRequest
.builder()
.model(config.model)
.maxTokens(config.maxTokens)
.messages(List.of(
cm("system", systemPrompt),
cm("system", config.systemPrompt),
cm("user", prompt)
))
.build();
try {
return service.createChatCompletion(completionRequest);
} catch (OpenAiHttpException e) {
if ("context_length_exceeded".equals(e.code)) {
completionRequest.setMaxTokens(400);
try {
return service.createChatCompletion(completionRequest);
} catch (OpenAiHttpException second) {
if ("context_length_exceeded".equals(second.code)) {
completionRequest.setMaxTokens(20);
try {
return service.createChatCompletion(completionRequest);
} catch (OpenAiHttpException third) {
throw e;
}
}
throw e;
}
}
throw e;
}
return service.createChatCompletion(completionRequest);
}


public ImageResult dale(String prompt, DalleConfig config) {
var key = refRepository.findAll(selector("_openai/key" + props.getLocalOrigin()).refSpec());
if (key.isEmpty()) {
Expand All @@ -228,85 +124,26 @@ public ImageResult dale(String prompt, DalleConfig config) {
var service = new OpenAiService(key.get(0).getComment(), Duration.ofSeconds(200));
var imageRequest = CreateImageRequest.builder()
.prompt(prompt)
.model(config.model)
.size(config.size)
.quality(config.quality)
.build();
try {
return service.createImage(imageRequest);
} catch (OpenAiHttpException e) {
throw e;
}
return service.createImage(imageRequest);
}

public CompletionResult fineTunedCompletion(List<ChatMessage> messages) {
var key = refRepository.findAll(selector("_openai/key" + props.getLocalOrigin()).refSpec());
if (key.isEmpty()) {
throw new NotFoundException("requires openai api key");
}
var service = new OpenAiService(key.get(0).getComment(), Duration.ofSeconds(200));
var completionRequest = CompletionRequest.builder()
.maxTokens(1024)
.prompt(messages.stream().map(ChatMessage::getContent).collect(Collectors.joining("\n")))
.model("text-davinci-003")
.stop(List.of("Prompt:", "Reply:"))
.build();
try {
return service.createCompletion(completionRequest);
} catch (OpenAiHttpException e) {
if ("context_length_exceeded".equals(e.code)) {
completionRequest.setMaxTokens(400);
try {
return service.createCompletion(completionRequest);
} catch (OpenAiHttpException second) {
if ("context_length_exceeded".equals(second.code)) {
completionRequest.setMaxTokens(20);
try {
return service.createCompletion(completionRequest);
} catch (OpenAiHttpException third) {
throw e;
}
}
throw e;
}
}
throw e;
}
}

public ChatCompletionResult chat(String model, List<ChatMessage> messages) throws JsonProcessingException {
public ChatCompletionResult chat(List<ChatMessage> messages, AiConfig config) {
var key = refRepository.findAll(selector("_openai/key" + props.getLocalOrigin()).refSpec());
if (key.isEmpty()) {
throw new NotFoundException("requires openai api key");
}
OpenAiService service = new OpenAiService(key.get(0).getComment(), Duration.ofSeconds(200));
ChatCompletionRequest completionRequest = ChatCompletionRequest.builder()
.maxTokens(4096)
ChatCompletionRequest completionRequest = ChatCompletionRequest
.builder()
.model(config.model)
.maxTokens(config.maxTokens)
.messages(messages)
.model(model)
.build();
try {
return service.createChatCompletion(completionRequest);
} catch (OpenAiHttpException e) {
logger.error("context_length_exceeded {}", 4096);
if ("context_length_exceeded".equals(e.code)) {
try {
completionRequest.setMaxTokens(400);
return service.createChatCompletion(completionRequest);
} catch (OpenAiHttpException second) {
logger.error("context_length_exceeded {}", 400);
if ("context_length_exceeded".equals(second.code)) {
try {
completionRequest.setMaxTokens(20);
return service.createChatCompletion(completionRequest);
} catch (OpenAiHttpException third) {
logger.error("context_length_exceeded {}", 20);
throw e;
}
}
throw e;
}
}
throw e;
}
return service.createChatCompletion(completionRequest);
}

public static ChatMessage cm(String origin, String role, String title, String content, ObjectMapper om) {
Expand Down Expand Up @@ -344,16 +181,15 @@ public static String ref(String origin, String role, String title, String conten
}

public static class AiConfig {
public String model;
public String model = "gpt-4-1106-preview";
public int maxTokens = 4096;
public String systemPrompt;
public String fineTuning;
public String fileId;
public String ftId;
public String fineTunedModel;
}

public static class DalleConfig {
public String size = "1024x1024";
public String model = "dall-e-3";
public String quality = "hd";
}

}
8 changes: 6 additions & 2 deletions src/main/java/jasper/component/delta/Ai.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import jasper.repository.TemplateRepository;
import lombok.Getter;
import lombok.Setter;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
Expand All @@ -34,6 +35,7 @@
import java.util.UUID;
import java.util.stream.Collectors;

import static com.rometools.utils.Strings.isBlank;
import static jasper.component.OpenAi.cm;
import static jasper.repository.spec.RefSpec.hasInternalResponse;
import static jasper.repository.spec.RefSpec.hasResponse;
Expand Down Expand Up @@ -167,7 +169,7 @@ You may only use public tags (starting with a lowercase letter or number) and yo
}
messages.add(cm("user", objectMapper.writeValueAsString(sample)));
messages.add(cm(ref.getOrigin(), "system", "Output format instructions", instructions, objectMapper));
var res = openAi.chat(config.model, messages);
var res = openAi.chat(messages, config);
var reply = res.getChoices().stream().map(ChatCompletionChoice::getMessage).map(ChatMessage::getContent).collect(Collectors.joining("\n\n"));
response.setUrl("ai:" + res.getId());
response.setPlugin("+plugin/openai", objectMapper.convertValue(res.getUsage(), JsonNode.class));
Expand Down Expand Up @@ -246,7 +248,9 @@ You may only use public tags (starting with a lowercase letter or number) and yo
t -> t.matches(Tag.REGEX) && (t.equals("+plugin/openai") || !t.startsWith("+") && !t.startsWith("_"))
).collect(Collectors.toList()));
}

if (isBlank(aiReply.getUrl())) {
aiReply.setUrl("ai:" + UUID.randomUUID());
}
ingest.ingest(aiReply, false);
logger.debug("AI reply sent ({})", aiReply.getUrl());
}
Expand Down
18 changes: 7 additions & 11 deletions src/main/java/jasper/component/delta/Summary.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,24 +6,23 @@
import com.theokanning.openai.completion.chat.ChatMessage;
import jasper.component.Ingest;
import jasper.component.OpenAi;
import jasper.component.OpenAi.AiConfig;
import jasper.component.scheduler.Async;
import jasper.domain.Ref;
import jasper.domain.User;
import jasper.errors.NotFoundException;
import jasper.repository.PluginRepository;
import lombok.Getter;
import lombok.Setter;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Profile;
import org.springframework.stereotype.Component;

import javax.annotation.PostConstruct;
import java.util.ArrayList;
import java.util.List;
import java.util.UUID;
import java.util.stream.Collectors;
import javax.annotation.PostConstruct;

@Profile("ai")
@Component
Expand Down Expand Up @@ -64,10 +63,10 @@ public void run(Ref ref) {
var config = objectMapper.convertValue(summaryPlugin.getConfig(), SummaryConfig.class);
var response = new Ref();
try {
var res = openAi.chatCompletion(config.getSystemPrompt(), String.join("\n\n",
var res = openAi.chatCompletion(String.join("\n\n",
"Title: " + ref.getTitle(),
"Tags: " + String.join(", ", ref.getTags()),
ref.getComment()));
ref.getComment()), config);
response.setComment(res.getChoices().stream()
.map(ChatCompletionChoice::getMessage)
.map(ChatMessage::getContent)
Expand All @@ -79,7 +78,7 @@ public void run(Ref ref) {
response.setUrl("internal:" + UUID.randomUUID());
}
var title = ref.getTitle();
if (!title.startsWith(config.getTitlePrefix())) title = config.titlePrefix + title;
if (!title.startsWith(config.titlePrefix)) title = config.titlePrefix + title;
response.setTitle(title);
response.setOrigin(ref.getOrigin());
var tags = new ArrayList<String>();
Expand Down Expand Up @@ -117,10 +116,7 @@ public void run(Ref ref) {
ingest.ingest(response, false);
}

@Getter
@Setter
private static class SummaryConfig {
private String titlePrefix;
private String systemPrompt;
private static class SummaryConfig extends AiConfig {
public String titlePrefix;
}
}

0 comments on commit 427ba4d

Please sign in to comment.