Skip to content

Commit

Permalink
chore: update wechat QR code
Browse files Browse the repository at this point in the history
  • Loading branch information
csunny committed Jan 29, 2024
2 parents 2d38063 + be67188 commit f484d65
Show file tree
Hide file tree
Showing 8 changed files with 105 additions and 11 deletions.
1 change: 1 addition & 0 deletions README.zh.md
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ DB-GPT是一个开源的数据域大模型框架。目的是构建大模型领
<img src="https://contrib.rocks/image?repo=eosphoros-ai/DB-GPT&max=200" />
</a>


## Licence

The MIT License (MIT)
Expand Down
Binary file modified assets/wechat.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
19 changes: 18 additions & 1 deletion dbgpt/core/interface/operators/message_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,18 @@ def _filter_round_messages(
... ],
... ]
# Test end rounds is zero
>>> operator = BufferedConversationMapperOperator(
... keep_start_rounds=1, keep_end_rounds=0
... )
>>> assert operator._filter_round_messages(messages) == [
... [
... HumanMessage(content="Hi", round_index=1),
... AIMessage(content="Hello!", round_index=1),
... ]
... ]
Args:
messages_by_round (List[List[BaseMessage]]):
The messages grouped by round.
Expand All @@ -425,7 +437,12 @@ def _filter_round_messages(
"""
total_rounds = len(messages_by_round)
if self._keep_start_rounds is not None and self._keep_end_rounds is not None:
if (
self._keep_start_rounds is not None
and self._keep_end_rounds is not None
and self._keep_start_rounds > 0
and self._keep_end_rounds > 0
):
if self._keep_start_rounds + self._keep_end_rounds > total_rounds:
# Avoid overlapping when the sum of start and end rounds exceeds total
# rounds
Expand Down
67 changes: 67 additions & 0 deletions dbgpt/rag/embedding/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,3 +361,70 @@ def embed_query(self, text: str) -> List[float]:
Embeddings for the text.
"""
return self.embed_documents([text])[0]


class JinaEmbeddings(BaseModel, Embeddings):
"""
This class is used to get embeddings for a list of texts using the Jina AI API.
It requires an API key and a model name. The default model name is "jina-embeddings-v2-base-en".
"""

api_url: Any #: :meta private:
session: Any #: :meta private:
api_key: str
"""our API key for the Jina AI API.."""
model_name: str = "jina-embeddings-v2-base-en"
"""he name of the model to use for text embeddings. Defaults to "jina-embeddings-v2-base-en"."""

def __init__(self, **kwargs):
"""
Initialize the JinaEmbeddings.
"""
super().__init__(**kwargs)
try:
import requests
except ImportError:
raise ValueError(
"The requests python package is not installed. Please install it with `pip install requests`"
)
self.api_url = "https://api.jina.ai/v1/embeddings"
self.session = requests.Session()
self.session.headers.update(
{"Authorization": f"Bearer {self.api_key}", "Accept-Encoding": "identity"}
)

def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Get the embeddings for a list of texts.
Args:
texts (Documents): A list of texts to get embeddings for.
Returns:
Embedded texts as List[List[float]], where each inner List[float]
corresponds to a single input text.
"""
# Call Jina AI Embedding API
resp = self.session.post( # type: ignore
self.api_url, json={"input": texts, "model": self.model_name}
).json()
if "data" not in resp:
raise RuntimeError(resp["detail"])

embeddings = resp["data"]

# Sort resulting embeddings by index
sorted_embeddings = sorted(embeddings, key=lambda e: e["index"]) # type: ignore

# Return just the embeddings
return [result["embedding"] for result in sorted_embeddings]

def embed_query(self, text: str) -> List[float]:
"""Compute query embeddings using a HuggingFace transformer model.
Args:
text: The text to embed.
Returns:
Embeddings for the text.
"""
return self.embed_documents([text])[0]
5 changes: 3 additions & 2 deletions dbgpt/storage/vector_store/milvus_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,9 @@ def __init__(self, vector_store_config: MilvusVectorConfig) -> None:
connections.connect(
host=self.uri or "127.0.0.1",
port=self.port or "19530",
alias="default"
# secure=self.secure,
username=self.username,
password=self.password,
alias="default",
)

def init_schema_and_load(self, vector_name, documents) -> List[str]:
Expand Down
9 changes: 6 additions & 3 deletions examples/awel/simple_dbschema_retriever_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@

from pydantic import BaseModel, Field

from dbgpt.configs.model_config import MODEL_PATH, PILOT_PATH
from dbgpt._private.config import Config
from dbgpt.configs.model_config import EMBEDDING_MODEL_CONFIG, MODEL_PATH, PILOT_PATH
from dbgpt.core.awel import DAG, HttpTrigger, JoinOperator, MapOperator
from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnect
from dbgpt.rag.chunk import Chunk
Expand Down Expand Up @@ -38,17 +39,19 @@
"""


CFG = Config()


def _create_vector_connector():
"""Create vector connector."""
model_name = os.getenv("EMBEDDING_MODEL", "text2vec")
return VectorStoreConnector.from_default(
"Chroma",
vector_store_config=ChromaVectorConfig(
name="vector_name",
persist_path=os.path.join(PILOT_PATH, "data"),
),
embedding_fn=DefaultEmbeddingFactory(
default_model_name=os.path.join(MODEL_PATH, model_name),
default_model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
).create(),
)

Expand Down
8 changes: 5 additions & 3 deletions examples/awel/simple_rag_embedding_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@

from pydantic import BaseModel, Field

from dbgpt.configs.model_config import MODEL_PATH, PILOT_PATH
from dbgpt._private.config import Config
from dbgpt.configs.model_config import EMBEDDING_MODEL_CONFIG, MODEL_PATH, PILOT_PATH
from dbgpt.core.awel import DAG, HttpTrigger, MapOperator
from dbgpt.rag.embedding.embedding_factory import DefaultEmbeddingFactory
from dbgpt.rag.knowledge.base import KnowledgeType
Expand All @@ -25,18 +26,19 @@
}'
"""

CFG = Config()


def _create_vector_connector() -> VectorStoreConnector:
"""Create vector connector."""
model_name = os.getenv("EMBEDDING_MODEL", "text2vec")
return VectorStoreConnector.from_default(
"Chroma",
vector_store_config=ChromaVectorConfig(
name="vector_name",
persist_path=os.path.join(PILOT_PATH, "data"),
),
embedding_fn=DefaultEmbeddingFactory(
default_model_name=os.path.join(MODEL_PATH, model_name),
default_model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
).create(),
)

Expand Down
7 changes: 5 additions & 2 deletions examples/awel/simple_rag_retriever_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

from pydantic import BaseModel, Field

from dbgpt.configs.model_config import MODEL_PATH, PILOT_PATH
from dbgpt._private.config import Config
from dbgpt.configs.model_config import EMBEDDING_MODEL_CONFIG, MODEL_PATH, PILOT_PATH
from dbgpt.core.awel import DAG, HttpTrigger, JoinOperator, MapOperator
from dbgpt.model.proxy import OpenAILLMClient
from dbgpt.rag.chunk import Chunk
Expand Down Expand Up @@ -43,6 +44,8 @@
}'
"""

CFG = Config()


class TriggerReqBody(BaseModel):
query: str = Field(..., description="User query")
Expand Down Expand Up @@ -83,7 +86,7 @@ def _create_vector_connector():
persist_path=os.path.join(PILOT_PATH, "data"),
),
embedding_fn=DefaultEmbeddingFactory(
default_model_name=os.path.join(MODEL_PATH, model_name),
default_model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
).create(),
)

Expand Down

0 comments on commit f484d65

Please sign in to comment.