From 5669cb2e2660ee2bc8f37b32eadd480fd3af9fc2 Mon Sep 17 00:00:00 2001 From: Chris Malloy Date: Thu, 5 Oct 2023 00:15:46 -0300 Subject: [PATCH] Added DALL-E --- src/main/java/jasper/component/OpenAi.java | 25 ++++ .../java/jasper/component/delta/Dalle.java | 138 ++++++++++++++++++ 2 files changed, 163 insertions(+) create mode 100644 src/main/java/jasper/component/delta/Dalle.java diff --git a/src/main/java/jasper/component/OpenAi.java b/src/main/java/jasper/component/OpenAi.java index 4638552f..6b7d4d2f 100644 --- a/src/main/java/jasper/component/OpenAi.java +++ b/src/main/java/jasper/component/OpenAi.java @@ -11,6 +11,8 @@ import com.theokanning.openai.completion.chat.ChatCompletionResult; import com.theokanning.openai.completion.chat.ChatMessage; import com.theokanning.openai.finetune.FineTuneRequest; +import com.theokanning.openai.image.CreateImageRequest; +import com.theokanning.openai.image.ImageResult; import com.theokanning.openai.service.OpenAiService; import jasper.client.dto.RefDto; import jasper.config.Props; @@ -217,6 +219,24 @@ public ChatCompletionResult chatCompletion(String systemPrompt, String prompt) { } } + + public ImageResult dale(String prompt, DalleConfig config) { + var key = refRepository.findAll(selector("_openai/key" + props.getLocalOrigin()).refSpec()); + if (key.isEmpty()) { + throw new NotFoundException("requires openai api key"); + } + var service = new OpenAiService(key.get(0).getComment(), Duration.ofSeconds(200)); + var imageRequest = CreateImageRequest.builder() + .prompt(prompt) + .size(config.size) + .build(); + try { + return service.createImage(imageRequest); + } catch (OpenAiHttpException e) { + throw e; + } + } + public CompletionResult fineTunedCompletion(List messages) { var key = refRepository.findAll(selector("_openai/key" + props.getLocalOrigin()).refSpec()); if (key.isEmpty()) { @@ -331,4 +351,9 @@ public static class AiConfig { public String ftId; public String fineTunedModel; } + + public static class DalleConfig { + public String size = "1024x1024"; + } + } diff --git a/src/main/java/jasper/component/delta/Dalle.java b/src/main/java/jasper/component/delta/Dalle.java new file mode 100644 index 00000000..feac8b5d --- /dev/null +++ b/src/main/java/jasper/component/delta/Dalle.java @@ -0,0 +1,138 @@ +package jasper.component.delta; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import jasper.component.Ingest; +import jasper.component.OpenAi; +import jasper.component.scheduler.Async; +import jasper.domain.Ext; +import jasper.domain.Plugin; +import jasper.domain.Ref; +import jasper.domain.Template; +import jasper.domain.User; +import jasper.errors.NotFoundException; +import jasper.repository.PluginRepository; +import jasper.repository.RefRepository; +import jasper.repository.TemplateRepository; +import lombok.Getter; +import lombok.Setter; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.context.annotation.Profile; +import org.springframework.stereotype.Component; + +import javax.annotation.PostConstruct; +import java.util.ArrayList; +import java.util.List; +import java.util.UUID; + +@Profile("ai") +@Component +public class Dalle implements Async.AsyncRunner { + private static final Logger logger = LoggerFactory.getLogger(Dalle.class); + + @Autowired + Async async; + + @Autowired + Ingest ingest; + + @Autowired + OpenAi openAi; + + @Autowired + RefRepository refRepository; + + @Autowired + PluginRepository pluginRepository; + + @Autowired + TemplateRepository templateRepository; + + @Autowired + ObjectMapper objectMapper; + + @Autowired + RefMapper refMapper; + + @PostConstruct + void init() { + async.addAsyncResponse("plugin/inbox/dalle", this); + } + + @Override + public String signature() { + return "+plugin/dalle"; + } + + @Override + public void run(Ref ref) throws JsonProcessingException { + logger.debug("AI replying to {} ({})", ref.getTitle(), ref.getUrl()); + var author = ref.getTags().stream().filter(User::isUser).findFirst().orElse(null); + var dallePlugin = pluginRepository.findByTagAndOrigin("+plugin/dalle", ref.getOrigin()) + .orElseThrow(() -> new NotFoundException("+plugin/dalle")); + var config = objectMapper.convertValue(dallePlugin.getConfig(), OpenAi.DalleConfig.class); + var response = new Ref(); + try { + var res = openAi.dale(ref.getTitle() + ": " + ref.getComment(), config); + response.setTitle("Re: " + ref.getTitle()); + response.setUrl(res.getData().get(0).getUrl()); + } catch (Exception e) { + response.setComment("Error invoking DALL-E. " + e.getMessage()); + response.setUrl("internal:" + UUID.randomUUID()); + } + if (ref.getTags().contains("public")) response.addTag("public"); + if (ref.getTags().contains("internal")) response.addTag("internal"); + if (ref.getTags().contains("dm")) response.addTag("dm"); + if (ref.getTags().contains("dm")) response.addTag("plugin/thread"); + if (ref.getTags().contains("plugin/email")) response.addTag("plugin/email"); + if (ref.getTags().contains("plugin/email")) response.addTag("plugin/thread"); + if (ref.getTags().contains("plugin/comment")) response.addTag("plugin/comment"); + if (ref.getTags().contains("plugin/comment")) response.addTag("plugin/thread"); + if (ref.getTags().contains("plugin/thread")) response.addTag("plugin/thread"); + response.addTag("plugin/image"); + var chat = false; + for (var t : ref.getTags()) { + if (t.startsWith("chat/") || t.equals("chat")) { + chat = true; + response.addTag(t); + } + } + if (!chat) { + if (author != null) response.addTag("plugin/inbox/" + author.substring(1)); + for (var t : ref.getTags()) { + if (t.startsWith("plugin/inbox/") || t.startsWith("plugin/outbox/")) { + response.addTag(t); + } + } + } + response.addTag("+plugin/dalle"); + response.getTags().remove("plugin/inbox/dalle"); + var sources = new ArrayList<>(List.of(ref.getUrl())); + if (response.getTags().contains("plugin/thread")) { + // Add top comment source + if (ref.getSources() != null && ref.getSources().size() > 0) { + if (ref.getSources().size() > 1) { + sources.add(ref.getSources().get(1)); + } else { + sources.add(ref.getSources().get(0)); + } + } + } + response.setSources(sources); + response.setOrigin(ref.getOrigin()); + ingest.ingest(response, false); + logger.debug("DALL-E reply sent ({})", response.getUrl()); + } + + @Getter + @Setter + private static class AiReply { + private Ref[] ref; + private Ext[] ext; + private Plugin[] plugin; + private Template[] template; + private User[] user; + } +}