From de31b517820e92eaee2ead261d741a797493153b Mon Sep 17 00:00:00 2001 From: wencan Date: Sat, 6 Jul 2024 09:49:57 +0800 Subject: [PATCH] add baidu-qianfan llm (#14414) --- docs/docs/examples/llm/qianfan.ipynb | 195 ++++++++ .../llms/llama-index-llms-qianfan/.gitignore | 153 ++++++ .../llms/llama-index-llms-qianfan/BUILD | 3 + .../llms/llama-index-llms-qianfan/Makefile | 17 + .../llms/llama-index-llms-qianfan/README.md | 26 + .../llama_index/llms/qianfan/BUILD | 1 + .../llama_index/llms/qianfan/__init__.py | 3 + .../llama_index/llms/qianfan/base.py | 432 ++++++++++++++++ .../llama-index-llms-qianfan/pyproject.toml | 57 +++ .../llms/llama-index-llms-qianfan/tests/BUILD | 1 + .../tests/__init__.py | 0 .../tests/test_llms_qianfan.py | 464 ++++++++++++++++++ .../llama-index-utils-qianfan/.gitignore | 153 ++++++ .../llama-index-utils-qianfan/BUILD | 3 + .../llama-index-utils-qianfan/Makefile | 17 + .../llama-index-utils-qianfan/README.md | 3 + .../llama_index/utils/qianfan/BUILD | 1 + .../llama_index/utils/qianfan/__init__.py | 4 + .../llama_index/utils/qianfan/apis.py | 92 ++++ .../utils/qianfan/authorization.py | 112 +++++ .../llama_index/utils/qianfan/client.py | 179 +++++++ .../llama-index-utils-qianfan/pyproject.toml | 50 ++ .../llama-index-utils-qianfan/tests/BUILD | 1 + .../tests/__init__.py | 0 .../tests/test_apis.py | 91 ++++ .../tests/test_authorization.py | 20 + .../tests/test_client.py | 87 ++++ 27 files changed, 2165 insertions(+) create mode 100644 docs/docs/examples/llm/qianfan.ipynb create mode 100644 llama-index-integrations/llms/llama-index-llms-qianfan/.gitignore create mode 100644 llama-index-integrations/llms/llama-index-llms-qianfan/BUILD create mode 100644 llama-index-integrations/llms/llama-index-llms-qianfan/Makefile create mode 100644 llama-index-integrations/llms/llama-index-llms-qianfan/README.md create mode 100644 llama-index-integrations/llms/llama-index-llms-qianfan/llama_index/llms/qianfan/BUILD create mode 100644 llama-index-integrations/llms/llama-index-llms-qianfan/llama_index/llms/qianfan/__init__.py create mode 100644 llama-index-integrations/llms/llama-index-llms-qianfan/llama_index/llms/qianfan/base.py create mode 100644 llama-index-integrations/llms/llama-index-llms-qianfan/pyproject.toml create mode 100644 llama-index-integrations/llms/llama-index-llms-qianfan/tests/BUILD create mode 100644 llama-index-integrations/llms/llama-index-llms-qianfan/tests/__init__.py create mode 100644 llama-index-integrations/llms/llama-index-llms-qianfan/tests/test_llms_qianfan.py create mode 100644 llama-index-utils/llama-index-utils-qianfan/.gitignore create mode 100644 llama-index-utils/llama-index-utils-qianfan/BUILD create mode 100644 llama-index-utils/llama-index-utils-qianfan/Makefile create mode 100644 llama-index-utils/llama-index-utils-qianfan/README.md create mode 100644 llama-index-utils/llama-index-utils-qianfan/llama_index/utils/qianfan/BUILD create mode 100644 llama-index-utils/llama-index-utils-qianfan/llama_index/utils/qianfan/__init__.py create mode 100644 llama-index-utils/llama-index-utils-qianfan/llama_index/utils/qianfan/apis.py create mode 100644 llama-index-utils/llama-index-utils-qianfan/llama_index/utils/qianfan/authorization.py create mode 100644 llama-index-utils/llama-index-utils-qianfan/llama_index/utils/qianfan/client.py create mode 100644 llama-index-utils/llama-index-utils-qianfan/pyproject.toml create mode 100644 llama-index-utils/llama-index-utils-qianfan/tests/BUILD create mode 100644 llama-index-utils/llama-index-utils-qianfan/tests/__init__.py create mode 100644 llama-index-utils/llama-index-utils-qianfan/tests/test_apis.py create mode 100644 llama-index-utils/llama-index-utils-qianfan/tests/test_authorization.py create mode 100644 llama-index-utils/llama-index-utils-qianfan/tests/test_client.py diff --git a/docs/docs/examples/llm/qianfan.ipynb b/docs/docs/examples/llm/qianfan.ipynb new file mode 100644 index 0000000000000..e7b46fa3bbf74 --- /dev/null +++ b/docs/docs/examples/llm/qianfan.ipynb @@ -0,0 +1,195 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "introduction", + "metadata": {}, + "source": [ + "# Client of Baidu Intelligent Cloud's Qianfan LLM Platform\n", + "\n", + "Baidu Intelligent Cloud's Qianfan LLM Platform offers API services for all Baidu LLMs, such as ERNIE-3.5-8K and ERNIE-4.0-8K. It also provides a small number of open-source LLMs like Llama-2-70b-chat.\n", + "\n", + "Before using the chat client, you need to activate the LLM service on the Qianfan LLM Platform console's [online service](https://console.bce.baidu.com/qianfan/ais/console/onlineService) page. Then, Generate an Access Key and a Secret Key in the [Security Authentication](https://console.bce.baidu.com/iam/#/iam/accesslist) page of the console." + ] + }, + { + "cell_type": "markdown", + "id": "installation", + "metadata": {}, + "source": [ + "## Installation\n", + "\n", + "Install the necessary package:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "installation-code", + "metadata": {}, + "outputs": [], + "source": [ + "%pip install llama-index-llms-qianfan" + ] + }, + { + "cell_type": "markdown", + "id": "initialization", + "metadata": {}, + "source": [ + "## Initialization" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "initialization-code", + "metadata": {}, + "outputs": [], + "source": [ + "from llama_index.llms.qianfan import Qianfan\n", + "import asyncio\n", + "\n", + "access_key = \"XXX\"\n", + "secret_key = \"XXX\"\n", + "model_name = \"ERNIE-Speed-8K\"\n", + "endpoint_url = \"https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie_speed\"\n", + "context_window = 8192\n", + "llm = Qianfan(access_key, secret_key, model_name, endpoint_url, context_window)" + ] + }, + { + "cell_type": "markdown", + "id": "sync-chat", + "metadata": {}, + "source": [ + "## Synchronous Chat\n", + "\n", + "Generate a chat response synchronously using the `chat` method:\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "sync-chat-code", + "metadata": {}, + "outputs": [], + "source": [ + "from llama_index.core.base.llms.types import ChatMessage\n", + "\n", + "messages = [\n", + " ChatMessage(role=\"user\", content=\"Tell me a joke.\"),\n", + "]\n", + "chat_response = llm.chat(messages)\n", + "print(chat_response.message.content)" + ] + }, + { + "cell_type": "markdown", + "id": "sync-stream-chat", + "metadata": {}, + "source": [ + "## Synchronous Stream Chat\n", + "\n", + "Generate a streaming chat response synchronously using the `stream_chat` method:\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "sync-stream-chat-code", + "metadata": {}, + "outputs": [], + "source": [ + "messages = [\n", + " ChatMessage(role=\"system\", content=\"You are a helpful assistant.\"),\n", + " ChatMessage(role=\"user\", content=\"Tell me a story.\"),\n", + "]\n", + "content = \"\"\n", + "for chat_response in llm.stream_chat(messages):\n", + " content += chat_response.delta\n", + " print(chat_response.delta, end=\"\")" + ] + }, + { + "cell_type": "markdown", + "id": "async-chat", + "metadata": {}, + "source": [ + "## Asynchronous Chat\n", + "\n", + "Generate a chat response asynchronously using the `achat` method:\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "async-chat-code", + "metadata": {}, + "outputs": [], + "source": [ + "async def async_chat():\n", + " messages = [\n", + " ChatMessage(role=\"user\", content=\"Tell me an async joke.\"),\n", + " ]\n", + " chat_response = await llm.achat(messages)\n", + " print(chat_response.message.content)\n", + "\n", + "\n", + "asyncio.run(async_chat())" + ] + }, + { + "cell_type": "markdown", + "id": "async-stream-chat", + "metadata": {}, + "source": [ + "## Asynchronous Stream Chat\n", + "\n", + "Generate a streaming chat response asynchronously using the `astream_chat` method:\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "async-stream-chat-code", + "metadata": {}, + "outputs": [], + "source": [ + "async def async_stream_chat():\n", + " messages = [\n", + " ChatMessage(role=\"system\", content=\"You are a helpful assistant.\"),\n", + " ChatMessage(role=\"user\", content=\"Tell me an async story.\"),\n", + " ]\n", + " content = \"\"\n", + " response = await llm.astream_chat(messages)\n", + " async for chat_response in response:\n", + " content += chat_response.delta\n", + " print(chat_response.delta, end=\"\")\n", + "\n", + "\n", + "asyncio.run(async_stream_chat())" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/llama-index-integrations/llms/llama-index-llms-qianfan/.gitignore b/llama-index-integrations/llms/llama-index-llms-qianfan/.gitignore new file mode 100644 index 0000000000000..990c18de22908 --- /dev/null +++ b/llama-index-integrations/llms/llama-index-llms-qianfan/.gitignore @@ -0,0 +1,153 @@ +llama_index/_static +.DS_Store +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +bin/ +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +etc/ +include/ +lib/ +lib64/ +parts/ +sdist/ +share/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +.ruff_cache + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints +notebooks/ + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ +pyvenv.cfg + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# Jetbrains +.idea +modules/ +*.swp + +# VsCode +.vscode + +# pipenv +Pipfile +Pipfile.lock + +# pyright +pyrightconfig.json diff --git a/llama-index-integrations/llms/llama-index-llms-qianfan/BUILD b/llama-index-integrations/llms/llama-index-llms-qianfan/BUILD new file mode 100644 index 0000000000000..0896ca890d8bf --- /dev/null +++ b/llama-index-integrations/llms/llama-index-llms-qianfan/BUILD @@ -0,0 +1,3 @@ +poetry_requirements( + name="poetry", +) diff --git a/llama-index-integrations/llms/llama-index-llms-qianfan/Makefile b/llama-index-integrations/llms/llama-index-llms-qianfan/Makefile new file mode 100644 index 0000000000000..b9eab05aa3706 --- /dev/null +++ b/llama-index-integrations/llms/llama-index-llms-qianfan/Makefile @@ -0,0 +1,17 @@ +GIT_ROOT ?= $(shell git rev-parse --show-toplevel) + +help: ## Show all Makefile targets. + @grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[33m%-30s\033[0m %s\n", $$1, $$2}' + +format: ## Run code autoformatters (black). + pre-commit install + git ls-files | xargs pre-commit run black --files + +lint: ## Run linters: pre-commit (black, ruff, codespell) and mypy + pre-commit install && git ls-files | xargs pre-commit run --show-diff-on-failure --files + +test: ## Run tests via pytest. + pytest tests + +watch-docs: ## Build and watch documentation. + sphinx-autobuild docs/ docs/_build/html --open-browser --watch $(GIT_ROOT)/llama_index/ diff --git a/llama-index-integrations/llms/llama-index-llms-qianfan/README.md b/llama-index-integrations/llms/llama-index-llms-qianfan/README.md new file mode 100644 index 0000000000000..078f813abc1ff --- /dev/null +++ b/llama-index-integrations/llms/llama-index-llms-qianfan/README.md @@ -0,0 +1,26 @@ +# LlamaIndex Llms Integration: Baidu Qianfan + +Baidu Intelligent Cloud's Qianfan LLM Platform offers API services for all Baidu LLMs, such as ERNIE-3.5-8K and ERNIE-4.0-8K. It also provides a small number of open-source LLMs like Llama-2-70b-chat. + +Before using the chat client, you need to activate the LLM service on the Qianfan LLM Platform console's [online service](https://console.bce.baidu.com/qianfan/ais/console/onlineService) page. Then, Generate an Access Key and a Secret Key in the [Security Authentication](https://console.bce.baidu.com/iam/#/iam/accesslist) page of the console. + +## Installation + +Install the necessary package: + +``` +pip install llama-index-llms-qianfan +``` + +## Initialization + +```python +from llama_index.llms.qianfan import Qianfan + +access_key = "XXX" +secret_key = "XXX" +model_name = "ERNIE-Speed-8K" +endpoint_url = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie_speed" +context_window = 8192 +llm = Qianfan(access_key, secret_key, model_name, endpoint_url, context_window) +``` diff --git a/llama-index-integrations/llms/llama-index-llms-qianfan/llama_index/llms/qianfan/BUILD b/llama-index-integrations/llms/llama-index-llms-qianfan/llama_index/llms/qianfan/BUILD new file mode 100644 index 0000000000000..db46e8d6c978c --- /dev/null +++ b/llama-index-integrations/llms/llama-index-llms-qianfan/llama_index/llms/qianfan/BUILD @@ -0,0 +1 @@ +python_sources() diff --git a/llama-index-integrations/llms/llama-index-llms-qianfan/llama_index/llms/qianfan/__init__.py b/llama-index-integrations/llms/llama-index-llms-qianfan/llama_index/llms/qianfan/__init__.py new file mode 100644 index 0000000000000..542f79841cf6e --- /dev/null +++ b/llama-index-integrations/llms/llama-index-llms-qianfan/llama_index/llms/qianfan/__init__.py @@ -0,0 +1,3 @@ +from llama_index.llms.qianfan.base import Qianfan + +__all__ = ["Qianfan"] diff --git a/llama-index-integrations/llms/llama-index-llms-qianfan/llama_index/llms/qianfan/base.py b/llama-index-integrations/llms/llama-index-llms-qianfan/llama_index/llms/qianfan/base.py new file mode 100644 index 0000000000000..e564fb61540ee --- /dev/null +++ b/llama-index-integrations/llms/llama-index-llms-qianfan/llama_index/llms/qianfan/base.py @@ -0,0 +1,432 @@ +from typing import Any, Dict, Sequence, List, Literal, Iterable, AsyncIterable + +from llama_index.core.bridge.pydantic import BaseModel, Field, PrivateAttr +from llama_index.core.base.llms.types import ( + ChatMessage, + ChatResponse, + ChatResponseGen, + ChatResponseAsyncGen, + LLMMetadata, + MessageRole, + CompletionResponse, + CompletionResponseGen, + CompletionResponseAsyncGen, +) +from llama_index.core.llms.callbacks import ( + llm_chat_callback, + llm_completion_callback, +) +from llama_index.core.base.llms.generic_utils import ( + chat_to_completion_decorator, + achat_to_completion_decorator, + stream_chat_to_completion_decorator, + astream_chat_to_completion_decorator, +) +from llama_index.core.constants import DEFAULT_CONTEXT_WINDOW +from llama_index.core.llms.custom import CustomLLM +from llama_index.utils.qianfan import ( + Client, + APIType, + get_service_list, + aget_service_list, +) + + +class ChatMsg(BaseModel): + """ + Chat request message, which is the message item of the chat request model. + """ + + role: Literal["user", "assistant"] + """The role that sends the message.""" + + content: str + """The content of the message.""" + + +class ChatRequest(BaseModel): + """ + Chat request model. + """ + + messages: List[ChatMsg] + """Chat message list.""" + + system: str = "" + """Prompt.""" + + stream: bool = False + """Indicate whether to respond in stream or not.""" + + +class ChatResp(BaseModel): + """ + Chat response model. + """ + + result: str + + +def build_chat_request( + stream: bool, messages: Sequence[ChatMessage], **kwargs: Any +) -> ChatRequest: + """ + Construct a ChatRequest. + + :param messages: The chat message list. + :param stream: Indicate whether to respond in stream or not. + :return: The ChatResponse object. + """ + request = ChatRequest(messages=[], stream=stream) + for message in messages: + if message.role == MessageRole.USER: + msg = ChatMsg(role="user", content=message.content) + request.messages.append(msg) + elif message.role == MessageRole.ASSISTANT: + msg = ChatMsg(role="assistant", content=message.content) + request.messages.append(msg) + elif message.role == MessageRole.SYSTEM: + request.system = message.content + else: + raise NotImplementedError( + f"The message role {message.role} is not supported." + ) + return request + + +def parse_chat_response(resp_dict: Dict) -> ChatResponse: + """ + Parse chat response. + + :param resp_dict: Response body in dict form. + :return: The ChatResponse object. + """ + resp = ChatResp(**resp_dict) + return ChatResponse( + message=ChatMessage(role=MessageRole.ASSISTANT, content=resp.result) + ) + + +def parse_stream_chat_response( + resp_dict_iter: Iterable[Dict], +) -> Iterable[ChatResponse]: + """ + Parse streaming chat response. + + :param resp_dict_iter: Iterator of the response body in dict form. + :return: Iterator of the ChatResponse object. + """ + content = "" + for resp_dict in resp_dict_iter: + resp = ChatResp(**resp_dict) + content += resp.result + yield ChatResponse( + message=ChatMessage(role=MessageRole.ASSISTANT, content=content), + delta=resp.result, + ) + + +async def aparse_stream_chat_response( + resp_dict_iter: AsyncIterable[Dict], +) -> AsyncIterable[ChatResponse]: + """ + Parse asyncio streaming chat response. + + :param resp_dict_iter: Async iterator of the response body in dict form. + :return: Async iterator of the ChatResponse object. + """ + content = "" + async for resp_dict in resp_dict_iter: + resp = ChatResp(**resp_dict) + content += resp.result + yield ChatResponse( + message=ChatMessage(role=MessageRole.ASSISTANT, content=content), + delta=resp.result, + ) + + +class Qianfan(CustomLLM): + """ + The LLM supported by Baidu Intelligent Cloud's QIANFAN LLM Platform. + """ + + access_key: str = Field( + description="The Access Key obtained from the Security Authentication Center of Baidu Intelligent Cloud Console." + ) + + secret_key: str = Field(description="The Secret Key paired with the Access Key.") + + model_name: str = Field(description="The name of the model service.") + + endpoint_url: str = Field(description="The chat endpoint URL of the model service.") + + context_window: int = Field( + default=DEFAULT_CONTEXT_WINDOW, description="The context window size." + ) + + llm_type: APIType = Field(default="chat", description="The LLM type.") + + _client = PrivateAttr() + + def __init__( + self, + access_key: str, + secret_key: str, + model_name: str, + endpoint_url: str, + context_window: int, + llm_type: APIType = "chat", + ) -> None: + """ + Initialize a Qianfan LLM instance. + + :param access_key: The Access Key obtained from the Security Authentication Center + of Baidu Intelligent Cloud Console. + :param secret_key: The Secret Key paired with the Access Key. + :param model_name: The name of the model service. For example: ERNIE-4.0-8K. + :param endpoint_url: The chat endpoint URL of the model service. + For example: https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro . + :param context_windows: The context window size. for example: 8192. + :param llm_type: The LLM type. Currently, only the chat type is supported. + """ + if llm_type != "chat": + raise NotImplementedError("Only the chat type is supported.") + + self._client = Client(access_key, secret_key) + + super().__init__( + model_name=model_name, + endpoint_url=endpoint_url, + context_window=context_window, + access_key=access_key, + secret_key=secret_key, + llm_type=llm_type, + ) + + @classmethod + def from_model_name( + cls, + access_key: str, + secret_key: str, + model_name: str, + context_window: int, + ): + """ + Initialize a Qianfan LLM instance. Then query more parameters based on the model name. + + :param access_key: The Access Key obtained from the Security Authentication Center + of Baidu Intelligent Cloud Console. + :param secret_key: The Secret Key paired with the Access Key. + :param model_name: The name of the model service. For example: ERNIE-4.0-8K. + :param context_windows: The context window size. for example: 8192. + """ + service_list = get_service_list(access_key, secret_key, ["chat"]) + try: + service = next( + service for service in service_list if service.name == model_name + ) + except StopIteration: + raise NameError(f"not found {model_name}") + + return cls( + access_key=access_key, + secret_key=secret_key, + model_name=model_name, + endpoint_url=service.url, + context_window=context_window, + llm_type=service.api_type, + ) + + @classmethod + async def afrom_model_name( + cls, + access_key: str, + secret_key: str, + model_name: str, + context_window: int, + ): + """ + Initialize a Qianfan LLM instance. Then asynchronously query more parameters based on the model name. + + :param access_key: The Access Key obtained from the Security Authentication Center of + Baidu Intelligent Cloud Console. + :param secret_key: The Secret Key paired with the Access Key. + :param model_name: The name of the model service. For example: ERNIE-4.0-8K. + :param context_windows: The context window size. for example: 8192. + The LLMs developed by Baidu all carry context window size in their names. + """ + service_list = await aget_service_list(access_key, secret_key, ["chat"]) + try: + service = next( + service for service in service_list if service.name == model_name + ) + except StopIteration: + raise NameError(f"not found {model_name}") + + return cls( + access_key=access_key, + secret_key=secret_key, + model_name=model_name, + endpoint_url=service.url, + context_window=context_window, + llm_type=service.api_type, + ) + + @classmethod + def class_name(cls) -> str: + """Get class name.""" + return "Qianfan_LLM" + + @property + def metadata(self) -> LLMMetadata: + """LLM metadata.""" + return LLMMetadata( + context_window=self.context_window, + is_chat_model=self.llm_type == "chat", + model_name=self.model_name, + ) + + @llm_chat_callback() + def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse: + """ + Request a chat. + + :param messages: The chat message list. The last message is the current request, + and the previous messages are the historical chat information. The number of + members must be odd, and the role value of the odd-numbered messages must be + "user", while the role value of the even-numbered messages must be "assistant". + :return: The ChatResponse object. + """ + request = build_chat_request(stream=False, messages=messages, **kwargs) + resp_dict = self._client.post(self.endpoint_url, json=request.dict()) + return parse_chat_response(resp_dict) + + @llm_chat_callback() + async def achat( + self, + messages: Sequence[ChatMessage], + **kwargs: Any, + ) -> ChatResponse: + """ + Asynchronous request for a chat. + + :param messages: The chat message list. The last message is the current request, + and the previous messages are the historical chat information. The number of + members must be odd, and the role value of the odd-numbered messages must be + "user", while the role value of the even-numbered messages must be "assistant". + :return: The ChatResponse object. + """ + request = build_chat_request(stream=False, messages=messages, **kwargs) + resp_dict = await self._client.apost(self.endpoint_url, json=request.dict()) + return parse_chat_response(resp_dict) + + @llm_chat_callback() + def stream_chat( + self, messages: Sequence[ChatMessage], **kwargs: Any + ) -> ChatResponseGen: + """ + Request a chat, and the response is returned in a stream. + + :param messages: The chat message list. The last message is the current request, + and the previous messages are the historical chat information. The number of + members must be odd, and the role value of the odd-numbered messages must be + "user", while the role value of the even-numbered messages must be "assistant". + :return: A ChatResponseGen object, which is a generator of ChatResponse. + """ + request = build_chat_request(stream=True, messages=messages, **kwargs) + + def gen(): + resp_dict_iter = self._client.post_reply_stream( + self.endpoint_url, json=request.dict() + ) + yield from parse_stream_chat_response(resp_dict_iter) + + return gen() + + @llm_chat_callback() + async def astream_chat( + self, + messages: Sequence[ChatMessage], + **kwargs: Any, + ) -> ChatResponseAsyncGen: + """ + Asynchronous request a chat, and the response is returned in a stream. + + :param messages: The chat message list. The last message is the current request, + and the previous messages are the historical chat information. The number of + members must be odd, and the role value of the odd-numbered messages must be + "user", while the role value of the even-numbered messages must be "assistant". + :return: A ChatResponseAsyncGen object, which is a asynchronous generator of ChatResponse. + """ + request = build_chat_request(stream=True, messages=messages, **kwargs) + + async def gen(): + resp_dict_iter = self._client.apost_reply_stream( + self.endpoint_url, json=request.dict() + ) + async for part in aparse_stream_chat_response(resp_dict_iter): + yield part + + return gen() + + @llm_completion_callback() + def complete( + self, prompt: str, formatted: bool = False, **kwargs: Any + ) -> CompletionResponse: + """ + Request to complete a message that begins with the specified prompt. + The LLM developed by Baidu does not support the complete function. + Here use a converter to convert the chat function to a complete function. + + :param prompt: The prompt message at the beginning of the completed content. + :return: CompletionResponse. + """ + complete_fn = chat_to_completion_decorator(self.chat) + return complete_fn(prompt, **kwargs) + + @llm_completion_callback() + async def acomplete( + self, prompt: str, formatted: bool = False, **kwargs: Any + ) -> CompletionResponse: + """ + Asynchronous request to complete a message that begins with the specified prompt. + The LLM developed by Baidu does not support the complete function. + Here use a converter to convert the chat function to a complete function. + + :param prompt: The prompt message at the beginning of the completed content. + :return: A CompletionResponse object. + """ + complete_fn = achat_to_completion_decorator(self.achat) + return await complete_fn(prompt, **kwargs) + + @llm_completion_callback() + def stream_complete( + self, prompt: str, formatted: bool = False, **kwargs: Any + ) -> CompletionResponseGen: + """ + Request to complete a message that begins with the specified prompt, + and the response is returned in a stream. + The LLM developed by Baidu does not support the complete function. + Here use a converter to convert the chat function to a complete function. + + :param prompt: The prompt message at the beginning of the completed content. + :return: A CompletionResponseGen object. + """ + complete_fn = stream_chat_to_completion_decorator(self.stream_chat) + return complete_fn(prompt, **kwargs) + + @llm_completion_callback() + async def astream_complete( + self, prompt: str, formatted: bool = False, **kwargs: Any + ) -> CompletionResponseAsyncGen: + """ + Asynchronous request to complete a message that begins with the specified prompt, + and the response is returned in a stream. + The LLM developed by Baidu does not support the complete function. + Here use a converter to convert the chat function to a complete function. + + :param prompt: The prompt message at the beginning of the completed content. + :return: A CompletionResponseAsyncGen object. + """ + complete_fn = astream_chat_to_completion_decorator(self.astream_chat) + return await complete_fn(prompt, **kwargs) diff --git a/llama-index-integrations/llms/llama-index-llms-qianfan/pyproject.toml b/llama-index-integrations/llms/llama-index-llms-qianfan/pyproject.toml new file mode 100644 index 0000000000000..08d75eef8aa33 --- /dev/null +++ b/llama-index-integrations/llms/llama-index-llms-qianfan/pyproject.toml @@ -0,0 +1,57 @@ +[build-system] +build-backend = "poetry.core.masonry.api" +requires = ["poetry-core"] + +[tool.codespell] +check-filenames = true +check-hidden = true +# Feel free to un-skip examples, and experimental, you will just need to +# work through many typos (--write-changes and --interactive will help) +skip = "*.csv,*.html,*.json,*.jsonl,*.pdf,*.txt,*.ipynb" + +[tool.llamahub] +contains_example = false +import_path = "llama_index.llms.qianfan" + +[tool.llamahub.class_authors] +Qianfan = "wencan" + +[tool.mypy] +disallow_untyped_defs = true +# Remove venv skip when integrated with pre-commit +exclude = ["_static", "build", "examples", "notebooks", "venv"] +ignore_missing_imports = true +python_version = "3.8" + +[tool.poetry] +authors = ["wencan "] +description = "llama-index llms baidu qianfan integration" +license = "MIT" +name = "llama-index-llms-qianfan" +packages = [{include = "llama_index/"}] +readme = "README.md" +version = "0.1.0" + +[tool.poetry.dependencies] +python = ">=3.8.1,<4.0" +llama-index-core = "^0.10.0" +llama-index-utils-qianfan = "^0.1.0" + +[tool.poetry.group.dev.dependencies] +black = {extras = ["jupyter"], version = "<=23.9.1,>=23.7.0"} +codespell = {extras = ["toml"], version = ">=v2.2.6"} +ipython = "8.10.0" +jupyter = "^1.0.0" +mypy = "0.991" +pre-commit = "3.2.0" +pylint = "2.15.10" +pytest = "7.2.1" +pytest-mock = "3.11.1" +ruff = "0.0.292" +tree-sitter-languages = "^1.8.0" +types-Deprecated = ">=0.1.0" +types-PyYAML = "^6.0.12.12" +types-protobuf = "^4.24.0.4" +types-redis = "4.5.5.0" +types-requests = "2.28.11.8" # TODO: unpin when mypy>0.991 +types-setuptools = "67.1.0.0" diff --git a/llama-index-integrations/llms/llama-index-llms-qianfan/tests/BUILD b/llama-index-integrations/llms/llama-index-llms-qianfan/tests/BUILD new file mode 100644 index 0000000000000..dabf212d7e716 --- /dev/null +++ b/llama-index-integrations/llms/llama-index-llms-qianfan/tests/BUILD @@ -0,0 +1 @@ +python_tests() diff --git a/llama-index-integrations/llms/llama-index-llms-qianfan/tests/__init__.py b/llama-index-integrations/llms/llama-index-llms-qianfan/tests/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/llama-index-integrations/llms/llama-index-llms-qianfan/tests/test_llms_qianfan.py b/llama-index-integrations/llms/llama-index-llms-qianfan/tests/test_llms_qianfan.py new file mode 100644 index 0000000000000..6fa28010430c7 --- /dev/null +++ b/llama-index-integrations/llms/llama-index-llms-qianfan/tests/test_llms_qianfan.py @@ -0,0 +1,464 @@ +import json +import asyncio +from unittest.mock import patch, MagicMock + +import httpx +from llama_index.core.base.llms.types import ( + ChatMessage, + MessageRole, +) +from llama_index.llms.qianfan import Qianfan + +# The request and response messages come from: +# https://cloud.baidu.com/doc/WENXINWORKSHOP/s/4lqoklvr1 +# https://cloud.baidu.com/doc/WENXINWORKSHOP/s/clntwmv7t + +mock_service_list_reponse = { + "log_id": "4102908182", + "success": True, + "result": { + "common": [ + { + "name": "ERNIE-Bot 4.0", + "url": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro", + "apiType": "chat", + "chargeStatus": "OPENED", + "versionList": [{"trainType": "ernieBot_4", "serviceStatus": "Done"}], + } + ], + "custom": [ + { + "serviceId": "123", + "serviceUuid": "svco-xxxxaaa", + "name": "conductor_liana2", + "url": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ca6zisxxxx", + "apiType": "chat", + "chargeStatus": "NOTOPEN", + "versionList": [ + { + "aiModelId": "xxx-123", + "aiModelVersionId": "xxx-456", + "trainType": "llama2-7b", + "serviceStatus": "Done", + } + ], + } + ], + }, +} + + +mock_chat_response = { + "id": "as-fg4g836x8n", + "object": "chat.completion", + "created": 1709716601, + "result": "北京,简称“京”,古称燕京、北平,中华民族的发祥地之一,是中华人民共和国首都、直辖市、国家中心城市、超大城市,也是国务院批复确定的中国政治中心、文化中心、国际交往中心、科技创新中心,中国历史文化名城和古都之一,世界一线城市。\n\n北京被世界城市研究机构评为世界一线城市,联合国报告指出北京市人类发展指数居中国城市第二位。北京市成功举办夏奥会与冬奥会,成为全世界第一个“双奥之城”。北京有着3000余年的建城史和850余年的建都史,是全球拥有世界遗产(7处)最多的城市。\n\n北京是一个充满活力和创新精神的城市,也是中国传统文化与现代文明的交汇点。在这里,你可以看到古老的四合院、传统的胡同、雄伟的长城和现代化的高楼大厦交相辉映。此外,北京还拥有丰富的美食文化,如烤鸭、炸酱面等,以及各种传统艺术表演,如京剧、相声等。\n\n总的来说,北京是一个充满魅力和活力的城市,无论你是历史爱好者、美食家还是现代都市人,都能在这里找到属于自己的乐趣和归属感。", + "is_truncated": False, + "need_clear_history": False, + "finish_reason": "normal", + "usage": {"prompt_tokens": 2, "completion_tokens": 221, "total_tokens": 223}, +} + +mock_stream_chat_response = [ + { + "id": "as-vb0m37ti8y", + "object": "chat.completion", + "created": 1709089502, + "sentence_id": 0, + "is_end": False, + "is_truncated": False, + "result": "当然可以,", + "need_clear_history": False, + "finish_reason": "normal", + "usage": {"prompt_tokens": 5, "completion_tokens": 2, "total_tokens": 7}, + }, + { + "id": "as-vb0m37ti8y", + "object": "chat.completion", + "created": 1709089504, + "sentence_id": 1, + "is_end": False, + "is_truncated": False, + "result": "以下是一些建议的自驾游路线,它们涵盖了各种不同的风景和文化体验:\n\n1. **西安-敦煌历史文化之旅**:\n\n\n\t* 路线:西安", + "need_clear_history": False, + "finish_reason": "normal", + "usage": {"prompt_tokens": 5, "completion_tokens": 2, "total_tokens": 7}, + }, + { + "id": "as-vb0m37ti8y", + "object": "chat.completion", + "created": 1709089506, + "sentence_id": 2, + "is_end": False, + "is_truncated": False, + "result": " - 天水 - 兰州 - 嘉峪关 - 敦煌\n\t* 特点:此路线让您领略到中国西北的丰富历史文化。", + "need_clear_history": False, + "finish_reason": "normal", + "usage": {"prompt_tokens": 5, "completion_tokens": 2, "total_tokens": 7}, + }, + { + "id": "as-vb0m37ti8y", + "object": "chat.completion", + "created": 1709089508, + "sentence_id": 3, + "is_end": False, + "is_truncated": False, + "result": "您可以参观西安的兵马俑、大雁塔,体验兰州的黄河风情,以及在敦煌欣赏壮丽的莫高窟。", + "need_clear_history": False, + "finish_reason": "normal", + "usage": {"prompt_tokens": 5, "completion_tokens": 2, "total_tokens": 7}, + }, + { + "id": "as-vb0m37ti8y", + "object": "chat.completion", + "created": 1709089511, + "sentence_id": 4, + "is_end": False, + "is_truncated": False, + "result": "\n2. **海南环岛热带风情游**:\n\n\n\t* 路线:海口 - 三亚 - 陵水 - 万宁 - 文昌 - 海", + "need_clear_history": False, + "finish_reason": "normal", + "usage": {"prompt_tokens": 5, "completion_tokens": 2, "total_tokens": 7}, + }, + { + "id": "as-vb0m37ti8y", + "object": "chat.completion", + "created": 1709089512, + "sentence_id": 5, + "is_end": False, + "is_truncated": False, + "result": "口\n\t* 特点:海南岛是中国唯一的黎族聚居区,这里有独特的热带风情、美丽的海滩和丰富的水果。", + "need_clear_history": False, + "finish_reason": "normal", + "usage": {"prompt_tokens": 5, "completion_tokens": 153, "total_tokens": 158}, + }, + { + "id": "as-vb0m37ti8y", + "object": "chat.completion", + "created": 1709089513, + "sentence_id": 6, + "is_end": False, + "is_truncated": False, + "result": "您可以在三亚享受阳光沙滩,品尝当地美食,感受海南的悠闲生活。", + "need_clear_history": False, + "finish_reason": "normal", + "usage": {"prompt_tokens": 5, "completion_tokens": 153, "total_tokens": 158}, + }, + { + "id": "as-vb0m37ti8y", + "object": "chat.completion", + "created": 1709089516, + "sentence_id": 7, + "is_end": False, + "is_truncated": False, + "result": "\n3. **穿越阿里大北线**:\n\n\n\t* 路线:成都 - 广元 - 汉中 - 西安 - 延安 - 银川 -", + "need_clear_history": False, + "finish_reason": "normal", + "usage": {"prompt_tokens": 5, "completion_tokens": 153, "total_tokens": 158}, + }, + { + "id": "as-vb0m37ti8y", + "object": "chat.completion", + "created": 1709089518, + "sentence_id": 8, + "is_end": False, + "is_truncated": False, + "result": " 阿拉善左旗 - 额济纳旗 - 嘉峪关 - 敦煌\n\t* 特点:这是一条充满挑战的自驾路线,穿越了中国", + "need_clear_history": False, + "finish_reason": "normal", + "usage": {"prompt_tokens": 5, "completion_tokens": 153, "total_tokens": 158}, + }, + { + "id": "as-vb0m37ti8y", + "object": "chat.completion", + "created": 1709089519, + "sentence_id": 9, + "is_end": False, + "is_truncated": False, + "result": "的西部。", + "need_clear_history": False, + "finish_reason": "normal", + "usage": {"prompt_tokens": 5, "completion_tokens": 239, "total_tokens": 244}, + }, + { + "id": "as-vb0m37ti8y", + "object": "chat.completion", + "created": 1709089519, + "sentence_id": 10, + "is_end": False, + "is_truncated": False, + "result": "您将经过壮观的沙漠、神秘的戈壁和古老的丝绸之路遗址。", + "need_clear_history": False, + "finish_reason": "normal", + "usage": {"prompt_tokens": 5, "completion_tokens": 239, "total_tokens": 244}, + }, + { + "id": "as-vb0m37ti8y", + "object": "chat.completion", + "created": 1709089520, + "sentence_id": 11, + "is_end": False, + "is_truncated": False, + "result": "此路线适合喜欢探险和寻求不同体验的旅行者。", + "need_clear_history": False, + "finish_reason": "normal", + "usage": {"prompt_tokens": 5, "completion_tokens": 239, "total_tokens": 244}, + }, + { + "id": "as-vb0m37ti8y", + "object": "chat.completion", + "created": 1709089523, + "sentence_id": 12, + "is_end": False, + "is_truncated": False, + "result": "\n4. **寻找北方净土 - 阿尔山自驾之旅**:\n\n\n\t* 路线:北京 - 张家口 - 张北 - 太仆寺旗", + "need_clear_history": False, + "finish_reason": "normal", + "usage": {"prompt_tokens": 5, "completion_tokens": 239, "total_tokens": 244}, + }, + { + "id": "as-vb0m37ti8y", + "object": "chat.completion", + "created": 1709089525, + "sentence_id": 13, + "is_end": False, + "is_truncated": False, + "result": " - 锡林浩特 - 东乌珠穆沁旗 - 满都湖宝拉格 - 宝格达林场 - 五岔沟 - 阿尔山 -", + "need_clear_history": False, + "finish_reason": "normal", + "usage": {"prompt_tokens": 5, "completion_tokens": 239, "total_tokens": 244}, + }, + { + "id": "as-vb0m37ti8y", + "object": "chat.completion", + "created": 1709089527, + "sentence_id": 14, + "is_end": False, + "is_truncated": False, + "result": " 伊尔施 - 新巴尔虎右旗 - 满洲里 - 北京\n\t* 特点:此路线带您穿越中国北方的草原和森林,抵达", + "need_clear_history": False, + "finish_reason": "normal", + "usage": {"prompt_tokens": 5, "completion_tokens": 239, "total_tokens": 244}, + }, + { + "id": "as-vb0m37ti8y", + "object": "chat.completion", + "created": 1709089527, + "sentence_id": 15, + "is_end": False, + "is_truncated": False, + "result": "风景如画的阿尔山。", + "need_clear_history": False, + "finish_reason": "normal", + "usage": {"prompt_tokens": 5, "completion_tokens": 239, "total_tokens": 244}, + }, + { + "id": "as-vb0m37ti8y", + "object": "chat.completion", + "created": 1709089528, + "sentence_id": 16, + "is_end": False, + "is_truncated": False, + "result": "您可以在这里欣赏壮丽的自然风光,体验当地的民俗文化,享受宁静的乡村生活。", + "need_clear_history": False, + "finish_reason": "normal", + "usage": {"prompt_tokens": 5, "completion_tokens": 239, "total_tokens": 244}, + }, + { + "id": "as-vb0m37ti8y", + "object": "chat.completion", + "created": 1709089529, + "sentence_id": 17, + "is_end": False, + "is_truncated": False, + "result": "\n\n以上路线仅供参考,您可以根据自己的兴趣和时间安排进行调整。", + "need_clear_history": False, + "finish_reason": "normal", + "usage": {"prompt_tokens": 5, "completion_tokens": 239, "total_tokens": 244}, + }, + { + "id": "as-vb0m37ti8y", + "object": "chat.completion", + "created": 1709089531, + "sentence_id": 18, + "is_end": False, + "is_truncated": False, + "result": "在规划自驾游时,请务必注意道路安全、车辆保养以及当地的天气和交通状况。", + "need_clear_history": False, + "finish_reason": "normal", + "usage": {"prompt_tokens": 5, "completion_tokens": 239, "total_tokens": 244}, + }, + { + "id": "as-vb0m37ti8y", + "object": "chat.completion", + "created": 1709089531, + "sentence_id": 19, + "is_end": False, + "is_truncated": False, + "result": "祝您旅途愉快!", + "need_clear_history": False, + "finish_reason": "normal", + "usage": {"prompt_tokens": 5, "completion_tokens": 239, "total_tokens": 244}, + }, + { + "id": "as-vb0m37ti8y", + "object": "chat.completion", + "created": 1709089531, + "sentence_id": 20, + "is_end": True, + "is_truncated": False, + "result": "", + "need_clear_history": False, + "finish_reason": "normal", + "usage": {"prompt_tokens": 5, "completion_tokens": 420, "total_tokens": 425}, + }, +] + + +@patch("httpx.Client") +def test_from_model_name(mock_client: httpx.Client): + mock_response = MagicMock() + mock_response.json.return_value = mock_service_list_reponse + mock_client.return_value.__enter__.return_value.send.return_value = mock_response + + llm = Qianfan.from_model_name( + "mock_access_key", "mock_secret_key", "ERNIE-Bot 4.0", "8192" + ) + assert llm.model_name == "ERNIE-Bot 4.0" + assert ( + llm.endpoint_url + == "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro" + ) + assert llm.llm_type == "chat" + + mock_client.return_value.__enter__.return_value.send.assert_called_once() + + +@patch("httpx.AsyncClient") +def test_afrom_model_name(mock_client: httpx.AsyncClient): + mock_response = MagicMock() + mock_response.json.return_value = mock_service_list_reponse + mock_client.return_value.__aenter__.return_value.send.return_value = mock_response + + async def async_process(): + llm = await Qianfan.afrom_model_name( + "mock_access_key", "mock_secret_key", "ERNIE-Bot 4.0", "8192" + ) + assert llm.model_name == "ERNIE-Bot 4.0" + assert ( + llm.endpoint_url + == "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro" + ) + assert llm.llm_type == "chat" + + asyncio.run(async_process()) + + mock_client.return_value.__aenter__.return_value.send.assert_called_once() + + +@patch("httpx.Client") +def test_chat(mock_client: httpx.Client): + mock_response = MagicMock() + mock_response.json.return_value = mock_chat_response + mock_client.return_value.__enter__.return_value.send.return_value = mock_response + + llm = Qianfan( + "mock_access_key", + "mock_secret_key", + "test-model", + "https://127.0.0.1/test", + 8192, + ) + resp = llm.chat([ChatMessage(role=MessageRole.USER, content="介绍一下北京")]) + assert resp.message.content == mock_chat_response["result"] + + mock_client.return_value.__enter__.return_value.send.assert_called_once() + + +@patch("httpx.AsyncClient") +def test_achat(mock_client: httpx.AsyncClient): + mock_response = MagicMock() + mock_response.json.return_value = mock_chat_response + mock_client.return_value.__aenter__.return_value.send.return_value = mock_response + + async def async_process(): + llm = Qianfan( + "mock_access_key", + "mock_secret_key", + "test-model", + "https://127.0.0.1/test", + 8192, + ) + resp = await llm.achat([ChatMessage(role=MessageRole.USER, content="介绍一下北京")]) + assert resp.message.content == mock_chat_response["result"] + + asyncio.run(async_process()) + + mock_client.return_value.__aenter__.return_value.send.assert_called_once() + + +@patch("httpx.Client") +def test_stream_chat(mock_client: httpx.Client): + reply_data = ["data: " + json.dumps(item) for item in mock_stream_chat_response] + + mock_response = MagicMock() + mock_response.iter_lines.return_value = iter(reply_data) + mock_client.return_value.__enter__.return_value.send.return_value = mock_response + + llm = Qianfan( + "mock_access_key", + "mock_secret_key", + "test-model", + "https://127.0.0.1/test", + 8192, + ) + resp = llm.stream_chat([ChatMessage(role=MessageRole.USER, content="给我推荐一些自驾游路线")]) + last_content = "" + content = "" + for part in resp: + content += part.delta + last_content = part.message.content + assert last_content == content + assert last_content == "".join( + [mock_part["result"] for mock_part in mock_stream_chat_response] + ) + + mock_client.return_value.__enter__.return_value.send.assert_called_once() + + +@patch("httpx.AsyncClient") +def test_astream_chat(mock_client: httpx.AsyncClient): + reply_data = ["data: " + json.dumps(item) for item in mock_stream_chat_response] + + async def mock_async_gen(): + for part in reply_data: + yield part + + mock_response = MagicMock() + mock_response.aiter_lines.return_value = mock_async_gen() + mock_client.return_value.__aenter__.return_value.send.return_value = mock_response + + async def async_process(): + llm = Qianfan( + "mock_access_key", + "mock_secret_key", + "test-model", + "https://127.0.0.1/test", + 8192, + ) + resp = await llm.astream_chat( + [ChatMessage(role=MessageRole.USER, content="给我推荐一些自驾游路线")] + ) + last_content = "" + content = "" + async for part in resp: + content += part.delta + last_content = part.message.content + assert last_content == content + assert last_content == "".join( + [mock_part["result"] for mock_part in mock_stream_chat_response] + ) + + asyncio.run(async_process()) + + mock_client.return_value.__aenter__.return_value.send.assert_called_once() diff --git a/llama-index-utils/llama-index-utils-qianfan/.gitignore b/llama-index-utils/llama-index-utils-qianfan/.gitignore new file mode 100644 index 0000000000000..990c18de22908 --- /dev/null +++ b/llama-index-utils/llama-index-utils-qianfan/.gitignore @@ -0,0 +1,153 @@ +llama_index/_static +.DS_Store +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +bin/ +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +etc/ +include/ +lib/ +lib64/ +parts/ +sdist/ +share/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +.ruff_cache + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints +notebooks/ + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ +pyvenv.cfg + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# Jetbrains +.idea +modules/ +*.swp + +# VsCode +.vscode + +# pipenv +Pipfile +Pipfile.lock + +# pyright +pyrightconfig.json diff --git a/llama-index-utils/llama-index-utils-qianfan/BUILD b/llama-index-utils/llama-index-utils-qianfan/BUILD new file mode 100644 index 0000000000000..0896ca890d8bf --- /dev/null +++ b/llama-index-utils/llama-index-utils-qianfan/BUILD @@ -0,0 +1,3 @@ +poetry_requirements( + name="poetry", +) diff --git a/llama-index-utils/llama-index-utils-qianfan/Makefile b/llama-index-utils/llama-index-utils-qianfan/Makefile new file mode 100644 index 0000000000000..b9eab05aa3706 --- /dev/null +++ b/llama-index-utils/llama-index-utils-qianfan/Makefile @@ -0,0 +1,17 @@ +GIT_ROOT ?= $(shell git rev-parse --show-toplevel) + +help: ## Show all Makefile targets. + @grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[33m%-30s\033[0m %s\n", $$1, $$2}' + +format: ## Run code autoformatters (black). + pre-commit install + git ls-files | xargs pre-commit run black --files + +lint: ## Run linters: pre-commit (black, ruff, codespell) and mypy + pre-commit install && git ls-files | xargs pre-commit run --show-diff-on-failure --files + +test: ## Run tests via pytest. + pytest tests + +watch-docs: ## Build and watch documentation. + sphinx-autobuild docs/ docs/_build/html --open-browser --watch $(GIT_ROOT)/llama_index/ diff --git a/llama-index-utils/llama-index-utils-qianfan/README.md b/llama-index-utils/llama-index-utils-qianfan/README.md new file mode 100644 index 0000000000000..8f0b808910b46 --- /dev/null +++ b/llama-index-utils/llama-index-utils-qianfan/README.md @@ -0,0 +1,3 @@ +# LlamaIndex Utils Integration: Baidu Qianfan + +Client-side underlying components of Baidu Intelligent Cloud's Qianfan LLM Platform. Can be used to support access to model services of types such as chat, completion, and embedding. diff --git a/llama-index-utils/llama-index-utils-qianfan/llama_index/utils/qianfan/BUILD b/llama-index-utils/llama-index-utils-qianfan/llama_index/utils/qianfan/BUILD new file mode 100644 index 0000000000000..db46e8d6c978c --- /dev/null +++ b/llama-index-utils/llama-index-utils-qianfan/llama_index/utils/qianfan/BUILD @@ -0,0 +1 @@ +python_sources() diff --git a/llama-index-utils/llama-index-utils-qianfan/llama_index/utils/qianfan/__init__.py b/llama-index-utils/llama-index-utils-qianfan/llama_index/utils/qianfan/__init__.py new file mode 100644 index 0000000000000..361ad5f8c096f --- /dev/null +++ b/llama-index-utils/llama-index-utils-qianfan/llama_index/utils/qianfan/__init__.py @@ -0,0 +1,4 @@ +from llama_index.utils.qianfan.client import Client +from llama_index.utils.qianfan.apis import APIType, get_service_list, aget_service_list + +__all__ = ["Client", "APIType", "get_service_list", "aget_service_list"] diff --git a/llama-index-utils/llama-index-utils-qianfan/llama_index/utils/qianfan/apis.py b/llama-index-utils/llama-index-utils-qianfan/llama_index/utils/qianfan/apis.py new file mode 100644 index 0000000000000..d89b8ef8cdd20 --- /dev/null +++ b/llama-index-utils/llama-index-utils-qianfan/llama_index/utils/qianfan/apis.py @@ -0,0 +1,92 @@ +from typing import Sequence, Literal, List + +from llama_index.core.bridge.pydantic import BaseModel, Field + +from llama_index.utils.qianfan.client import Client + +APIType = Literal["chat", "completions", "embeddings", "text2image", "image2text"] + + +class ServiceItem(BaseModel): + """ + Model service item. + """ + + name: str + """model name. example: ERNIE-4.0-8K""" + + url: str + """endpoint url. example: https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro""" + + api_type: APIType = Field(..., alias="apiType") + """api type""" + + charge_status: Literal["NOTOPEN", "OPENED", "STOP", "FREE"] = Field( + ..., alias="chargeStatus" + ) + """Payment status""" + + +class ServiceListResult(BaseModel): + """ + All model service items. + """ + + common: List[ServiceItem] + """built-in model service""" + + custom: List[ServiceItem] + """custom model service""" + + +class ServiceListResponse(BaseModel): + """ + Response for Querying the List of Model Serving. + """ + + result: ServiceListResult + """All model available service items.""" + + +def get_service_list( + access_key: str, secret_key: str, api_type_filter: Sequence[APIType] = [] +): + """ + Get a list of available model services. Can be filtered by api type. + """ + url = "https://qianfan.baidubce.com/wenxinworkshop/service/list" + json = {"apiTypefilter": api_type_filter} + + client = Client(access_key, secret_key) + resp_dict = client.post(url, json=json) + resp = ServiceListResponse(**resp_dict) + + common_services = filter( + lambda service: service.charge_status == "OPENED", resp.result.common + ) + custom_services = filter( + lambda service: service.charge_status == "OPENED", resp.result.custom + ) + return list(common_services) + list(custom_services) + + +async def aget_service_list( + access_key: str, secret_key: str, api_type_filter: Sequence[APIType] = [] +): + """ + Asynchronous get a list of available model services. Can be filtered by api type. + """ + url = "https://qianfan.baidubce.com/wenxinworkshop/service/list" + json = {"apiTypefilter": api_type_filter} + + client = Client(access_key, secret_key) + resp_dict = await client.apost(url, json=json) + resp = ServiceListResponse(**resp_dict) + + common_services = filter( + lambda service: service.charge_status == "OPENED", resp.result.common + ) + custom_services = filter( + lambda service: service.charge_status == "OPENED", resp.result.custom + ) + return list(common_services) + list(custom_services) diff --git a/llama-index-utils/llama-index-utils-qianfan/llama_index/utils/qianfan/authorization.py b/llama-index-utils/llama-index-utils-qianfan/llama_index/utils/qianfan/authorization.py new file mode 100644 index 0000000000000..8806f43f25987 --- /dev/null +++ b/llama-index-utils/llama-index-utils-qianfan/llama_index/utils/qianfan/authorization.py @@ -0,0 +1,112 @@ +import hmac +import hashlib +import urllib.parse +from typing import List, Dict, Tuple +from datetime import datetime, timezone +import urllib.parse + + +def encode_canonical_query(query: str) -> str: + """ + Encoding the HTTP query. + """ + parsed = urllib.parse.parse_qs(query, keep_blank_values=True) + + items: List[str] = [] + for key, values in parsed.items(): + encoded_key = urllib.parse.quote_plus(key) + if key.lower() == "authorization": + continue + if len(values) > 1: # multi value + for val in values: + item = encoded_key + "=" + urllib.parse.quote_plus(val) + items.append(item) + elif len(values[0]) > 0: # single value + item = encoded_key + "=" + urllib.parse.quote_plus(values[0]) + items.append(item) + else: # just key, no value + item = encoded_key + "=" + items.append(item) + + items = sorted(items) + return "&".join(items) + + +def encode_canonical_headers(headers: Dict[str, str], host: str) -> Tuple[str, str]: + """ + Encoding the HTTTP headers. + """ + new_headers: Dict[str, str] = {} + for key, value in headers.items(): + key = key.lower() + new_headers[key] = value + headers = new_headers + + if "host" not in headers: + headers["host"] = host.strip() + + signed_headers: List[str] = [] + canonical_headers: List[str] = [] + for key, value in headers.items(): + if key.find("x-bce-") != 0 and key not in ( + "host", + "content-length", + "content-type", + "content-md5", + ): + continue + signed_headers.append(key) + + if value != "": + header = urllib.parse.quote_plus(key) + ":" + urllib.parse.quote_plus(value) + canonical_headers.append(header) + signed_headers = sorted(signed_headers) + canonical_headers = sorted(canonical_headers) + + return ";".join(signed_headers), "\n".join(canonical_headers) + + +def encode_authorization( + method: str, url: str, headers: Dict[str, str], access_key: str, secret_key: str +) -> str: + """ + Compute the signature for the API. + Document: https://cloud.baidu.com/doc/Reference/s/Njwvz1wot . + + :param method: HTTP method. + :param url: HTTP URL with query string. + :param headers: HTTP headers. + :param access_key: The Access Key obtained from the Security Authentication Center of Baidu Intelligent Cloud Console. + :param secret_key: The Secret Key paired with the Access Key. + :return: The Authorization value. + """ + timestamp = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ") + expire_in_seconds = 60 + + url_parsed = urllib.parse.urlparse(url) + + auth_string_prefix = f"bce-auth-v1/{access_key}/{timestamp}/{expire_in_seconds}" + + canonical_url = urllib.parse.quote(url_parsed.path) + canonical_query = encode_canonical_query(url_parsed.query) + signed_headers, canonical_headers = encode_canonical_headers( + headers, url_parsed.hostname + ) + canonical_request = ( + method.upper() + + "\n" + + canonical_url + + "\n" + + canonical_query + + "\n" + + canonical_headers + ) + + signing_key = hmac.new( + secret_key.encode(), auth_string_prefix.encode(), hashlib.sha256 + ).hexdigest() + signature = hmac.new( + signing_key.encode(), canonical_request.encode(), hashlib.sha256 + ).hexdigest() + + return f"bce-auth-v1/{access_key}/{timestamp}/{expire_in_seconds}/{signed_headers}/{signature}" diff --git a/llama-index-utils/llama-index-utils-qianfan/llama_index/utils/qianfan/client.py b/llama-index-utils/llama-index-utils-qianfan/llama_index/utils/qianfan/client.py new file mode 100644 index 0000000000000..cc9f24e4cf6e1 --- /dev/null +++ b/llama-index-utils/llama-index-utils-qianfan/llama_index/utils/qianfan/client.py @@ -0,0 +1,179 @@ +import json +import logging +import urllib.parse +from typing import Dict, Mapping, Any, Sequence, Union, Tuple, Iterable, AsyncIterable + +import httpx + +from llama_index.utils.qianfan.authorization import encode_authorization + + +QueryParamTypes = Union[ + Mapping[str, Union[Any, Sequence[Any]]], + Sequence[Tuple[str, Any]], +] + + +class Error(Exception): + """ + Error message returned by Baidu QIANFAN LLM Platform. + """ + + error_code: int + error_msg: str + + +logger = logging.getLogger(__name__) + + +def _rebuild_url( + url: str, params: QueryParamTypes = None +) -> Tuple[str, str, QueryParamTypes]: + """ + Rebuild url and return the full URL, the URL without the query, and the query parameters. + """ + parsed_url = urllib.parse.urlparse(url) + query_items = urllib.parse.parse_qsl(parsed_url.query) + + query = httpx.QueryParams(query_items) + if params: + query = query.merge(params) + + full_url = urllib.parse.ParseResult( + parsed_url.scheme, + parsed_url.netloc, + parsed_url.path, + params="", + query=str(query), + fragment=parsed_url.fragment, + ) + url_without_query = urllib.parse.ParseResult( + parsed_url.scheme, + parsed_url.netloc, + parsed_url.path, + params="", + query="", + fragment=parsed_url.fragment, + ) + return full_url.geturl(), url_without_query.geturl(), query.multi_items() + + +class Client: + """ + The access client for Baidu's Qianfan LLM Platform. + """ + + def __init__(self, access_key: str, secret_key: str): + """ + Initialize a Client instance. + + :param access_key: The Access Key obtained from the Security Authentication Center of Baidu Intelligent Cloud Console. + :param secret_key: The Secret Key paired with the Access Key. + """ + self._access_key = access_key + self._secret_key = secret_key + + def _get_headers(self, method: str, url: str) -> Dict[str, str]: + headers = { + "Content-Type": "application/json", + } + authorization = encode_authorization( + method.upper(), url, headers, self._access_key, self._secret_key + ) + headers["Authorization"] = authorization + + return headers + + def _preprocess( + self, method: str, url: str, params: QueryParamTypes = None, json: Any = None + ) -> httpx.Request: + full_url, url_without_query, params = _rebuild_url(url, params) + + if logger.level <= logging.DEBUG: + logging.debug(f"{method} {url_without_query}, request body: {json}") + + headers = self._get_headers(method, full_url) + return httpx.Request( + method=method, + url=url_without_query, + params=params, + headers=headers, + json=json, + ) + + def _postprocess(self, r: httpx.Response) -> Dict: + if logger.level <= logging.DEBUG: + logger.debug(f"{r.request.method} {r.url} response body: {r.text}") + resp_dict = r.json() + + error_code = resp_dict.get("error_code", 0) + if error_code != 0: + raise Error(error_code, resp_dict.get("error_msg")) + + return resp_dict + + def _postprocess_stream_part(self, line: str) -> Iterable[Dict]: + if line == "": + return + + if line.startswith("{") and line.endswith("}"): # error + resp_dict = json.loads(line) + error_code = resp_dict.get("error_code", 0) + if error_code != 0: + raise Error(error_code, resp_dict.get("error_msg")) + + if line.startswith("data: "): + line = line[len("data: ") :] + resp_dict = json.loads(line) + yield resp_dict + + def post(self, url: str, params: QueryParamTypes = None, json: Any = None) -> Dict: + """ + Make an Request with POST Method. + """ + request = self._preprocess("POST", url=url, params=params, json=json) + with httpx.Client() as client: + r = client.send(request=request) + r.raise_for_status() + return self._postprocess(r) + + async def apost( + self, url: str, params: QueryParamTypes = None, json: Any = None + ) -> Dict: + """ + Make an Asynchronous Request with POST Method. + """ + response = self._preprocess("POST", url=url, params=params, json=json) + async with httpx.AsyncClient() as aclient: + r = await aclient.send(request=response) + r.raise_for_status() + return self._postprocess(r) + + def post_reply_stream( + self, url: str, params: QueryParamTypes = None, json: Any = None + ) -> Iterable[Dict]: + """ + Make an Request with POST Method and the response is returned in a stream. + """ + request = self._preprocess("POST", url=url, params=params, json=json) + + with httpx.Client() as client: + r = client.send(request=request, stream=True) + r.raise_for_status() + for line in r.iter_lines(): + yield from self._postprocess_stream_part(line) + + async def apost_reply_stream( + self, url: str, params: QueryParamTypes = None, json: Any = None + ) -> AsyncIterable[Dict]: + """ + Make an Asynchronous Request with POST Method and the response is returned in a stream. + """ + request = self._preprocess("POST", url=url, params=params, json=json) + + async with httpx.AsyncClient() as aclient: + r = await aclient.send(request=request, stream=True) + r.raise_for_status() + async for line in r.aiter_lines(): + for part in self._postprocess_stream_part(line): + yield part diff --git a/llama-index-utils/llama-index-utils-qianfan/pyproject.toml b/llama-index-utils/llama-index-utils-qianfan/pyproject.toml new file mode 100644 index 0000000000000..e1106cdf846ab --- /dev/null +++ b/llama-index-utils/llama-index-utils-qianfan/pyproject.toml @@ -0,0 +1,50 @@ +[build-system] +build-backend = "poetry.core.masonry.api" +requires = ["poetry-core"] + +[tool.codespell] +check-filenames = true +check-hidden = true +# Feel free to un-skip examples, and experimental, you will just need to +# work through many typos (--write-changes and --interactive will help) +skip = "*.csv,*.html,*.json,*.jsonl,*.pdf,*.txt,*.ipynb" + +[tool.mypy] +disallow_untyped_defs = true +# Remove venv skip when integrated with pre-commit +exclude = ["_static", "build", "examples", "notebooks", "venv"] +ignore_missing_imports = true +python_version = "3.8" + +[tool.poetry] +authors = ["wencan "] +description = "llama-index utils baidu qianfan integration" +license = "MIT" +name = "llama-index-utils-qianfan" +packages = [{include = "llama_index/"}] +readme = "README.md" +version = "0.1.0" + +[tool.poetry.dependencies] +python = ">=3.8.1,<4.0" +llama-index-core = "^0.10.0" +httpx = "^0.27.0" + +[tool.poetry.group.dev.dependencies] +black = {extras = ["jupyter"], version = "<=23.9.1,>=23.7.0"} +codespell = {extras = ["toml"], version = ">=v2.2.6"} +ipython = "8.10.0" +jupyter = "^1.0.0" +mypy = "0.991" +pre-commit = "3.2.0" +pylint = "2.15.10" +pytest = "7.2.1" +pytest-mock = "3.11.1" +ruff = "0.0.292" +tree-sitter-languages = "^1.8.0" +types-Deprecated = ">=0.1.0" +types-PyYAML = "^6.0.12.12" +types-protobuf = "^4.24.0.4" +types-redis = "4.5.5.0" +types-requests = "2.28.11.8" # TODO: unpin when mypy>0.991 +types-setuptools = "67.1.0.0" diff --git a/llama-index-utils/llama-index-utils-qianfan/tests/BUILD b/llama-index-utils/llama-index-utils-qianfan/tests/BUILD new file mode 100644 index 0000000000000..dabf212d7e716 --- /dev/null +++ b/llama-index-utils/llama-index-utils-qianfan/tests/BUILD @@ -0,0 +1 @@ +python_tests() diff --git a/llama-index-utils/llama-index-utils-qianfan/tests/__init__.py b/llama-index-utils/llama-index-utils-qianfan/tests/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/llama-index-utils/llama-index-utils-qianfan/tests/test_apis.py b/llama-index-utils/llama-index-utils-qianfan/tests/test_apis.py new file mode 100644 index 0000000000000..47fb1fdfe4b52 --- /dev/null +++ b/llama-index-utils/llama-index-utils-qianfan/tests/test_apis.py @@ -0,0 +1,91 @@ +import asyncio +from unittest.mock import patch, MagicMock +from typing import List + +import httpx + +from llama_index.utils.qianfan.apis import ( + get_service_list, + aget_service_list, + ServiceItem, +) + +mock_service_list_reponse = { + "log_id": "4102908182", + "success": True, + "result": { + "common": [ + { + "name": "ERNIE-Bot 4.0", + "url": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro", + "apiType": "chat", + "chargeStatus": "OPENED", + "versionList": [{"trainType": "ernieBot_4", "serviceStatus": "Done"}], + } + ], + "custom": [ + { + "serviceId": "123", + "serviceUuid": "svco-xxxxaaa", + "name": "conductor_liana2", + "url": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ca6zisxxxx", + "apiType": "chat", + "chargeStatus": "NOTOPEN", + "versionList": [ + { + "aiModelId": "xxx-123", + "aiModelVersionId": "xxx-456", + "trainType": "llama2-7b", + "serviceStatus": "Done", + } + ], + } + ], + }, +} + + +@patch("httpx.Client") +def test_get_service_list(mock_client: httpx.Client): + mock_response = MagicMock() + mock_response.json.return_value = mock_service_list_reponse + mock_client.return_value.__enter__.return_value.send.return_value = mock_response + + service_list: List[ServiceItem] = get_service_list( + "mock_access_key", "mock_secret_key", api_type_filter=["chat"] + ) + assert len(service_list) == 1 # Only return models with the status OPENED. + assert service_list[0].name == "ERNIE-Bot 4.0" + assert ( + service_list[0].url + == "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro" + ) + assert service_list[0].api_type == "chat" + assert service_list[0].charge_status == "OPENED" + + mock_client.return_value.__enter__.return_value.send.assert_called_once() + + +@patch("httpx.AsyncClient") +def test_aget_service_list(mock_client: httpx.AsyncClient): + mock_response = MagicMock() + mock_response.json.return_value = mock_service_list_reponse + mock_client.return_value.__aenter__.return_value.send.return_value = mock_response + + async def async_process(): + service_list: List[ServiceItem] = await aget_service_list( + "mock_access_key", "mock_secret_key", api_type_filter=["chat"] + ) + # Only return models with the status OPENED. + assert len(service_list) == 1 + assert service_list[0].name == "ERNIE-Bot 4.0" + assert ( + service_list[0].url + == "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro" + ) + assert service_list[0].api_type == "chat" + assert service_list[0].charge_status == "OPENED" + + asyncio.run(async_process()) + + mock_client.return_value.__aenter__.return_value.send.assert_called_once() diff --git a/llama-index-utils/llama-index-utils-qianfan/tests/test_authorization.py b/llama-index-utils/llama-index-utils-qianfan/tests/test_authorization.py new file mode 100644 index 0000000000000..c4be66adfd497 --- /dev/null +++ b/llama-index-utils/llama-index-utils-qianfan/tests/test_authorization.py @@ -0,0 +1,20 @@ +from llama_index.utils.qianfan.authorization import ( + encode_canonical_query, + encode_canonical_headers, +) + + +def test_encode_canonical_query() -> None: + assert ( + encode_canonical_query("text&text1=测试&text10=test") + == "text10=test&text1=%E6%B5%8B%E8%AF%95&text=" + ) + + +def test_encode_canonical_headers() -> None: + headers = {"Content-Type": "application/json"} + host = "aip.baidubce.com" + assert encode_canonical_headers(headers, host) == ( + "content-type;host", + "content-type:application%2Fjson\nhost:aip.baidubce.com", + ) diff --git a/llama-index-utils/llama-index-utils-qianfan/tests/test_client.py b/llama-index-utils/llama-index-utils-qianfan/tests/test_client.py new file mode 100644 index 0000000000000..7c2910ec8e1e1 --- /dev/null +++ b/llama-index-utils/llama-index-utils-qianfan/tests/test_client.py @@ -0,0 +1,87 @@ +import json +import asyncio +from unittest.mock import patch, MagicMock + +import httpx + +from llama_index.utils.qianfan.client import Client + + +@patch("httpx.Client") +def test_post(mock_client: httpx.Client): + content = {"content": "Hello"} + + mock_response = MagicMock() + mock_response.json.return_value = {"echo": content} + mock_client.return_value.__enter__.return_value.send.return_value = mock_response + + client = Client("mock_access_key", "mock_secret_key") + resp_dict = client.post( + url="https://127.0.0.1/mock/echo", params={"param": "123"}, json=content + ) + assert resp_dict == {"echo": content} + + mock_client.return_value.__enter__.return_value.send.assert_called_once() + + +@patch("httpx.AsyncClient") +def test_apost(mock_client: httpx.AsyncClient): + content = {"content": "Hello"} + + mock_response = MagicMock() + mock_response.json.return_value = {"echo": content} + mock_client.return_value.__aenter__.return_value.send.return_value = mock_response + + async def async_process(): + client = Client("mock_access_key", "mock_secret_key") + resp_dict = await client.apost( + url="https://127.0.0.1/mock/echo", params={"param": "123"}, json=content + ) + assert resp_dict == {"echo": content} + + asyncio.run(async_process()) + + mock_client.return_value.__aenter__.return_value.send.assert_called_once() + + +@patch("httpx.Client") +def test_post_reply_stream(mock_client: httpx.Client): + content = [{"content": "Hello"}, {"content": "world"}] + reply_data = ["data: " + json.dumps(item) for item in content] + + mock_response = MagicMock() + mock_response.iter_lines.return_value = iter(reply_data) + mock_client.return_value.__enter__.return_value.send.return_value = mock_response + + client = Client("mock_access_key", "mock_secret_key") + resp_dict_iter = client.post_reply_stream( + url="https://127.0.0.1/mock/echo", params={"param": "123"}, json=content + ) + assert list(resp_dict_iter) == content + + mock_client.return_value.__enter__.return_value.send.assert_called_once() + + +@patch("httpx.AsyncClient") +def test_apost_reply_stream(mock_client: httpx.AsyncClient): + content = [{"content": "Hello"}, {"content": "world"}] + reply_data = ["data: " + json.dumps(item) for item in content] + + async def mock_async_gen(): + for part in reply_data: + yield part + + mock_response = MagicMock() + mock_response.aiter_lines.return_value = mock_async_gen() + mock_client.return_value.__aenter__.return_value.send.return_value = mock_response + + async def async_process(): + client = Client("mock_access_key", "mock_secret_key") + resp_dict_iter = client.apost_reply_stream( + url="https://127.0.0.1/mock/echo", params={"param": "123"}, json=content + ) + assert [part async for part in resp_dict_iter] == content + + asyncio.run(async_process()) + + mock_client.return_value.__aenter__.return_value.send.assert_called_once()