From 1c360d8157d2f4d0bdbfdd708adefa9917eb3338 Mon Sep 17 00:00:00 2001 From: Miao Zhuang <1060950782@163.com> Date: Wed, 13 Dec 2023 13:56:43 +0800 Subject: [PATCH] fix: ai import bug (#710) --- app/pages/Setting/index.tsx | 20 +++------ app/stores/llm.ts | 41 ++++++++++--------- app/utils/ngql.ts | 5 ++- server/api/studio/pkg/llm/importjob.go | 14 +++++-- .../api/studio/pkg/llm/transformer/openai.go | 5 +++ 5 files changed, 46 insertions(+), 39 deletions(-) diff --git a/app/pages/Setting/index.tsx b/app/pages/Setting/index.tsx index d3f66170..8bc93539 100644 --- a/app/pages/Setting/index.tsx +++ b/app/pages/Setting/index.tsx @@ -1,4 +1,4 @@ -import { useCallback, useEffect, useState } from 'react'; +import { useCallback, useEffect } from 'react'; import { observer } from 'mobx-react-lite'; import { Button, Col, Form, Input, InputNumber, Row, Select, Switch, message } from 'antd'; import { useI18n } from '@vesoft-inc/i18n'; @@ -15,7 +15,6 @@ const Setting = observer(() => { const { global, llm } = useStore(); const { appSetting, saveAppSetting } = global; const [form] = useForm(); - const [apiType, setApiType] = useState('openai'); useEffect(() => { initForm(); }, []); @@ -23,7 +22,6 @@ const Setting = observer(() => { const initForm = async () => { await llm.fetchConfig(); form.setFieldsValue(llm.config); - setApiType(llm.config.apiType); }; const updateAppSetting = useCallback(async (param: Partial) => { @@ -109,13 +107,7 @@ const Setting = observer(() => {
{intl.get('setting.llmImportDesc')}
- OpenAI Aliyun @@ -126,11 +118,9 @@ const Setting = observer(() => { - {apiType === 'qwen' && ( - - - - )} + + + diff --git a/app/stores/llm.ts b/app/stores/llm.ts index a3fcc12c..0e4c1705 100644 --- a/app/stores/llm.ts +++ b/app/stores/llm.ts @@ -6,7 +6,9 @@ import * as ngqlDoc from '@app/utils/ngql'; import schema from './schema'; import rootStore from '.'; -export const matchPrompt = `Use NebulaGraph match knowledge to help me answer question. +export const matchPrompt = `I want you to be a NebulaGraph database asistant. +There are below document. +---- Use only the provided relationship types and properties in the schema. Do not use any other relationship types or properties that are not provided. Schema: @@ -25,6 +27,7 @@ diff --- > MATCH (p:person)-[:directed]->(m:movie) WHERE m.movie.name == 'The Godfather' > RETURN p.person.name; +--- Question:{query_str} `; export const llmImportPrompt = `As a knowledge graph AI importer, your task is to extract useful data from the following text: @@ -46,14 +49,12 @@ The name of the nodes should be an actual object and a noun. Result: `; -export const docFinderPrompt = `Assume your are doc finder,from the following graph database book categories: +export const docFinderPrompt = `Assuming you are a document navigator, within the following categories related to graph database books: "{category_string}" -user current space is: {space_name} -find top two useful categories to solve the question:"{query_str}", -don't explain, if you can't find, return "Sorry". -just return the two combined categories, separated by ',' is:`; +please identify the most two relevant categories that could address the question: "{query_str}"., +Please just return the two categories as a comma-separated list without any other word`; -export const text2queryPrompt = `Assume you are a NebulaGraph AI chat asistant to help user write NGQL. +export const text2queryPrompt = `Assume you are a NebulaGraph database AI chat asistant to help user write NGQL with NebulaGraph. You have access to the following information: the user space schema is: ---- @@ -227,10 +228,8 @@ class LLM { console.log(prompt); await ws.runChat({ req: { - temperature: 0.5, stream: true, max_tokens: 20, - messages: [ ...historyMessages, { @@ -274,18 +273,19 @@ class LLM { let prompt = matchPrompt; // default use text2cypher if (this.mode !== 'text2cypher') { text = text.replaceAll('"', "'"); + const docPrompt = docFinderPrompt + .replace('{category_string}', ngqlDoc.NGQLCategoryString) + .replace('{query_str}', text) + .replace('{space_name}', rootStore.console.currentSpace); + console.log(docPrompt); const res = (await ws.runChat({ req: { - temperature: 0.5, stream: false, max_tokens: 20, messages: [ { role: 'user', - content: docFinderPrompt - .replace('{category_string}', ngqlDoc.NGQLCategoryString) - .replace('{query_str}', text) - .replace('{space_name}', rootStore.console.currentSpace), + content: docPrompt, }, ], }, @@ -297,19 +297,19 @@ class LLM { .replaceAll(/\s|"|\\/g, '') .split(','); console.log('select doc url:', paths); - if (ngqlDoc.ngqlMap[paths[0]]) { - let doc = ngqlDoc.ngqlMap[paths[0]].content; + if (paths[0] !== 'sorry') { + let doc = ngqlDoc.ngqlMap[paths[0]]?.content; if (!doc) { doc = ''; } - const doc2 = ngqlDoc.ngqlMap[paths[1]].content; + const doc2 = ngqlDoc.ngqlMap[paths[1]]?.content; if (doc2) { - doc += doc2; + doc = + doc.slice(0, this.config.maxContextLength / 2) + `\n` + doc2.slice(0, this.config.maxContextLength / 2); } doc = doc.replaceAll(/\n\n+/g, ''); if (doc.length) { - console.log('docString:', doc); - prompt = text2queryPrompt.replace('{doc}', doc.slice(0, this.config.maxContextLength)); + prompt = text2queryPrompt.replace('{doc}', doc); } } } @@ -327,6 +327,7 @@ class LLM { schemaPrompt += `\nuser console ngql context: ${rootStore.console.currentGQL}`; } prompt = prompt.replace('{schema}', schemaPrompt); + console.log(prompt); return prompt; } diff --git a/app/utils/ngql.ts b/app/utils/ngql.ts index 90b1c8ff..24a6b675 100644 --- a/app/utils/ngql.ts +++ b/app/utils/ngql.ts @@ -14,7 +14,10 @@ export const ngqlDoc = (ngqlJson as { url: string; content: string; title: strin if (urlTransformerMap[item.title]) { item.title = urlTransformerMap[item.title]; } - item.title = item.title.replaceAll(' ', ''); + item.title = item.title + .split(' ') + .map((word) => word[0].toUpperCase() + word.slice(1)) + .join(''); item.content = item.content.replace(/nebula>/g, ''); return item; diff --git a/server/api/studio/pkg/llm/importjob.go b/server/api/studio/pkg/llm/importjob.go index cb75bf79..1d64982e 100644 --- a/server/api/studio/pkg/llm/importjob.go +++ b/server/api/studio/pkg/llm/importjob.go @@ -475,12 +475,16 @@ func (i *ImportJob) MakeGQLFile(filePath string) ([]string, error) { if valueStr != "" { valueStr += "," } - valueStr += fmt.Sprintf(`"%v"`, value) + if strings.Contains(strings.ToLower(field.DataType), "string") { + valueStr += fmt.Sprintf(`"%v"`, value) + } else { + valueStr += fmt.Sprintf(`%v`, value) + } } gql := fmt.Sprintf("INSERT VERTEX `%s` ({props}) VALUES \"%s\":({value});", typ, name) gql = strings.ReplaceAll(gql, "{props}", propsStr) - gql = strings.ReplaceAll(gql, "{value}", propsStr) + gql = strings.ReplaceAll(gql, "{value}", valueStr) gqls = append(gqls, gql) } @@ -508,7 +512,11 @@ func (i *ImportJob) MakeGQLFile(filePath string) ([]string, error) { if propsValue != "" { propsValue += "," } - propsValue += fmt.Sprintf("\"%v\"", value) + if strings.Contains(strings.ToLower(field.DataType), "string") { + propsValue += fmt.Sprintf(`"%v"`, value) + } else { + propsValue += fmt.Sprintf(`%v`, value) + } } gql := fmt.Sprintf("INSERT EDGE `%s` (%s) VALUES \"%s\"->\"%s\":(%s);", dst.EdgeType, propsName, dst.Src, dst.Dst, propsValue) gqls = append(gqls, gql) diff --git a/server/api/studio/pkg/llm/transformer/openai.go b/server/api/studio/pkg/llm/transformer/openai.go index 8c211f8c..7aae4cdd 100644 --- a/server/api/studio/pkg/llm/transformer/openai.go +++ b/server/api/studio/pkg/llm/transformer/openai.go @@ -20,6 +20,11 @@ type OpenAI struct { } func (o *OpenAI) HandleRequest(req map[string]any, config *db.LLMConfig) (*http.Request, error) { + configs := make(map[string]any) + err := json.Unmarshal([]byte(config.Config), &configs) + if err == nil { + req["model"] = configs["model"] + } // Convert the request parameters to a JSON string reqJSON, err := json.Marshal(req) if err != nil {