From 18038b529ca1efc6df3f9b26a337d72b4ff9e010 Mon Sep 17 00:00:00 2001 From: Andris Reinman Date: Sat, 23 Sep 2023 17:55:53 +0300 Subject: [PATCH] Store embeddings in ElasticSearch --- lib/es.js | 40 ++++++++++ package.json | 10 +-- workers/documents.js | 178 ++++++++++++++++++++++++++++++++++++++++++- 3 files changed, 221 insertions(+), 7 deletions(-) diff --git a/lib/es.js b/lib/es.js index 139b320f..5f829f92 100644 --- a/lib/es.js +++ b/lib/es.js @@ -512,6 +512,36 @@ const threadTemplateSettings = { number_of_replicas: 1 }; +const embeddingsMappings = { + account: { + type: 'keyword', + ignore_above: 256 + }, + messageId: { + type: 'keyword', + ignore_above: 998 + }, + embeddings: { + type: 'dense_vector', + dims: 1536, + index: true, + similarity: 'cosine' + }, + chunk: { + type: 'text', + index: false + }, + chunkNr: { + type: 'integer' + }, + chunks: { + type: 'integer' + }, + created: { + type: 'date' + } +}; + /** * Function to either create or update an index to match the definition * @param {Object} client ElasticSearch client object @@ -704,6 +734,16 @@ module.exports = { await ensureThreadIndex(client, index); + try { + // try to create vector index, might fail with older ES versions (<8.8) + let embeddingsIndexResult = await ensureIndex(client, `${index}.embeddings`, { mappings: embeddingsMappings }); + if (embeddingsIndexResult) { + await redis.hset(`${REDIS_PREFIX}settings`, `embeddings:index`, JSON.stringify({ updated: Date.now() })); + } + } catch (err) { + await redis.hset(`${REDIS_PREFIX}settings`, `embeddings:index`, JSON.stringify({ error: err.message })); + } + return indexResult; }, diff --git a/package.json b/package.json index 3c991c3b..cc33d00b 100644 --- a/package.json +++ b/package.json @@ -50,14 +50,14 @@ "@hapi/vision": "7.0.3", "@phc/pbkdf2": "1.1.14", "@postalsys/certs": "1.0.5", - "@postalsys/email-ai-tools": "1.4.0", + "@postalsys/email-ai-tools": "1.4.1", "@postalsys/email-text-tools": "2.1.1", "@postalsys/hecks": "3.0.0-fork.3", "@postalsys/templates": "1.0.5", - "ace-builds": "1.27.0", + "ace-builds": "1.28.0", "base32.js": "0.1.0", "bull-arena": "4.0.1", - "bullmq": "4.11.2", + "bullmq": "4.11.4", "compare-versions": "6.1.0", "dotenv": "16.3.1", "encoding-japanese": "2.0.0", @@ -104,7 +104,7 @@ "speakeasy": "2.0.0", "startbootstrap-sb-admin-2": "3.3.7", "timezones-list": "3.0.2", - "undici": "5.25.1", + "undici": "5.25.2", "uuid": "9.0.1", "wild-config": "1.7.0", "xml2js": "0.6.2" @@ -112,7 +112,7 @@ "devDependencies": { "chai": "4.3.8", "eerawlog": "1.5.1", - "eslint": "8.49.0", + "eslint": "8.50.0", "eslint-config-nodemailer": "1.2.0", "eslint-config-prettier": "9.0.0", "grunt": "1.6.1", diff --git a/workers/documents.js b/workers/documents.js index c8b469c6..43005742 100644 --- a/workers/documents.js +++ b/workers/documents.js @@ -5,8 +5,10 @@ const { parentPort } = require('worker_threads'); const packageData = require('../package.json'); const logger = require('../lib/logger'); const { preProcess } = require('../lib/pre-process'); +const settings = require('../lib/settings'); +const crypto = require('crypto'); -const { readEnvValue, threadStats } = require('../lib/tools'); +const { readEnvValue, threadStats, getDuration } = require('../lib/tools'); const Bugsnag = require('@bugsnag/js'); if (readEnvValue('BUGSNAG_API_KEY')) { @@ -46,6 +48,48 @@ const { REDIS_PREFIX } = require('../lib/consts'); +const config = require('wild-config'); +config.service = config.service || {}; + +const DEFAULT_EENGINE_TIMEOUT = 10 * 1000; + +const EENGINE_TIMEOUT = getDuration(readEnvValue('EENGINE_TIMEOUT') || config.service.commandTimeout) || DEFAULT_EENGINE_TIMEOUT; + +let callQueue = new Map(); +let mids = 0; + +async function call(message, transferList) { + return new Promise((resolve, reject) => { + let mid = `${Date.now()}:${++mids}`; + + let ttl = Math.max(message.timeout || 0, EENGINE_TIMEOUT || 0); + let timer = setTimeout(() => { + let err = new Error('Timeout waiting for command response [T4]'); + err.statusCode = 504; + err.code = 'Timeout'; + err.ttl = ttl; + reject(err); + }, ttl); + + callQueue.set(mid, { resolve, reject, timer }); + + try { + parentPort.postMessage( + { + cmd: 'call', + mid, + message + }, + transferList + ); + } catch (err) { + clearTimeout(timer); + callQueue.delete(mid); + return reject(err); + } + }); +} + async function metrics(logger, key, method, ...args) { try { parentPort.postMessage({ @@ -320,6 +364,8 @@ const documentsWorker = new Worker( } // Skip embeddings if set for document store (nested dense cosine vectors can not be indexed, must be separate documents) + + let embeddings = messageData.embeddings; delete messageData.embeddings; let emailDocument = await preProcess.run(messageData); @@ -379,12 +425,140 @@ const documentsWorker = new Worker( indexResult }); + let storedEmbeddings; + + if ((await settings.get('documentStoreGenerateEmbeddings')) && messageData.messageId) { + let embeddingsQuery = { + bool: { + must: [ + { + term: { + account: job.data.account + } + }, + { + term: { + messageId: messageData.messageId + } + } + ] + } + }; + + let embeddingsIndex = `${index}.embeddings`; + + let existingResult; + + try { + existingResult = await client.search({ + index: embeddingsIndex, + size: 1, + query: embeddingsQuery, + _source: false + }); + if (!existingResult || !existingResult.hits) { + logger.error({ + msg: 'Failed to check for existing embeddings', + account: job.data.account, + messageId: messageData.messageId, + existingResult + }); + storedEmbeddings = false; + } + } catch (err) { + logger.error({ + msg: 'Failed to check for existing embeddings', + account: job.data.account, + messageId: messageData.messageId, + err + }); + storedEmbeddings = false; + } + + if (existingResult?.hits?.total?.value === 0) { + if (!embeddings) { + try { + embeddings = await call({ + cmd: 'generateEmbeddings', + data: { + message: { + headers: Object.keys(messageData.headers || {}).map(key => ({ + key, + value: [].concat(messageData.headers[key] || []) + })), + attachments: messageData.attachments, + from: messageData.from, + subject: messageData.subject, + text: messageData.text.plain, + html: messageData.text.html + } + }, + timeout: 2 * 60 * 1000 + }); + } catch (err) { + logger.error({ msg: 'Failed to fetch embeddings', account: job.data.account, messageId: messageData.messageId, err }); + storedEmbeddings = false; + } + } + + if (embeddings?.embeddings?.length) { + let messageIdHash = crypto.createHash('sha256').update(messageData.messageId).digest('hex'); + let dataset = embeddings.embeddings.map((entry, i) => ({ + account: job.data.account, + messageId: messageData.messageId, + embeddings: entry.embedding, + chunk: entry.chunk, + chunkNr: i, + chunks: embeddings.embeddings.length, + created: new Date() + })); + + const operations = dataset.flatMap(doc => [ + { index: { _index: embeddingsIndex, _id: `${job.data.account}:${messageIdHash}:${doc.chunkNr}` } }, + doc + ]); + + try { + const bulkResponse = await client.bulk({ refresh: true, operations }); + if (bulkResponse?.errors !== false) { + logger.error({ + msg: 'Failed to store embeddings', + account: job.data.account, + messageId: messageData.messageId, + bulkResponse + }); + storedEmbeddings = false; + } else { + logger.info({ + msg: 'Stored embeddings for a message', + messageId: messageData.messageId, + items: bulkResponse.items?.length + }); + storedEmbeddings = true; + } + } catch (err) { + logger.error({ + msg: 'Failed to store embeddings', + account: job.data.account, + messageId: messageData.messageId, + err + }); + storedEmbeddings = false; + } + } + } else { + logger.info({ msg: 'Skipped embeddings, already exist', account: job.data.account, messageId: messageData.messageId }); + storedEmbeddings = false; + } + } + return { index: indexResult._index, id: indexResult._id, documentVersion: indexResult._version, threadId: messageData.threadId, - result: indexResult.result + result: indexResult.result, + storedEmbeddings }; }