|
1 | 1 | from typing import Any, Dict |
2 | 2 | import torch |
| 3 | +import os |
| 4 | +import json |
| 5 | +from peft import PeftModel |
3 | 6 | from transformers import TextIteratorStreamer |
4 | 7 | from transformers.generation.logits_process import LogitsProcessorList |
5 | 8 | from threading import Thread |
|
15 | 18 | invalid_score_processor = InvalidScoreLogitsProcessor() |
16 | 19 |
|
17 | 20 |
|
| 21 | +class NoneContextManager: |
| 22 | + def __enter__(self): |
| 23 | + pass |
| 24 | + |
| 25 | + def __exit__(self, exc_type, exc_val, exc_tb): |
| 26 | + return True |
| 27 | + |
| 28 | + |
18 | 29 | class HFBackend(ModelBackend): |
19 | 30 | def __init__(self, tokenizer, model: torch.nn.Module) -> None: |
20 | 31 | self.model = model |
21 | 32 | self.tokenizer = tokenizer |
| 33 | + self.lora_requests = [] |
| 34 | + lora = os.getenv("lora", None) |
| 35 | + if lora: |
| 36 | + lora_dict: dict = json.loads(lora) |
| 37 | + for i, (lora_name, lora_path) in enumerate(lora_dict.items()): |
| 38 | + self.lora_requests.append( |
| 39 | + dict( |
| 40 | + lora_name=lora_name, |
| 41 | + lora_int_id=i, |
| 42 | + lora_local_path=lora_path, |
| 43 | + ) |
| 44 | + ) |
| 45 | + if i == 0: |
| 46 | + self.model = PeftModel.from_pretrained( |
| 47 | + model=model, |
| 48 | + model_id=lora_path, |
| 49 | + adapter_name=lora_name, |
| 50 | + ) |
| 51 | + continue |
| 52 | + self.model.load_adapter(model_id=lora_path, adapter_name=lora_name) |
22 | 53 |
|
23 | 54 | async def stream_chat(self, params: Dict[str, Any]): |
24 | | - prompt = params.get("prompt","") |
| 55 | + prompt = params.get("prompt", "") |
25 | 56 | logger.info(prompt) |
26 | 57 | temperature = float(params.get("temperature", 0.8)) |
27 | 58 | top_p = float(params.get("top_p", 0.8)) |
@@ -61,8 +92,18 @@ async def stream_chat(self, params: Dict[str, Any]): |
61 | 92 | # presence_penalty=presence_penalty, |
62 | 93 | # frequency_penalty=frequency_penalty, |
63 | 94 | ) |
64 | | - thread = Thread(target=self.model.generate, kwargs=generation_kwargs) |
65 | | - thread.start() |
| 95 | + use_lora = False |
| 96 | + for lora in self.lora_requests: |
| 97 | + if params["model"] == lora["lora_name"]: |
| 98 | + self.model.set_adapter(lora["lora_name"]) |
| 99 | + use_lora = True |
| 100 | + break |
| 101 | + context_manager = NoneContextManager() |
| 102 | + if not use_lora and self.lora_requests: |
| 103 | + context_manager = self.model.disable_adapter() |
| 104 | + with context_manager: |
| 105 | + thread = Thread(target=self.model.generate, kwargs=generation_kwargs) |
| 106 | + thread.start() |
66 | 107 | generated_text = "" |
67 | 108 | prompt_tokens = len(input_ids.tolist()[0]) |
68 | 109 | completion_tokens = 0 |
|
0 commit comments