-
Notifications
You must be signed in to change notification settings - Fork 76
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
c7f0caa
commit 51f4b92
Showing
2 changed files
with
179 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
package com.edgechain; | ||
|
||
import com.edgechain.lib.endpoint.impl.OpenAiEndpoint; | ||
import com.edgechain.lib.jsonnet.JsonnetArgs; | ||
import com.edgechain.lib.jsonnet.JsonnetLoader; | ||
import com.edgechain.lib.jsonnet.enums.DataType; | ||
import com.edgechain.lib.jsonnet.impl.FileJsonnetLoader; | ||
import com.edgechain.lib.openai.request.ChatMessage; | ||
import com.edgechain.lib.request.ArkRequest; | ||
import com.edgechain.lib.rxjava.retry.impl.ExponentialDelay; | ||
import com.edgechain.lib.rxjava.transformer.observable.EdgeChain; | ||
import org.json.JSONObject; | ||
import org.slf4j.Logger; | ||
import org.slf4j.LoggerFactory; | ||
import org.springframework.boot.autoconfigure.SpringBootApplication; | ||
import org.springframework.boot.builder.SpringApplicationBuilder; | ||
import org.springframework.http.ResponseEntity; | ||
import org.springframework.web.bind.annotation.PostMapping; | ||
import org.springframework.web.bind.annotation.RequestMapping; | ||
import org.springframework.web.bind.annotation.RestController; | ||
|
||
|
||
import java.util.ArrayList; | ||
import java.util.List; | ||
import java.util.Properties; | ||
import java.util.concurrent.TimeUnit; | ||
|
||
import static com.edgechain.lib.constants.EndpointConstants.OPENAI_CHAT_COMPLETION_API; | ||
|
||
@SpringBootApplication | ||
public class ChatBot { | ||
|
||
private static final String OPENAI_AUTH_KEY = ""; // YOUR OPENAI KEY | ||
private static final String OPENAI_ORG_ID = ""; // YOUR OPENAI KEY | ||
private static final String TMDB_TOKEN = ""; // TMDB_TOKEN | ||
private static OpenAiEndpoint gpt3Endpoint; | ||
private static JsonnetLoader loader = new FileJsonnetLoader("./chatbot-planner/planner.jsonnet"); // JSONNET FILE PATH | ||
|
||
|
||
public static void main(String[] args) { | ||
System.setProperty("server.port", "8080"); | ||
|
||
|
||
Properties properties = new Properties(); | ||
|
||
properties.setProperty("postgres.db.host", ""); | ||
properties.setProperty("postgres.db.username", ""); | ||
properties.setProperty("postgres.db.password", ""); | ||
|
||
|
||
new SpringApplicationBuilder(ChatBot.class).run(args); | ||
|
||
gpt3Endpoint = new OpenAiEndpoint( | ||
OPENAI_CHAT_COMPLETION_API, | ||
OPENAI_AUTH_KEY, | ||
OPENAI_ORG_ID, | ||
"gpt-3.5-turbo", | ||
"user", | ||
0.7, | ||
new ExponentialDelay(3, 5, 2, TimeUnit.SECONDS) | ||
); | ||
} | ||
|
||
|
||
@RestController | ||
@RequestMapping | ||
public class Conversation { | ||
|
||
Logger logger = LoggerFactory.getLogger(getClass()); | ||
private List<ChatMessage> messages; | ||
|
||
public Conversation() { | ||
messages = new ArrayList<>(); | ||
// messages.add(new ChatMessage("system", "You are a helpful, polite, old English assistant. Answer the user prompt with a bit of humor.")); | ||
// messages.add( | ||
// new ChatMessage("system", "You are planner bot that plans user prompts and returns " + | ||
// "back a plan on how to accomplish the request." + | ||
// "Try to answer in natural language.") | ||
// ); | ||
} | ||
|
||
@PostMapping("/planner") | ||
public ResponseEntity<String> ask(ArkRequest arkRequest) { | ||
String query = arkRequest.getBody().getString("prompt"); | ||
logger.info("user query from POSTMAN {}", query); | ||
JSONObject jsonObject = arkRequest.getBody(); | ||
// updateMessageList("user", prompt); | ||
|
||
|
||
loader.put("prompt", new JsonnetArgs(DataType.STRING, jsonObject.getString("prompt"))) | ||
.loadOrReload(); | ||
|
||
String prompt = loader.get("prompt"); | ||
|
||
logger.info("jsonnet prompt {}", prompt); | ||
|
||
messages.add(new ChatMessage("system", loader.get("apiPlannerSelector"))); | ||
updateMessageList("user", query); | ||
|
||
|
||
|
||
String response = new EdgeChain<>(gpt3Endpoint.chatCompletion(messages, "planner", loader, arkRequest)) | ||
.get() | ||
.getChoices() | ||
.get(0) | ||
.getMessage() | ||
.getContent(); | ||
|
||
logger.info("Response from OPENAI {}", response); | ||
|
||
|
||
updateMessageList("assistant", response); | ||
|
||
return ResponseEntity.ok(response); | ||
} | ||
|
||
private void updateMessageList(String role, String content) { | ||
messages.add(new ChatMessage(role, content)); | ||
|
||
if (messages.size() > 20) { | ||
messages.remove(0); | ||
} | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
local maxTokens = if(payload.keepMaxTokens == 'True') then payload.maxTokens else 10000; | ||
local api_planner_selector = ||| | ||
You are a planner that plans a sequence of RESTful API calls to assist with user queries against an API. | ||
You should: | ||
1) evaluate whether the user query can be solved by the API documentated below. If no, say CAN'T_HELP_RIGHT_NOW. | ||
2) if yes, generate a plan of API calls and say what they are doing step by step. | ||
You should only use API endpoints documented below ("actual endpoints you can use"). | ||
Some user queries can be resolved using a single endpoint, but some will require several endpoints. | ||
Your selected endpoints will be passed to an API planner that can look at the detailed documentation and make an execution plan. | ||
You must always follow this format: | ||
User query: the query from the user | ||
Thought: you should always describe your thoughts | ||
Result: a comma separated list of operation title potentially relevant for the query | ||
Here are some examples: | ||
Do not use APIs that are not listed here. | ||
Fake endpoints for examples: | ||
GET /person/{person_id}/movie_credits to Get the movie credits for a person. | ||
User query: tell me about today's wheather | ||
Thought: Sorry, this API's domain is Movie, not wheather. | ||
Result: NOT_APPICABLE | ||
User query: give me the latest movie directed by Wong Kar-Wai. | ||
Thought: GET /person/{person_id}/movie_creditsto get the latest movie directed by Wong Kar-Wai (id 12453) | ||
Result: The latest movie directed by Wong Kar-Wai is The Grandmaster (id 44865) | ||
Here are endpoints you can use. Do not reference any of the endpoints above. | ||
{endpoint} | ||
Begin! Remember to first describe your thoughts and then return the result list using Result or output NOT_APPLICABLE if the query can not be solved with the given endpoints: | ||
User query: {query} | ||
Thought: | ||
|||; | ||
local query = "User query:" + payload.prompt; | ||
local context = if(payload.keepContext == "true") then payload.context else ""; | ||
local prompt = std.join("\n", [query, api_planner_selector]); | ||
{ | ||
"apiPlannerSelector": api_planner_selector, | ||
"query": query, | ||
"context": context, | ||
"maxTokens": maxTokens, | ||
"prompt": if(std.length(prompt) > xtr.parseNum(maxTokens)) then std.substr(prompt, 0, xtr.parseNum(maxTokens)) else prompt | ||
} |