Skip to content

Commit

Permalink
Feat: add gpustack model provider (#4469)
Browse files Browse the repository at this point in the history
### What problem does this PR solve?

Add GPUStack as a new model provider.
[GPUStack](https://github.com/gpustack/gpustack) is an open-source GPU
cluster manager for running LLMs. Currently, locally deployed models in
GPUStack cannot integrate well with RAGFlow. GPUStack provides both
OpenAI compatible APIs (Models / Chat Completions / Embeddings /
Speech2Text / TTS) and other APIs like Rerank. We would like to use
GPUStack as a model provider in ragflow.

[GPUStack Docs](https://docs.gpustack.ai/latest/quickstart/)

Related issue: #4064.

### Type of change

- [x] New Feature (non-breaking change which adds functionality)



### Testing Instructions
1. Install GPUStack and deploy the `llama-3.2-1b-instruct` llm, `bge-m3`
text embedding model, `bge-reranker-v2-m3` rerank model,
`faster-whisper-medium` Speech-to-Text model, `cosyvoice-300m-sft` in
GPUStack.
2. Add provider in ragflow settings.
3. Testing in ragflow.
  • Loading branch information
alexcodelf authored Jan 15, 2025
1 parent e478586 commit 7944aac
Show file tree
Hide file tree
Showing 12 changed files with 159 additions and 3 deletions.
4 changes: 2 additions & 2 deletions api/apps/llm_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ def my_llms():
@manager.route('/list', methods=['GET']) # noqa: F821
@login_required
def list_app():
self_deploied = ["Youdao", "FastEmbed", "BAAI", "Ollama", "Xinference", "LocalAI", "LM-Studio"]
self_deployed = ["Youdao", "FastEmbed", "BAAI", "Ollama", "Xinference", "LocalAI", "LM-Studio", "GPUStack"]
weighted = ["Youdao", "FastEmbed", "BAAI"] if settings.LIGHTEN != 0 else []
model_type = request.args.get("model_type")
try:
Expand All @@ -339,7 +339,7 @@ def list_app():
llms = [m.to_dict()
for m in llms if m.status == StatusEnum.VALID.value and m.fid not in weighted]
for m in llms:
m["available"] = m["fid"] in facts or m["llm_name"].lower() == "flag-embedding" or m["fid"] in self_deploied
m["available"] = m["fid"] in facts or m["llm_name"].lower() == "flag-embedding" or m["fid"] in self_deployed

llm_set = set([m["llm_name"] + "@" + m["fid"] for m in llms])
for o in objs:
Expand Down
7 changes: 7 additions & 0 deletions conf/llm_factories.json
Original file line number Diff line number Diff line change
Expand Up @@ -2543,6 +2543,13 @@
"tags": "TEXT EMBEDDING",
"status": "1",
"llm": []
},
{
"name": "GPUStack",
"logo": "",
"tags": "LLM,TEXT EMBEDDING,TTS,SPEECH2TEXT,TEXT RE-RANK",
"status": "1",
"llm": []
}
]
}
10 changes: 10 additions & 0 deletions rag/llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
VoyageEmbed,
HuggingFaceEmbed,
VolcEngineEmbed,
GPUStackEmbed,
)
from .chat_model import (
GptTurbo,
Expand Down Expand Up @@ -80,6 +81,7 @@
AnthropicChat,
GoogleChat,
HuggingFaceChat,
GPUStackChat,
)

from .cv_model import (
Expand Down Expand Up @@ -116,20 +118,23 @@
BaiduYiyanRerank,
VoyageRerank,
QWenRerank,
GPUStackRerank,
)
from .sequence2txt_model import (
GPTSeq2txt,
QWenSeq2txt,
AzureSeq2txt,
XinferenceSeq2txt,
TencentCloudSeq2txt,
GPUStackSeq2txt,
)
from .tts_model import (
FishAudioTTS,
QwenTTS,
OpenAITTS,
SparkTTS,
XinferenceTTS,
GPUStackTTS,
)

