From 2e4dec365db2cff503edc35507b063ab6df74cfd Mon Sep 17 00:00:00 2001 From: Jyong <76649700+JohnJyong@users.noreply.github.com> Date: Tue, 9 Apr 2024 02:16:19 +0800 Subject: [PATCH] Compatible with unique index conflicts (#3183) --- api/core/embedding/cached_embedding.py | 28 +++++++++++++++----------- 1 file changed, 16 insertions(+), 12 deletions(-) diff --git a/api/core/embedding/cached_embedding.py b/api/core/embedding/cached_embedding.py index 4156368e562c2c..b7e0cc0c2b2ae6 100644 --- a/api/core/embedding/cached_embedding.py +++ b/api/core/embedding/cached_embedding.py @@ -41,7 +41,8 @@ def embed_documents(self, texts: list[str]) -> list[list[float]]: embedding_queue_embeddings = [] try: model_type_instance = cast(TextEmbeddingModel, self._model_instance.model_type_instance) - model_schema = model_type_instance.get_model_schema(self._model_instance.model, self._model_instance.credentials) + model_schema = model_type_instance.get_model_schema(self._model_instance.model, + self._model_instance.credentials) max_chunks = model_schema.model_properties[ModelPropertyKey.MAX_CHUNKS] \ if model_schema and ModelPropertyKey.MAX_CHUNKS in model_schema.model_properties else 1 for i in range(0, len(embedding_queue_texts), max_chunks): @@ -61,17 +62,20 @@ def embed_documents(self, texts: list[str]) -> list[list[float]]: except Exception as e: logging.exception('Failed transform embedding: ', e) cache_embeddings = [] - for i, embedding in zip(embedding_queue_indices, embedding_queue_embeddings): - text_embeddings[i] = embedding - hash = helper.generate_text_hash(texts[i]) - if hash not in cache_embeddings: - embedding_cache = Embedding(model_name=self._model_instance.model, - hash=hash, - provider_name=self._model_instance.provider) - embedding_cache.set_embedding(embedding) - db.session.add(embedding_cache) - cache_embeddings.append(hash) - db.session.commit() + try: + for i, embedding in zip(embedding_queue_indices, embedding_queue_embeddings): + text_embeddings[i] = embedding + hash = helper.generate_text_hash(texts[i]) + if hash not in cache_embeddings: + embedding_cache = Embedding(model_name=self._model_instance.model, + hash=hash, + provider_name=self._model_instance.provider) + embedding_cache.set_embedding(embedding) + db.session.add(embedding_cache) + cache_embeddings.append(hash) + db.session.commit() + except IntegrityError: + db.session.rollback() except Exception as ex: db.session.rollback() logger.error('Failed to embed documents: ', ex)