Skip to content

Commit

Permalink
Send up to 7 levels of context to ai
Browse files Browse the repository at this point in the history
  • Loading branch information
cjmalloy committed Dec 22, 2023
1 parent 146af26 commit e337cb8
Showing 1 changed file with 21 additions and 5 deletions.
26 changes: 21 additions & 5 deletions src/main/java/jasper/component/delta/Ai.java
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
import static jasper.component.OpenAi.cm;
import static jasper.repository.spec.RefSpec.hasInternalResponse;
import static jasper.repository.spec.RefSpec.hasResponse;
import static jasper.repository.spec.RefSpec.isNotObsolete;
import static java.util.Optional.ofNullable;
import static java.util.stream.Stream.concat;
import static org.apache.commons.lang3.StringUtils.isBlank;
Expand Down Expand Up @@ -101,8 +102,23 @@ public void run(Ref ref) throws JsonProcessingException {
.orElseThrow(() -> new NotFoundException("+plugin/openai"));
var config = objectMapper.convertValue(aiPlugin.getConfig(), OpenAi.AiConfig.class);
// TODO: compress pages if too long
var parents = refRepository.findAll(hasResponse(ref.getUrl()).or(hasInternalResponse(ref.getUrl())), by(Ref_.PUBLISHED))
var context = new HashMap<String, RefDto>();
var parents = refRepository.findAll(isNotObsolete().and(hasResponse(ref.getUrl()).or(hasInternalResponse(ref.getUrl()))), by(Ref_.PUBLISHED))
.stream().map(refMapper::domainToDto).toList();
parents.forEach(p -> context.put(p.getUrl(), p));
for (var i = 0; i < 7; i++) {
if (parents.isEmpty()) break;
var grandParents = parents.stream().flatMap(p -> refRepository.findAll(isNotObsolete().and(hasResponse(p.getUrl()).or(hasInternalResponse(p.getUrl()))), by(Ref_.PUBLISHED)).stream())
.map(refMapper::domainToDto).toList();
var newParents = new ArrayList<RefDto>();
for (var p : grandParents) {
if (!context.containsKey(p.getUrl())) {
newParents.add(p);
context.put(p.getUrl(), p);
}
}
parents = newParents;
}
var exts = new HashMap<String, Ext>();
if (ref.getTags() != null) {
for (var t : ref.getTags()) {
Expand All @@ -111,7 +127,7 @@ public void run(Ref ref) throws JsonProcessingException {
if (ext.isPresent()) exts.put(qt, extRepository.findOneByQualifiedTag(qt).get());
}
}
for (var p : parents) {
for (var p : context.values()) {
for (var t : p.getTags()) {
var qt = t + ref.getOrigin();
var ext = extRepository.findOneByQualifiedTag(qt);
Expand All @@ -136,7 +152,7 @@ public void run(Ref ref) throws JsonProcessingException {
List<Ref> refArray = List.of(response);
for (var model : models) {
config.model = model;
var messages = getChatMessages(ref, exts.values(), plugins, templates, config, parents, author, sample);
var messages = getChatMessages(ref, exts.values(), plugins, templates, config, context.values(), author, sample);
try {
var res = openAi.chat(messages, config);
var reply = res.getChoices().stream().map(ChatCompletionChoice::getMessage).map(ChatMessage::getContent).collect(Collectors.joining("\n\n"));
Expand Down Expand Up @@ -234,7 +250,7 @@ public void run(Ref ref) throws JsonProcessingException {
}

@NotNull
private ArrayList<ChatMessage> getChatMessages(Ref ref, Collection<Ext> exts, List<Plugin> plugins, List<Template> templates, OpenAi.AiConfig config, List<RefDto> parents, String author, RefDto sample) throws JsonProcessingException {
private ArrayList<ChatMessage> getChatMessages(Ref ref, Collection<Ext> exts, Collection<Plugin> plugins, Collection<Template> templates, OpenAi.AiConfig config, Collection<RefDto> context, String author, RefDto sample) throws JsonProcessingException {
var modsPrompt = concat(
plugins.stream().map(Plugin::getConfig),
templates.stream().map(Template::getConfig)
Expand Down Expand Up @@ -389,7 +405,7 @@ You may only use public tags (starting with a lowercase letter or number) and yo
if (exts.isEmpty()) {
messages.add(cm(ref.getOrigin(), "system", "Exts", extsPrompt, objectMapper));
}
for (var p : parents) {
for (var p : context) {
p.setMetadata(null);
if (p.getTags().contains("+plugin/openai")) {
messages.add(cm("assistant", objectMapper.writeValueAsString(p)));
Expand Down

0 comments on commit e337cb8

Please sign in to comment.