Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🌟 [destination-pinecone]: Add support for Pinecone Serverless #37601

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -24,28 +24,40 @@ class DestinationPinecone(Destination):
embedder: Embedder

def _init_indexer(self, config: ConfigModel):
self.embedder = create_from_config(config.embedding, config.processing)
self.indexer = PineconeIndexer(config.indexing, self.embedder.embedding_dimensions)

def write(
self, config: Mapping[str, Any], configured_catalog: ConfiguredAirbyteCatalog, input_messages: Iterable[AirbyteMessage]
) -> Iterable[AirbyteMessage]:
config_model = ConfigModel.parse_obj(config)
self._init_indexer(config_model)
writer = Writer(
config_model.processing, self.indexer, self.embedder, batch_size=BATCH_SIZE, omit_raw_text=config_model.omit_raw_text
)
yield from writer.write(configured_catalog, input_messages)
try:
self.embedder = create_from_config(config.embedding, config.processing)
self.indexer = PineconeIndexer(config.indexing, self.embedder.embedding_dimensions)
except Exception as e:
return AirbyteConnectionStatus(status=Status.FAILED, message=str(e))

def write(self, config: Mapping[str, Any], configured_catalog: ConfiguredAirbyteCatalog, input_messages: Iterable[AirbyteMessage]) -> Iterable[AirbyteMessage]:
try:
config_model = ConfigModel.parse_obj(config)
self._init_indexer(config_model)
writer = Writer(config_model.processing, self.indexer, self.embedder, batch_size=BATCH_SIZE, omit_raw_text=config_model.omit_raw_text)
yield from writer.write(configured_catalog, input_messages)
except Exception as e:
yield AirbyteMessage(type='LOG', log=AirbyteLogger(level='ERROR', message=str(e)))

def check(self, logger: AirbyteLogger, config: Mapping[str, Any]) -> AirbyteConnectionStatus:
parsed_config = ConfigModel.parse_obj(config)
self._init_indexer(parsed_config)
checks = [self.embedder.check(), self.indexer.check(), DocumentProcessor.check_config(parsed_config.processing)]
errors = [error for error in checks if error is not None]
if len(errors) > 0:
return AirbyteConnectionStatus(status=Status.FAILED, message="\n".join(errors))
else:
return AirbyteConnectionStatus(status=Status.SUCCEEDED)
try:
parsed_config = ConfigModel.parse_obj(config)
init_status = self._init_indexer(parsed_config)
if init_status and init_status.status == Status.FAILED:
logger.error(f"Initialization failed with message: {init_status.message}")
return init_status # Return the failure status immediately if initialization fails

checks = [self.embedder.check(), self.indexer.check(), DocumentProcessor.check_config(parsed_config.processing)]
errors = [error for error in checks if error is not None]
if len(errors) > 0:
error_message = "\n".join(errors)
logger.error(f"Configuration check failed: {error_message}")
return AirbyteConnectionStatus(status=Status.FAILED, message=error_message)
else:
return AirbyteConnectionStatus(status=Status.SUCCEEDED)
except Exception as e:
logger.error(f"Exception during configuration check: {str(e)}")
return AirbyteConnectionStatus(status=Status.FAILED, message=str(e))

