Skip to content

Commit

Permalink
Google Java Format
Browse files Browse the repository at this point in the history
  • Loading branch information
github-actions committed Sep 11, 2023
1 parent 6425143 commit 9fa2052
Show file tree
Hide file tree
Showing 14 changed files with 353 additions and 314 deletions.
88 changes: 46 additions & 42 deletions Examples/pinecone/PineconeExample.java
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ public static void main(String[] args) {

// Redis Configuration
properties.setProperty("redis.url", "");
properties.setProperty("redis.port","12285");
properties.setProperty("redis.port", "12285");
properties.setProperty("redis.username", "default");
properties.setProperty("redis.password", "");
properties.setProperty("redis.ttl", "3600");
Expand Down Expand Up @@ -99,34 +99,39 @@ public static void main(String[] args) {
new ExponentialDelay(3, 5, 2, TimeUnit.SECONDS));

gpt3StreamEndpoint =
new OpenAiEndpoint(
OPENAI_CHAT_COMPLETION_API,
OPENAI_AUTH_KEY,
OPENAI_ORG_ID,
"gpt-3.5-turbo",
"user",
0.85,
true,
new ExponentialDelay(3, 5, 2, TimeUnit.SECONDS));
new OpenAiEndpoint(
OPENAI_CHAT_COMPLETION_API,
OPENAI_AUTH_KEY,
OPENAI_ORG_ID,
"gpt-3.5-turbo",
"user",
0.85,
true,
new ExponentialDelay(3, 5, 2, TimeUnit.SECONDS));

upsertPineconeEndpoint =
new PineconeEndpoint(
PINECONE_UPSERT_API,
PINECONE_AUTH_KEY,
"machine-learning", // Passing namespace; read more on Pinecone documentation. You can pass empty string
"machine-learning", // Passing namespace; read more on Pinecone documentation. You can
// pass empty string
new ExponentialDelay(3, 3, 2, TimeUnit.SECONDS));

queryPineconeEndpoint =
new PineconeEndpoint(
PINECONE_QUERY_API, PINECONE_AUTH_KEY,
"machine-learning", // Passing namespace; read more on Pinecone documentation. You can pass empty string
new ExponentialDelay(3, 3, 2, TimeUnit.SECONDS));
PINECONE_QUERY_API,
PINECONE_AUTH_KEY,
"machine-learning", // Passing namespace; read more on Pinecone documentation. You can
// pass empty string
new ExponentialDelay(3, 3, 2, TimeUnit.SECONDS));

deletePineconeEndpoint =
new PineconeEndpoint(
PINECONE_DELETE, PINECONE_AUTH_KEY,
"machine-learning", // Passing namespace; read more on Pinecone documentation. You can pass empty string
new FixedDelay(4, 5, TimeUnit.SECONDS));
PINECONE_DELETE,
PINECONE_AUTH_KEY,
"machine-learning", // Passing namespace; read more on Pinecone documentation. You can
// pass empty string
new FixedDelay(4, 5, TimeUnit.SECONDS));

