forked from casper-hansen/AutoAWQ
-
Notifications
You must be signed in to change notification settings - Fork 0
/
basic_vllm.py
56 lines (44 loc) · 1.53 KB
/
basic_vllm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import asyncio
from transformers import AutoTokenizer, PreTrainedTokenizer
from vllm import AsyncLLMEngine, SamplingParams, AsyncEngineArgs
model_path = "casperhansen/mixtral-instruct-awq"
# prompting
prompt = "You're standing on the surface of the Earth. "\
"You walk one mile south, one mile west and one mile north. "\
"You end up exactly where you started. Where are you?",
prompt_template = "[INST] {prompt} [/INST]"
# sampling params
sampling_params = SamplingParams(
repetition_penalty=1.1,
temperature=0.8,
max_tokens=512
)
# tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_path)
# async engine args for streaming
engine_args = AsyncEngineArgs(
model=model_path,
quantization="awq",
dtype="float16",
max_model_len=512,
enforce_eager=True,
disable_log_requests=True,
disable_log_stats=True,
)
async def generate(model: AsyncLLMEngine, tokenizer: PreTrainedTokenizer):
tokens = tokenizer(prompt_template.format(prompt=prompt)).input_ids
outputs = model.generate(
prompt=prompt,
sampling_params=sampling_params,
request_id=1,
prompt_token_ids=tokens,
)
print("\n** Starting generation!\n")
last_index = 0
async for output in outputs:
print(output.outputs[0].text[last_index:], end="", flush=True)
last_index = len(output.outputs[0].text)
print("\n\n** Finished generation!\n")
if __name__ == '__main__':
model = AsyncLLMEngine.from_engine_args(engine_args)
asyncio.run(generate(model, tokenizer))