def spec(self, *args: Any, **kwargs: Any) -> ConnectorSpecification:
return ConnectorSpecification(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,16 @@
import uuid
from typing import Optional

import pinecone
from pinecone.grpc import PineconeGRPC
from pinecone import PineconeException
import urllib3
from airbyte_cdk.destinations.vector_db_based.document_processor import METADATA_RECORD_ID_FIELD, METADATA_STREAM_FIELD
from airbyte_cdk.destinations.vector_db_based.indexer import Indexer
from airbyte_cdk.destinations.vector_db_based.utils import create_chunks, create_stream_identifier, format_exception
from airbyte_cdk.models.airbyte_protocol import ConfiguredAirbyteCatalog, DestinationSyncMode
from airbyte_cdk.models import AirbyteConnectionStatus, Status


from destination_pinecone.config import PineconeIndexingModel

# large enough to speed up processing, small enough to not hit pinecone request limits
Expand All @@ -29,29 +33,51 @@ class PineconeIndexer(Indexer):

def __init__(self, config: PineconeIndexingModel, embedding_dimensions: int):
super().__init__(config)
pinecone.init(api_key=config.pinecone_key, environment=config.pinecone_environment, threaded=True)

self.pinecone_index = pinecone.GRPCIndex(config.index)
try:
self.pc = PineconeGRPC(api_key=config.pinecone_key, threaded=True)
except PineconeException as e:
return AirbyteConnectionStatus(status=Status.FAILED, message=str(e))

self.pinecone_index = self.pc.Index(config.index)
self.embedding_dimensions = embedding_dimensions

def determine_spec_type(self, index_name):
description = self.pc.describe_index(index_name)
spec_keys = description.get('spec', {})
if 'pod' in spec_keys:
return 'pod'
elif 'serverless' in spec_keys:
return 'serverless'
else:
raise ValueError("Unknown index specification type.")

def pre_sync(self, catalog: ConfiguredAirbyteCatalog):
index_description = pinecone.describe_index(self.config.index)
self._pod_type = index_description.pod_type
index_description = self.pc.describe_index(self.config.index)
self._pod_type = self.determine_spec_type(self.config.index)

for stream in catalog.streams:
stream_identifier = create_stream_identifier(stream.stream)
if stream.destination_sync_mode == DestinationSyncMode.overwrite:
self.delete_vectors(
filter={METADATA_STREAM_FIELD: create_stream_identifier(stream.stream)}, namespace=stream.stream.namespace
filter={METADATA_STREAM_FIELD: stream_identifier}, namespace=stream.stream.namespace, prefix=stream_identifier
)

def post_sync(self):
return []



def delete_vectors(self, filter, namespace=None):
def delete_vectors(self, filter, namespace=None, prefix=None):
if self._pod_type == "starter":
# Starter pod types have a maximum of 100000 rows
top_k = 10000
self.delete_by_metadata(filter, top_k, namespace)
elif self._pod_type == "serverless":
if prefix == None:
raise ValueError("Prefix is required for a serverless index.")
self.delete_by_prefix(prefix=prefix, namespace=namespace)
else:
# Pod spec
self.pinecone_index.delete(filter=filter, namespace=namespace)

def delete_by_metadata(self, filter, top_k, namespace=None):
Expand All @@ -66,6 +92,10 @@ def delete_by_metadata(self, filter, top_k, namespace=None):
self.pinecone_index.delete(ids=list(batch), namespace=namespace)
query_result = self.pinecone_index.query(vector=zero_vector, filter=filter, top_k=top_k, namespace=namespace)

def delete_by_prefix(self, prefix, namespace=None):
for ids in self.pinecone_index.list(prefix=prefix, namespace=namespace):
self.pinecone_index.delete(ids=ids, namespace=namespace)

def _truncate_metadata(self, metadata: dict) -> dict:
"""
Normalize metadata to ensure it is within the size limit and doesn't contain complex objects.
Expand All @@ -85,34 +115,46 @@ def _truncate_metadata(self, metadata: dict) -> dict:

return result

def index(self, document_chunks, namespace, stream):
def index(self, document_chunks, namespace, streamName):
pinecone_docs = []
for i in range(len(document_chunks)):
chunk = document_chunks[i]
metadata = self._truncate_metadata(chunk.metadata)
if chunk.page_content is not None:
metadata["text"] = chunk.page_content
pinecone_docs.append((str(uuid.uuid4()), chunk.embedding, metadata))
metadata["text"] = chunk.page_content
prefix = streamName
pinecone_docs.append((prefix + "#" + str(uuid.uuid4()), chunk.embedding, metadata))
serial_batches = create_chunks(pinecone_docs, batch_size=PINECONE_BATCH_SIZE * PARALLELISM_LIMIT)
for batch in serial_batches:
async_results = [
self.pinecone_index.upsert(vectors=ids_vectors_chunk, async_req=True, show_progress=False, namespace=namespace)
for ids_vectors_chunk in create_chunks(batch, batch_size=PINECONE_BATCH_SIZE)
]
async_results = []
for ids_vectors_chunk in create_chunks(batch, batch_size=PINECONE_BATCH_SIZE):
async_result = self.pinecone_index.upsert(vectors=ids_vectors_chunk, async_req=True, show_progress=False)
async_results.append(async_result)
# Wait for and retrieve responses (this raises in case of error)
[async_result.result() for async_result in async_results]

def delete(self, delete_ids, namespace, stream):
filter = {METADATA_RECORD_ID_FIELD: {"$in": delete_ids}}
if len(delete_ids) > 0:
self.delete_vectors(filter={METADATA_RECORD_ID_FIELD: {"$in": delete_ids}}, namespace=namespace)
if self._pod_type == "starter":
# Starter pod types have a maximum of 100000 rows
top_k = 10000
self.delete_by_metadata(filter=filter, top_k=top_k, namespace=namespace)
elif self._pod_type == "serverless":
self.pinecone_index.delete(ids=delete_ids, namespace=namespace)
else:
# Pod spec
self.pinecone_index.delete(filter=filter, namespace=namespace)

def check(self) -> Optional[str]:
try:
indexes = pinecone.list_indexes()
if self.config.index not in indexes:
list = self.pc.list_indexes()
index_names = [index['name'] for index in list.indexes]

if self.config.index not in index_names:
return f"Index {self.config.index} does not exist in environment {self.config.pinecone_environment}."

description = pinecone.describe_index(self.config.index)
description = self.pc.describe_index(self.config.index)
actual_dimension = int(description.dimension)
if actual_dimension != self.embedding_dimensions:
return f"Your embedding configuration will produce vectors with dimension {self.embedding_dimensions:d}, but your index is configured with dimension {actual_dimension:d}. Make sure embedding and indexing configurations match."
Expand All @@ -121,7 +163,7 @@ def check(self) -> Optional[str]:
if f"Failed to resolve 'controller.{self.config.pinecone_environment}.pinecone.io'" in str(e.reason):
return f"Failed to resolve environment, please check whether {self.config.pinecone_environment} is correct."

if isinstance(e, pinecone.exceptions.UnauthorizedException):
if isinstance(e, self.pc.exceptions.UnauthorizedException):
if e.body:
return e.body

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,11 @@

import json
import logging
import time

import pinecone
from pinecone.grpc import PineconeGRPC
from pinecone import PineconeException
from pinecone import Pinecone as PineconeREST
from airbyte_cdk.destinations.vector_db_based.embedder import OPEN_AI_VECTOR_SIZE
from airbyte_cdk.destinations.vector_db_based.test_utils import BaseIntegrationTest
from airbyte_cdk.models import DestinationSyncMode, Status
Expand All @@ -16,21 +19,39 @@

class PineconeIntegrationTest(BaseIntegrationTest):
def _init_pinecone(self):
pinecone.init(api_key=self.config["indexing"]["pinecone_key"], environment=self.config["indexing"]["pinecone_environment"])
self.pinecone_index = pinecone.Index(self.config["indexing"]["index"])

self.pc = PineconeGRPC(api_key=self.config["indexing"]["pinecone_key"])
self.pinecone_index = self.pc.Index(self.config["indexing"]["index"])
self.pc_rest = PineconeREST(api_key=self.config["indexing"]["pinecone_key"])
self.pinecone_index_rest = self.pc_rest.Index(name=self.config["indexing"]["index"])

def _wait(self):
print("Waiting for Pinecone...", end='', flush=True)
for i in range(15):
time.sleep(1)
print(".", end='', flush=True)
print() # Move to the next line after the loop

def setUp(self):
with open("secrets/config.json", "r") as f:
self.config = json.loads(f.read())
self._init_pinecone()
# self.pinecone_index.delete(delete_all=True)

def tearDown(self):
# make sure pinecone is initialized correctly before cleaning up
self._wait()
# make sure pinecone is initialized correctly before cleaning up
self._init_pinecone()
self.pinecone_index.delete(delete_all=True)
try:
self.pinecone_index.delete(delete_all=True)
except PineconeException as e:
if "Namespace not found" not in str(e):
raise(e)
else :
print("Noting to delete. No data in the index/namespace.")


def test_check_valid_config(self):
outcome = DestinationPinecone().check(logging.getLogger("airbyte"), self.config)
outcome = DestinationPinecone().check(logging.getLogger("airbyte"), self.config)
assert outcome.status == Status.SUCCEEDED

def test_check_invalid_config(self):
Expand All @@ -43,10 +64,11 @@ def test_check_invalid_config(self):
"mode": "pinecone",
"pinecone_key": "mykey",
"index": "testdata",
"pinecone_environment": "asia-southeast1-gcp-free",
"pinecone_environment": "us-west1-gcp",
},
},
)

assert outcome.status == Status.FAILED

def test_write(self):
Expand All @@ -57,14 +79,21 @@ def test_write(self):
# initial sync
destination = DestinationPinecone()
list(destination.write(self.config, catalog, [*first_record_chunk, first_state_message]))


self._wait()
assert self.pinecone_index.describe_index_stats().total_vector_count == 5

# incrementalally update a doc
incremental_catalog = self._get_configured_catalog(DestinationSyncMode.append_dedup)
list(destination.write(self.config, incremental_catalog, [self._record("mystream", "Cats are nice", 2), first_state_message]))

self._wait()

result = self.pinecone_index.query(
vector=[0] * OPEN_AI_VECTOR_SIZE, top_k=10, filter={"_ab_record_id": "mystream_2"}, include_metadata=True
)

assert len(result.matches) == 1
assert (
result.matches[0].metadata["text"] == "str_col: Cats are nice"
Expand All @@ -73,6 +102,6 @@ def test_write(self):
# test langchain integration
embeddings = OpenAIEmbeddings(openai_api_key=self.config["embedding"]["openai_key"])
self._init_pinecone()
vector_store = Pinecone(self.pinecone_index, embeddings.embed_query, "text")
vector_store = Pinecone(self.pinecone_index_rest, embeddings.embed_query, "text")
result = vector_store.similarity_search("feline animals", 1)
assert result[0].metadata["_ab_record_id"] == "mystream_2"
Loading
Loading