From 2ec647f46ebbb82f991483e30bbae1b79d724562 Mon Sep 17 00:00:00 2001 From: xdnw Date: Thu, 13 Jul 2023 03:46:15 +0100 Subject: [PATCH] Add PW gpt tool --- .../manager/v2/impl/pw/CommandManager2.java | 1 + .../v2/impl/pw/commands/HelpCommands.java | 8 +++ .../discord/db/AEmbeddingDatabase.java | 71 ++++++++++--------- .../link/locutus/discord/gpt/GptHandler.java | 26 ++++++- .../discord/gpt/imps/GPTText2Text.java | 45 ++++++++++++ .../locutus/discord/gpt/imps/IText2Text.java | 5 ++ .../discord/gpt/imps/ProcessSummarizer.java | 16 ++++- .../discord/gpt/imps/ProcessText2Text.java | 55 ++++++++++++++ .../discord/gpt/pwembed/PWGPTHandler.java | 62 ++++++++++++++++ 9 files changed, 250 insertions(+), 39 deletions(-) create mode 100644 src/main/java/link/locutus/discord/gpt/imps/GPTText2Text.java create mode 100644 src/main/java/link/locutus/discord/gpt/imps/IText2Text.java create mode 100644 src/main/java/link/locutus/discord/gpt/imps/ProcessText2Text.java diff --git a/src/main/java/link/locutus/discord/commands/manager/v2/impl/pw/CommandManager2.java b/src/main/java/link/locutus/discord/commands/manager/v2/impl/pw/CommandManager2.java index ff08bf81..bce87f89 100644 --- a/src/main/java/link/locutus/discord/commands/manager/v2/impl/pw/CommandManager2.java +++ b/src/main/java/link/locutus/discord/commands/manager/v2/impl/pw/CommandManager2.java @@ -229,6 +229,7 @@ public CommandManager2 registerDefaults() { this.commands.registerMethod(help, List.of("help"), "find_setting", "find_setting"); this.commands.registerMethod(help, List.of("help"), "moderation_check", "moderation_check"); + this.commands.registerMethod(help, List.of("help"), "query", "query"); pwgptHandler.registerDefaults(); } diff --git a/src/main/java/link/locutus/discord/commands/manager/v2/impl/pw/commands/HelpCommands.java b/src/main/java/link/locutus/discord/commands/manager/v2/impl/pw/commands/HelpCommands.java index 8033c04e..bd1db3a7 100644 --- a/src/main/java/link/locutus/discord/commands/manager/v2/impl/pw/commands/HelpCommands.java +++ b/src/main/java/link/locutus/discord/commands/manager/v2/impl/pw/commands/HelpCommands.java @@ -12,6 +12,7 @@ import link.locutus.discord.commands.manager.v2.command.ParametricCallable; import link.locutus.discord.commands.manager.v2.impl.pw.CM; import link.locutus.discord.commands.manager.v2.perm.PermissionHandler; +import link.locutus.discord.db.GuildDB; import link.locutus.discord.db.guild.GuildSetting; import link.locutus.discord.gpt.pwembed.CommandEmbedding; import link.locutus.discord.gpt.imps.EmbeddingType; @@ -19,6 +20,7 @@ import link.locutus.discord.gpt.pwembed.PWEmbedding; import link.locutus.discord.gpt.pwembed.PWGPTHandler; import link.locutus.discord.gpt.pwembed.SettingEmbedding; +import net.dv8tion.jda.api.entities.User; import java.io.IOException; import java.util.List; @@ -44,6 +46,12 @@ public PWGPTHandler getGPT() { // // } + @Command + public String query(ValueStore store, @Me GuildDB db, @Me User user, @Me IMessageIO io, String input) throws IOException { + String result = getGPT().generateSolution(store, db, user, input); + return result; + } + @Command public void moderation_check(@Me IMessageIO io, String input) throws IOException { List inputs = List.of(input); diff --git a/src/main/java/link/locutus/discord/db/AEmbeddingDatabase.java b/src/main/java/link/locutus/discord/db/AEmbeddingDatabase.java index 9e613115..9ba20738 100644 --- a/src/main/java/link/locutus/discord/db/AEmbeddingDatabase.java +++ b/src/main/java/link/locutus/discord/db/AEmbeddingDatabase.java @@ -10,6 +10,7 @@ import javax.annotation.Nullable; import java.io.Closeable; import java.math.BigInteger; +import java.sql.ResultSet; import java.sql.SQLException; import java.util.concurrent.atomic.AtomicInteger; @@ -71,7 +72,7 @@ private void loadContent() { } @Override - public void createTables() { + public synchronized void createTables() { // embeddings ctx().createTableIfNotExists("embeddings_2") .column("hash", SQLDataType.BIGINT.notNull()) @@ -90,33 +91,35 @@ public void createTables() { // if table `embeddings` exists try { - if (getConnection().getMetaData().getTables(null, null, "embeddings", null).next()) { - AtomicInteger inserted = new AtomicInteger(); - // iterate over all rows - ctx().select().from("embeddings").fetch().forEach(r -> { - // get hash - long hash = r.get("hash", Long.class); - // get type - long type = r.get("type", Long.class); - // get id - String id = r.get("id", String.class); - // get data - byte[] data = r.get("data", byte[].class); - double[] vectors = ArrayUtil.toDoubleArray(data); - float[] downCast = new float[vectors.length]; - for (int i = 0; i < vectors.length; i++) { - downCast[i] = (float) vectors[i]; - } - byte[] downCastBytes = ArrayUtil.toByteArray(downCast); - addEmbedding(hash, type, id, downCastBytes); - inserted.incrementAndGet(); - }); - if (inserted.get() > 0) { - System.out.println("Inserted " + inserted.get() + " embeddings"); - // drop old table - ctx().dropTableIfExists("embeddings").execute(); + AtomicInteger inserted = new AtomicInteger(); + try (ResultSet query = getConnection().getMetaData().getTables(null, null, "embeddings", null)) { + if (query.next()) { + // iterate over all rows + ctx().select().from("embeddings").fetch().forEach(r -> { + // get hash + long hash = r.get("hash", Long.class); + // get type + long type = r.get("type", Long.class); + // get id + String id = r.get("id", String.class); + // get data + byte[] data = r.get("data", byte[].class); + double[] vectors = ArrayUtil.toDoubleArray(data); + float[] downCast = new float[vectors.length]; + for (int i = 0; i < vectors.length; i++) { + downCast[i] = (float) vectors[i]; + } + byte[] downCastBytes = ArrayUtil.toByteArray(downCast); + addEmbedding(hash, type, id, downCastBytes); + inserted.incrementAndGet(); + }); } } + if (inserted.get() > 0) { + System.out.println("Inserted " + inserted.get() + " embeddings"); + // drop old table + ctx().dropTableIfExists("embeddings").execute(); + } } catch (SQLException e) { e.printStackTrace(); } @@ -169,12 +172,14 @@ public void setEmbedding(int type, @Nullable String id2, String content, float[] } return; } - // delete from database - ctx().execute("DELETE FROM `embeddings` WHERE `hash` = ?", info.contentHash); + synchronized (this) { + // delete from database + ctx().execute("DELETE FROM `embeddings` WHERE `hash` = ?", info.contentHash); - // delete content if exists - if (hashContent.remove(info.contentHash) != null) { - ctx().execute("DELETE FROM `content` WHERE `hash` = ?", info.contentHash); + // delete content if exists + if (hashContent.remove(info.contentHash) != null) { + ctx().execute("DELETE FROM `content` WHERE `hash` = ?", info.contentHash); + } } System.out.println("Delete different embedding"); @@ -199,12 +204,12 @@ public void setEmbedding(int type, @Nullable String id2, String content, float[] } } - private void updateEmbedding(long contentHash, int type, String id) { + private synchronized void updateEmbedding(long contentHash, int type, String id) { if (id == null) id = ""; ctx().execute("UPDATE `embeddings_2` SET `type` = ?, `id` = ? WHERE `hash` = ?", type, id, contentHash); } - private void addContent(long hash, String content) { + private synchronized void addContent(long hash, String content) { ctx().execute("INSERT OR IGNORE INTO `content` (`hash`, `content`) VALUES (?, ?)", hash, content); } diff --git a/src/main/java/link/locutus/discord/gpt/GptHandler.java b/src/main/java/link/locutus/discord/gpt/GptHandler.java index 2e73f127..6e3fad9d 100644 --- a/src/main/java/link/locutus/discord/gpt/GptHandler.java +++ b/src/main/java/link/locutus/discord/gpt/GptHandler.java @@ -11,12 +11,15 @@ import link.locutus.discord.config.Settings; import link.locutus.discord.gpt.imps.AdaEmbedding; import link.locutus.discord.gpt.imps.GPTSummarizer; +import link.locutus.discord.gpt.imps.ProcessSummarizer; +import link.locutus.discord.gpt.imps.ProcessText2Text; import link.locutus.discord.util.FileUtil; import link.locutus.discord.util.math.ArrayUtil; import org.json.JSONArray; import org.json.JSONObject; import javax.annotation.Nullable; +import java.io.File; import java.io.IOException; import java.net.HttpURLConnection; import java.nio.charset.StandardCharsets; @@ -41,6 +44,7 @@ public class GptHandler { public final IEmbeddingDatabase embeddingDatabase; private final ISummarizer summarizer; private final IModerator moderator; + private final ProcessText2Text text2text; public GptHandler() throws SQLException, ClassNotFoundException { this.registry = Encodings.newDefaultEncodingRegistry(); @@ -50,10 +54,26 @@ public GptHandler() throws SQLException, ClassNotFoundException { this.platform = Platform.detectPlatform("pytorch"); - - this.summarizer = new GPTSummarizer(registry, service); - this.embeddingDatabase = new AdaEmbedding(registry, service); this.moderator = new GPTModerator(service); + this.embeddingDatabase = new AdaEmbedding(registry, service); + // TODO change ^ that to mini + + File gpt4freePath = new File("../gpt4free/mymain.py"); + File venvExe = new File("../gpt4free/venv/Scripts/python.exe"); + // ensure files exist + if (!gpt4freePath.exists()) { + throw new RuntimeException("gpt4free not found: " + gpt4freePath.getAbsolutePath()); + } + if (!venvExe.exists()) { + throw new RuntimeException("venv not found: " + venvExe.getAbsolutePath()); + } + + this.summarizer = new ProcessSummarizer(venvExe, gpt4freePath); + this.text2text = new ProcessText2Text(venvExe, gpt4freePath); + } + + public ProcessText2Text getText2text() { + return text2text; } public IModerator getModerator() { diff --git a/src/main/java/link/locutus/discord/gpt/imps/GPTText2Text.java b/src/main/java/link/locutus/discord/gpt/imps/GPTText2Text.java new file mode 100644 index 00000000..c9b6720b --- /dev/null +++ b/src/main/java/link/locutus/discord/gpt/imps/GPTText2Text.java @@ -0,0 +1,45 @@ +package link.locutus.discord.gpt.imps; + +import com.knuddels.jtokkit.api.Encoding; +import com.knuddels.jtokkit.api.EncodingRegistry; +import com.knuddels.jtokkit.api.ModelType; +import com.theokanning.openai.OpenAiService; +import com.theokanning.openai.completion.CompletionChoice; +import com.theokanning.openai.completion.CompletionRequest; +import com.theokanning.openai.completion.CompletionResult; +import link.locutus.discord.gpt.GPTUtil; +import link.locutus.discord.gpt.IEmbeddingDatabase; + +import java.util.ArrayList; +import java.util.List; + +public class GPTText2Text implements IText2Text{ + private final EncodingRegistry registry; + private final OpenAiService service; + private final Encoding chatEncoder; + private final ModelType model; + private final IEmbeddingDatabase embeddings; + + public GPTText2Text(EncodingRegistry registry, OpenAiService service, IEmbeddingDatabase embeddings) { + this.registry = registry; + this.service = service; + this.model = ModelType.GPT_3_5_TURBO; + this.chatEncoder = registry.getEncodingForModel(model); + this.embeddings = embeddings; + } + + @Override + public String generate(String text) { + CompletionRequest completionRequest = CompletionRequest.builder() + .prompt(text) + .model(this.model.getName()) + .echo(false) + .build(); + CompletionResult completion = service.createCompletion(completionRequest); + List results = new ArrayList<>(); + for (CompletionChoice choice : completion.getChoices()) { + results.add(choice.getText()); + } + return String.join("\n", results); + } +} diff --git a/src/main/java/link/locutus/discord/gpt/imps/IText2Text.java b/src/main/java/link/locutus/discord/gpt/imps/IText2Text.java new file mode 100644 index 00000000..582f81ae --- /dev/null +++ b/src/main/java/link/locutus/discord/gpt/imps/IText2Text.java @@ -0,0 +1,5 @@ +package link.locutus.discord.gpt.imps; + +public interface IText2Text { + String generate(String text); +} diff --git a/src/main/java/link/locutus/discord/gpt/imps/ProcessSummarizer.java b/src/main/java/link/locutus/discord/gpt/imps/ProcessSummarizer.java index 3021de92..1261fbab 100644 --- a/src/main/java/link/locutus/discord/gpt/imps/ProcessSummarizer.java +++ b/src/main/java/link/locutus/discord/gpt/imps/ProcessSummarizer.java @@ -21,8 +21,10 @@ public class ProcessSummarizer implements ISummarizer { private final String prompt; private final int promptTokens; private final ModelType model; + private final File venvExe; - public ProcessSummarizer(File file) { + public ProcessSummarizer(File venvExe, File file) { + this.venvExe = venvExe; this.file = file; this.prompt = """ Write a concise summary which preserves syntax, equations, arguments and constraints of the following: @@ -50,7 +52,8 @@ public String summarizeChunk(String chunk) { String encodedString = Base64.getEncoder().encodeToString(full.getBytes()); List lines = new ArrayList<>(); - ProcessBuilder pb = new ProcessBuilder("python", file.getAbsolutePath(), encodedString).redirectErrorStream(true); + String command = venvExe == null ? "python" : venvExe.getAbsolutePath(); + ProcessBuilder pb = new ProcessBuilder(command, file.getAbsolutePath(), encodedString).redirectErrorStream(true); try { Process p = pb.start(); BufferedReader reader = new BufferedReader(new InputStreamReader(p.getInputStream())); @@ -61,6 +64,13 @@ public String summarizeChunk(String chunk) { } catch (IOException e) { throw new RuntimeException(e); } - return StringMan.join(lines, "\n"); + String result = StringMan.join(lines, "\n"); + if (result.contains("result:")) { + result = result.substring(result.indexOf("result:") + 7); + return result; + } else { + System.err.println(result); + throw new IllegalArgumentException("Unknown process result (see console)"); + } } } diff --git a/src/main/java/link/locutus/discord/gpt/imps/ProcessText2Text.java b/src/main/java/link/locutus/discord/gpt/imps/ProcessText2Text.java new file mode 100644 index 00000000..a57c1165 --- /dev/null +++ b/src/main/java/link/locutus/discord/gpt/imps/ProcessText2Text.java @@ -0,0 +1,55 @@ +package link.locutus.discord.gpt.imps; + +import com.knuddels.jtokkit.api.Encoding; +import com.knuddels.jtokkit.api.EncodingRegistry; +import com.knuddels.jtokkit.api.ModelType; +import com.theokanning.openai.OpenAiService; +import com.theokanning.openai.completion.CompletionChoice; +import com.theokanning.openai.completion.CompletionRequest; +import com.theokanning.openai.completion.CompletionResult; +import link.locutus.discord.gpt.IEmbeddingDatabase; +import link.locutus.discord.util.StringMan; + +import java.io.BufferedReader; +import java.io.File; +import java.io.IOException; +import java.io.InputStreamReader; +import java.util.ArrayList; +import java.util.Base64; +import java.util.List; + +public class ProcessText2Text implements IText2Text{ + private final File file; + private final File venvExe; + + public ProcessText2Text(File venvExe, File file) { + this.venvExe = venvExe; + this.file = file; + } + + @Override + public String generate(String text) { + String encodedString = Base64.getEncoder().encodeToString(text.getBytes()); + List lines = new ArrayList<>(); + String command = venvExe == null ? "python" : venvExe.getAbsolutePath(); + ProcessBuilder pb = new ProcessBuilder(command, file.getAbsolutePath(), encodedString).redirectErrorStream(true); + try { + Process p = pb.start(); + BufferedReader reader = new BufferedReader(new InputStreamReader(p.getInputStream())); + String line; + while ((line = reader.readLine()) != null) { + lines.add(line); + } + } catch (IOException e) { + throw new RuntimeException(e); + } + String result = StringMan.join(lines, "\n"); + if (result.contains("result:")) { + result = result.substring(result.indexOf("result:") + 7); + return result; + } else { + System.err.println(result); + throw new IllegalArgumentException("Unknown process result (see console)"); + } + } +} diff --git a/src/main/java/link/locutus/discord/gpt/pwembed/PWGPTHandler.java b/src/main/java/link/locutus/discord/gpt/pwembed/PWGPTHandler.java index 1333f5a1..1bed73a6 100644 --- a/src/main/java/link/locutus/discord/gpt/pwembed/PWGPTHandler.java +++ b/src/main/java/link/locutus/discord/gpt/pwembed/PWGPTHandler.java @@ -1,18 +1,27 @@ package link.locutus.discord.gpt.pwembed; +import link.locutus.discord.Locutus; +import link.locutus.discord.commands.manager.v2.binding.Key; +import link.locutus.discord.commands.manager.v2.binding.LocalValueStore; import link.locutus.discord.commands.manager.v2.binding.ValueStore; +import link.locutus.discord.commands.manager.v2.binding.annotation.Me; import link.locutus.discord.commands.manager.v2.command.ParametricCallable; import link.locutus.discord.commands.manager.v2.impl.pw.CommandManager2; import link.locutus.discord.commands.manager.v2.impl.pw.binding.NationAttribute; +import link.locutus.discord.db.GuildDB; import link.locutus.discord.db.guild.GuildSetting; import link.locutus.discord.db.guild.GuildKey; +import link.locutus.discord.gpt.GPTUtil; +import link.locutus.discord.gpt.ModerationResult; import link.locutus.discord.gpt.imps.EmbeddingType; import link.locutus.discord.gpt.GptHandler; import link.locutus.discord.util.math.ArrayUtil; +import net.dv8tion.jda.api.entities.User; import java.lang.reflect.Method; import java.sql.SQLException; import java.util.ArrayList; +import java.util.Arrays; import java.util.HashSet; import java.util.List; import java.util.Map; @@ -47,6 +56,59 @@ public void registerDefaults() { // registerTutorialBindings("Tutorial"); } + public String generateSolution(ValueStore store, GuildDB db, User user, String userInput) { + // check moderation + List modResult = handler.getModerator().moderate(userInput); + GPTUtil.checkThrowModeration(modResult, userInput); + + // get prompt + String prompt = """ + You are `Locutus` a discord bot assistant of a player who is the leader of a nation in the game Politics And War. + Use the information below and your own knowledge to respond. + + Player conversation: + ``` + {user_input} + ``` + + Top results from searching the game database: + {search_results}"""; + + // 2000 + int promptLength = prompt.replace("{user_input}", "").replace("{search_results}", "").length(); + int userInputLength = userInput.length(); + + int max = 2000; + int remaining = max - promptLength - userInputLength; + + if (store == null) { + store = new LocalValueStore(Locutus.imp().getCommandManager().getV2().getStore()); + // set db and user + store.addProvider(Key.of(GuildDB.class, Me.class), db); + store.addProvider(Key.of(User.class, Me.class), user); + } + + // get the closest results + List embeddings = new ArrayList<>(); + HashSet allowedTypes = new HashSet<>(Arrays.asList(EmbeddingType.values())); + List> closest = this.getClosest(store, userInput, 50, allowedTypes); + for (Map.Entry entry : closest) { + PWEmbedding embedding = entry.getKey(); + String text = embedding.getType() + "." + embedding.getId() + "=" + embedding.getContent(); + if (text.length() + 1 > remaining) continue; + embeddings.add(text); + remaining -= text.length() + 1; + } + + String formatted = prompt.replace("{user_input}", userInput).replace("{search_results}", String.join("\n", embeddings)); + + System.out.println("Prompt\n```\n" + formatted + "\n```"); + + String result = this.handler.getText2text().generate(formatted); + + return result; + } + private void registerCommandEmbeddings() { Set existing = new HashSet<>(); for (ParametricCallable callable : cmdManager.getCommands().getParametricCallables(f -> true)) {