From e78acc441a2058cc95a78b23317a7ced2e23c425 Mon Sep 17 00:00:00 2001 From: xiuzhu Date: Wed, 24 Jan 2024 09:45:03 +0800 Subject: [PATCH] fix: Fix examples/awel get model_name from model_config --- examples/awel/simple_dbschema_retriever_example.py | 9 ++++++--- examples/awel/simple_rag_embedding_example.py | 7 ++++--- examples/awel/simple_rag_retriever_example.py | 6 ++++-- 3 files changed, 14 insertions(+), 8 deletions(-) diff --git a/examples/awel/simple_dbschema_retriever_example.py b/examples/awel/simple_dbschema_retriever_example.py index e9119fdb4..72c2dfcd3 100644 --- a/examples/awel/simple_dbschema_retriever_example.py +++ b/examples/awel/simple_dbschema_retriever_example.py @@ -3,7 +3,8 @@ from pydantic import BaseModel, Field -from dbgpt.configs.model_config import MODEL_PATH, PILOT_PATH +from dbgpt._private.config import Config +from dbgpt.configs.model_config import EMBEDDING_MODEL_CONFIG, MODEL_PATH, PILOT_PATH from dbgpt.core.awel import DAG, HttpTrigger, JoinOperator, MapOperator from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnect from dbgpt.rag.chunk import Chunk @@ -38,9 +39,11 @@ """ +CFG = Config() + + def _create_vector_connector(): """Create vector connector.""" - model_name = os.getenv("EMBEDDING_MODEL", "text2vec") return VectorStoreConnector.from_default( "Chroma", vector_store_config=ChromaVectorConfig( @@ -48,7 +51,7 @@ def _create_vector_connector(): persist_path=os.path.join(PILOT_PATH, "data"), ), embedding_fn=DefaultEmbeddingFactory( - default_model_name=os.path.join(MODEL_PATH, model_name), + default_model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL], ).create(), ) diff --git a/examples/awel/simple_rag_embedding_example.py b/examples/awel/simple_rag_embedding_example.py index a49475acc..9c41e5c05 100644 --- a/examples/awel/simple_rag_embedding_example.py +++ b/examples/awel/simple_rag_embedding_example.py @@ -3,7 +3,8 @@ from pydantic import BaseModel, Field -from dbgpt.configs.model_config import MODEL_PATH, PILOT_PATH +from dbgpt._private.config import Config +from dbgpt.configs.model_config import EMBEDDING_MODEL_CONFIG, MODEL_PATH, PILOT_PATH from dbgpt.core.awel import DAG, HttpTrigger, MapOperator from dbgpt.rag.embedding.embedding_factory import DefaultEmbeddingFactory from dbgpt.rag.knowledge.base import KnowledgeType @@ -25,10 +26,10 @@ }' """ +CFG = Config() def _create_vector_connector() -> VectorStoreConnector: """Create vector connector.""" - model_name = os.getenv("EMBEDDING_MODEL", "text2vec") return VectorStoreConnector.from_default( "Chroma", vector_store_config=ChromaVectorConfig( @@ -36,7 +37,7 @@ def _create_vector_connector() -> VectorStoreConnector: persist_path=os.path.join(PILOT_PATH, "data"), ), embedding_fn=DefaultEmbeddingFactory( - default_model_name=os.path.join(MODEL_PATH, model_name), + default_model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL], ).create(), ) diff --git a/examples/awel/simple_rag_retriever_example.py b/examples/awel/simple_rag_retriever_example.py index 5470fca36..69068711b 100644 --- a/examples/awel/simple_rag_retriever_example.py +++ b/examples/awel/simple_rag_retriever_example.py @@ -4,7 +4,8 @@ from pydantic import BaseModel, Field -from dbgpt.configs.model_config import MODEL_PATH, PILOT_PATH +from dbgpt._private.config import Config +from dbgpt.configs.model_config import EMBEDDING_MODEL_CONFIG, MODEL_PATH, PILOT_PATH from dbgpt.core.awel import DAG, HttpTrigger, JoinOperator, MapOperator from dbgpt.model.proxy import OpenAILLMClient from dbgpt.rag.chunk import Chunk @@ -43,6 +44,7 @@ }' """ +CFG = Config() class TriggerReqBody(BaseModel): query: str = Field(..., description="User query") @@ -83,7 +85,7 @@ def _create_vector_connector(): persist_path=os.path.join(PILOT_PATH, "data"), ), embedding_fn=DefaultEmbeddingFactory( - default_model_name=os.path.join(MODEL_PATH, model_name), + default_model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL], ).create(), )