Skip to content

Commit

Permalink
Add PW gpt tool
Browse files Browse the repository at this point in the history
  • Loading branch information
xdnw committed Jul 13, 2023
1 parent cc451b1 commit 2ec647f
Show file tree
Hide file tree
Showing 9 changed files with 250 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,7 @@ public CommandManager2 registerDefaults() {
this.commands.registerMethod(help, List.of("help"), "find_setting", "find_setting");

this.commands.registerMethod(help, List.of("help"), "moderation_check", "moderation_check");
this.commands.registerMethod(help, List.of("help"), "query", "query");

pwgptHandler.registerDefaults();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,15 @@
import link.locutus.discord.commands.manager.v2.command.ParametricCallable;
import link.locutus.discord.commands.manager.v2.impl.pw.CM;
import link.locutus.discord.commands.manager.v2.perm.PermissionHandler;
import link.locutus.discord.db.GuildDB;
import link.locutus.discord.db.guild.GuildSetting;
import link.locutus.discord.gpt.pwembed.CommandEmbedding;
import link.locutus.discord.gpt.imps.EmbeddingType;
import link.locutus.discord.gpt.ModerationResult;
import link.locutus.discord.gpt.pwembed.PWEmbedding;
import link.locutus.discord.gpt.pwembed.PWGPTHandler;
import link.locutus.discord.gpt.pwembed.SettingEmbedding;
import net.dv8tion.jda.api.entities.User;

import java.io.IOException;
import java.util.List;
Expand All @@ -44,6 +46,12 @@ public PWGPTHandler getGPT() {
//
// }

@Command
public String query(ValueStore store, @Me GuildDB db, @Me User user, @Me IMessageIO io, String input) throws IOException {
String result = getGPT().generateSolution(store, db, user, input);
return result;
}

@Command
public void moderation_check(@Me IMessageIO io, String input) throws IOException {
List<String> inputs = List.of(input);
Expand Down
71 changes: 38 additions & 33 deletions src/main/java/link/locutus/discord/db/AEmbeddingDatabase.java
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import javax.annotation.Nullable;
import java.io.Closeable;
import java.math.BigInteger;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.concurrent.atomic.AtomicInteger;

Expand Down Expand Up @@ -71,7 +72,7 @@ private void loadContent() {
}

@Override
public void createTables() {
public synchronized void createTables() {
// embeddings
ctx().createTableIfNotExists("embeddings_2")
.column("hash", SQLDataType.BIGINT.notNull())
Expand All @@ -90,33 +91,35 @@ public void createTables() {

// if table `embeddings` exists
try {
if (getConnection().getMetaData().getTables(null, null, "embeddings", null).next()) {
AtomicInteger inserted = new AtomicInteger();
// iterate over all rows
ctx().select().from("embeddings").fetch().forEach(r -> {
// get hash
long hash = r.get("hash", Long.class);
// get type
long type = r.get("type", Long.class);
// get id
String id = r.get("id", String.class);
// get data
byte[] data = r.get("data", byte[].class);
double[] vectors = ArrayUtil.toDoubleArray(data);
float[] downCast = new float[vectors.length];
for (int i = 0; i < vectors.length; i++) {
downCast[i] = (float) vectors[i];
}
byte[] downCastBytes = ArrayUtil.toByteArray(downCast);
addEmbedding(hash, type, id, downCastBytes);
inserted.incrementAndGet();
});
if (inserted.get() > 0) {
System.out.println("Inserted " + inserted.get() + " embeddings");
// drop old table
ctx().dropTableIfExists("embeddings").execute();
AtomicInteger inserted = new AtomicInteger();
try (ResultSet query = getConnection().getMetaData().getTables(null, null, "embeddings", null)) {
if (query.next()) {
// iterate over all rows
ctx().select().from("embeddings").fetch().forEach(r -> {
// get hash
long hash = r.get("hash", Long.class);
// get type
long type = r.get("type", Long.class);
// get id
String id = r.get("id", String.class);
// get data
byte[] data = r.get("data", byte[].class);
double[] vectors = ArrayUtil.toDoubleArray(data);
float[] downCast = new float[vectors.length];
for (int i = 0; i < vectors.length; i++) {
downCast[i] = (float) vectors[i];
}
byte[] downCastBytes = ArrayUtil.toByteArray(downCast);
addEmbedding(hash, type, id, downCastBytes);
inserted.incrementAndGet();
});
}
}
if (inserted.get() > 0) {
System.out.println("Inserted " + inserted.get() + " embeddings");
// drop old table
ctx().dropTableIfExists("embeddings").execute();
}
} catch (SQLException e) {
e.printStackTrace();
}
Expand Down Expand Up @@ -169,12 +172,14 @@ public void setEmbedding(int type, @Nullable String id2, String content, float[]
}
return;
}
// delete from database
ctx().execute("DELETE FROM `embeddings` WHERE `hash` = ?", info.contentHash);
synchronized (this) {
// delete from database
ctx().execute("DELETE FROM `embeddings` WHERE `hash` = ?", info.contentHash);

// delete content if exists
if (hashContent.remove(info.contentHash) != null) {
ctx().execute("DELETE FROM `content` WHERE `hash` = ?", info.contentHash);
// delete content if exists
if (hashContent.remove(info.contentHash) != null) {
ctx().execute("DELETE FROM `content` WHERE `hash` = ?", info.contentHash);
}
}

System.out.println("Delete different embedding");
Expand All @@ -199,12 +204,12 @@ public void setEmbedding(int type, @Nullable String id2, String content, float[]
}
}

private void updateEmbedding(long contentHash, int type, String id) {
private synchronized void updateEmbedding(long contentHash, int type, String id) {
if (id == null) id = "";
ctx().execute("UPDATE `embeddings_2` SET `type` = ?, `id` = ? WHERE `hash` = ?", type, id, contentHash);
}

private void addContent(long hash, String content) {
private synchronized void addContent(long hash, String content) {
ctx().execute("INSERT OR IGNORE INTO `content` (`hash`, `content`) VALUES (?, ?)", hash, content);
}

Expand Down
26 changes: 23 additions & 3 deletions src/main/java/link/locutus/discord/gpt/GptHandler.java
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,15 @@
import link.locutus.discord.config.Settings;
import link.locutus.discord.gpt.imps.AdaEmbedding;
import link.locutus.discord.gpt.imps.GPTSummarizer;
import link.locutus.discord.gpt.imps.ProcessSummarizer;
import link.locutus.discord.gpt.imps.ProcessText2Text;
import link.locutus.discord.util.FileUtil;
import link.locutus.discord.util.math.ArrayUtil;
import org.json.JSONArray;
import org.json.JSONObject;

import javax.annotation.Nullable;
import java.io.File;
import java.io.IOException;
import java.net.HttpURLConnection;
import java.nio.charset.StandardCharsets;
Expand All @@ -41,6 +44,7 @@ public class GptHandler {
public final IEmbeddingDatabase embeddingDatabase;
private final ISummarizer summarizer;
private final IModerator moderator;
private final ProcessText2Text text2text;

public GptHandler() throws SQLException, ClassNotFoundException {
this.registry = Encodings.newDefaultEncodingRegistry();
Expand All @@ -50,10 +54,26 @@ public GptHandler() throws SQLException, ClassNotFoundException {

this.platform = Platform.detectPlatform("pytorch");


this.summarizer = new GPTSummarizer(registry, service);
this.embeddingDatabase = new AdaEmbedding(registry, service);
this.moderator = new GPTModerator(service);
this.embeddingDatabase = new AdaEmbedding(registry, service);
// TODO change ^ that to mini

File gpt4freePath = new File("../gpt4free/mymain.py");
File venvExe = new File("../gpt4free/venv/Scripts/python.exe");
// ensure files exist
if (!gpt4freePath.exists()) {
throw new RuntimeException("gpt4free not found: " + gpt4freePath.getAbsolutePath());
}
if (!venvExe.exists()) {
throw new RuntimeException("venv not found: " + venvExe.getAbsolutePath());
}

this.summarizer = new ProcessSummarizer(venvExe, gpt4freePath);
this.text2text = new ProcessText2Text(venvExe, gpt4freePath);
}

public ProcessText2Text getText2text() {
return text2text;
}

public IModerator getModerator() {
Expand Down
45 changes: 45 additions & 0 deletions src/main/java/link/locutus/discord/gpt/imps/GPTText2Text.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
package link.locutus.discord.gpt.imps;

import com.knuddels.jtokkit.api.Encoding;
import com.knuddels.jtokkit.api.EncodingRegistry;
import com.knuddels.jtokkit.api.ModelType;
import com.theokanning.openai.OpenAiService;
import com.theokanning.openai.completion.CompletionChoice;
import com.theokanning.openai.completion.CompletionRequest;
import com.theokanning.openai.completion.CompletionResult;
import link.locutus.discord.gpt.GPTUtil;
import link.locutus.discord.gpt.IEmbeddingDatabase;

import java.util.ArrayList;
import java.util.List;

public class GPTText2Text implements IText2Text{
private final EncodingRegistry registry;
private final OpenAiService service;
private final Encoding chatEncoder;
private final ModelType model;
private final IEmbeddingDatabase embeddings;

public GPTText2Text(EncodingRegistry registry, OpenAiService service, IEmbeddingDatabase embeddings) {
this.registry = registry;
this.service = service;
this.model = ModelType.GPT_3_5_TURBO;
this.chatEncoder = registry.getEncodingForModel(model);
this.embeddings = embeddings;
}

@Override
public String generate(String text) {
CompletionRequest completionRequest = CompletionRequest.builder()
.prompt(text)
.model(this.model.getName())
.echo(false)
.build();
CompletionResult completion = service.createCompletion(completionRequest);
List<String> results = new ArrayList<>();
for (CompletionChoice choice : completion.getChoices()) {
results.add(choice.getText());
}
return String.join("\n", results);
}
}
5 changes: 5 additions & 0 deletions src/main/java/link/locutus/discord/gpt/imps/IText2Text.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
package link.locutus.discord.gpt.imps;

public interface IText2Text {
String generate(String text);
}
16 changes: 13 additions & 3 deletions src/main/java/link/locutus/discord/gpt/imps/ProcessSummarizer.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@ public class ProcessSummarizer implements ISummarizer {
private final String prompt;
private final int promptTokens;
private final ModelType model;
private final File venvExe;

public ProcessSummarizer(File file) {
public ProcessSummarizer(File venvExe, File file) {
this.venvExe = venvExe;
this.file = file;
this.prompt = """
Write a concise summary which preserves syntax, equations, arguments and constraints of the following:
Expand Down Expand Up @@ -50,7 +52,8 @@ public String summarizeChunk(String chunk) {

String encodedString = Base64.getEncoder().encodeToString(full.getBytes());
List<String> lines = new ArrayList<>();
ProcessBuilder pb = new ProcessBuilder("python", file.getAbsolutePath(), encodedString).redirectErrorStream(true);
String command = venvExe == null ? "python" : venvExe.getAbsolutePath();
ProcessBuilder pb = new ProcessBuilder(command, file.getAbsolutePath(), encodedString).redirectErrorStream(true);
try {
Process p = pb.start();
BufferedReader reader = new BufferedReader(new InputStreamReader(p.getInputStream()));
Expand All @@ -61,6 +64,13 @@ public String summarizeChunk(String chunk) {
} catch (IOException e) {
throw new RuntimeException(e);
}
return StringMan.join(lines, "\n");
String result = StringMan.join(lines, "\n");
if (result.contains("result:")) {
result = result.substring(result.indexOf("result:") + 7);
return result;
} else {
System.err.println(result);
throw new IllegalArgumentException("Unknown process result (see console)");
}
}
}
55 changes: 55 additions & 0 deletions src/main/java/link/locutus/discord/gpt/imps/ProcessText2Text.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
package link.locutus.discord.gpt.imps;

import com.knuddels.jtokkit.api.Encoding;
import com.knuddels.jtokkit.api.EncodingRegistry;
import com.knuddels.jtokkit.api.ModelType;
import com.theokanning.openai.OpenAiService;
import com.theokanning.openai.completion.CompletionChoice;
import com.theokanning.openai.completion.CompletionRequest;
import com.theokanning.openai.completion.CompletionResult;
import link.locutus.discord.gpt.IEmbeddingDatabase;
import link.locutus.discord.util.StringMan;

import java.io.BufferedReader;
import java.io.File;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.Base64;
import java.util.List;

public class ProcessText2Text implements IText2Text{
private final File file;
private final File venvExe;

public ProcessText2Text(File venvExe, File file) {
this.venvExe = venvExe;
this.file = file;
}

@Override
public String generate(String text) {
String encodedString = Base64.getEncoder().encodeToString(text.getBytes());
List<String> lines = new ArrayList<>();
String command = venvExe == null ? "python" : venvExe.getAbsolutePath();
ProcessBuilder pb = new ProcessBuilder(command, file.getAbsolutePath(), encodedString).redirectErrorStream(true);
try {
Process p = pb.start();
BufferedReader reader = new BufferedReader(new InputStreamReader(p.getInputStream()));
String line;
while ((line = reader.readLine()) != null) {
lines.add(line);
}
} catch (IOException e) {
throw new RuntimeException(e);
}
String result = StringMan.join(lines, "\n");
if (result.contains("result:")) {
result = result.substring(result.indexOf("result:") + 7);
return result;
} else {
System.err.println(result);
throw new IllegalArgumentException("Unknown process result (see console)");
}
}
}
Loading

0 comments on commit 2ec647f

Please sign in to comment.