From 9d9e2a21d2c455cd69eda484305c86e32d15bfcf Mon Sep 17 00:00:00 2001 From: Harsh4902 Date: Mon, 25 Sep 2023 16:09:28 +0530 Subject: [PATCH] Code added for embeddings in hyde-example --- JS/edgechains/examples/ormconfig.json | 12 + JS/edgechains/examples/package-lock.json | 140 +++++++++ JS/edgechains/examples/package.json | 2 +- JS/edgechains/examples/src/app.controller.ts | 3 +- .../examples/src/hydeExample/hydeExample.ts | 268 ++++++++++++------ JS/jsonnet-demo/app.js | 2 + 6 files changed, 334 insertions(+), 93 deletions(-) create mode 100644 JS/edgechains/examples/ormconfig.json diff --git a/JS/edgechains/examples/ormconfig.json b/JS/edgechains/examples/ormconfig.json new file mode 100644 index 000000000..04267dd89 --- /dev/null +++ b/JS/edgechains/examples/ormconfig.json @@ -0,0 +1,12 @@ +{ + "type": "postgres", + "host": "db.rmzqtepwnzoxgkkzjctt.supabase.co", + "port": 5432, + "username": "postgres", + "password": "xaX0MYcf1YiJlChK", + "database": "postgres", + "entities": ["dist/entities/**/*.js"], + "synchronize": false, + "logging": false + } + \ No newline at end of file diff --git a/JS/edgechains/examples/package-lock.json b/JS/edgechains/examples/package-lock.json index fc5d45838..fda84cc18 100644 --- a/JS/edgechains/examples/package-lock.json +++ b/JS/edgechains/examples/package-lock.json @@ -14,6 +14,7 @@ "@nestjs/core": "^10.0.0", "@nestjs/platform-express": "^10.0.0", "@nestjs/typeorm": "^10.0.0", + "pg": "^8.11.3", "reflect-metadata": "^0.1.13", "rxjs": "^7.8.1" }, @@ -3146,6 +3147,14 @@ "resolved": "https://registry.npmjs.org/buffer-from/-/buffer-from-1.1.2.tgz", "integrity": "sha512-E+XQCRwSbaaiChtv6k6Dwgc+bx+Bs6vuKJHHl5kox/BaKbhiXzqQOwK4cO22yElGp2OCmjwVhT3HmxgyPGnJfQ==" }, + "node_modules/buffer-writer": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/buffer-writer/-/buffer-writer-2.0.0.tgz", + "integrity": "sha512-a7ZpuTZU1TRtnwyCNW3I5dc0wWNC3VR9S++Ewyk2HHZdrO3CQJqSpd+95Us590V6AL7JqUAH2IwZ/398PmNFgw==", + "engines": { + "node": ">=4" + } + }, "node_modules/bundle-name": { "version": "3.0.0", "resolved": "https://registry.npmjs.org/bundle-name/-/bundle-name-3.0.0.tgz", @@ -7648,6 +7657,11 @@ "node": ">=8" } }, + "node_modules/packet-reader": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/packet-reader/-/packet-reader-1.0.0.tgz", + "integrity": "sha512-HAKu/fG3HpHFO0AA8WE8q2g+gBJaZ9MG7fcKk+IJPLTGAD6Psw4443l+9DGRbOIh3/aXr7Phy0TjilYivJo5XQ==" + }, "node_modules/parent-module": { "version": "1.0.1", "resolved": "https://registry.npmjs.org/parent-module/-/parent-module-1.0.1.tgz", @@ -7788,6 +7802,89 @@ "node": ">=8" } }, + "node_modules/pg": { + "version": "8.11.3", + "resolved": "https://registry.npmjs.org/pg/-/pg-8.11.3.tgz", + "integrity": "sha512-+9iuvG8QfaaUrrph+kpF24cXkH1YOOUeArRNYIxq1viYHZagBxrTno7cecY1Fa44tJeZvaoG+Djpkc3JwehN5g==", + "dependencies": { + "buffer-writer": "2.0.0", + "packet-reader": "1.0.0", + "pg-connection-string": "^2.6.2", + "pg-pool": "^3.6.1", + "pg-protocol": "^1.6.0", + "pg-types": "^2.1.0", + "pgpass": "1.x" + }, + "engines": { + "node": ">= 8.0.0" + }, + "optionalDependencies": { + "pg-cloudflare": "^1.1.1" + }, + "peerDependencies": { + "pg-native": ">=3.0.1" + }, + "peerDependenciesMeta": { + "pg-native": { + "optional": true + } + } + }, + "node_modules/pg-cloudflare": { + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/pg-cloudflare/-/pg-cloudflare-1.1.1.tgz", + "integrity": "sha512-xWPagP/4B6BgFO+EKz3JONXv3YDgvkbVrGw2mTo3D6tVDQRh1e7cqVGvyR3BE+eQgAvx1XhW/iEASj4/jCWl3Q==", + "optional": true + }, + "node_modules/pg-connection-string": { + "version": "2.6.2", + "resolved": "https://registry.npmjs.org/pg-connection-string/-/pg-connection-string-2.6.2.tgz", + "integrity": "sha512-ch6OwaeaPYcova4kKZ15sbJ2hKb/VP48ZD2gE7i1J+L4MspCtBMAx8nMgz7bksc7IojCIIWuEhHibSMFH8m8oA==" + }, + "node_modules/pg-int8": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/pg-int8/-/pg-int8-1.0.1.tgz", + "integrity": "sha512-WCtabS6t3c8SkpDBUlb1kjOs7l66xsGdKpIPZsg4wR+B3+u9UAum2odSsF9tnvxg80h4ZxLWMy4pRjOsFIqQpw==", + "engines": { + "node": ">=4.0.0" + } + }, + "node_modules/pg-pool": { + "version": "3.6.1", + "resolved": "https://registry.npmjs.org/pg-pool/-/pg-pool-3.6.1.tgz", + "integrity": "sha512-jizsIzhkIitxCGfPRzJn1ZdcosIt3pz9Sh3V01fm1vZnbnCMgmGl5wvGGdNN2EL9Rmb0EcFoCkixH4Pu+sP9Og==", + "peerDependencies": { + "pg": ">=8.0" + } + }, + "node_modules/pg-protocol": { + "version": "1.6.0", + "resolved": "https://registry.npmjs.org/pg-protocol/-/pg-protocol-1.6.0.tgz", + "integrity": "sha512-M+PDm637OY5WM307051+bsDia5Xej6d9IR4GwJse1qA1DIhiKlksvrneZOYQq42OM+spubpcNYEo2FcKQrDk+Q==" + }, + "node_modules/pg-types": { + "version": "2.2.0", + "resolved": "https://registry.npmjs.org/pg-types/-/pg-types-2.2.0.tgz", + "integrity": "sha512-qTAAlrEsl8s4OiEQY69wDvcMIdQN6wdz5ojQiOy6YRMuynxenON0O5oCpJI6lshc6scgAY8qvJ2On/p+CXY0GA==", + "dependencies": { + "pg-int8": "1.0.1", + "postgres-array": "~2.0.0", + "postgres-bytea": "~1.0.0", + "postgres-date": "~1.0.4", + "postgres-interval": "^1.1.0" + }, + "engines": { + "node": ">=4" + } + }, + "node_modules/pgpass": { + "version": "1.0.5", + "resolved": "https://registry.npmjs.org/pgpass/-/pgpass-1.0.5.tgz", + "integrity": "sha512-FdW9r/jQZhSeohs1Z3sI1yxFQNFvMcnmfuj4WBMUTxOrAyLMaTcE1aAMBiTlbMNaXvBCQuVi0R7hd8udDSP7ug==", + "dependencies": { + "split2": "^4.1.0" + } + }, "node_modules/picocolors": { "version": "1.0.0", "resolved": "https://registry.npmjs.org/picocolors/-/picocolors-1.0.0.tgz", @@ -7888,6 +7985,41 @@ "node": ">=4" } }, + "node_modules/postgres-array": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/postgres-array/-/postgres-array-2.0.0.tgz", + "integrity": "sha512-VpZrUqU5A69eQyW2c5CA1jtLecCsN2U/bD6VilrFDWq5+5UIEVO7nazS3TEcHf1zuPYO/sqGvUvW62g86RXZuA==", + "engines": { + "node": ">=4" + } + }, + "node_modules/postgres-bytea": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/postgres-bytea/-/postgres-bytea-1.0.0.tgz", + "integrity": "sha512-xy3pmLuQqRBZBXDULy7KbaitYqLcmxigw14Q5sj8QBVLqEwXfeybIKVWiqAXTlcvdvb0+xkOtDbfQMOf4lST1w==", + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/postgres-date": { + "version": "1.0.7", + "resolved": "https://registry.npmjs.org/postgres-date/-/postgres-date-1.0.7.tgz", + "integrity": "sha512-suDmjLVQg78nMK2UZ454hAG+OAW+HQPZ6n++TNDUX+L0+uUlLywnoxJKDou51Zm+zTCjrCl0Nq6J9C5hP9vK/Q==", + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/postgres-interval": { + "version": "1.2.0", + "resolved": "https://registry.npmjs.org/postgres-interval/-/postgres-interval-1.2.0.tgz", + "integrity": "sha512-9ZhXKM/rw350N1ovuWHbGxnGh/SNJ4cnxHiM0rxE4VN41wsg8P8zWn9hv/buK00RP4WvlOyr/RBDiptyxVbkZQ==", + "dependencies": { + "xtend": "^4.0.0" + }, + "engines": { + "node": ">=0.10.0" + } + }, "node_modules/prelude-ls": { "version": "1.2.1", "resolved": "https://registry.npmjs.org/prelude-ls/-/prelude-ls-1.2.1.tgz", @@ -8770,6 +8902,14 @@ "semver": "bin/semver.js" } }, + "node_modules/split2": { + "version": "4.2.0", + "resolved": "https://registry.npmjs.org/split2/-/split2-4.2.0.tgz", + "integrity": "sha512-UcjcJOWknrNkF6PLX83qcHM6KHgVKNkV62Y8a5uYDVv9ydGQVwAHMKqHdJje1VTWpljG0WYpCDhrCdAOYH4TWg==", + "engines": { + "node": ">= 10.x" + } + }, "node_modules/sprintf-js": { "version": "1.0.3", "resolved": "https://registry.npmjs.org/sprintf-js/-/sprintf-js-1.0.3.tgz", diff --git a/JS/edgechains/examples/package.json b/JS/edgechains/examples/package.json index fae419748..c18995574 100644 --- a/JS/edgechains/examples/package.json +++ b/JS/edgechains/examples/package.json @@ -5,7 +5,6 @@ "author": "", "private": true, "license": "UNLICENSED", - "scripts": { "build": "nest build", "clean": "rm -rf coverage dist .nyc_output", @@ -27,6 +26,7 @@ "@nestjs/core": "^10.0.0", "@nestjs/platform-express": "^10.0.0", "@nestjs/typeorm": "^10.0.0", + "pg": "^8.11.3", "reflect-metadata": "^0.1.13", "rxjs": "^7.8.1" }, diff --git a/JS/edgechains/examples/src/app.controller.ts b/JS/edgechains/examples/src/app.controller.ts index 760c428a9..0a77f824a 100644 --- a/JS/edgechains/examples/src/app.controller.ts +++ b/JS/edgechains/examples/src/app.controller.ts @@ -14,13 +14,14 @@ export class AppController { } @Post("/hyde-search") - hydeSearch(@Query() params:any, @Body() query:string){ + hydeSearch(@Query() params:any, @Body('query') query: string){ const arkRequest = { tableName : params.table, nameSpace : params.namespace, query : query, topK : params.topK } + hydeSearchAdaEmbedding(arkRequest); } diff --git a/JS/edgechains/examples/src/hydeExample/hydeExample.ts b/JS/edgechains/examples/src/hydeExample/hydeExample.ts index c43c0267b..8e106435d 100644 --- a/JS/edgechains/examples/src/hydeExample/hydeExample.ts +++ b/JS/edgechains/examples/src/hydeExample/hydeExample.ts @@ -1,19 +1,7 @@ import axios from 'axios'; import { Jsonnet } from "@hanazuki/node-jsonnet"; -import { error } from 'console'; - -interface WordEmbeddings { - // Define the properties of WordEmbeddings here -} - -interface PostgresWordEmbeddings { - // Define the properties of PostgresWordEmbeddings here -} - - -interface ChatMessage { - // Define the properties and methods of ChatMessage here -} +import * as path from 'path'; +import { createConnection,getManager } from 'typeorm'; const gpt3endpoint = { @@ -36,99 +24,96 @@ export async function hydeSearchAdaEmbedding(arkRequest){ // const jsonnet = new Jsonnet(); - // Configure PostgresEndpoint - const postgresEndpoint = { - tableName: table, - namespace: namespace, - // Add other properties here... - }; - + const promptPath = path.join(process.cwd(),'./src/hydeExample/prompts.jsonnet') + const hydePath = path.join(process.cwd(),'./src/hydeExample/hyde.jsonnet') // Load Jsonnet to extract args.. - // const promptLoader = await jsonnet - // .evaluateFile('./hyde.jsonnet'); + var promptLoader = await jsonnet + .evaluateFile(promptPath); - // // Getting ${summary} basePrompt - // const promptTemplate = JSON.parse(promptLoader).summary; - // console.log(promptTemplate); + // Getting ${summary} basePrompt + const promptTemplate = JSON.parse(promptLoader).summary; + console.log(promptTemplate); // // Getting the updated promptTemplate with query - // const hydeLoader = await jsonnet - // .extString('promptTemplate',promptTemplate) - // .extString('time',"") - // .extString('query',query) - // .evaluateFile("./hyde.jsonnet"); + var hydeLoader = await jsonnet + .extString('promptTemplate',promptTemplate) + .extString('time',"") + .extString('query',query) + .evaluateFile(hydePath); // Get concatenated prompt - const prompt = "Do not expand on abbreviations and leave them as is in the reply. Please generate 5 different responses in bullet points for the question.Please write a summary to answer the question in detail:\nQuestion: Hello How are You\nPassage:" - console.log(prompt); + const prompt = JSON.parse(hydeLoader).prompt; // Block and get the response from GPT3 - const gptResponse = await gptFn(prompt, arkRequest); + const gptResponse = await gptFn(prompt); // Chain 1 ==> Get Gpt3Response & split const gpt3Responses = gptResponse.split('\n'); // Chain 2 ==> Get Embeddings from OpenAI using Each Response - // const embeddingsListChain: Promise = Promise.all( - // gpt3Responses.map(async (resp) => { - // const embeddings = await ada002Embedding.embeddings(resp, arkRequest); - // return embeddings.getValues(); - // }) - // ); - + const embeddingsListChain: Promise = Promise.all( + gpt3Responses.map(async (resp) => { + const embedding = await embeddings(resp, arkRequest); + return embedding; + }) + ); // // Chain 4 ==> Calculate Mean from EmbeddingList & Pass to WordEmbedding Object - // const meanEmbedding = await meanFn(await embeddingsListChain, 1536); - // const wordEmbeddings = new WordEmbeddings(gptResponse, meanEmbedding); + const meanEmbedding = await meanFn(await embeddingsListChain, 1536); + const wordEmbeddings = { + id : gptResponse, + score : meanEmbedding + } // // Chain 5 ==> Query via EmbeddingChain - // const queryResult = await postgresEndpoint.query(wordEmbeddings, PostgresDistanceMetric.COSINE, topK, probes); + const queryResult = await dbQuery(wordEmbeddings, "<=>", topK, 20,table,namespace); // // Chain 6 ==> Create Prompt using Embeddings - // const retrievedDocs: string[] = []; + const retrievedDocs: string[] = []; + + for (const embeddings of queryResult) { + retrievedDocs.push( + `${embeddings.getRawText()}\n score:${embeddings.getScore()}\n filename:${embeddings.getFilename()}\n` + ); + } + + if (retrievedDocs.join('').length > 4096) { + retrievedDocs.length = 4096; + } - // for (const embeddings of queryResult) { - // retrievedDocs.push( - // `${embeddings.getRawText()}\n score:${embeddings.getScore()}\n filename:${embeddings.getFilename()}\n` - // ); - // } + const currentTime = new Date().toLocaleString(); + const formattedTime = currentTime; - // if (retrievedDocs.join('').length > 4096) { - // retrievedDocs.length = 4096; - // } + // System prompt + const ansPromptSystem = JSON.parse(promptLoader).ans_prompt_system + + hydeLoader = await jsonnet + .extString(promptTemplate,ansPromptSystem) + .extString('time',formattedTime) + .extString('qeury',retrievedDocs.join('')) + .evaluateFile(hydePath); - // const currentTime = new Date().toLocaleString(); - // const formattedTime = currentTime; + const finalPromptSystem = JSON.parse(hydeLoader).prompt; - // // System prompt - // const ansPromptSystem = jsonnet - // .extString() - // .evaluateFile('./hyde.jsonnet'); + // User prompt + const ansPromptUser = JSON.parse(promptLoader).ans_prompt_user - // promptLoader.get('ans_prompt_system'); - // hydeLoader.put('promptTemplate', new JsonnetArgs(DataType.STRING, ansPromptSystem)); - // hydeLoader.put('time', new JsonnetArgs(DataType.STRING, formattedTime)); - // hydeLoader.put('query', new JsonnetArgs(DataType.STRING, retrievedDocs.join(''))); - // await hydeLoader.loadOrReload(); - // const finalPromptSystem = hydeLoader.get('prompt'); - - // // User prompt - // const ansPromptUser = promptLoader.get('ans_prompt_user'); - // hydeLoader.put('promptTemplate', new JsonnetArgs(DataType.STRING, ansPromptUser)); - // hydeLoader.put('query', new JsonnetArgs(DataType.STRING, query)); - // await hydeLoader.loadOrReload(); - // const finalPromptUser = hydeLoader.get('prompt'); - - // const chatMessages: ChatMessage[] = [ - // { sender: 'system', message: finalPromptSystem }, - // { sender: 'user', message: finalPromptUser }, - // ]; - - // const finalAnswer = await gptFnChat(chatMessages, arkRequest); - - // const response = { - // wordEmbeddings: queryResult, - // finalAnswer: finalAnswer, - // }; + hydeLoader = await jsonnet + .extString(promptTemplate,ansPromptUser) + .extString('qeury',query) + .evaluateFile(hydePath); + const finalPromptUser = JSON.parse(hydeLoader).prompt;; + + const chatMessages = [ + { 'sender': 'system', 'message': finalPromptSystem }, + { 'sender': 'user', 'message': finalPromptUser }, + ]; + + const finalAnswer = await gptFnChat(chatMessages, arkRequest); + + const response = { + wordEmbeddings: queryResult, + finalAnswer: finalAnswer, + }; // return response; } catch (error) { @@ -138,9 +123,9 @@ export async function hydeSearchAdaEmbedding(arkRequest){ } } -async function gptFn(prompt:string, arkRequest): Promise { +async function gptFn(prompt:string) : Promise{ - const response = await axios.post('https://api.openai.com/v1/chat/completions', { + const responce = await axios.post('https://api.openai.com/v1/chat/completions', { 'model' : gpt3endpoint.model, 'messages' : [{ 'role' : gpt3endpoint.role, @@ -149,11 +134,15 @@ async function gptFn(prompt:string, arkRequest): Promise { 'temperature' : gpt3endpoint.temprature },{ headers : { - Authorization : 'Bearer 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' , + Authorization : 'Bearer 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' , 'content-type' : 'application/json' } }) - .then() + .then(function(response){ + return response.data.choices + } + + ) .catch(function (error) { if (error.response) { console.log('Server responded with status code:', error.response.status); @@ -164,10 +153,107 @@ async function gptFn(prompt:string, arkRequest): Promise { console.log('Error creating request:', error.message); } }); - - return response.data.choices; + return responce[0].message.content; } -async function gptFnChat(chatMessages:ChatMessage[],arkRequest) { +async function gptFnChat(chatMessages,arkRequest) { + const responce = await axios.post('https://api.openai.com/v1/chat/completions', { + 'model' : gpt3endpoint.model, + 'messages' : chatMessages, + 'temperature' : gpt3endpoint.temprature + },{ + headers : { + Authorization : 'Bearer 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' , + 'content-type' : 'application/json' + } + }) + .then(function(response){ + return response.data.choices + } + ) + .catch(function (error) { + if (error.response) { + console.log('Server responded with status code:', error.response.status); + console.log('Response data:', error.response.data); + } else if (error.request) { + console.log('No response received:', error.request); + } else { + console.log('Error creating request:', error.message); + } + }); +} + +async function embeddings(resp : string, arkRequest): Promise { + const responce = await axios.post('https://api.openai.com/v1/embeddings', { + "model" : "text-embedding-ada-002", + "input" : resp + },{ + headers : { + Authorization : 'Bearer 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' , + 'content-type' : 'application/json' + } + }) + .then(function(response){ + return response.data.data[0].embedding; + } + + ) + .catch(function (error) { + if (error.response) { + console.log('Server responded with status code:', error.response.status); + console.log('Response data:', error.response.data); + } else if (error.request) { + console.log('No response received:', error.request); + } else { + console.log('Error creating request:', error.message); + } + }); + + return responce; +} + +function meanFn(embeddingsList: Number[][], dimensions: number): Number[] { + const mean: Number[] = []; + + for (let i = 0; i < dimensions; i++) { + let sum = 0; + + for (let j = 0; j < embeddingsList.length; j++) { + sum = sum.valueOf() + embeddingsList[j][i].valueOf(); + } + + mean.push(sum / embeddingsList.length); + } + return mean; } + + +async function dbQuery(wordEmbeddings, metric, topK, probes,tableName,namespace:string) { + const embedding = JSON.stringify(wordEmbeddings.score) + console.log(embedding) + + const connection = await createConnection(); + const entityManager = getManager(); + try { + const query1 = `SET LOCAL ivfflat.probes = ${probes};` + await entityManager.query(query1); + + const query = ` + SELECT id, raw_text, namespace, filename, timestamp, + 1 - (embedding <=> "${embedding.toString()}") AS score + FROM ${tableName} + WHERE namespace = '${namespace}' + ORDER BY embedding ${metric} ${embedding} + LIMIT ${topK}; + `; + + const results = await entityManager.query(query); + console.log(results) + return results; + } catch (error) { + // Handle errors here + console.error(error); + throw error; + } +} \ No newline at end of file diff --git a/JS/jsonnet-demo/app.js b/JS/jsonnet-demo/app.js index 348d10932..21add6b28 100644 --- a/JS/jsonnet-demo/app.js +++ b/JS/jsonnet-demo/app.js @@ -18,6 +18,8 @@ const prompt = JSON.parse(hydeLoader).prompt; console.log(prompt); +const a = [0.0128617634,-0.0071556790580000005,0.01010643056,0.0042276748,-0.0194928978,0.0229963362,-0.0298838202,-0.0011673947880000003] +console.log(JSON.stringify(a)) // Evaluates a simple Jsonnet program into a JSON value // await jsonnet // // .extString("keepMaxToken","true")