contextEndpoint =
new RedisHistoryContextEndpoint(new ExponentialDelay(2, 2, 2, TimeUnit.SECONDS));
Expand Down Expand Up @@ -227,7 +232,6 @@ public ArkResponse chatWithPinecone(ArkRequest arkRequest) {
String query = arkRequest.getBody().getString("query");
boolean stream = arkRequest.getBooleanHeader("stream");


// Get HistoryContext
HistoryContext historyContext = contextEndpoint.get(contextId);

Expand Down Expand Up @@ -272,18 +276,18 @@ public ArkResponse chatWithPinecone(ArkRequest arkRequest) {

// Chain 5 ==> Pass the Prompt To Gpt3
EdgeChain<ChatCompletionResponse> gpt3Chain =
new EdgeChain<>(
gpt3Endpoint.chatCompletion(promptChain.get(), "RedisChatChain", arkRequest));
new EdgeChain<>(
gpt3Endpoint.chatCompletion(promptChain.get(), "RedisChatChain", arkRequest));

// Chain 6
EdgeChain<ChatCompletionResponse> historyUpdatedChain =
gpt3Chain.doOnNext(
chatResponse ->
contextEndpoint.put(
historyContext.getId(),
query
+ chatResponse.getChoices().get(0).getMessage().getContent()
+ historyContext.getResponse()));
gpt3Chain.doOnNext(
chatResponse ->
contextEndpoint.put(
historyContext.getId(),
query
+ chatResponse.getChoices().get(0).getMessage().getContent()
+ historyContext.getResponse()));

return historyUpdatedChain.getArkResponse();
}
Expand All @@ -293,8 +297,8 @@ public ArkResponse chatWithPinecone(ArkRequest arkRequest) {

// Chain 5 ==> Pass the Prompt To Gpt3
EdgeChain<ChatCompletionResponse> gpt3Chain =
new EdgeChain<>(
gpt3StreamEndpoint.chatCompletion(promptChain.get(), "RedisChatChain", arkRequest));
new EdgeChain<>(
gpt3StreamEndpoint.chatCompletion(promptChain.get(), "RedisChatChain", arkRequest));

/* As the response is in stream, so we will use StringBuilder to append the response
and once GPT chain indicates that it is finished, we will save the following into Redis
Expand All @@ -305,19 +309,19 @@ public ArkResponse chatWithPinecone(ArkRequest arkRequest) {

// Chain 7
EdgeChain<ChatCompletionResponse> streamingOutputChain =
gpt3Chain.doOnNext(
chatResponse -> {
if (Objects.isNull(chatResponse.getChoices().get(0).getFinishReason())) {
stringBuilder.append(
chatResponse.getChoices().get(0).getMessage().getContent());
}
// Now the streaming response is ended. Save it to DB i.e. HistoryContext
else {
contextEndpoint.put(
historyContext.getId(),
query + stringBuilder + historyContext.getResponse());
}
});
gpt3Chain.doOnNext(
chatResponse -> {
if (Objects.isNull(chatResponse.getChoices().get(0).getFinishReason())) {
stringBuilder.append(
chatResponse.getChoices().get(0).getMessage().getContent());
}
// Now the streaming response is ended. Save it to DB i.e. HistoryContext
else {
contextEndpoint.put(
historyContext.getId(),
query + stringBuilder + historyContext.getResponse());
}
});

return streamingOutputChain.getArkStreamResponse();
}
Expand Down
41 changes: 25 additions & 16 deletions Examples/postgresql/PostgreSQLExample.java
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ public class PostgreSQLExample {

private static final String OPENAI_AUTH_KEY = ""; // YOUR OPENAI AUTH KEY
private static final String OPENAI_ORG_ID = ""; // YOUR OPENAI ORG ID

private static OpenAiEndpoint ada002Embedding;
private static OpenAiEndpoint gpt3Endpoint;
private static OpenAiEndpoint gpt3StreamEndpoint;
Expand Down Expand Up @@ -88,19 +88,20 @@ public static void main(String[] args) {
new ExponentialDelay(3, 5, 2, TimeUnit.SECONDS));

gpt3StreamEndpoint =
new OpenAiEndpoint(
OPENAI_CHAT_COMPLETION_API,
OPENAI_AUTH_KEY,
OPENAI_ORG_ID,
"gpt-3.5-turbo",
"user",
0.85,
true,
new ExponentialDelay(3, 5, 2, TimeUnit.SECONDS));
new OpenAiEndpoint(
OPENAI_CHAT_COMPLETION_API,
OPENAI_AUTH_KEY,
OPENAI_ORG_ID,
"gpt-3.5-turbo",
"user",
0.85,
true,
new ExponentialDelay(3, 5, 2, TimeUnit.SECONDS));

// Defining tablename and namespace...
postgresEndpoint =
new PostgresEndpoint("pg_vectors", "machine-learning", new ExponentialDelay(5, 5, 2, TimeUnit.SECONDS));
new PostgresEndpoint(
"pg_vectors", "machine-learning", new ExponentialDelay(5, 5, 2, TimeUnit.SECONDS));
contextEndpoint = new PostgreSQLHistoryContextEndpoint(new FixedDelay(2, 3, TimeUnit.SECONDS));
}

Expand Down Expand Up @@ -157,7 +158,14 @@ public void upsert(ArkRequest arkRequest) throws IOException {
String[] arr = pdfReader.readByChunkSize(file, 512);

PostgresRetrieval retrieval =
new PostgresRetrieval(arr, ada002Embedding, postgresEndpoint, 1536, filename, PostgresLanguage.ENGLISH, arkRequest);
new PostgresRetrieval(
arr,
ada002Embedding,
postgresEndpoint,
1536,
filename,
PostgresLanguage.ENGLISH,
arkRequest);

// retrieval.setBatchSize(50); // Modifying batchSize....(Default is 30)

Expand Down Expand Up @@ -250,8 +258,8 @@ public ArkResponse chat(ArkRequest arkRequest) {

// Chain 5 ==> Pass the Prompt To Gpt3
EdgeChain<ChatCompletionResponse> gpt3Chain =
new EdgeChain<>(
gpt3Endpoint.chatCompletion(promptChain.get(), "PostgresChatChain", arkRequest));
new EdgeChain<>(
gpt3Endpoint.chatCompletion(promptChain.get(), "PostgresChatChain", arkRequest));

// Chain 6
EdgeChain<ChatCompletionResponse> historyUpdatedChain =
Expand All @@ -271,8 +279,9 @@ public ArkResponse chat(ArkRequest arkRequest) {

// Chain 5 ==> Pass the Prompt To Gpt3
EdgeChain<ChatCompletionResponse> gpt3Chain =
new EdgeChain<>(
gpt3StreamEndpoint.chatCompletion(promptChain.get(), "PostgresChatChain", arkRequest));
new EdgeChain<>(
gpt3StreamEndpoint.chatCompletion(
promptChain.get(), "PostgresChatChain", arkRequest));

/* As the response is in stream, so we will use StringBuilder to append the response
and once GPT chain indicates that it is finished, we will save the following into Postgres
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,9 @@ public class PostgresEndpoint extends Endpoint {
private List<String> metadataList;
private String documentDate;

/** RRF **/
/** RRF * */
private RRFWeight textWeight;

private RRFWeight similarityWeight;
private RRFWeight dateWeight;

Expand Down Expand Up @@ -238,7 +239,8 @@ public StringResponse createMetadataTable(String metadataTableName) {
return this.postgresService.createMetadataTable(this).blockingGet();
}

public List<StringResponse> upsert(List<WordEmbeddings> wordEmbeddingsList, String filename, PostgresLanguage postgresLanguage) {
public List<StringResponse> upsert(
List<WordEmbeddings> wordEmbeddingsList, String filename, PostgresLanguage postgresLanguage) {
this.wordEmbeddingsList = wordEmbeddingsList;
this.filename = filename;
this.postgresLanguage = postgresLanguage;
Expand Down Expand Up @@ -283,17 +285,30 @@ public Observable<List<PostgresWordEmbeddings>> query(
this.probes = probes;
return Observable.fromSingle(this.postgresService.query(this));
}

public Observable<List<PostgresWordEmbeddings>> query(
List<WordEmbeddings> wordEmbeddingsList, PostgresDistanceMetric metric, int topK, int probes) {
List<WordEmbeddings> wordEmbeddingsList,
PostgresDistanceMetric metric,
int topK,
int probes) {
this.wordEmbeddingsList = wordEmbeddingsList;
this.metric = metric;
this.probes = probes;
this.topK = topK;
return Observable.fromSingle(this.postgresService.query(this));
}

public Observable<List<PostgresWordEmbeddings>> queryRRF
(String metadataTable, WordEmbeddings wordEmbedding, RRFWeight textWeight, RRFWeight similarityWeight, RRFWeight dateWeight, OrderRRFBy orderRRFBy, String searchQuery,PostgresLanguage postgresLanguage, PostgresDistanceMetric metric, int topK) {
public Observable<List<PostgresWordEmbeddings>> queryRRF(
String metadataTable,
WordEmbeddings wordEmbedding,
RRFWeight textWeight,
RRFWeight similarityWeight,
RRFWeight dateWeight,
OrderRRFBy orderRRFBy,
String searchQuery,
PostgresLanguage postgresLanguage,
PostgresDistanceMetric metric,
int topK) {
this.metadataTableNames = List.of(metadataTable);
this.wordEmbedding = wordEmbedding;
this.textWeight = textWeight;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;

import org.postgresql.util.PGobject;
import org.slf4j.Logger;
Expand Down Expand Up @@ -291,12 +290,17 @@ public EdgeChain<List<PostgresWordEmbeddings>> queryRRF(PostgresEndpoint postgre
: null;
val.setScore(bigDecimal.doubleValue());

if(postgresEndpoint.getMetadataTableNames().get(0).contains("title")) {
val.setTitleMetadata(Objects.nonNull(row.get("metadata")) ? (String) row.get("metadata") : null);
}else {
val.setMetadata(Objects.nonNull(row.get("metadata")) ? (String) row.get("metadata") : null);
if (postgresEndpoint.getMetadataTableNames().get(0).contains("title")) {
val.setTitleMetadata(
Objects.nonNull(row.get("metadata")) ? (String) row.get("metadata") : null);
} else {
val.setMetadata(
Objects.nonNull(row.get("metadata")) ? (String) row.get("metadata") : null);
}
Date documentDate = Objects.nonNull(row.get("document_date")) ? (Date) row.get("document_date"): null;
Date documentDate =
Objects.nonNull(row.get("document_date"))
? (Date) row.get("document_date")
: null;
val.setDocumentDate(documentDate.toString());

wordEmbeddingsList.add(val);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,41 +6,40 @@

public class RRFWeight {

private BaseWeight baseWeight = BaseWeight.W1_0;
private double fineTuneWeight = 0.5;

public RRFWeight() {}

public RRFWeight(BaseWeight baseWeight, double fineTuneWeight) {
this.baseWeight = baseWeight;
this.fineTuneWeight = fineTuneWeight;

if(fineTuneWeight < 0 || fineTuneWeight > 1.0)
throw new IllegalArgumentException("Weights must be between 0 and 1.");

}

public void setBaseWeight(BaseWeight baseWeight) {
this.baseWeight = baseWeight;
}

public void setFineTuneWeight(double fineTuneWeight) {
this.fineTuneWeight = fineTuneWeight;
}

public BaseWeight getBaseWeight() {
return baseWeight;
}

public double getFineTuneWeight() {
return fineTuneWeight;
}

@Override
public String toString() {
return new StringJoiner(", ", RRFWeight.class.getSimpleName() + "[", "]")
.add("baseWeight=" + baseWeight)
.add("fineTuneWeight=" + fineTuneWeight)
.toString();
}
private BaseWeight baseWeight = BaseWeight.W1_0;
private double fineTuneWeight = 0.5;

public RRFWeight() {}

public RRFWeight(BaseWeight baseWeight, double fineTuneWeight) {
this.baseWeight = baseWeight;
this.fineTuneWeight = fineTuneWeight;

if (fineTuneWeight < 0 || fineTuneWeight > 1.0)
throw new IllegalArgumentException("Weights must be between 0 and 1.");
}

public void setBaseWeight(BaseWeight baseWeight) {
this.baseWeight = baseWeight;
}

public void setFineTuneWeight(double fineTuneWeight) {
this.fineTuneWeight = fineTuneWeight;
}

public BaseWeight getBaseWeight() {
return baseWeight;
}

public double getFineTuneWeight() {
return fineTuneWeight;
}

@Override
public String toString() {
return new StringJoiner(", ", RRFWeight.class.getSimpleName() + "[", "]")
.add("baseWeight=" + baseWeight)
.add("fineTuneWeight=" + fineTuneWeight)
.toString();
}
}
Loading

0 comments on commit 9fa2052

Please sign in to comment.