EmbeddingModel = {
Expand Down Expand Up @@ -161,6 +166,7 @@
"Voyage AI": VoyageEmbed,
"HuggingFace": HuggingFaceEmbed,
"VolcEngine": VolcEngineEmbed,
"GPUStack": GPUStackEmbed,
}

CvModel = {
Expand Down Expand Up @@ -220,6 +226,7 @@
"Anthropic": AnthropicChat,
"Google Cloud": GoogleChat,
"HuggingFace": HuggingFaceChat,
"GPUStack": GPUStackChat,
}

RerankModel = {
Expand All @@ -237,6 +244,7 @@
"BaiduYiyan": BaiduYiyanRerank,
"Voyage AI": VoyageRerank,
"Tongyi-Qianwen": QWenRerank,
"GPUStack": GPUStackRerank,
}

Seq2txtModel = {
Expand All @@ -245,6 +253,7 @@
"Azure-OpenAI": AzureSeq2txt,
"Xinference": XinferenceSeq2txt,
"Tencent Cloud": TencentCloudSeq2txt,
"GPUStack": GPUStackSeq2txt,
}

TTSModel = {
Expand All @@ -253,4 +262,5 @@
"OpenAI": OpenAITTS,
"XunFei Spark": SparkTTS,
"Xinference": XinferenceTTS,
"GPUStack": GPUStackTTS,
}
8 changes: 8 additions & 0 deletions rag/llm/chat_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1514,3 +1514,11 @@ def chat_streamly(self, system, history, gen_conf):
yield ans + "\n**ERROR**: " + str(e)

yield response._chunks[-1].usage_metadata.total_token_count

class GPUStackChat(Base):
def __init__(self, key=None, model_name="", base_url=""):
if not base_url:
raise ValueError("Local llm url cannot be None")
if base_url.split("/")[-1] != "v1-openai":
base_url = os.path.join(base_url, "v1-openai")
super().__init__(key, model_name, base_url)
13 changes: 12 additions & 1 deletion rag/llm/embedding_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from api import settings
from api.utils.file_utils import get_home_cache_dir
from rag.utils import num_tokens_from_string, truncate
import google.generativeai as genai
import google.generativeai as genai
import json


