Skip to content

Commit dfa4676

Browse files
committed
hf后端支持 lora
1 parent a32391b commit dfa4676

File tree

2 files changed

+45
-4
lines changed

2 files changed

+45
-4
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
## 更新信息
3535

3636
```plaintext
37-
8-17 支持了 vllm 后端的 lora 部署
37+
8-17 支持了 vllm/hf 后端的 lora 部署
3838
8-14 支持了 InternVL2 系列多模态模型
3939
7-28 支持embedding/reranker 的动态组批加速(infinity后端, 比onnx/tensorrt更快)
4040
7-19 支持了多模态模型 glm-4v-gb 的LMDeploy PyTorch后端

gpt_server/model_backend/hf_backend.py

Lines changed: 44 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
from typing import Any, Dict
22
import torch
3+
import os
4+
import json
5+
from peft import PeftModel
36
from transformers import TextIteratorStreamer
47
from transformers.generation.logits_process import LogitsProcessorList
58
from threading import Thread
@@ -15,13 +18,41 @@
1518
invalid_score_processor = InvalidScoreLogitsProcessor()
1619

1720

21+
class NoneContextManager:
22+
def __enter__(self):
23+
pass
24+
25+
def __exit__(self, exc_type, exc_val, exc_tb):
26+
return True
27+
28+
1829
class HFBackend(ModelBackend):
1930
def __init__(self, tokenizer, model: torch.nn.Module) -> None:
2031
self.model = model
2132
self.tokenizer = tokenizer
33+
self.lora_requests = []
34+
lora = os.getenv("lora", None)
35+
if lora:
36+
lora_dict: dict = json.loads(lora)
37+
for i, (lora_name, lora_path) in enumerate(lora_dict.items()):
38+
self.lora_requests.append(
39+
dict(
40+
lora_name=lora_name,
41+
lora_int_id=i,
42+
lora_local_path=lora_path,
43+
)
44+
)
45+
if i == 0:
46+
self.model = PeftModel.from_pretrained(
47+
model=model,
48+
model_id=lora_path,
49+
adapter_name=lora_name,
50+
)
51+
continue
52+
self.model.load_adapter(model_id=lora_path, adapter_name=lora_name)
2253

2354
async def stream_chat(self, params: Dict[str, Any]):
24-
prompt = params.get("prompt","")
55+
prompt = params.get("prompt", "")
2556
logger.info(prompt)
2657
temperature = float(params.get("temperature", 0.8))
2758
top_p = float(params.get("top_p", 0.8))
@@ -61,8 +92,18 @@ async def stream_chat(self, params: Dict[str, Any]):
6192
# presence_penalty=presence_penalty,
6293
# frequency_penalty=frequency_penalty,
6394
)
64-
thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
65-
thread.start()
95+
use_lora = False
96+
for lora in self.lora_requests:
97+
if params["model"] == lora["lora_name"]:
98+
self.model.set_adapter(lora["lora_name"])
99+
use_lora = True
100+
break
101+
context_manager = NoneContextManager()
102+
if not use_lora and self.lora_requests:
103+
context_manager = self.model.disable_adapter()
104+
with context_manager:
105+
thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
106+
thread.start()
66107
generated_text = ""
67108
prompt_tokens = len(input_ids.tolist()[0])
68109
completion_tokens = 0

0 commit comments

Comments
 (0)