Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added pinecone, moved import stmts #42

Merged
merged 1 commit into from
May 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 0 additions & 9 deletions src/beyondllm/retrievers/hybridRetriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,15 +113,6 @@ def initialize_from_data(self):
)
return vector_index, keyword_index

# def load_index(self):
# vector_index = VectorStoreIndex(
# self.data, embed_model=self.embed_model
# )
# keyword_index = SimpleKeywordTableIndex(
# self.data, service_context=ServiceContext.from_defaults(llm=None,embed_model=None)
# )
# return vector_index, keyword_index

def as_retriever(self):
vector_index, keyword_index = self.load_index()

Expand Down
3 changes: 2 additions & 1 deletion src/beyondllm/vectordb/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .chroma import ChromaVectorDb
from .chroma import ChromaVectorDb
from .pinecone import PineconeVectorDb
21 changes: 11 additions & 10 deletions src/beyondllm/vectordb/chroma.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,6 @@
import warnings
warnings.filterwarnings("ignore")
import subprocess,sys
try:
from llama_index.vector_stores.chroma import ChromaVectorStore
except ImportError:
user_agree = input("The feature you're trying to use requires an additional library(s):llama_index.vector_stores.chroma. Would you like to install it now? [y/N]: ")
if user_agree.lower() == 'y':
subprocess.check_call([sys.executable, "-m", "pip", "install", "llama_index.vector_stores.chroma"])
from llama_index.vector_stores.chroma import ChromaVectorStore
else:
raise ImportError("The required 'llama_index.vector_stores.chroma' is not installed.")
import chromadb

@dataclass
class ChromaVectorDb:
Expand All @@ -24,6 +14,17 @@ class ChromaVectorDb:
persist_directory: str = ""

def __post_init__(self):
try:
from llama_index.vector_stores.chroma import ChromaVectorStore
except ImportError:
user_agree = input("The feature you're trying to use requires an additional library(s):llama_index.vector_stores.chroma. Would you like to install it now? [y/N]: ")
if user_agree.lower() == 'y':
subprocess.check_call([sys.executable, "-m", "pip", "install", "llama_index.vector_stores.chroma"])
from llama_index.vector_stores.chroma import ChromaVectorStore
else:
raise ImportError("The required 'llama_index.vector_stores.chroma' is not installed.")
import chromadb

if self.persist_directory=="" or self.persist_directory==None:
self.chroma_client = chromadb.EphemeralClient()
else:
Expand Down
114 changes: 114 additions & 0 deletions src/beyondllm/vectordb/pinecone.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
from beyondllm.vectordb.base import VectorDb, VectorDbConfig
from dataclasses import dataclass, field
import warnings
warnings.filterwarnings("ignore")
import os, subprocess, sys

@dataclass
class PineconeVectorDb:
"""
from beyondllm.vectordb import PineconeVectorDb
# to load pre-existing index
vectordb = PineconeVectorDb(
api_key="<your api key>",
index_name="quickstart",
)
# or
# create new index
vectordb = PineconeVectorDb(create=True,
api_key="<your api key>", index_name="quickstart",
embedding_dim=1536, metric="cosine",
spec="serverless", cloud="aws", region="us-east-1"
)
"""

api_key: str
index_name: str
create: bool = False
embedding_dim: int = None
metric: str = None
spec: str = "serverless"
cloud: str = None
region: str = None
pod_type: str = None # Only required for pod-type Pinecone indexes
replicas: int = None # Only required for pod-type Pinecone indexes

def __post_init__(self):
try:
from llama_index.vector_stores.pinecone import PineconeVectorStore
from pinecone import Pinecone, ServerlessSpec
except ImportError:
user_agree = input(
"The feature you're trying to use requires an additional library(s):llama_index.vector_stores.pinecone, pinecone-client. Would you like to install it now? [y/N]: "
)
if user_agree.lower() == "y":
subprocess.check_call([sys.executable, "-m", "pip", "install", "llama_index.vector_stores.pinecone"])
subprocess.check_call([sys.executable, "-m", "pip", "install", "pinecone-client"])
from llama_index.vector_stores.pinecone import PineconeVectorStore
from pinecone import Pinecone, ServerlessSpec
else:
raise ImportError("The required 'llama_index.vector_stores.pinecone and pinecone-client' are not installed.")

pinecone_client = Pinecone(api_key=self.api_key)

if self.create==False:
self.pinecone_index = pinecone_client.Index(self.index_name)
else:
if self.spec is "pod-based" or None:
from pinecone import PodSpec

pinecone_client.create_index(
name=self.index_name,
dimension=self.embedding_dim,
metric=self.metric,
spec=PodSpec(environment=self.environment, pod_type=self.pod_type, replicas=self.replicas),
)
else:
pinecone_client.create_index(
name=self.index_name,
dimension=self.embedding_dim,
metric=self.metric,
spec=ServerlessSpec(cloud=self.cloud, region=self.region),
)

self.pinecone_index = pinecone_client.Index(self.index_name)

self.load()

def load(self):
try:
from llama_index.vector_stores.pinecone import PineconeVectorStore
except:
raise ImportError(
"PineconeVectorStore library is not installed. Please install it with ``pip install llama_index.vector_stores.pinecone``."
)

try:
vector_store = PineconeVectorStore(pinecone_index=self.pinecone_index)
self.client = vector_store
except Exception as e:
raise Exception(f"Failed to load the Pinecone Vectorstore: {e}")

return self.client

def add(self, *args, **kwargs):
client = self.client
return client.add(*args, **kwargs)

def stores_text(self, *args, **kwargs):
client = self.client
return client.stores_text(*args, **kwargs)

def is_embedding_query(self, *args, **kwargs):
client = self.client
return client.is_embedding_query(*args, **kwargs)

def query(self, *args, **kwargs):
client = self.client
return client.query(*args, **kwargs)

@staticmethod
def load_from_kwargs(self, kwargs):
embed_config = VectorDbConfig(**kwargs)
self.config = embed_config
self.load()
Loading