Skip to content

Commit

Permalink
Compatible with unique index conflicts (langgenius#3183)
Browse files Browse the repository at this point in the history
  • Loading branch information
JohnJyong authored Apr 8, 2024
1 parent ca3e2e6 commit 2e4dec3
Showing 1 changed file with 16 additions and 12 deletions.
28 changes: 16 additions & 12 deletions api/core/embedding/cached_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ def embed_documents(self, texts: list[str]) -> list[list[float]]:
embedding_queue_embeddings = []
try:
model_type_instance = cast(TextEmbeddingModel, self._model_instance.model_type_instance)
model_schema = model_type_instance.get_model_schema(self._model_instance.model, self._model_instance.credentials)
model_schema = model_type_instance.get_model_schema(self._model_instance.model,
self._model_instance.credentials)
max_chunks = model_schema.model_properties[ModelPropertyKey.MAX_CHUNKS] \
if model_schema and ModelPropertyKey.MAX_CHUNKS in model_schema.model_properties else 1
for i in range(0, len(embedding_queue_texts), max_chunks):
Expand All @@ -61,17 +62,20 @@ def embed_documents(self, texts: list[str]) -> list[list[float]]:
except Exception as e:
logging.exception('Failed transform embedding: ', e)
cache_embeddings = []
for i, embedding in zip(embedding_queue_indices, embedding_queue_embeddings):
text_embeddings[i] = embedding
hash = helper.generate_text_hash(texts[i])
if hash not in cache_embeddings:
embedding_cache = Embedding(model_name=self._model_instance.model,
hash=hash,
provider_name=self._model_instance.provider)
embedding_cache.set_embedding(embedding)
db.session.add(embedding_cache)
cache_embeddings.append(hash)
db.session.commit()
try:
for i, embedding in zip(embedding_queue_indices, embedding_queue_embeddings):
text_embeddings[i] = embedding
hash = helper.generate_text_hash(texts[i])
if hash not in cache_embeddings:
embedding_cache = Embedding(model_name=self._model_instance.model,
hash=hash,
provider_name=self._model_instance.provider)
embedding_cache.set_embedding(embedding)
db.session.add(embedding_cache)
cache_embeddings.append(hash)
db.session.commit()
except IntegrityError:
db.session.rollback()
except Exception as ex:
db.session.rollback()
logger.error('Failed to embed documents: ', ex)
Expand Down

0 comments on commit 2e4dec3

Please sign in to comment.