Skip to content

Commit

Permalink
fix: move prompt to config & add addon prompt (#714)
Browse files Browse the repository at this point in the history
  • Loading branch information
mizy authored Dec 18, 2023
1 parent cc65f0e commit ce590c2
Show file tree
Hide file tree
Showing 17 changed files with 174 additions and 160 deletions.
2 changes: 1 addition & 1 deletion app/config/locale/en-US.ts
Original file line number Diff line number Diff line change
Expand Up @@ -557,7 +557,7 @@ export default {
filePath: 'File Path',
importGraphSpace: 'Import Graph Space',
exportNGQLFilePath: 'Export NGQL File Path',
prompt: 'Prompt',
attachPrompt: 'Attach Prompt',
next: 'Next',
url: 'URL',
previous: 'Previous',
Expand Down
2 changes: 1 addition & 1 deletion app/config/locale/zh-CN.ts
Original file line number Diff line number Diff line change
Expand Up @@ -537,7 +537,7 @@ export default {
filePath: '文件路径',
importGraphSpace: '导入图空间',
exportNGQLFilePath: '导出 NGQL 文件路径',
prompt: '提示',
attachPrompt: '附加提示',
next: '下一步',
previous: '上一步',
start: '开始',
Expand Down
9 changes: 3 additions & 6 deletions app/pages/Import/AIImport/Create.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import { Button, Form, Input, Modal, Radio, Select, message } from 'antd';
import { observer } from 'mobx-react-lite';
import Icon from '@app/components/Icon';
import { useEffect, useMemo, useState } from 'react';
import { llmImportPrompt } from '@app/stores/llm';
import { getByteLength } from '@app/utils/function';
import { post } from '@app/utils/http';
import styles from './index.module.less';
Expand All @@ -30,7 +29,7 @@ const Create = observer((props: { visible: boolean; onCancel: () => void }) => {
form.resetFields();
form.setFieldsValue({
type: 'file',
promptTemplate: llmImportPrompt,
userPrompt: '',
});
setTokens(null);
}, [props.visible]);
Expand Down Expand Up @@ -63,11 +62,9 @@ const Create = observer((props: { visible: boolean; onCancel: () => void }) => {

const onConfirm = async () => {
const values = form.getFieldsValue();
const schema = await llm.getSpaceSchema(space);
post('/api/llm/import/job')({
type,
...values,
spaceSchemaString: schema,
}).then((res) => {
if (res.code === 0) {
message.success(intl.get('common.success'));
Expand Down Expand Up @@ -152,8 +149,8 @@ const Create = observer((props: { visible: boolean; onCancel: () => void }) => {
<Form.Item required label={intl.get('llm.exportNGQLFilePath')}>
<Input disabled value={llm.config.gqlPath} />
</Form.Item>
<Form.Item required={true} label={intl.get('llm.prompt')} name="promptTemplate">
<Input.TextArea style={{ height: 200 }} />
<Form.Item label={intl.get('llm.attachPrompt')} name="userPrompt">
<Input.TextArea />
</Form.Item>
</Form>

Expand Down
2 changes: 1 addition & 1 deletion app/pages/Import/AIImport/index.module.less
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@

.tokenNum {
position: absolute;
top: 200px;
top: 180px;
right: 20px;
display: flex;
align-items: center;
Expand Down
2 changes: 1 addition & 1 deletion app/pages/LLMBot/chat.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ function Chat() {
const newMessages = [
...messages,
{ role: 'user', content: currentInput },
{ role: 'assistant', content: '', status: 'pending' },
{ role: 'assistant', content: '', status: 'pending' }, // asistant can't be changed
];
llm.update({
currentInput: '',
Expand Down
81 changes: 28 additions & 53 deletions app/stores/llm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,26 +30,13 @@ diff
---
Question:{query_str}
`;
export const llmImportPrompt = `As a knowledge graph AI importer, your task is to extract useful data from the following text:
----text
{text}
----
the knowledge graph has following schema and node name must be a real :
----graph schema
{spaceSchema}
----
Return the results directly, without explain and comment. The results should be in the following JSON format:
{
"nodes":[{ "name":string,"type":string,"props":object }],
"edges":[{ "src":string,"dst":string,"edgeType":string,"props":object }]
}
The name of the nodes should be an actual object and a noun.
Result:
`;

export const docFinderPrompt = `The task is to identify the two best words from "{category_string}"\n that answer the question "{query_str}" for NebulaGraph database.The output should be a a comma-separated list of these two words.Don't explain anything.`;
export const docFinderPrompt = `The task is to identify the top2 effectively categories from
\`\`\`categories
{category_string}
\`\`\`
that answer the question "{query_str}" with the user's history ask is:"{history_str}" for NebulaGraph database.
The output should be a comma-separated list like "category1,category2" and don't explain anything`;

export const text2queryPrompt = `Assuming you are an NebulaGraph database AI assistant, your role is to assist users in crafting NGQL queries with NebulaGraph. You have access to the following details:
the user space schema is:
Expand Down Expand Up @@ -137,45 +124,28 @@ class LLM {
}

async getSpaceSchema(space: string) {
let finalPrompt: any = {
currentUsedSpaceName: space,
};
const finalPrompt = `The user's current graph space is: ${space} \n`;
if (this.config.features.includes('spaceSchema')) {
await schema.switchSpace(space);
await schema.getTagList();
await schema.getEdgeList();
const tagList = schema.tagList;
const edgeList = schema.edgeList;
finalPrompt = {
...finalPrompt,
vidType: schema.spaceVidType,
nodeTypes: tagList.map((item) => {
return {
type: item.name,
props: item.fields.map((item) => {
return {
name: item.Field,
dataType: item.Type,
nullable: (item as any).Null === 'YES',
};
}),
};
}),
edgeTypes: edgeList.map((item) => {
return {
type: item.name,
props: item.fields.map((item) => {
return {
name: item.Field,
dataType: item.Type,
nullable: (item as any).Null === 'YES',
};
}),
};
}),
};
let nodeSchemaString = '';
const edgeSchemaString = '';
tagList.forEach((item) => {
nodeSchemaString += `NodeType ${item.name} (${item.fields
.map((field) => `${field.Field}:${field.Type}`)
.join(' ')})\n`;
});
edgeList.forEach((item) => {
nodeSchemaString += `EdgeType ${item.name} (${item.fields
.map((field) => `${field.Field}:${field.Type}`)
.join(' ')})\n`;
});
return finalPrompt + nodeSchemaString + edgeSchemaString;
}
return JSON.stringify(finalPrompt);
return finalPrompt;
}

async getAgentPrompt(query_str: string, historyMessages: any, callback: (res: any) => void) {
Expand Down Expand Up @@ -270,17 +240,22 @@ class LLM {
let prompt = this.mode === 'text2cypher' ? matchPrompt : text2queryPrompt;
if (this.mode !== 'text2cypher') {
text = text.replaceAll('"', "'");
const history = historyMessages
.filter((item) => item.role === 'user')
.map((item) => item.content)
.join(',');
const docPrompt = docFinderPrompt
.replace('{category_string}', ngqlDoc.NGQLCategoryString)
.replace('{query_str}', text)
.replace('{history_str}', history)
.replace('{space_name}', rootStore.console.currentSpace);
console.log(docPrompt);
const res = (await ws.runChat({
req: {
stream: false,
max_tokens: 20,
max_tokens: 40,
top_p: 0.8,
messages: [
...historyMessages,
{
role: 'user',
content: docPrompt,
Expand Down
2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
"@vesoft-inc/force-graph": "2.0.7",
"@vesoft-inc/i18n": "^1.0.1",
"@vesoft-inc/icons": "^1.2.0",
"@vesoft-inc/nebula-explain-graph": "^1.0.2-beta.2",
"@vesoft-inc/nebula-explain-graph": "^1.0.2-beta.6",
"@vesoft-inc/veditor": "^4.4.12",
"antd": "^5.8.4",
"axios": "^0.23.0",
Expand Down
31 changes: 16 additions & 15 deletions server/api/studio/cmd/ai_importer.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,8 @@ import (

type Config struct {
LLMJob struct {
Space string
File string
PromptTemplate string
Space string
File string
}
Auth struct {
Address string
Expand All @@ -36,8 +35,9 @@ type Config struct {
APIType db.APIType
ContextLengthLimit int
}
GQLBatchSize int `json:",default=100"`
MaxBlockSize int `json:",default=0"`
GQLBatchSize int `json:",default=100"`
MaxBlockSize int `json:",default=0"`
PromptTemplate string `json:",default="`
}

func main() {
Expand All @@ -55,10 +55,9 @@ func main() {
CacheNodes: make(map[string]llm.Node),
CacheEdges: make(map[string]map[string]llm.Edge),
LLMJob: &db.LLMJob{
JobID: fmt.Sprintf("%d", time.Now().UnixNano()),
Space: c.LLMJob.Space,
File: c.LLMJob.File,
PromptTemplate: c.LLMJob.PromptTemplate,
JobID: fmt.Sprintf("%d", time.Now().UnixNano()),
Space: c.LLMJob.Space,
File: c.LLMJob.File,
},
AuthData: &auth.AuthData{
Address: c.Auth.Address,
Expand All @@ -75,13 +74,15 @@ func main() {
}
studioConfig := config.Config{
LLM: struct {
GQLPath string `json:",default=./data/llm"`
GQLBatchSize int `json:",default=100"`
MaxBlockSize int `json:",default=0"`
GQLPath string `json:",default=./data/llm"`
GQLBatchSize int `json:",default=100"`
MaxBlockSize int `json:",default=0"`
PromptTemplate string `json:",default="`
}{
GQLPath: *outputPath,
GQLBatchSize: c.GQLBatchSize,
MaxBlockSize: c.MaxBlockSize,
GQLPath: *outputPath,
GQLBatchSize: c.GQLBatchSize,
MaxBlockSize: c.MaxBlockSize,
PromptTemplate: c.PromptTemplate,
},
}
studioConfig.InitConfig()
Expand Down
28 changes: 14 additions & 14 deletions server/api/studio/etc/ai-importer.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,19 @@
LLMJob:
Space: "" #space name
File: "" #file path,support pdf,txt,json,csv and other text format
PromptTemplate: |
Auth:
Address: "127.0.0.1" # nebula graphd address
Port: 9669
Username: "root"
Password: "nebula"
LLMConfig:
URL: "" # openai api url
Key: "" # openai api key
APIType: "openai"
ContextLengthLimit: 1024
MaxBlockSize: 0 # max request block num
GQLBatchSize: 100 # max gql batch size
PromptTemplate: |
As a knowledge graph AI importer, your task is to extract useful data from the following text:
----text
{text}
Expand All @@ -18,16 +30,4 @@ LLMJob:
"edges":[{ "src":string,"dst":string,"edgeType":string,"props":object }]
}
The name of the nodes should be an actual object and a noun.
Result:
Auth:
Address: "127.0.0.1" # nebula graphd address
Port: 9669
Username: "root"
Password: "nebula"
LLMConfig:
URL: "" # openai api url
Key: "" # openai api key
APIType: "openai"
ContextLengthLimit: 1024
MaxBlockSize: 0 # max request block num
GQLBatchSize: 100 # max gql batch size
Result:
19 changes: 18 additions & 1 deletion server/api/studio/etc/studio-api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -62,4 +62,21 @@ DB:
LLM:
GQLPath: "./data/llm"
GQLBatchSize: 100
MaxBlockSize: 0
MaxBlockSize: 0
PromptTemplate: |
As a knowledge graph AI importer, your task is to extract useful data from the following text:
```text
{text}
```
the knowledge graph has following schema and node name must be a real :
```graph-schema
{spaceSchema}
```
{userPrompt}
Return the results directly, without explain and comment. The results should be in the following JSON format:
{
"nodes":[{ "name":string,"type":string,"props":object }],
"edges":[{ "src":string,"dst":string,"edgeType":string,"props":object }]
}
The name of the nodes should be an actual object and a noun.
Result:
27 changes: 24 additions & 3 deletions server/api/studio/internal/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,23 @@ import (

var configIns *Config

var PromptTemplate = `As a knowledge graph AI importer, your task is to extract useful data from the following text:` +
"```text\n" +
`{text}` +
"\n```\n" +
`the knowledge graph has following schema and node name must be a real :` +
"```graph-schema\n" +
`{spaceSchema}` +
"\n```\n" +
`{userPrompt}
Return the results directly, without explain and comment. The results should be in the following JSON format:
{
"nodes":[{ "name":string,"type":string,"props":object }],
"edges":[{ "src":string,"dst":string,"edgeType":string,"props":object }]
}
The name of the nodes should be an actual object and a noun.
Result:`

func GetConfig() *Config {
return configIns
}
Expand Down Expand Up @@ -65,9 +82,10 @@ type Config struct {
}

LLM struct {
GQLPath string `json:",default=./data/llm"`
GQLBatchSize int `json:",default=100"`
MaxBlockSize int `json:",default=0"`
GQLPath string `json:",default=./data/llm"`
GQLBatchSize int `json:",default=100"`
MaxBlockSize int `json:",default=0"`
PromptTemplate string `json:",default="`
}
}

Expand Down Expand Up @@ -117,6 +135,9 @@ func (c *Config) Complete() {
if c.LLM.MaxBlockSize == 0 {
c.LLM.MaxBlockSize = 1024 * 1024 * 1024
}
if c.LLM.PromptTemplate == "" {
c.LLM.PromptTemplate = PromptTemplate
}
}

func (c *Config) InitConfig() error {
Expand Down
2 changes: 1 addition & 1 deletion server/api/studio/internal/model/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,14 @@ import (
"time"

"github.com/pkg/errors"
"github.com/vesoft-inc/nebula-studio/server/api/studio/internal/config"
"github.com/zeromicro/go-zero/core/logx"
"go.uber.org/zap"
"gorm.io/driver/mysql"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"gorm.io/gorm/logger"

"github.com/vesoft-inc/nebula-studio/server/api/studio/internal/config"
dbutil "github.com/vesoft-inc/nebula-studio/server/api/studio/pkg/db"
)

Expand Down
Loading

0 comments on commit ce590c2

Please sign in to comment.