diff --git a/README.md b/README.md index 06690164..3e9226aa 100644 --- a/README.md +++ b/README.md @@ -101,14 +101,13 @@ and more... - [x] OpenAI - [ ] Anthropic -- [ ] Falcon-7B ### Embedding models - [X] OpenAI - [X] TensorFlow -- [ ] HuggingFace -- [ ] Cohere +- [X] HuggingFace +- [X] Cohere ### Application diff --git a/app/ui/src/utils/embeddings.ts b/app/ui/src/utils/embeddings.ts index d7866881..00fae48e 100644 --- a/app/ui/src/utils/embeddings.ts +++ b/app/ui/src/utils/embeddings.ts @@ -1,6 +1,7 @@ export const availableEmbeddingTypes = [ { value: "openai", label: "OpenAI" }, { value: "tensorflow", label: "Tensorflow" }, - // { value: "cohere", label: "Cohere"} + { value: "cohere", label: "Cohere"}, + { value: "huggingface-api", label: "HuggingFace (Inference)"} ]; \ No newline at end of file diff --git a/docker/imp.env b/docker/imp.env index f1041504..3f3f723d 100644 --- a/docker/imp.env +++ b/docker/imp.env @@ -4,4 +4,6 @@ OPENAI_API_KEY="" # DB_SECRET_KEY is used for jwt token generation please change it to your own secret key DB_SECRET_KEY="super-secret-key" # Cohere API key -> https://dashboard.cohere.ai/api-keys -# COHERE_API_KEY="" \ No newline at end of file +COHERE_API_KEY="" +# Huggingface Hub API key -> https://huggingface.co/settings/token +HUGGINGFACEHUB_API_KEY="" diff --git a/package.json b/package.json index 2b2da936..d90c20ec 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "dialoqbase", - "version": "0.0.3", + "version": "0.0.4", "description": "Create chatbots with ease", "scripts": { "ui:dev": "pnpm run --filter ui dev", diff --git a/server/package.json b/server/package.json index b240d003..a8f40160 100644 --- a/server/package.json +++ b/server/package.json @@ -28,6 +28,7 @@ "@fastify/multipart": "^7.6.0", "@fastify/sensible": "^5.0.0", "@fastify/static": "^6.10.2", + "@huggingface/inference": "1", "@prisma/client": "4.15.0", "@tensorflow-models/universal-sentence-encoder": "^1.3.3", "@tensorflow/tfjs-backend-cpu": "^4.7.0", diff --git a/server/prisma/schema.prisma b/server/prisma/schema.prisma index c197aa17..4d14b8ef 100644 --- a/server/prisma/schema.prisma +++ b/server/prisma/schema.prisma @@ -1,10 +1,12 @@ generator client { provider = "prisma-client-js" + previewFeatures = ["postgresqlExtensions"] } datasource db { provider = "postgresql" url = env("DATABASE_URL") + extensions = [pgvector(map: "vector", schema: "extensions")] } model Bot { diff --git a/server/src/app.ts b/server/src/app.ts index 22aa5697..08e7c24a 100644 --- a/server/src/app.ts +++ b/server/src/app.ts @@ -4,6 +4,7 @@ import { FastifyPluginAsync } from "fastify"; import cors from "@fastify/cors"; import fastifyStatic from "@fastify/static"; import fastifyMultipart from "@fastify/multipart"; + export type AppOptions = {} & Partial; const options: AppOptions = {}; diff --git a/server/src/queue/index.ts b/server/src/queue/index.ts index c6b81d9b..49ecc3ad 100644 --- a/server/src/queue/index.ts +++ b/server/src/queue/index.ts @@ -18,145 +18,147 @@ export const queue = new Queue("vector", process.env.DB_REDIS_URL!, {}); export const queueHandler = async (job: Job, done: DoneCallback) => { const data = job.data as QSource[]; - console.log("Processing queue" ); - - for (const source of data) { - try { - if (source.type.toLowerCase() === "website") { - await prisma.botSource.update({ - where: { - id: source.id, - }, - data: { - status: "PROCESSING", - }, - }); - - const loader = new CheerioWebBaseLoader(source.content!); - const docs = await loader.load(); - - const textSplitter = new RecursiveCharacterTextSplitter({ - chunkSize: 1000, - chunkOverlap: 200, - }); - const chunks = await textSplitter.splitDocuments(docs); - - await DialoqbaseVectorStore.fromDocuments( - chunks, - embeddings(source.embedding), - { - botId: source.botId, - sourceId: source.id, - }, - ); - - await prisma.botSource.update({ - where: { - id: source.id, - }, - data: { - status: "FINISHED", - isPending: false, - }, - }); - } else if (source.type.toLowerCase() === "text") { - await prisma.botSource.update({ - where: { - id: source.id, - }, - data: { - status: "PROCESSING", - }, - }); - - const textSplitter = new RecursiveCharacterTextSplitter({ - chunkSize: 1000, - chunkOverlap: 200, - }); - const chunks = await textSplitter.splitDocuments([ - { - pageContent: source.content!, - metadata: { - source: `text-${source.id}`, + console.log("Processing queue"); + try { + for (const source of data) { + try { + if (source.type.toLowerCase() === "website") { + await prisma.botSource.update({ + where: { + id: source.id, }, - }, - ]); - - await DialoqbaseVectorStore.fromDocuments( - chunks, - embeddings(source.embedding), - { - botId: source.botId, - sourceId: source.id, - }, - ); + data: { + status: "PROCESSING", + }, + }); + + const loader = new CheerioWebBaseLoader(source.content!); + const docs = await loader.load(); + + const textSplitter = new RecursiveCharacterTextSplitter({ + chunkSize: 1000, + chunkOverlap: 200, + }); + const chunks = await textSplitter.splitDocuments(docs); + + await DialoqbaseVectorStore.fromDocuments( + chunks, + embeddings(source.embedding), + { + botId: source.botId, + sourceId: source.id, + }, + ); - await prisma.botSource.update({ - where: { - id: source.id, - }, - data: { - status: "FINISHED", - isPending: false, - }, - }); - } else if (source.type.toLowerCase() === "pdf") { - console.log("loading pdf"); - await prisma.botSource.update({ - where: { - id: source.id, - }, - data: { - status: "PROCESSING", - }, - }); + await prisma.botSource.update({ + where: { + id: source.id, + }, + data: { + status: "FINISHED", + isPending: false, + }, + }); + } else if (source.type.toLowerCase() === "text") { + await prisma.botSource.update({ + where: { + id: source.id, + }, + data: { + status: "PROCESSING", + }, + }); + + const textSplitter = new RecursiveCharacterTextSplitter({ + chunkSize: 1000, + chunkOverlap: 200, + }); + const chunks = await textSplitter.splitDocuments([ + { + pageContent: source.content!, + metadata: { + source: `text-${source.id}`, + }, + }, + ]); + + await DialoqbaseVectorStore.fromDocuments( + chunks, + embeddings(source.embedding), + { + botId: source.botId, + sourceId: source.id, + }, + ); - const location = source.location!; - const loader = new PDFLoader(location); - const docs = await loader.load(); + await prisma.botSource.update({ + where: { + id: source.id, + }, + data: { + status: "FINISHED", + isPending: false, + }, + }); + } else if (source.type.toLowerCase() === "pdf") { + console.log("loading pdf"); + await prisma.botSource.update({ + where: { + id: source.id, + }, + data: { + status: "PROCESSING", + }, + }); + + const location = source.location!; + const loader = new PDFLoader(location); + const docs = await loader.load(); + + const textSplitter = new RecursiveCharacterTextSplitter({ + chunkSize: 1000, + chunkOverlap: 200, + }); + const chunks = await textSplitter.splitDocuments(docs); + + await DialoqbaseVectorStore.fromDocuments( + chunks, + embeddings(source.embedding), + { + botId: source.botId, + sourceId: source.id, + }, + ); - const textSplitter = new RecursiveCharacterTextSplitter({ - chunkSize: 1000, - chunkOverlap: 200, - }); - const chunks = await textSplitter.splitDocuments(docs); - - await DialoqbaseVectorStore.fromDocuments( - chunks, - embeddings(source.embedding), - { - botId: source.botId, - sourceId: source.id, - }, - ); + await prisma.botSource.update({ + where: { + id: source.id, + }, + data: { + status: "FINISHED", + isPending: false, + }, + }); + } + } catch (e) { + console.log(e); await prisma.botSource.update({ where: { id: source.id, }, data: { - status: "FINISHED", + status: "FAILED", isPending: false, }, }); } - } catch (e) { - console.log(e); - - await prisma.botSource.update({ - where: { - id: source.id, - }, - data: { - status: "FAILED", - isPending: false, - }, - }); } + } catch (e) { + console.log(e); + } finally { + done(); } - - done(); }; queue.process(queueHandler); - diff --git a/server/src/routes/api/v1/bot/handlers/index.ts b/server/src/routes/api/v1/bot/handlers/index.ts index bd432fb9..7eed1240 100644 --- a/server/src/routes/api/v1/bot/handlers/index.ts +++ b/server/src/routes/api/v1/bot/handlers/index.ts @@ -129,9 +129,10 @@ export const createBotPDFHandler = async ( ...botSource, embedding: bot.embedding, }]); - return { + + return reply.status(200).send({ id: bot.id, - }; + }); } catch (err) { return reply.status(500).send({ message: "Upload failed due to internal server error", diff --git a/server/src/routes/api/v1/bot/handlers/schema.ts b/server/src/routes/api/v1/bot/handlers/schema.ts index 6d567363..a09e5dbc 100644 --- a/server/src/routes/api/v1/bot/handlers/schema.ts +++ b/server/src/routes/api/v1/bot/handlers/schema.ts @@ -17,7 +17,7 @@ export const createBotSchema: FastifySchema = { }, embedding: { type: "string", - enum: ["tensorflow", "openai", "cohere"], + enum: ["tensorflow", "openai", "cohere", "huggingface-api"], } }, }, diff --git a/server/src/routes/bot/root.ts b/server/src/routes/bot/root.ts index 1924c827..214350d3 100644 --- a/server/src/routes/bot/root.ts +++ b/server/src/routes/bot/root.ts @@ -11,6 +11,8 @@ const root: FastifyPluginAsync = async (fastify, _): Promise => { fastify.get("/:id", async (request, reply) => { return reply.sendFile('bot.html') }); + + }; export default root; diff --git a/server/src/utils/embeddings.ts b/server/src/utils/embeddings.ts index e7a62f25..227ea86f 100644 --- a/server/src/utils/embeddings.ts +++ b/server/src/utils/embeddings.ts @@ -1,6 +1,8 @@ import "@tensorflow/tfjs-backend-cpu"; import { TensorFlowEmbeddings } from "langchain/embeddings/tensorflow"; import { OpenAIEmbeddings } from "langchain/embeddings/openai"; +import { CohereEmbeddings } from "langchain/embeddings/cohere"; +import { HuggingFaceInferenceEmbeddings } from "langchain/embeddings/hf"; export const embeddings = (embeddingsType: string) => { switch (embeddingsType) { @@ -8,7 +10,11 @@ export const embeddings = (embeddingsType: string) => { return new TensorFlowEmbeddings(); case "openai": return new OpenAIEmbeddings(); + case "cohere": + return new CohereEmbeddings(); + case "huggingface-api": + return new HuggingFaceInferenceEmbeddings(); default: return new OpenAIEmbeddings(); } -}; \ No newline at end of file +}; diff --git a/server/src/utils/store.ts b/server/src/utils/store.ts index 53f0e7ee..c360c949 100644 --- a/server/src/utils/store.ts +++ b/server/src/utils/store.ts @@ -42,14 +42,18 @@ export class DialoqbaseVectorStore extends VectorStore { // this is bad method right ? try { chunk.forEach(async (row) => { - await prisma.$executeRawUnsafe( - 'INSERT INTO "BotDocument" ("content", "embedding", "metadata", "botId", "sourceId") VALUES ($1, $2, $3, $4, $5)', - row.content, - row.embedding, - row.metadata, - row.botId, - row.sourceId, - ); + const vector = `[${row.embedding.join(",")}]`; + + // console.log(vector.length) + // await prisma.$executeRawUnsafe( + // 'INSERT INTO "BotDocument" ("content", "embedding", "metadata", "botId", "sourceId") VALUES ($1, $2, $3, $4, $5)', + // row.content, + // vector, + // row.metadata, + // row.botId, + // row.sourceId, + // ); + await prisma.$executeRaw`INSERT INTO "BotDocument" ("content", "embedding", "metadata", "botId", "sourceId") VALUES (${row.content}, ${vector}::vector, ${row.metadata}, ${row.botId}, ${row.sourceId})` }); } catch (e) { console.log(e); @@ -120,7 +124,6 @@ export class DialoqbaseVectorStore extends VectorStore { }), resp.similarity, ]); - return result; } } diff --git a/server/src/utils/validate.ts b/server/src/utils/validate.ts index 507c3bc1..806974e4 100644 --- a/server/src/utils/validate.ts +++ b/server/src/utils/validate.ts @@ -6,15 +6,20 @@ export const embeddingsValidation = (embeddingsType: string) => { return process.env.OPENAI_API_KEY ? process.env.OPENAI_API_KEY.length > 0 : false; case "cohere": return process.env.COHERE_API_KEY ? process.env.COHERE_API_KEY.length > 0 : false; + case "huggingface-api": + return process.env.HUGGINGFACEHUB_API_KEY ? process.env.HUGGINGFACEHUB_API_KEY.length > 0 : false; + default: + true } } - export const embeddingsValidationMessage = (embeddingsType: string) => { switch (embeddingsType) { case "openai": return "Please add OPENAI_API_KEY to your .env file" case "cohere": return "Please add COHERE_API_KEY to your .env file" + case "huggingface-api": + return "Please add HUGGINGFACEHUB_API_KEY to your .env file" } } \ No newline at end of file diff --git a/server/yarn.lock b/server/yarn.lock index a28da357..cf888306 100644 --- a/server/yarn.lock +++ b/server/yarn.lock @@ -395,6 +395,11 @@ resolved "https://registry.yarnpkg.com/@fortaine/fetch-event-source/-/fetch-event-source-3.0.6.tgz#b8552a2ca2c5202f5699b93a92be0188d422b06e" integrity sha512-621GAuLMvKtyZQ3IA6nlDWhV1V/7PGOTNIGLUifxt0KzM+dZIweJ6F3XvQF3QnqeNfS1N7WQ0Kil1Di/lhChEw== +"@huggingface/inference@1": + version "1.8.0" + resolved "https://registry.yarnpkg.com/@huggingface/inference/-/inference-1.8.0.tgz#5b1b22c790451e8051adefe9b7f725a7135c38a8" + integrity sha512-Dkh7PiyMf6TINRocQsdceiR5LcqJiUHgWjaBMRpCUOCbs+GZA122VH9q+wodoSptj6rIQf7wIwtDsof+/gd0WA== + "@ioredis/commands@^1.1.1": version "1.2.0" resolved "https://registry.yarnpkg.com/@ioredis/commands/-/commands-1.2.0.tgz#6d61b3097470af1fdbbe622795b8921d42018e11"