diff --git a/examples/awel/simple_dbschema_retriever_example.py b/examples/awel/simple_dbschema_retriever_example.py index 744d7b763..1edbd2100 100644 --- a/examples/awel/simple_dbschema_retriever_example.py +++ b/examples/awel/simple_dbschema_retriever_example.py @@ -40,6 +40,7 @@ def _create_vector_connector(): """Create vector connector.""" + model_name = os.getenv("EMBEDDING_MODEL", "text2vec") return VectorStoreConnector.from_default( "Chroma", vector_store_config=ChromaVectorConfig( @@ -47,7 +48,7 @@ def _create_vector_connector(): persist_path=os.path.join(PILOT_PATH, "data"), ), embedding_fn=DefaultEmbeddingFactory( - default_model_name=os.path.join(MODEL_PATH, "text2vec-large-chinese"), + default_model_name=os.path.join(MODEL_PATH, model_name), ).create(), ) diff --git a/examples/awel/simple_rag_embedding_example.py b/examples/awel/simple_rag_embedding_example.py index a2a6f961b..1f1763af1 100644 --- a/examples/awel/simple_rag_embedding_example.py +++ b/examples/awel/simple_rag_embedding_example.py @@ -28,6 +28,7 @@ def _create_vector_connector() -> VectorStoreConnector: """Create vector connector.""" + model_name = os.getenv("EMBEDDING_MODEL", "text2vec") return VectorStoreConnector.from_default( "Chroma", vector_store_config=ChromaVectorConfig( @@ -35,7 +36,7 @@ def _create_vector_connector() -> VectorStoreConnector: persist_path=os.path.join(PILOT_PATH, "data"), ), embedding_fn=DefaultEmbeddingFactory( - default_model_name=os.path.join(MODEL_PATH, "text2vec-large-chinese"), + default_model_name=os.path.join(MODEL_PATH, model_name), ).create(), ) diff --git a/examples/awel/simple_rag_retriever_example.py b/examples/awel/simple_rag_retriever_example.py index 1d48d5478..4fda20281 100644 --- a/examples/awel/simple_rag_retriever_example.py +++ b/examples/awel/simple_rag_retriever_example.py @@ -75,6 +75,7 @@ def _context_join_fn(context_dict: Dict, chunks: List[Chunk]) -> Dict: def _create_vector_connector(): """Create vector connector.""" + model_name = os.getenv("EMBEDDING_MODEL", "text2vec") return VectorStoreConnector.from_default( "Chroma", vector_store_config=ChromaVectorConfig( @@ -82,7 +83,7 @@ def _create_vector_connector(): persist_path=os.path.join(PILOT_PATH, "data"), ), embedding_fn=DefaultEmbeddingFactory( - default_model_name=os.path.join(MODEL_PATH, "text2vec-large-chinese"), + default_model_name=os.path.join(MODEL_PATH, model_name), ).create(), )