Skip to content

Commit

Permalink
Create ProcessSummarizer.java
Browse files Browse the repository at this point in the history
  • Loading branch information
xdnw committed Jul 13, 2023
1 parent b828201 commit cc451b1
Showing 1 changed file with 66 additions and 0 deletions.
66 changes: 66 additions & 0 deletions src/main/java/link/locutus/discord/gpt/imps/ProcessSummarizer.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
package link.locutus.discord.gpt.imps;

import com.knuddels.jtokkit.api.ModelType;
import com.theokanning.openai.completion.CompletionChoice;
import com.theokanning.openai.completion.CompletionRequest;
import com.theokanning.openai.completion.CompletionResult;
import link.locutus.discord.gpt.GPTUtil;
import link.locutus.discord.gpt.ISummarizer;
import link.locutus.discord.util.StringMan;

import java.io.BufferedReader;
import java.io.File;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.Base64;
import java.util.List;

public class ProcessSummarizer implements ISummarizer {
private final File file;
private final String prompt;
private final int promptTokens;
private final ModelType model;

public ProcessSummarizer(File file) {
this.file = file;
this.prompt = """
Write a concise summary which preserves syntax, equations, arguments and constraints of the following:
{query}
Concise summary:""";
this.model = ModelType.GPT_4;
this.promptTokens = GPTUtil.getTokens(prompt.replace("{query}", ""), model);
}
@Override
public String summarize(String text) {
int cap = 4096 - 4;
int remaining = cap - promptTokens;
List<String> summaries = new ArrayList<>();
for (String chunk : GPTUtil.getChunks(text, model, remaining)) {
String result = summarizeChunk(chunk);
summaries.add(result);
}
return String.join("\n", summaries);
}

public String summarizeChunk(String chunk) {
String full = prompt.replace("{query}", chunk);

String encodedString = Base64.getEncoder().encodeToString(full.getBytes());
List<String> lines = new ArrayList<>();
ProcessBuilder pb = new ProcessBuilder("python", file.getAbsolutePath(), encodedString).redirectErrorStream(true);
try {
Process p = pb.start();
BufferedReader reader = new BufferedReader(new InputStreamReader(p.getInputStream()));
String line;
while ((line = reader.readLine()) != null) {
lines.add(line);
}
} catch (IOException e) {
throw new RuntimeException(e);
}
return StringMan.join(lines, "\n");
}
}

0 comments on commit cc451b1

Please sign in to comment.