Expand Down Expand Up @@ -799,3 +799,14 @@ def __init__(self, key, model_name, base_url="https://ark.cn-beijing.volces.com/
ark_api_key = json.loads(key).get('ark_api_key', '')
model_name = json.loads(key).get('ep_id', '') + json.loads(key).get('endpoint_id', '')
super().__init__(ark_api_key,model_name,base_url)

class GPUStackEmbed(OpenAIEmbed):
def __init__(self, key, model_name, base_url):
if not base_url:
raise ValueError("url cannot be None")
if base_url.split("/")[-1] != "v1-openai":
base_url = os.path.join(base_url, "v1-openai")

print(key,base_url)
self.client = OpenAI(api_key=key, base_url=base_url)
self.model_name = model_name
52 changes: 52 additions & 0 deletions rag/llm/rerank_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,12 @@
from urllib.parse import urljoin

import requests
import httpx
from huggingface_hub import snapshot_download
import os
from abc import ABC
import numpy as np
from yarl import URL

from api import settings
from api.utils.file_utils import get_home_cache_dir
Expand Down Expand Up @@ -457,3 +459,53 @@ def similarity(self, query: str, texts: list):
return rank, resp.usage.total_tokens
else:
raise ValueError(f"Error calling QWenRerank model {self.model_name}: {resp.status_code} - {resp.text}")

class GPUStackRerank(Base):
def __init__(
self, key, model_name, base_url
):
if not base_url:
raise ValueError("url cannot be None")

self.model_name = model_name
self.base_url = str(URL(base_url)/ "v1" / "rerank")
self.headers = {
"accept": "application/json",
"content-type": "application/json",
"authorization": f"Bearer {key}",
}

def similarity(self, query: str, texts: list):
payload = {
"model": self.model_name,
"query": query,
"documents": texts,
"top_n": len(texts),
}

try:
response = requests.post(
self.base_url, json=payload, headers=self.headers
)
response.raise_for_status()
response_json = response.json()

rank = np.zeros(len(texts), dtype=float)
if "results" not in response_json:
return rank, 0

token_count = 0
for t in texts:
token_count += num_tokens_from_string(t)

for result in response_json["results"]:
rank[result["index"]] = result["relevance_score"]

return (
rank,
token_count,
)

except httpx.HTTPStatusError as e:
raise ValueError(f"Error calling GPUStackRerank model {self.model_name}: {e.response.status_code} - {e.response.text}")

12 changes: 12 additions & 0 deletions rag/llm/sequence2txt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
import requests
from openai.lib.azure import AzureOpenAI
import io
Expand Down Expand Up @@ -191,3 +192,14 @@ def transcription(self, audio, max_retries=60, retry_interval=5):
return "**ERROR**: " + str(e), 0
except Exception as e:
return "**ERROR**: " + str(e), 0


class GPUStackSeq2txt(Base):
def __init__(self, key, model_name, base_url):
if not base_url:
raise ValueError("url cannot be None")
if base_url.split("/")[-1] != "v1-openai":
base_url = os.path.join(base_url, "v1-openai")
self.base_url = base_url
self.model_name = model_name
self.key = key
32 changes: 32 additions & 0 deletions rag/llm/tts_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,3 +355,35 @@ def tts(self, text, voice="standard-voice"):
for chunk in response.iter_content():
if chunk:
yield chunk

class GPUStackTTS:
def __init__(self, key, model_name, **kwargs):
self.base_url = kwargs.get("base_url", None)
self.api_key = key
self.model_name = model_name
self.headers = {
"accept": "application/json",
"Content-Type": "application/json",
"Authorization": f"Bearer {self.api_key}"
}

def tts(self, text, voice="Chinese Female", stream=True):
payload = {
"model": self.model_name,
"input": text,
"voice": voice
}

response = requests.post(
f"{self.base_url}/v1-openai/audio/speech",
headers=self.headers,
json=payload,
stream=stream
)

if response.status_code != 200:
raise Exception(f"**Error**: {response.status_code}, {response.text}")

for chunk in response.iter_content(chunk_size=1024):
if chunk:
yield chunk
14 changes: 14 additions & 0 deletions web/src/assets/svg/llm/gpustack.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 1 addition & 0 deletions web/src/constants/setting.ts
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ export const IconMap = {
'nomic-ai': 'nomic-ai',
jinaai: 'jina',
'sentence-transformers': 'sentence-transformers',
GPUStack: 'gpustack',
};

export const TimezoneList = [
Expand Down
1 change: 1 addition & 0 deletions web/src/pages/user-setting/constants.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ export const LocalLlmFactories = [
'Replicate',
'OpenRouter',
'HuggingFace',
'GPUStack',
];

export enum TenantRole {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ const llmFactoryToUrlMap = {
OpenRouter: 'https://openrouter.ai/docs',
HuggingFace:
'https://huggingface.co/docs/text-embeddings-inference/quick_tour',
GPUStack: 'https://docs.gpustack.ai/latest/quickstart',
};
type LlmFactory = keyof typeof llmFactoryToUrlMap;

Expand Down Expand Up @@ -76,6 +77,13 @@ const OllamaModal = ({
{ value: 'speech2text', label: 'sequence2text' },
{ value: 'tts', label: 'tts' },
],
GPUStack: [
{ value: 'chat', label: 'chat' },
{ value: 'embedding', label: 'embedding' },
{ value: 'rerank', label: 'rerank' },
{ value: 'speech2text', label: 'sequence2text' },
{ value: 'tts', label: 'tts' },
],
Default: [
{ value: 'chat', label: 'chat' },
{ value: 'embedding', label: 'embedding' },
Expand Down

0 comments on commit 7944aac

Please sign in to comment.